"""Utility functions for code execution."""
import asyncio
import contextlib
import errno
import functools
import gc
import inspect
import io
import logging
import multiprocessing as mp
import os
import signal
import threading
import time
from pathlib import Path
from typing import Callable, Dict, Generator, List, Optional, Tuple
from tqdm import tqdm
from code_execution import utility_modules
logger = logging.getLogger(__name__)
[docs]
def in_notebook():
"""Checks if the code is running in a notebook."""
try:
# pylint: disable=import-outside-toplevel
# pylint: disable=import-error
from IPython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
[docs]
class RunThread(threading.Thread):
"""Class that will allow asycnio to run in a thread when called from Jupyter."""
[docs]
def __init__(self, func, *args, **kwargs):
self.func = func
self.args = args
self.kwargs = kwargs
self.result = None
self.had_error = False
super().__init__()
[docs]
def run(self):
try:
self.result = asyncio.run(self.func(*self.args, **self.kwargs))
except Exception as e:
self.result = e
self.had_error = True
[docs]
def notebook_safe_async_run(target, *args, **kwargs):
"""Run an async function in a thread."""
if in_notebook():
logger.info("Running in separate thread due to notebook.")
thread = RunThread(target, *args, **kwargs)
thread.start()
thread.join()
if thread.had_error:
raise thread.result
return thread.result
logger.info("Running in main thread.")
return asyncio.run(target(*args, **kwargs))
def _batched_wrapper(batch, processor, proc_returns_list):
"""Wrapper for batched processing."""
out = []
for example in batch:
idx = example["idx"]
start = time.time()
result = processor(*example["args"])
elapsed = time.time() - start
if proc_returns_list:
out.extend([((idx, i), elapsed, r) for i, r in enumerate(result)])
else:
out.append(((idx, 0), elapsed, result))
return out
def _normal_wrapper(arg_dict, processor, proc_returns_list):
"""Wrapper for normal processing."""
start = time.time()
result = processor(*arg_dict["args"])
elapsed = time.time() - start
if proc_returns_list:
return [
((arg_dict["idx"], i), elapsed, r) for i, r in enumerate(result)
]
return ((arg_dict["idx"], 0), elapsed, result)
[docs]
def wrap_processor(
processor_fn: Callable,
batch_size: int,
returns_list: bool,
) -> Callable:
"""Wraps a processor function to handle batching."""
if batch_size > 1:
logger.debug("Using batched processing with size %d", batch_size)
wrapper = _batched_wrapper
else:
logger.debug("Using normal processing")
wrapper = _normal_wrapper
return functools.partial(
wrapper, processor=processor_fn, proc_returns_list=returns_list
)
[docs]
def get_pred_dir(idx: int, parent: Path):
"""Gets the prediction directory for a prediction."""
return parent.joinpath(f"pred{idx}")
[docs]
class ContextTimeLimitException(Exception):
"""Timeout error for running commands."""
[docs]
def timeout_signal_handler(signum, frame):
raise ContextTimeLimitException(errno.ETIME)
# Timeout for windows.
[docs]
class TimeoutContext:
[docs]
def __init__(self, seconds, on_end=None):
self.seconds = seconds
self.timer = None
self.error_raised = False
self.on_end = on_end
def _timeout_handler(self):
self.error_raised = True
def __enter__(self):
self.timer = threading.Timer(self.seconds, self._timeout_handler)
self.timer.start()
def __exit__(self, exc_type, exc_val, exc_tb):
if self.timer is not None:
self.timer.cancel()
if self.error_raised:
raise ContextTimeLimitException(
f"Operation timed out after {self.seconds} seconds"
)
ON_WINDOWS = os.name == "nt"
[docs]
def timeout_decorator(seconds: int = 10):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not ON_WINDOWS:
signal.signal(signal.SIGALRM, timeout_signal_handler)
signal.alarm(seconds)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
else:
with TimeoutContext(seconds):
result = func(*args, **kwargs)
return result
return wrapper
return decorator
[docs]
@contextlib.contextmanager
def time_limit(seconds: float, on_end: Callable = None):
"""Sets a time limit."""
if seconds == -1:
yield
return
if ON_WINDOWS:
with TimeoutContext(seconds, on_end):
yield
else:
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, timeout_signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
# Original code from https://github.com/openai/human-eval/
[docs]
class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from"""
[docs]
def read(self, *args, **kwargs):
raise IOError
[docs]
def readline(self, *args, **kwargs):
raise IOError
[docs]
def readlines(self, *args, **kwargs):
raise IOError
[docs]
def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read."""
return False
[docs]
class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = "stdin"
[docs]
@contextlib.contextmanager
def swallow_io():
stream = WriteOnlyStringIO()
with contextlib.redirect_stdout(stream):
with contextlib.redirect_stderr(stream):
with redirect_stdin(stream):
yield
SUPPORTED_MODULES = {
"safeguard": ("import safeguard", "safeguard.reliability_guard()"),
}
[docs]
def get_module_and_call(module_name: str) -> Tuple[str, str, str]:
"""Gets the import code, call, and module source code for a module.
Args:
module_name (str): The name of the module to get.
Returns:
Tuple[str,str, str]: Import code, call, and module source code.
"""
if module_name not in SUPPORTED_MODULES:
raise ValueError(f"Module '{module_name}' is supported.")
import_call, call = SUPPORTED_MODULES[module_name]
module_source = inspect.getsource(getattr(utility_modules, module_name))
return import_call, call, module_source
[docs]
def get_results_from_generator(
generator: Generator,
total: int,
target_returns_multiple: bool,
garbage_collect_freq: int,
log_freq: int,
):
"""Gets the results from a generator.
Args:
generator (Generator): The generator to get results from.
total (int): The total number of items in the generator.
target_returns_multiple (bool): If the target returns multiple items per iteration.
disable_tqdm (bool): Whether to disable the progress bar.
garbage_collect_freq (int): How often to perform garbage collection.
log_freq (int): How often to log if not using tqdm.
quiet (bool, optional): Whether to suppress logging. Defaults to False.
Returns:
List: The results from the generator.
"""
results = []
# Create a counter for completed since the size of results will not
# always go up by one.
start_time = time.time()
num_completed = 0
for r in generator:
if target_returns_multiple:
results.extend(r)
else:
results.append(r)
num_completed += 1
if num_completed % log_freq == 0:
logger.debug(f"Finished {num_completed}/{total}")
if num_completed % garbage_collect_freq == 0:
gc.collect()
elapsed = time.time() - start_time
return elapsed, results
[docs]
def run_in_parallel(
target: Callable,
args: List,
num_workers: int,
desc: Optional[str] = None,
max_tasks_per_process: Optional[int] = None,
disable_tqdm: bool = False,
garbage_collect_freq: int = 500,
chunk_size: int = 1,
log_freq: int = 500,
target_returns_multiple: bool = False,
tqdm_kwargs: Optional[Dict] = None,
) -> List:
"""Runs a function in parallel.
Args:
target (Callable): The function to run.
args (List): The arguments to pass to the function.
num_workers (int): The number of workers to use.
desc (str): The description to use for the progress bar.
max_tasks_per_process (Optional[int], optional): Maximum number of tasks
before starting a new process. Defaults to None.
disable_tqdm (bool, optional): Disable the progress bar. Defaults to False.
garbage_collect_freq (int, optional): How often to perform garbage
collection. Defaults to 500.
chunk_size (int, optional): The chunk size to use for imap. Defaults to 1.
log_freq (int, optional): How often to log if not using tqdm. Defaults
to 500.
target_returns_multiple (bool, optional): If the target returns multiple
so that `.extend` is used instead of `.append`. Defaults to False.
tqdm_kwargs (Optional[Dict], optional): Additional keyword arguments to
pass to tqdm. Defaults to None.
Returns:
List: The results of `target(a)` for each `a` in `args`.
"""
logger.debug(
"Starting run_in_parallel for %s.",
desc or getattr(target, "__name__", "Target"),
)
logger.debug("Will use %d/%d CPUs", num_workers, mp.cpu_count())
generator_creator = functools.partial(
tqdm,
total=len(args),
desc=desc,
disable=disable_tqdm,
**(tqdm_kwargs or {}),
)
num_workers = min(num_workers, len(args))
# Yes, this is not entirely parallel, but it makes debugging so much easier.
if num_workers > 1:
logger.debug(
"Running in parallel with %d workers (%s,%s)",
num_workers,
f"{max_tasks_per_process=}",
f"{chunk_size=}",
)
with mp.Pool(
processes=num_workers, maxtasksperchild=max_tasks_per_process
) as pool:
pbar_generator = generator_creator(
pool.imap(target, args, chunksize=chunk_size),
)
elapsed, results = get_results_from_generator(
generator=pbar_generator,
total=len(args),
target_returns_multiple=target_returns_multiple,
garbage_collect_freq=garbage_collect_freq,
log_freq=log_freq,
)
# Cleanup pool
pool.close()
pool.terminate()
else:
logger.debug("Running in serial as num_workers=1")
pbar_generator = generator_creator(map(target, args))
elapsed, results = get_results_from_generator(
generator=pbar_generator,
total=len(args),
target_returns_multiple=target_returns_multiple,
garbage_collect_freq=garbage_collect_freq,
log_freq=log_freq,
)
pbar_generator.close()
logger.debug(f"Finished {desc} in {elapsed:.2f} seconds")
return results
MEM_MAX_CODE = """__MAX_MEM = %%MEM_LIMIT%%
def _set_mem_limit():
import resource
import platform
if not __MAX_MEM:
return
resource.setrlimit(
resource.RLIMIT_AS, (__MAX_MEM, __MAX_MEM)
)
resource.setrlimit(
resource.RLIMIT_DATA, (__MAX_MEM, __MAX_MEM)
)
if not platform.uname().system == "Darwin":
resource.setrlimit(
resource.RLIMIT_STACK, (__MAX_MEM, __MAX_MEM)
)
_set_mem_limit()
"""
[docs]
def get_mem_limit_code(mem_limit: Optional[str], trailing: str = "\n") -> str:
"""Gets the code to set the memory limit.
Args:
mem_limit (str): The memory limit value as a string. You can do
something like "4 * 1024" or "1024". If None, will return an
empty string.
trailing: The trailing characters to add to the code.
Returns:
str: The code to set the memory limit.
"""
if mem_limit is None:
return ""
out = MEM_MAX_CODE.replace("%%MEM_LIMIT%%", mem_limit)
return out + trailing