diff --git a/src/mcp/server/mcpserver/utilities/context_injection.py b/src/mcp/server/mcpserver/utilities/context_injection.py index ac7ab82d0..c861364f2 100644 --- a/src/mcp/server/mcpserver/utilities/context_injection.py +++ b/src/mcp/server/mcpserver/utilities/context_injection.py @@ -22,9 +22,13 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None: Returns: The name of the context parameter, or None if not found """ + target = fn + if callable(fn) and not inspect.isfunction(fn) and not inspect.ismethod(fn) and not inspect.isclass(fn): + target = fn.__call__ + # Get type hints to properly resolve string annotations try: - hints = typing.get_type_hints(fn) + hints = typing.get_type_hints(target) except Exception: # pragma: lax no cover # If we can't resolve type hints, we can't find the context parameter return None diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index e4dfd4ff9..658e02766 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -137,6 +137,27 @@ async def __call__(self, x: int) -> int: # pragma: no cover assert tool.is_async is True assert tool.parameters["properties"]["x"]["type"] == "integer" + @pytest.mark.anyio + async def test_add_callable_object_with_context(self): + """Test registering a callable object with an injected context parameter.""" + + class MyTool: + def __init__(self): + self.__name__ = "MyTool" + + async def __call__(self, query: str, ctx: Context) -> str: + assert isinstance(ctx, Context) + return f"Results for: {query}" + + manager = ToolManager() + tool = manager.add_tool(MyTool()) + + assert tool.context_kwarg == "ctx" + assert "query" in tool.parameters["properties"] + assert "ctx" not in tool.parameters["properties"] + assert tool.parameters["required"] == ["query"] + assert await tool.run({"query": "books"}, Context()) == "Results for: books" + def test_add_invalid_tool(self): manager = ToolManager() with pytest.raises(AttributeError):