Source code for code_execution.entrypoints

"""Module for entrypoints for code execution."""

import json
import logging
import os
import tempfile
from collections import defaultdict
from dataclasses import asdict
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union

from tqdm import tqdm

from code_execution.configs import ExecutionConfig
from code_execution.data_structures import Executable
from code_execution.data_structures import ExecutionResult
from code_execution.data_structures import OverallExecutionResults
from code_execution.execution import execute_commands
from code_execution.file_writing import cleanup
from code_execution.file_writing import write_executables
from code_execution.processing import PredTimingsCollection
from code_execution.processing import default_postprocessor
from code_execution.processing import postprocess_commands
from code_execution.processing import preprocess_commands

logger = logging.getLogger(__name__)


[docs] @dataclass class ChunkExecutionResult: """Stores the execution result for a chunk of commands.""" results: Dict[Tuple[int, int], ExecutionResult] write_elapsed: float exec_elapsed: float pure_exec_elapsed: float write_timings: Dict[str, float] cleanup_timings: Dict[str, float]
def _write_maybe_save_error_dir( config, files, raw_preds, exec_dir: Path, error_directory: Optional[Path] ): try: times = write_executables( files_to_write=files, write_rate_limit=config.write_rate_limit, enable_tqdm=config.display_write_progress, log_freq=config.write_log_freq, ) except Exception as e: logger.exception("Error writing executables") if error_directory: error_directory.mkdir(parents=True, exist_ok=True) with error_directory.joinpath("exec_files.txt").open("w") as dir_f: for f in os.listdir(exec_dir.absolute()): dir_f.write(f"{exec_dir/f}\n") error_file = error_directory.joinpath("errors.jsonl") with error_file.open("w") as error_f: for idx, files, pred_dir in files: if isinstance(idx, str): use_idx = int(idx.split(".")[0]) elif isinstance(idx, tuple): use_idx = idx[0] idx = f"{idx[0]}.{idx[1]}" else: use_idx = int(idx) try: to_write = json.dumps( { "use_idx": idx, "pred": raw_preds[use_idx], "files": files, "pred_dir": str(pred_dir), } ) except json.JSONDecodeError: to_write = json.dumps( { "use_idx": idx, "files": "Error decoding files", "pred_dir": str(pred_dir), } ) error_f.write(to_write + "\n") raise e return times def _preproc_step( config: ExecutionConfig, pred_list: List[Dict], preprocessor: Callable[[Dict], Union[Executable, List[Executable]]], execution_dir: Path, preproc_returns_list: bool = False, preproc_batch_size: int = 1, ) -> Tuple: """Preprocess commands and measure timing. Args: config: The config for execution. pred_list: The list of predictions to execute. preprocessor: The preprocessor function to create files and commands. execution_dir: Directory where execution will take place. preproc_returns_list: Whether preprocessor returns one or many items. preproc_batch_size: The batch size for preprocessing. Returns: Tuple: (files_to_write, commands_to_run, filtered_results, preproc_timings) """ logger.debug("Preprocessing commands") preproc_start = datetime.now() *preproc_res, preproc_timings = preprocess_commands( config=config, dir_to_use=execution_dir, pred_list=pred_list, preprocessor=preprocessor, preproc_returns_list=preproc_returns_list, batch_size=preproc_batch_size, ) preproc_elapsed = (datetime.now() - preproc_start).total_seconds() timings = (preproc_timings, preproc_elapsed) return preproc_res, timings def _create_execution_chunks( config: ExecutionConfig, files_to_write: List, commands_to_run: List, ) -> tuple: """Split files and commands into manageable chunks. Args: config: The config for execution. files_to_write: List of files to write. commands_to_run: List of commands to run. Returns: tuple: (file_chunks, command_chunks) """ file_chunks = [] command_chunks = [] if ( config.max_execute_at_once > 0 and len(commands_to_run) > config.max_execute_at_once ): logger.info( f"Executing {len(commands_to_run):,} commands " f"in chunks of {config.max_execute_at_once:,}" ) if config.max_execute_at_once == 1: logger.warning( "max_execute_at_once is set to 1, this will be slow." ) file_chunks = [[f] for f in files_to_write] command_chunks = [[c] for c in commands_to_run] else: for i in range(0, len(commands_to_run), config.max_execute_at_once): file_chunks.append( files_to_write[i : i + config.max_execute_at_once] ) command_chunks.append( commands_to_run[i : i + config.max_execute_at_once] ) else: logger.info(f"Executing {len(commands_to_run):,} commands") file_chunks.append(files_to_write) command_chunks.append(commands_to_run) assert len(file_chunks) == len(command_chunks) logger.debug(f"{len(file_chunks)} chunks to execute.") return file_chunks, command_chunks def _process_single_chunk( chunk_idx: int, total_chunks: int, files: List, commands: List, config: ExecutionConfig, pred_list: List[Dict], execution_dir: Path, debug_dir: Optional[Path] = None, error_directory: Optional[Path] = None, ) -> ChunkExecutionResult: """Process a single chunk of files and commands. Args: chunk_idx: Index of the current chunk. total_chunks: Total number of chunks. files: List of files in this chunk. commands: List of commands in this chunk. config: The config for execution. pred_list: The list of predictions. execution_dir: Directory where execution will take place. debug_dir: Directory to save all files for debugging. error_directory: Directory to save errors to. Returns: Tuple: (results, (write_elapsed, exec_elapsed, pure_exec_elapsed, write_timings, cleanup_timings)) """ if total_chunks > 1: logger.info(f"Executing chunk {chunk_idx+1}/{total_chunks}") # Write files write_start = datetime.now() write_timings = _write_maybe_save_error_dir( config=config, files=files, raw_preds=pred_list, exec_dir=execution_dir, error_directory=error_directory, ) write_elapsed = (datetime.now() - write_start).total_seconds() # Execute commands exec_elapsed, pure_exec_elapsed, results = execute_commands( commands, config ) # Clean up if in debug mode logger.debug("Finished execution, cleaning up...") cleanup_timings = {} if debug_dir is None: cleanup_timings = cleanup( files, rate_limit=config.write_rate_limit, enable_tqdm=config.display_write_progress, ) return ChunkExecutionResult( results=results, write_elapsed=write_elapsed, exec_elapsed=exec_elapsed, pure_exec_elapsed=pure_exec_elapsed, write_timings=write_timings, cleanup_timings=cleanup_timings, ) def _run_execution_workflow( config: ExecutionConfig, pred_list: List[Dict], preprocessor: Callable[[Dict], Union[Executable, List[Executable]]], postprocessor: Callable[[Dict, Dict], Dict], execution_dir: Path, debug_dir: Optional[Path] = None, preproc_returns_list: bool = False, preproc_batch_size: int = 1, error_directory: Optional[Path] = None, ) -> OverallExecutionResults: """Run the execution workflow in a given directory. Args: config: The config for execution. pred_list: The list of predictions to execute. preprocessor: The preprocessor function to create files and commands. postprocessor: The postprocessor function for processing results. execution_dir: Directory where execution will take place. debug_dir: Directory to save all files for debugging. preproc_returns_list: Whether preprocessor returns one or many items. preproc_batch_size: The batch size for preprocessing. error_directory: Directory to save errors to. Returns: OverallExecutionResults: The results of the execution. """ logger.debug(f"Using {execution_dir} as execution directory") start_time = datetime.now() # Step 1: Preprocess commands preproc_res, (preproc_timings, preproc_elapsed) = _preproc_step( config=config, pred_list=pred_list, preprocessor=preprocessor, execution_dir=execution_dir, preproc_returns_list=preproc_returns_list, preproc_batch_size=preproc_batch_size, ) files_to_write, commands_to_run, filtered_results = preproc_res # Step 2: Create execution chunks file_chunks, command_chunks = _create_execution_chunks( config=config, files_to_write=files_to_write, commands_to_run=commands_to_run, ) # Step 3: Process each chunk chunk_results = [] for chunk_idx, (files, commands) in enumerate( zip(file_chunks, command_chunks) ): chunk_results.append( _process_single_chunk( chunk_idx=chunk_idx, total_chunks=len(file_chunks), files=files, commands=commands, config=config, pred_list=pred_list, execution_dir=execution_dir, debug_dir=debug_dir, error_directory=error_directory, ) ) # Step 4: Combine results and postprocess results_dict = {} write_elapsed = exec_elapsed = pure_exec_elapsed = 0 write_timings = {} cleanup_timings = {} for result in chunk_results: results_dict.update({r[0]: r[1] for r in result.results}) write_elapsed += result.write_elapsed exec_elapsed += result.exec_elapsed pure_exec_elapsed += result.pure_exec_elapsed write_timings.update(result.write_timings) cleanup_timings.update(result.cleanup_timings) results_dict.update(filtered_results) timings = PredTimingsCollection( preprocess_time=preproc_timings, writing_time=write_timings, cleanup_time=cleanup_timings, ) # Step 5: Postprocess the results post_start = datetime.now() postprocessed = postprocess_commands( raw_preds=pred_list, results=results_dict, postprocessor=postprocessor, returned_multiple=preproc_returns_list, timings=timings, disable_tqdm=config.disable_tqdm, log_freq=config.log_freq, ) post_elapsed = (datetime.now() - post_start).total_seconds() net_elapsed = (datetime.now() - start_time).total_seconds() return OverallExecutionResults( results=postprocessed, net_time=net_elapsed, pure_exec_time=pure_exec_elapsed, execution_time=exec_elapsed, writing_time=write_elapsed, postprocessing_time=post_elapsed, preprocessing_time=preproc_elapsed, )
[docs] def execute_predictions( config: ExecutionConfig, pred_list: List[Dict], preprocessor: Callable[[Dict], Union[Executable, List[Executable]]], postprocessor: Callable[[Dict, Dict], Dict] = None, debug_dir: Optional[Path] = None, preproc_returns_list: bool = False, preproc_batch_size: int = 1, error_directory: Optional[Path] = None, ) -> OverallExecutionResults: """Executes the program predictions. First preprocesses the commands to run, writes them to disk, then executes them, and finally postprocesses the results. Args: config: The config for execution. pred_list: The list of predictions to execute. preprocessor: The preprocessor function to create files and commands. postprocessor: The postprocessor function for processing results. debug_dir: Directory to save all files for debugging. preproc_returns_list: Whether preprocessor returns one or many items. preproc_batch_size: The batch size for preprocessing. error_directory: Directory to save errors to. Returns: OverallExecutionResults: The results of the execution. """ # Use default postprocessor if none provided if postprocessor is None: logger.info("Using default postprocessor") postprocessor = default_postprocessor logger.debug(f"Starting execution with {len(pred_list):,} predictions") # Either use debug directory or create a temporary one if debug_dir is None: # Check for environment variable specifying temp directory tmp_dir_loc = os.getenv("EXEC_TMP_DIR") if tmp_dir_loc is not None: logger.debug(f"Using {tmp_dir_loc} as temp dir") Path(tmp_dir_loc).mkdir(parents=True, exist_ok=True) # Create temporary directory and run workflow with tempfile.TemporaryDirectory(dir=tmp_dir_loc) as temp_dir: return _run_execution_workflow( config=config, pred_list=pred_list, preprocessor=preprocessor, postprocessor=postprocessor, execution_dir=Path(temp_dir), debug_dir=None, preproc_returns_list=preproc_returns_list, preproc_batch_size=preproc_batch_size, error_directory=error_directory, ) # Use provided debug directory return _run_execution_workflow( config=config, pred_list=pred_list, preprocessor=preprocessor, postprocessor=postprocessor, execution_dir=debug_dir, debug_dir=debug_dir, preproc_returns_list=preproc_returns_list, preproc_batch_size=preproc_batch_size, error_directory=error_directory, )