From ca0cc75c2b8d59e135d53307ef7aa74731393aca Mon Sep 17 00:00:00 2001 From: darshil3011 Date: Fri, 15 May 2026 16:20:29 -0700 Subject: [PATCH] feat(tools): expose httpx_client_factory on RestApiTool and OpenAPIToolset Mirror the pattern merged for MCP in #2997 (StreamableHTTPConnectionParams. httpx_client_factory) on the OpenAPI tool surface. Adds an optional httpx_client_factory parameter to: - RestApiTool.__init__ - RestApiTool.from_parsed_operation - OpenAPIToolset.__init__ OpenAPIToolset forwards the factory to every generated RestApiTool the same way ssl_verify and header_provider are already forwarded. When provided, the factory's client is used to issue each API call; when None (default), the existing httpx.AsyncClient(verify=..., timeout=None) construction is preserved exactly. This unlocks httpx.AsyncClient knobs that the narrower ssl_verify parameter can't reach: proxies, HTTP/2, custom transports (e.g. request signing), and shared connection pools. Closes #5681 --- .../openapi_spec_parser/openapi_toolset.py | 12 ++ .../openapi_spec_parser/rest_api_tool.py | 46 +++++++- .../test_openapi_toolset.py | 23 ++++ .../openapi_spec_parser/test_rest_api_tool.py | 107 ++++++++++++++++++ 4 files changed, 182 insertions(+), 6 deletions(-) diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py index 99e649d9c9..2a1d384f05 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py @@ -36,6 +36,7 @@ from ...base_toolset import BaseToolset from ...base_toolset import ToolPredicate from .openapi_spec_parser import OpenApiSpecParser +from .rest_api_tool import HttpxClientFactory from .rest_api_tool import RestApiTool logger = logging.getLogger("google_adk." + __name__) @@ -77,6 +78,7 @@ def __init__( header_provider: Optional[ Callable[[ReadonlyContext], Dict[str, str]] ] = None, + httpx_client_factory: Optional[HttpxClientFactory] = None, preserve_property_names: bool = False, ): """Initializes the OpenAPIToolset. @@ -130,6 +132,14 @@ def __init__( an argument, allowing dynamic header generation based on the current context. Useful for adding custom headers like correlation IDs, authentication tokens, or other request metadata. + httpx_client_factory: Optional zero-argument callable returning an + ``httpx.AsyncClient`` to use for every generated tool's API calls. + When provided, it takes precedence over the per-tool default client + construction and unlocks ``httpx.AsyncClient`` options that + ``ssl_verify`` can't reach (proxies, HTTP/2, custom transports such as + request signing, shared connection pools). Defaults to ``None``, which + preserves today's behaviour. Mirrors the pattern exposed for MCP by + ``StreamableHTTPConnectionParams.httpx_client_factory``. preserve_property_names: If True, preserve the original property names from the OpenAPI spec instead of converting them to snake_case. This is useful when calling APIs that expect camelCase or other @@ -155,6 +165,7 @@ def __init__( if not spec_dict: spec_dict = self._load_spec(spec_str, spec_str_type) self._ssl_verify = ssl_verify + self._httpx_client_factory = httpx_client_factory self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict)) if auth_scheme or auth_credential: self._configure_auth_all(auth_scheme, auth_credential) @@ -237,6 +248,7 @@ def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]: o, ssl_verify=self._ssl_verify, header_provider=self._header_provider, + httpx_client_factory=self._httpx_client_factory, ) logger.info("Parsed tool: %s", tool.name) tools.append(tool) diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index fa32ce932a..2139c1882e 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -75,6 +75,17 @@ def snake_to_lower_camel(snake_case_string: str): AuthPreparationState = Literal["pending", "done"] +HttpxClientFactory = Callable[..., httpx.AsyncClient] +"""Type alias for a factory returning an ``httpx.AsyncClient``. + +When supplied to ``RestApiTool`` or ``OpenAPIToolset``, the factory is invoked +once per API call and its returned client is used (as an async context +manager) to issue the request, in place of the default +``httpx.AsyncClient(verify=..., timeout=None)``. This unlocks knobs that the +narrower ``ssl_verify`` parameter can't reach: proxies, HTTP/2, custom +transports (e.g. request-signing), shared connection pools, and so on. +""" + class RestApiTool(BaseTool): """A generic tool that interacts with a REST API. @@ -103,6 +114,7 @@ def __init__( header_provider: Optional[ Callable[[ReadonlyContext], Dict[str, str]] ] = None, + httpx_client_factory: Optional[HttpxClientFactory] = None, *, credential_key: Optional[str] = None, ): @@ -142,6 +154,15 @@ def __init__( an argument, allowing dynamic header generation based on the current context. Useful for adding custom headers like correlation IDs, authentication tokens, or other request metadata. + httpx_client_factory: Optional zero-argument callable returning an + ``httpx.AsyncClient``. When provided, the returned client is used to + issue the request, allowing callers to configure proxies, HTTP/2, + custom transports (e.g. request signing), shared connection pools, + or any other ``httpx.AsyncClient`` option that ``ssl_verify`` can't + reach. When ``None`` (default), behaviour is unchanged: a fresh + ``httpx.AsyncClient(verify=..., timeout=None)`` is created per + request. Mirrors the pattern exposed for MCP by + ``StreamableHTTPConnectionParams.httpx_client_factory``. credential_key: Optional stable key used for interactive auth and credential caching. """ @@ -169,6 +190,7 @@ def __init__( self._default_headers: Dict[str, str] = {} self._ssl_verify = ssl_verify self._header_provider = header_provider + self._httpx_client_factory = httpx_client_factory self._logger = logger if should_parse_operation: self._operation_parser = OperationParser(self.operation) @@ -181,6 +203,7 @@ def from_parsed_operation( header_provider: Optional[ Callable[[ReadonlyContext], Dict[str, str]] ] = None, + httpx_client_factory: Optional[HttpxClientFactory] = None, ) -> "RestApiTool": """Initializes the RestApiTool from a ParsedOperation object. @@ -192,6 +215,9 @@ def from_parsed_operation( an argument, allowing dynamic header generation based on the current context. Useful for adding custom headers like correlation IDs, authentication tokens, or other request metadata. + httpx_client_factory: Optional zero-argument callable returning an + ``httpx.AsyncClient`` to be used for the API call. See + ``RestApiTool.__init__`` for details. Returns: A RestApiTool object. @@ -212,6 +238,7 @@ def from_parsed_operation( auth_credential=parsed.auth_credential, ssl_verify=ssl_verify, header_provider=header_provider, + httpx_client_factory=httpx_client_factory, ) generated._operation_parser = operation_parser return generated @@ -520,7 +547,9 @@ async def call( if provider_headers: request_params.setdefault("headers", {}).update(provider_headers) - response = await _request(**request_params) + response = await _request( + httpx_client_factory=self._httpx_client_factory, **request_params + ) # Log the API response self._logger.debug( @@ -569,9 +598,14 @@ def __repr__(self): ) -async def _request(**request_params) -> httpx.Response: - async with httpx.AsyncClient( - verify=request_params.pop("verify", True), - timeout=None, - ) as client: +async def _request( + *, + httpx_client_factory: Optional[HttpxClientFactory] = None, + **request_params, +) -> httpx.Response: + verify = request_params.pop("verify", True) + if httpx_client_factory is not None: + async with httpx_client_factory() as client: + return await client.request(**request_params) + async with httpx.AsyncClient(verify=verify, timeout=None) as client: return await client.request(**request_params) diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py index 81c35b6964..d49ff2db77 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py @@ -153,6 +153,29 @@ def test_openapi_toolset_verify_on_init( assert all(tool._ssl_verify == verify_value for tool in toolset._tools) +def test_openapi_toolset_httpx_client_factory_on_init( + openapi_spec: Dict[str, Any], +): + """The httpx_client_factory is forwarded to every generated tool.""" + custom_factory = lambda: None # noqa: E731 - placeholder, never invoked here + toolset = OpenAPIToolset( + spec_dict=openapi_spec, httpx_client_factory=custom_factory + ) + assert toolset._httpx_client_factory is custom_factory + assert all( + tool._httpx_client_factory is custom_factory for tool in toolset._tools + ) + + +def test_openapi_toolset_httpx_client_factory_none_by_default( + openapi_spec: Dict[str, Any], +): + """httpx_client_factory is None on the toolset and each tool by default.""" + toolset = OpenAPIToolset(spec_dict=openapi_spec) + assert toolset._httpx_client_factory is None + assert all(tool._httpx_client_factory is None for tool in toolset._tools) + + def test_openapi_toolset_configure_verify_all(openapi_spec: Dict[str, Any]): """Test configure_verify_all method.""" toolset = OpenAPIToolset(spec_dict=openapi_spec) diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index fa21201488..412d16f64e 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -1339,6 +1339,113 @@ async def test_call_without_header_provider( assert result == {"result": "success"} + def test_init_httpx_client_factory_none_by_default( + self, + sample_endpoint, + sample_operation, + ): + """httpx_client_factory is None by default.""" + tool = RestApiTool( + name="test_tool", + description="Test Tool", + endpoint=sample_endpoint, + operation=sample_operation, + ) + assert tool._httpx_client_factory is None + + def test_init_with_httpx_client_factory( + self, + sample_endpoint, + sample_operation, + ): + """A user-supplied httpx_client_factory is stored on the tool.""" + custom_factory = MagicMock() + tool = RestApiTool( + name="test_tool", + description="Test Tool", + endpoint=sample_endpoint, + operation=sample_operation, + httpx_client_factory=custom_factory, + ) + assert tool._httpx_client_factory is custom_factory + + @pytest.mark.asyncio + async def test_call_uses_custom_httpx_client_factory( + self, + mock_tool_context, + sample_endpoint, + sample_operation, + sample_auth_scheme, + sample_auth_credential, + ): + """When a factory is provided, its client is used to issue the request.""" + mock_response = mock.create_autospec(requests.Response, instance=True) + mock_response.json.return_value = {"result": "success"} + mock_response.configure_mock(status_code=200) + + mock_client = mock.create_autospec( + httpx.AsyncClient, instance=True, spec_set=True + ) + mock_client.request = AsyncMock(return_value=mock_response) + # Make the mock client work as an async context manager. + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + custom_factory = MagicMock(return_value=mock_client) + + tool = RestApiTool( + name="test_tool", + description="Test Tool", + endpoint=sample_endpoint, + operation=sample_operation, + auth_scheme=sample_auth_scheme, + auth_credential=sample_auth_credential, + httpx_client_factory=custom_factory, + ) + + with patch.object(httpx, "AsyncClient", autospec=True) as mock_default: + result = await tool.call(args={}, tool_context=mock_tool_context) + + # Factory must be invoked once and the default client must not be built. + custom_factory.assert_called_once_with() + mock_default.assert_not_called() + mock_client.request.assert_awaited_once() + assert result == {"result": "success"} + + @pytest.mark.asyncio + async def test_call_without_httpx_client_factory_uses_default_client( + self, + mock_tool_context, + sample_endpoint, + sample_operation, + sample_auth_scheme, + sample_auth_credential, + ): + """When no factory is provided, the default httpx.AsyncClient is used.""" + mock_response = mock.create_autospec(requests.Response, instance=True) + mock_response.json.return_value = {"result": "success"} + mock_response.configure_mock(status_code=200) + + mock_client = mock.create_autospec( + httpx.AsyncClient, instance=True, spec_set=True + ) + mock_client.request = AsyncMock(return_value=mock_response) + + tool = RestApiTool( + name="test_tool", + description="Test Tool", + endpoint=sample_endpoint, + operation=sample_operation, + auth_scheme=sample_auth_scheme, + auth_credential=sample_auth_credential, + ) + + with patch.object( + httpx, "AsyncClient", return_value=mock_client, autospec=True + ) as mock_async_client: + await tool.call(args={}, tool_context=mock_tool_context) + assert mock_async_client.called + def test_prepare_request_params_extracts_embedded_query_params( self, sample_auth_credential, sample_auth_scheme ):