importabcimportasyncioimportinspectfromcollections.abcimportAwaitable,CallablefromtypesimportNoneTypefromtypingimportGeneric,TypeVarimportjsonpatchfromopentelemetryimporttracefrompydanticimportValidationErrorfrom.stateimportBaseStatefrom.utilimportgenerate_safe_id# State / StoreStateT=TypeVar("StateT",bound=BaseState)StoreT=TypeVar("StoreT",bound="BaseStore")# Parent State / StoreParentStateT=TypeVar("ParentStateT",bound="BaseState | NoneType")ParentStoreT=TypeVar("ParentStoreT",bound="BaseStore | NoneType")# Type alias: each subscriber can be either a sync callable or an async callable (returns an Awaitable).Subscriber=Callable[[StateT],None]|Callable[[StateT],Awaitable[None]]
[docs]classBaseStore(Generic[StateT],metaclass=abc.ABCMeta):""" BaseStore represents a generic store for managing the state of a workflow. It is designed to be subclassed with a specific state type (Pydantic model). The store is responsible for: | - Managing the state of the workflow. | - Making immuable updates to the state safely in a concurrent environment. | - Validating state updates against the Pydantic model. | - Providing methods to subscribe to state changes. | - Notifying subscribers when the state changes. The store uses an asyncio.Lock to ensure that state updates are thread-safe and that subscribers are notified in a safe manner. This is important in an async environment where multiple coroutines may be trying to update the state or subscribe to changes at the same time. """def__init__(self,initial_state:StateT)->None:""" Args: initial_state: The initial state of the store, based on the Pydantic model. """# Use an asyncio.Lock for concurrency control in an async environmentself._lock=asyncio.Lock()# Generate a unique ID for the store instanceself._id=generate_safe_id()# The current state of the storeself._state:StateT=initial_state# Each subscriber can be a synchronous or asynchronous functionself._subscribers:list[Subscriber]=[]@propertydefid(self)->str:"""Returns the unique identifier of a given store's implementation."""returnself._id
[docs]asyncdefsubscribe(self,listener:Subscriber)->Callable[[],Awaitable[None]]:""" Register a listener (sync or async callable) to be called whenever the state changes. Returns an *async* unsubscribe function that, when awaited, removes this listener. """asyncwithself._lock:self._subscribers.append(listener)asyncdefunsubscribe()->None:""" Async function to remove the listener from the subscriber list. We lock again to ensure concurrency safety. """asyncwithself._lock:iflistenerinself._subscribers:self._subscribers.remove(listener)returnunsubscribe
[docs]asyncdefget_state(self)->StateT:""" Return a shallow copy of the current state. (Follows immutability principle) """asyncwithself._lock:# Return a separate copy of the Pydantic model so outside code doesn't mutate the storereturnself._state.model_copy()
[docs]asyncdefget_state_json(self)->str:""" Return the current state as a JSON string. """asyncwithself._lock:returnself._state.model_dump_json()
[docs]asyncdefset_state(self,update:dict)->None:""" Update the store's state with a dictionary of changes. | - Immutable update with a deep state copy | - Merges the current state with `updates` using `model_copy(update=...)`. | - Validates that each updated field is valid for StateT. | - If there's a change, notifies subscribers outside the lock. Args: update: A dictionary of updates to apply to the state. .. code-block:: python class MessageWorkflowState(BaseState): # A pydantic model to represent the state received_message: Message class MessageWorkflowStore(BaseStore[MessageWorkflowState]): # A concrete store for MessageWorkflowState async def set_received_message(self, payload: Message) -> None: await self.set_state({"received_message": payload}) payload = Message(...) await store.set_received_message(payload) # Utilizes the set_state method to update a particular field """# Get the caller function's name and class name for telemetry purposescaller_frame=inspect.currentframe()ifcaller_frame:caller_frame=caller_frame.f_backcaller_function_name=caller_frame.f_code.co_nameifcaller_frameelse"unknown action"# Get the caller class name if availablecaller_class_name="unknown store"ifcaller_frameand"self"incaller_frame.f_locals:caller_class_name=caller_frame.f_locals["self"].__class__.__name__# Validate the update dictionary against the state's model OUTSIDE the locktry:self._state.__class__.model_validate({**self._state.model_dump(),**update})exceptValidationErrorase:raiseValueError(f"Invalid state update from caller {caller_class_name} -> {caller_function_name}.\n"f"Check that you are updating a valid state property and type: {e}")fromesubscribers_to_notify:list[Subscriber]=[]asyncwithself._lock:# Create a new instance with partial updates, deep=True for true immutabilitynew_state=self._state.model_copy(update=update,deep=True)# Patch starts as Nonepatch=None# Only notify if something actually changedifnew_state!=self._state:state_json_before=self._state.model_dump(mode="json")state_json_after=new_state.model_dump(mode="json")# Calculate the patchpatch=jsonpatch.make_patch(state_json_before,state_json_after)# print("PATCH: ", patch)# Update the stack (have lock)self._state=new_statesubscribers_to_notify=list(self._subscribers)# --- OpenTelemetry Event (call even if nothing changed) --- #current_span=trace.get_current_span()ifcurrent_span.is_recording():current_span.add_event(name="set_state",attributes={"id":generate_safe_id(),"junjo.store.name":caller_class_name,"junjo.store.id":self.id,"junjo.store.action":caller_function_name,"junjo.state_json_patch":patch.to_string()ifpatchelse"{}",# Empty if nothing changed},)# --- End OpenTelemetry Event --- ## Notify subscribers outside the lockifsubscribers_to_notify:awaitself._notify_subscribers(new_state,subscribers_to_notify)
asyncdef_notify_subscribers(self,new_state:StateT,subscribers:list[Subscriber])->None:""" Private helper to call subscribers once the lock is released. """forsubscriberinsubscribers:result=subscriber(new_state)# If the subscriber is async, it returns a coroutine or awaitableifasyncio.iscoroutine(result)orisinstance(result,Awaitable):awaitresult