Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ...base_toolset import BaseToolset
from ...base_toolset import ToolPredicate
from .openapi_spec_parser import OpenApiSpecParser
from .rest_api_tool import HttpxClientFactory
from .rest_api_tool import RestApiTool

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
httpx_client_factory: Optional[HttpxClientFactory] = None,
preserve_property_names: bool = False,
):
"""Initializes the OpenAPIToolset.
Expand Down Expand Up @@ -130,6 +132,14 @@ def __init__(
an argument, allowing dynamic header generation based on the current
context. Useful for adding custom headers like correlation IDs,
authentication tokens, or other request metadata.
httpx_client_factory: Optional zero-argument callable returning an
``httpx.AsyncClient`` to use for every generated tool's API calls.
When provided, it takes precedence over the per-tool default client
construction and unlocks ``httpx.AsyncClient`` options that
``ssl_verify`` can't reach (proxies, HTTP/2, custom transports such as
request signing, shared connection pools). Defaults to ``None``, which
preserves today's behaviour. Mirrors the pattern exposed for MCP by
``StreamableHTTPConnectionParams.httpx_client_factory``.
preserve_property_names: If True, preserve the original property names
from the OpenAPI spec instead of converting them to snake_case. This
is useful when calling APIs that expect camelCase or other
Expand All @@ -155,6 +165,7 @@ def __init__(
if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type)
self._ssl_verify = ssl_verify
self._httpx_client_factory = httpx_client_factory
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
if auth_scheme or auth_credential:
self._configure_auth_all(auth_scheme, auth_credential)
Expand Down Expand Up @@ -237,6 +248,7 @@ def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:
o,
ssl_verify=self._ssl_verify,
header_provider=self._header_provider,
httpx_client_factory=self._httpx_client_factory,
)
logger.info("Parsed tool: %s", tool.name)
tools.append(tool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ def snake_to_lower_camel(snake_case_string: str):

AuthPreparationState = Literal["pending", "done"]

HttpxClientFactory = Callable[..., httpx.AsyncClient]
"""Type alias for a factory returning an ``httpx.AsyncClient``.

When supplied to ``RestApiTool`` or ``OpenAPIToolset``, the factory is invoked
once per API call and its returned client is used (as an async context
manager) to issue the request, in place of the default
``httpx.AsyncClient(verify=..., timeout=None)``. This unlocks knobs that the
narrower ``ssl_verify`` parameter can't reach: proxies, HTTP/2, custom
transports (e.g. request-signing), shared connection pools, and so on.
"""


class RestApiTool(BaseTool):
"""A generic tool that interacts with a REST API.
Expand Down Expand Up @@ -103,6 +114,7 @@ def __init__(
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
httpx_client_factory: Optional[HttpxClientFactory] = None,
*,
credential_key: Optional[str] = None,
):
Expand Down Expand Up @@ -142,6 +154,15 @@ def __init__(
an argument, allowing dynamic header generation based on the current
context. Useful for adding custom headers like correlation IDs,
authentication tokens, or other request metadata.
httpx_client_factory: Optional zero-argument callable returning an
``httpx.AsyncClient``. When provided, the returned client is used to
issue the request, allowing callers to configure proxies, HTTP/2,
custom transports (e.g. request signing), shared connection pools,
or any other ``httpx.AsyncClient`` option that ``ssl_verify`` can't
reach. When ``None`` (default), behaviour is unchanged: a fresh
``httpx.AsyncClient(verify=..., timeout=None)`` is created per
request. Mirrors the pattern exposed for MCP by
``StreamableHTTPConnectionParams.httpx_client_factory``.
credential_key: Optional stable key used for interactive auth and
credential caching.
"""
Expand Down Expand Up @@ -169,6 +190,7 @@ def __init__(
self._default_headers: Dict[str, str] = {}
self._ssl_verify = ssl_verify
self._header_provider = header_provider
self._httpx_client_factory = httpx_client_factory
self._logger = logger
if should_parse_operation:
self._operation_parser = OperationParser(self.operation)
Expand All @@ -181,6 +203,7 @@ def from_parsed_operation(
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
httpx_client_factory: Optional[HttpxClientFactory] = None,
) -> "RestApiTool":
"""Initializes the RestApiTool from a ParsedOperation object.

Expand All @@ -192,6 +215,9 @@ def from_parsed_operation(
an argument, allowing dynamic header generation based on the current
context. Useful for adding custom headers like correlation IDs,
authentication tokens, or other request metadata.
httpx_client_factory: Optional zero-argument callable returning an
``httpx.AsyncClient`` to be used for the API call. See
``RestApiTool.__init__`` for details.

Returns:
A RestApiTool object.
Expand All @@ -212,6 +238,7 @@ def from_parsed_operation(
auth_credential=parsed.auth_credential,
ssl_verify=ssl_verify,
header_provider=header_provider,
httpx_client_factory=httpx_client_factory,
)
generated._operation_parser = operation_parser
return generated
Expand Down Expand Up @@ -520,7 +547,9 @@ async def call(
if provider_headers:
request_params.setdefault("headers", {}).update(provider_headers)

response = await _request(**request_params)
response = await _request(
httpx_client_factory=self._httpx_client_factory, **request_params
)

# Log the API response
self._logger.debug(
Expand Down Expand Up @@ -569,9 +598,14 @@ def __repr__(self):
)


async def _request(**request_params) -> httpx.Response:
async with httpx.AsyncClient(
verify=request_params.pop("verify", True),
timeout=None,
) as client:
async def _request(
*,
httpx_client_factory: Optional[HttpxClientFactory] = None,
**request_params,
) -> httpx.Response:
verify = request_params.pop("verify", True)
if httpx_client_factory is not None:
async with httpx_client_factory() as client:
return await client.request(**request_params)
async with httpx.AsyncClient(verify=verify, timeout=None) as client:
return await client.request(**request_params)
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,29 @@ def test_openapi_toolset_verify_on_init(
assert all(tool._ssl_verify == verify_value for tool in toolset._tools)


def test_openapi_toolset_httpx_client_factory_on_init(
openapi_spec: Dict[str, Any],
):
"""The httpx_client_factory is forwarded to every generated tool."""
custom_factory = lambda: None # noqa: E731 - placeholder, never invoked here
toolset = OpenAPIToolset(
spec_dict=openapi_spec, httpx_client_factory=custom_factory
)
assert toolset._httpx_client_factory is custom_factory
assert all(
tool._httpx_client_factory is custom_factory for tool in toolset._tools
)


def test_openapi_toolset_httpx_client_factory_none_by_default(
openapi_spec: Dict[str, Any],
):
"""httpx_client_factory is None on the toolset and each tool by default."""
toolset = OpenAPIToolset(spec_dict=openapi_spec)
assert toolset._httpx_client_factory is None
assert all(tool._httpx_client_factory is None for tool in toolset._tools)


def test_openapi_toolset_configure_verify_all(openapi_spec: Dict[str, Any]):
"""Test configure_verify_all method."""
toolset = OpenAPIToolset(spec_dict=openapi_spec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,113 @@ async def test_call_without_header_provider(

assert result == {"result": "success"}

def test_init_httpx_client_factory_none_by_default(
self,
sample_endpoint,
sample_operation,
):
"""httpx_client_factory is None by default."""
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpoint,
operation=sample_operation,
)
assert tool._httpx_client_factory is None

def test_init_with_httpx_client_factory(
self,
sample_endpoint,
sample_operation,
):
"""A user-supplied httpx_client_factory is stored on the tool."""
custom_factory = MagicMock()
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpoint,
operation=sample_operation,
httpx_client_factory=custom_factory,
)
assert tool._httpx_client_factory is custom_factory

@pytest.mark.asyncio
async def test_call_uses_custom_httpx_client_factory(
self,
mock_tool_context,
sample_endpoint,
sample_operation,
sample_auth_scheme,
sample_auth_credential,
):
"""When a factory is provided, its client is used to issue the request."""
mock_response = mock.create_autospec(requests.Response, instance=True)
mock_response.json.return_value = {"result": "success"}
mock_response.configure_mock(status_code=200)

mock_client = mock.create_autospec(
httpx.AsyncClient, instance=True, spec_set=True
)
mock_client.request = AsyncMock(return_value=mock_response)
# Make the mock client work as an async context manager.
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)

custom_factory = MagicMock(return_value=mock_client)

tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpoint,
operation=sample_operation,
auth_scheme=sample_auth_scheme,
auth_credential=sample_auth_credential,
httpx_client_factory=custom_factory,
)

with patch.object(httpx, "AsyncClient", autospec=True) as mock_default:
result = await tool.call(args={}, tool_context=mock_tool_context)

# Factory must be invoked once and the default client must not be built.
custom_factory.assert_called_once_with()
mock_default.assert_not_called()
mock_client.request.assert_awaited_once()
assert result == {"result": "success"}

@pytest.mark.asyncio
async def test_call_without_httpx_client_factory_uses_default_client(
self,
mock_tool_context,
sample_endpoint,
sample_operation,
sample_auth_scheme,
sample_auth_credential,
):
"""When no factory is provided, the default httpx.AsyncClient is used."""
mock_response = mock.create_autospec(requests.Response, instance=True)
mock_response.json.return_value = {"result": "success"}
mock_response.configure_mock(status_code=200)

mock_client = mock.create_autospec(
httpx.AsyncClient, instance=True, spec_set=True
)
mock_client.request = AsyncMock(return_value=mock_response)

tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpoint,
operation=sample_operation,
auth_scheme=sample_auth_scheme,
auth_credential=sample_auth_credential,
)

with patch.object(
httpx, "AsyncClient", return_value=mock_client, autospec=True
) as mock_async_client:
await tool.call(args={}, tool_context=mock_tool_context)
assert mock_async_client.called

def test_prepare_request_params_extracts_embedded_query_params(
self, sample_auth_credential, sample_auth_scheme
):
Expand Down
Loading