From 62c0e63f4d5091180ffc3cce0ca0fb735d185f0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ahlert?= Date: Wed, 18 Mar 2026 07:57:42 -0300 Subject: [PATCH 1/4] fix: allow streaming actions to gracefully handle raised exceptions When a streaming action catches an exception and yields a final state in a try/except/finally block, the stream now completes gracefully instead of propagating the exception and killing the connection. If the generator yields a valid state_update before the exception propagates, the exception is suppressed and the stream terminates normally. If no state was yielded, the exception propagates as before. Closes #581 --- burr/core/application.py | 205 +++++++++++++++++++++------------------ 1 file changed, 113 insertions(+), 92 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index dc8067c4b..44d4a462f 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -338,31 +338,35 @@ def _run_single_step_streaming_action( result = None state_update = None count = 0 - for item in generator: - if not isinstance(item, tuple): - # TODO -- consider adding support for just returning a result. - raise ValueError( - f"Action {action.name} must yield a tuple of (result, state_update). " - f"For all non-final results (intermediate)," - f"the state update must be None" - ) - result, state_update = item - count += 1 + try: + for item in generator: + if not isinstance(item, tuple): + # TODO -- consider adding support for just returning a result. + raise ValueError( + f"Action {action.name} must yield a tuple of (result, state_update). " + f"For all non-final results (intermediate)," + f"the state update must be None" + ) + result, state_update = item + count += 1 + if state_update is None: + if first_stream_start_time is None: + first_stream_start_time = system.now() + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_item", + item=result, + item_index=count, + stream_initialize_time=stream_initialize_time, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + yield result, None + except Exception: if state_update is None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - lifecycle_adapters.call_all_lifecycle_hooks_sync( - "post_stream_item", - item=result, - item_index=count, - stream_initialize_time=stream_initialize_time, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - yield result, None + raise if state_update is None: raise ValueError( @@ -391,31 +395,36 @@ async def _arun_single_step_streaming_action( result = None state_update = None count = 0 - async for item in generator: - if not isinstance(item, tuple): - # TODO -- consider adding support for just returning a result. - raise ValueError( - f"Action {action.name} must yield a tuple of (result, state_update). " - f"For all non-final results (intermediate)," - f"the state update must be None" - ) - result, state_update = item + try: + async for item in generator: + if not isinstance(item, tuple): + # TODO -- consider adding support for just returning a result. + raise ValueError( + f"Action {action.name} must yield a tuple of (result, state_update). " + f"For all non-final results (intermediate)," + f"the state update must be None" + ) + result, state_update = item + if state_update is None: + if first_stream_start_time is None: + first_stream_start_time = system.now() + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_item", + item=result, + item_index=count, + stream_initialize_time=stream_initialize_time, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + count += 1 + yield result, None + except Exception: if state_update is None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( - "post_stream_item", - item=result, - item_index=count, - stream_initialize_time=stream_initialize_time, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - count += 1 - yield result, None + raise + if state_update is None: raise ValueError( f"Action {action.name} did not return a state update. For async actions, the last yield " @@ -450,28 +459,34 @@ def _run_multi_step_streaming_action( result = None first_stream_start_time = None count = 0 - for item in generator: - # We want to peek ahead so we can return the last one - # This is slightly eager, but only in the case in which we - # are using a multi-step streaming action - next_result = result - result = item - if next_result is not None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - lifecycle_adapters.call_all_lifecycle_hooks_sync( - "post_stream_item", - item=next_result, - item_index=count, - stream_initialize_time=stream_initialize_time, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - count += 1 - yield next_result, None + caught_exc = None + try: + for item in generator: + # We want to peek ahead so we can return the last one + # This is slightly eager, but only in the case in which we + # are using a multi-step streaming action + next_result = result + result = item + if next_result is not None: + if first_stream_start_time is None: + first_stream_start_time = system.now() + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_item", + item=next_result, + item_index=count, + stream_initialize_time=stream_initialize_time, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + count += 1 + yield next_result, None + except Exception as e: + if result is None: + raise + caught_exc = e state_update = _run_reducer(action, state, result, action.name) _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) @@ -494,28 +509,34 @@ async def _arun_multi_step_streaming_action( result = None first_stream_start_time = None count = 0 - async for item in generator: - # We want to peek ahead so we can return the last one - # This is slightly eager, but only in the case in which we - # are using a multi-step streaming action - next_result = result - result = item - if next_result is not None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( - "post_stream_item", - item=next_result, - stream_initialize_time=stream_initialize_time, - item_index=count, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - count += 1 - yield next_result, None + caught_exc = None + try: + async for item in generator: + # We want to peek ahead so we can return the last one + # This is slightly eager, but only in the case in which we + # are using a multi-step streaming action + next_result = result + result = item + if next_result is not None: + if first_stream_start_time is None: + first_stream_start_time = system.now() + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_item", + item=next_result, + stream_initialize_time=stream_initialize_time, + item_index=count, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + count += 1 + yield next_result, None + except Exception as e: + if result is None: + raise + caught_exc = e state_update = _run_reducer(action, state, result, action.name) _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) From e66b67e5d89c6aa70f7707bc493b38391fb5e75b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ahlert?= Date: Tue, 24 Mar 2026 09:36:41 -0300 Subject: [PATCH 2/4] fix: add logging for suppressed exceptions and tests for streaming graceful shutdown - Add logger.warning with exc_info in all 4 streaming except blocks - Remove dead caught_exc variable in multi-step functions - Fix count increment asymmetry between sync and async single-step - Add 8 tests covering graceful exception handling and propagation --- burr/core/application.py | 38 ++++- tests/core/test_application.py | 290 +++++++++++++++++++++++++++++++++ 2 files changed, 320 insertions(+), 8 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index 44d4a462f..296faf36a 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -348,8 +348,8 @@ def _run_single_step_streaming_action( f"the state update must be None" ) result, state_update = item - count += 1 if state_update is None: + count += 1 if first_stream_start_time is None: first_stream_start_time = system.now() lifecycle_adapters.call_all_lifecycle_hooks_sync( @@ -364,9 +364,15 @@ def _run_single_step_streaming_action( sequence_id=sequence_id, ) yield result, None - except Exception: + except Exception as e: if state_update is None: raise + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with final state from generator cleanup. Original error: %s", + action.name, type(e).__name__, count, e, + exc_info=True, + ) if state_update is None: raise ValueError( @@ -406,6 +412,7 @@ async def _arun_single_step_streaming_action( ) result, state_update = item if state_update is None: + count += 1 if first_stream_start_time is None: first_stream_start_time = system.now() await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( @@ -419,11 +426,16 @@ async def _arun_single_step_streaming_action( partition_key=partition_key, sequence_id=sequence_id, ) - count += 1 yield result, None - except Exception: + except Exception as e: if state_update is None: raise + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with final state from generator cleanup. Original error: %s", + action.name, type(e).__name__, count, e, + exc_info=True, + ) if state_update is None: raise ValueError( @@ -459,7 +471,6 @@ def _run_multi_step_streaming_action( result = None first_stream_start_time = None count = 0 - caught_exc = None try: for item in generator: # We want to peek ahead so we can return the last one @@ -486,7 +497,13 @@ def _run_multi_step_streaming_action( except Exception as e: if result is None: raise - caught_exc = e + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with last yielded result for reducer. " + "Note: the reducer will run on potentially partial data. Original error: %s", + action.name, type(e).__name__, count, e, + exc_info=True, + ) state_update = _run_reducer(action, state, result, action.name) _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) @@ -509,7 +526,6 @@ async def _arun_multi_step_streaming_action( result = None first_stream_start_time = None count = 0 - caught_exc = None try: async for item in generator: # We want to peek ahead so we can return the last one @@ -536,7 +552,13 @@ async def _arun_multi_step_streaming_action( except Exception as e: if result is None: raise - caught_exc = e + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with last yielded result for reducer. " + "Note: the reducer will run on potentially partial data. Original error: %s", + action.name, type(e).__name__, count, e, + exc_info=True, + ) state_update = _run_reducer(action, state, result, action.name) _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) diff --git a/tests/core/test_application.py b/tests/core/test_application.py index c90c40676..587c151f1 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -1275,6 +1275,296 @@ async def post_stream_item(self, item: Any, **future_kwargs: Any): assert len(hook.items) == 10 # one for each streaming callback +class SingleStepStreamingCounterWithException(SingleStepStreamingAction): + """Yields intermediate items, raises, then yields final state in finally block.""" + + def stream_run_and_update( + self, state: State, **run_kwargs + ) -> Generator[Tuple[dict, Optional[State]], None, None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class SingleStepStreamingCounterWithExceptionNoState(SingleStepStreamingAction): + """Raises without ever yielding a final state update.""" + + def stream_run_and_update( + self, state: State, **run_kwargs + ) -> Generator[Tuple[dict, Optional[State]], None, None]: + count = state["count"] + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure with no state") + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class SingleStepStreamingCounterWithExceptionAsync(SingleStepStreamingAction): + """Async variant: yields intermediate items, raises, then yields final state in finally.""" + + async def stream_run_and_update( + self, state: State, **run_kwargs + ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class SingleStepStreamingCounterWithExceptionNoStateAsync(SingleStepStreamingAction): + """Async variant: raises without ever yielding a final state update.""" + + async def stream_run_and_update( + self, state: State, **run_kwargs + ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + count = state["count"] + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure with no state") + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class MultiStepStreamingCounterWithException(StreamingAction): + """Yields intermediate items, raises, then yields final result in finally block.""" + + def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)} + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1} + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +class MultiStepStreamingCounterWithExceptionNoResult(StreamingAction): + """Raises without ever yielding any item.""" + + def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: + raise RuntimeError("simulated failure with no result") + yield # make this a generator function + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +class MultiStepStreamingCounterWithExceptionAsync(AsyncStreamingAction): + """Async variant: yields intermediate items, raises, then yields final result in finally.""" + + async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)} + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1} + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +class MultiStepStreamingCounterWithExceptionNoResultAsync(AsyncStreamingAction): + """Async variant: raises without ever yielding any item.""" + + async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: + raise RuntimeError("simulated failure with no result") + yield # make this an async generator + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +def test__run_single_step_streaming_action_graceful_exception(): + """When the generator raises but yields a final state in finally, stream completes gracefully.""" + action = SingleStepStreamingCounterWithException().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + results = list(generator) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + +def test__run_single_step_streaming_action_exception_propagates(): + """When the generator raises without yielding a final state, exception propagates.""" + action = SingleStepStreamingCounterWithExceptionNoState().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + with pytest.raises(RuntimeError, match="simulated failure with no state"): + list(generator) + + +async def test__run_single_step_streaming_action_graceful_exception_async(): + """Async: when the generator raises but yields a final state in finally, stream completes.""" + action = SingleStepStreamingCounterWithExceptionAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_single_step_streaming_action( + action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + results = [] + async for item in generator: + results.append(item) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + +async def test__run_single_step_streaming_action_exception_propagates_async(): + """Async: when the generator raises without yielding a final state, exception propagates.""" + action = SingleStepStreamingCounterWithExceptionNoStateAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_single_step_streaming_action( + action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + with pytest.raises(RuntimeError, match="simulated failure with no state"): + async for _ in generator: + pass + + +def test__run_multi_step_streaming_action_graceful_exception(): + """When the generator raises but yields a final result in finally, stream completes.""" + action = MultiStepStreamingCounterWithException().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_multi_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + results = list(generator) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + +def test__run_multi_step_streaming_action_exception_propagates(): + """When the generator raises without yielding any result, exception propagates.""" + action = MultiStepStreamingCounterWithExceptionNoResult().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_multi_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + with pytest.raises(RuntimeError, match="simulated failure with no result"): + list(generator) + + +async def test__run_multi_step_streaming_action_graceful_exception_async(): + """Async: when the generator raises but yields a final result in finally, stream completes.""" + action = MultiStepStreamingCounterWithExceptionAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_multi_step_streaming_action( + action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + results = [] + async for item in generator: + results.append(item) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + +async def test__run_multi_step_streaming_action_exception_propagates_async(): + """Async: when the generator raises without yielding any result, exception propagates.""" + action = MultiStepStreamingCounterWithExceptionNoResultAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_multi_step_streaming_action( + action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + with pytest.raises(RuntimeError, match="simulated failure with no result"): + async for _ in generator: + pass + + class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion): async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return {}, state.wipe(delete=["to_delete"]) From 84b439fade37c54775f1f2d2fcefb1aa8003931b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ahlert?= Date: Sat, 28 Mar 2026 16:06:26 -0300 Subject: [PATCH 3/4] style: apply black formatting --- burr/core/application.py | 186 ++++++++++---- tests/core/test_application.py | 444 ++++++++++++++++++++++++++------- 2 files changed, 488 insertions(+), 142 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index 296faf36a..cca6feccc 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -70,7 +70,12 @@ from burr.core.state import State from burr.core.typing import ActionSchema, DictBasedTypingSystem, TypingSystem from burr.core.validation import BASE_ERROR_MESSAGE -from burr.lifecycle.base import ExecuteMethod, LifecycleAdapter, PostRunStepHook, PreRunStepHook +from burr.lifecycle.base import ( + ExecuteMethod, + LifecycleAdapter, + PostRunStepHook, + PreRunStepHook, +) from burr.lifecycle.internal import LifecycleAdapterSet from burr.visibility import tracing from burr.visibility.tracing import tracer_factory_context_var @@ -89,7 +94,9 @@ StateTypeToSet = TypeVar("StateTypeToSet") -def _validate_result(result: Any, name: str, schema: ActionSchema = DEFAULT_SCHEMA) -> None: +def _validate_result( + result: Any, name: str, schema: ActionSchema = DEFAULT_SCHEMA +) -> None: # TODO -- split out the action schema into input/output schema types # Currently they're tied together, but this doesn't make as much sense for single-step actions result_type = schema.intermediate_result_type() @@ -158,7 +165,9 @@ def _remap_dunder_parameters( return inputs -def _run_function(function: Function, state: State, inputs: Dict[str, Any], name: str) -> dict: +def _run_function( + function: Function, state: State, inputs: Dict[str, Any], name: str +) -> dict: """Runs a function, returning the result of running the function. Note this restricts the keys in the state to only those that the function reads. @@ -178,7 +187,9 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name function.validate_inputs(inputs) if "__context" in inputs or "__tracer" in inputs: # potentially need to remap the __context & __tracer variables - inputs = _remap_dunder_parameters(function.run, inputs, ["__context", "__tracer"]) + inputs = _remap_dunder_parameters( + function.run, inputs, ["__context", "__tracer"] + ) result = function.run(state_to_use, **inputs) _validate_result(result, name) return result @@ -223,11 +234,14 @@ def _state_update(state_to_modify: State, modified_state: State) -> State: deleted_keys = [ item for item in ( - set(state_to_modify.keys()) - set(modified_state_without_private_fields.keys()) + set(state_to_modify.keys()) + - set(modified_state_without_private_fields.keys()) ) if not item.startswith("__") ] - return state_to_modify.merge(modified_state_without_private_fields).wipe(delete=deleted_keys) + return state_to_modify.merge(modified_state_without_private_fields).wipe( + delete=deleted_keys + ) def _validate_reducer_writes(reducer: Reducer, state: State, name: str) -> None: @@ -292,7 +306,9 @@ def _format_BASE_ERROR_MESSAGE(action: Action, input_state: State, inputs: dict) message += f"> Action: `{action.name}` encountered an error!" padding = " " * (80 - len(message) - 1) message += padding + "<" - message += "\n> State (at time of action):\n" + _create_dict_string(input_state.get_all()) + message += "\n> State (at time of action):\n" + _create_dict_string( + input_state.get_all() + ) message += "\n> Inputs (at time of action):\n" + _create_dict_string(inputs) border = "*" * 80 return "\n" + border + "\n" + message + "\n" + border @@ -370,7 +386,10 @@ def _run_single_step_streaming_action( logger.warning( "Streaming action '%s' raised %s after yielding %d items. " "Proceeding with final state from generator cleanup. Original error: %s", - action.name, type(e).__name__, count, e, + action.name, + type(e).__name__, + count, + e, exc_info=True, ) @@ -433,7 +452,10 @@ async def _arun_single_step_streaming_action( logger.warning( "Streaming action '%s' raised %s after yielding %d items. " "Proceeding with final state from generator cleanup. Original error: %s", - action.name, type(e).__name__, count, e, + action.name, + type(e).__name__, + count, + e, exc_info=True, ) @@ -501,7 +523,10 @@ def _run_multi_step_streaming_action( "Streaming action '%s' raised %s after yielding %d items. " "Proceeding with last yielded result for reducer. " "Note: the reducer will run on potentially partial data. Original error: %s", - action.name, type(e).__name__, count, e, + action.name, + type(e).__name__, + count, + e, exc_info=True, ) state_update = _run_reducer(action, state, result, action.name) @@ -556,7 +581,10 @@ async def _arun_multi_step_streaming_action( "Streaming action '%s' raised %s after yielding %d items. " "Proceeding with last yielded result for reducer. " "Note: the reducer will run on potentially partial data. Original error: %s", - action.name, type(e).__name__, count, e, + action.name, + type(e).__name__, + count, + e, exc_info=True, ) state_update = _run_reducer(action, state, result, action.name) @@ -686,9 +714,12 @@ def call_pre(self, app) -> bool: def call_post(self, app, exc) -> bool: if should_run_hooks := ( - app.uid in _run_call_var.get(dict) and _run_call_var.get()[app.uid] == self.method + app.uid in _run_call_var.get(dict) + and _run_call_var.get()[app.uid] == self.method ): - _run_call_var.set({k: v for k, v in _run_call_var.get().items() if k != app.uid}) + _run_call_var.set( + {k: v for k, v in _run_call_var.get().items() if k != app.uid} + ) app._adapter_set.call_all_lifecycle_hooks_sync( "post_run_execute_call", app_id=app.uid, @@ -713,9 +744,12 @@ async def acall_pre(self, app) -> bool: async def acall_post(self, app, exc) -> bool: if should_run_hooks := ( - app.uid in _run_call_var.get(dict) and _run_call_var.get()[app.uid] == self.method + app.uid in _run_call_var.get(dict) + and _run_call_var.get()[app.uid] == self.method ): - _run_call_var.set({k: v for k, v in _run_call_var.get().items() if k != app.uid}) + _run_call_var.set( + {k: v for k, v in _run_call_var.get().items() if k != app.uid} + ) await app._adapter_set.call_all_lifecycle_hooks_sync_and_async( "post_run_execute_call", app_id=app._uid, @@ -796,7 +830,9 @@ def post_run_step( **future_kwargs: Any, ): try: - tracer_factory_context_var.reset(self.token_pointer_map[(app_id, sequence_id)]) + tracer_factory_context_var.reset( + self.token_pointer_map[(app_id, sequence_id)] + ) except ValueError: # Token var = ContextVar created in a different context # This occurs when we're at the finally block of an async streaming action logger.debug( @@ -859,7 +895,9 @@ def __init__( ) self._state = state adapter_set = adapter_set if adapter_set is not None else LifecycleAdapterSet() - self._adapter_set = adapter_set.with_new_adapters(TracerFactoryContextHook(adapter_set)) + self._adapter_set = adapter_set.with_new_adapters( + TracerFactoryContextHook(adapter_set) + ) # TODO -- consider adding global inputs + global input factories to the builder self._tracker = tracker @@ -897,7 +935,9 @@ def __init__( # @telemetry.capture_function_usage # todo -- capture usage when we break this up into one that isn't called internally # This will be doable when we move sequence ID to the beginning of the function https://github.com/DAGWorks-Inc/burr/pull/73 @_call_execute_method_pre_post(ExecuteMethod.step) - def step(self, inputs: Optional[Dict[str, Any]] = None) -> Optional[Tuple[Action, dict, State]]: + def step( + self, inputs: Optional[Dict[str, Any]] = None + ) -> Optional[Tuple[Action, dict, State]]: """Performs a single step, advancing the state machine along. This returns a tuple of the action that was run, the result of running the action, and the new state. @@ -964,13 +1004,17 @@ def _step( result = _run_function( next_action, self._state, action_inputs, name=next_action.name ) - new_state = _run_reducer(next_action, self._state, result, next_action.name) + new_state = _run_reducer( + next_action, self._state, result, next_action.name + ) new_state = self._update_internal_state_value(new_state, next_action) self._set_state(new_state) except Exception as e: exc = e - logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) + logger.exception( + _format_BASE_ERROR_MESSAGE(next_action, self._state, inputs) + ) raise e finally: if _run_hooks: @@ -1004,7 +1048,9 @@ def _update_internal_state_value( def _process_inputs(self, inputs: Dict[str, Any], action: Action) -> Dict[str, Any]: """Processes inputs, injecting the common inputs and ensuring that all required inputs are present.""" - starting_with_double_underscore = {key for key in inputs.keys() if key.startswith("__")} + starting_with_double_underscore = { + key for key in inputs.keys() if key.startswith("__") + } if len(starting_with_double_underscore) > 0: raise ValueError( BASE_ERROR_MESSAGE @@ -1117,12 +1163,16 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True inputs=action_inputs, name=next_action.name, ) - new_state = _run_reducer(next_action, self._state, result, next_action.name) + new_state = _run_reducer( + next_action, self._state, result, next_action.name + ) new_state = self._update_internal_state_value(new_state, next_action) self._set_state(new_state) except Exception as e: exc = e - logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) + logger.exception( + _format_BASE_ERROR_MESSAGE(next_action, self._state, inputs) + ) raise e finally: if _run_hooks: @@ -1170,13 +1220,19 @@ def _process_control_flow_params( halt_after_actions, halt_after_tags = self._parse_action_list(halt_after) halt_before_expanded = set(halt_before_actions) for tag in halt_before_tags: - halt_before_expanded.update([item.name for item in self.graph.get_actions_by_tag(tag)]) + halt_before_expanded.update( + [item.name for item in self.graph.get_actions_by_tag(tag)] + ) halt_after_expanded = set(halt_after_actions) for tag in halt_after_tags: - halt_after_expanded.update([item.name for item in self.graph.get_actions_by_tag(tag)]) + halt_after_expanded.update( + [item.name for item in self.graph.get_actions_by_tag(tag)] + ) return list(halt_before_expanded), list(halt_after_expanded), inputs - def _validate_halt_conditions(self, halt_before: list[str], halt_after: list[str]) -> None: + def _validate_halt_conditions( + self, halt_before: list[str], halt_after: list[str] + ) -> None: """Utility function to validate halt conditions""" missing_actions = set(halt_before + halt_after) - set( action.name for action in self.graph.actions @@ -1341,7 +1397,9 @@ def run( :return: The final state, and the results of running the actions in the order that they were specified. """ self.validate_correct_async_use() - gen = self.iterate(halt_before=halt_before, halt_after=halt_after, inputs=inputs) + gen = self.iterate( + halt_before=halt_before, halt_after=halt_after, inputs=inputs + ) while True: try: next(gen) @@ -1387,7 +1445,9 @@ def stream_result( halt_after: list[str], halt_before: Optional[list[str]] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]]]: + ) -> Tuple[ + Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]] + ]: """Streams a result out. :param halt_after: The list of actions/tags to halt after execution of. It will halt on the first one. @@ -1489,7 +1549,9 @@ def streaming_response(state: State, prompt: str) -> Generator[dict, None, Tuple print(format(result['response'], color)) """ self.validate_correct_async_use() - call_execute_method_wrapper = _call_execute_method_pre_post(ExecuteMethod.stream_result) + call_execute_method_wrapper = _call_execute_method_pre_post( + ExecuteMethod.stream_result + ) call_execute_method_wrapper.call_pre(self) halt_before, halt_after, inputs = self._process_control_flow_params( halt_before, halt_after, inputs @@ -1639,7 +1701,9 @@ async def astream_result( halt_after: list[str], halt_before: Optional[list[str]] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]]]: + ) -> Tuple[ + Action, AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]] + ]: """Streams a result out in an asynchronous manner. :param halt_after: The list of actions/tags to halt after execution of. It will halt on the first one. @@ -1743,7 +1807,9 @@ async def streaming_response(state: State, prompt: str) -> Generator[dict, None, result, state = await output.get() print(format(result['response'], color)) """ - call_execute_method_wrapper = _call_execute_method_pre_post(ExecuteMethod.stream_result) + call_execute_method_wrapper = _call_execute_method_pre_post( + ExecuteMethod.stream_result + ) await call_execute_method_wrapper.acall_pre(self) halt_before, halt_after, inputs = self._process_control_flow_params( halt_before, halt_after, inputs @@ -1816,7 +1882,9 @@ async def callback( if not next_action.streaming: # In this case we are halting at a non-streaming condition # This is allowed as we want to maintain a more consistent API - action, result, state = await self._astep(inputs=inputs, _run_hooks=False) + action, result, state = await self._astep( + inputs=inputs, _run_hooks=False + ) await self._adapter_set.call_all_lifecycle_hooks_sync_and_async( "post_run_step", app_id=self._uid, @@ -1914,7 +1982,9 @@ def stream_iterate( halt_before: Optional[Union[str, List[str]]] = None, inputs: Optional[Dict[str, Any]] = None, ) -> Generator[ - Tuple[Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]]], None, None + Tuple[Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]]], + None, + None, ]: """Produces an iterator that iterates through intermediate streams. You may want to use this in something like deep research mode in which: @@ -1958,7 +2028,11 @@ async def astream_iterate( halt_before: Optional[Union[str, List[str]]] = None, inputs: Optional[Dict[str, Any]] = None, ) -> AsyncGenerator[ - Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]]], None + Tuple[ + Action, + AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]], + ], + None, ]: """Async version of stream_iterate. Produces an async generator that iterates through intermediate streams. See stream_iterate for more details. @@ -2021,7 +2095,9 @@ def _set_state(self, new_state: State[ApplicationStateType]): self._state = new_state def get_next_action(self) -> Optional[Action]: - return self._graph.get_next_node(self._state.get(PRIOR_STEP), self._state, self.entrypoint) + return self._graph.get_next_node( + self._state.get(PRIOR_STEP), self._state, self.entrypoint + ) def update_state(self, new_state: State[ApplicationStateType]): """Updates state -- this is meant to be called if you need to do @@ -2373,7 +2449,9 @@ def with_actions( :return: The application builder for future chaining. """ self._initialize_graph_builder() - self.graph_builder = self.graph_builder.with_actions(*action_list, **action_dict) + self.graph_builder = self.graph_builder.with_actions( + *action_list, **action_dict + ) return self def with_transitions( @@ -2398,7 +2476,9 @@ def with_transitions( self.graph_builder = self.graph_builder.with_transitions(*transitions) return self - def with_hooks(self, *adapters: LifecycleAdapter) -> "ApplicationBuilder[StateType]": + def with_hooks( + self, *adapters: LifecycleAdapter + ) -> "ApplicationBuilder[StateType]": """Adds a lifecycle adapter to the application. This is a way to add hooks to the application so that they are run at the appropriate times. You can use this to synchronize state out, log results, etc... @@ -2453,7 +2533,9 @@ def with_tracker( if use_otel_tracing: from burr.integrations.opentelemetry import OpenTelemetryTracker - instantiated_tracker = OpenTelemetryTracker(burr_tracker=instantiated_tracker) + instantiated_tracker = OpenTelemetryTracker( + burr_tracker=instantiated_tracker + ) self.lifecycle_adapters.append(instantiated_tracker) self.tracker = instantiated_tracker return self @@ -2527,7 +2609,9 @@ def with_state_persister( if on_every != "step": raise ValueError(f"on_every {on_every} not supported") - self.state_persister = persister # tracks for later; validates in build / abuild + self.state_persister = ( + persister # tracks for later; validates in build / abuild + ) return self def with_spawning_parent( @@ -2574,7 +2658,9 @@ def _set_sync_state_persister(self): ) except NotImplementedError: pass - self.lifecycle_adapters.append(persistence.PersisterHook(self.state_persister)) + self.lifecycle_adapters.append( + persistence.PersisterHook(self.state_persister) + ) async def _set_async_state_persister(self): """Inits the asynchronous with_state_persister to save the state (local/DB/custom implementations). @@ -2599,7 +2685,9 @@ async def _set_async_state_persister(self): ) except NotImplementedError: pass - self.lifecycle_adapters.append(persistence.PersisterHookAsync(self.state_persister)) + self.lifecycle_adapters.append( + persistence.PersisterHookAsync(self.state_persister) + ) def _identify_state_to_load(self): """Helper to determine which state to load.""" @@ -2697,7 +2785,9 @@ def _load_from_sync_persister(self): # load state from persister load_result = self.state_initializer.load(_partition_key, _app_id, _sequence_id) - self._init_state_from_persister(load_result, _partition_key, _app_id, _sequence_id) + self._init_state_from_persister( + load_result, _partition_key, _app_id, _sequence_id + ) async def _load_from_async_persister(self): """Loads from the set async persister and into this current object. @@ -2718,8 +2808,12 @@ async def _load_from_async_persister(self): _partition_key, _app_id, _sequence_id = self._identify_state_to_load() # load state from persister - load_result = await self.state_initializer.load(_partition_key, _app_id, _sequence_id) - self._init_state_from_persister(load_result, _partition_key, _app_id, _sequence_id) + load_result = await self.state_initializer.load( + _partition_key, _app_id, _sequence_id + ) + self._init_state_from_persister( + load_result, _partition_key, _app_id, _sequence_id + ) def reset_to_entrypoint(self): self.state = self.state.wipe(delete=[PRIOR_STEP]) @@ -2737,7 +2831,9 @@ def _build_common(self) -> Application: graph = self._get_built_graph() _validate_start(self.start, {action.name for action in graph.actions}) typing_system: TypingSystem[StateType] = ( - self.typing_system if self.typing_system is not None else DictBasedTypingSystem() + self.typing_system + if self.typing_system is not None + else DictBasedTypingSystem() ) # type: ignore self.state = self.state.with_typing_system(typing_system=typing_system) return Application( diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 587c151f1..d025f7f29 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -21,7 +21,17 @@ import logging import typing import uuid -from typing import Any, Awaitable, Callable, Dict, Generator, Literal, Optional, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generator, + Literal, + Optional, + Tuple, + Union, +) import pytest @@ -163,7 +173,9 @@ async def run(self, state: State, **run_kwargs) -> dict: ) -class StreamEventCaptureTracker(PreStartStreamHook, PostStreamItemHook, PostEndStreamHook): +class StreamEventCaptureTracker( + PreStartStreamHook, PostStreamItemHook, PostEndStreamHook +): def post_end_stream( self, *, @@ -440,7 +452,9 @@ def test__run_function_with_inputs(): """Tests that we can run a function""" action = base_counter_action_with_inputs state = State({}) - result = _run_function(action, state, inputs={"additional_increment": 1}, name=action.name) + result = _run_function( + action, state, inputs={"additional_increment": 1}, name=action.name + ) assert result == {"count": 2} @@ -596,7 +610,9 @@ class BrokenAction(SingleStepAction): def reads(self) -> list[str]: return [] - async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: + async def run_and_update( + self, state: State, **run_kwargs + ) -> Tuple[dict, State]: await asyncio.sleep(0.0001) # just so we can make this *truly* async return {"present_value": 1}, state.update(present_value=1) @@ -630,7 +646,12 @@ def writes(self) -> list[str]: state = State() with pytest.raises(ValueError, match="missing_value"): gen = _run_single_step_streaming_action( - action, state, inputs={}, sequence_id=0, partition_key="partition_key", app_id="app_id" + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", ) collections.deque(gen, maxlen=0) # exhaust the generator @@ -687,7 +708,12 @@ def writes(self) -> list[str]: state = State() with pytest.raises(ValueError, match="missing_value"): gen = _run_multi_step_streaming_action( - action, state, inputs={}, sequence_id=0, partition_key="partition_key", app_id="app_id" + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", ) collections.deque(gen, maxlen=0) # exhaust the generator @@ -774,7 +800,9 @@ def update(self, result: dict, state: State) -> State: class AsyncStreamingCounter(AsyncStreamingAction): - async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: + async def stream_run( + self, state: State, **run_kwargs + ) -> AsyncGenerator[dict, None]: if "steps_per_count" in run_kwargs: steps_per_count = run_kwargs["granularity"] else: @@ -806,7 +834,9 @@ def stream_run_and_update( count = state["count"] for i in range(steps_per_count): yield {"count": count + ((i + 1) / 10)}, None - yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + yield {"count": count + 1}, state.update(count=count + 1).append( + tracker=count + 1 + ) @property def reads(self) -> list[str]: @@ -827,7 +857,9 @@ async def stream_run_and_update( await asyncio.sleep(0.01) yield {"count": count + ((i + 1) / 10)}, None await asyncio.sleep(0.01) - yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + yield {"count": count + 1}, state.update(count=count + 1).append( + tracker=count + 1 + ) @property def reads(self) -> list[str]: @@ -856,7 +888,9 @@ def update(self, result: dict, state: State) -> State: class StreamingActionIncorrectResultTypeAsync(AsyncStreamingAction): - async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: + async def stream_run( + self, state: State, **run_kwargs + ) -> AsyncGenerator[dict, None]: yield {} yield "not a dict" @@ -916,7 +950,9 @@ def writes(self) -> list[str]: base_streaming_single_step_counter_async = SingleStepStreamingCounterAsync() base_single_step_action_incorrect_result_type = SingleStepActionIncorrectResultType() -base_single_step_action_incorrect_result_type_async = SingleStepActionIncorrectResultTypeAsync() +base_single_step_action_incorrect_result_type_async = ( + SingleStepActionIncorrectResultTypeAsync() +) def test__run_single_step_action(): @@ -947,10 +983,14 @@ async def test__arun_single_step_action_incorrect_result_type(): def test__run_single_step_action_with_inputs(): action = base_single_step_counter_with_inputs.with_name("counter") state = State({"count": 0, "tracker": []}) - result, state = _run_single_step_action(action, state, inputs={"additional_increment": 1}) + result, state = _run_single_step_action( + action, state, inputs={"additional_increment": 1} + ) assert result == {"count": 2} assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [2]} - result, state = _run_single_step_action(action, state, inputs={"additional_increment": 1}) + result, state = _run_single_step_action( + action, state, inputs={"additional_increment": 1} + ) assert result == {"count": 4} assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} @@ -1005,7 +1045,12 @@ def test__run_multistep_streaming_action(): action = base_streaming_counter.with_name("counter") state = State({"count": 0, "tracker": []}) generator = _run_multi_step_streaming_action( - action, state, inputs={}, sequence_id=0, partition_key="partition_key", app_id="app_id" + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", ) last_result = -1 result = None @@ -1111,7 +1156,12 @@ def test__run_streaming_action_incorrect_result_type(): state = State() with pytest.raises(ValueError, match="returned a non-dict"): gen = _run_multi_step_streaming_action( - action, state, inputs={}, sequence_id=0, partition_key="partition_key", app_id="app_id" + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", ) collections.deque(gen, maxlen=0) # exhaust the generator @@ -1168,7 +1218,12 @@ def test__run_single_step_streaming_action(): action = base_streaming_single_step_counter.with_name("counter") state = State({"count": 0, "tracker": []}) generator = _run_single_step_streaming_action( - action, state, inputs={}, sequence_id=0, partition_key="partition_key", app_id="app_id" + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", ) last_result = -1 result, state = None, None @@ -1287,7 +1342,9 @@ def stream_run_and_update( yield {"count": count + ((i + 1) / 10)}, None raise RuntimeError("simulated failure") finally: - yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + yield {"count": count + 1}, state.update(count=count + 1).append( + tracker=count + 1 + ) @property def reads(self) -> list[str]: @@ -1330,7 +1387,9 @@ async def stream_run_and_update( yield {"count": count + ((i + 1) / 10)}, None raise RuntimeError("simulated failure") finally: - yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + yield {"count": count + 1}, state.update(count=count + 1).append( + tracker=count + 1 + ) @property def reads(self) -> list[str]: @@ -1407,7 +1466,9 @@ def update(self, result: dict, state: State) -> State: class MultiStepStreamingCounterWithExceptionAsync(AsyncStreamingAction): """Async variant: yields intermediate items, raises, then yields final result in finally.""" - async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: + async def stream_run( + self, state: State, **run_kwargs + ) -> AsyncGenerator[dict, None]: count = state["count"] try: for i in range(3): @@ -1431,7 +1492,9 @@ def update(self, result: dict, state: State) -> State: class MultiStepStreamingCounterWithExceptionNoResultAsync(AsyncStreamingAction): """Async variant: raises without ever yielding any item.""" - async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: + async def stream_run( + self, state: State, **run_kwargs + ) -> AsyncGenerator[dict, None]: raise RuntimeError("simulated failure with no result") yield # make this an async generator @@ -1460,7 +1523,10 @@ def test__run_single_step_streaming_action_graceful_exception(): assert len(intermediate) == 3 assert len(final) == 1 assert final[0][0] == {"count": 1} - assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + assert final[0][1].subset("count", "tracker").get_all() == { + "count": 1, + "tracker": [1], + } def test__run_single_step_streaming_action_exception_propagates(): @@ -1479,7 +1545,12 @@ async def test__run_single_step_streaming_action_graceful_exception_async(): action = SingleStepStreamingCounterWithExceptionAsync().with_name("counter") state = State({"count": 0, "tracker": []}) generator = _arun_single_step_streaming_action( - action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", lifecycle_adapters=LifecycleAdapterSet(), ) results = [] @@ -1490,7 +1561,10 @@ async def test__run_single_step_streaming_action_graceful_exception_async(): assert len(intermediate) == 3 assert len(final) == 1 assert final[0][0] == {"count": 1} - assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + assert final[0][1].subset("count", "tracker").get_all() == { + "count": 1, + "tracker": [1], + } async def test__run_single_step_streaming_action_exception_propagates_async(): @@ -1498,7 +1572,12 @@ async def test__run_single_step_streaming_action_exception_propagates_async(): action = SingleStepStreamingCounterWithExceptionNoStateAsync().with_name("counter") state = State({"count": 0, "tracker": []}) generator = _arun_single_step_streaming_action( - action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", lifecycle_adapters=LifecycleAdapterSet(), ) with pytest.raises(RuntimeError, match="simulated failure with no state"): @@ -1519,7 +1598,10 @@ def test__run_multi_step_streaming_action_graceful_exception(): assert len(intermediate) == 3 assert len(final) == 1 assert final[0][0] == {"count": 1} - assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + assert final[0][1].subset("count", "tracker").get_all() == { + "count": 1, + "tracker": [1], + } def test__run_multi_step_streaming_action_exception_propagates(): @@ -1538,7 +1620,12 @@ async def test__run_multi_step_streaming_action_graceful_exception_async(): action = MultiStepStreamingCounterWithExceptionAsync().with_name("counter") state = State({"count": 0, "tracker": []}) generator = _arun_multi_step_streaming_action( - action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", lifecycle_adapters=LifecycleAdapterSet(), ) results = [] @@ -1549,7 +1636,10 @@ async def test__run_multi_step_streaming_action_graceful_exception_async(): assert len(intermediate) == 3 assert len(final) == 1 assert final[0][0] == {"count": 1} - assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + assert final[0][1].subset("count", "tracker").get_all() == { + "count": 1, + "tracker": [1], + } async def test__run_multi_step_streaming_action_exception_propagates_async(): @@ -1557,7 +1647,12 @@ async def test__run_multi_step_streaming_action_exception_propagates_async(): action = MultiStepStreamingCounterWithExceptionNoResultAsync().with_name("counter") state = State({"count": 0, "tracker": []}) generator = _arun_multi_step_streaming_action( - action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", lifecycle_adapters=LifecycleAdapterSet(), ) with pytest.raises(RuntimeError, match="simulated failure with no result"): @@ -1595,7 +1690,9 @@ def test_app_step(): assert app.sequence_id == 1 assert action.name == "counter" assert result == {"count": 1} - assert state[PRIOR_STEP] == "counter" # internal contract, not part of the public API + assert ( + state[PRIOR_STEP] == "counter" + ) # internal contract, not part of the public API def test_app_step_with_inputs(): @@ -1650,7 +1747,9 @@ def test_app_step_broken(caplog): transitions=[Transition(broken_action, broken_action, default)], ), ) - with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now + with caplog.at_level( + logging.ERROR + ): # it should say the name, that's the only contract for now with pytest.raises(BrokenStepException): app.step() assert "broken_action_unique_name" in caplog.text @@ -1692,7 +1791,9 @@ async def test_app_astep(): assert app.sequence_id == 1 assert action.name == "counter_async" assert result == {"count": 1} - assert state[PRIOR_STEP] == "counter_async" # internal contract, not part of the public API + assert ( + state[PRIOR_STEP] == "counter_async" + ) # internal contract, not part of the public API def test_app_step_context(): @@ -1750,7 +1851,9 @@ def test_action(state: State, __context: ApplicationContext) -> State: async def test_app_astep_with_inputs(): """Tests that we can run an async step in an app""" - counter_action = base_single_step_counter_with_inputs_async.with_name("counter_async") + counter_action = base_single_step_counter_with_inputs_async.with_name( + "counter_async" + ) app = Application( state=State({"count": 0, "tracker": []}), entrypoint="counter_async", @@ -1770,7 +1873,9 @@ async def test_app_astep_with_inputs(): async def test_app_astep_with_inputs_missing(): """Tests that we can run an async step in an app""" - counter_action = base_single_step_counter_with_inputs_async.with_name("counter_async") + counter_action = base_single_step_counter_with_inputs_async.with_name( + "counter_async" + ) app = Application( state=State({"count": 0, "tracker": []}), entrypoint="counter_async", @@ -1800,7 +1905,9 @@ async def test_app_astep_broken(caplog): transitions=[Transition(broken_action, broken_action, default)], ), ) - with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now + with caplog.at_level( + logging.ERROR + ): # it should say the name, that's the only contract for now with pytest.raises(BrokenStepException): await app.astep() assert "broken_action_unique_name" in caplog.text @@ -1877,7 +1984,9 @@ def test_iterate(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, counter_action, Condition.expr("count < 10")), + Transition( + counter_action, counter_action, Condition.expr("count < 10") + ), Transition(counter_action, result_action, default), ], ), @@ -1946,7 +2055,9 @@ async def test_aiterate(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, counter_action, Condition.expr("count < 10")), + Transition( + counter_action, counter_action, Condition.expr("count < 10") + ), Transition(counter_action, result_action, default), ], ), @@ -1978,7 +2089,9 @@ async def test_aiterate_halt_before(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, counter_action, Condition.expr("count < 10")), + Transition( + counter_action, counter_action, Condition.expr("count < 10") + ), Transition(counter_action, result_action, default), ], ), @@ -2008,7 +2121,9 @@ async def test_app_aiterate_with_inputs(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, counter_action, Condition.expr("count < 10")), + Transition( + counter_action, counter_action, Condition.expr("count < 10") + ), Transition(counter_action, result_action, default), ], ), @@ -2033,7 +2148,9 @@ def test_run(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, counter_action, Condition.expr("count < 10")), + Transition( + counter_action, counter_action, Condition.expr("count < 10") + ), Transition(counter_action, result_action, default), ], ), @@ -2055,7 +2172,9 @@ def test_run_halt_before(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, counter_action, Condition.expr("count < 10")), + Transition( + counter_action, counter_action, Condition.expr("count < 10") + ), Transition(counter_action, result_action, default), ], ), @@ -2078,12 +2197,16 @@ def test_run_with_inputs(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, counter_action, Condition.expr("count < 10")), + Transition( + counter_action, counter_action, Condition.expr("count < 10") + ), Transition(counter_action, result_action, default), ], ), ) - action_, result, state = app.run(halt_after=["result"], inputs={"additional_increment": 10}) + action_, result, state = app.run( + halt_after=["result"], inputs={"additional_increment": 10} + ) assert action_.name == "result" assert state["count"] == result["count"] == 11 @@ -2102,14 +2225,22 @@ def test_run_with_inputs_multiple_actions(): graph=Graph( actions=[counter_action1, counter_action2, result_action], transitions=[ - Transition(counter_action1, counter_action1, Condition.expr("count < 10")), - Transition(counter_action1, counter_action2, Condition.expr("count >= 10")), - Transition(counter_action2, counter_action2, Condition.expr("count < 20")), + Transition( + counter_action1, counter_action1, Condition.expr("count < 10") + ), + Transition( + counter_action1, counter_action2, Condition.expr("count >= 10") + ), + Transition( + counter_action2, counter_action2, Condition.expr("count < 20") + ), Transition(counter_action2, result_action, default), ], ), ) - action_, result, state = app.run(halt_after=["result"], inputs={"additional_increment": 8}) + action_, result, state = app.run( + halt_after=["result"], inputs={"additional_increment": 8} + ) assert action_.name == "result" assert state["count"] == result["count"] == 27 assert state["__SEQUENCE_ID"] == 4 @@ -2127,7 +2258,9 @@ async def test_arun(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, counter_action, Condition.expr("count < 10")), + Transition( + counter_action, counter_action, Condition.expr("count < 10") + ), Transition(counter_action, result_action, default), ], ), @@ -2149,7 +2282,9 @@ async def test_arun_halt_before(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, counter_action, Condition.expr("count < 10")), + Transition( + counter_action, counter_action, Condition.expr("count < 10") + ), Transition(counter_action, result_action, default), ], ), @@ -2172,7 +2307,9 @@ async def test_arun_with_inputs(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, counter_action, Condition.expr("count < 10")), + Transition( + counter_action, counter_action, Condition.expr("count < 10") + ), Transition(counter_action, result_action, default), ], ), @@ -2197,9 +2334,15 @@ async def test_arun_with_inputs_multiple_actions(): graph=Graph( actions=[counter_action1, counter_action2, result_action], transitions=[ - Transition(counter_action1, counter_action1, Condition.expr("count < 10")), - Transition(counter_action1, counter_action2, Condition.expr("count >= 10")), - Transition(counter_action2, counter_action2, Condition.expr("count < 20")), + Transition( + counter_action1, counter_action1, Condition.expr("count < 10") + ), + Transition( + counter_action1, counter_action2, Condition.expr("count >= 10") + ), + Transition( + counter_action2, counter_action2, Condition.expr("count < 20") + ), Transition(counter_action2, result_action, default), ], ), @@ -2225,7 +2368,11 @@ async def test_app_a_run_async_and_sync(): graph=Graph( actions=[counter_action_sync, counter_action_async, result_action], transitions=[ - Transition(counter_action_sync, counter_action_async, Condition.expr("count < 20")), + Transition( + counter_action_sync, + counter_action_async, + Condition.expr("count < 20"), + ), Transition(counter_action_async, counter_action_sync, default), Transition(counter_action_sync, result_action, default), ], @@ -2304,7 +2451,9 @@ async def test_astream_result_halt_after_unique_ordered_sequence_id(): ], ), ) - action_, streaming_async_container = await app.astream_result(halt_after=["counter_2"]) + action_, streaming_async_container = await app.astream_result( + halt_after=["counter_2"] + ) results = [ item async for item in streaming_async_container ] # this should just have the intermediate results @@ -2350,7 +2499,8 @@ async def test_astream_result_halt_after_unique_ordered_sequence_id(): def test_stream_result_halt_after_run_through_streaming(): """Tests that we can pass through streaming results, - fully realize them, then get to the streaming results at the end and return the stream""" + fully realize them, then get to the streaming results at the end and return the stream + """ action_tracker = CallCaptureTracker() stream_event_tracker = StreamEventCaptureTracker() counter_action = base_streaming_single_step_counter.with_name("counter") @@ -2577,7 +2727,9 @@ def test_stream_result_halt_after_run_through_non_streaming(): action_tracker = CallCaptureTracker() stream_event_tracker = StreamEventCaptureTracker() counter_non_streaming = base_counter_action.with_name("counter_non_streaming") - counter_streaming = base_streaming_single_step_counter.with_name("counter_streaming") + counter_streaming = base_streaming_single_step_counter.with_name( + "counter_streaming" + ) app = Application( state=State({"count": 0}), @@ -2588,7 +2740,9 @@ def test_stream_result_halt_after_run_through_non_streaming(): graph=Graph( actions=[counter_non_streaming, counter_streaming], transitions=[ - Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), + Transition( + counter_non_streaming, counter_non_streaming, expr("count < 10") + ), Transition(counter_non_streaming, counter_streaming, default), ], ), @@ -2630,7 +2784,9 @@ async def test_astream_result_halt_after_run_through_non_streaming(): stream_event_tracker = StreamEventCaptureTrackerAsync() stream_event_tracker_sync = StreamEventCaptureTracker() counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") - counter_streaming = base_streaming_single_step_counter_async.with_name("counter_streaming") + counter_streaming = base_streaming_single_step_counter_async.with_name( + "counter_streaming" + ) app = Application( state=State({"count": 0}), @@ -2643,12 +2799,16 @@ async def test_astream_result_halt_after_run_through_non_streaming(): graph=Graph( actions=[counter_non_streaming, counter_streaming], transitions=[ - Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), + Transition( + counter_non_streaming, counter_non_streaming, expr("count < 10") + ), Transition(counter_non_streaming, counter_streaming, default), ], ), ) - action_, async_streaming_container = await app.astream_result(halt_after=["counter_streaming"]) + action_, async_streaming_container = await app.astream_result( + halt_after=["counter_streaming"] + ) results = [ item async for item in async_streaming_container ] # this should just have the intermediate results @@ -2698,7 +2858,9 @@ def test_stream_result_halt_after_run_through_final_non_streaming(): """Tests that we can pass through non-streaming results when streaming is called""" action_tracker = CallCaptureTracker() counter_non_streaming = base_counter_action.with_name("counter_non_streaming") - counter_final_non_streaming = base_counter_action.with_name("counter_final_non_streaming") + counter_final_non_streaming = base_counter_action.with_name( + "counter_final_non_streaming" + ) app = Application( state=State({"count": 0}), @@ -2709,12 +2871,16 @@ def test_stream_result_halt_after_run_through_final_non_streaming(): graph=Graph( actions=[counter_non_streaming, counter_final_non_streaming], transitions=[ - Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), + Transition( + counter_non_streaming, counter_non_streaming, expr("count < 10") + ), Transition(counter_non_streaming, counter_final_non_streaming, default), ], ), ) - action, streaming_container = app.stream_result(halt_after=["counter_final_non_streaming"]) + action, streaming_container = app.stream_result( + halt_after=["counter_final_non_streaming"] + ) results = list(streaming_container) assert len(results) == 0 # nothing to steram result, state = streaming_container.get() @@ -2744,7 +2910,9 @@ async def test_astream_result_halt_after_run_through_final_streaming(): action_tracker = CallCaptureTracker() counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") - counter_final_non_streaming = base_counter_action_async.with_name("counter_final_non_streaming") + counter_final_non_streaming = base_counter_action_async.with_name( + "counter_final_non_streaming" + ) app = Application( state=State({"count": 0}), @@ -2755,7 +2923,9 @@ async def test_astream_result_halt_after_run_through_final_streaming(): graph=Graph( actions=[counter_non_streaming, counter_final_non_streaming], transitions=[ - Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), + Transition( + counter_non_streaming, counter_non_streaming, expr("count < 10") + ), Transition(counter_non_streaming, counter_final_non_streaming, default), ], ), @@ -2803,12 +2973,16 @@ def test_stream_result_halt_before(): graph=Graph( actions=[counter_non_streaming, counter_streaming], transitions=[ - Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), + Transition( + counter_non_streaming, counter_non_streaming, expr("count < 10") + ), Transition(counter_non_streaming, counter_streaming, default), ], ), ) - action, streaming_container = app.stream_result(halt_after=[], halt_before=["counter_final"]) + action, streaming_container = app.stream_result( + halt_after=[], halt_before=["counter_final"] + ) results = list(streaming_container) assert len(results) == 0 # nothing to steram result, state = streaming_container.get() @@ -2828,7 +3002,9 @@ def test_stream_result_halt_before(): async def test_astream_result_halt_before(): action_tracker = CallCaptureTracker() counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") - counter_streaming = base_streaming_single_step_counter_async.with_name("counter_final") + counter_streaming = base_streaming_single_step_counter_async.with_name( + "counter_final" + ) app = Application( state=State({"count": 0}), @@ -2839,7 +3015,9 @@ async def test_astream_result_halt_before(): graph=Graph( actions=[counter_non_streaming, counter_streaming], transitions=[ - Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), + Transition( + counter_non_streaming, counter_non_streaming, expr("count < 10") + ), Transition(counter_non_streaming, counter_streaming, default), ], ), @@ -2947,7 +3125,9 @@ def test__validate_start_not_found(): def test__adjust_single_step_output_result_and_state(): state = State({"count": 1}) result = {"count": 1} - assert _adjust_single_step_output((result, state), "test_action", DEFAULT_SCHEMA) == ( + assert _adjust_single_step_output( + (result, state), "test_action", DEFAULT_SCHEMA + ) == ( result, state, ) @@ -2955,7 +3135,10 @@ def test__adjust_single_step_output_result_and_state(): def test__adjust_single_step_output_just_state(): state = State({"count": 1}) - assert _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) == ({}, state) + assert _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) == ( + {}, + state, + ) def test__adjust_single_step_output_errors_incorrect_type(): @@ -2990,7 +3173,9 @@ def test_application_run_step_hooks_sync(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, result_action, Condition.expr("count >= 10")), + Transition( + counter_action, result_action, Condition.expr("count >= 10") + ), Transition(counter_action, counter_action, default), ], ), @@ -3041,7 +3226,9 @@ async def test_application_run_step_hooks_async(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, result_action, Condition.expr("count >= 10")), + Transition( + counter_action, result_action, Condition.expr("count >= 10") + ), Transition(counter_action, counter_action, default), ], ), @@ -3155,7 +3342,9 @@ def post_application_create(self, **kwargs): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, result_action, Condition.expr("count >= 10")), + Transition( + counter_action, result_action, Condition.expr("count >= 10") + ), Transition(counter_action, counter_action, default), ], ), @@ -3177,7 +3366,9 @@ async def test_application_gives_graph(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition(counter_action, result_action, Condition.expr("count >= 10")), + Transition( + counter_action, result_action, Condition.expr("count >= 10") + ), Transition(counter_action, counter_action, default), ], ), @@ -3190,7 +3381,9 @@ async def test_application_gives_graph(): def test_application_builder_initialize_does_not_allow_state_setting(): with pytest.raises(ValueError, match="Cannot call initialize_from"): - ApplicationBuilder().with_entrypoint("foo").with_state(**{"foo": "bar"}).initialize_from( + ApplicationBuilder().with_entrypoint("foo").with_state( + **{"foo": "bar"} + ).initialize_from( DevNullPersister(), resume_at_next_action=True, default_state={}, @@ -3202,7 +3395,11 @@ class BrokenPersister(BaseStatePersister): """Broken persistor.""" def load( - self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs + self, + partition_key: str, + app_id: Optional[str], + sequence_id: Optional[int] = None, + **kwargs, ) -> Optional[PersistedStateData]: return dict( partition_key="key", @@ -3258,7 +3455,8 @@ def test_load_from_sync_cannot_have_async_persistor_error(): default_entrypoint="foo", ) with pytest.raises( - ValueError, match="are building the sync application, but have used an async initializer." + ValueError, + match="are building the sync application, but have used an async initializer.", ): # we have not initialized builder._load_from_sync_persister() @@ -3274,7 +3472,8 @@ async def test_load_from_async_cannot_have_sync_persistor_error(): default_entrypoint="foo", ) with pytest.raises( - ValueError, match="are building the async application, but have used an sync initializer." + ValueError, + match="are building the async application, but have used an sync initializer.", ): # we have not initialized await builder._load_from_async_persister() @@ -3317,7 +3516,9 @@ def test_action(state: State) -> State: .build() ) with pytest.raises(ValueError, match="(?=.*no_exist_1)(?=.*no_exist_2)"): - app._validate_halt_conditions(halt_after=["no_exist_1"], halt_before=["no_exist_2"]) + app._validate_halt_conditions( + halt_after=["no_exist_1"], halt_before=["no_exist_2"] + ) def test_application_builder_initialize_raises_on_fork_app_id_not_provided(): @@ -3345,7 +3546,11 @@ class DummyPersister(BaseStatePersister): """Dummy persistor.""" def load( - self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs + self, + partition_key: str, + app_id: Optional[str], + sequence_id: Optional[int] = None, + **kwargs, ) -> Optional[PersistedStateData]: return PersistedStateData( partition_key="user123", @@ -3416,7 +3621,9 @@ def test_application_builder_initialize_fork_app_id_happy_pth(): .build() ) assert app.uid != old_app_id - assert app.state == State({"count": 5, "__PRIOR_STEP": "counter", "__SEQUENCE_ID": 5}) + assert app.state == State( + {"count": 5, "__PRIOR_STEP": "counter", "__SEQUENCE_ID": 5} + ) assert app.parent_pointer.app_id == old_app_id @@ -3569,7 +3776,9 @@ def context_counter(state: State, __context: ApplicationContext) -> State: app = ( ApplicationBuilder() .with_actions(counter=context_counter, result=result_action) - .with_transitions(("counter", "counter", expr("count < 10")), ("counter", "result")) + .with_transitions( + ("counter", "counter", expr("count < 10")), ("counter", "result") + ) .with_tracker(NoOpTracker("unique_tracker_name")) .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) .with_entrypoint("counter") @@ -3600,7 +3809,9 @@ def context_counter(state: State, __context: ApplicationContext = None) -> State app = ( ApplicationBuilder() .with_actions(counter=context_counter, result=result_action) - .with_transitions(("counter", "counter", expr("count < 10")), ("counter", "result")) + .with_transitions( + ("counter", "counter", expr("count < 10")), ("counter", "result") + ) .with_tracker(NoOpTracker("unique_tracker_name")) .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) .with_entrypoint("counter") @@ -3706,7 +3917,10 @@ def post_run_execute_call( hook = TestingHook() foo = [] - @action(reads=["recursion_count", "total_count"], writes=["recursion_count", "total_count"]) + @action( + reads=["recursion_count", "total_count"], + writes=["recursion_count", "total_count"], + ) def recursive_action(state: State) -> State: foo.append(1) recursion_count = state["recursion_count"] @@ -3735,7 +3949,9 @@ def recursive_action(state: State) -> State: result = recursive_action(State({"recursion_count": 0, "total_count": 0})) # Basic sanity checks to demonstrate assert result["recursion_count"] == 5 - assert result["total_count"] == 63 # One for each of the calls (sum(2**n for n in range(6))) + assert ( + result["total_count"] == 63 + ) # One for each of the calls (sum(2**n for n in range(6))) assert ( len(hook.pre_called) == 62 ) # 63 - the initial one from the call to recursive_action outside the application @@ -3788,7 +4004,8 @@ def test_set_sync_state_persister_cannot_have_async_error(): persister = AsyncDevNullPersister() builder.with_state_persister(persister) with pytest.raises( - ValueError, match="are building the sync application, but have used an async persister." + ValueError, + match="are building the sync application, but have used an async persister.", ): # we have not initialized builder._set_sync_state_persister() @@ -3809,7 +4026,8 @@ async def test_set_async_state_persister_cannot_have_sync_error(): persister = DevNullPersister() builder.with_state_persister(persister) with pytest.raises( - ValueError, match="are building the async application, but have used an sync persister." + ValueError, + match="are building the async application, but have used an sync persister.", ): # we have not initialized await builder._set_async_state_persister() @@ -3910,21 +4128,39 @@ def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: def test_remap_context_variable_with_mangled_context_kwargs(): _action = ActionWithKwargs() - inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} - expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} - assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected + inputs = { + "__context": "context_value", + "other_key": "other_value", + "foo": "foo_value", + } + expected = { + "__context": "context_value", + "other_key": "other_value", + "foo": "foo_value", + } + assert ( + _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) + == expected + ) def test_remap_context_variable_with_mangled_context(): _action = ActionWithContext() - inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} + inputs = { + "__context": "context_value", + "other_key": "other_value", + "foo": "foo_value", + } expected = { f"_{ActionWithContext.__name__}__context": "context_value", "other_key": "other_value", "foo": "foo_value", } - assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected + assert ( + _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) + == expected + ) def test_remap_context_variable_with_mangled_contexttracer(): @@ -3942,14 +4178,28 @@ def test_remap_context_variable_with_mangled_contexttracer(): "foo": "foo_value", f"_{ActionWithContextTracer.__name__}__tracer": "tracer_value", } - assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected + assert ( + _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) + == expected + ) def test_remap_context_variable_without_mangled_context(): _action = ActionWithoutContext() - inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} - expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} - assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected + inputs = { + "__context": "context_value", + "other_key": "other_value", + "foo": "foo_value", + } + expected = { + "__context": "context_value", + "other_key": "other_value", + "foo": "foo_value", + } + assert ( + _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) + == expected + ) async def test_async_application_builder_initialize_raises_on_broken_persistor(): From bc13e383823eb9c55c5b8fe8a74154228c2f3044 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ahlert?= Date: Sat, 28 Mar 2026 16:09:32 -0300 Subject: [PATCH 4/4] style: apply pre-commit formatting (black 23.11.0, isort 5.12.0) --- burr/core/application.py | 156 +++++------------- tests/core/test_application.py | 288 ++++++++------------------------- 2 files changed, 108 insertions(+), 336 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index cca6feccc..fb460035c 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -70,12 +70,7 @@ from burr.core.state import State from burr.core.typing import ActionSchema, DictBasedTypingSystem, TypingSystem from burr.core.validation import BASE_ERROR_MESSAGE -from burr.lifecycle.base import ( - ExecuteMethod, - LifecycleAdapter, - PostRunStepHook, - PreRunStepHook, -) +from burr.lifecycle.base import ExecuteMethod, LifecycleAdapter, PostRunStepHook, PreRunStepHook from burr.lifecycle.internal import LifecycleAdapterSet from burr.visibility import tracing from burr.visibility.tracing import tracer_factory_context_var @@ -94,9 +89,7 @@ StateTypeToSet = TypeVar("StateTypeToSet") -def _validate_result( - result: Any, name: str, schema: ActionSchema = DEFAULT_SCHEMA -) -> None: +def _validate_result(result: Any, name: str, schema: ActionSchema = DEFAULT_SCHEMA) -> None: # TODO -- split out the action schema into input/output schema types # Currently they're tied together, but this doesn't make as much sense for single-step actions result_type = schema.intermediate_result_type() @@ -165,9 +158,7 @@ def _remap_dunder_parameters( return inputs -def _run_function( - function: Function, state: State, inputs: Dict[str, Any], name: str -) -> dict: +def _run_function(function: Function, state: State, inputs: Dict[str, Any], name: str) -> dict: """Runs a function, returning the result of running the function. Note this restricts the keys in the state to only those that the function reads. @@ -187,9 +178,7 @@ def _run_function( function.validate_inputs(inputs) if "__context" in inputs or "__tracer" in inputs: # potentially need to remap the __context & __tracer variables - inputs = _remap_dunder_parameters( - function.run, inputs, ["__context", "__tracer"] - ) + inputs = _remap_dunder_parameters(function.run, inputs, ["__context", "__tracer"]) result = function.run(state_to_use, **inputs) _validate_result(result, name) return result @@ -234,14 +223,11 @@ def _state_update(state_to_modify: State, modified_state: State) -> State: deleted_keys = [ item for item in ( - set(state_to_modify.keys()) - - set(modified_state_without_private_fields.keys()) + set(state_to_modify.keys()) - set(modified_state_without_private_fields.keys()) ) if not item.startswith("__") ] - return state_to_modify.merge(modified_state_without_private_fields).wipe( - delete=deleted_keys - ) + return state_to_modify.merge(modified_state_without_private_fields).wipe(delete=deleted_keys) def _validate_reducer_writes(reducer: Reducer, state: State, name: str) -> None: @@ -306,9 +292,7 @@ def _format_BASE_ERROR_MESSAGE(action: Action, input_state: State, inputs: dict) message += f"> Action: `{action.name}` encountered an error!" padding = " " * (80 - len(message) - 1) message += padding + "<" - message += "\n> State (at time of action):\n" + _create_dict_string( - input_state.get_all() - ) + message += "\n> State (at time of action):\n" + _create_dict_string(input_state.get_all()) message += "\n> Inputs (at time of action):\n" + _create_dict_string(inputs) border = "*" * 80 return "\n" + border + "\n" + message + "\n" + border @@ -714,12 +698,9 @@ def call_pre(self, app) -> bool: def call_post(self, app, exc) -> bool: if should_run_hooks := ( - app.uid in _run_call_var.get(dict) - and _run_call_var.get()[app.uid] == self.method + app.uid in _run_call_var.get(dict) and _run_call_var.get()[app.uid] == self.method ): - _run_call_var.set( - {k: v for k, v in _run_call_var.get().items() if k != app.uid} - ) + _run_call_var.set({k: v for k, v in _run_call_var.get().items() if k != app.uid}) app._adapter_set.call_all_lifecycle_hooks_sync( "post_run_execute_call", app_id=app.uid, @@ -744,12 +725,9 @@ async def acall_pre(self, app) -> bool: async def acall_post(self, app, exc) -> bool: if should_run_hooks := ( - app.uid in _run_call_var.get(dict) - and _run_call_var.get()[app.uid] == self.method + app.uid in _run_call_var.get(dict) and _run_call_var.get()[app.uid] == self.method ): - _run_call_var.set( - {k: v for k, v in _run_call_var.get().items() if k != app.uid} - ) + _run_call_var.set({k: v for k, v in _run_call_var.get().items() if k != app.uid}) await app._adapter_set.call_all_lifecycle_hooks_sync_and_async( "post_run_execute_call", app_id=app._uid, @@ -830,9 +808,7 @@ def post_run_step( **future_kwargs: Any, ): try: - tracer_factory_context_var.reset( - self.token_pointer_map[(app_id, sequence_id)] - ) + tracer_factory_context_var.reset(self.token_pointer_map[(app_id, sequence_id)]) except ValueError: # Token var = ContextVar created in a different context # This occurs when we're at the finally block of an async streaming action logger.debug( @@ -895,9 +871,7 @@ def __init__( ) self._state = state adapter_set = adapter_set if adapter_set is not None else LifecycleAdapterSet() - self._adapter_set = adapter_set.with_new_adapters( - TracerFactoryContextHook(adapter_set) - ) + self._adapter_set = adapter_set.with_new_adapters(TracerFactoryContextHook(adapter_set)) # TODO -- consider adding global inputs + global input factories to the builder self._tracker = tracker @@ -935,9 +909,7 @@ def __init__( # @telemetry.capture_function_usage # todo -- capture usage when we break this up into one that isn't called internally # This will be doable when we move sequence ID to the beginning of the function https://github.com/DAGWorks-Inc/burr/pull/73 @_call_execute_method_pre_post(ExecuteMethod.step) - def step( - self, inputs: Optional[Dict[str, Any]] = None - ) -> Optional[Tuple[Action, dict, State]]: + def step(self, inputs: Optional[Dict[str, Any]] = None) -> Optional[Tuple[Action, dict, State]]: """Performs a single step, advancing the state machine along. This returns a tuple of the action that was run, the result of running the action, and the new state. @@ -1004,17 +976,13 @@ def _step( result = _run_function( next_action, self._state, action_inputs, name=next_action.name ) - new_state = _run_reducer( - next_action, self._state, result, next_action.name - ) + new_state = _run_reducer(next_action, self._state, result, next_action.name) new_state = self._update_internal_state_value(new_state, next_action) self._set_state(new_state) except Exception as e: exc = e - logger.exception( - _format_BASE_ERROR_MESSAGE(next_action, self._state, inputs) - ) + logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) raise e finally: if _run_hooks: @@ -1048,9 +1016,7 @@ def _update_internal_state_value( def _process_inputs(self, inputs: Dict[str, Any], action: Action) -> Dict[str, Any]: """Processes inputs, injecting the common inputs and ensuring that all required inputs are present.""" - starting_with_double_underscore = { - key for key in inputs.keys() if key.startswith("__") - } + starting_with_double_underscore = {key for key in inputs.keys() if key.startswith("__")} if len(starting_with_double_underscore) > 0: raise ValueError( BASE_ERROR_MESSAGE @@ -1163,16 +1129,12 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True inputs=action_inputs, name=next_action.name, ) - new_state = _run_reducer( - next_action, self._state, result, next_action.name - ) + new_state = _run_reducer(next_action, self._state, result, next_action.name) new_state = self._update_internal_state_value(new_state, next_action) self._set_state(new_state) except Exception as e: exc = e - logger.exception( - _format_BASE_ERROR_MESSAGE(next_action, self._state, inputs) - ) + logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) raise e finally: if _run_hooks: @@ -1220,19 +1182,13 @@ def _process_control_flow_params( halt_after_actions, halt_after_tags = self._parse_action_list(halt_after) halt_before_expanded = set(halt_before_actions) for tag in halt_before_tags: - halt_before_expanded.update( - [item.name for item in self.graph.get_actions_by_tag(tag)] - ) + halt_before_expanded.update([item.name for item in self.graph.get_actions_by_tag(tag)]) halt_after_expanded = set(halt_after_actions) for tag in halt_after_tags: - halt_after_expanded.update( - [item.name for item in self.graph.get_actions_by_tag(tag)] - ) + halt_after_expanded.update([item.name for item in self.graph.get_actions_by_tag(tag)]) return list(halt_before_expanded), list(halt_after_expanded), inputs - def _validate_halt_conditions( - self, halt_before: list[str], halt_after: list[str] - ) -> None: + def _validate_halt_conditions(self, halt_before: list[str], halt_after: list[str]) -> None: """Utility function to validate halt conditions""" missing_actions = set(halt_before + halt_after) - set( action.name for action in self.graph.actions @@ -1397,9 +1353,7 @@ def run( :return: The final state, and the results of running the actions in the order that they were specified. """ self.validate_correct_async_use() - gen = self.iterate( - halt_before=halt_before, halt_after=halt_after, inputs=inputs - ) + gen = self.iterate(halt_before=halt_before, halt_after=halt_after, inputs=inputs) while True: try: next(gen) @@ -1445,9 +1399,7 @@ def stream_result( halt_after: list[str], halt_before: Optional[list[str]] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> Tuple[ - Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]] - ]: + ) -> Tuple[Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]]]: """Streams a result out. :param halt_after: The list of actions/tags to halt after execution of. It will halt on the first one. @@ -1549,9 +1501,7 @@ def streaming_response(state: State, prompt: str) -> Generator[dict, None, Tuple print(format(result['response'], color)) """ self.validate_correct_async_use() - call_execute_method_wrapper = _call_execute_method_pre_post( - ExecuteMethod.stream_result - ) + call_execute_method_wrapper = _call_execute_method_pre_post(ExecuteMethod.stream_result) call_execute_method_wrapper.call_pre(self) halt_before, halt_after, inputs = self._process_control_flow_params( halt_before, halt_after, inputs @@ -1701,9 +1651,7 @@ async def astream_result( halt_after: list[str], halt_before: Optional[list[str]] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> Tuple[ - Action, AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]] - ]: + ) -> Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]]]: """Streams a result out in an asynchronous manner. :param halt_after: The list of actions/tags to halt after execution of. It will halt on the first one. @@ -1807,9 +1755,7 @@ async def streaming_response(state: State, prompt: str) -> Generator[dict, None, result, state = await output.get() print(format(result['response'], color)) """ - call_execute_method_wrapper = _call_execute_method_pre_post( - ExecuteMethod.stream_result - ) + call_execute_method_wrapper = _call_execute_method_pre_post(ExecuteMethod.stream_result) await call_execute_method_wrapper.acall_pre(self) halt_before, halt_after, inputs = self._process_control_flow_params( halt_before, halt_after, inputs @@ -1882,9 +1828,7 @@ async def callback( if not next_action.streaming: # In this case we are halting at a non-streaming condition # This is allowed as we want to maintain a more consistent API - action, result, state = await self._astep( - inputs=inputs, _run_hooks=False - ) + action, result, state = await self._astep(inputs=inputs, _run_hooks=False) await self._adapter_set.call_all_lifecycle_hooks_sync_and_async( "post_run_step", app_id=self._uid, @@ -2095,9 +2039,7 @@ def _set_state(self, new_state: State[ApplicationStateType]): self._state = new_state def get_next_action(self) -> Optional[Action]: - return self._graph.get_next_node( - self._state.get(PRIOR_STEP), self._state, self.entrypoint - ) + return self._graph.get_next_node(self._state.get(PRIOR_STEP), self._state, self.entrypoint) def update_state(self, new_state: State[ApplicationStateType]): """Updates state -- this is meant to be called if you need to do @@ -2449,9 +2391,7 @@ def with_actions( :return: The application builder for future chaining. """ self._initialize_graph_builder() - self.graph_builder = self.graph_builder.with_actions( - *action_list, **action_dict - ) + self.graph_builder = self.graph_builder.with_actions(*action_list, **action_dict) return self def with_transitions( @@ -2476,9 +2416,7 @@ def with_transitions( self.graph_builder = self.graph_builder.with_transitions(*transitions) return self - def with_hooks( - self, *adapters: LifecycleAdapter - ) -> "ApplicationBuilder[StateType]": + def with_hooks(self, *adapters: LifecycleAdapter) -> "ApplicationBuilder[StateType]": """Adds a lifecycle adapter to the application. This is a way to add hooks to the application so that they are run at the appropriate times. You can use this to synchronize state out, log results, etc... @@ -2533,9 +2471,7 @@ def with_tracker( if use_otel_tracing: from burr.integrations.opentelemetry import OpenTelemetryTracker - instantiated_tracker = OpenTelemetryTracker( - burr_tracker=instantiated_tracker - ) + instantiated_tracker = OpenTelemetryTracker(burr_tracker=instantiated_tracker) self.lifecycle_adapters.append(instantiated_tracker) self.tracker = instantiated_tracker return self @@ -2609,9 +2545,7 @@ def with_state_persister( if on_every != "step": raise ValueError(f"on_every {on_every} not supported") - self.state_persister = ( - persister # tracks for later; validates in build / abuild - ) + self.state_persister = persister # tracks for later; validates in build / abuild return self def with_spawning_parent( @@ -2658,9 +2592,7 @@ def _set_sync_state_persister(self): ) except NotImplementedError: pass - self.lifecycle_adapters.append( - persistence.PersisterHook(self.state_persister) - ) + self.lifecycle_adapters.append(persistence.PersisterHook(self.state_persister)) async def _set_async_state_persister(self): """Inits the asynchronous with_state_persister to save the state (local/DB/custom implementations). @@ -2685,9 +2617,7 @@ async def _set_async_state_persister(self): ) except NotImplementedError: pass - self.lifecycle_adapters.append( - persistence.PersisterHookAsync(self.state_persister) - ) + self.lifecycle_adapters.append(persistence.PersisterHookAsync(self.state_persister)) def _identify_state_to_load(self): """Helper to determine which state to load.""" @@ -2785,9 +2715,7 @@ def _load_from_sync_persister(self): # load state from persister load_result = self.state_initializer.load(_partition_key, _app_id, _sequence_id) - self._init_state_from_persister( - load_result, _partition_key, _app_id, _sequence_id - ) + self._init_state_from_persister(load_result, _partition_key, _app_id, _sequence_id) async def _load_from_async_persister(self): """Loads from the set async persister and into this current object. @@ -2808,12 +2736,8 @@ async def _load_from_async_persister(self): _partition_key, _app_id, _sequence_id = self._identify_state_to_load() # load state from persister - load_result = await self.state_initializer.load( - _partition_key, _app_id, _sequence_id - ) - self._init_state_from_persister( - load_result, _partition_key, _app_id, _sequence_id - ) + load_result = await self.state_initializer.load(_partition_key, _app_id, _sequence_id) + self._init_state_from_persister(load_result, _partition_key, _app_id, _sequence_id) def reset_to_entrypoint(self): self.state = self.state.wipe(delete=[PRIOR_STEP]) @@ -2831,9 +2755,7 @@ def _build_common(self) -> Application: graph = self._get_built_graph() _validate_start(self.start, {action.name for action in graph.actions}) typing_system: TypingSystem[StateType] = ( - self.typing_system - if self.typing_system is not None - else DictBasedTypingSystem() + self.typing_system if self.typing_system is not None else DictBasedTypingSystem() ) # type: ignore self.state = self.state.with_typing_system(typing_system=typing_system) return Application( diff --git a/tests/core/test_application.py b/tests/core/test_application.py index d025f7f29..7383583f0 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -21,17 +21,7 @@ import logging import typing import uuid -from typing import ( - Any, - Awaitable, - Callable, - Dict, - Generator, - Literal, - Optional, - Tuple, - Union, -) +from typing import Any, Awaitable, Callable, Dict, Generator, Literal, Optional, Tuple, Union import pytest @@ -173,9 +163,7 @@ async def run(self, state: State, **run_kwargs) -> dict: ) -class StreamEventCaptureTracker( - PreStartStreamHook, PostStreamItemHook, PostEndStreamHook -): +class StreamEventCaptureTracker(PreStartStreamHook, PostStreamItemHook, PostEndStreamHook): def post_end_stream( self, *, @@ -452,9 +440,7 @@ def test__run_function_with_inputs(): """Tests that we can run a function""" action = base_counter_action_with_inputs state = State({}) - result = _run_function( - action, state, inputs={"additional_increment": 1}, name=action.name - ) + result = _run_function(action, state, inputs={"additional_increment": 1}, name=action.name) assert result == {"count": 2} @@ -610,9 +596,7 @@ class BrokenAction(SingleStepAction): def reads(self) -> list[str]: return [] - async def run_and_update( - self, state: State, **run_kwargs - ) -> Tuple[dict, State]: + async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: await asyncio.sleep(0.0001) # just so we can make this *truly* async return {"present_value": 1}, state.update(present_value=1) @@ -800,9 +784,7 @@ def update(self, result: dict, state: State) -> State: class AsyncStreamingCounter(AsyncStreamingAction): - async def stream_run( - self, state: State, **run_kwargs - ) -> AsyncGenerator[dict, None]: + async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: if "steps_per_count" in run_kwargs: steps_per_count = run_kwargs["granularity"] else: @@ -834,9 +816,7 @@ def stream_run_and_update( count = state["count"] for i in range(steps_per_count): yield {"count": count + ((i + 1) / 10)}, None - yield {"count": count + 1}, state.update(count=count + 1).append( - tracker=count + 1 - ) + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) @property def reads(self) -> list[str]: @@ -857,9 +837,7 @@ async def stream_run_and_update( await asyncio.sleep(0.01) yield {"count": count + ((i + 1) / 10)}, None await asyncio.sleep(0.01) - yield {"count": count + 1}, state.update(count=count + 1).append( - tracker=count + 1 - ) + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) @property def reads(self) -> list[str]: @@ -888,9 +866,7 @@ def update(self, result: dict, state: State) -> State: class StreamingActionIncorrectResultTypeAsync(AsyncStreamingAction): - async def stream_run( - self, state: State, **run_kwargs - ) -> AsyncGenerator[dict, None]: + async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: yield {} yield "not a dict" @@ -950,9 +926,7 @@ def writes(self) -> list[str]: base_streaming_single_step_counter_async = SingleStepStreamingCounterAsync() base_single_step_action_incorrect_result_type = SingleStepActionIncorrectResultType() -base_single_step_action_incorrect_result_type_async = ( - SingleStepActionIncorrectResultTypeAsync() -) +base_single_step_action_incorrect_result_type_async = SingleStepActionIncorrectResultTypeAsync() def test__run_single_step_action(): @@ -983,14 +957,10 @@ async def test__arun_single_step_action_incorrect_result_type(): def test__run_single_step_action_with_inputs(): action = base_single_step_counter_with_inputs.with_name("counter") state = State({"count": 0, "tracker": []}) - result, state = _run_single_step_action( - action, state, inputs={"additional_increment": 1} - ) + result, state = _run_single_step_action(action, state, inputs={"additional_increment": 1}) assert result == {"count": 2} assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [2]} - result, state = _run_single_step_action( - action, state, inputs={"additional_increment": 1} - ) + result, state = _run_single_step_action(action, state, inputs={"additional_increment": 1}) assert result == {"count": 4} assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} @@ -1342,9 +1312,7 @@ def stream_run_and_update( yield {"count": count + ((i + 1) / 10)}, None raise RuntimeError("simulated failure") finally: - yield {"count": count + 1}, state.update(count=count + 1).append( - tracker=count + 1 - ) + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) @property def reads(self) -> list[str]: @@ -1387,9 +1355,7 @@ async def stream_run_and_update( yield {"count": count + ((i + 1) / 10)}, None raise RuntimeError("simulated failure") finally: - yield {"count": count + 1}, state.update(count=count + 1).append( - tracker=count + 1 - ) + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) @property def reads(self) -> list[str]: @@ -1466,9 +1432,7 @@ def update(self, result: dict, state: State) -> State: class MultiStepStreamingCounterWithExceptionAsync(AsyncStreamingAction): """Async variant: yields intermediate items, raises, then yields final result in finally.""" - async def stream_run( - self, state: State, **run_kwargs - ) -> AsyncGenerator[dict, None]: + async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: count = state["count"] try: for i in range(3): @@ -1492,9 +1456,7 @@ def update(self, result: dict, state: State) -> State: class MultiStepStreamingCounterWithExceptionNoResultAsync(AsyncStreamingAction): """Async variant: raises without ever yielding any item.""" - async def stream_run( - self, state: State, **run_kwargs - ) -> AsyncGenerator[dict, None]: + async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: raise RuntimeError("simulated failure with no result") yield # make this an async generator @@ -1690,9 +1652,7 @@ def test_app_step(): assert app.sequence_id == 1 assert action.name == "counter" assert result == {"count": 1} - assert ( - state[PRIOR_STEP] == "counter" - ) # internal contract, not part of the public API + assert state[PRIOR_STEP] == "counter" # internal contract, not part of the public API def test_app_step_with_inputs(): @@ -1747,9 +1707,7 @@ def test_app_step_broken(caplog): transitions=[Transition(broken_action, broken_action, default)], ), ) - with caplog.at_level( - logging.ERROR - ): # it should say the name, that's the only contract for now + with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now with pytest.raises(BrokenStepException): app.step() assert "broken_action_unique_name" in caplog.text @@ -1791,9 +1749,7 @@ async def test_app_astep(): assert app.sequence_id == 1 assert action.name == "counter_async" assert result == {"count": 1} - assert ( - state[PRIOR_STEP] == "counter_async" - ) # internal contract, not part of the public API + assert state[PRIOR_STEP] == "counter_async" # internal contract, not part of the public API def test_app_step_context(): @@ -1851,9 +1807,7 @@ def test_action(state: State, __context: ApplicationContext) -> State: async def test_app_astep_with_inputs(): """Tests that we can run an async step in an app""" - counter_action = base_single_step_counter_with_inputs_async.with_name( - "counter_async" - ) + counter_action = base_single_step_counter_with_inputs_async.with_name("counter_async") app = Application( state=State({"count": 0, "tracker": []}), entrypoint="counter_async", @@ -1873,9 +1827,7 @@ async def test_app_astep_with_inputs(): async def test_app_astep_with_inputs_missing(): """Tests that we can run an async step in an app""" - counter_action = base_single_step_counter_with_inputs_async.with_name( - "counter_async" - ) + counter_action = base_single_step_counter_with_inputs_async.with_name("counter_async") app = Application( state=State({"count": 0, "tracker": []}), entrypoint="counter_async", @@ -1905,9 +1857,7 @@ async def test_app_astep_broken(caplog): transitions=[Transition(broken_action, broken_action, default)], ), ) - with caplog.at_level( - logging.ERROR - ): # it should say the name, that's the only contract for now + with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now with pytest.raises(BrokenStepException): await app.astep() assert "broken_action_unique_name" in caplog.text @@ -1984,9 +1934,7 @@ def test_iterate(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, counter_action, Condition.expr("count < 10") - ), + Transition(counter_action, counter_action, Condition.expr("count < 10")), Transition(counter_action, result_action, default), ], ), @@ -2055,9 +2003,7 @@ async def test_aiterate(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, counter_action, Condition.expr("count < 10") - ), + Transition(counter_action, counter_action, Condition.expr("count < 10")), Transition(counter_action, result_action, default), ], ), @@ -2089,9 +2035,7 @@ async def test_aiterate_halt_before(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, counter_action, Condition.expr("count < 10") - ), + Transition(counter_action, counter_action, Condition.expr("count < 10")), Transition(counter_action, result_action, default), ], ), @@ -2121,9 +2065,7 @@ async def test_app_aiterate_with_inputs(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, counter_action, Condition.expr("count < 10") - ), + Transition(counter_action, counter_action, Condition.expr("count < 10")), Transition(counter_action, result_action, default), ], ), @@ -2148,9 +2090,7 @@ def test_run(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, counter_action, Condition.expr("count < 10") - ), + Transition(counter_action, counter_action, Condition.expr("count < 10")), Transition(counter_action, result_action, default), ], ), @@ -2172,9 +2112,7 @@ def test_run_halt_before(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, counter_action, Condition.expr("count < 10") - ), + Transition(counter_action, counter_action, Condition.expr("count < 10")), Transition(counter_action, result_action, default), ], ), @@ -2197,16 +2135,12 @@ def test_run_with_inputs(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, counter_action, Condition.expr("count < 10") - ), + Transition(counter_action, counter_action, Condition.expr("count < 10")), Transition(counter_action, result_action, default), ], ), ) - action_, result, state = app.run( - halt_after=["result"], inputs={"additional_increment": 10} - ) + action_, result, state = app.run(halt_after=["result"], inputs={"additional_increment": 10}) assert action_.name == "result" assert state["count"] == result["count"] == 11 @@ -2225,22 +2159,14 @@ def test_run_with_inputs_multiple_actions(): graph=Graph( actions=[counter_action1, counter_action2, result_action], transitions=[ - Transition( - counter_action1, counter_action1, Condition.expr("count < 10") - ), - Transition( - counter_action1, counter_action2, Condition.expr("count >= 10") - ), - Transition( - counter_action2, counter_action2, Condition.expr("count < 20") - ), + Transition(counter_action1, counter_action1, Condition.expr("count < 10")), + Transition(counter_action1, counter_action2, Condition.expr("count >= 10")), + Transition(counter_action2, counter_action2, Condition.expr("count < 20")), Transition(counter_action2, result_action, default), ], ), ) - action_, result, state = app.run( - halt_after=["result"], inputs={"additional_increment": 8} - ) + action_, result, state = app.run(halt_after=["result"], inputs={"additional_increment": 8}) assert action_.name == "result" assert state["count"] == result["count"] == 27 assert state["__SEQUENCE_ID"] == 4 @@ -2258,9 +2184,7 @@ async def test_arun(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, counter_action, Condition.expr("count < 10") - ), + Transition(counter_action, counter_action, Condition.expr("count < 10")), Transition(counter_action, result_action, default), ], ), @@ -2282,9 +2206,7 @@ async def test_arun_halt_before(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, counter_action, Condition.expr("count < 10") - ), + Transition(counter_action, counter_action, Condition.expr("count < 10")), Transition(counter_action, result_action, default), ], ), @@ -2307,9 +2229,7 @@ async def test_arun_with_inputs(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, counter_action, Condition.expr("count < 10") - ), + Transition(counter_action, counter_action, Condition.expr("count < 10")), Transition(counter_action, result_action, default), ], ), @@ -2334,15 +2254,9 @@ async def test_arun_with_inputs_multiple_actions(): graph=Graph( actions=[counter_action1, counter_action2, result_action], transitions=[ - Transition( - counter_action1, counter_action1, Condition.expr("count < 10") - ), - Transition( - counter_action1, counter_action2, Condition.expr("count >= 10") - ), - Transition( - counter_action2, counter_action2, Condition.expr("count < 20") - ), + Transition(counter_action1, counter_action1, Condition.expr("count < 10")), + Transition(counter_action1, counter_action2, Condition.expr("count >= 10")), + Transition(counter_action2, counter_action2, Condition.expr("count < 20")), Transition(counter_action2, result_action, default), ], ), @@ -2451,9 +2365,7 @@ async def test_astream_result_halt_after_unique_ordered_sequence_id(): ], ), ) - action_, streaming_async_container = await app.astream_result( - halt_after=["counter_2"] - ) + action_, streaming_async_container = await app.astream_result(halt_after=["counter_2"]) results = [ item async for item in streaming_async_container ] # this should just have the intermediate results @@ -2727,9 +2639,7 @@ def test_stream_result_halt_after_run_through_non_streaming(): action_tracker = CallCaptureTracker() stream_event_tracker = StreamEventCaptureTracker() counter_non_streaming = base_counter_action.with_name("counter_non_streaming") - counter_streaming = base_streaming_single_step_counter.with_name( - "counter_streaming" - ) + counter_streaming = base_streaming_single_step_counter.with_name("counter_streaming") app = Application( state=State({"count": 0}), @@ -2740,9 +2650,7 @@ def test_stream_result_halt_after_run_through_non_streaming(): graph=Graph( actions=[counter_non_streaming, counter_streaming], transitions=[ - Transition( - counter_non_streaming, counter_non_streaming, expr("count < 10") - ), + Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), Transition(counter_non_streaming, counter_streaming, default), ], ), @@ -2784,9 +2692,7 @@ async def test_astream_result_halt_after_run_through_non_streaming(): stream_event_tracker = StreamEventCaptureTrackerAsync() stream_event_tracker_sync = StreamEventCaptureTracker() counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") - counter_streaming = base_streaming_single_step_counter_async.with_name( - "counter_streaming" - ) + counter_streaming = base_streaming_single_step_counter_async.with_name("counter_streaming") app = Application( state=State({"count": 0}), @@ -2799,16 +2705,12 @@ async def test_astream_result_halt_after_run_through_non_streaming(): graph=Graph( actions=[counter_non_streaming, counter_streaming], transitions=[ - Transition( - counter_non_streaming, counter_non_streaming, expr("count < 10") - ), + Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), Transition(counter_non_streaming, counter_streaming, default), ], ), ) - action_, async_streaming_container = await app.astream_result( - halt_after=["counter_streaming"] - ) + action_, async_streaming_container = await app.astream_result(halt_after=["counter_streaming"]) results = [ item async for item in async_streaming_container ] # this should just have the intermediate results @@ -2858,9 +2760,7 @@ def test_stream_result_halt_after_run_through_final_non_streaming(): """Tests that we can pass through non-streaming results when streaming is called""" action_tracker = CallCaptureTracker() counter_non_streaming = base_counter_action.with_name("counter_non_streaming") - counter_final_non_streaming = base_counter_action.with_name( - "counter_final_non_streaming" - ) + counter_final_non_streaming = base_counter_action.with_name("counter_final_non_streaming") app = Application( state=State({"count": 0}), @@ -2871,16 +2771,12 @@ def test_stream_result_halt_after_run_through_final_non_streaming(): graph=Graph( actions=[counter_non_streaming, counter_final_non_streaming], transitions=[ - Transition( - counter_non_streaming, counter_non_streaming, expr("count < 10") - ), + Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), Transition(counter_non_streaming, counter_final_non_streaming, default), ], ), ) - action, streaming_container = app.stream_result( - halt_after=["counter_final_non_streaming"] - ) + action, streaming_container = app.stream_result(halt_after=["counter_final_non_streaming"]) results = list(streaming_container) assert len(results) == 0 # nothing to steram result, state = streaming_container.get() @@ -2910,9 +2806,7 @@ async def test_astream_result_halt_after_run_through_final_streaming(): action_tracker = CallCaptureTracker() counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") - counter_final_non_streaming = base_counter_action_async.with_name( - "counter_final_non_streaming" - ) + counter_final_non_streaming = base_counter_action_async.with_name("counter_final_non_streaming") app = Application( state=State({"count": 0}), @@ -2923,9 +2817,7 @@ async def test_astream_result_halt_after_run_through_final_streaming(): graph=Graph( actions=[counter_non_streaming, counter_final_non_streaming], transitions=[ - Transition( - counter_non_streaming, counter_non_streaming, expr("count < 10") - ), + Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), Transition(counter_non_streaming, counter_final_non_streaming, default), ], ), @@ -2973,16 +2865,12 @@ def test_stream_result_halt_before(): graph=Graph( actions=[counter_non_streaming, counter_streaming], transitions=[ - Transition( - counter_non_streaming, counter_non_streaming, expr("count < 10") - ), + Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), Transition(counter_non_streaming, counter_streaming, default), ], ), ) - action, streaming_container = app.stream_result( - halt_after=[], halt_before=["counter_final"] - ) + action, streaming_container = app.stream_result(halt_after=[], halt_before=["counter_final"]) results = list(streaming_container) assert len(results) == 0 # nothing to steram result, state = streaming_container.get() @@ -3002,9 +2890,7 @@ def test_stream_result_halt_before(): async def test_astream_result_halt_before(): action_tracker = CallCaptureTracker() counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") - counter_streaming = base_streaming_single_step_counter_async.with_name( - "counter_final" - ) + counter_streaming = base_streaming_single_step_counter_async.with_name("counter_final") app = Application( state=State({"count": 0}), @@ -3015,9 +2901,7 @@ async def test_astream_result_halt_before(): graph=Graph( actions=[counter_non_streaming, counter_streaming], transitions=[ - Transition( - counter_non_streaming, counter_non_streaming, expr("count < 10") - ), + Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), Transition(counter_non_streaming, counter_streaming, default), ], ), @@ -3125,9 +3009,7 @@ def test__validate_start_not_found(): def test__adjust_single_step_output_result_and_state(): state = State({"count": 1}) result = {"count": 1} - assert _adjust_single_step_output( - (result, state), "test_action", DEFAULT_SCHEMA - ) == ( + assert _adjust_single_step_output((result, state), "test_action", DEFAULT_SCHEMA) == ( result, state, ) @@ -3173,9 +3055,7 @@ def test_application_run_step_hooks_sync(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, result_action, Condition.expr("count >= 10") - ), + Transition(counter_action, result_action, Condition.expr("count >= 10")), Transition(counter_action, counter_action, default), ], ), @@ -3226,9 +3106,7 @@ async def test_application_run_step_hooks_async(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, result_action, Condition.expr("count >= 10") - ), + Transition(counter_action, result_action, Condition.expr("count >= 10")), Transition(counter_action, counter_action, default), ], ), @@ -3342,9 +3220,7 @@ def post_application_create(self, **kwargs): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, result_action, Condition.expr("count >= 10") - ), + Transition(counter_action, result_action, Condition.expr("count >= 10")), Transition(counter_action, counter_action, default), ], ), @@ -3366,9 +3242,7 @@ async def test_application_gives_graph(): graph=Graph( actions=[counter_action, result_action], transitions=[ - Transition( - counter_action, result_action, Condition.expr("count >= 10") - ), + Transition(counter_action, result_action, Condition.expr("count >= 10")), Transition(counter_action, counter_action, default), ], ), @@ -3381,9 +3255,7 @@ async def test_application_gives_graph(): def test_application_builder_initialize_does_not_allow_state_setting(): with pytest.raises(ValueError, match="Cannot call initialize_from"): - ApplicationBuilder().with_entrypoint("foo").with_state( - **{"foo": "bar"} - ).initialize_from( + ApplicationBuilder().with_entrypoint("foo").with_state(**{"foo": "bar"}).initialize_from( DevNullPersister(), resume_at_next_action=True, default_state={}, @@ -3516,9 +3388,7 @@ def test_action(state: State) -> State: .build() ) with pytest.raises(ValueError, match="(?=.*no_exist_1)(?=.*no_exist_2)"): - app._validate_halt_conditions( - halt_after=["no_exist_1"], halt_before=["no_exist_2"] - ) + app._validate_halt_conditions(halt_after=["no_exist_1"], halt_before=["no_exist_2"]) def test_application_builder_initialize_raises_on_fork_app_id_not_provided(): @@ -3621,9 +3491,7 @@ def test_application_builder_initialize_fork_app_id_happy_pth(): .build() ) assert app.uid != old_app_id - assert app.state == State( - {"count": 5, "__PRIOR_STEP": "counter", "__SEQUENCE_ID": 5} - ) + assert app.state == State({"count": 5, "__PRIOR_STEP": "counter", "__SEQUENCE_ID": 5}) assert app.parent_pointer.app_id == old_app_id @@ -3776,9 +3644,7 @@ def context_counter(state: State, __context: ApplicationContext) -> State: app = ( ApplicationBuilder() .with_actions(counter=context_counter, result=result_action) - .with_transitions( - ("counter", "counter", expr("count < 10")), ("counter", "result") - ) + .with_transitions(("counter", "counter", expr("count < 10")), ("counter", "result")) .with_tracker(NoOpTracker("unique_tracker_name")) .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) .with_entrypoint("counter") @@ -3809,9 +3675,7 @@ def context_counter(state: State, __context: ApplicationContext = None) -> State app = ( ApplicationBuilder() .with_actions(counter=context_counter, result=result_action) - .with_transitions( - ("counter", "counter", expr("count < 10")), ("counter", "result") - ) + .with_transitions(("counter", "counter", expr("count < 10")), ("counter", "result")) .with_tracker(NoOpTracker("unique_tracker_name")) .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) .with_entrypoint("counter") @@ -3949,9 +3813,7 @@ def recursive_action(state: State) -> State: result = recursive_action(State({"recursion_count": 0, "total_count": 0})) # Basic sanity checks to demonstrate assert result["recursion_count"] == 5 - assert ( - result["total_count"] == 63 - ) # One for each of the calls (sum(2**n for n in range(6))) + assert result["total_count"] == 63 # One for each of the calls (sum(2**n for n in range(6))) assert ( len(hook.pre_called) == 62 ) # 63 - the initial one from the call to recursive_action outside the application @@ -4138,10 +4000,7 @@ def test_remap_context_variable_with_mangled_context_kwargs(): "other_key": "other_value", "foo": "foo_value", } - assert ( - _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) - == expected - ) + assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected def test_remap_context_variable_with_mangled_context(): @@ -4157,10 +4016,7 @@ def test_remap_context_variable_with_mangled_context(): "other_key": "other_value", "foo": "foo_value", } - assert ( - _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) - == expected - ) + assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected def test_remap_context_variable_with_mangled_contexttracer(): @@ -4178,10 +4034,7 @@ def test_remap_context_variable_with_mangled_contexttracer(): "foo": "foo_value", f"_{ActionWithContextTracer.__name__}__tracer": "tracer_value", } - assert ( - _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) - == expected - ) + assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected def test_remap_context_variable_without_mangled_context(): @@ -4196,10 +4049,7 @@ def test_remap_context_variable_without_mangled_context(): "other_key": "other_value", "foo": "foo_value", } - assert ( - _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) - == expected - ) + assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected async def test_async_application_builder_initialize_raises_on_broken_persistor():