Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions src/anthropic/lib/tools/_beta_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Union, Generic, TypeVar, Callable, Iterable, Coroutine, cast, overload
from inspect import iscoroutinefunction
from inspect import ismethod, isfunction, iscoroutinefunction
from typing_extensions import TypeAlias, override

import pydantic
Expand All @@ -23,6 +23,33 @@

BetaFunctionToolResultType: TypeAlias = Union[str, Iterable[BetaContent]]


def _normalize_callable(func: Callable[..., Any]) -> Callable[..., Any]:
"""Normalize a callable to a function that can be used with pydantic.validate_call.

If the callable is a class instance with a __call__ method (but not a function or method),
this extracts the bound __call__ method. This allows callable class instances to be used
as tools without requiring manual extraction of __call__.

Args:
func: A function, method, or callable instance

Returns:
A function or bound method suitable for use with pydantic.validate_call
"""
# If it's already a function or method, use it directly
if isfunction(func) or ismethod(func):
return func

# If it's a callable instance (class with __call__), extract the bound __call__ method
if callable(func):
call_method = func.__call__ # pyright: ignore[reportFunctionMemberAccess] # noqa: B004
if ismethod(call_method):
return call_method

return func


Function = Callable[..., BetaFunctionToolResultType]
FunctionT = TypeVar("FunctionT", bound=Function)

Expand Down Expand Up @@ -83,9 +110,11 @@ def __init__(
if _compat.PYDANTIC_V1:
raise RuntimeError("Tool functions are only supported with Pydantic v2")

self.func = func
self._func_with_validate = pydantic.validate_call(func)
self.name = name or func.__name__
# Normalize callable instances to their __call__ method
normalized_func = _normalize_callable(func)
self.func = cast(CallableT, normalized_func)
self._func_with_validate = pydantic.validate_call(normalized_func)
self.name = name or normalized_func.__name__
self._defer_loading = defer_loading

self.description = description or self._get_description_from_docstring()
Expand Down
74 changes: 74 additions & 0 deletions tests/lib/tools/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,80 @@ def simple_add(a: int, b: int) -> str:
assert function_tool.input_schema == expected_schema


def test_callable_class_instance(self) -> None:
"""Test that callable class instances can be used as tools."""

class FetchProduct:
def __init__(self, ctx: dict[str, str]) -> None:
self.ctx = ctx

def __call__(self, product_id: int) -> str:
"""Fetch a product by ID."""
return f"Product {product_id} from {self.ctx['session']}"

instance = FetchProduct({"session": "test-session"})
tool = beta_tool(instance, name="fetch_product")

assert tool.name == "fetch_product"
assert tool.description == "Fetch a product by ID."
assert tool.call({"product_id": 123}) == "Product 123 from test-session"

# Check schema
expected_schema = {
"additionalProperties": False,
"type": "object",
"properties": {
"product_id": {"title": "Product Id", "type": "integer"},
},
"required": ["product_id"],
}
assert tool.input_schema == expected_schema

def test_bound_method(self) -> None:
"""Test that bound methods can be used as tools."""

class WeatherService:
def __init__(self, api_key: str) -> None:
self.api_key = api_key

def get_weather(self, location: str) -> str:
"""Get weather for a location."""
return f"Weather in {location} (using key: {self.api_key[:4]}...)"

service = WeatherService("secret-api-key")
tool = beta_tool(service.get_weather, name="weather")

assert tool.name == "weather"
assert tool.description == "Get weather for a location."
assert tool.call({"location": "London"}) == "Weather in London (using key: secr...)"

# Check schema
expected_schema = {
"additionalProperties": False,
"type": "object",
"properties": {
"location": {"title": "Location", "type": "string"},
},
"required": ["location"],
}
assert tool.input_schema == expected_schema

def test_callable_class_without_explicit_name(self) -> None:
"""Test that callable class instances infer name from __call__ method."""

class MyTool:
def __call__(self, x: int) -> str:
"""Process x."""
return str(x)

instance = MyTool()
tool = beta_tool(instance)

# Should use __call__ as the name since that's the actual method
assert tool.name == "__call__"
assert tool.description == "Process x."


def _get_parameters_info(fn: BaseFunctionTool[Any]) -> dict[str, str]:
param_info: dict[str, str] = {}
for param in fn._parsed_docstring.params:
Expand Down
80 changes: 80 additions & 0 deletions tests/lib/tools/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,86 @@ def tool_runner(client: Anthropic) -> BetaToolRunner[None]:
respx_mock=respx_mock,
)

@pytest.mark.respx(base_url=base_url)
def test_callable_class_instance_tool(
self, client: Anthropic, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test that callable class instances work with tool_runner."""

class WeatherTool:
def __init__(self, default_units: str) -> None:
self.default_units = default_units

def __call__(self, location: str, units: Literal["c", "f"] = "c") -> str:
"""Lookup the weather for a given city in either celsius or fahrenheit
Args:
location: The city and state, e.g. San Francisco, CA
units: Unit for the output, either 'c' for celsius or 'f' for fahrenheit
Returns:
A dictionary containing the location, temperature, and weather condition.
"""
actual_units = units or self.default_units
return json.dumps(_get_weather(location, actual_units))

weather_instance = WeatherTool(default_units="f")
weather_tool = beta_tool(weather_instance, name="get_weather")

message = make_snapshot_request(
lambda c: c.beta.messages.tool_runner(
max_tokens=1024,
model="claude-haiku-4-5",
tools=[weather_tool],
messages=[{"role": "user", "content": "What is the weather in SF?"}],
).until_done(),
content_snapshot=snapshots["basic"]["responses"],
path="/v1/messages",
mock_client=client,
respx_mock=respx_mock,
)

assert print_obj(message, monkeypatch) == snapshots["basic"]["result"]

@pytest.mark.respx(base_url=base_url)
def test_bound_method_tool(
self, client: Anthropic, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test that bound methods work with tool_runner."""

class WeatherService:
def __init__(self, api_key: str) -> None:
self.api_key = api_key

def get_weather(self, location: str, units: Literal["c", "f"]) -> str:
"""Lookup the weather for a given city in either celsius or fahrenheit
Args:
location: The city and state, e.g. San Francisco, CA
units: Unit for the output, either 'c' for celsius or 'f' for fahrenheit
Returns:
A dictionary containing the location, temperature, and weather condition.
"""
# In a real scenario, self.api_key would be used
return json.dumps(_get_weather(location, units))

service = WeatherService(api_key="secret-key")
weather_tool = beta_tool(service.get_weather, name="get_weather")

message = make_snapshot_request(
lambda c: c.beta.messages.tool_runner(
max_tokens=1024,
model="claude-haiku-4-5",
tools=[weather_tool],
messages=[{"role": "user", "content": "What is the weather in SF?"}],
).until_done(),
content_snapshot=snapshots["basic"]["responses"],
path="/v1/messages",
mock_client=client,
respx_mock=respx_mock,
)

assert print_obj(message, monkeypatch) == snapshots["basic"]["result"]


@pytest.mark.skipif(PYDANTIC_V1, reason="tool runner not supported with pydantic v1")
@pytest.mark.respx(base_url=base_url)
Expand Down