Source code for code_execution.code_trees

import ast
import functools
import logging
from typing import Callable, List, Tuple, Union

from .utils import ContextTimeLimitException
from .utils import run_in_parallel
from .utils import swallow_io
from .utils import time_limit

logger = logging.getLogger(__name__)


[docs] def safe_ast_parse(code) -> ast.Module: """Safely parse a string of code into an AST, if possible. Otherwise return None.""" try: with swallow_io(): with time_limit(5): res = ast.parse(code) except ( SyntaxError, ValueError, RecursionError, ContextTimeLimitException, MemoryError, ): return None return res
[docs] def is_valid_python(code): """Checks if the code is valid python.""" return safe_ast_parse(code) is not None
[docs] def is_simple_test_case(tree): """Checks if the test case is an assert with a function call on the left.""" if not tree.body: return False n = tree.body[0] if not isinstance(n, ast.Assert): return False if not isinstance(n.test, ast.Compare): return False left = n.test.left for c in ast.walk(left): if isinstance(c, ast.Call): return True return False
[docs] def get_global_imports(tree: ast.Module) -> List[str]: """Get the global imports from an ast tree as a list of strings.""" out = [] for node in tree.body: if isinstance(node, ast.Import): out.append(ast.unparse(node)) elif isinstance(node, ast.ImportFrom): out.append(ast.unparse(node)) return out
[docs] def convert_call_to_assert( call: str, expected_output: str, requires_float=False, return_str: bool = False, ) -> Union[ast.Module, str]: """Coverts call code to an assertion with an expected output. The call code must end in an ast.Expr node, which is the node that will be converted to an assertion. The expected output must be an expression. Args: call: The code to be converted to an assertion. expected_output: The expected output of the call. requires_float: Whether the expected output is a float. If so, we will add a tolerance of 1e-6. return_str: Whether to return the converted code as a string or as an ast tree. Returns: The converted ast tree or the converted code. """ tree = ast.parse(call) if isinstance(tree.body[-1], ast.Assert): return ast.unparse(tree) if return_str else tree out_tree = ast.parse(expected_output).body[0].value if requires_float: tree.body[-1] = ast.Assert( test=ast.Compare( left=ast.Call( func=ast.Name("abs", ctx=ast.Load()), args=[ ast.BinOp( left=tree.body[-1].value, op=ast.Sub(), right=out_tree, ) ], keywords=[], ), ops=[ast.Lt()], comparators=[ast.Constant(value=1e-6, kind=float)], ), msg=None, ) else: tree.body[-1] = ast.Assert( test=ast.Compare( left=tree.body[-1].value, ops=[ast.Eq()], comparators=[out_tree] ), msg=None, ) if return_str: return ast.unparse(tree) return tree
[docs] def convert_test_list_to_assert( test_list: List[Union[Tuple[str, str, bool], str]], timeout: float = -1.0, convert_to_string: bool = False, ) -> List[Union[ast.AST, str]]: """Converts a list of test cases to assertion nodes. Args: test_list: A list of test cases. Each test case can be a string or a tuple of (call, output, requires_float). If the test case is a string, it will be parsed as a call. If it is a tuple, it will be converted to an assertion. timeout: The timeout for parsing the test cases. convert_to_string: Whether to convert the resulting AST to a string. Returns: A list of converted test cases as AST nodes or strings. """ out = [] for tc in test_list: try: if isinstance(tc, str): with time_limit(timeout): tree = safe_ast_parse(tc) else: i, o, *rest = tc requires_float = False if len(rest) > 0: requires_float = rest[0] with time_limit(timeout): tree = convert_call_to_assert( i, o, requires_float=requires_float ) except ContextTimeLimitException: tree = None if tree is not None: if convert_to_string: out.append(ast.unparse(tree)) else: out.append(tree) return out
[docs] def wrap_assert_in_try_print( idx: int, call: str, output: str, requires_float: bool, print_formatter: Callable[[int], Tuple[str, str, List[Tuple[str, str]]]], ) -> str: """Wraps a test case in a try-except block that prints the result. The resulting code will be: ``` try: {ASSERTION} print({pass_str}) except AssertionError: print({fail_str}) ``` The exceptions will be appended as: ``` except {exception_type} as e: print({print_string}) ``` Args: idx: The index of the test case. call: The call code. output: The expected output code. requires_float: Whether the expected output is a float. print_formatter: A function that takes in the index and returns the pass, fail, and a list of length 2 tuples for exceptions. For the exception strings, they should be in the format (exception_type, print_string). The except clause will be `except {exception_type} as e\n\tprint({print_string}). The resulting strings will be passed directly to print(). """ tree = convert_call_to_assert( call=call, expected_output=output, requires_float=requires_float ) pass_str, fail_str, error_strs = print_formatter(idx) template = ( f"try:\n\tprint({pass_str})\n" f"except AssertionError:\n\tprint({fail_str})\n" ) for e_type, e_str in error_strs: template += f"except {e_type} as e:\n\tprint({e_str})\n" template_tree = ast.parse(template).body[0] template_tree.body = tree.body + template_tree.body return ast.unparse(ast.fix_missing_locations(template_tree))
[docs] def remove_deep_trees( code_lines: List[str], tree_process_func: Callable, timeout: float ): out = [] for code in code_lines: try: with time_limit(timeout): tree = tree_process_func(code) except ( ContextTimeLimitException, RecursionError, SyntaxError, MemoryError, ): tree = None if tree is not None: out.append(code) return out
def _remove_deep_trees_worker(batch, tree_process_func, timeout): out = [] for line in batch: out.append( { "idx": line["idx"], "code": remove_deep_trees( line["code"], tree_process_func, timeout=timeout ), } ) return out
[docs] def default_tree_process_func(code: str): tree = ast.parse(code) _ = [n for n in ast.walk(tree)] return tree
[docs] def remove_trees_from_lists( codes: List, tree_process_func: Callable = default_tree_process_func, timeout=2, num_workers=4, batch_size=100, **parallel_kwargs, ) -> List[List[str]]: """Removes deep trees from the code.""" batches = [] codes = [{"idx": i, "code": c} for i, c in enumerate(codes)] for i in range(0, len(codes), batch_size): batches.append(codes[i : i + batch_size]) codes = run_in_parallel( functools.partial( _remove_deep_trees_worker, tree_process_func=tree_process_func, timeout=timeout, ), args=batches, num_workers=num_workers, desc="Removing Deep Trees", **parallel_kwargs, ) # Unbatch the results. codes = [c for b in codes for c in b] # Sort the codes back into the original order return [ b["code"] for b in sorted( codes, key=lambda x: x["idx"], ) ]