Source code for junjo.workflow

from __future__ import annotations

from abc import ABC, abstractmethod
from types import NoneType
from typing import TYPE_CHECKING, Generic

from opentelemetry import trace

from .node import Node
from .run_concurrent import RunConcurrent
from .store import ParentStateT, ParentStoreT, StateT, StoreT
from .telemetry.hook_manager import HookManager
from .telemetry.otel_schema import JUNJO_OTEL_MODULE_NAME, JunjoOtelSpanTypes
from .util import generate_safe_id

if TYPE_CHECKING:
    from .graph import Graph

class _NestableWorkflow(Generic[StateT, StoreT, ParentStateT, ParentStoreT]):
    """
    Represents a workflow execution.
    """

    def __init__(
            self,
            graph: Graph,
            store: StoreT,
            max_iterations: int = 100,
            hook_manager: HookManager | None = None,
            name: str | None = None,
    ):
        self._id = generate_safe_id()
        self._name = name
        self.graph = graph
        self.max_iterations = max_iterations
        self.node_execution_counter: dict[str, int] = {}
        self.hook_manager = hook_manager

        # Private stores (immutable interactions only)
        self._store = store

    @property
    def store(self) -> StoreT:
        return self._store

    @property
    def id(self) -> str:
        """Returns the unique identifier for the node."""
        return self._id

    @property
    def name(self) -> str:
        """Returns the name of the node class instance."""
        if self._name is not None:
            return self._name

        return self.__class__.__name__

    @property
    def span_type(self) -> JunjoOtelSpanTypes:
        """Returns the span type of the workflow."""

        if isinstance(self, Subflow):
            return JunjoOtelSpanTypes.SUBFLOW
        return JunjoOtelSpanTypes.WORKFLOW

    async def get_state(self) -> StateT:
        return await self._store.get_state()

    async def get_state_json(self) -> str:
        return await self._store.get_state_json()

    async def execute(  # noqa: C901
            self,
            parent_store: ParentStoreT | None = None,
            parent_id: str | None = None,
        ):
        """
        Executes the workflow.
        """
        print(f"Executing workflow: {self.name} with ID: {self.id}")

        # TODO: Test that the sink node can be reached

        # # Execute workflow before hooks
        # if self.hook_manager is not None:
            # self.hook_manager.run_before_workflow_execute_hooks(before_workflow_hook_args)

        # Acquire a tracer (will be a real tracer if configured, otherwise no-op)
        tracer = trace.get_tracer(JUNJO_OTEL_MODULE_NAME)

        # Start a new span and keep a reference to the span object
        with tracer.start_as_current_span(self.name) as span:
            # Set span attributes
            span.set_attribute("junjo.workflow.state.start", await self.get_state_json())
            span.set_attribute("junjo.workflow.graph_structure", self.graph.serialize_to_json_string())
            span.set_attribute("junjo.workflow.store.id", self.store.id)
            span.set_attribute("junjo.span_type", self.span_type)
            span.set_attribute("junjo.id", self.id)

            # Set the parent ID and store ID if available (for subflows)
            if parent_id is not None:
                span.set_attribute("junjo.parent_id", parent_id)

            if parent_store is not None and parent_store.id is not None:
                span.set_attribute("junjo.workflow.parent_store.id", parent_store.id)

            # If executing a subflow, run pre-run actions
            if isinstance(self, Subflow):
                if parent_store is None:
                    raise ValueError("Subflow requires a parent store to execute pre_run_actions.")
                await self.pre_run_actions(parent_store)

            # Loop to execute the nodes inside this workflow
            current_executable = self.graph.source
            try:
                while True:

                    # # Execute node before hooks
                    # if self.hook_manager is not None:
                    #     self.hook_manager.run_before_node_execute_hooks(span_open_node_args)

                    # # If executing a subflow
                    if isinstance(current_executable, Subflow):
                        print("Executing subflow:", current_executable.name)

                        # Pass the current store as the parent store for the sub-flow
                        await current_executable.execute(self.store, self.id)

                        # Incorporate the Subflows node count
                        # into the parent workflow's node execution counter
                        self.node_execution_counter[current_executable.id] = sum(
                            current_executable.node_execution_counter.values()
                        )

                    # If executing a node
                    if isinstance(current_executable, Node):
                        print("Executing node:", current_executable.name)
                        await current_executable.execute(self.store, self.id)

                        # # Execute node after hooks
                        # if self.hook_manager is not None:
                        #     self.hook_manager.run_after_node_execute_hooks(span_close_node_args)

                        # Increment the execution counter for RunConcurrent executions
                        if isinstance(current_executable, RunConcurrent):
                            for item in current_executable.items:
                                self.node_execution_counter[item.id] = self.node_execution_counter.get(item.id, 0) + 1
                                if self.node_execution_counter[item.id] > self.max_iterations:
                                    raise ValueError(
                                        f"Node '{item}' exceeded maximum execution count. \
                                        Check for loops in your graph. Ensure it transitions to the sink node."
                                    )

                        # Increment the execution counter for Node executions
                        else:
                            self.node_execution_counter[current_executable.id] = self.node_execution_counter.get(current_executable.id, 0) + 1
                            if self.node_execution_counter[current_executable.id] > self.max_iterations:
                                raise ValueError(
                                    f"Node '{current_executable}' exceeded maximum execution count. \
                                    Check for loops in your graph. Ensure it transitions to the sink node."
                                )

                    # Break the loop if the current node is the final node.
                    if current_executable == self.graph.sink:
                        print("Sink has executed. Exiting loop.")
                        break

                    # Get the next executable in the workflow.
                    current_executable = await self.graph.get_next_node(self.store, current_executable)


                print(f"Completed workflow: {self.name} with ID: {self.id}")

                # Perform subflow post-run actions
                if isinstance(self, Subflow):
                    if parent_store is None:
                        raise ValueError("Subflow requires a parent store to execute post_run_actions.")
                    else:
                        print("Performing post-run actions for subflow:", self.name)
                        await self.post_run_actions(parent_store)

            except Exception as e:
                print(f"Error executing workflow: {e}")
                span.set_status(trace.StatusCode.ERROR, str(e))
                span.record_exception(e)

                # Raise the error to be handled by the caller
                raise e

            finally:
                execution_sum = sum(self.node_execution_counter.values())

                # Update attributes *after* the workflow loop completes (or errors)
                span.set_attribute("junjo.workflow.state.end", await self.get_state_json())
                span.set_attribute("junjo.workflow.node.count", execution_sum)

            # # Execute workflow after hooks
            # if self.hook_manager is not None:
            #     self.hook_manager.run_after_workflow_execute_hooks(
            #         after_workflow_hook_args
            #     )

            return

# Class Variation
[docs] class Workflow(_NestableWorkflow[StateT, StoreT, NoneType, NoneType]): """ Represents a top level workflow that can be executed. Generic Type Parameters: | StateT: The type of state managed by this workflow | StoreT: The type of store used by this workflow A workflow is a collection of nodes and edges as a graph that can be executed. .. code-block:: python workflow = Workflow[MyGraphState, MyGraphStore]( name="demo_base_workflow", graph=graph, store=graph_store, hook_manager=HookManager(verbose_logging=False, open_telemetry=True), ) await workflow.execute() """ pass
[docs] class Subflow(_NestableWorkflow[StateT, StoreT, ParentStateT, ParentStoreT], ABC): """ Represents a subflow execution that can interact with a parent workflow. Generic Type Parameters: | StateT: The type of state managed by this subflow | StoreT: The type of store used by this subflow | ParentStateT: The type of state managed by the parent workflow | ParentStoreT: The type of store used by the parent workflow A subflow is a workflow that: | 1. Executes within a parent workflow | 2. Has its own isolated state and store | 3. Can interact with the parent workflow's state before and after execution .. code-block:: python class ExampleSubFlow(Subflow[SubflowState, SubflowStore, ParentState, ParentStore]): async def pre_run_actions(self, parent_store): parent_state = await parent_store.get_state() await self.store.set_parameter({ "parameter": parent_state.parameter }) async def post_run_actions(self, parent_store): async def post_run_actions(self, parent_store): sub_flow_state = await self.get_state() await parent_store.set_subflow_result(self, sub_flow_state.result) """ def __init__( self, graph: Graph, store: StoreT, max_iterations: int = 100, ): """ Initializes the Subflow. Args: graph: The workflow graph. store: The store instance for this subflow. max_iterations: The maximum number of times a node can be executed before raising an exception (defaults to 100) """ super().__init__( graph=graph, store=store, max_iterations=max_iterations, hook_manager=None )
[docs] @abstractmethod async def pre_run_actions(self, parent_store: ParentStoreT) -> None: """ This method is called before the workflow has run. This is where you can pass initial state values from the parent workflow to the subflow state. Args: parent_store: The parent store to interact with. In this example, we are passing a parameter from the parent store to the subflow store, using the subflow's `set_parameter` method, defined in the subflow's store. """ pass
[docs] @abstractmethod async def post_run_actions(self, parent_store: ParentStoreT) -> None: """ This method is called after the workflow has run. This is where you can update the parent store with the results of the workflow. This is useful for subflows that need to update the parent workflow store with their results. Args: parent_store: The parent store to update. """ pass