diff --git a/burr/core/application.py b/burr/core/application.py index dc8067c4b..fb460035c 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -338,31 +338,44 @@ def _run_single_step_streaming_action( result = None state_update = None count = 0 - for item in generator: - if not isinstance(item, tuple): - # TODO -- consider adding support for just returning a result. - raise ValueError( - f"Action {action.name} must yield a tuple of (result, state_update). " - f"For all non-final results (intermediate)," - f"the state update must be None" - ) - result, state_update = item - count += 1 + try: + for item in generator: + if not isinstance(item, tuple): + # TODO -- consider adding support for just returning a result. + raise ValueError( + f"Action {action.name} must yield a tuple of (result, state_update). " + f"For all non-final results (intermediate)," + f"the state update must be None" + ) + result, state_update = item + if state_update is None: + count += 1 + if first_stream_start_time is None: + first_stream_start_time = system.now() + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_item", + item=result, + item_index=count, + stream_initialize_time=stream_initialize_time, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + yield result, None + except Exception as e: if state_update is None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - lifecycle_adapters.call_all_lifecycle_hooks_sync( - "post_stream_item", - item=result, - item_index=count, - stream_initialize_time=stream_initialize_time, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - yield result, None + raise + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with final state from generator cleanup. Original error: %s", + action.name, + type(e).__name__, + count, + e, + exc_info=True, + ) if state_update is None: raise ValueError( @@ -391,31 +404,45 @@ async def _arun_single_step_streaming_action( result = None state_update = None count = 0 - async for item in generator: - if not isinstance(item, tuple): - # TODO -- consider adding support for just returning a result. - raise ValueError( - f"Action {action.name} must yield a tuple of (result, state_update). " - f"For all non-final results (intermediate)," - f"the state update must be None" - ) - result, state_update = item + try: + async for item in generator: + if not isinstance(item, tuple): + # TODO -- consider adding support for just returning a result. + raise ValueError( + f"Action {action.name} must yield a tuple of (result, state_update). " + f"For all non-final results (intermediate)," + f"the state update must be None" + ) + result, state_update = item + if state_update is None: + count += 1 + if first_stream_start_time is None: + first_stream_start_time = system.now() + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_item", + item=result, + item_index=count, + stream_initialize_time=stream_initialize_time, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + yield result, None + except Exception as e: if state_update is None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( - "post_stream_item", - item=result, - item_index=count, - stream_initialize_time=stream_initialize_time, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - count += 1 - yield result, None + raise + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with final state from generator cleanup. Original error: %s", + action.name, + type(e).__name__, + count, + e, + exc_info=True, + ) + if state_update is None: raise ValueError( f"Action {action.name} did not return a state update. For async actions, the last yield " @@ -450,28 +477,42 @@ def _run_multi_step_streaming_action( result = None first_stream_start_time = None count = 0 - for item in generator: - # We want to peek ahead so we can return the last one - # This is slightly eager, but only in the case in which we - # are using a multi-step streaming action - next_result = result - result = item - if next_result is not None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - lifecycle_adapters.call_all_lifecycle_hooks_sync( - "post_stream_item", - item=next_result, - item_index=count, - stream_initialize_time=stream_initialize_time, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - count += 1 - yield next_result, None + try: + for item in generator: + # We want to peek ahead so we can return the last one + # This is slightly eager, but only in the case in which we + # are using a multi-step streaming action + next_result = result + result = item + if next_result is not None: + if first_stream_start_time is None: + first_stream_start_time = system.now() + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_item", + item=next_result, + item_index=count, + stream_initialize_time=stream_initialize_time, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + count += 1 + yield next_result, None + except Exception as e: + if result is None: + raise + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with last yielded result for reducer. " + "Note: the reducer will run on potentially partial data. Original error: %s", + action.name, + type(e).__name__, + count, + e, + exc_info=True, + ) state_update = _run_reducer(action, state, result, action.name) _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) @@ -494,28 +535,42 @@ async def _arun_multi_step_streaming_action( result = None first_stream_start_time = None count = 0 - async for item in generator: - # We want to peek ahead so we can return the last one - # This is slightly eager, but only in the case in which we - # are using a multi-step streaming action - next_result = result - result = item - if next_result is not None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( - "post_stream_item", - item=next_result, - stream_initialize_time=stream_initialize_time, - item_index=count, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - count += 1 - yield next_result, None + try: + async for item in generator: + # We want to peek ahead so we can return the last one + # This is slightly eager, but only in the case in which we + # are using a multi-step streaming action + next_result = result + result = item + if next_result is not None: + if first_stream_start_time is None: + first_stream_start_time = system.now() + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_item", + item=next_result, + stream_initialize_time=stream_initialize_time, + item_index=count, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + count += 1 + yield next_result, None + except Exception as e: + if result is None: + raise + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with last yielded result for reducer. " + "Note: the reducer will run on potentially partial data. Original error: %s", + action.name, + type(e).__name__, + count, + e, + exc_info=True, + ) state_update = _run_reducer(action, state, result, action.name) _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) @@ -1871,7 +1926,9 @@ def stream_iterate( halt_before: Optional[Union[str, List[str]]] = None, inputs: Optional[Dict[str, Any]] = None, ) -> Generator[ - Tuple[Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]]], None, None + Tuple[Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]]], + None, + None, ]: """Produces an iterator that iterates through intermediate streams. You may want to use this in something like deep research mode in which: @@ -1915,7 +1972,11 @@ async def astream_iterate( halt_before: Optional[Union[str, List[str]]] = None, inputs: Optional[Dict[str, Any]] = None, ) -> AsyncGenerator[ - Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]]], None + Tuple[ + Action, + AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]], + ], + None, ]: """Async version of stream_iterate. Produces an async generator that iterates through intermediate streams. See stream_iterate for more details. diff --git a/tests/core/test_application.py b/tests/core/test_application.py index c90c40676..7383583f0 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -630,7 +630,12 @@ def writes(self) -> list[str]: state = State() with pytest.raises(ValueError, match="missing_value"): gen = _run_single_step_streaming_action( - action, state, inputs={}, sequence_id=0, partition_key="partition_key", app_id="app_id" + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", ) collections.deque(gen, maxlen=0) # exhaust the generator @@ -687,7 +692,12 @@ def writes(self) -> list[str]: state = State() with pytest.raises(ValueError, match="missing_value"): gen = _run_multi_step_streaming_action( - action, state, inputs={}, sequence_id=0, partition_key="partition_key", app_id="app_id" + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", ) collections.deque(gen, maxlen=0) # exhaust the generator @@ -1005,7 +1015,12 @@ def test__run_multistep_streaming_action(): action = base_streaming_counter.with_name("counter") state = State({"count": 0, "tracker": []}) generator = _run_multi_step_streaming_action( - action, state, inputs={}, sequence_id=0, partition_key="partition_key", app_id="app_id" + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", ) last_result = -1 result = None @@ -1111,7 +1126,12 @@ def test__run_streaming_action_incorrect_result_type(): state = State() with pytest.raises(ValueError, match="returned a non-dict"): gen = _run_multi_step_streaming_action( - action, state, inputs={}, sequence_id=0, partition_key="partition_key", app_id="app_id" + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", ) collections.deque(gen, maxlen=0) # exhaust the generator @@ -1168,7 +1188,12 @@ def test__run_single_step_streaming_action(): action = base_streaming_single_step_counter.with_name("counter") state = State({"count": 0, "tracker": []}) generator = _run_single_step_streaming_action( - action, state, inputs={}, sequence_id=0, partition_key="partition_key", app_id="app_id" + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", ) last_result = -1 result, state = None, None @@ -1275,6 +1300,328 @@ async def post_stream_item(self, item: Any, **future_kwargs: Any): assert len(hook.items) == 10 # one for each streaming callback +class SingleStepStreamingCounterWithException(SingleStepStreamingAction): + """Yields intermediate items, raises, then yields final state in finally block.""" + + def stream_run_and_update( + self, state: State, **run_kwargs + ) -> Generator[Tuple[dict, Optional[State]], None, None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class SingleStepStreamingCounterWithExceptionNoState(SingleStepStreamingAction): + """Raises without ever yielding a final state update.""" + + def stream_run_and_update( + self, state: State, **run_kwargs + ) -> Generator[Tuple[dict, Optional[State]], None, None]: + count = state["count"] + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure with no state") + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class SingleStepStreamingCounterWithExceptionAsync(SingleStepStreamingAction): + """Async variant: yields intermediate items, raises, then yields final state in finally.""" + + async def stream_run_and_update( + self, state: State, **run_kwargs + ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class SingleStepStreamingCounterWithExceptionNoStateAsync(SingleStepStreamingAction): + """Async variant: raises without ever yielding a final state update.""" + + async def stream_run_and_update( + self, state: State, **run_kwargs + ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + count = state["count"] + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure with no state") + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class MultiStepStreamingCounterWithException(StreamingAction): + """Yields intermediate items, raises, then yields final result in finally block.""" + + def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)} + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1} + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +class MultiStepStreamingCounterWithExceptionNoResult(StreamingAction): + """Raises without ever yielding any item.""" + + def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: + raise RuntimeError("simulated failure with no result") + yield # make this a generator function + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +class MultiStepStreamingCounterWithExceptionAsync(AsyncStreamingAction): + """Async variant: yields intermediate items, raises, then yields final result in finally.""" + + async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)} + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1} + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +class MultiStepStreamingCounterWithExceptionNoResultAsync(AsyncStreamingAction): + """Async variant: raises without ever yielding any item.""" + + async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: + raise RuntimeError("simulated failure with no result") + yield # make this an async generator + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +def test__run_single_step_streaming_action_graceful_exception(): + """When the generator raises but yields a final state in finally, stream completes gracefully.""" + action = SingleStepStreamingCounterWithException().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + results = list(generator) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == { + "count": 1, + "tracker": [1], + } + + +def test__run_single_step_streaming_action_exception_propagates(): + """When the generator raises without yielding a final state, exception propagates.""" + action = SingleStepStreamingCounterWithExceptionNoState().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + with pytest.raises(RuntimeError, match="simulated failure with no state"): + list(generator) + + +async def test__run_single_step_streaming_action_graceful_exception_async(): + """Async: when the generator raises but yields a final state in finally, stream completes.""" + action = SingleStepStreamingCounterWithExceptionAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_single_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + results = [] + async for item in generator: + results.append(item) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == { + "count": 1, + "tracker": [1], + } + + +async def test__run_single_step_streaming_action_exception_propagates_async(): + """Async: when the generator raises without yielding a final state, exception propagates.""" + action = SingleStepStreamingCounterWithExceptionNoStateAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_single_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + with pytest.raises(RuntimeError, match="simulated failure with no state"): + async for _ in generator: + pass + + +def test__run_multi_step_streaming_action_graceful_exception(): + """When the generator raises but yields a final result in finally, stream completes.""" + action = MultiStepStreamingCounterWithException().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_multi_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + results = list(generator) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == { + "count": 1, + "tracker": [1], + } + + +def test__run_multi_step_streaming_action_exception_propagates(): + """When the generator raises without yielding any result, exception propagates.""" + action = MultiStepStreamingCounterWithExceptionNoResult().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_multi_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + with pytest.raises(RuntimeError, match="simulated failure with no result"): + list(generator) + + +async def test__run_multi_step_streaming_action_graceful_exception_async(): + """Async: when the generator raises but yields a final result in finally, stream completes.""" + action = MultiStepStreamingCounterWithExceptionAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_multi_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + results = [] + async for item in generator: + results.append(item) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == { + "count": 1, + "tracker": [1], + } + + +async def test__run_multi_step_streaming_action_exception_propagates_async(): + """Async: when the generator raises without yielding any result, exception propagates.""" + action = MultiStepStreamingCounterWithExceptionNoResultAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_multi_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + with pytest.raises(RuntimeError, match="simulated failure with no result"): + async for _ in generator: + pass + + class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion): async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return {}, state.wipe(delete=["to_delete"]) @@ -1935,7 +2282,11 @@ async def test_app_a_run_async_and_sync(): graph=Graph( actions=[counter_action_sync, counter_action_async, result_action], transitions=[ - Transition(counter_action_sync, counter_action_async, Condition.expr("count < 20")), + Transition( + counter_action_sync, + counter_action_async, + Condition.expr("count < 20"), + ), Transition(counter_action_async, counter_action_sync, default), Transition(counter_action_sync, result_action, default), ], @@ -2060,7 +2411,8 @@ async def test_astream_result_halt_after_unique_ordered_sequence_id(): def test_stream_result_halt_after_run_through_streaming(): """Tests that we can pass through streaming results, - fully realize them, then get to the streaming results at the end and return the stream""" + fully realize them, then get to the streaming results at the end and return the stream + """ action_tracker = CallCaptureTracker() stream_event_tracker = StreamEventCaptureTracker() counter_action = base_streaming_single_step_counter.with_name("counter") @@ -2665,7 +3017,10 @@ def test__adjust_single_step_output_result_and_state(): def test__adjust_single_step_output_just_state(): state = State({"count": 1}) - assert _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) == ({}, state) + assert _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) == ( + {}, + state, + ) def test__adjust_single_step_output_errors_incorrect_type(): @@ -2912,7 +3267,11 @@ class BrokenPersister(BaseStatePersister): """Broken persistor.""" def load( - self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs + self, + partition_key: str, + app_id: Optional[str], + sequence_id: Optional[int] = None, + **kwargs, ) -> Optional[PersistedStateData]: return dict( partition_key="key", @@ -2968,7 +3327,8 @@ def test_load_from_sync_cannot_have_async_persistor_error(): default_entrypoint="foo", ) with pytest.raises( - ValueError, match="are building the sync application, but have used an async initializer." + ValueError, + match="are building the sync application, but have used an async initializer.", ): # we have not initialized builder._load_from_sync_persister() @@ -2984,7 +3344,8 @@ async def test_load_from_async_cannot_have_sync_persistor_error(): default_entrypoint="foo", ) with pytest.raises( - ValueError, match="are building the async application, but have used an sync initializer." + ValueError, + match="are building the async application, but have used an sync initializer.", ): # we have not initialized await builder._load_from_async_persister() @@ -3055,7 +3416,11 @@ class DummyPersister(BaseStatePersister): """Dummy persistor.""" def load( - self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs + self, + partition_key: str, + app_id: Optional[str], + sequence_id: Optional[int] = None, + **kwargs, ) -> Optional[PersistedStateData]: return PersistedStateData( partition_key="user123", @@ -3416,7 +3781,10 @@ def post_run_execute_call( hook = TestingHook() foo = [] - @action(reads=["recursion_count", "total_count"], writes=["recursion_count", "total_count"]) + @action( + reads=["recursion_count", "total_count"], + writes=["recursion_count", "total_count"], + ) def recursive_action(state: State) -> State: foo.append(1) recursion_count = state["recursion_count"] @@ -3498,7 +3866,8 @@ def test_set_sync_state_persister_cannot_have_async_error(): persister = AsyncDevNullPersister() builder.with_state_persister(persister) with pytest.raises( - ValueError, match="are building the sync application, but have used an async persister." + ValueError, + match="are building the sync application, but have used an async persister.", ): # we have not initialized builder._set_sync_state_persister() @@ -3519,7 +3888,8 @@ async def test_set_async_state_persister_cannot_have_sync_error(): persister = DevNullPersister() builder.with_state_persister(persister) with pytest.raises( - ValueError, match="are building the async application, but have used an sync persister." + ValueError, + match="are building the async application, but have used an sync persister.", ): # we have not initialized await builder._set_async_state_persister() @@ -3620,15 +3990,27 @@ def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: def test_remap_context_variable_with_mangled_context_kwargs(): _action = ActionWithKwargs() - inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} - expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} + inputs = { + "__context": "context_value", + "other_key": "other_value", + "foo": "foo_value", + } + expected = { + "__context": "context_value", + "other_key": "other_value", + "foo": "foo_value", + } assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected def test_remap_context_variable_with_mangled_context(): _action = ActionWithContext() - inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} + inputs = { + "__context": "context_value", + "other_key": "other_value", + "foo": "foo_value", + } expected = { f"_{ActionWithContext.__name__}__context": "context_value", "other_key": "other_value", @@ -3657,8 +4039,16 @@ def test_remap_context_variable_with_mangled_contexttracer(): def test_remap_context_variable_without_mangled_context(): _action = ActionWithoutContext() - inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} - expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} + inputs = { + "__context": "context_value", + "other_key": "other_value", + "foo": "foo_value", + } + expected = { + "__context": "context_value", + "other_key": "other_value", + "foo": "foo_value", + } assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected