diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py index 857fed33..3b03fc03 100644 --- a/burr/core/parallelism.py +++ b/burr/core/parallelism.py @@ -518,7 +518,7 @@ def _create_task(key: str, action: Action, substate: State) -> SubGraphTask: def _tasks() -> Generator[SubGraphTask, None, None]: for i, action in enumerate(self.actions(state, context, inputs)): for j, substate in enumerate(self.states(state, context, inputs)): - key = f"{i}-{j}" # this is a stable hash for now but will not handle caching + key = f"{context.sequence_id}-{i}-{j}" yield _create_task(key, action, substate) async def _atasks() -> AsyncGenerator[SubGraphTask, None]: @@ -528,7 +528,7 @@ async def _atasks() -> AsyncGenerator[SubGraphTask, None]: states = await async_utils.arealize(state_generator) for i, action in enumerate(actions): for j, substate in enumerate(states): - key = f"{i}-{j}" + key = f"{context.sequence_id}-{i}-{j}" yield _create_task(key, action, substate) return _atasks() if self.is_async() else _tasks() diff --git a/tests/core/test_parallelism.py b/tests/core/test_parallelism.py index 25d37cc2..1e3afdc3 100644 --- a/tests/core/test_parallelism.py +++ b/tests/core/test_parallelism.py @@ -45,7 +45,12 @@ _cascade_adapter, map_reduce_action, ) -from burr.core.persistence import BaseStateLoader, BaseStateSaver, PersistedStateData +from burr.core.persistence import ( + BaseStateLoader, + BaseStateSaver, + InMemoryPersister, + PersistedStateData, +) from burr.tracking.base import SyncTrackingClient from burr.visibility import ActionSpan @@ -1227,3 +1232,69 @@ def reads(self) -> list[str]: assert task.state_initializer is not None assert task.tracker is not None assert task.state_persister is task.state_initializer # This ensures they're the same + + +def test_map_states_no_stale_replay_on_repeated_invocation(): + """Regression test for #761. + + When the parent app cascades a state initializer (via ``initialize_from``) + and a MapStates action is invoked more than once, each invocation must + spawn fresh sub-applications. Before the fix, sub-app ids were keyed only + on ``(i, j)``, so they collided across invocations and the cascaded + initializer replayed the prior call's persisted sub-state. + """ + + @old_action(reads=["round"], writes=["output_number"]) + def emit_round(state: State) -> State: + return state.update(output_number=state["round"]) + + @old_action(reads=["round"], writes=["round"]) + def bump(state: State) -> State: + return state.update(round=state["round"] + 1) + + class Fan(MapStates): + def action(self, state: State, inputs: Dict[str, Any]): + return emit_round + + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + for _ in range(3): + yield state + + def reduce(self, state: State, states: Generator[State, None, None]) -> State: + new_state = state + for output_state in states: + new_state = new_state.append(outputs=output_state["output_number"]) + return new_state + + @property + def reads(self) -> list[str]: + return ["round"] + + @property + def writes(self) -> list[str]: + return ["outputs"] + + persister = InMemoryPersister() + app = ( + ApplicationBuilder() + .with_actions(fan=Fan(), bump=bump) + .with_transitions(("fan", "bump"), ("bump", "fan")) + .with_state_persister(persister) + .with_identifiers(app_id="parent-app") + .initialize_from( + persister, + resume_at_next_action=True, + default_state={"round": 1, "outputs": []}, + default_entrypoint="fan", + ) + .build() + ) + + app.run(halt_after=["fan"]) # first fan invocation, round=1 + app.run(halt_after=["fan"]) # bump runs, then second fan invocation, round=2 + + # Each fan invocation contributes 3 outputs. Fresh execution -> [1,1,1,2,2,2]. + # Buggy behavior replays the first invocation's persisted sub-state -> [1,1,1,1,1,1]. + assert list(app.state["outputs"]) == [1, 1, 1, 2, 2, 2]