Source code for executorlib.backend.cache_parallel

import pickle
import sys
import time

import cloudpickle

from executorlib.standalone.error import backend_write_error_file
from executorlib.task_scheduler.file.backend import (
    backend_load_file,
    backend_write_file,
)


[docs] def main() -> None: """ Main function for executing the cache_parallel script. This function uses MPI (Message Passing Interface) to distribute the execution of a function across multiple processes. It loads a file, broadcasts the data to all processes, executes the function, gathers the results (if there are multiple processes), and writes the output to a file. Args: None Returns: None """ from mpi4py import MPI MPI.pickle.__init__( # type: ignore cloudpickle.dumps, cloudpickle.loads, pickle.HIGHEST_PROTOCOL, ) mpi_rank_zero = MPI.COMM_WORLD.Get_rank() == 0 mpi_size_larger_one = MPI.COMM_WORLD.Get_size() > 1 file_name = sys.argv[1] time_start = time.time() apply_dict = {} try: if mpi_rank_zero: apply_dict = backend_load_file(file_name=file_name) apply_dict = MPI.COMM_WORLD.bcast(apply_dict, root=0) output = apply_dict["fn"].__call__(*apply_dict["args"], **apply_dict["kwargs"]) result = ( MPI.COMM_WORLD.gather(output, root=0) if mpi_size_larger_one else output ) except Exception as error: if mpi_rank_zero: backend_write_file( file_name=file_name, output={"error": error}, runtime=time.time() - time_start, ) backend_write_error_file( error=error, apply_dict=apply_dict, ) else: if mpi_rank_zero: backend_write_file( file_name=file_name, output={"result": result}, runtime=time.time() - time_start, ) MPI.COMM_WORLD.Barrier()
if __name__ == "__main__": main()