Source code for code_execution.execution

"""Module for executing code."""

import concurrent.futures
import functools
import logging
import math
import multiprocessing as mp
import pathlib
import subprocess
import time
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import psutil
from tqdm import tqdm

from code_execution.configs import ExecutionConfig
from code_execution.data_structures import Command
from code_execution.data_structures import CommandResult
from code_execution.data_structures import CommandsToRun
from code_execution.data_structures import ExecutionResult
from code_execution.utils import get_results_from_generator

logger = logging.getLogger(__name__)

LOGGING_IS_CONFIGURED = logging.getLogger().hasHandlers()


[docs] def seconds_to_human(seconds): """Converts seconds to a human readable format.""" hours, seconds = divmod(seconds, 3600) minutes, seconds = divmod(seconds, 60) return f"{int(hours):02d}:{int(minutes):02d}:{seconds:05.2f}"
def _execute( command_to_run: List[str], working_dir: pathlib.Path, timeout: int, stdin: Optional[str | List[str]] = None, ) -> Dict: """Executes a single command.""" timed_out = False return_code = -1 runtime = timeout stderr = None stdout = None had_unexpected_error = False start_time = time.time() execution_process = subprocess.Popen( command_to_run, cwd=str(working_dir), stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE, ) if stdin: if isinstance(stdin, list): stdin = "\n".join(stdin) stdin = stdin.encode("utf-8") else: stdin = None try: try: outputs = execution_process.communicate( input=stdin, timeout=timeout ) t1 = time.time() stdout = outputs[0].decode("utf-8") stderr = outputs[1].decode("utf-8") runtime = t1 - start_time return_code = execution_process.returncode except subprocess.TimeoutExpired: stdout = stderr = "" runtime = timeout return_code = 0 timed_out = True execution_process.kill() # pylint: disable=broad-except except Exception as e: stderr = str(e) stdout = "" return_code = -1 runtime = -1 timed_out = False had_unexpected_error = True execution_process.kill() return dict( return_code=return_code, runtime=runtime, stderr=stderr, stdout=stdout, timed_out=timed_out, had_unexpected_error=had_unexpected_error, )
[docs] def safe_execute( command_to_run: List[str], working_dir: pathlib.Path, timeout: int = 10, num_times: int = 1, stdin: Optional[str] = None, stdout_postprocessor: Optional[Callable] = None, ) -> CommandResult: """Executes a list of commands safely. Args: command_to_run: The command to run. working_dir: The working directory to run them in. timeout Timeout. num_times: Number of times to execute the command. Useful for getting runtime and memory means. stdin: The stdin for the command. stdout_postprocessor: A postprocessor for the stdout. Returns: The result of executing the command. """ times = [] had_error = False for _ in range(num_times): res = _execute(command_to_run, working_dir, timeout, stdin=stdin) times.append(res["runtime"]) if res["return_code"] != 0: had_error = True if res["timed_out"]: had_error = True if res["had_unexpected_error"]: had_error = True if had_error: break if num_times == 1 or had_error: res["runtime"] = times[0] else: res["runtime"] = float(np.mean(times)) if stdout_postprocessor: res["stdout"] = stdout_postprocessor(res["stdout"]) return CommandResult(**res)
[docs] def serial_execute_code(key, sample: CommandsToRun) -> ExecutionResult: """Execute a file of code. Args: sample: The sample to run. Returns: The execution result. """ file_path = sample.cwd working_dir_for_execution = ( file_path.parent if file_path.is_file() else file_path ) working_dir_for_execution = working_dir_for_execution.resolve().absolute() results = [] t0 = time.time() for cidx, command in enumerate(sample.commands): command: Command res = safe_execute( command.command, working_dir=working_dir_for_execution, timeout=command.timeout, num_times=command.num_times, stdin=command.stdin, stdout_postprocessor=sample.stdout_postprocessor, ) results.append(res) if sample.should_early_stop(cidx, res): break file_contents = {} for fn in sample.tracked_files: fp = file_path.joinpath(fn) if fp.exists(): file_contents[fn] = fp.read_text(encoding="utf-8") else: file_contents[fn] = None elapsed = time.time() - t0 return ExecutionResult( key=key, command_results=results, elapsed=elapsed, cwd=str(working_dir_for_execution), tracked_files=file_contents, expected_num_commands=len(sample.commands), )
[docs] def execute_single(execution_dict: Dict) -> Tuple[Tuple, ExecutionResult]: """Executes a single program.""" key = execution_dict["key"] executable = execution_dict["executable"] return key, serial_execute_code( key=".".join(map(str, key)), sample=executable )
[docs] def batched_execute_code(to_run: List[Dict]) -> List[Dict]: """Executes a batch of commands.""" results = [None] * len(to_run) for i, command in enumerate(to_run): results[i] = execute_single(command) return results
[docs] def sizeof_fmt(num, suffix="B"): """Human readable file size.""" for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): if abs(num) < 1024.0: return f"{num:3.1f}{unit}{suffix}" num /= 1024.0 return f"{num:.1f}Yi{suffix}"
[docs] def threaded_execution( to_run, execution_fn, max_threads, is_batched: bool = False ): """Executes a list of commands in parallel.""" num_threads = min(len(to_run), max_threads) out = [] if max_threads > 1: with concurrent.futures.ThreadPoolExecutor( max_workers=num_threads ) as executor: for result in executor.map(execution_fn, to_run): if is_batched: out.extend(result) else: out.append(result) executor.shutdown() else: for command in to_run: if is_batched: out.extend(execution_fn(command)) else: out.append(execution_fn(command)) return out
def _parallel_execute_code( to_run: List, max_processes: int, num_executors: int, log_freq: int, is_batched: bool = False, execute_batch_size: int = 100, ) -> List[ExecutionResult]: """Executes a list of commands in parallel. Args: to_run: The list of commands to run. max_processes: The maximum number of processes to run. num_executors: The number of executors to run. log_freq: The frequency to log progress. is_batched: Whether the commands are batched. execute_batch_size: The size of execution batches. Returns: The list of results. """ logger.info( "Starting parallel execution (max_processes=%d num_executors=%d)", max_processes, num_executors, ) max_threads = max(max_processes // num_executors, 1) logger.debug("max_threads=%d", max_threads) logger.debug("execute_batch_size=%d", execute_batch_size) logger.debug("len(to_run)//max_threads=%d", len(to_run) // max_threads) manager_process = psutil.Process() logger.debug("manager_process=%s", manager_process) chunk_size = min(execute_batch_size, len(to_run) // max_threads) logger.debug("chunk_size=%d", chunk_size) # initialize cpu percent psutil.getloadavg() chunks = [] if is_batched: total_commands = sum(map(len, to_run)) else: total_commands = len(to_run) for i in range(0, len(to_run), chunk_size): chunks.append(to_run[i : i + chunk_size]) logger.info( f"Executing {total_commands:,} command(s) in {len(chunks):,} chunk(s)" ) results = [] start_time = time.time() interval_start = start_time interval_received = 0 executor_fn = batched_execute_code if is_batched else execute_single threaded_fn = functools.partial( threaded_execution, execution_fn=executor_fn, max_threads=max_threads, is_batched=is_batched, ) last_log = last_chunk_log = 0 chunks_finished = last_pct = 0 chunks_completed = 0 t0 = time.time() with mp.Pool(processes=num_executors) as pool: for result in pool.imap_unordered(threaded_fn, chunks): results.extend(result) chunks_completed += 1 chunks_pct = math.floor(chunks_completed / len(chunks) * 10) if len(results) - last_log >= log_freq or chunks_pct > last_pct: last_pct = chunks_pct chunks_finished += 1 if ( len(results) - last_log >= log_freq or math.floor(10 * chunks_finished / len(chunks)) > last_chunk_log ): last_chunk_log = math.floor(10 * chunks_finished / len(chunks)) last_log = len(results) t1 = time.time() interval_elapsed = t1 - interval_start elapsed = t1 - start_time interval_completed = len(results) - interval_received prog = len(results) / total_commands if interval_elapsed == 0: rate = 1 else: rate = interval_completed / interval_elapsed eta = seconds_to_human((total_commands - len(results)) / rate) interval_received = len(results) rate_str = f"{rate:0.2f} P/S" prog_str = f"{prog:0.2%}" print( f"[{datetime.now().isoformat(' ','seconds')}] " f"Finished {prog_str:<6} @ {rate_str:<12} " f"in {seconds_to_human(elapsed)} ETA: {eta}" ) interval_start = time.time() if chunks_completed % 10 == 0: logger.debug( f"{len(results):>9,}/{total_commands:<9,} total finished in {seconds_to_human(time.time() - start_time)}" ) one_min_cpu, _, fifteen_min_cpu = [ x / psutil.cpu_count() for x in psutil.getloadavg() ] logger.debug( f"Memory={sizeof_fmt(manager_process.memory_info().rss)} " f"CPU: 1Min={one_min_cpu:0.2%} 15Min={fifteen_min_cpu:0.2%}" ) pool.close() pool.terminate() elapsed = time.time() - start_time logger.info( f"Finished executing {len(results):,} in {seconds_to_human(elapsed)}" ) if len(results) != total_commands: raise ValueError( f"Expected {total_commands:,} results, got {len(results):,}" ) return elapsed, results
[docs] def execute_commands( predictions, config: ExecutionConfig, ) -> Tuple[float, float, List[ExecutionResult]]: """Executes a list of commands.""" start = datetime.now() if not LOGGING_IS_CONFIGURED: print(f"Executing {len(predictions):,} predictions") else: logger.debug("Executing %d predictions", len(predictions)) if config.batched: logger.debug( "Running in batched mode with batch_size=%d", config.batch_size ) executor_fn = batched_execute_code to_run = [] for i in range(0, len(predictions), config.batch_size): to_run.append(predictions[i : i + config.batch_size]) else: logger.debug("Running in non-batched mode") executor_fn = execute_single to_run = predictions num_workers = min(len(to_run), config.num_workers) # Yes, this is not entirely parallel, but it makes debugging so much easier. if num_workers > 1: pure_exec_time, results = _parallel_execute_code( to_run=to_run, max_processes=num_workers, num_executors=config.num_executors, is_batched=config.batched, log_freq=config.log_freq, execute_batch_size=config.buffer_size, ) else: logger.debug("Running in serial as num_workers=1") pbar_generator = tqdm( map(executor_fn, to_run), total=len(to_run), desc="Executing Predictions", disable=config.disable_tqdm, ) pure_exec_time, results = get_results_from_generator( generator=pbar_generator, total=len(to_run), target_returns_multiple=config.batched, garbage_collect_freq=500, log_freq=500, ) pbar_generator.close() elapsed = (datetime.now() - start).total_seconds() return elapsed, pure_exec_time, results