Source code for executorlib.task_scheduler.base

import contextlib
import queue
from concurrent.futures import (
    Executor as FutureExecutor,
)
from concurrent.futures import (
    Future,
)
from threading import Thread
from typing import Callable, Optional, Union

from executorlib.standalone.inputcheck import check_resource_dict
from executorlib.standalone.queue import cancel_items_in_queue
from executorlib.standalone.serialize import cloudpickle_register


[docs] def validate_resource_dict(resource_dict: dict): """ No-op resource dict validator used as the default when no validation is required. Args: resource_dict (dict): Dictionary of resource requirements (ignored). """ pass
[docs] class TaskSchedulerBase(FutureExecutor): """ Base class for the executor. Args: max_cores (int): defines the number cores which can be used in parallel """
[docs] def __init__( self, max_cores: Optional[int] = None, validator: Callable = validate_resource_dict, ): """ Initialize the TaskSchedulerBase. Args: max_cores (int, optional): Maximum number of cores available to the scheduler. Tasks requesting more cores than this will be rejected. Defaults to None (unlimited). validator (Callable): Function used to validate per-task resource dicts before submission. Defaults to the no-op validate_resource_dict. """ cloudpickle_register(ind=3) self._process_kwargs: dict = {} self._max_cores = max_cores self._future_queue: Optional[queue.Queue] = queue.Queue() self._process: Optional[Union[Thread, list[Thread]]] = None self._validator = validator
@property def max_workers(self) -> Optional[int]: """ Return the configured number of parallel workers, or None if unconstrained. Returns: Optional[int]: The max_workers value stored in process kwargs, or None. """ return self._process_kwargs.get("max_workers") @max_workers.setter def max_workers(self, max_workers: int): """ Setting max_workers after construction is not supported by the base scheduler. Raises: NotImplementedError: Always. """ raise NotImplementedError("The max_workers setter is not implemented.") @property def info(self) -> Optional[dict]: """ Get the information about the executor. Returns: Optional[dict]: Information about the executor. """ meta_data_dict = self._process_kwargs.copy() if "future_queue" in meta_data_dict: del meta_data_dict["future_queue"] if self._process is not None and isinstance(self._process, list): meta_data_dict["max_workers"] = len(self._process) return meta_data_dict elif self._process is not None: return meta_data_dict else: return None @property def future_queue(self) -> Optional[queue.Queue]: """ Get the future queue. Returns: queue.Queue: The future queue. """ return self._future_queue
[docs] def batched( self, iterable: list[Future], n: int, ) -> list[Future]: """ Batch futures from the iterable into tuples of length n. The last batch may be shorter than n. Args: iterable (list): list of future objects to batch based on which future objects finish first n (int): badge size Returns: list[Future]: list of future objects one for each batch """ raise NotImplementedError("The batched method is not implemented.")
[docs] def submit( # type: ignore self, fn: Callable, /, *args, resource_dict: Optional[dict] = None, **kwargs, ) -> Future: """ Submits a callable to be executed with the given arguments. Schedules the callable to be executed as fn(*args, **kwargs) and returns a Future instance representing the execution of the callable. Args: fn (callable): function to submit for execution args: arguments for the submitted function kwargs: keyword arguments for the submitted function resource_dict (dict): A dictionary of resources required by the task. With the following keys: - cores (int): number of MPI cores to be used for each function call - threads_per_core (int): number of OpenMP threads to be used for each function call - gpus_per_core (int): number of GPUs per worker - defaults to 0 - cwd (str/None): current working directory where the parallel python task is executed - openmpi_oversubscribe (bool): adds the `--oversubscribe` command line flag (OpenMPI and SLURM only) - default False - slurm_cmd_args (list): Additional command line arguments for the srun call (SLURM only) - error_log_file (str): Name of the error log file to use for storing exceptions raised by the Python functions submitted to the Executor. Returns: Future: A Future representing the given call. """ if resource_dict is None: resource_dict = {} self._validator(resource_dict=resource_dict) cores = resource_dict.get("cores") if ( cores is not None and self._max_cores is not None and cores > self._max_cores ): raise ValueError( "The specified number of cores is larger than the available number of cores." ) check_resource_dict(function=fn) f: Future = Future() if self._future_queue is not None: self._future_queue.put( { "fn": fn, "args": args, "kwargs": kwargs, "future": f, "resource_dict": resource_dict, } ) return f
[docs] def map( self, fn: Callable, *iterables, timeout: Optional[float] = None, chunksize: int = 1, ): """Returns an iterator equivalent to map(fn, iter). Args: fn: A callable that will take as many arguments as there are passed iterables. timeout: The maximum number of seconds to wait. If None, then there is no limit on the wait time. chunksize: The size of the chunks the iterable will be broken into before being passed to a child process. This argument is only used by ProcessPoolExecutor; it is ignored by ThreadPoolExecutor. Returns: An iterator equivalent to: map(func, *iterables) but the calls may be evaluated out-of-order. Raises: TimeoutError: If the entire result iterator could not be generated before the given timeout. Exception: If fn(*args) raises for any values. """ if isinstance(iterables, (list, tuple)) and any( isinstance(i, Future) for i in iterables ): iterables = tuple( i.result() if isinstance(i, Future) else i for i in iterables ) return super().map(fn, *iterables, timeout=timeout, chunksize=chunksize)
[docs] def shutdown(self, wait: bool = True, *, cancel_futures: bool = False): """ Clean-up the resources associated with the Executor. It is safe to call this method several times. Otherwise, no other methods can be called after this one. Args: wait (bool): If True then shutdown will not return until all running futures have finished executing and the resources used by the parallel_executors have been reclaimed. cancel_futures (bool): If True then shutdown will cancel all pending futures. Futures that are completed or running will not be cancelled. """ if cancel_futures and self._future_queue is not None: cancel_items_in_queue(que=self._future_queue) if self._process is not None and self._future_queue is not None: self._future_queue.put( {"shutdown": True, "wait": wait, "cancel_futures": cancel_futures} ) if isinstance(self._process, Thread): self._process.join() self._future_queue.join() self._process = None self._future_queue = None
def _set_process(self, process: Thread): """ Set the process for the executor. Args: process (RaisingThread): The process for the executor. """ self._process = process self._process.start() def __len__(self) -> int: """ Get the length of the executor. Returns: int: The length of the executor. """ queue_size = 0 if self._future_queue is not None: queue_size = self._future_queue.qsize() return queue_size def __del__(self): """ Clean-up the resources associated with the Executor. """ with contextlib.suppress(AttributeError, RuntimeError): self.shutdown(wait=False)