Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ async def _handle_tool_execution(

if interrupts:
# Session state stored on AfterInvocationEvent.
agent._interrupt_state.context = {"tool_use_message": message, "tool_results": tool_results}
agent._interrupt_state.context.update({"tool_use_message": message, "tool_results": tool_results})
agent._interrupt_state.activate()

agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
Expand Down
8 changes: 8 additions & 0 deletions src/strands/interrupt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def __init__(self, interrupt: Interrupt) -> None:
self.interrupt = interrupt


class CascadedInterruptException(Exception):
"""Exception raised when a tool cascades one or more interrupts."""

def __init__(self, interrupts: list[Interrupt]) -> None:
"""Set the cascaded interrupts."""
self.interrupts = interrupts


@dataclass
class _InterruptState:
"""Track the state of interrupt events raised by the user.
Expand Down
6 changes: 5 additions & 1 deletion src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
from pydantic_core import PydanticSerializationError
from typing_extensions import override

from ..interrupt import InterruptException
from ..interrupt import CascadedInterruptException, InterruptException
from ..types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent
from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse

Expand Down Expand Up @@ -637,6 +637,10 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
yield ToolInterruptEvent(tool_use, [e.interrupt])
return

except CascadedInterruptException as e:
yield ToolInterruptEvent(tool_use, e.interrupts)
return

except ValueError as e:
# Special handling for validation errors
error_msg = str(e)
Expand Down
58 changes: 57 additions & 1 deletion src/strands/types/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from typing_extensions import NotRequired, TypedDict

from .interrupt import _Interruptible
from ..interrupt import CascadedInterruptException, Interrupt
from .interrupt import InterruptResponseContent, _Interruptible
from .media import DocumentContent, ImageContent

JSONSchema = dict
Expand Down Expand Up @@ -161,6 +162,61 @@ def _interrupt_id(self, name: str) -> str:
"""
return f"v1:tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}"

def cascade_interrupts(self, interrupts: list[Interrupt]) -> None:
"""Cascade interrupts from a sub-agent tool to the orchestrator.

This method stores interrupts on the orchestrator interrupt state and tracks
which interrupt IDs were raised for the current tool use. It always raises
a `CascadedInterruptException` so the tool execution pipeline can emit a
`ToolInterruptEvent`.

Args:
interrupts: Interrupts raised by a sub-agent.

Raises:
CascadedInterruptException: Always raised to stop tool execution and
return interrupts to the orchestrator.
"""
interrupt_state = self.agent._interrupt_state
for interrupt in interrupts:
interrupt_state.interrupts[interrupt.id] = interrupt

context_key = f"cascaded:{self.tool_use['toolUseId']}"
interrupt_state.context[context_key] = [interrupt.id for interrupt in interrupts]
raise CascadedInterruptException(interrupts)

def get_cascaded_interrupt_responses(self) -> list[InterruptResponseContent] | None:
"""Get interrupt responses for cascaded sub-agent interrupts on resume.

Returns:
Formatted interrupt response contents when resuming from a cascaded
interrupt, otherwise None.

Raises:
KeyError: If a tracked cascaded interrupt ID is missing from state.
"""
context_key = f"cascaded:{self.tool_use['toolUseId']}"
cascaded_interrupt_ids = self.agent._interrupt_state.context.get(context_key)
if not cascaded_interrupt_ids:
return None

responses: list[InterruptResponseContent] = []
for interrupt_id in cascaded_interrupt_ids:
if interrupt_id not in self.agent._interrupt_state.interrupts:
raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found")

interrupt = self.agent._interrupt_state.interrupts[interrupt_id]
responses.append(
{
"interruptResponse": {
"interruptId": interrupt.id,
"response": interrupt.response,
}
}
)

return responses


# Individual ToolChoice type aliases
ToolChoiceAutoDict = dict[Literal["auto"], ToolChoiceAuto]
Expand Down
72 changes: 72 additions & 0 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
MaxTokensReachedException,
ModelThrottledException,
)
from strands.types.tools import ToolContext
from tests.fixtures.mock_hook_provider import MockHookProvider
from tests.fixtures.mocked_model_provider import MockedModelProvider

Expand Down Expand Up @@ -1037,6 +1038,77 @@ def interrupt_callback(event):
assert tru_state == exp_state


@pytest.mark.asyncio
async def test_event_loop_cycle_cascaded_interrupt_resume(agent, model, tool_registry, agenerator, alist):
sub_agent_interrupt = Interrupt(
id="sub_agent:interrupt:approval",
name="approval",
reason="approval required",
)
sub_agent_responses = {}

@strands.tool(context=True)
def run_sub_agent(query: str, tool_context: ToolContext) -> str:
responses = tool_context.get_cascaded_interrupt_responses()
if responses is None:
tool_context.cascade_interrupts([sub_agent_interrupt])

sub_agent_responses["query"] = query
sub_agent_responses["responses"] = responses
return "approved"

tool_registry.register_tool(run_sub_agent)

model.stream.side_effect = [
agenerator(
[
{
"contentBlockStart": {
"start": {
"toolUse": {
"toolUseId": "t1",
"name": run_sub_agent.tool_spec["name"],
}
},
}
},
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"query": "delete X"}'}}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "tool_use"}},
]
),
agenerator(
[
{"contentBlockDelta": {"delta": {"text": "completed"}}},
{"contentBlockStop": {}},
]
),
]

first_events = await alist(strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}))
first_stop_reason, _, _, _, first_interrupts, _ = first_events[-1]["stop"]
assert first_stop_reason == "interrupt"
assert first_interrupts == [sub_agent_interrupt]
assert agent._interrupt_state.context["cascaded:t1"] == [sub_agent_interrupt.id]

agent._interrupt_state.resume(
[{"interruptResponse": {"interruptId": sub_agent_interrupt.id, "response": "approved"}}]
)

second_events = await alist(strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}))
second_stop_reason, _, _, _, _, _ = second_events[-1]["stop"]
assert second_stop_reason == "end_turn"
assert sub_agent_responses["responses"] == [
{"interruptResponse": {"interruptId": sub_agent_interrupt.id, "response": "approved"}}
]
assert agent.messages[-2]["content"][0]["toolResult"]["content"][0]["text"] == "approved"
assert agent._interrupt_state.to_dict() == {
"activated": False,
"context": {},
"interrupts": {},
}


@pytest.mark.asyncio
async def test_invalid_tool_names_adds_tool_uses(agent, model, alist):
model.stream = MockedModelProvider(
Expand Down
64 changes: 64 additions & 0 deletions tests/strands/types/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import unittest.mock

import pytest

from strands.interrupt import CascadedInterruptException, Interrupt, _InterruptState
from strands.types.tools import ToolContext


@pytest.fixture
def agent():
instance = unittest.mock.Mock()
instance._interrupt_state = _InterruptState()
return instance


def test_tool_context_cascade_interrupts_stores_interrupts_and_raises(agent):
context = ToolContext(
tool_use={"toolUseId": "tool_1", "name": "sub_agent_tool", "input": {}},
agent=agent,
invocation_state={},
)
interrupts = [
Interrupt(id="sub-1", name="approval_1", reason="first"),
Interrupt(id="sub-2", name="approval_2", reason="second"),
]

with pytest.raises(CascadedInterruptException) as exc_info:
context.cascade_interrupts(interrupts)

assert exc_info.value.interrupts == interrupts
assert agent._interrupt_state.interrupts["sub-1"] == interrupts[0]
assert agent._interrupt_state.interrupts["sub-2"] == interrupts[1]
assert agent._interrupt_state.context["cascaded:tool_1"] == ["sub-1", "sub-2"]


def test_tool_context_get_cascaded_interrupt_responses(agent):
interrupt_1 = Interrupt(id="sub-1", name="approval_1", reason="first", response="approved")
interrupt_2 = Interrupt(id="sub-2", name="approval_2", reason="second", response={"allow": True})
agent._interrupt_state.interrupts = {
interrupt_1.id: interrupt_1,
interrupt_2.id: interrupt_2,
}
agent._interrupt_state.context = {"cascaded:tool_1": [interrupt_1.id, interrupt_2.id]}

context = ToolContext(
tool_use={"toolUseId": "tool_1", "name": "sub_agent_tool", "input": {}},
agent=agent,
invocation_state={},
)

assert context.get_cascaded_interrupt_responses() == [
{"interruptResponse": {"interruptId": "sub-1", "response": "approved"}},
{"interruptResponse": {"interruptId": "sub-2", "response": {"allow": True}}},
]


def test_tool_context_get_cascaded_interrupt_responses_when_not_resuming(agent):
context = ToolContext(
tool_use={"toolUseId": "tool_1", "name": "sub_agent_tool", "input": {}},
agent=agent,
invocation_state={},
)

assert context.get_cascaded_interrupt_responses() is None