Skip to content

Commit 482171f

Browse files
fix: detect context on callable tool instances
1 parent 161834d commit 482171f

2 files changed

Lines changed: 26 additions & 1 deletion

File tree

src/mcp/server/mcpserver/utilities/context_injection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None:
2222
Returns:
2323
The name of the context parameter, or None if not found
2424
"""
25+
target = fn
26+
if callable(fn) and not inspect.isfunction(fn) and not inspect.ismethod(fn) and not inspect.isclass(fn):
27+
target = fn.__call__
28+
2529
# Get type hints to properly resolve string annotations
2630
try:
27-
hints = typing.get_type_hints(fn)
31+
hints = typing.get_type_hints(target)
2832
except Exception: # pragma: lax no cover
2933
# If we can't resolve type hints, we can't find the context parameter
3034
return None

tests/server/mcpserver/test_tool_manager.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,27 @@ async def __call__(self, x: int) -> int: # pragma: no cover
137137
assert tool.is_async is True
138138
assert tool.parameters["properties"]["x"]["type"] == "integer"
139139

140+
@pytest.mark.anyio
141+
async def test_add_callable_object_with_context(self):
142+
"""Test registering a callable object with an injected context parameter."""
143+
144+
class MyTool:
145+
def __init__(self):
146+
self.__name__ = "MyTool"
147+
148+
async def __call__(self, query: str, ctx: Context) -> str:
149+
assert isinstance(ctx, Context)
150+
return f"Results for: {query}"
151+
152+
manager = ToolManager()
153+
tool = manager.add_tool(MyTool())
154+
155+
assert tool.context_kwarg == "ctx"
156+
assert "query" in tool.parameters["properties"]
157+
assert "ctx" not in tool.parameters["properties"]
158+
assert tool.parameters["required"] == ["query"]
159+
assert await tool.run({"query": "books"}, Context()) == "Results for: books"
160+
140161
def test_add_invalid_tool(self):
141162
manager = ToolManager()
142163
with pytest.raises(AttributeError):

0 commit comments

Comments
 (0)