From 412546e3d6966ed2a5f936c339095f27a16a0e7f Mon Sep 17 00:00:00 2001 From: Thiago Oliveira Date: Fri, 20 Mar 2026 18:08:11 -0300 Subject: [PATCH] feat: add cascaded interrupt handling for sub-agent tools --- src/strands/event_loop/event_loop.py | 2 +- src/strands/interrupt.py | 8 +++ src/strands/tools/decorator.py | 6 +- src/strands/types/tools.py | 58 ++++++++++++++++- tests/strands/event_loop/test_event_loop.py | 72 +++++++++++++++++++++ tests/strands/types/test_tools.py | 64 ++++++++++++++++++ 6 files changed, 207 insertions(+), 3 deletions(-) create mode 100644 tests/strands/types/test_tools.py diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 2e8e4a660..a15186f0c 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -526,7 +526,7 @@ async def _handle_tool_execution( if interrupts: # Session state stored on AfterInvocationEvent. - agent._interrupt_state.context = {"tool_use_message": message, "tool_results": tool_results} + agent._interrupt_state.context.update({"tool_use_message": message, "tool_results": tool_results}) agent._interrupt_state.activate() agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py index 7d02b50ff..f369c01fc 100644 --- a/src/strands/interrupt.py +++ b/src/strands/interrupt.py @@ -37,6 +37,14 @@ def __init__(self, interrupt: Interrupt) -> None: self.interrupt = interrupt +class CascadedInterruptException(Exception): + """Exception raised when a tool cascades one or more interrupts.""" + + def __init__(self, interrupts: list[Interrupt]) -> None: + """Set the cascaded interrupts.""" + self.interrupts = interrupts + + @dataclass class _InterruptState: """Track the state of interrupt events raised by the user. diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 9207df9b8..bcf081ec9 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -65,7 +65,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: from pydantic_core import PydanticSerializationError from typing_extensions import override -from ..interrupt import InterruptException +from ..interrupt import CascadedInterruptException, InterruptException from ..types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse @@ -637,6 +637,10 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw yield ToolInterruptEvent(tool_use, [e.interrupt]) return + except CascadedInterruptException as e: + yield ToolInterruptEvent(tool_use, e.interrupts) + return + except ValueError as e: # Special handling for validation errors error_msg = str(e) diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 088c83bdb..b1f5e2672 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -13,7 +13,8 @@ from typing_extensions import NotRequired, TypedDict -from .interrupt import _Interruptible +from ..interrupt import CascadedInterruptException, Interrupt +from .interrupt import InterruptResponseContent, _Interruptible from .media import DocumentContent, ImageContent JSONSchema = dict @@ -161,6 +162,61 @@ def _interrupt_id(self, name: str) -> str: """ return f"v1:tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" + def cascade_interrupts(self, interrupts: list[Interrupt]) -> None: + """Cascade interrupts from a sub-agent tool to the orchestrator. + + This method stores interrupts on the orchestrator interrupt state and tracks + which interrupt IDs were raised for the current tool use. It always raises + a `CascadedInterruptException` so the tool execution pipeline can emit a + `ToolInterruptEvent`. + + Args: + interrupts: Interrupts raised by a sub-agent. + + Raises: + CascadedInterruptException: Always raised to stop tool execution and + return interrupts to the orchestrator. + """ + interrupt_state = self.agent._interrupt_state + for interrupt in interrupts: + interrupt_state.interrupts[interrupt.id] = interrupt + + context_key = f"cascaded:{self.tool_use['toolUseId']}" + interrupt_state.context[context_key] = [interrupt.id for interrupt in interrupts] + raise CascadedInterruptException(interrupts) + + def get_cascaded_interrupt_responses(self) -> list[InterruptResponseContent] | None: + """Get interrupt responses for cascaded sub-agent interrupts on resume. + + Returns: + Formatted interrupt response contents when resuming from a cascaded + interrupt, otherwise None. + + Raises: + KeyError: If a tracked cascaded interrupt ID is missing from state. + """ + context_key = f"cascaded:{self.tool_use['toolUseId']}" + cascaded_interrupt_ids = self.agent._interrupt_state.context.get(context_key) + if not cascaded_interrupt_ids: + return None + + responses: list[InterruptResponseContent] = [] + for interrupt_id in cascaded_interrupt_ids: + if interrupt_id not in self.agent._interrupt_state.interrupts: + raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") + + interrupt = self.agent._interrupt_state.interrupts[interrupt_id] + responses.append( + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": interrupt.response, + } + } + ) + + return responses + # Individual ToolChoice type aliases ToolChoiceAutoDict = dict[Literal["auto"], ToolChoiceAuto] diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index cedca269b..78b6fffa2 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -28,6 +28,7 @@ MaxTokensReachedException, ModelThrottledException, ) +from strands.types.tools import ToolContext from tests.fixtures.mock_hook_provider import MockHookProvider from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -1037,6 +1038,77 @@ def interrupt_callback(event): assert tru_state == exp_state +@pytest.mark.asyncio +async def test_event_loop_cycle_cascaded_interrupt_resume(agent, model, tool_registry, agenerator, alist): + sub_agent_interrupt = Interrupt( + id="sub_agent:interrupt:approval", + name="approval", + reason="approval required", + ) + sub_agent_responses = {} + + @strands.tool(context=True) + def run_sub_agent(query: str, tool_context: ToolContext) -> str: + responses = tool_context.get_cascaded_interrupt_responses() + if responses is None: + tool_context.cascade_interrupts([sub_agent_interrupt]) + + sub_agent_responses["query"] = query + sub_agent_responses["responses"] = responses + return "approved" + + tool_registry.register_tool(run_sub_agent) + + model.stream.side_effect = [ + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": run_sub_agent.tool_spec["name"], + } + }, + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"query": "delete X"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ), + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "completed"}}}, + {"contentBlockStop": {}}, + ] + ), + ] + + first_events = await alist(strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={})) + first_stop_reason, _, _, _, first_interrupts, _ = first_events[-1]["stop"] + assert first_stop_reason == "interrupt" + assert first_interrupts == [sub_agent_interrupt] + assert agent._interrupt_state.context["cascaded:t1"] == [sub_agent_interrupt.id] + + agent._interrupt_state.resume( + [{"interruptResponse": {"interruptId": sub_agent_interrupt.id, "response": "approved"}}] + ) + + second_events = await alist(strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={})) + second_stop_reason, _, _, _, _, _ = second_events[-1]["stop"] + assert second_stop_reason == "end_turn" + assert sub_agent_responses["responses"] == [ + {"interruptResponse": {"interruptId": sub_agent_interrupt.id, "response": "approved"}} + ] + assert agent.messages[-2]["content"][0]["toolResult"]["content"][0]["text"] == "approved" + assert agent._interrupt_state.to_dict() == { + "activated": False, + "context": {}, + "interrupts": {}, + } + + @pytest.mark.asyncio async def test_invalid_tool_names_adds_tool_uses(agent, model, alist): model.stream = MockedModelProvider( diff --git a/tests/strands/types/test_tools.py b/tests/strands/types/test_tools.py new file mode 100644 index 000000000..a8bc168a6 --- /dev/null +++ b/tests/strands/types/test_tools.py @@ -0,0 +1,64 @@ +import unittest.mock + +import pytest + +from strands.interrupt import CascadedInterruptException, Interrupt, _InterruptState +from strands.types.tools import ToolContext + + +@pytest.fixture +def agent(): + instance = unittest.mock.Mock() + instance._interrupt_state = _InterruptState() + return instance + + +def test_tool_context_cascade_interrupts_stores_interrupts_and_raises(agent): + context = ToolContext( + tool_use={"toolUseId": "tool_1", "name": "sub_agent_tool", "input": {}}, + agent=agent, + invocation_state={}, + ) + interrupts = [ + Interrupt(id="sub-1", name="approval_1", reason="first"), + Interrupt(id="sub-2", name="approval_2", reason="second"), + ] + + with pytest.raises(CascadedInterruptException) as exc_info: + context.cascade_interrupts(interrupts) + + assert exc_info.value.interrupts == interrupts + assert agent._interrupt_state.interrupts["sub-1"] == interrupts[0] + assert agent._interrupt_state.interrupts["sub-2"] == interrupts[1] + assert agent._interrupt_state.context["cascaded:tool_1"] == ["sub-1", "sub-2"] + + +def test_tool_context_get_cascaded_interrupt_responses(agent): + interrupt_1 = Interrupt(id="sub-1", name="approval_1", reason="first", response="approved") + interrupt_2 = Interrupt(id="sub-2", name="approval_2", reason="second", response={"allow": True}) + agent._interrupt_state.interrupts = { + interrupt_1.id: interrupt_1, + interrupt_2.id: interrupt_2, + } + agent._interrupt_state.context = {"cascaded:tool_1": [interrupt_1.id, interrupt_2.id]} + + context = ToolContext( + tool_use={"toolUseId": "tool_1", "name": "sub_agent_tool", "input": {}}, + agent=agent, + invocation_state={}, + ) + + assert context.get_cascaded_interrupt_responses() == [ + {"interruptResponse": {"interruptId": "sub-1", "response": "approved"}}, + {"interruptResponse": {"interruptId": "sub-2", "response": {"allow": True}}}, + ] + + +def test_tool_context_get_cascaded_interrupt_responses_when_not_resuming(agent): + context = ToolContext( + tool_use={"toolUseId": "tool_1", "name": "sub_agent_tool", "input": {}}, + agent=agent, + invocation_state={}, + ) + + assert context.get_cascaded_interrupt_responses() is None