Source code for code_execution.data_structures

"""Data structures for code execution."""

import dataclasses
import datetime
from pathlib import Path
from typing import Callable, Dict, List, Optional


[docs] @dataclasses.dataclass(frozen=True) class Command: """Dataclass for a command to execute. Args: command: The command to execute. timeout: The timeout for the command. If not set, the default timeout is used. num_times: Number of times to execute the command. stdin: The stdin for the command. """ command: List[str] timeout: Optional[float] = None num_times: int = 1 stdin: List[str] = dataclasses.field(default_factory=list)
[docs] @dataclasses.dataclass(frozen=True) class CommandResult: """Dataclass for the result of executing a command. Args: return_code: The return code. runtime: The runtime. stdout: The stdout. stderr: The stderr. timed_out: Whether the command timed out. had_unexpected_error: Whether the command had an unexpected error. """ return_code: int runtime: float stdout: str stderr: str timed_out: bool had_unexpected_error: bool = False @property def had_error(self) -> bool: """Whether the last command had an error.""" return self.return_code != 0 or self.had_unexpected_error def __repr__(self): if self.timed_out: return f"CommandResult(return_code={self.return_code}, runtime={self.runtime}, timed_out={self.timed_out})" if self.stdout: use_str = self.stdout else: use_str = self.stderr return f"CommandResult(return_code={self.return_code}, runtime={self.runtime}, output={use_str[:50]}...)"
[docs] @dataclasses.dataclass(frozen=True) class ExecutionResult: """Dataclass for the result of executing a list of commands. Args: key: The key for the result. command_results: The results of the commands. elapsed: The elapsed time. cwd: The current working directory. tracked_files: The tracked files. expected_num_commands: The expected number of commands ran. """ key: str command_results: List[CommandResult] elapsed: float cwd: str tracked_files: Dict[str, str] expected_num_commands: int writing_time: float = -1 cleanup_time: float = -1 preprocess_time: float = -1 @property def timed_out(self) -> bool: """Whether the last command timed out.""" if not self.command_results: return False return self.command_results[-1].timed_out @property def had_error(self) -> bool: """Whether the last command had an error.""" if not self.command_results: return True return self.command_results[-1].return_code != 0 @property def last_cmd(self) -> CommandResult: """The last command result.""" if not self.command_results: return None return self.command_results[-1]
[docs] def all_had_return_code(self, return_code: int) -> bool: """Whether all commands had the same return code.""" return all(r.return_code == return_code for r in self.command_results)
[docs] def to_dict(self, include_command_results: bool = False) -> Dict: """Converts the result to a dictionary.""" out = dataclasses.asdict(self) if not include_command_results: out.pop("command_results") return out
[docs] @classmethod def invalid_result( cls, key: str, num_commands: int = 1, runtime: float = 10.0, return_code: int = 1, stdout: str = "SyntaxError", stderr: str = "Invalid", elapsed: float = 10.0, ) -> "ExecutionResult": """Creates a dummy ExecutionResult that represents an invalid result. Useful for when your preprocessor finds a program you want to skip execution for.""" return cls( key=key, command_results=[ CommandResult( return_code=return_code, runtime=runtime, stdout=stdout, stderr=stderr, timed_out=False, ) for _ in range(num_commands) ], elapsed=elapsed, cwd=None, tracked_files={}, expected_num_commands=num_commands, )
def __getitem__(self, idx: int) -> CommandResult: return self.command_results[idx] def __len__(self) -> int: return len(self.command_results)
[docs] def default_should_early_stop( cmd_idx: int, res: CommandResult, expected_rtr_code: Optional[int] = 0, stop_for_timeout: bool = True, **_k, ) -> bool: _ = cmd_idx _ = _k if stop_for_timeout and res.timed_out: return True if expected_rtr_code is not None and res.return_code != expected_rtr_code: return True if res.had_unexpected_error: return True return False
[docs] @dataclasses.dataclass(frozen=True) class Executable: """Dataclass to represent the commands and setup needed to execute a prediction. Args: files: The files to write. commands: The commands to run. tracked_files: The files to get contents of after execution. should_early_stop: A function that takes the index of the command and the result, returning a bool if the execution should stop early. THIS MUST BE PICKLEABLE """ files: Dict[str, str] commands: List[Command] tracked_files: List[str] = dataclasses.field(default_factory=list) should_early_stop: Callable[[int, CommandResult], bool] = ( default_should_early_stop ) stdout_postprocessor: Optional[Callable[[str], str]] = None def __post_init__(self): if not callable(self.should_early_stop): raise ValueError("should_early_stop must be callable")
[docs] @dataclasses.dataclass(frozen=True) class CommandsToRun: """Dataclass to represent the information needed to run a command. The main reason to have this class is to avoid the need to pass around the raw files to every function. Args: cwd: The current working directory. commands: The commands to run. tracked_files: The files to get contents of after execution. """ cwd: Path commands: List[Command] tracked_files: List[str] = dataclasses.field(default_factory=list) should_early_stop: Callable[[int, CommandResult], bool] = ( default_should_early_stop ) stdout_postprocessor: Optional[Callable[[str], str]] = None
[docs] @dataclasses.dataclass class OverallExecutionResults: results: List[Dict] net_time: float pure_exec_time: float execution_time: float writing_time: float postprocessing_time: float preprocessing_time: float timestamp: str = None def __post_init__(self): if self.timestamp is None: self.timestamp = datetime.datetime.isoformat( datetime.datetime.now() ) @property def timing_dict(self) -> Dict: return { "net_time": self.net_time, "pure_exec_time": self.pure_exec_time, "execution_time": self.execution_time, "writing_time": self.writing_time, "postprocessing_time": self.postprocessing_time, "preprocessing_time": self.preprocessing_time, "timestamp": self.timestamp, }