import queue
from concurrent.futures import Future
from threading import Thread
from time import sleep
from typing import Any, Callable, Optional
from executorlib.standalone.batched import batched_futures
from executorlib.standalone.interactive.arguments import (
check_exception_was_raised,
check_list_of_futures_is_done,
get_exception_lst,
get_future_objects_from_input,
update_futures_in_input,
)
from executorlib.task_scheduler.base import TaskSchedulerBase, validate_resource_dict
from executorlib.task_scheduler.interactive.dependency_plot import (
export_dependency_graph_function,
generate_nodes_and_edges_for_plotting,
generate_task_hash_for_plotting,
plot_dependency_graph_function,
)
[docs]
class DependencyTaskScheduler(TaskSchedulerBase):
"""
ExecutorWithDependencies is a class that extends ExecutorBase and provides functionality for executing tasks with
dependencies.
Args:
refresh_rate (float, optional): The refresh rate for updating the executor queue. Defaults to 0.01.
plot_dependency_graph (bool, optional): Whether to generate and plot the dependency graph. Defaults to False.
plot_dependency_graph_filename (str): Name of the file to store the plotted graph in.
export_workflow_filename (str): Name of the file to store the exported workflow graph in.
Attributes:
_future_hash_dict (Dict[str, Future]): A dictionary mapping task hash to future object.
_task_hash_dict (Dict[str, Dict]): A dictionary mapping task hash to task dictionary.
_generate_dependency_graph (bool): Whether to generate the dependency graph.
_generate_dependency_graph (str): Name of the file to store the plotted graph in.
"""
[docs]
def __init__(
self,
executor: TaskSchedulerBase,
max_cores: Optional[int] = None,
refresh_rate: float = 0.01,
plot_dependency_graph: bool = False,
plot_dependency_graph_filename: Optional[str] = None,
export_workflow_filename: Optional[str] = None,
validator: Callable = validate_resource_dict,
) -> None:
super().__init__(max_cores=max_cores, validator=validator)
self._process_kwargs = {
"future_queue": self._future_queue,
"executor_queue": executor._future_queue,
"executor": executor,
"refresh_rate": refresh_rate,
}
self._set_process(
Thread(
target=_execute_tasks_with_dependencies,
kwargs=self._process_kwargs,
)
)
self._future_hash_dict: dict = {}
self._task_hash_dict: dict = {}
self._plot_dependency_graph_filename = plot_dependency_graph_filename
self._export_workflow_filename = export_workflow_filename
if plot_dependency_graph_filename is None and export_workflow_filename is None:
self._generate_dependency_graph = plot_dependency_graph
else:
self._generate_dependency_graph = True
@property
def info(self) -> Optional[dict]:
"""
Get the information about the executor.
Returns:
Optional[dict]: Information about the executor.
"""
if isinstance(self._future_queue, queue.Queue):
f: Future = Future()
self._future_queue.queue.insert(
0, {"internal": True, "task": "get_info", "future": f}
)
return f.result()
else:
return None
@property
def max_workers(self) -> Optional[int]:
if isinstance(self._future_queue, queue.Queue):
f: Future = Future()
self._future_queue.queue.insert(
0, {"internal": True, "task": "get_max_workers", "future": f}
)
return f.result()
else:
return None
@max_workers.setter
def max_workers(self, max_workers: int):
if isinstance(self._future_queue, queue.Queue):
f: Future = Future()
self._future_queue.queue.insert(
0,
{
"internal": True,
"task": "set_max_workers",
"max_workers": max_workers,
"future": f,
},
)
if not f.result():
raise NotImplementedError("The max_workers setter is not implemented.")
[docs]
def submit( # type: ignore
self,
fn: Callable[..., Any],
*args: Any,
resource_dict: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Future:
"""
Submits a task to the executor.
Args:
fn (Callable): The function to be executed.
*args: Variable length argument list.
resource_dict (dict, optional): A dictionary of resources required by the task. Defaults to {}.
**kwargs: Arbitrary keyword arguments.
Returns:
Future: A future object representing the result of the task.
"""
if resource_dict is None:
resource_dict = {}
self._validator(resource_dict=resource_dict)
if not self._generate_dependency_graph:
f = super().submit(fn, *args, resource_dict=resource_dict, **kwargs)
else:
f = Future()
f.set_result(None)
task_dict = {
"fn": fn,
"args": args,
"kwargs": kwargs,
"future": f,
"resource_dict": resource_dict,
}
task_hash = generate_task_hash_for_plotting(
task_dict=task_dict,
future_hash_dict=self._future_hash_dict,
)
self._future_hash_dict[task_hash] = f
self._task_hash_dict[task_hash] = task_dict
return f
[docs]
def batched(
self,
iterable: list[Future],
n: int,
) -> list[Future]:
"""
Batch futures from the iterable into tuples of length n. The last batch may be shorter than n.
Args:
iterable (list): list of future objects to batch based on which future objects finish first
n (int): batch size
Returns:
list[Future]: list of future objects one for each batch
"""
skip_lst: list[Future] = []
future_lst: list[Future] = []
for _ in range(len(iterable) // n + (1 if len(iterable) % n > 0 else 0)):
f: Future = Future()
f_skip: Future = Future()
if self._future_queue is not None:
self._future_queue.put(
{
"fn": "batched",
"args": (),
"kwargs": {"lst": iterable, "n": n, "skip_lst": skip_lst},
"future": f,
"future_lst": iterable,
"future_skip": f_skip,
"resource_dict": {},
}
)
skip_lst = skip_lst.copy() + [f_skip] # be careful
future_lst.append(f)
return future_lst
def __exit__(
self,
exc_type: Any,
exc_val: Any,
exc_tb: Any,
) -> None:
"""
Exit method called when exiting the context manager.
Args:
exc_type: The type of the exception.
exc_val: The exception instance.
exc_tb: The traceback object.
"""
super().__exit__(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb) # type: ignore
if self._generate_dependency_graph:
node_lst, edge_lst = generate_nodes_and_edges_for_plotting(
task_hash_dict=self._task_hash_dict,
future_hash_inverse_dict={
v: k for k, v in self._future_hash_dict.items()
},
)
if self._export_workflow_filename is not None:
return export_dependency_graph_function(
node_lst=node_lst,
edge_lst=edge_lst,
file_name=self._export_workflow_filename,
)
else:
return plot_dependency_graph_function(
node_lst=node_lst,
edge_lst=edge_lst,
filename=self._plot_dependency_graph_filename,
)
else:
return None
def _execute_tasks_with_dependencies(
future_queue: queue.Queue,
executor_queue: queue.Queue,
executor: TaskSchedulerBase,
refresh_rate: float = 0.01,
):
"""
Resolve the dependencies of multiple tasks, by analysing which task requires concurrent.future.Futures objects from
other tasks.
Args:
future_queue (Queue): Queue for receiving new tasks.
executor_queue (Queue): Queue for the internal executor.
executor (TaskSchedulerBase): Executor to execute the tasks with after the dependencies are resolved.
refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.
"""
future_dependency_lst: list = []
while True:
try:
task_dict = future_queue.get_nowait()
except queue.Empty:
task_dict = None
if ( # shutdown the executor
task_dict is not None and "shutdown" in task_dict and task_dict["shutdown"]
):
while len(future_dependency_lst) > 0:
# Check functions in the wait list and execute them if all future objects are now ready
future_dependency_lst = _handle_future_dependencies(
future_dependency_lst=future_dependency_lst,
executor_queue=executor_queue,
refresh_rate=refresh_rate,
)
executor.shutdown(wait=task_dict["wait"])
future_queue.task_done()
future_queue.join()
break
elif ( # handle internal tasks for getting and setting information about the executor
task_dict is not None and "internal" in task_dict and task_dict["internal"]
):
if task_dict["task"] == "get_info":
task_dict["future"].set_result(executor.info)
elif task_dict["task"] == "get_max_workers":
task_dict["future"].set_result(executor.max_workers)
elif task_dict["task"] == "set_max_workers":
try:
executor.max_workers = task_dict["max_workers"]
except NotImplementedError:
task_dict["future"].set_result(False)
else:
task_dict["future"].set_result(True)
elif ( # handle batched function submitted to the executor
task_dict is not None
and "fn" in task_dict
and task_dict["fn"] == "batched"
and "future" in task_dict
):
future_dependency_lst.append(task_dict)
future_queue.task_done()
elif ( # handle function submitted to the executor
task_dict is not None
and "fn" in task_dict
and task_dict["fn"] != "batched"
and "future" in task_dict
):
future_lst = get_future_objects_from_input(
args=task_dict["args"], kwargs=task_dict["kwargs"]
)
ready_flag = check_list_of_futures_is_done(future_lst=future_lst)
exception_lst = get_exception_lst(future_lst=future_lst)
if not check_exception_was_raised(future_obj=task_dict["future"]):
if len(exception_lst) > 0:
task_dict["future"].set_exception(exception_lst[0])
elif len(future_lst) == 0 or ready_flag:
# No future objects are used in the input or all future objects are already done
task_dict["args"], task_dict["kwargs"] = update_futures_in_input(
args=task_dict["args"], kwargs=task_dict["kwargs"]
)
executor_queue.put(task_dict)
else: # Otherwise add the function to the wait list
task_dict["future_lst"] = future_lst
future_dependency_lst.append(task_dict)
future_queue.task_done()
elif len(future_dependency_lst) > 0:
# Check functions in the wait list and execute them if all future objects are now ready
future_dependency_lst = _handle_future_dependencies(
future_dependency_lst=future_dependency_lst,
executor_queue=executor_queue,
refresh_rate=refresh_rate,
)
else:
# If there is nothing else to do, sleep for a moment
sleep(refresh_rate)
def _handle_future_dependencies(
future_dependency_lst: list[dict],
executor_queue: queue.Queue,
refresh_rate: float = 0.01,
) -> list:
"""
Submit the waiting tasks, which future inputs have been completed, to the executor
Args:
future_dependency_lst (list): List of waiting tasks
executor_queue (Queue): Queue of the internal executor
refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.
Returns:
list: list tasks which future inputs have not been completed
"""
wait_tmp_lst = []
for task_wait_dict in future_dependency_lst:
exception_lst = get_exception_lst(future_lst=task_wait_dict["future_lst"])
if len(exception_lst) > 0 and task_wait_dict["fn"] != "batched":
task_wait_dict["future"].set_exception(exception_lst[0])
elif task_wait_dict["fn"] != "batched" and all(
future.done() for future in task_wait_dict["future_lst"]
):
del task_wait_dict["future_lst"]
task_wait_dict["args"], task_wait_dict["kwargs"] = update_futures_in_input(
args=task_wait_dict["args"], kwargs=task_wait_dict["kwargs"]
)
executor_queue.put(task_wait_dict)
elif task_wait_dict["fn"] == "batched" and all(
future.done() for future in task_wait_dict["kwargs"]["skip_lst"]
):
success, done_lst = batched_futures(
lst=task_wait_dict["kwargs"]["lst"],
n=task_wait_dict["kwargs"]["n"],
nested_skip_lst=task_wait_dict["kwargs"]["skip_lst"],
)
if success and len(done_lst) == 0:
wait_tmp_lst.append(task_wait_dict)
elif success and len(done_lst) > 0:
task_wait_dict["future"].set_result([f.result() for f in done_lst])
task_wait_dict["future_skip"].set_result([id(f) for f in done_lst])
else:
task_wait_dict["future"].set_exception(done_lst[0].exception())
task_wait_dict["future_skip"].set_result([id(f) for f in done_lst])
else:
wait_tmp_lst.append(task_wait_dict)
if len(future_dependency_lst) == len(wait_tmp_lst):
sleep(refresh_rate)
return wait_tmp_lst