Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions burr/core/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def _create_task(key: str, action: Action, substate: State) -> SubGraphTask:
def _tasks() -> Generator[SubGraphTask, None, None]:
for i, action in enumerate(self.actions(state, context, inputs)):
for j, substate in enumerate(self.states(state, context, inputs)):
key = f"{i}-{j}" # this is a stable hash for now but will not handle caching
key = f"{context.sequence_id}-{i}-{j}"
yield _create_task(key, action, substate)

async def _atasks() -> AsyncGenerator[SubGraphTask, None]:
Expand All @@ -528,7 +528,7 @@ async def _atasks() -> AsyncGenerator[SubGraphTask, None]:
states = await async_utils.arealize(state_generator)
for i, action in enumerate(actions):
for j, substate in enumerate(states):
key = f"{i}-{j}"
key = f"{context.sequence_id}-{i}-{j}"
yield _create_task(key, action, substate)

return _atasks() if self.is_async() else _tasks()
Expand Down
73 changes: 72 additions & 1 deletion tests/core/test_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
_cascade_adapter,
map_reduce_action,
)
from burr.core.persistence import BaseStateLoader, BaseStateSaver, PersistedStateData
from burr.core.persistence import (
BaseStateLoader,
BaseStateSaver,
InMemoryPersister,
PersistedStateData,
)
from burr.tracking.base import SyncTrackingClient
from burr.visibility import ActionSpan

Expand Down Expand Up @@ -1227,3 +1232,69 @@ def reads(self) -> list[str]:
assert task.state_initializer is not None
assert task.tracker is not None
assert task.state_persister is task.state_initializer # This ensures they're the same


def test_map_states_no_stale_replay_on_repeated_invocation():
"""Regression test for #761.

When the parent app cascades a state initializer (via ``initialize_from``)
and a MapStates action is invoked more than once, each invocation must
spawn fresh sub-applications. Before the fix, sub-app ids were keyed only
on ``(i, j)``, so they collided across invocations and the cascaded
initializer replayed the prior call's persisted sub-state.
"""

@old_action(reads=["round"], writes=["output_number"])
def emit_round(state: State) -> State:
return state.update(output_number=state["round"])

@old_action(reads=["round"], writes=["round"])
def bump(state: State) -> State:
return state.update(round=state["round"] + 1)

class Fan(MapStates):
def action(self, state: State, inputs: Dict[str, Any]):
return emit_round

def states(
self, state: State, context: ApplicationContext, inputs: Dict[str, Any]
) -> Generator[State, None, None]:
for _ in range(3):
yield state

def reduce(self, state: State, states: Generator[State, None, None]) -> State:
new_state = state
for output_state in states:
new_state = new_state.append(outputs=output_state["output_number"])
return new_state

@property
def reads(self) -> list[str]:
return ["round"]

@property
def writes(self) -> list[str]:
return ["outputs"]

persister = InMemoryPersister()
app = (
ApplicationBuilder()
.with_actions(fan=Fan(), bump=bump)
.with_transitions(("fan", "bump"), ("bump", "fan"))
.with_state_persister(persister)
.with_identifiers(app_id="parent-app")
.initialize_from(
persister,
resume_at_next_action=True,
default_state={"round": 1, "outputs": []},
default_entrypoint="fan",
)
.build()
)

app.run(halt_after=["fan"]) # first fan invocation, round=1
app.run(halt_after=["fan"]) # bump runs, then second fan invocation, round=2

# Each fan invocation contributes 3 outputs. Fresh execution -> [1,1,1,2,2,2].
# Buggy behavior replays the first invocation's persisted sub-state -> [1,1,1,1,1,1].
assert list(app.state["outputs"]) == [1, 1, 1, 2, 2, 2]
Loading