import inspect
import json
import os.path
from concurrent.futures import Future
from typing import Optional
import cloudpickle
from executorlib.standalone.select import FutureSelector
[docs]
def generate_nodes_and_edges_for_plotting(
task_hash_dict: dict, future_hash_inverse_dict: dict
) -> tuple[list, list]:
"""
Generate nodes and edges for visualization.
Args:
task_hash_dict (dict): Dictionary mapping task hash to task information.
future_hash_inverse_dict (dict): Dictionary mapping future hash to future object.
Returns:
Tuple[list, list]: Tuple containing the list of nodes and the list of edges.
"""
node_lst: list = []
edge_lst: list = []
hash_id_dict: dict = {}
def extend_args(funct_dict):
sig = inspect.signature(funct_dict["fn"])
args = sig.bind(*funct_dict["args"], **funct_dict["kwargs"])
funct_dict["signature"] = args.arguments
return funct_dict
def add_element(arg, link_to, label=""):
"""
Add element to the node and edge lists.
Args:
arg: Argument to be added.
link_to: ID of the node to link the element to.
label (str, optional): Label for the edge. Defaults to "".
"""
if isinstance(arg, FutureSelector):
edge_lst.append(
{
"start": hash_id_dict[future_hash_inverse_dict[arg._future]],
"end": link_to,
"label": label + str(arg._selector),
"end_label": label,
"start_label": str(arg._selector),
}
)
elif isinstance(arg, Future):
edge_lst.append(
{
"start": hash_id_dict[future_hash_inverse_dict[arg]],
"end": link_to,
"label": label,
}
)
elif isinstance(arg, list) and any(isinstance(a, Future) for a in arg):
lst_no_future = [a if not isinstance(a, Future) else "$" for a in arg]
node_id = len(node_lst)
node_lst.append(
{
"name": str(lst_no_future),
"value": "python_workflow_definition.shared.get_list",
"id": node_id,
"type": "function",
"shape": "box",
}
)
edge_lst.append({"start": node_id, "end": link_to, "label": label})
for i, a in enumerate(arg):
if isinstance(a, Future):
add_element(arg=a, link_to=node_id, label=str(i))
elif isinstance(arg, dict) and any(isinstance(a, Future) for a in arg.values()):
dict_no_future = {
kt: vt if not isinstance(vt, Future) else "$" for kt, vt in arg.items()
}
node_id = len(node_lst)
node_lst.append(
{
"name": str(dict_no_future),
"value": "python_workflow_definition.shared.get_dict",
"id": node_id,
"type": "function",
"shape": "box",
}
)
edge_lst.append({"start": node_id, "end": link_to, "label": label})
for kt, vt in arg.items():
add_element(arg=vt, link_to=node_id, label=kt)
else:
value_dict = {
str(n["value"]): n["id"] for n in node_lst if n["type"] == "input"
}
if str(arg) not in value_dict:
node_id = len(node_lst)
node_lst.append(
{
"name": label,
"value": arg,
"id": node_id,
"type": "input",
"shape": "circle",
}
)
else:
node_id = value_dict[str(arg)]
edge_lst.append({"start": node_id, "end": link_to, "label": label})
task_hash_modified_dict = {
k: extend_args(funct_dict=v) for k, v in task_hash_dict.items()
}
for k, v in task_hash_modified_dict.items():
hash_id_dict[k] = len(node_lst)
node_lst.append(
{
"name": v["fn"].__name__,
"type": "function",
"value": v["fn"].__module__ + "." + v["fn"].__name__,
"id": hash_id_dict[k],
"shape": "box",
}
)
for k, task_dict in task_hash_modified_dict.items():
for kw, v in task_dict["signature"].items():
add_element(arg=v, link_to=hash_id_dict[k], label=str(kw))
return node_lst, edge_lst
[docs]
def generate_task_hash_for_plotting(task_dict: dict, future_hash_dict: dict) -> bytes:
"""
Generate a hash for a task dictionary.
Args:
task_dict (dict): Dictionary containing task information.
future_hash_dict (dict): Dictionary mapping future hash to future object.
Returns:
bytes: Hash generated for the task dictionary.
"""
def convert_arg(arg, future_hash_inverse_dict):
"""
Convert an argument to its hash representation.
Args:
arg: Argument to be converted.
future_hash_inverse_dict (dict): Dictionary mapping future hash to future object.
Returns:
The hash representation of the argument.
"""
if isinstance(arg, FutureSelector):
if arg not in future_hash_inverse_dict:
obj_dict = {
"args": (),
"kwargs": {
"future": future_hash_inverse_dict[arg._future],
"selector": arg._selector,
},
}
if isinstance(arg._selector, str):
obj_dict["fn"] = "get_item_from_future"
else:
obj_dict["fn"] = "split_future"
arg_hash = cloudpickle.dumps(obj_dict)
future_hash_dict[arg_hash] = arg
future_hash_inverse_dict[arg] = arg_hash
return future_hash_inverse_dict[arg]
elif isinstance(arg, Future):
return future_hash_inverse_dict[arg]
elif isinstance(arg, list):
return [
convert_arg(arg=a, future_hash_inverse_dict=future_hash_inverse_dict)
for a in arg
]
elif isinstance(arg, dict):
return {
k: convert_arg(arg=v, future_hash_inverse_dict=future_hash_inverse_dict)
for k, v in arg.items()
}
else:
return arg
future_hash_inverted_dict = {v: k for k, v in future_hash_dict.items()}
args_for_hash = [
convert_arg(arg=arg, future_hash_inverse_dict=future_hash_inverted_dict)
for arg in task_dict["args"]
]
kwargs_for_hash = {
k: convert_arg(arg=v, future_hash_inverse_dict=future_hash_inverted_dict)
for k, v in task_dict["kwargs"].items()
}
return cloudpickle.dumps(
{"fn": task_dict["fn"], "args": args_for_hash, "kwargs": kwargs_for_hash}
)
[docs]
def plot_dependency_graph_function(
node_lst: list, edge_lst: list, filename: Optional[str] = None
):
"""
Draw the graph visualization of nodes and edges.
Args:
node_lst (list): List of nodes.
edge_lst (list): List of edges.
filename (str): Name of the file to store the plotted graph in.
"""
import networkx as nx # noqa
graph = nx.DiGraph()
for node in node_lst:
if node["type"] == "input":
graph.add_node(
node["id"],
label=_short_object_name(node=node["value"]),
shape=node["shape"],
)
else:
graph.add_node(node["id"], label=str(node["name"]), shape=node["shape"])
for edge in edge_lst:
graph.add_edge(edge["start"], edge["end"], label=edge["label"])
if filename is not None:
file_format = os.path.splitext(filename)[-1][1:]
with open(filename, "wb") as f:
f.write(nx.nx_agraph.to_agraph(graph).draw(prog="dot", format=file_format))
else:
from IPython.display import SVG, display # noqa
display(SVG(nx.nx_agraph.to_agraph(graph).draw(prog="dot", format="svg")))
[docs]
def export_dependency_graph_function(
node_lst: list, edge_lst: list, file_name: str = "workflow.json"
):
"""
Export the graph visualization of nodes and edges as a JSON dictionary.
Args:
node_lst (list): List of nodes.
edge_lst (list): List of edges.
file_name (str): Name of the file to store the exported graph in.
"""
import numpy as np
pwd_nodes_lst = []
for n in node_lst:
if n["type"] == "function":
pwd_nodes_lst.append(
{"id": n["id"], "type": n["type"], "value": n["value"]}
)
elif n["type"] == "input" and isinstance(n["value"], np.ndarray):
pwd_nodes_lst.append(
{
"id": n["id"],
"type": n["type"],
"value": n["value"].tolist(),
"name": n["name"],
}
)
else:
pwd_nodes_lst.append(
{
"id": n["id"],
"type": n["type"],
"value": n["value"],
"name": n["name"],
}
)
final_node = {"id": len(pwd_nodes_lst), "type": "output", "name": "result"}
pwd_nodes_lst.append(final_node)
pwd_edges_lst = [
(
{
"target": e["end"],
"targetPort": e["label"],
"source": e["start"],
"sourcePort": None,
}
if "start_label" not in e
else {
"target": e["end"],
"targetPort": e["end_label"],
"source": e["start"],
"sourcePort": e["start_label"],
}
)
for e in edge_lst
]
pwd_edges_lst.append(
{
"target": final_node["id"],
"targetPort": None,
"source": max([e["target"] for e in pwd_edges_lst]),
"sourcePort": None,
}
)
pwd_dict = {
"version": "0.1.0",
"nodes": pwd_nodes_lst,
"edges": pwd_edges_lst,
}
with open(file_name, "w") as f:
json.dump(pwd_dict, f, indent=4)
def _short_object_name(node):
node_value_str = str(node)
if isinstance(node, tuple):
short_name = str(tuple(_short_object_name(node=el) for el in node))
elif isinstance(node, list):
short_name = str([_short_object_name(node=el) for el in node])
elif isinstance(node, dict):
short_name = str(
{
_short_object_name(node=key): _short_object_name(node=value)
for key, value in node.items()
}
)
elif "object at" in node_value_str:
short_name = (
node_value_str[1:-1].split(maxsplit=1)[0].rsplit(".", maxsplit=1)[-1] + "()"
)
elif "<function" in node_value_str:
short_name = node_value_str.split()[1] + "()"
elif "\n" in node_value_str:
short_name = str(type(node)).split("'")[1].split(".")[-1] + "()"
elif "(" in node_value_str and ")" in node_value_str:
short_name = node_value_str.split("(", maxsplit=1)[0] + "()"
elif len(node_value_str) > 20:
short_name = node_value_str[:21] + "..."
else:
short_name = node_value_str
return short_name