Source code for junjo.graph

from __future__ import annotations

import json

from .edge import Edge
from .node import Node
from .run_concurrent import RunConcurrent
from .store import BaseStore
from .workflow import _NestableWorkflow


[docs] class Graph: """ Represents a directed graph of nodes and edges. """ def __init__(self, source: Node | _NestableWorkflow, sink: Node | _NestableWorkflow, edges: list[Edge]): self.source = source self.sink = sink self.edges = edges
[docs] async def get_next_node(self, store: BaseStore, current_node: Node | _NestableWorkflow) -> Node | _NestableWorkflow: """ Retrieves the next node (or workflow / subflow) in the graph for the given current node. This method checks the edges connected to the current node and resolves the next node based on the conditions defined in the edges. Args: store (BaseStore): The store instance to use for resolving the next node. current_node (Node | _NestableWorkflow): The current node or subflow in the graph. Returns: Node | _NestableWorkflow: The next node or subflow in the graph. """ matching_edges = [edge for edge in self.edges if edge.tail == current_node] resolved_edges = [edge for edge in matching_edges if await edge.next_node(store) is not None] if len(resolved_edges) == 0: raise ValueError("Check your Graph. No resolved edges. " f"No valid transition found for node or subflow: '{current_node}'.") else: resolved_edge = await resolved_edges[0].next_node(store) if resolved_edge is None: raise ValueError("Check your Graph. Resolved edge is None. " f"No valid transition found for node or subflow: '{current_node}'") return resolved_edge
[docs] def serialize_to_json_string(self) -> str: # noqa: C901 """ Converts the graph to a neutral serialized JSON string, representing RunConcurrent instances as subgraphs and includes Subflow graphs as well. Returns: str: A JSON string containing the graph structure. """ all_nodes_dict: dict[str, Node | _NestableWorkflow] = {} # Dictionary to store unique nodes found all_edges_dict: dict[str, Edge] = {} # Dictionary to store all edges including subflow edges processed_subflows: set[str] = set() # Track processed subflows to avoid recursion loops # Recursive helper function to find all nodes, including those inside RunConcurrent and Subflows def collect_nodes(node: Node | _NestableWorkflow | None): if node is None: return # Skip if not a Node or _NestableWorkflow or doesn't have an ID if not (isinstance(node, Node) or isinstance(node, _NestableWorkflow)) or not hasattr(node, 'id'): print(f"Warning: Item '{node}' is not a valid Node or Workflow with an id, skipping collection.") return if node.id not in all_nodes_dict: all_nodes_dict[node.id] = node # If it's a RunConcurrent, recursively collect the items it contains if isinstance(node, RunConcurrent) and hasattr(node, 'items'): for run_concurrent_item in node.items: collect_nodes(run_concurrent_item) # If it's a Subflow (inherits from _NestableWorkflow), recursively collect its graph elif ( isinstance(node, _NestableWorkflow) and hasattr(node, 'graph') and node.id not in processed_subflows ): processed_subflows.add(node.id) # Mark as processed to avoid cycles # Collect subflow's source, sink and all nodes connected by edges subflow_graph = node.graph collect_nodes(subflow_graph.source) collect_nodes(subflow_graph.sink) # Collect all edges from the subflow for edge in subflow_graph.edges: # Create a unique ID for the subflow edge edge_id = f"subflow_{node.id}_edge_{edge.tail.id}_{edge.head.id}" all_edges_dict[edge_id] = edge collect_nodes(edge.tail) collect_nodes(edge.head) # Collect edges from the main graph for i, edge in enumerate(self.edges): edge_id = f"edge_{edge.tail.id}_{edge.head.id}_{i}" all_edges_dict[edge_id] = edge # Start node collection collect_nodes(self.source) collect_nodes(self.sink) for edge in self.edges: collect_nodes(edge.tail) collect_nodes(edge.head) # Create nodes list for JSON output nodes_json = [] for node_id, node in all_nodes_dict.items(): # Determine Label: Prioritize 'label', then 'name', then class name label = getattr(node, 'label', None) or \ getattr(node, 'name', None) or \ node.__class__.__name__ node_info = { "id": node.id, "type": node.__class__.__name__, "label": label } # Add subgraph representation for RunConcurrent if isinstance(node, RunConcurrent): node_info["isSubgraph"] = True children_ids = [ n.id for n in node.items if (isinstance(n, Node) or isinstance(n, _NestableWorkflow)) and hasattr(n, 'id') ] node_info["children"] = children_ids # Add subflow representation for Subflows elif isinstance(node, _NestableWorkflow) and hasattr(node, 'graph'): node_info["isSubflow"] = True node_info["subflowSourceId"] = node.graph.source.id node_info["subflowSinkId"] = node.graph.sink.id nodes_json.append(node_info) # Create explicit edges list for JSON output edges_json = [] for edge_id, edge in all_edges_dict.items(): # Determine if this is a subflow edge is_subflow_edge = edge_id.startswith("subflow_") subflow_id = None if is_subflow_edge: # Extract the subflow ID from the edge_id (between "subflow_" and "_edge_") subflow_id = edge_id.split("_edge_")[0].replace("subflow_", "") edges_json.append({ "id": edge_id, "source": str(edge.tail.id), "target": str(edge.head.id), "condition": str(edge.condition) if edge.condition else None, "type": "subflow" if is_subflow_edge else "explicit", "subflowId": subflow_id if is_subflow_edge else None }) # Final graph dictionary structure graph_dict = { "v": 1, # Schema version "nodes": nodes_json, "edges": edges_json } try: # Serialize the dictionary to a JSON string return json.dumps(graph_dict, indent=2) except TypeError as e: print(f"Error serializing graph to JSON: {e}") error_info = { "error": "Failed to serialize graph", "detail": str(e), } return json.dumps(error_info, indent=2)
[docs] def to_mermaid(self) -> str: """ Currently Broken: Generates a Mermaid diagram string from the graph. The junjo-server telemetry server will produce a proper mermaid diagram for the workflow executions. """ mermaid_str = "graph LR\n" # Add nodes nodes = { node.id: node for node in [self.source, self.sink] + [e.tail for e in self.edges] + [e.head for e in self.edges] } for node_id, node in nodes.items(): node_label = node.__class__.__name__ # Or a custom label from node.name mermaid_str += f" {node_id}[{node_label}]\n" # Add edges for edge in self.edges: tail_id = edge.tail.id head_id = edge.head.id edge_label = "" if edge.condition: edge_label = str(edge.condition) mermaid_str += f" {tail_id} --> {edge_label}{head_id}\n" return mermaid_str
[docs] def to_dot_notation(self) -> str: """Currently Broken: Converts the graph to DOT notation.""" dot_str = "digraph G {\n" # Start of DOT graph dot_str += " node [shape=box, style=\"rounded\", fontsize=10];\n" #Added node styling dot_str += " ranksep=0.5; nodesep=1.0;\n" # Adjust spacing between ranks and nodes dot_str += " margin=1.0;\n" # Adjust graph margin # Add nodes nodes = {node.id: node for node in [self.source, self.sink] + [e.tail for e in self.edges] + [e.head for e in self.edges]} for node_id, node in nodes.items(): node_label = node.__class__.__name__ # Or a custom label from node.name dot_str += f' "{node_id}" [label="{node_label}"];\n' # Add edges for edge in self.edges: tail_id = edge.tail.id head_id = edge.head.id condition_str = str(edge.condition) style = "dashed" if condition_str else "solid" # Dotted for conditional, solid otherwise dot_str += f' "{tail_id}" -> "{head_id}" [label="{condition_str}", style="{style}"];\n' dot_str += "}\n" # End of DOT graph return dot_str
[docs] def to_graphviz(self) -> str: """ Converts the graph to Graphviz format. This is a placeholder for future implementation. """ raise NotImplementedError("Graphviz conversion is not implemented yet.")