Source code for junjo.graph

from __future__ import annotations

import html
import json
import re
import subprocess
from collections.abc import Callable
from pathlib import Path

from .edge import Edge
from .node import Node
from .run_concurrent import RunConcurrent
from .store import BaseStore
from .workflow import _NestableWorkflow


[docs] class Graph: """ Represents a directed graph of nodes and edges. """ def __init__(self, source: Node | _NestableWorkflow, sink: Node | _NestableWorkflow, edges: list[Edge]): self.source = source self.sink = sink self.edges = edges
[docs] async def get_next_node(self, store: BaseStore, current_node: Node | _NestableWorkflow) -> Node | _NestableWorkflow: """ Retrieves the next node (or workflow / subflow) in the graph for the given current node. This method checks the edges connected to the current node and resolves the next node based on the conditions defined in the edges. Args: store (BaseStore): The store instance to use for resolving the next node. current_node (Node | _NestableWorkflow): The current node or subflow in the graph. Returns: Node | _NestableWorkflow: The next node or subflow in the graph. """ matching_edges = [edge for edge in self.edges if edge.tail == current_node] resolved_edges = [edge for edge in matching_edges if await edge.next_node(store) is not None] if len(resolved_edges) == 0: raise ValueError("Check your Graph. No resolved edges. " f"No valid transition found for node or subflow: '{current_node}'.") else: resolved_edge = await resolved_edges[0].next_node(store) if resolved_edge is None: raise ValueError("Check your Graph. Resolved edge is None. " f"No valid transition found for node or subflow: '{current_node}'") return resolved_edge
[docs] def serialize_to_json_string(self) -> str: # noqa: C901 """ Converts the graph to a neutral serialized JSON string, representing RunConcurrent instances as subgraphs and includes Subflow graphs as well. Returns: str: A JSON string containing the graph structure. """ all_nodes_dict: dict[str, Node | _NestableWorkflow] = {} # Dictionary to store unique nodes found all_edges_dict: dict[str, Edge] = {} # Dictionary to store all edges including subflow edges processed_subflows: set[str] = set() # Track processed subflows to avoid recursion loops # Recursive helper function to find all nodes, including those inside RunConcurrent and Subflows def collect_nodes(node: Node | _NestableWorkflow | None): if node is None: return # Skip if not a Node or _NestableWorkflow or doesn't have an ID if not (isinstance(node, Node) or isinstance(node, _NestableWorkflow)) or not hasattr(node, 'id'): print(f"Warning: Item '{node}' is not a valid Node or Workflow with an id, skipping collection.") return if node.id not in all_nodes_dict: all_nodes_dict[node.id] = node # If it's a RunConcurrent, recursively collect the items it contains if isinstance(node, RunConcurrent) and hasattr(node, 'items'): for run_concurrent_item in node.items: collect_nodes(run_concurrent_item) # If it's a Subflow (inherits from _NestableWorkflow), recursively collect its graph elif ( isinstance(node, _NestableWorkflow) and hasattr(node, 'graph') and node.id not in processed_subflows ): processed_subflows.add(node.id) # Mark as processed to avoid cycles # Collect subflow's source, sink and all nodes connected by edges subflow_graph = node.graph collect_nodes(subflow_graph.source) collect_nodes(subflow_graph.sink) # Collect all edges from the subflow for edge in subflow_graph.edges: # Create a unique ID for the subflow edge edge_id = f"subflow_{node.id}_edge_{edge.tail.id}_{edge.head.id}" all_edges_dict[edge_id] = edge collect_nodes(edge.tail) collect_nodes(edge.head) # Collect edges from the main graph for i, edge in enumerate(self.edges): edge_id = f"edge_{edge.tail.id}_{edge.head.id}_{i}" all_edges_dict[edge_id] = edge # Start node collection collect_nodes(self.source) collect_nodes(self.sink) for edge in self.edges: collect_nodes(edge.tail) collect_nodes(edge.head) # Create nodes list for JSON output nodes_json = [] for node_id, node in all_nodes_dict.items(): # Determine Label: Prioritize 'label', then 'name', then class name label = getattr(node, 'label', None) or \ getattr(node, 'name', None) or \ node.__class__.__name__ node_info = { "id": node.id, "type": node.__class__.__name__, "label": label } # Add subgraph representation for RunConcurrent if isinstance(node, RunConcurrent): node_info["isSubgraph"] = True children_ids = [ n.id for n in node.items if (isinstance(n, Node) or isinstance(n, _NestableWorkflow)) and hasattr(n, 'id') ] node_info["children"] = children_ids # Add subflow representation for Subflows elif isinstance(node, _NestableWorkflow) and hasattr(node, 'graph'): node_info["isSubflow"] = True node_info["subflowSourceId"] = node.graph.source.id node_info["subflowSinkId"] = node.graph.sink.id nodes_json.append(node_info) # Create explicit edges list for JSON output edges_json = [] for edge_id, edge in all_edges_dict.items(): # Determine if this is a subflow edge is_subflow_edge = edge_id.startswith("subflow_") subflow_id = None if is_subflow_edge: # Extract the subflow ID from the edge_id (between "subflow_" and "_edge_") subflow_id = edge_id.split("_edge_")[0].replace("subflow_", "") edges_json.append({ "id": edge_id, "source": str(edge.tail.id), "target": str(edge.head.id), "condition": str(edge.condition) if edge.condition else None, "type": "subflow" if is_subflow_edge else "explicit", "subflowId": subflow_id if is_subflow_edge else None }) # Final graph dictionary structure graph_dict = { "v": 1, # Schema version "nodes": nodes_json, "edges": edges_json } try: # Serialize the dictionary to a JSON string return json.dumps(graph_dict, indent=2) except TypeError as e: print(f"Error serializing graph to JSON: {e}") error_info = { "error": "Failed to serialize graph", "detail": str(e), } return json.dumps(error_info, indent=2)
[docs] def to_mermaid(self) -> str: """ Converts the graph to Mermaid syntax. This is a placeholder for future implementation. """ raise NotImplementedError("Mermaid conversion is not implemented yet.")
[docs] def to_dot_notation(self) -> str: # noqa: C901 (complexity fine for helper) """ Render the Junjo graph as a *main* overview digraph plus one additional digraph for **each Subflow**. Strategy -------- • In the **overview** we treat every Subflow node as an atomic component (shape=component, fillcolour light-yellow). • Any `RunConcurrent` node is rendered as a cluster, exactly like before. • For every Subflow we emit a *second* `digraph subflow_<id>` that expands its internal graph, again treating nested Subflows as atomic macro nodes (so the drill-down is recursive). """ graph = json.loads(self.serialize_to_json_string()) # ----------------------------------------------------------------------- # # Render helpers # # ----------------------------------------------------------------------- # def render_graph( graph_name: str, edge_filter: Callable[[dict], bool], ) -> str: """ Build a single digraph string with: * RunConcurrent → clusters * Subflow → macro node (no cluster) Only edges for which `edge_filter(edge)` is True are included. """ nodes_by_id = {n["id"]: n for n in graph["nodes"]} edges = [e for e in graph["edges"] if edge_filter(e)] # ---- gather node‑ids that really participate in *this* drawing ---- # node_ids: set[str] = set() for e in edges: node_ids.update((e["source"], e["target"])) # if a RunConcurrent appears, also pull in its children for nid in list(node_ids): n = nodes_by_id.get(nid) if n and n.get("isSubgraph"): node_ids.update(n["children"]) # --------------- meta for RunConcurrent clusters ------------------- # clusters = { n["id"]: n for n in graph["nodes"] if n.get("isSubgraph") and n["id"] in node_ids } entry_anchor = {cid: f"{cid}__entry" for cid in clusters} exit_anchor = {cid: f"{cid}__exit" for cid in clusters} def _anchor(nid: str, *, is_src: bool) -> str: if nid in clusters: return exit_anchor[nid] if is_src else entry_anchor[nid] return nid # ------------------------------------------------------------------- # out: list[str] = [] a = out.append a(f'digraph "{graph_name}" {{') a(" rankdir=LR;") a(" compound=true;") a(' node [shape=box, style="rounded,filled", fillcolor="#EFEFEF", ' 'fontname="Helvetica", fontsize=10];') a(' edge [fontname="Helvetica", fontsize=9];') # ---------------------- RunConcurrent clusters --------------------- # for cid, n in clusters.items(): a(f' subgraph "cluster_{cid}" {{') a(f' label="{self._safe_label(n["label"])} (Concurrent)";') a(' style="filled"; fillcolor="lightblue"; color="blue";') a(' node [fillcolor="lightblue", style="filled,rounded"];') # invisible entry/exit points a(f' "{entry_anchor[cid]}" [label="", shape=point, width=0.01, ' 'style=invis];') a(f' "{exit_anchor[cid]}" [label="", shape=point, width=0.01, ' 'style=invis];') for child_id in n["children"]: child = nodes_by_id[child_id] a(f' {self._q(child_id)} ' f'[label="{self._safe_label(child["label"])}"];') a(" }") # ------------------- ordinary & Subflow macro nodes ---------------- # for nid in node_ids: if nid in clusters: continue # already rendered inside its cluster n = nodes_by_id[nid] if n.get("isSubflow"): # macro representation a(f' {self._q(nid)} [label="{self._safe_label(n["label"])}", ' 'shape=component, style="filled,rounded", ' 'fillcolor="lightyellow"];') else: a(f' {self._q(nid)} [label="{self._safe_label(n["label"])}"];') # ------------------------------ edges ------------------------------ # for e in edges: src = _anchor(e["source"], is_src=True) tgt = _anchor(e["target"], is_src=False) attrs: list[str] = [] if e["source"] in clusters: attrs.append(f'ltail="cluster_{e["source"]}"') if e["target"] in clusters: attrs.append(f'lhead="cluster_{e["target"]}"') if e.get("condition"): attrs.extend( ('style="dashed"', f'label="{self._safe_label(e["condition"])}"') ) else: attrs.append('style="solid"') a(f' {self._q(src)} -> {self._q(tgt)} [{", ".join(attrs)}];') a("}") return "\n".join(out) # ----------------------------------------------------------------------- # # 1) overview graph (explicit edges only) # # ----------------------------------------------------------------------- # dot_parts: list[str] = [ render_graph( graph_name="G", edge_filter=lambda e: e["type"] == "explicit" ) ] # ----------------------------------------------------------------------- # # 2) one digraph per sub‑flow # # ----------------------------------------------------------------------- # subflows = [ n for n in graph["nodes"] if n.get("isSubflow") ] for sf in subflows: dot_parts.append( render_graph( graph_name=f'subflow_{sf["id"]}', edge_filter=lambda e, sid=sf["id"]: ( e["type"] == "subflow" and e["subflowId"] == sid ), ) ) # join with blank lines so graphviz treats them as separate digraphs return "\n\n".join(dot_parts)
[docs] def export_graphviz_assets( self, out_dir: str | Path = "graphviz_out", fmt: str = "svg", dot_cmd: str = "dot", open_html: bool = False, clean: bool = True, ) -> dict[str, Path]: """ Render every digraph produced by `to_dot_notation()` and build a gallery HTML page whose headings use the *human* labels (e.g. “SampleSubflow”) instead of raw digraph identifiers. Returns ------- Ordered mapping digraph_name → rendered file path, **in encounter order**. """ out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) # --- delete old artefacts -------------------------------------- if clean: for p in out_dir.iterdir(): if p.suffix in (".dot", f".{fmt}") and p.is_file(): p.unlink() # ------------------------------------------------------------------ # # 1. Build "digraph name" → human‑readable label lookup # # ------------------------------------------------------------------ # label_lookup = {"G": "Overview"} # first graph json_graph = json.loads(self.serialize_to_json_string()) for node in json_graph["nodes"]: if node.get("isSubflow"): label_lookup[f"subflow_{node['id']}"] = node["label"] # ------------------------------------------------------------------ # # 2. Split the combined DOT text into individual blocks # # (order preserved) # # ------------------------------------------------------------------ # dot_text = self.to_dot_notation().lstrip() blocks: list[str] = re.split(r"\n(?=digraph )", dot_text) # ------------------------------------------------------------------ # # choose a safe filename stem # # ------------------------------------------------------------------ # def _fname(s: str) -> str: """Turn any string into a filesystem‑friendly stem.""" return re.sub(r"[^A-Za-z0-9_.-]", "_", s) top_stem = _fname(type(self).__name__ or "Overview") # e.g. MyWorkflow digraph_files: dict[str, Path] = {} # preserves order for block in blocks: m = re.match(r'digraph\s+"?([A-Za-z0-9_]+)"?', block) if not m: continue dgraph_id = m.group(1) # raw identifier: "G", "subflow_<id>", … # --------------------------------------------- # # decide the output filename stem # # --------------------------------------------- # if dgraph_id == "G": # the primary overview graph stem = top_stem else: # keep sub‑flow identifier for others stem = dgraph_id dot_path = out_dir / f"{stem}.dot" img_path = out_dir / f"{stem}.{fmt}" dot_path.write_text(block, encoding="utf-8") subprocess.run( [dot_cmd, "-T", fmt, str(dot_path), "-o", str(img_path)], check=True, ) # keep mapping by digraph‑identifier, not by filename digraph_files[dgraph_id] = img_path # ------------------------------------------------------------------ # # 3. Build index.html (respect encounter order) # # ------------------------------------------------------------------ # html_path = out_dir / "index.html" html_parts = [ "<!doctype html><html><head>", '<meta charset="utf-8"><title>Junjo Graphs</title>', "<style>body{font-family:Helvetica,Arial,sans-serif}" "img{max-width:100%;border:1px solid #ccc;margin-bottom:2rem}</style>", "</head><body>", "<h1>Junjo workflow diagrams</h1>", ] for name, img in digraph_files.items(): heading = html.escape(label_lookup.get(name, name)) html_parts.append(f"<h2>{heading}</h2>") html_parts.append(f'<img src="{img.name}" alt="{heading} diagram">') html_parts.append("</body></html>") html_path.write_text("\n".join(html_parts), encoding="utf-8") if open_html: import webbrowser webbrowser.open(html_path.as_uri()) return digraph_files
# --------------------------------------------------------------------------- # # Utility helpers # # --------------------------------------------------------------------------- # _ID_RX = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") def _q(self, id_: str) -> str: """Quote a Graphviz identifier when needed.""" return id_ if self._ID_RX.fullmatch(id_) else f'"{id_}"' def _safe_label(self, text: str) -> str: """Escape quotes so they stay intact in dot files.""" return html.escape(str(text)).replace('"', r"\"")