diff --git a/src/anthropic/lib/tools/_beta_functions.py b/src/anthropic/lib/tools/_beta_functions.py index 21780cc7..c3f65ea7 100644 --- a/src/anthropic/lib/tools/_beta_functions.py +++ b/src/anthropic/lib/tools/_beta_functions.py @@ -3,7 +3,7 @@ import logging from abc import ABC, abstractmethod from typing import Any, Union, Generic, TypeVar, Callable, Iterable, Coroutine, cast, overload -from inspect import iscoroutinefunction +from inspect import ismethod, isfunction, iscoroutinefunction from typing_extensions import TypeAlias, override import pydantic @@ -23,6 +23,33 @@ BetaFunctionToolResultType: TypeAlias = Union[str, Iterable[BetaContent]] + +def _normalize_callable(func: Callable[..., Any]) -> Callable[..., Any]: + """Normalize a callable to a function that can be used with pydantic.validate_call. + + If the callable is a class instance with a __call__ method (but not a function or method), + this extracts the bound __call__ method. This allows callable class instances to be used + as tools without requiring manual extraction of __call__. + + Args: + func: A function, method, or callable instance + + Returns: + A function or bound method suitable for use with pydantic.validate_call + """ + # If it's already a function or method, use it directly + if isfunction(func) or ismethod(func): + return func + + # If it's a callable instance (class with __call__), extract the bound __call__ method + if callable(func): + call_method = func.__call__ # pyright: ignore[reportFunctionMemberAccess] # noqa: B004 + if ismethod(call_method): + return call_method + + return func + + Function = Callable[..., BetaFunctionToolResultType] FunctionT = TypeVar("FunctionT", bound=Function) @@ -83,9 +110,11 @@ def __init__( if _compat.PYDANTIC_V1: raise RuntimeError("Tool functions are only supported with Pydantic v2") - self.func = func - self._func_with_validate = pydantic.validate_call(func) - self.name = name or func.__name__ + # Normalize callable instances to their __call__ method + normalized_func = _normalize_callable(func) + self.func = cast(CallableT, normalized_func) + self._func_with_validate = pydantic.validate_call(normalized_func) + self.name = name or normalized_func.__name__ self._defer_loading = defer_loading self.description = description or self._get_description_from_docstring() diff --git a/tests/lib/tools/test_functions.py b/tests/lib/tools/test_functions.py index 7f632246..868eead9 100644 --- a/tests/lib/tools/test_functions.py +++ b/tests/lib/tools/test_functions.py @@ -439,6 +439,80 @@ def simple_add(a: int, b: int) -> str: assert function_tool.input_schema == expected_schema + def test_callable_class_instance(self) -> None: + """Test that callable class instances can be used as tools.""" + + class FetchProduct: + def __init__(self, ctx: dict[str, str]) -> None: + self.ctx = ctx + + def __call__(self, product_id: int) -> str: + """Fetch a product by ID.""" + return f"Product {product_id} from {self.ctx['session']}" + + instance = FetchProduct({"session": "test-session"}) + tool = beta_tool(instance, name="fetch_product") + + assert tool.name == "fetch_product" + assert tool.description == "Fetch a product by ID." + assert tool.call({"product_id": 123}) == "Product 123 from test-session" + + # Check schema + expected_schema = { + "additionalProperties": False, + "type": "object", + "properties": { + "product_id": {"title": "Product Id", "type": "integer"}, + }, + "required": ["product_id"], + } + assert tool.input_schema == expected_schema + + def test_bound_method(self) -> None: + """Test that bound methods can be used as tools.""" + + class WeatherService: + def __init__(self, api_key: str) -> None: + self.api_key = api_key + + def get_weather(self, location: str) -> str: + """Get weather for a location.""" + return f"Weather in {location} (using key: {self.api_key[:4]}...)" + + service = WeatherService("secret-api-key") + tool = beta_tool(service.get_weather, name="weather") + + assert tool.name == "weather" + assert tool.description == "Get weather for a location." + assert tool.call({"location": "London"}) == "Weather in London (using key: secr...)" + + # Check schema + expected_schema = { + "additionalProperties": False, + "type": "object", + "properties": { + "location": {"title": "Location", "type": "string"}, + }, + "required": ["location"], + } + assert tool.input_schema == expected_schema + + def test_callable_class_without_explicit_name(self) -> None: + """Test that callable class instances infer name from __call__ method.""" + + class MyTool: + def __call__(self, x: int) -> str: + """Process x.""" + return str(x) + + instance = MyTool() + tool = beta_tool(instance) + + # Should use __call__ as the name since that's the actual method + assert tool.name == "__call__" + assert tool.description == "Process x." + + def _get_parameters_info(fn: BaseFunctionTool[Any]) -> dict[str, str]: param_info: dict[str, str] = {} for param in fn._parsed_docstring.params: diff --git a/tests/lib/tools/test_runners.py b/tests/lib/tools/test_runners.py index 7575deba..977eedba 100644 --- a/tests/lib/tools/test_runners.py +++ b/tests/lib/tools/test_runners.py @@ -547,6 +547,86 @@ def tool_runner(client: Anthropic) -> BetaToolRunner[None]: respx_mock=respx_mock, ) + @pytest.mark.respx(base_url=base_url) + def test_callable_class_instance_tool( + self, client: Anthropic, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Test that callable class instances work with tool_runner.""" + + class WeatherTool: + def __init__(self, default_units: str) -> None: + self.default_units = default_units + + def __call__(self, location: str, units: Literal["c", "f"] = "c") -> str: + """Lookup the weather for a given city in either celsius or fahrenheit + + Args: + location: The city and state, e.g. San Francisco, CA + units: Unit for the output, either 'c' for celsius or 'f' for fahrenheit + Returns: + A dictionary containing the location, temperature, and weather condition. + """ + actual_units = units or self.default_units + return json.dumps(_get_weather(location, actual_units)) + + weather_instance = WeatherTool(default_units="f") + weather_tool = beta_tool(weather_instance, name="get_weather") + + message = make_snapshot_request( + lambda c: c.beta.messages.tool_runner( + max_tokens=1024, + model="claude-haiku-4-5", + tools=[weather_tool], + messages=[{"role": "user", "content": "What is the weather in SF?"}], + ).until_done(), + content_snapshot=snapshots["basic"]["responses"], + path="/v1/messages", + mock_client=client, + respx_mock=respx_mock, + ) + + assert print_obj(message, monkeypatch) == snapshots["basic"]["result"] + + @pytest.mark.respx(base_url=base_url) + def test_bound_method_tool( + self, client: Anthropic, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Test that bound methods work with tool_runner.""" + + class WeatherService: + def __init__(self, api_key: str) -> None: + self.api_key = api_key + + def get_weather(self, location: str, units: Literal["c", "f"]) -> str: + """Lookup the weather for a given city in either celsius or fahrenheit + + Args: + location: The city and state, e.g. San Francisco, CA + units: Unit for the output, either 'c' for celsius or 'f' for fahrenheit + Returns: + A dictionary containing the location, temperature, and weather condition. + """ + # In a real scenario, self.api_key would be used + return json.dumps(_get_weather(location, units)) + + service = WeatherService(api_key="secret-key") + weather_tool = beta_tool(service.get_weather, name="get_weather") + + message = make_snapshot_request( + lambda c: c.beta.messages.tool_runner( + max_tokens=1024, + model="claude-haiku-4-5", + tools=[weather_tool], + messages=[{"role": "user", "content": "What is the weather in SF?"}], + ).until_done(), + content_snapshot=snapshots["basic"]["responses"], + path="/v1/messages", + mock_client=client, + respx_mock=respx_mock, + ) + + assert print_obj(message, monkeypatch) == snapshots["basic"]["result"] + @pytest.mark.skipif(PYDANTIC_V1, reason="tool runner not supported with pydantic v1") @pytest.mark.respx(base_url=base_url)