From 2418ad251f523e828db32735c70d636f95a7f96c Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Wed, 18 Mar 2026 15:55:29 -0400 Subject: [PATCH 1/8] feat: add AgentAsTool --- src/strands/agent/__init__.py | 2 + src/strands/agent/agent.py | 29 +++ src/strands/agent/agent_as_tool.py | 173 ++++++++++++++++++ tests/strands/agent/test_agent.py | 50 +++++ tests/strands/agent/test_agent_as_tool.py | 213 ++++++++++++++++++++++ tests_integ/test_agent_as_tool.py | 36 ++++ 6 files changed, 503 insertions(+) create mode 100644 src/strands/agent/agent_as_tool.py create mode 100644 tests/strands/agent/test_agent_as_tool.py create mode 100644 tests_integ/test_agent_as_tool.py diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index c901e800f..a4911b34e 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -11,6 +11,7 @@ from ..event_loop._retry import ModelRetryStrategy from .agent import Agent +from .agent_as_tool import AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( @@ -24,6 +25,7 @@ "Agent", "AgentBase", "AgentResult", + "AgentAsTool", "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a..ad2b7a14b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -62,6 +62,7 @@ from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue +from .agent_as_tool import AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( @@ -612,6 +613,34 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu finally: await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, invocation_state={})) + def as_tool( + self, + name: str | None = None, + description: str | None = None, + ) -> AgentAsTool: + r"""Convert this agent into a tool for use by another agent. + + Args: + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + Defaults to the agent's name. + description: Tool description. Defaults to the agent's description. + + Returns: + An AgentAsTool wrapping this agent. + + Example: + ```python + researcher = Agent(name="researcher", description="Finds information") + writer = Agent(name="writer", tools=[researcher.as_tool()]) + writer("Write about AI agents") + ``` + """ + if not name: + name = self.name + if not description: + description = self.description or f"Use the {name} tool to invoke this agent as a tool" + return AgentAsTool(self, name=name, description=description) + def cleanup(self) -> None: """Clean up resources used by the agent. diff --git a/src/strands/agent/agent_as_tool.py b/src/strands/agent/agent_as_tool.py new file mode 100644 index 000000000..713456b91 --- /dev/null +++ b/src/strands/agent/agent_as_tool.py @@ -0,0 +1,173 @@ +"""Agent-as-tool adapter. + +This module provides the AgentAsTool class that wraps an Agent (or any AgentBase) as a tool +so it can be passed to another agent's tool list. +""" + +import logging +from typing import Any + +from typing_extensions import override + +from ..types._events import ToolResultEvent +from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse +from .base import AgentBase + +logger = logging.getLogger(__name__) + + +class AgentAsTool(AgentTool): + """Adapter that exposes an Agent as a tool for use by other agents. + + The tool accepts a single ``input`` string parameter, invokes the wrapped + agent, and returns the text response. + + Example: + ```python + from strands import Agent + from strands.agent.agent_as_tool import AgentAsTool + + researcher = Agent(name="researcher", description="Finds information") + + # Use directly + tool = AgentAsTool(researcher, name="researcher", description="Finds information") + + # Or via convenience method + tool = researcher.as_tool() + + writer = Agent(name="writer", tools=[tool]) + writer("Write about AI agents") + ``` + """ + + def __init__( + self, + agent: AgentBase, + *, + name: str, + description: str, + ) -> None: + r"""Initialize the agent-as-tool adapter. + + Args: + agent: The agent to wrap as a tool. + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + description: Tool description. + """ + super().__init__() + self._agent = agent + self._tool_name = name + self._description = description + + @property + def agent(self) -> AgentBase: + """The wrapped agent instance.""" + return self._agent + + @property + def tool_name(self) -> str: + """Get the tool name.""" + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + """Get the tool specification.""" + return { + "name": self._tool_name, + "description": self._description, + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "The input to send to the agent tool.", + } + }, + "required": ["input"], + } + }, + } + + @property + def tool_type(self) -> str: + """Get the tool type.""" + return "agent" + + @override + async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: + """Invoke the wrapped agent via streaming and yield events. + + Intermediate agent events are forwarded as ToolStreamEvents so the parent + agent's callback handler can display sub-agent progress. The final + AgentResult is yielded as a ToolResultEvent. + + Args: + tool_use: The tool use request containing the input parameter. + invocation_state: Context for the tool invocation. + **kwargs: Additional keyword arguments. + + Yields: + ToolStreamEvent for intermediate events, then ToolResultEvent with the final response. + """ + prompt = tool_use["input"].get("input", "") if isinstance(tool_use["input"], dict) else tool_use["input"] + tool_use_id = tool_use["toolUseId"] + + logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) + + try: + result = None + async for event in self._agent.stream_async(prompt): + if "result" in event: + result = event["result"] + else: + yield event + + if result is None: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": "Agent did not produce a result"}], + } + ) + return + + if result.structured_output: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"json": result.structured_output.model_dump()}], + } + ) + else: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": str(result)}], + } + ) + + except Exception as e: + logger.warning( + "tool_name=<%s>, tool_use_id=<%s> | agent invocation failed: %s", + self._tool_name, + tool_use_id, + e, + ) + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Agent error: {e}"}], + } + ) + + @override + def get_display_properties(self) -> dict[str, str]: + """Get properties for UI display.""" + properties = super().get_display_properties() + properties["Agent"] = getattr(self._agent, "name", "unknown") + return properties diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 967a0dafb..6bb64f870 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -16,6 +16,7 @@ import strands from strands import Agent, Plugin, ToolContext from strands.agent import AgentResult +from strands.agent.agent_as_tool import AgentAsTool from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState @@ -2699,3 +2700,52 @@ def hook_callback(event: BeforeModelCallEvent): agent("test") assert len(hook_called) == 1 + + +def test_as_tool_returns_agent_tool(): + """Test that as_tool returns an AgentAsTool wrapping the agent.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert isinstance(tool, AgentAsTool) + assert tool.agent is agent + + +def test_as_tool_defaults_name_from_agent(): + """Test that as_tool defaults the tool name to the agent's name.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_name == "researcher" + + +def test_as_tool_defaults_description_from_agent(): + """Test that as_tool defaults the description to the agent's description.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Finds information" + + +def test_as_tool_custom_name(): + """Test that as_tool accepts a custom name.""" + agent = Agent(name="researcher") + tool = agent.as_tool(name="custom_name") + + assert tool.tool_name == "custom_name" + + +def test_as_tool_custom_description(): + """Test that as_tool accepts a custom description.""" + agent = Agent(name="researcher", description="Original") + tool = agent.as_tool(description="Custom description") + + assert tool.tool_spec["description"] == "Custom description" + + +def test_as_tool_defaults_description_when_agent_has_none(): + """Test that as_tool generates a default description when agent has none.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Use the researcher tool to invoke this agent as a tool" diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py new file mode 100644 index 000000000..2b2fb9ca6 --- /dev/null +++ b/tests/strands/agent/test_agent_as_tool.py @@ -0,0 +1,213 @@ +"""Tests for AgentAsTool - the agent-as-tool adapter.""" + +from unittest.mock import MagicMock + +import pytest + +from strands.agent.agent_as_tool import AgentAsTool +from strands.agent.agent_result import AgentResult +from strands.telemetry.metrics import EventLoopMetrics +from strands.types._events import ToolResultEvent + + +async def _mock_stream_async(result, intermediate_events=None): + """Helper that yields intermediate events then the final result event.""" + for event in intermediate_events or []: + yield event + yield {"result": result} + + +@pytest.fixture +def mock_agent(): + agent = MagicMock() + agent.name = "test_agent" + agent.description = "A test agent" + return agent + + +@pytest.fixture +def tool(mock_agent): + return AgentAsTool(mock_agent, name="test_agent", description="A test agent") + + +@pytest.fixture +def tool_use(): + return { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {"input": "hello"}, + } + + +@pytest.fixture +def agent_result(): + return AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "response text"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + + +# --- init --- + + +def test_init_sets_name(mock_agent): + tool = AgentAsTool(mock_agent, name="my_tool", description="desc") + assert tool.tool_name == "my_tool" + + +def test_init_sets_description(mock_agent): + tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc") + assert tool._description == "custom desc" + + +def test_init_stores_agent_reference(mock_agent, tool): + assert tool.agent is mock_agent + + +# --- properties --- + + +def test_tool_name(tool): + assert tool.tool_name == "test_agent" + + +def test_tool_type(tool): + assert tool.tool_type == "agent" + + +def test_tool_spec_name(tool): + assert tool.tool_spec["name"] == "test_agent" + + +def test_tool_spec_description(tool): + assert tool.tool_spec["description"] == "A test agent" + + +def test_tool_spec_input_schema(tool): + schema = tool.tool_spec["inputSchema"]["json"] + assert schema["type"] == "object" + assert "input" in schema["properties"] + assert schema["properties"]["input"]["type"] == "string" + assert schema["required"] == ["input"] + + +def test_display_properties(tool): + props = tool.get_display_properties() + assert props["Agent"] == "test_agent" + assert props["Type"] == "agent" + + +# --- stream --- + + +@pytest.mark.asyncio +async def test_stream_success(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["text"] == "response text\n" + + +@pytest.mark.asyncio +async def test_stream_passes_input_to_agent(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("hello") + + +@pytest.mark.asyncio +async def test_stream_empty_input(tool, mock_agent, agent_result): + empty_tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {}, + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(empty_tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("") + + +@pytest.mark.asyncio +async def test_stream_error(tool, mock_agent, tool_use): + mock_agent.stream_async.side_effect = RuntimeError("boom") + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "boom" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_propagates_tool_use_id(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["toolUseId"] == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_forwards_intermediate_events(tool, mock_agent, tool_use, agent_result): + intermediate = [{"data": "partial"}, {"data": "more"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + # Intermediate events are yielded as-is (raw dicts); wrapping in ToolStreamEvent happens in the caller + non_result_events = [e for e in events if not isinstance(e, ToolResultEvent)] + assert len(non_result_events) == 2 + assert non_result_events[0]["data"] == "partial" + assert non_result_events[1]["data"] == "more" + + +@pytest.mark.asyncio +async def test_stream_no_result_yields_error(tool, mock_agent, tool_use): + async def _empty_stream(): + return + yield # noqa: RET504 - make it an async generator + + mock_agent.stream_async.return_value = _empty_stream() + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "did not produce a result" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_structured_output(tool, mock_agent, tool_use): + from pydantic import BaseModel + + class MyOutput(BaseModel): + answer: str + + structured = MyOutput(answer="42") + result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ignored"}]}, + metrics=EventLoopMetrics(), + state={}, + structured_output=structured, + ) + mock_agent.stream_async.return_value = _mock_stream_async(result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["json"] == {"answer": "42"} diff --git a/tests_integ/test_agent_as_tool.py b/tests_integ/test_agent_as_tool.py new file mode 100644 index 000000000..a808fcd23 --- /dev/null +++ b/tests_integ/test_agent_as_tool.py @@ -0,0 +1,36 @@ +import pytest + +from strands import Agent, tool + + +@tool +def get_tiger_height() -> int: + """Returns the height of a tiger in centimeters.""" + return 100 + + +@pytest.mark.asyncio +async def test_stream_async_with_agent_tool(): + inner_agent = Agent( + name="myAgentTool", + description="An agent tool knowledgeable about tigers", + tools=[get_tiger_height], + ) + agent_tool = inner_agent.as_tool() + agent = Agent( + name="myOtherAgent", + tools=[agent_tool], + ) + + result = await agent.invoke_async( + prompt="Invoke the myAgentTool and ask about the height of tigers.", + ) + + # Outer agent completed and called the agent tool + assert result.stop_reason == "end_turn" + assert "myAgentTool" in result.metrics.tool_metrics + assert result.metrics.tool_metrics["myAgentTool"].success_count >= 1 + + # Inner agent called get_tiger_height + assert "get_tiger_height" in inner_agent.event_loop_metrics.tool_metrics + assert inner_agent.event_loop_metrics.tool_metrics["get_tiger_height"].success_count >= 1 From 2826d268c01008c87ef93659f6e6208df59c9869 Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Thu, 19 Mar 2026 14:58:55 -0400 Subject: [PATCH 2/8] feat: add preserve_context to AgentAsTool --- src/strands/agent/agent_as_tool.py | 46 ++++++- tests/strands/agent/test_agent_as_tool.py | 139 ++++++++++++++++++++++ 2 files changed, 183 insertions(+), 2 deletions(-) diff --git a/src/strands/agent/agent_as_tool.py b/src/strands/agent/agent_as_tool.py index 713456b91..c2aefe021 100644 --- a/src/strands/agent/agent_as_tool.py +++ b/src/strands/agent/agent_as_tool.py @@ -82,7 +82,14 @@ def tool_spec(self) -> ToolSpec: "input": { "type": "string", "description": "The input to send to the agent tool.", - } + }, + "preserve_context": { + "type": "boolean", + "description": ( + "Whether to preserve the agent's conversation context across invocations. " + "Defaults to true. Set to false to clear conversation history before this call." + ), + }, }, "required": ["input"], } @@ -110,9 +117,44 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw Yields: ToolStreamEvent for intermediate events, then ToolResultEvent with the final response. """ - prompt = tool_use["input"].get("input", "") if isinstance(tool_use["input"], dict) else tool_use["input"] + tool_input = tool_use["input"] + if isinstance(tool_input, dict): + prompt = tool_input.get("input", "") + preserve_context = tool_input.get("preserve_context", True) + elif isinstance(tool_input, str): + prompt = tool_input + preserve_context = True + else: + logger.warning( + "tool_name=<%s> | unexpected input type: %s", + self._tool_name, + type(tool_input), + ) + prompt = str(tool_input) + preserve_context = True + tool_use_id = tool_use["toolUseId"] + if not preserve_context: + # AgentBase is a protocol and does not guarantee a messages attribute. + # We check for it at runtime to support Agent and other implementations + # that expose a mutable messages list. + messages = getattr(self._agent, "messages", None) + if isinstance(messages, list): + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | clearing agent conversation context", + self._tool_name, + tool_use_id, + ) + messages.clear() + else: + logger.warning( + "tool_name=<%s>, tool_use_id=<%s> | preserve_context=false requested" + " but agent does not expose a messages list", + self._tool_name, + tool_use_id, + ) + logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) try: diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py index 2b2fb9ca6..f142090a8 100644 --- a/tests/strands/agent/test_agent_as_tool.py +++ b/tests/strands/agent/test_agent_as_tool.py @@ -1,5 +1,6 @@ """Tests for AgentAsTool - the agent-as-tool adapter.""" +import logging from unittest.mock import MagicMock import pytest @@ -90,6 +91,8 @@ def test_tool_spec_input_schema(tool): assert schema["type"] == "object" assert "input" in schema["properties"] assert schema["properties"]["input"]["type"] == "string" + assert "preserve_context" in schema["properties"] + assert schema["properties"]["preserve_context"]["type"] == "boolean" assert schema["required"] == ["input"] @@ -211,3 +214,139 @@ class MyOutput(BaseModel): result_events = [e for e in events if isinstance(e, ToolResultEvent)] assert result_events[0]["tool_result"]["status"] == "success" assert result_events[0]["tool_result"]["content"][0]["json"] == {"answer": "42"} + + +@pytest.mark.asyncio +async def test_stream_string_input(tool, mock_agent, agent_result): + """When tool_use input is a plain string rather than a dict.""" + tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": "direct string", + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("direct string") + + +# --- preserve_context --- + + +class _FakeAgent: + """Minimal fake agent with a real messages list for preserve_context tests.""" + + def __init__(self): + self.name = "fake_agent" + self.messages: list = [] + + async def invoke_async(self, prompt=None, **kwargs): + pass + + def __call__(self, prompt=None, **kwargs): + pass + + def stream_async(self, prompt=None, **kwargs): + return _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + +@pytest.mark.asyncio +async def test_stream_clears_context_when_preserve_context_false(): + agent = _FakeAgent() + agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + tool = AgentAsTool(agent, name="fake_agent", description="desc") + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello", "preserve_context": False}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert agent.messages == [] + + +@pytest.mark.asyncio +async def test_stream_preserves_context_by_default(): + agent = _FakeAgent() + agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + tool = AgentAsTool(agent, name="fake_agent", description="desc") + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert len(agent.messages) >= 1 + + +@pytest.mark.asyncio +async def test_stream_preserves_context_when_explicitly_true(): + agent = _FakeAgent() + agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + tool = AgentAsTool(agent, name="fake_agent", description="desc") + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello", "preserve_context": True}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert len(agent.messages) >= 1 + + +@pytest.mark.asyncio +async def test_stream_preserve_context_false_warns_when_no_messages_attr(caplog): + """Agent without a messages attribute should log a warning.""" + + class _NoMessagesAgent: + name = "bare_agent" + + async def invoke_async(self, prompt=None, **kwargs): + pass + + def __call__(self, prompt=None, **kwargs): + pass + + def stream_async(self, prompt=None, **kwargs): + return _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + agent = _NoMessagesAgent() + tool = AgentAsTool(agent, name="bare_agent", description="desc") + + tool_use = { + "toolUseId": "tool-123", + "name": "bare_agent", + "input": {"input": "hello", "preserve_context": False}, + } + + with caplog.at_level(logging.WARNING, logger="strands.agent.agent_as_tool"): + async for _ in tool.stream(tool_use, {}): + pass + + assert "preserve_context=false requested" in caplog.text From 9ca227c1dad9b092b00f6cd527f5b01c3e487dd0 Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Fri, 20 Mar 2026 14:22:41 -0400 Subject: [PATCH 3/8] fix: move preserve_context to class var; reset to original agent state; yield AgentAsToolStreamEvents; small fixes --- src/strands/__init__.py | 2 + src/strands/agent/__init__.py | 2 +- .../{agent_as_tool.py => _agent_as_tool.py} | 102 ++++--- src/strands/agent/agent.py | 12 +- src/strands/types/_events.py | 26 ++ tests/strands/agent/test_agent.py | 5 +- tests/strands/agent/test_agent_as_tool.py | 279 +++++++++++------- tests/strands/types/test__events.py | 37 +++ 8 files changed, 317 insertions(+), 148 deletions(-) rename src/strands/agent/{agent_as_tool.py => _agent_as_tool.py} (63%) diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 2078f16ce..dc3a0c7ff 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -1,6 +1,7 @@ """A framework for building, deploying, and managing AI agents.""" from . import agent, models, telemetry, types +from .agent._agent_as_tool import AgentAsTool from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy @@ -11,6 +12,7 @@ __all__ = [ "Agent", + "AgentAsTool", "AgentBase", "AgentSkills", "agent", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index a4911b34e..d0254852d 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -10,8 +10,8 @@ from typing import Any from ..event_loop._retry import ModelRetryStrategy +from ._agent_as_tool import AgentAsTool from .agent import Agent -from .agent_as_tool import AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( diff --git a/src/strands/agent/agent_as_tool.py b/src/strands/agent/_agent_as_tool.py similarity index 63% rename from src/strands/agent/agent_as_tool.py rename to src/strands/agent/_agent_as_tool.py index c2aefe021..534be5d9a 100644 --- a/src/strands/agent/agent_as_tool.py +++ b/src/strands/agent/_agent_as_tool.py @@ -4,12 +4,15 @@ so it can be passed to another agent's tool list. """ +import copy import logging from typing import Any from typing_extensions import override -from ..types._events import ToolResultEvent +from ..agent.state import AgentState +from ..types._events import AgentAsToolStreamEvent, ToolResultEvent +from ..types.content import Messages from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse from .base import AgentBase @@ -25,7 +28,7 @@ class AgentAsTool(AgentTool): Example: ```python from strands import Agent - from strands.agent.agent_as_tool import AgentAsTool + from strands.agent import AgentAsTool researcher = Agent(name="researcher", description="Finds information") @@ -35,6 +38,9 @@ class AgentAsTool(AgentTool): # Or via convenience method tool = researcher.as_tool() + # Start each invocation with a fresh conversation + tool = researcher.as_tool(preserve_context=False) + writer = Agent(name="writer", tools=[tool]) writer("Write about AI agents") ``` @@ -46,6 +52,7 @@ def __init__( *, name: str, description: str, + preserve_context: bool = True, ) -> None: r"""Initialize the agent-as-tool adapter. @@ -53,11 +60,34 @@ def __init__( agent: The agent to wrap as a tool. name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. description: Tool description. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to True. Only effective when the + wrapped agent exposes a mutable ``messages`` list and/or an ``AgentState`` + (e.g. ``strands.agent.Agent``). """ super().__init__() self._agent = agent self._tool_name = name self._description = description + self._preserve_context = preserve_context + + # When preserve_context=False, we snapshot the agent's initial state so we can + # restore it before each invocation. This mirrors GraphNode.reset_executor_state(). + # We require an Agent instance for this since AgentBase doesn't guarantee + # messages/state attributes. + self._initial_messages: Messages = [] + self._initial_state: AgentState = AgentState() + + if not preserve_context: + from .agent import Agent + + if not isinstance(agent, Agent): + raise TypeError(f"preserve_context=False requires an Agent instance, got {type(agent).__name__}") + self._initial_messages = copy.deepcopy(agent.messages) + self._initial_state = AgentState(agent.state.get()) @property def agent(self) -> AgentBase: @@ -83,13 +113,6 @@ def tool_spec(self) -> ToolSpec: "type": "string", "description": "The input to send to the agent tool.", }, - "preserve_context": { - "type": "boolean", - "description": ( - "Whether to preserve the agent's conversation context across invocations. " - "Defaults to true. Set to false to clear conversation history before this call." - ), - }, }, "required": ["input"], } @@ -105,8 +128,8 @@ def tool_type(self) -> str: async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: """Invoke the wrapped agent via streaming and yield events. - Intermediate agent events are forwarded as ToolStreamEvents so the parent - agent's callback handler can display sub-agent progress. The final + Intermediate agent events are wrapped in AgentAsToolStreamEvent so the caller + can distinguish sub-agent progress from regular tool events. The final AgentResult is yielded as a ToolResultEvent. Args: @@ -115,45 +138,21 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw **kwargs: Additional keyword arguments. Yields: - ToolStreamEvent for intermediate events, then ToolResultEvent with the final response. + AgentAsToolStreamEvent for intermediate events, then ToolResultEvent with the final response. """ tool_input = tool_use["input"] if isinstance(tool_input, dict): prompt = tool_input.get("input", "") - preserve_context = tool_input.get("preserve_context", True) elif isinstance(tool_input, str): prompt = tool_input - preserve_context = True else: - logger.warning( - "tool_name=<%s> | unexpected input type: %s", - self._tool_name, - type(tool_input), - ) + logger.warning("tool_name=<%s> | unexpected input type: %s", self._tool_name, type(tool_input)) prompt = str(tool_input) - preserve_context = True tool_use_id = tool_use["toolUseId"] - if not preserve_context: - # AgentBase is a protocol and does not guarantee a messages attribute. - # We check for it at runtime to support Agent and other implementations - # that expose a mutable messages list. - messages = getattr(self._agent, "messages", None) - if isinstance(messages, list): - logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | clearing agent conversation context", - self._tool_name, - tool_use_id, - ) - messages.clear() - else: - logger.warning( - "tool_name=<%s>, tool_use_id=<%s> | preserve_context=false requested" - " but agent does not expose a messages list", - self._tool_name, - tool_use_id, - ) + if not self._preserve_context: + self._reset_agent_state(tool_use_id) logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) @@ -163,7 +162,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw if "result" in event: result = event["result"] else: - yield event + yield AgentAsToolStreamEvent(tool_use, event, self) if result is None: yield ToolResultEvent( @@ -207,6 +206,29 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw } ) + def _reset_agent_state(self, tool_use_id: str) -> None: + """Reset the wrapped agent to its initial state. + + Restores messages and state to the values captured at construction time. + This mirrors the pattern used by ``GraphNode.reset_executor_state()``. + + Args: + tool_use_id: Tool use ID for logging context. + """ + from .agent import Agent + + # isinstance narrows the type for mypy; __init__ guarantees this when preserve_context=False + if not isinstance(self._agent, Agent): + return + + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | resetting agent to initial state", + self._tool_name, + tool_use_id, + ) + self._agent.messages = copy.deepcopy(self._initial_messages) + self._agent.state = AgentState(self._initial_state.get()) + @override def get_display_properties(self) -> dict[str, str]: """Get properties for UI display.""" diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ad2b7a14b..8d94de45b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -62,7 +62,7 @@ from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue -from .agent_as_tool import AgentAsTool +from ._agent_as_tool import AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( @@ -617,6 +617,7 @@ def as_tool( self, name: str | None = None, description: str | None = None, + preserve_context: bool = True, ) -> AgentAsTool: r"""Convert this agent into a tool for use by another agent. @@ -624,6 +625,11 @@ def as_tool( name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. Defaults to the agent's name. description: Tool description. Defaults to the agent's description. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to True. Returns: An AgentAsTool wrapping this agent. @@ -638,8 +644,8 @@ def as_tool( if not name: name = self.name if not description: - description = self.description or f"Use the {name} tool to invoke this agent as a tool" - return AgentAsTool(self, name=name, description=description) + description = self.description or f"Use the {name} agent as a tool by providing a natural language input" + return AgentAsTool(self, name=name, description=description, preserve_context=preserve_context) def cleanup(self) -> None: """Clean up resources used by the agent. diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 5b0ae78f6..5603aedfb 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from ..agent import AgentResult + from ..agent._agent_as_tool import AgentAsTool from ..multiagent.base import MultiAgentResult, NodeResult @@ -323,6 +324,31 @@ def tool_use_id(self) -> str: return cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use"))["toolUseId"] +class AgentAsToolStreamEvent(ToolStreamEvent): + """Event emitted when an agent-as-tool yields intermediate events during execution. + + Extends ToolStreamEvent with a reference to the originating AgentAsTool so callers + can distinguish sub-agent stream events from regular tool stream events and access + the wrapped agent, tool name, description, etc. + """ + + def __init__(self, tool_use: ToolUse, tool_stream_data: Any, agent_as_tool: "AgentAsTool") -> None: + """Initialize with tool streaming data and agent-tool reference. + + Args: + tool_use: The tool invocation producing the stream. + tool_stream_data: The yielded event from the sub-agent execution. + agent_as_tool: The AgentAsTool instance that produced this event. + """ + super().__init__(tool_use, tool_stream_data) + self._agent_as_tool = agent_as_tool + + @property + def agent_as_tool(self) -> "AgentAsTool": + """The AgentAsTool instance that produced this event.""" + return self._agent_as_tool + + class ToolCancelEvent(TypedEvent): """Event emitted when a user cancels a tool call from their BeforeToolCallEvent hook.""" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 6bb64f870..c089ba808 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -15,8 +15,7 @@ import strands from strands import Agent, Plugin, ToolContext -from strands.agent import AgentResult -from strands.agent.agent_as_tool import AgentAsTool +from strands.agent import AgentAsTool, AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState @@ -2748,4 +2747,4 @@ def test_as_tool_defaults_description_when_agent_has_none(): agent = Agent(name="researcher") tool = agent.as_tool() - assert tool.tool_spec["description"] == "Use the researcher tool to invoke this agent as a tool" + assert tool.tool_spec["description"] == "Use the researcher agent as a tool by providing a natural language input" diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py index f142090a8..a6cc28d51 100644 --- a/tests/strands/agent/test_agent_as_tool.py +++ b/tests/strands/agent/test_agent_as_tool.py @@ -1,14 +1,13 @@ """Tests for AgentAsTool - the agent-as-tool adapter.""" -import logging from unittest.mock import MagicMock import pytest -from strands.agent.agent_as_tool import AgentAsTool +from strands.agent import AgentAsTool from strands.agent.agent_result import AgentResult from strands.telemetry.metrics import EventLoopMetrics -from strands.types._events import ToolResultEvent +from strands.types._events import AgentAsToolStreamEvent, ToolResultEvent, ToolStreamEvent async def _mock_stream_async(result, intermediate_events=None): @@ -53,50 +52,40 @@ def agent_result(): # --- init --- -def test_init_sets_name(mock_agent): - tool = AgentAsTool(mock_agent, name="my_tool", description="desc") +def test_init(mock_agent): + tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc") assert tool.tool_name == "my_tool" + assert tool._description == "custom desc" + assert tool.agent is mock_agent -def test_init_sets_description(mock_agent): - tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc") - assert tool._description == "custom desc" +def test_init_preserve_context_defaults_true(mock_agent): + tool = AgentAsTool(mock_agent, name="t", description="d") + assert tool._preserve_context is True -def test_init_stores_agent_reference(mock_agent, tool): - assert tool.agent is mock_agent +def test_init_preserve_context_false(fake_agent): + tool = AgentAsTool(fake_agent, name="t", description="d", preserve_context=False) + assert tool._preserve_context is False # --- properties --- -def test_tool_name(tool): +def test_tool_properties(tool): assert tool.tool_name == "test_agent" - - -def test_tool_type(tool): assert tool.tool_type == "agent" + spec = tool.tool_spec + assert spec["name"] == "test_agent" + assert spec["description"] == "A test agent" -def test_tool_spec_name(tool): - assert tool.tool_spec["name"] == "test_agent" - - -def test_tool_spec_description(tool): - assert tool.tool_spec["description"] == "A test agent" - - -def test_tool_spec_input_schema(tool): - schema = tool.tool_spec["inputSchema"]["json"] + schema = spec["inputSchema"]["json"] assert schema["type"] == "object" assert "input" in schema["properties"] assert schema["properties"]["input"]["type"] == "string" - assert "preserve_context" in schema["properties"] - assert schema["properties"]["preserve_context"]["type"] == "boolean" assert schema["required"] == ["input"] - -def test_display_properties(tool): props = tool.get_display_properties() assert props["Agent"] == "test_agent" assert props["Type"] == "agent" @@ -142,6 +131,21 @@ async def test_stream_empty_input(tool, mock_agent, agent_result): mock_agent.stream_async.assert_called_once_with("") +@pytest.mark.asyncio +async def test_stream_string_input(tool, mock_agent, agent_result): + tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": "direct string", + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("direct string") + + @pytest.mark.asyncio async def test_stream_error(tool, mock_agent, tool_use): mock_agent.stream_async.side_effect = RuntimeError("boom") @@ -170,11 +174,32 @@ async def test_stream_forwards_intermediate_events(tool, mock_agent, tool_use, a events = [event async for event in tool.stream(tool_use, {})] - # Intermediate events are yielded as-is (raw dicts); wrapping in ToolStreamEvent happens in the caller - non_result_events = [e for e in events if not isinstance(e, ToolResultEvent)] - assert len(non_result_events) == 2 - assert non_result_events[0]["data"] == "partial" - assert non_result_events[1]["data"] == "more" + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 2 + assert stream_events[0]["tool_stream_event"]["data"]["data"] == "partial" + assert stream_events[1]["tool_stream_event"]["data"]["data"] == "more" + assert stream_events[0].agent_as_tool is tool + assert stream_events[0].tool_use_id == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_events_not_double_wrapped_by_executor(tool, mock_agent, tool_use, agent_result): + """AgentAsToolStreamEvent is a ToolStreamEvent subclass, so the executor should pass it through directly.""" + intermediate = [{"data": "chunk"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 1 + + event = stream_events[0] + # It's a ToolStreamEvent (so the executor yields it directly) + assert isinstance(event, ToolStreamEvent) + # But it's specifically an AgentAsToolStreamEvent (not re-wrapped) + assert type(event) is AgentAsToolStreamEvent + # And it references the originating AgentAsTool + assert event.agent_as_tool is tool @pytest.mark.asyncio @@ -216,72 +241,133 @@ class MyOutput(BaseModel): assert result_events[0]["tool_result"]["content"][0]["json"] == {"answer": "42"} +# --- preserve_context --- + + +@pytest.fixture +def fake_agent(): + """A real Agent instance for preserve_context tests.""" + from strands.agent.agent import Agent + + return Agent(name="fake_agent", callback_handler=None) + + @pytest.mark.asyncio -async def test_stream_string_input(tool, mock_agent, agent_result): - """When tool_use input is a plain string rather than a dict.""" +async def test_stream_resets_to_initial_state_when_preserve_context_false(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] + fake_agent.state.set("counter", 0) + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + # Mutate agent state as if a previous invocation happened + fake_agent.messages.append({"role": "assistant", "content": [{"text": "reply"}]}) + fake_agent.state.set("counter", 5) + + # Mock stream_async so we don't need a real model + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + tool_use = { "toolUseId": "tool-123", - "name": "test_agent", - "input": "direct string", + "name": "fake_agent", + "input": {"input": "hello"}, } - mock_agent.stream_async.return_value = _mock_stream_async(agent_result) async for _ in tool.stream(tool_use, {}): pass - mock_agent.stream_async.assert_called_once_with("direct string") - + assert fake_agent.messages == [{"role": "user", "content": [{"text": "initial"}]}] + assert fake_agent.state.get("counter") == 0 -# --- preserve_context --- +@pytest.mark.asyncio +async def test_stream_resets_on_every_invocation(fake_agent): + """Each call should reset to the same initial snapshot, not to the previous call's state.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "seed"}]}] + fake_agent.state.set("count", 1) + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) -class _FakeAgent: - """Minimal fake agent with a real messages list for preserve_context tests.""" - - def __init__(self): - self.name = "fake_agent" - self.messages: list = [] + tool_use = { + "toolUseId": "tool-1", + "name": "fake_agent", + "input": {"input": "first"}, + } - async def invoke_async(self, prompt=None, **kwargs): + async for _ in tool.stream(tool_use, {}): pass + fake_agent.messages.append({"role": "assistant", "content": [{"text": "added"}]}) + fake_agent.state.set("count", 99) - def __call__(self, prompt=None, **kwargs): + tool_use["toolUseId"] = "tool-2" + async for _ in tool.stream(tool_use, {}): pass - def stream_async(self, prompt=None, **kwargs): - return _mock_stream_async( - AgentResult( - stop_reason="end_turn", - message={"role": "assistant", "content": [{"text": "ok"}]}, - metrics=EventLoopMetrics(), - state={}, - ) - ) + assert fake_agent.messages == [{"role": "user", "content": [{"text": "seed"}]}] + assert fake_agent.state.get("count") == 1 @pytest.mark.asyncio -async def test_stream_clears_context_when_preserve_context_false(): - agent = _FakeAgent() - agent.messages = [{"role": "user", "content": [{"text": "old"}]}] - tool = AgentAsTool(agent, name="fake_agent", description="desc") +async def test_stream_initial_snapshot_is_deep_copy(fake_agent): + """Mutating the agent's messages after construction should not affect the snapshot.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "original"}]}] + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages[0]["content"][0]["text"] = "mutated" + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) tool_use = { "toolUseId": "tool-123", "name": "fake_agent", - "input": {"input": "hello", "preserve_context": False}, + "input": {"input": "hello"}, } async for _ in tool.stream(tool_use, {}): pass - assert agent.messages == [] + assert fake_agent.messages == [{"role": "user", "content": [{"text": "original"}]}] @pytest.mark.asyncio -async def test_stream_preserves_context_by_default(): - agent = _FakeAgent() - agent.messages = [{"role": "user", "content": [{"text": "old"}]}] - tool = AgentAsTool(agent, name="fake_agent", description="desc") +async def test_stream_resets_empty_initial_state_when_preserve_context_false(fake_agent): + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) tool_use = { "toolUseId": "tool-123", @@ -292,33 +378,43 @@ async def test_stream_preserves_context_by_default(): async for _ in tool.stream(tool_use, {}): pass - assert len(agent.messages) >= 1 + assert fake_agent.messages == [] + assert fake_agent.state.get() == {} @pytest.mark.asyncio -async def test_stream_preserves_context_when_explicitly_true(): - agent = _FakeAgent() - agent.messages = [{"role": "user", "content": [{"text": "old"}]}] - tool = AgentAsTool(agent, name="fake_agent", description="desc") +async def test_stream_preserves_context_by_default(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) tool_use = { "toolUseId": "tool-123", "name": "fake_agent", - "input": {"input": "hello", "preserve_context": True}, + "input": {"input": "hello"}, } async for _ in tool.stream(tool_use, {}): pass - assert len(agent.messages) >= 1 + assert len(fake_agent.messages) >= 1 + assert fake_agent.state.get("key") == "value" -@pytest.mark.asyncio -async def test_stream_preserve_context_false_warns_when_no_messages_attr(caplog): - """Agent without a messages attribute should log a warning.""" +def test_preserve_context_false_requires_agent_instance(): + """preserve_context=False should raise TypeError for non-Agent instances.""" - class _NoMessagesAgent: - name = "bare_agent" + class _NotAnAgent: + name = "not_agent" async def invoke_async(self, prompt=None, **kwargs): pass @@ -327,26 +423,7 @@ def __call__(self, prompt=None, **kwargs): pass def stream_async(self, prompt=None, **kwargs): - return _mock_stream_async( - AgentResult( - stop_reason="end_turn", - message={"role": "assistant", "content": [{"text": "ok"}]}, - metrics=EventLoopMetrics(), - state={}, - ) - ) - - agent = _NoMessagesAgent() - tool = AgentAsTool(agent, name="bare_agent", description="desc") - - tool_use = { - "toolUseId": "tool-123", - "name": "bare_agent", - "input": {"input": "hello", "preserve_context": False}, - } - - with caplog.at_level(logging.WARNING, logger="strands.agent.agent_as_tool"): - async for _ in tool.stream(tool_use, {}): pass - assert "preserve_context=false requested" in caplog.text + with pytest.raises(TypeError, match="requires an Agent instance"): + AgentAsTool(_NotAnAgent(), name="bad", description="desc", preserve_context=False) diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index 6163faeb6..48465e1f6 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -6,6 +6,7 @@ from strands.telemetry import EventLoopMetrics from strands.types._events import ( + AgentAsToolStreamEvent, AgentResultEvent, CitationStreamEvent, EventLoopStopEvent, @@ -465,3 +466,39 @@ def test_event_inheritance(self): assert hasattr(event, "is_callback_event") assert hasattr(event, "as_dict") assert hasattr(event, "prepare") + + +class TestAgentAsToolStreamEvent: + """Tests for AgentAsToolStreamEvent.""" + + def test_initialization(self): + """Test AgentAsToolStreamEvent initialization with agent-tool reference.""" + tool_use: ToolUse = { + "toolUseId": "agent_tool_123", + "name": "researcher", + "input": {"input": "hello"}, + } + agent_event = {"data": "partial response"} + mock_agent_as_tool = MagicMock() + mock_agent_as_tool.tool_name = "researcher" + + event = AgentAsToolStreamEvent(tool_use, agent_event, mock_agent_as_tool) + + assert event["tool_stream_event"]["tool_use"] == tool_use + assert event["tool_stream_event"]["data"] == agent_event + assert event.agent_as_tool is mock_agent_as_tool + assert event.tool_use_id == "agent_tool_123" + + def test_is_tool_stream_event_subclass(self): + """Test that AgentAsToolStreamEvent is a ToolStreamEvent subclass.""" + tool_use: ToolUse = { + "toolUseId": "id_123", + "name": "tool", + "input": {}, + } + mock_agent_as_tool = MagicMock() + event = AgentAsToolStreamEvent(tool_use, {}, mock_agent_as_tool) + + assert isinstance(event, ToolStreamEvent) + assert isinstance(event, TypedEvent) + assert type(event) is AgentAsToolStreamEvent From 415a55d627a298eef583c4d95aaaa26f514a440c Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Fri, 20 Mar 2026 14:36:35 -0400 Subject: [PATCH 4/8] fix: make preserve_context default to true --- src/strands/agent/_agent_as_tool.py | 4 +- src/strands/agent/agent.py | 4 +- tests/strands/agent/test_agent_as_tool.py | 72 +++++++++++++++++------ 3 files changed, 57 insertions(+), 23 deletions(-) diff --git a/src/strands/agent/_agent_as_tool.py b/src/strands/agent/_agent_as_tool.py index 534be5d9a..a54b67df9 100644 --- a/src/strands/agent/_agent_as_tool.py +++ b/src/strands/agent/_agent_as_tool.py @@ -52,7 +52,7 @@ def __init__( *, name: str, description: str, - preserve_context: bool = True, + preserve_context: bool = False, ) -> None: r"""Initialize the agent-as-tool adapter. @@ -64,7 +64,7 @@ def __init__( invocations. When False, the agent's messages and state are reset to the values they had at construction time before each call, ensuring every invocation starts from the same baseline regardless of any external - interactions with the agent. Defaults to True. Only effective when the + interactions with the agent. Defaults to False. Only effective when the wrapped agent exposes a mutable ``messages`` list and/or an ``AgentState`` (e.g. ``strands.agent.Agent``). """ diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8d94de45b..f09399dbf 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -617,7 +617,7 @@ def as_tool( self, name: str | None = None, description: str | None = None, - preserve_context: bool = True, + preserve_context: bool = False, ) -> AgentAsTool: r"""Convert this agent into a tool for use by another agent. @@ -629,7 +629,7 @@ def as_tool( invocations. When False, the agent's messages and state are reset to the values they had at construction time before each call, ensuring every invocation starts from the same baseline regardless of any external - interactions with the agent. Defaults to True. + interactions with the agent. Defaults to False. Returns: An AgentAsTool wrapping this agent. diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py index a6cc28d51..68128e6e5 100644 --- a/tests/strands/agent/test_agent_as_tool.py +++ b/tests/strands/agent/test_agent_as_tool.py @@ -25,9 +25,17 @@ def mock_agent(): return agent +@pytest.fixture +def fake_agent(): + """A real Agent instance for tests that need Agent-specific features.""" + from strands.agent.agent import Agent + + return Agent(name="fake_agent", callback_handler=None) + + @pytest.fixture def tool(mock_agent): - return AgentAsTool(mock_agent, name="test_agent", description="A test agent") + return AgentAsTool(mock_agent, name="test_agent", description="A test agent", preserve_context=True) @pytest.fixture @@ -53,20 +61,20 @@ def agent_result(): def test_init(mock_agent): - tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc") + tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc", preserve_context=True) assert tool.tool_name == "my_tool" assert tool._description == "custom desc" assert tool.agent is mock_agent -def test_init_preserve_context_defaults_true(mock_agent): - tool = AgentAsTool(mock_agent, name="t", description="d") - assert tool._preserve_context is True +def test_init_preserve_context_defaults_false(fake_agent): + tool = AgentAsTool(fake_agent, name="t", description="d") + assert tool._preserve_context is False -def test_init_preserve_context_false(fake_agent): - tool = AgentAsTool(fake_agent, name="t", description="d", preserve_context=False) - assert tool._preserve_context is False +def test_init_preserve_context_true(mock_agent): + tool = AgentAsTool(mock_agent, name="t", description="d", preserve_context=True) + assert tool._preserve_context is True # --- properties --- @@ -244,14 +252,6 @@ class MyOutput(BaseModel): # --- preserve_context --- -@pytest.fixture -def fake_agent(): - """A real Agent instance for preserve_context tests.""" - from strands.agent.agent import Agent - - return Agent(name="fake_agent", callback_handler=None) - - @pytest.mark.asyncio async def test_stream_resets_to_initial_state_when_preserve_context_false(fake_agent): fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] @@ -383,11 +383,45 @@ async def test_stream_resets_empty_initial_state_when_preserve_context_false(fak @pytest.mark.asyncio -async def test_stream_preserves_context_by_default(fake_agent): +async def test_stream_resets_context_by_default(fake_agent): + """Default preserve_context=False means each invocation starts fresh.""" fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] fake_agent.state.set("key", "value") tool = AgentAsTool(fake_agent, name="fake_agent", description="desc") + # Mutate after construction + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + fake_agent.state.set("key", "changed") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + # Should reset to construction-time snapshot + assert fake_agent.messages == [{"role": "user", "content": [{"text": "old"}]}] + assert fake_agent.state.get("key") == "value" + + +@pytest.mark.asyncio +async def test_stream_preserves_context_when_explicitly_true(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( AgentResult( stop_reason="end_turn", @@ -411,7 +445,7 @@ async def test_stream_preserves_context_by_default(fake_agent): def test_preserve_context_false_requires_agent_instance(): - """preserve_context=False should raise TypeError for non-Agent instances.""" + """Default preserve_context=False should raise TypeError for non-Agent instances.""" class _NotAnAgent: name = "not_agent" @@ -426,4 +460,4 @@ def stream_async(self, prompt=None, **kwargs): pass with pytest.raises(TypeError, match="requires an Agent instance"): - AgentAsTool(_NotAnAgent(), name="bad", description="desc", preserve_context=False) + AgentAsTool(_NotAnAgent(), name="bad", description="desc") From 4d79d322e1cd5eac7cf6f98ab7e2b823740ec138 Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Tue, 24 Mar 2026 12:17:21 -0400 Subject: [PATCH 5/8] feat: propigate interrupts to parent agent in AgentAsTool --- src/strands/agent/_agent_as_tool.py | 54 +++++- src/strands/tools/executors/_executor.py | 6 + tests/strands/agent/test_agent_as_tool.py | 156 +++++++++++++++++- .../strands/tools/executors/test_executor.py | 51 ++++++ 4 files changed, 263 insertions(+), 4 deletions(-) diff --git a/src/strands/agent/_agent_as_tool.py b/src/strands/agent/_agent_as_tool.py index a54b67df9..8954be8a1 100644 --- a/src/strands/agent/_agent_as_tool.py +++ b/src/strands/agent/_agent_as_tool.py @@ -11,8 +11,9 @@ from typing_extensions import override from ..agent.state import AgentState -from ..types._events import AgentAsToolStreamEvent, ToolResultEvent +from ..types._events import AgentAsToolStreamEvent, ToolInterruptEvent, ToolResultEvent from ..types.content import Messages +from ..types.interrupt import InterruptResponseContent from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse from .base import AgentBase @@ -132,13 +133,18 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw can distinguish sub-agent progress from regular tool events. The final AgentResult is yielded as a ToolResultEvent. + When the sub-agent encounters a hook interrupt (e.g. from BeforeToolCallEvent), + the interrupts are propagated to the parent agent via ToolInterruptEvent. On + resume, interrupt responses are forwarded to the sub-agent automatically. + Args: tool_use: The tool use request containing the input parameter. invocation_state: Context for the tool invocation. **kwargs: Additional keyword arguments. Yields: - AgentAsToolStreamEvent for intermediate events, then ToolResultEvent with the final response. + AgentAsToolStreamEvent for intermediate events, ToolInterruptEvent if the + sub-agent is interrupted, or ToolResultEvent with the final response. """ tool_input = tool_use["input"] if isinstance(tool_input, dict): @@ -151,7 +157,15 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw tool_use_id = tool_use["toolUseId"] - if not self._preserve_context: + # Determine if we are resuming the sub-agent from an interrupt. + if self._is_sub_agent_interrupted(): + prompt = self._build_interrupt_responses() + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | resuming sub-agent from interrupt", + self._tool_name, + tool_use_id, + ) + elif not self._preserve_context: self._reset_agent_state(tool_use_id) logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) @@ -174,6 +188,11 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw ) return + # Propagate sub-agent interrupts to the parent agent. + if result.stop_reason == "interrupt" and result.interrupts: + yield ToolInterruptEvent(tool_use, list(result.interrupts)) + return + if result.structured_output: yield ToolResultEvent( { @@ -229,6 +248,35 @@ def _reset_agent_state(self, tool_use_id: str) -> None: self._agent.messages = copy.deepcopy(self._initial_messages) self._agent.state = AgentState(self._initial_state.get()) + def _is_sub_agent_interrupted(self) -> bool: + """Check whether the wrapped agent is in an activated interrupt state.""" + from .agent import Agent + + if not isinstance(self._agent, Agent): + return False + return self._agent._interrupt_state.activated + + def _build_interrupt_responses(self) -> list[InterruptResponseContent]: + """Build interrupt response payloads from the sub-agent's interrupt state. + + The parent agent's ``_interrupt_state.resume()`` sets ``.response`` on the shared + ``Interrupt`` objects (registered by the executor), so we re-package them in the + format expected by ``Agent.stream_async``. + + Returns: + List of interrupt response content blocks for resuming the sub-agent. + """ + from .agent import Agent + + if not isinstance(self._agent, Agent): + return [] + + return [ + {"interruptResponse": {"interruptId": interrupt.id, "response": interrupt.response}} + for interrupt in self._agent._interrupt_state.interrupts.values() + if interrupt.response is not None + ] + @override def get_display_properties(self) -> dict[str, str]: """Get properties for UI display.""" diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 0da6b5715..e8f21ca7c 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -226,6 +226,12 @@ async def _stream( # ToolStreamEvent and the last event is just the result. if isinstance(event, ToolInterruptEvent): + # Register any interrupts not already in the agent's state. + # For normal hooks this is a no-op (already registered by _Interruptible.interrupt()). + # For sub-agent interrupts propagated via AgentAsTool, this is where they get + # registered so that _interrupt_state.resume() can locate them by ID. + for interrupt in event.interrupts: + agent._interrupt_state.interrupts.setdefault(interrupt.id, interrupt) yield event return diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py index 68128e6e5..abae6dffd 100644 --- a/tests/strands/agent/test_agent_as_tool.py +++ b/tests/strands/agent/test_agent_as_tool.py @@ -6,8 +6,9 @@ from strands.agent import AgentAsTool from strands.agent.agent_result import AgentResult +from strands.interrupt import Interrupt from strands.telemetry.metrics import EventLoopMetrics -from strands.types._events import AgentAsToolStreamEvent, ToolResultEvent, ToolStreamEvent +from strands.types._events import AgentAsToolStreamEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent async def _mock_stream_async(result, intermediate_events=None): @@ -461,3 +462,156 @@ def stream_async(self, prompt=None, **kwargs): with pytest.raises(TypeError, match="requires an Agent instance"): AgentAsTool(_NotAnAgent(), name="bad", description="desc") + + +# --- interrupt propagation --- + + +@pytest.fixture +def interrupt_result(): + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval") + return AgentResult( + stop_reason="interrupt", + message={"role": "assistant", "content": [{"text": "pending"}]}, + metrics=EventLoopMetrics(), + state={}, + interrupts=[interrupt], + ) + + +@pytest.mark.asyncio +async def test_stream_interrupt_yields_tool_interrupt_event(tool, mock_agent, tool_use, interrupt_result): + """When the sub-agent returns an interrupt result, AgentAsTool should yield ToolInterruptEvent.""" + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result) + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert isinstance(events[0], ToolInterruptEvent) + assert events[0].interrupts == interrupt_result.interrupts + assert events[0].tool_use_id == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_interrupt_no_tool_result_appended(tool, mock_agent, tool_use, interrupt_result): + """ToolInterruptEvent should not produce a ToolResultEvent.""" + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events == [] + + +@pytest.mark.asyncio +async def test_stream_interrupt_forwards_intermediate_events(tool, mock_agent, tool_use, interrupt_result): + """Intermediate events should still be yielded before the interrupt.""" + intermediate = [{"data": "partial"}] + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + interrupt_events = [e for e in events if isinstance(e, ToolInterruptEvent)] + assert len(stream_events) == 1 + assert len(interrupt_events) == 1 + + +@pytest.mark.asyncio +async def test_stream_interrupt_resume_forwards_responses(fake_agent): + """On resume, AgentAsTool should forward interrupt responses to the sub-agent.""" + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval", response="APPROVE") + + # Put the sub-agent in an activated interrupt state with the response already set + fake_agent._interrupt_state.interrupts["interrupt-1"] = interrupt + fake_agent._interrupt_state.activate() + + normal_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "approved"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + fake_agent.stream_async = MagicMock(return_value=_mock_stream_async(normal_result)) + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + tool_use = {"toolUseId": "tool-123", "name": "fake_agent", "input": {"input": "do something"}} + + events = [event async for event in tool.stream(tool_use, {})] + + # Should have called stream_async with interrupt responses, not the original prompt + call_args = fake_agent.stream_async.call_args + agent_input = call_args[0][0] + assert isinstance(agent_input, list) + assert len(agent_input) == 1 + assert agent_input[0]["interruptResponse"]["interruptId"] == "interrupt-1" + assert agent_input[0]["interruptResponse"]["response"] == "APPROVE" + + # Should produce a normal result + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + + +@pytest.mark.asyncio +async def test_stream_interrupt_resume_skips_state_reset(fake_agent): + """When resuming from interrupt with preserve_context=False, state reset should be skipped.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] + fake_agent.state.set("key", "value") + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + # Simulate the sub-agent being in interrupt state after a previous invocation + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval", response="APPROVE") + fake_agent._interrupt_state.interrupts["interrupt-1"] = interrupt + fake_agent._interrupt_state.activate() + + # Mutate messages to simulate sub-agent progress before interrupt + fake_agent.messages.append({"role": "assistant", "content": [{"text": "working on it"}]}) + + normal_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "done"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + fake_agent.stream_async = MagicMock(return_value=_mock_stream_async(normal_result)) + + tool_use = {"toolUseId": "tool-123", "name": "fake_agent", "input": {"input": "do something"}} + async for _ in tool.stream(tool_use, {}): + pass + + # Messages should NOT have been reset — the sub-agent needs its conversation history intact + assert len(fake_agent.messages) == 2 + + +@pytest.mark.asyncio +async def test_is_sub_agent_interrupted_false_for_mock(tool): + """_is_sub_agent_interrupted returns False for non-Agent instances.""" + assert tool._is_sub_agent_interrupted() is False + + +@pytest.mark.asyncio +async def test_is_sub_agent_interrupted_true_when_activated(fake_agent): + """_is_sub_agent_interrupted returns True when the sub-agent's interrupt state is activated.""" + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + assert tool._is_sub_agent_interrupted() is False + + fake_agent._interrupt_state.activate() + assert tool._is_sub_agent_interrupted() is True + + +@pytest.mark.asyncio +async def test_build_interrupt_responses(fake_agent): + """_build_interrupt_responses packages sub-agent interrupts into response content blocks.""" + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + + interrupt_a = Interrupt(id="id-a", name="a", reason="r", response="yes") + interrupt_b = Interrupt(id="id-b", name="b", reason="r", response=None) + fake_agent._interrupt_state.interrupts = {"id-a": interrupt_a, "id-b": interrupt_b} + + responses = tool._build_interrupt_responses() + + # Only interrupt_a has a response + assert len(responses) == 1 + assert responses[0] == {"interruptResponse": {"interruptId": "id-a", "response": "yes"}} diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 4a5479503..c7ad10232 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -464,6 +464,57 @@ async def test_executor_stream_tool_interrupt_resume(executor, agent, tool_resul assert tru_results == exp_results +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt_registers_on_agent( + executor, agent, tool_results, invocation_state, alist +): + """ToolInterruptEvent from a tool should register interrupts in the agent's _interrupt_state.""" + # Create a tool that yields a ToolInterruptEvent with an interrupt NOT pre-registered on the agent + # (simulates AgentAsTool propagating sub-agent interrupts). + foreign_interrupt = Interrupt(id="sub-agent-interrupt-1", name="approval", reason="need approval") + + @strands.tool(name="agent_tool") + def agent_tool_func(): + return "unused" + + async def mock_stream(_tool_use, _invocation_state, **_kwargs): + yield ToolInterruptEvent(_tool_use, [foreign_interrupt]) + + agent_tool_func.stream = mock_stream + agent.tool_registry.register_tool(agent_tool_func) + + tool_use: ToolUse = {"name": "agent_tool", "toolUseId": "test_tool_id", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + events = await alist(stream) + + # Should yield the interrupt event + assert len(events) == 1 + assert isinstance(events[0], ToolInterruptEvent) + + # The interrupt should now be registered on the agent's _interrupt_state + assert "sub-agent-interrupt-1" in agent._interrupt_state.interrupts + assert agent._interrupt_state.interrupts["sub-agent-interrupt-1"] is foreign_interrupt + + +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt_does_not_overwrite_existing( + executor, agent, tool_results, invocation_state, alist +): + """setdefault should not overwrite interrupts already in the agent's state (normal hook case).""" + tool_use = {"name": "interrupt_tool", "toolUseId": "test_tool_id", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + # The interrupt_tool hook registered the interrupt via _Interruptible.interrupt(). + # The executor's setdefault should have been a no-op for this pre-registered interrupt. + registered = agent._interrupt_state.interrupts + assert len(registered) == 1 + interrupt = next(iter(registered.values())) + assert interrupt.name == "test_name" + assert interrupt.reason == "test reason" + + @pytest.mark.asyncio async def test_executor_stream_updates_invocation_state_with_agent( executor, agent, tool_results, invocation_state, weather_tool, alist From 7cf436c9030a15c94477cb76fac84b97f99f3e32 Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Wed, 25 Mar 2026 12:44:56 -0400 Subject: [PATCH 6/8] feat: add threading.lock on async stream; expand interrupts to other AgentBase instances --- src/strands/agent/_agent_as_tool.py | 57 ++++++++++++++++------- tests/strands/agent/test_agent_as_tool.py | 55 ++++++++++++++++++++++ 2 files changed, 94 insertions(+), 18 deletions(-) diff --git a/src/strands/agent/_agent_as_tool.py b/src/strands/agent/_agent_as_tool.py index 8954be8a1..e0882c5af 100644 --- a/src/strands/agent/_agent_as_tool.py +++ b/src/strands/agent/_agent_as_tool.py @@ -6,6 +6,7 @@ import copy import logging +import threading from typing import Any from typing_extensions import override @@ -81,6 +82,10 @@ def __init__( # messages/state attributes. self._initial_messages: Messages = [] self._initial_state: AgentState = AgentState() + # Serialize access so _reset_agent_state + stream_async are atomic. + # threading.Lock (not asyncio.Lock) because run_async() may create + # separate event loops in different threads. + self._lock = threading.Lock() if not preserve_context: from .agent import Agent @@ -157,20 +162,38 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw tool_use_id = tool_use["toolUseId"] - # Determine if we are resuming the sub-agent from an interrupt. - if self._is_sub_agent_interrupted(): - prompt = self._build_interrupt_responses() - logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | resuming sub-agent from interrupt", + # Serialize access to the underlying agent. _reset_agent_state() mutates + # the agent before stream_async acquires its own lock, so a concurrent + # call would corrupt an in-flight invocation. + if not self._lock.acquire(blocking=False): + logger.warning( + "tool_name=<%s>, tool_use_id=<%s> | agent is already processing a request", self._tool_name, tool_use_id, ) - elif not self._preserve_context: - self._reset_agent_state(tool_use_id) - - logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Agent '{self._tool_name}' is already processing a request"}], + } + ) + return try: + # Determine if we are resuming the sub-agent from an interrupt. + if self._is_sub_agent_interrupted(): + prompt = self._build_interrupt_responses() + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | resuming sub-agent from interrupt", + self._tool_name, + tool_use_id, + ) + elif not self._preserve_context: + self._reset_agent_state(tool_use_id) + + logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) + result = None async for event in self._agent.stream_async(prompt): if "result" in event: @@ -224,6 +247,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw "content": [{"text": f"Agent error: {e}"}], } ) + finally: + self._lock.release() def _reset_agent_state(self, tool_use_id: str) -> None: """Reset the wrapped agent to its initial state. @@ -250,11 +275,8 @@ def _reset_agent_state(self, tool_use_id: str) -> None: def _is_sub_agent_interrupted(self) -> bool: """Check whether the wrapped agent is in an activated interrupt state.""" - from .agent import Agent - - if not isinstance(self._agent, Agent): - return False - return self._agent._interrupt_state.activated + interrupt_state = getattr(self._agent, "_interrupt_state", None) + return interrupt_state is not None and interrupt_state.activated def _build_interrupt_responses(self) -> list[InterruptResponseContent]: """Build interrupt response payloads from the sub-agent's interrupt state. @@ -266,14 +288,13 @@ def _build_interrupt_responses(self) -> list[InterruptResponseContent]: Returns: List of interrupt response content blocks for resuming the sub-agent. """ - from .agent import Agent - - if not isinstance(self._agent, Agent): + interrupt_state = getattr(self._agent, "_interrupt_state", None) + if interrupt_state is None: return [] return [ {"interruptResponse": {"interruptId": interrupt.id, "response": interrupt.response}} - for interrupt in self._agent._interrupt_state.interrupts.values() + for interrupt in interrupt_state.interrupts.values() if interrupt.response is not None ] diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py index abae6dffd..50b9084ae 100644 --- a/tests/strands/agent/test_agent_as_tool.py +++ b/tests/strands/agent/test_agent_as_tool.py @@ -23,6 +23,9 @@ def mock_agent(): agent = MagicMock() agent.name = "test_agent" agent.description = "A test agent" + # Prevent MagicMock from auto-creating _interrupt_state on access, + # so getattr checks in AgentAsTool correctly detect its absence. + agent._interrupt_state = None return agent @@ -615,3 +618,55 @@ async def test_build_interrupt_responses(fake_agent): # Only interrupt_a has a response assert len(responses) == 1 assert responses[0] == {"interruptResponse": {"interruptId": "id-a", "response": "yes"}} + + +# --- concurrency --- + + +@pytest.mark.asyncio +async def test_stream_rejects_concurrent_call(tool, mock_agent, tool_use, agent_result): + """A second concurrent call should get an error ToolResultEvent.""" + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + # Simulate the lock already being held by another invocation + tool._lock.acquire() + try: + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert isinstance(events[0], ToolResultEvent) + assert events[0]["tool_result"]["status"] == "error" + assert "already processing" in events[0]["tool_result"]["content"][0]["text"] + mock_agent.stream_async.assert_not_called() + finally: + tool._lock.release() + + +@pytest.mark.asyncio +async def test_stream_releases_lock_after_completion(tool, mock_agent, tool_use, agent_result): + """Lock should be released after stream completes, allowing subsequent calls.""" + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + assert not tool._lock.locked() + + # A second call should succeed + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + + +@pytest.mark.asyncio +async def test_stream_releases_lock_after_error(tool, mock_agent, tool_use): + """Lock should be released even when the agent raises an exception.""" + mock_agent.stream_async.side_effect = RuntimeError("boom") + + async for _ in tool.stream(tool_use, {}): + pass + + assert not tool._lock.locked() From b9b6048560d45e85e8af4cd2a452728a1af36d90 Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Wed, 25 Mar 2026 16:04:45 -0400 Subject: [PATCH 7/8] feat: mark _AgentAsTool as private; take Agent instead of AgentBase in the constructor --- src/strands/__init__.py | 4 +- src/strands/agent/__init__.py | 4 +- src/strands/agent/_agent_as_tool.py | 66 ++++++-------- src/strands/agent/agent.py | 13 ++- src/strands/tools/executors/_executor.py | 2 +- src/strands/types/_events.py | 12 +-- tests/strands/agent/test_agent.py | 6 +- tests/strands/agent/test_agent_as_tool.py | 86 ++++++++++--------- .../strands/tools/executors/test_executor.py | 2 +- 9 files changed, 94 insertions(+), 101 deletions(-) diff --git a/src/strands/__init__.py b/src/strands/__init__.py index dc3a0c7ff..3db6f4758 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -1,7 +1,7 @@ """A framework for building, deploying, and managing AI agents.""" from . import agent, models, telemetry, types -from .agent._agent_as_tool import AgentAsTool +from .agent._agent_as_tool import _AgentAsTool from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy @@ -12,7 +12,7 @@ __all__ = [ "Agent", - "AgentAsTool", + "_AgentAsTool", "AgentBase", "AgentSkills", "agent", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index d0254852d..66dc56168 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -10,7 +10,7 @@ from typing import Any from ..event_loop._retry import ModelRetryStrategy -from ._agent_as_tool import AgentAsTool +from ._agent_as_tool import _AgentAsTool from .agent import Agent from .agent_result import AgentResult from .base import AgentBase @@ -25,7 +25,7 @@ "Agent", "AgentBase", "AgentResult", - "AgentAsTool", + "_AgentAsTool", "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", diff --git a/src/strands/agent/_agent_as_tool.py b/src/strands/agent/_agent_as_tool.py index e0882c5af..11b536789 100644 --- a/src/strands/agent/_agent_as_tool.py +++ b/src/strands/agent/_agent_as_tool.py @@ -1,13 +1,15 @@ """Agent-as-tool adapter. -This module provides the AgentAsTool class that wraps an Agent (or any AgentBase) as a tool +This module provides the _AgentAsTool class that wraps an Agent as a tool so it can be passed to another agent's tool list. """ +from __future__ import annotations + import copy import logging import threading -from typing import Any +from typing import TYPE_CHECKING, Any from typing_extensions import override @@ -16,12 +18,14 @@ from ..types.content import Messages from ..types.interrupt import InterruptResponseContent from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse -from .base import AgentBase + +if TYPE_CHECKING: + from .agent import Agent logger = logging.getLogger(__name__) -class AgentAsTool(AgentTool): +class _AgentAsTool(AgentTool): """Adapter that exposes an Agent as a tool for use by other agents. The tool accepts a single ``input`` string parameter, invokes the wrapped @@ -30,18 +34,14 @@ class AgentAsTool(AgentTool): Example: ```python from strands import Agent - from strands.agent import AgentAsTool researcher = Agent(name="researcher", description="Finds information") - # Use directly - tool = AgentAsTool(researcher, name="researcher", description="Finds information") - - # Or via convenience method + # Use via convenience method (default: fresh conversation each call) tool = researcher.as_tool() - # Start each invocation with a fresh conversation - tool = researcher.as_tool(preserve_context=False) + # Preserve context across invocations + tool = researcher.as_tool(preserve_context=True) writer = Agent(name="writer", tools=[tool]) writer("Write about AI agents") @@ -50,10 +50,10 @@ class AgentAsTool(AgentTool): def __init__( self, - agent: AgentBase, + agent: Agent, *, name: str, - description: str, + description: str | None = None, preserve_context: bool = False, ) -> None: r"""Initialize the agent-as-tool adapter. @@ -61,25 +61,24 @@ def __init__( Args: agent: The agent to wrap as a tool. name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. - description: Tool description. + description: Tool description. Defaults to the agent's description, or a + generic description if the agent has no description set. preserve_context: Whether to preserve the agent's conversation history across invocations. When False, the agent's messages and state are reset to the values they had at construction time before each call, ensuring every invocation starts from the same baseline regardless of any external - interactions with the agent. Defaults to False. Only effective when the - wrapped agent exposes a mutable ``messages`` list and/or an ``AgentState`` - (e.g. ``strands.agent.Agent``). + interactions with the agent. Defaults to False. """ super().__init__() self._agent = agent self._tool_name = name - self._description = description + self._description = ( + description or agent.description or f"Use the {name} agent as a tool by providing a natural language input" + ) self._preserve_context = preserve_context # When preserve_context=False, we snapshot the agent's initial state so we can # restore it before each invocation. This mirrors GraphNode.reset_executor_state(). - # We require an Agent instance for this since AgentBase doesn't guarantee - # messages/state attributes. self._initial_messages: Messages = [] self._initial_state: AgentState = AgentState() # Serialize access so _reset_agent_state + stream_async are atomic. @@ -88,15 +87,17 @@ def __init__( self._lock = threading.Lock() if not preserve_context: - from .agent import Agent - - if not isinstance(agent, Agent): - raise TypeError(f"preserve_context=False requires an Agent instance, got {type(agent).__name__}") + if getattr(agent, "_session_manager", None) is not None: + raise ValueError( + "preserve_context=False cannot be used with an agent that has a session manager. " + "The session manager persists conversation history externally, which conflicts with " + "resetting the agent's state between invocations." + ) self._initial_messages = copy.deepcopy(agent.messages) self._initial_state = AgentState(agent.state.get()) @property - def agent(self) -> AgentBase: + def agent(self) -> Agent: """The wrapped agent instance.""" return self._agent @@ -259,12 +260,6 @@ def _reset_agent_state(self, tool_use_id: str) -> None: Args: tool_use_id: Tool use ID for logging context. """ - from .agent import Agent - - # isinstance narrows the type for mypy; __init__ guarantees this when preserve_context=False - if not isinstance(self._agent, Agent): - return - logger.debug( "tool_name=<%s>, tool_use_id=<%s> | resetting agent to initial state", self._tool_name, @@ -275,8 +270,7 @@ def _reset_agent_state(self, tool_use_id: str) -> None: def _is_sub_agent_interrupted(self) -> bool: """Check whether the wrapped agent is in an activated interrupt state.""" - interrupt_state = getattr(self._agent, "_interrupt_state", None) - return interrupt_state is not None and interrupt_state.activated + return self._agent._interrupt_state.activated def _build_interrupt_responses(self) -> list[InterruptResponseContent]: """Build interrupt response payloads from the sub-agent's interrupt state. @@ -288,13 +282,9 @@ def _build_interrupt_responses(self) -> list[InterruptResponseContent]: Returns: List of interrupt response content blocks for resuming the sub-agent. """ - interrupt_state = getattr(self._agent, "_interrupt_state", None) - if interrupt_state is None: - return [] - return [ {"interruptResponse": {"interruptId": interrupt.id, "response": interrupt.response}} - for interrupt in interrupt_state.interrupts.values() + for interrupt in self._agent._interrupt_state.interrupts.values() if interrupt.response is not None ] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f09399dbf..c58debb35 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -62,7 +62,7 @@ from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue -from ._agent_as_tool import AgentAsTool +from ._agent_as_tool import _AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( @@ -618,13 +618,14 @@ def as_tool( name: str | None = None, description: str | None = None, preserve_context: bool = False, - ) -> AgentAsTool: + ) -> _AgentAsTool: r"""Convert this agent into a tool for use by another agent. Args: name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. Defaults to the agent's name. - description: Tool description. Defaults to the agent's description. + description: Tool description. Defaults to the agent's description, or a + generic description if the agent has no description set. preserve_context: Whether to preserve the agent's conversation history across invocations. When False, the agent's messages and state are reset to the values they had at construction time before each call, ensuring every @@ -632,7 +633,7 @@ def as_tool( interactions with the agent. Defaults to False. Returns: - An AgentAsTool wrapping this agent. + An _AgentAsTool wrapping this agent. Example: ```python @@ -643,9 +644,7 @@ def as_tool( """ if not name: name = self.name - if not description: - description = self.description or f"Use the {name} agent as a tool by providing a natural language input" - return AgentAsTool(self, name=name, description=description, preserve_context=preserve_context) + return _AgentAsTool(self, name=name, description=description, preserve_context=preserve_context) def cleanup(self) -> None: """Clean up resources used by the agent. diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index e8f21ca7c..5825b3cdb 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -228,7 +228,7 @@ async def _stream( if isinstance(event, ToolInterruptEvent): # Register any interrupts not already in the agent's state. # For normal hooks this is a no-op (already registered by _Interruptible.interrupt()). - # For sub-agent interrupts propagated via AgentAsTool, this is where they get + # For sub-agent interrupts propagated via _AgentAsTool, this is where they get # registered so that _interrupt_state.resume() can locate them by ID. for interrupt in event.interrupts: agent._interrupt_state.interrupts.setdefault(interrupt.id, interrupt) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 5603aedfb..1d5a5de79 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from ..agent import AgentResult - from ..agent._agent_as_tool import AgentAsTool + from ..agent._agent_as_tool import _AgentAsTool from ..multiagent.base import MultiAgentResult, NodeResult @@ -327,25 +327,25 @@ def tool_use_id(self) -> str: class AgentAsToolStreamEvent(ToolStreamEvent): """Event emitted when an agent-as-tool yields intermediate events during execution. - Extends ToolStreamEvent with a reference to the originating AgentAsTool so callers + Extends ToolStreamEvent with a reference to the originating _AgentAsTool so callers can distinguish sub-agent stream events from regular tool stream events and access the wrapped agent, tool name, description, etc. """ - def __init__(self, tool_use: ToolUse, tool_stream_data: Any, agent_as_tool: "AgentAsTool") -> None: + def __init__(self, tool_use: ToolUse, tool_stream_data: Any, agent_as_tool: "_AgentAsTool") -> None: """Initialize with tool streaming data and agent-tool reference. Args: tool_use: The tool invocation producing the stream. tool_stream_data: The yielded event from the sub-agent execution. - agent_as_tool: The AgentAsTool instance that produced this event. + agent_as_tool: The _AgentAsTool instance that produced this event. """ super().__init__(tool_use, tool_stream_data) self._agent_as_tool = agent_as_tool @property - def agent_as_tool(self) -> "AgentAsTool": - """The AgentAsTool instance that produced this event.""" + def agent_as_tool(self) -> "_AgentAsTool": + """The _AgentAsTool instance that produced this event.""" return self._agent_as_tool diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index c089ba808..cdfb609aa 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -15,7 +15,7 @@ import strands from strands import Agent, Plugin, ToolContext -from strands.agent import AgentAsTool, AgentResult +from strands.agent import AgentResult, _AgentAsTool from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState @@ -2702,11 +2702,11 @@ def hook_callback(event: BeforeModelCallEvent): def test_as_tool_returns_agent_tool(): - """Test that as_tool returns an AgentAsTool wrapping the agent.""" + """Test that as_tool returns an _AgentAsTool wrapping the agent.""" agent = Agent(name="researcher", description="Finds information") tool = agent.as_tool() - assert isinstance(tool, AgentAsTool) + assert isinstance(tool, _AgentAsTool) assert tool.agent is agent diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py index 50b9084ae..0dc5af3d4 100644 --- a/tests/strands/agent/test_agent_as_tool.py +++ b/tests/strands/agent/test_agent_as_tool.py @@ -1,12 +1,12 @@ -"""Tests for AgentAsTool - the agent-as-tool adapter.""" +"""Tests for _AgentAsTool - the agent-as-tool adapter.""" from unittest.mock import MagicMock import pytest -from strands.agent import AgentAsTool +from strands.agent import _AgentAsTool from strands.agent.agent_result import AgentResult -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState from strands.telemetry.metrics import EventLoopMetrics from strands.types._events import AgentAsToolStreamEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent @@ -23,9 +23,7 @@ def mock_agent(): agent = MagicMock() agent.name = "test_agent" agent.description = "A test agent" - # Prevent MagicMock from auto-creating _interrupt_state on access, - # so getattr checks in AgentAsTool correctly detect its absence. - agent._interrupt_state = None + agent._interrupt_state = _InterruptState() return agent @@ -39,7 +37,7 @@ def fake_agent(): @pytest.fixture def tool(mock_agent): - return AgentAsTool(mock_agent, name="test_agent", description="A test agent", preserve_context=True) + return _AgentAsTool(mock_agent, name="test_agent", description="A test agent", preserve_context=True) @pytest.fixture @@ -65,19 +63,36 @@ def agent_result(): def test_init(mock_agent): - tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc", preserve_context=True) + tool = _AgentAsTool(mock_agent, name="my_tool", description="custom desc", preserve_context=True) assert tool.tool_name == "my_tool" assert tool._description == "custom desc" assert tool.agent is mock_agent +def test_init_description_defaults_to_agent_description(fake_agent): + fake_agent.description = "Agent that researches topics" + tool = _AgentAsTool(fake_agent, name="researcher", preserve_context=True) + assert tool._description == "Agent that researches topics" + + +def test_init_description_defaults_to_generic_when_agent_has_none(fake_agent): + tool = _AgentAsTool(fake_agent, name="researcher", preserve_context=True) + assert tool._description == "Use the researcher agent as a tool by providing a natural language input" + + +def test_init_description_explicit_overrides_agent_description(fake_agent): + fake_agent.description = "Agent that researches topics" + tool = _AgentAsTool(fake_agent, name="researcher", description="custom", preserve_context=True) + assert tool._description == "custom" + + def test_init_preserve_context_defaults_false(fake_agent): - tool = AgentAsTool(fake_agent, name="t", description="d") + tool = _AgentAsTool(fake_agent, name="t", description="d") assert tool._preserve_context is False def test_init_preserve_context_true(mock_agent): - tool = AgentAsTool(mock_agent, name="t", description="d", preserve_context=True) + tool = _AgentAsTool(mock_agent, name="t", description="d", preserve_context=True) assert tool._preserve_context is True @@ -210,7 +225,7 @@ async def test_stream_events_not_double_wrapped_by_executor(tool, mock_agent, to assert isinstance(event, ToolStreamEvent) # But it's specifically an AgentAsToolStreamEvent (not re-wrapped) assert type(event) is AgentAsToolStreamEvent - # And it references the originating AgentAsTool + # And it references the originating _AgentAsTool assert event.agent_as_tool is tool @@ -261,7 +276,7 @@ async def test_stream_resets_to_initial_state_when_preserve_context_false(fake_a fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] fake_agent.state.set("counter", 0) - tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) # Mutate agent state as if a previous invocation happened fake_agent.messages.append({"role": "assistant", "content": [{"text": "reply"}]}) @@ -296,7 +311,7 @@ async def test_stream_resets_on_every_invocation(fake_agent): fake_agent.messages = [{"role": "user", "content": [{"text": "seed"}]}] fake_agent.state.set("count", 1) - tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( AgentResult( @@ -331,7 +346,7 @@ async def test_stream_initial_snapshot_is_deep_copy(fake_agent): """Mutating the agent's messages after construction should not affect the snapshot.""" fake_agent.messages = [{"role": "user", "content": [{"text": "original"}]}] - tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) fake_agent.messages[0]["content"][0]["text"] = "mutated" fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) @@ -359,7 +374,7 @@ async def test_stream_initial_snapshot_is_deep_copy(fake_agent): @pytest.mark.asyncio async def test_stream_resets_empty_initial_state_when_preserve_context_false(fake_agent): - tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] fake_agent.state.set("key", "value") @@ -391,7 +406,7 @@ async def test_stream_resets_context_by_default(fake_agent): """Default preserve_context=False means each invocation starts fresh.""" fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] fake_agent.state.set("key", "value") - tool = AgentAsTool(fake_agent, name="fake_agent", description="desc") + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc") # Mutate after construction fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) @@ -424,7 +439,7 @@ async def test_stream_resets_context_by_default(fake_agent): async def test_stream_preserves_context_when_explicitly_true(fake_agent): fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] fake_agent.state.set("key", "value") - tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( AgentResult( @@ -448,23 +463,12 @@ async def test_stream_preserves_context_when_explicitly_true(fake_agent): assert fake_agent.state.get("key") == "value" -def test_preserve_context_false_requires_agent_instance(): - """Default preserve_context=False should raise TypeError for non-Agent instances.""" - - class _NotAnAgent: - name = "not_agent" - - async def invoke_async(self, prompt=None, **kwargs): - pass - - def __call__(self, prompt=None, **kwargs): - pass - - def stream_async(self, prompt=None, **kwargs): - pass +def test_preserve_context_false_rejects_session_manager(fake_agent): + """preserve_context=False should raise ValueError when agent has a session manager.""" + fake_agent._session_manager = MagicMock() - with pytest.raises(TypeError, match="requires an Agent instance"): - AgentAsTool(_NotAnAgent(), name="bad", description="desc") + with pytest.raises(ValueError, match="cannot be used with an agent that has a session manager"): + _AgentAsTool(fake_agent, name="t", description="d", preserve_context=False) # --- interrupt propagation --- @@ -484,7 +488,7 @@ def interrupt_result(): @pytest.mark.asyncio async def test_stream_interrupt_yields_tool_interrupt_event(tool, mock_agent, tool_use, interrupt_result): - """When the sub-agent returns an interrupt result, AgentAsTool should yield ToolInterruptEvent.""" + """When the sub-agent returns an interrupt result, _AgentAsTool should yield ToolInterruptEvent.""" mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result) events = [event async for event in tool.stream(tool_use, {})] @@ -522,7 +526,7 @@ async def test_stream_interrupt_forwards_intermediate_events(tool, mock_agent, t @pytest.mark.asyncio async def test_stream_interrupt_resume_forwards_responses(fake_agent): - """On resume, AgentAsTool should forward interrupt responses to the sub-agent.""" + """On resume, _AgentAsTool should forward interrupt responses to the sub-agent.""" interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval", response="APPROVE") # Put the sub-agent in an activated interrupt state with the response already set @@ -537,7 +541,7 @@ async def test_stream_interrupt_resume_forwards_responses(fake_agent): ) fake_agent.stream_async = MagicMock(return_value=_mock_stream_async(normal_result)) - tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) tool_use = {"toolUseId": "tool-123", "name": "fake_agent", "input": {"input": "do something"}} events = [event async for event in tool.stream(tool_use, {})] @@ -562,7 +566,7 @@ async def test_stream_interrupt_resume_skips_state_reset(fake_agent): fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] fake_agent.state.set("key", "value") - tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) # Simulate the sub-agent being in interrupt state after a previous invocation interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval", response="APPROVE") @@ -589,15 +593,15 @@ async def test_stream_interrupt_resume_skips_state_reset(fake_agent): @pytest.mark.asyncio -async def test_is_sub_agent_interrupted_false_for_mock(tool): - """_is_sub_agent_interrupted returns False for non-Agent instances.""" +async def test_is_sub_agent_interrupted_false_by_default(tool): + """_is_sub_agent_interrupted returns False when no interrupts are active.""" assert tool._is_sub_agent_interrupted() is False @pytest.mark.asyncio async def test_is_sub_agent_interrupted_true_when_activated(fake_agent): """_is_sub_agent_interrupted returns True when the sub-agent's interrupt state is activated.""" - tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) assert tool._is_sub_agent_interrupted() is False fake_agent._interrupt_state.activate() @@ -607,7 +611,7 @@ async def test_is_sub_agent_interrupted_true_when_activated(fake_agent): @pytest.mark.asyncio async def test_build_interrupt_responses(fake_agent): """_build_interrupt_responses packages sub-agent interrupts into response content blocks.""" - tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) interrupt_a = Interrupt(id="id-a", name="a", reason="r", response="yes") interrupt_b = Interrupt(id="id-b", name="b", reason="r", response=None) diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index c7ad10232..297aa66f3 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -470,7 +470,7 @@ async def test_executor_stream_tool_interrupt_registers_on_agent( ): """ToolInterruptEvent from a tool should register interrupts in the agent's _interrupt_state.""" # Create a tool that yields a ToolInterruptEvent with an interrupt NOT pre-registered on the agent - # (simulates AgentAsTool propagating sub-agent interrupts). + # (simulates _AgentAsTool propagating sub-agent interrupts). foreign_interrupt = Interrupt(id="sub-agent-interrupt-1", name="approval", reason="need approval") @strands.tool(name="agent_tool") From 962ac556eadce3f8ee3bbc5e5783453d1c08f211 Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Wed, 25 Mar 2026 16:50:55 -0400 Subject: [PATCH 8/8] fix: remove exports; update as_tool signature to return an AgentTool --- src/strands/__init__.py | 2 -- src/strands/agent/__init__.py | 2 -- src/strands/agent/agent.py | 5 +++-- tests/strands/agent/test_agent.py | 3 ++- tests/strands/agent/test_agent_as_tool.py | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 3db6f4758..2078f16ce 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -1,7 +1,6 @@ """A framework for building, deploying, and managing AI agents.""" from . import agent, models, telemetry, types -from .agent._agent_as_tool import _AgentAsTool from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy @@ -12,7 +11,6 @@ __all__ = [ "Agent", - "_AgentAsTool", "AgentBase", "AgentSkills", "agent", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 66dc56168..c901e800f 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -10,7 +10,6 @@ from typing import Any from ..event_loop._retry import ModelRetryStrategy -from ._agent_as_tool import _AgentAsTool from .agent import Agent from .agent_result import AgentResult from .base import AgentBase @@ -25,7 +24,6 @@ "Agent", "AgentBase", "AgentResult", - "_AgentAsTool", "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index c58debb35..a29ea31eb 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -61,6 +61,7 @@ from ..types.agent import AgentInput, ConcurrentInvocationMode from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException +from ..types.tools import AgentTool from ..types.traces import AttributeValue from ._agent_as_tool import _AgentAsTool from .agent_result import AgentResult @@ -618,7 +619,7 @@ def as_tool( name: str | None = None, description: str | None = None, preserve_context: bool = False, - ) -> _AgentAsTool: + ) -> AgentTool: r"""Convert this agent into a tool for use by another agent. Args: @@ -633,7 +634,7 @@ def as_tool( interactions with the agent. Defaults to False. Returns: - An _AgentAsTool wrapping this agent. + A tool wrapping this agent. Example: ```python diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index cdfb609aa..2ce9ff245 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -15,7 +15,8 @@ import strands from strands import Agent, Plugin, ToolContext -from strands.agent import AgentResult, _AgentAsTool +from strands.agent import AgentResult +from strands.agent._agent_as_tool import _AgentAsTool from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py index 0dc5af3d4..f5848b315 100644 --- a/tests/strands/agent/test_agent_as_tool.py +++ b/tests/strands/agent/test_agent_as_tool.py @@ -4,7 +4,7 @@ import pytest -from strands.agent import _AgentAsTool +from strands.agent._agent_as_tool import _AgentAsTool from strands.agent.agent_result import AgentResult from strands.interrupt import Interrupt, _InterruptState from strands.telemetry.metrics import EventLoopMetrics