Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 296 additions & 0 deletions src/strands/agent/_agent_as_tool.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename this as _agent_as_tool.py and just export at the top level

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed

Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
"""Agent-as-tool adapter.
This module provides the _AgentAsTool class that wraps an Agent as a tool
so it can be passed to another agent's tool list.
"""

from __future__ import annotations

import copy
import logging
import threading
from typing import TYPE_CHECKING, Any

from typing_extensions import override

from ..agent.state import AgentState
from ..types._events import AgentAsToolStreamEvent, ToolInterruptEvent, ToolResultEvent
from ..types.content import Messages
from ..types.interrupt import InterruptResponseContent
from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse

if TYPE_CHECKING:
from .agent import Agent

logger = logging.getLogger(__name__)


class _AgentAsTool(AgentTool):
"""Adapter that exposes an Agent as a tool for use by other agents.
The tool accepts a single ``input`` string parameter, invokes the wrapped
agent, and returns the text response.
Example:
```python
from strands import Agent
researcher = Agent(name="researcher", description="Finds information")
# Use via convenience method (default: fresh conversation each call)
tool = researcher.as_tool()
# Preserve context across invocations
tool = researcher.as_tool(preserve_context=True)
writer = Agent(name="writer", tools=[tool])
writer("Write about AI agents")
```
"""

def __init__(
self,
agent: Agent,
*,
name: str,
description: str | None = None,
preserve_context: bool = False,
) -> None:
r"""Initialize the agent-as-tool adapter.
Args:
agent: The agent to wrap as a tool.
name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``.
description: Tool description. Defaults to the agent's description, or a
generic description if the agent has no description set.
preserve_context: Whether to preserve the agent's conversation history across
invocations. When False, the agent's messages and state are reset to the
values they had at construction time before each call, ensuring every
invocation starts from the same baseline regardless of any external
interactions with the agent. Defaults to False.
"""
super().__init__()
self._agent = agent
self._tool_name = name
self._description = (
description or agent.description or f"Use the {name} agent as a tool by providing a natural language input"
)
self._preserve_context = preserve_context

# When preserve_context=False, we snapshot the agent's initial state so we can
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move this logic to conversation managers. same thing is duplicated in here and in graph (this state management)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ or some other place, where we can allow to easily create "Amnesiac agents" that only remember initial state

Also as we use it more and more, this seems like a pattern that we should expose

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I mistaken that @zastrowm's work on snapshots could be used on the future to make this logic re-usable?

# restore it before each invocation. This mirrors GraphNode.reset_executor_state().
self._initial_messages: Messages = []
self._initial_state: AgentState = AgentState()
# Serialize access so _reset_agent_state + stream_async are atomic.
# threading.Lock (not asyncio.Lock) because run_async() may create
# separate event loops in different threads.
self._lock = threading.Lock()

if not preserve_context:
if getattr(agent, "_session_manager", None) is not None:
raise ValueError(
"preserve_context=False cannot be used with an agent that has a session manager. "
"The session manager persists conversation history externally, which conflicts with "
"resetting the agent's state between invocations."
)
self._initial_messages = copy.deepcopy(agent.messages)
self._initial_state = AgentState(agent.state.get())

@property
def agent(self) -> Agent:
"""The wrapped agent instance."""
return self._agent

@property
def tool_name(self) -> str:
"""Get the tool name."""
return self._tool_name

@property
def tool_spec(self) -> ToolSpec:
"""Get the tool specification."""
return {
"name": self._tool_name,
"description": self._description,
"inputSchema": {
"json": {
"type": "object",
"properties": {
"input": {
"type": "string",
"description": "The input to send to the agent tool.",
},
},
"required": ["input"],
}
},
}

@property
def tool_type(self) -> str:
"""Get the tool type."""
return "agent"

@override
async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
"""Invoke the wrapped agent via streaming and yield events.
Intermediate agent events are wrapped in AgentAsToolStreamEvent so the caller
can distinguish sub-agent progress from regular tool events. The final
AgentResult is yielded as a ToolResultEvent.
When the sub-agent encounters a hook interrupt (e.g. from BeforeToolCallEvent),
the interrupts are propagated to the parent agent via ToolInterruptEvent. On
resume, interrupt responses are forwarded to the sub-agent automatically.
Args:
tool_use: The tool use request containing the input parameter.
invocation_state: Context for the tool invocation.
**kwargs: Additional keyword arguments.
Yields:
AgentAsToolStreamEvent for intermediate events, ToolInterruptEvent if the
sub-agent is interrupted, or ToolResultEvent with the final response.
"""
tool_input = tool_use["input"]
if isinstance(tool_input, dict):
prompt = tool_input.get("input", "")
elif isinstance(tool_input, str):
prompt = tool_input
else:
logger.warning("tool_name=<%s> | unexpected input type: %s", self._tool_name, type(tool_input))
prompt = str(tool_input)

tool_use_id = tool_use["toolUseId"]

# Serialize access to the underlying agent. _reset_agent_state() mutates
# the agent before stream_async acquires its own lock, so a concurrent
# call would corrupt an in-flight invocation.
if not self._lock.acquire(blocking=False):
logger.warning(
"tool_name=<%s>, tool_use_id=<%s> | agent is already processing a request",
self._tool_name,
tool_use_id,
)
yield ToolResultEvent(
{
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Agent '{self._tool_name}' is already processing a request"}],
}
)
return

try:
# Determine if we are resuming the sub-agent from an interrupt.
if self._is_sub_agent_interrupted():
prompt = self._build_interrupt_responses()
logger.debug(
"tool_name=<%s>, tool_use_id=<%s> | resuming sub-agent from interrupt",
self._tool_name,
tool_use_id,
)
elif not self._preserve_context:
self._reset_agent_state(tool_use_id)

logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id)

result = None
async for event in self._agent.stream_async(prompt):
if "result" in event:
result = event["result"]
else:
yield AgentAsToolStreamEvent(tool_use, event, self)

if result is None:
yield ToolResultEvent(
{
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": "Agent did not produce a result"}],
}
)
return

# Propagate sub-agent interrupts to the parent agent.
if result.stop_reason == "interrupt" and result.interrupts:
yield ToolInterruptEvent(tool_use, list(result.interrupts))
return

if result.structured_output:
yield ToolResultEvent(
{
"toolUseId": tool_use_id,
"status": "success",
"content": [{"json": result.structured_output.model_dump()}],
}
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: With preserve_context=False, concurrent invocations of the same AgentAsTool instance create a race condition. Both calls will _reset_agent_state() which mutates the shared agent's messages and state, so overlapping calls can corrupt each other's state.

Suggestion: This follows GraphNode.reset_executor_state() pattern, so may be a known limitation. Consider adding a docstring note or log warning about thread safety, e.g.:

# Note: This method is not thread-safe. Concurrent invocations of the same
# AgentAsTool instance with preserve_context=False may produce incorrect results.

Alternatively, a future enhancement could use a lock or create per-invocation agent clones.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@notowen333 let's check this one. For example an agent might call the agent.as_tool in parallel. Would this break?

Normally we have locks for stream_async. do we need something here?

else:
yield ToolResultEvent(
{
"toolUseId": tool_use_id,
"status": "success",
"content": [{"text": str(result)}],
}
)

except Exception as e:
logger.warning(
"tool_name=<%s>, tool_use_id=<%s> | agent invocation failed: %s",
self._tool_name,
tool_use_id,
e,
)
yield ToolResultEvent(
{
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Agent error: {e}"}],
}
)
finally:
self._lock.release()

def _reset_agent_state(self, tool_use_id: str) -> None:
"""Reset the wrapped agent to its initial state.
Restores messages and state to the values captured at construction time.
This mirrors the pattern used by ``GraphNode.reset_executor_state()``.
Args:
tool_use_id: Tool use ID for logging context.
"""
logger.debug(
"tool_name=<%s>, tool_use_id=<%s> | resetting agent to initial state",
self._tool_name,
tool_use_id,
)
self._agent.messages = copy.deepcopy(self._initial_messages)
self._agent.state = AgentState(self._initial_state.get())

def _is_sub_agent_interrupted(self) -> bool:
"""Check whether the wrapped agent is in an activated interrupt state."""
return self._agent._interrupt_state.activated

def _build_interrupt_responses(self) -> list[InterruptResponseContent]:
"""Build interrupt response payloads from the sub-agent's interrupt state.
The parent agent's ``_interrupt_state.resume()`` sets ``.response`` on the shared
``Interrupt`` objects (registered by the executor), so we re-package them in the
format expected by ``Agent.stream_async``.
Returns:
List of interrupt response content blocks for resuming the sub-agent.
"""
return [
{"interruptResponse": {"interruptId": interrupt.id, "response": interrupt.response}}
for interrupt in self._agent._interrupt_state.interrupts.values()
if interrupt.response is not None
]

@override
def get_display_properties(self) -> dict[str, str]:
"""Get properties for UI display."""
properties = super().get_display_properties()
properties["Agent"] = getattr(self._agent, "name", "unknown")
return properties
35 changes: 35 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@
from ..types.agent import AgentInput, ConcurrentInvocationMode
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException
from ..types.tools import AgentTool
from ..types.traces import AttributeValue
from ._agent_as_tool import _AgentAsTool
from .agent_result import AgentResult
from .base import AgentBase
from .conversation_manager import (
Expand Down Expand Up @@ -612,6 +614,39 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu
finally:
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, invocation_state={}))

def as_tool(
self,
name: str | None = None,
description: str | None = None,
preserve_context: bool = False,
) -> AgentTool:
r"""Convert this agent into a tool for use by another agent.

Args:
name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``.
Defaults to the agent's name.
description: Tool description. Defaults to the agent's description, or a
generic description if the agent has no description set.
preserve_context: Whether to preserve the agent's conversation history across
invocations. When False, the agent's messages and state are reset to the
values they had at construction time before each call, ensuring every
invocation starts from the same baseline regardless of any external
interactions with the agent. Defaults to False.

Returns:
A tool wrapping this agent.

Example:
```python
researcher = Agent(name="researcher", description="Finds information")
writer = Agent(name="writer", tools=[researcher.as_tool()])
writer("Write about AI agents")
```
"""
if not name:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Empty string check uses falsy evaluation which is inconsistent.

Suggestion: Consider using explicit is None checks for clarity. Currently if not name would treat an empty string "" the same as None, which may not be the intended behavior:

if name is None:
    name = self.name
if description is None:
    description = self.description or f"Use the {name} tool to invoke this agent as a tool"

This makes the API contract clearer - None means "use default" while an explicit empty string could be considered user error.

name = self.name
return _AgentAsTool(self, name=name, description=description, preserve_context=preserve_context)

def cleanup(self) -> None:
"""Clean up resources used by the agent.

Expand Down
6 changes: 6 additions & 0 deletions src/strands/tools/executors/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ async def _stream(
# ToolStreamEvent and the last event is just the result.

if isinstance(event, ToolInterruptEvent):
# Register any interrupts not already in the agent's state.
# For normal hooks this is a no-op (already registered by _Interruptible.interrupt()).
# For sub-agent interrupts propagated via _AgentAsTool, this is where they get
# registered so that _interrupt_state.resume() can locate them by ID.
for interrupt in event.interrupts:
agent._interrupt_state.interrupts.setdefault(interrupt.id, interrupt)
yield event
return

Expand Down
26 changes: 26 additions & 0 deletions src/strands/types/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if TYPE_CHECKING:
from ..agent import AgentResult
from ..agent._agent_as_tool import _AgentAsTool
from ..multiagent.base import MultiAgentResult, NodeResult


Expand Down Expand Up @@ -323,6 +324,31 @@ def tool_use_id(self) -> str:
return cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use"))["toolUseId"]


class AgentAsToolStreamEvent(ToolStreamEvent):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a special stream event?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is used to differentiate in what context the agent is being used for the customer: either as an agent directly or as a tool. Thread from before: #1932 (comment)

"""Event emitted when an agent-as-tool yields intermediate events during execution.

Extends ToolStreamEvent with a reference to the originating _AgentAsTool so callers
can distinguish sub-agent stream events from regular tool stream events and access
the wrapped agent, tool name, description, etc.
"""

def __init__(self, tool_use: ToolUse, tool_stream_data: Any, agent_as_tool: "_AgentAsTool") -> None:
"""Initialize with tool streaming data and agent-tool reference.

Args:
tool_use: The tool invocation producing the stream.
tool_stream_data: The yielded event from the sub-agent execution.
agent_as_tool: The _AgentAsTool instance that produced this event.
"""
super().__init__(tool_use, tool_stream_data)
self._agent_as_tool = agent_as_tool

@property
def agent_as_tool(self) -> "_AgentAsTool":
"""The _AgentAsTool instance that produced this event."""
return self._agent_as_tool


class ToolCancelEvent(TypedEvent):
"""Event emitted when a user cancels a tool call from their BeforeToolCallEvent hook."""

Expand Down
Loading
Loading