From b6399d7e5630d30d3c2fc1a36e8f5f92df11e5ae Mon Sep 17 00:00:00 2001 From: Amy Wu Date: Mon, 16 Mar 2026 15:17:25 -0700 Subject: [PATCH] feat: Autoenable mTLS in environment with bound token (Agent Engine with AgentAuthority) through google-auth migration (except custom client args, custom client or custom ClientSession) PiperOrigin-RevId: 884659552 --- google/genai/_api_client.py | 266 +++++++++++++----- google/genai/client.py | 5 +- google/genai/errors.py | 59 +++- .../genai/tests/client/test_client_close.py | 8 + .../client/test_client_initialization.py | 33 ++- google/genai/tests/client/test_retries.py | 107 ++++++- pyproject.toml | 3 +- requirements.txt | 1 + 8 files changed, 378 insertions(+), 104 deletions(-) diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index 751211e51..6d723f3ab 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ import sys import threading import time -from typing import Any, AsyncIterator, Iterator, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, AsyncIterator, Iterator, Optional, TYPE_CHECKING, Tuple, Union from urllib.parse import urlparse from urllib.parse import urlunparse import warnings @@ -44,9 +44,13 @@ import google.auth import google.auth.credentials from google.auth.credentials import Credentials +from google.auth.transport import mtls +from google.auth.transport.requests import AuthorizedSession import httpx from pydantic import BaseModel from pydantic import ValidationError +import requests +from requests.structures import CaseInsensitiveDict import tenacity from . import _common @@ -65,6 +69,7 @@ # This try/except is for TAP, mypy complains about it which is why we have the type: ignore from websockets.client import connect as ws_connect # type: ignore + has_aiohttp = False try: import aiohttp @@ -76,6 +81,7 @@ if TYPE_CHECKING: from multidict import CIMultiDictProxy + from google.auth.aio.transport.sessions import AsyncAuthorizedSession logger = logging.getLogger('google_genai._api_client') @@ -181,13 +187,6 @@ def join_url_path(base_url: str, path: str) -> str: def load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]: """Loads google auth credentials and project id.""" - - ## Set GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES to false - ## to disable bound token sharing. Tracking on - ## https://github.com/googleapis/python-genai/issues/1956 - os.environ['GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES'] = ( - 'false' - ) credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call] scopes=['https://www.googleapis.com/auth/cloud-platform'], ) @@ -235,7 +234,12 @@ class HttpResponse: def __init__( self, - headers: Union[dict[str, str], httpx.Headers, 'CIMultiDictProxy[str]'], + headers: Union[ + dict[str, str], + httpx.Headers, + 'CIMultiDictProxy[str]', + CaseInsensitiveDict, + ], response_stream: Union[Any, str] = None, byte_stream: Union[Any, bytes] = None, ): @@ -245,6 +249,8 @@ def __init__( self.headers = { key: ', '.join(headers.get_list(key)) for key in headers.keys() } + elif isinstance(headers, CaseInsensitiveDict): + self.headers = {key: value for key, value in headers.items()} elif type(headers).__name__ == 'CIMultiDictProxy': self.headers = { key: ', '.join(headers.getall(key)) for key in headers.keys() @@ -321,7 +327,10 @@ def _copy_to_dict(self, response_payload: dict[str, object]) -> None: def _iter_response_stream(self) -> Iterator[str]: """Iterates over chunks retrieved from the API.""" - if not isinstance(self.response_stream, httpx.Response): + if not ( + isinstance(self.response_stream, httpx.Response) + or isinstance(self.response_stream, requests.Response) + ): raise TypeError( 'Expected self.response_stream to be an httpx.Response object, ' f'but got {type(self.response_stream).__name__}.' @@ -329,7 +338,15 @@ def _iter_response_stream(self) -> Iterator[str]: chunk = '' balance = 0 - for line in self.response_stream.iter_lines(): + if isinstance(self.response_stream, httpx.Response): + response_stream = self.response_stream.iter_lines() + else: + response_stream = self.response_stream.iter_lines(decode_unicode=True) + for line in response_stream: + if not isinstance(line, (str, bytes)): + print(f'Unexpected type yielded: {type(line)}') + continue + if not line: continue @@ -594,7 +611,10 @@ def __init__( elif http_options and _common.is_duck_type_of(http_options, HttpOptions): validated_http_options = http_options - if validated_http_options.base_url_resource_scope and not validated_http_options.base_url: + if ( + validated_http_options.base_url_resource_scope + and not validated_http_options.base_url + ): # base_url_resource_scope is only valid when base_url is set. raise ValueError( 'base_url must be set when base_url_resource_scope is set.' @@ -730,11 +750,22 @@ def __init__( self._http_options ) self._async_httpx_client_args = async_client_args + self._authorized_session: Optional[AuthorizedSession] = None - if self._http_options.httpx_client: + if self._use_google_auth_sync(): + self._httpx_client = None + elif self._http_options.httpx_client: self._httpx_client = self._http_options.httpx_client else: self._httpx_client = SyncHttpxClient(**client_args) + + if self._use_google_auth_async(): + self._async_httpx_client = None + elif self._http_options.httpx_async_client: + self._async_httpx_client = self._http_options.httpx_async_client + else: + self._async_httpx_client = AsyncHttpxClient(**async_client_args) + if self._http_options.httpx_async_client: self._async_httpx_client = self._http_options.httpx_async_client else: @@ -752,6 +783,14 @@ def __init__( self._async_client_session_request_args = ( self._ensure_aiohttp_ssl_ctx(self._http_options) ) + if self._use_google_auth_async(): + self._async_client_session_request_args['ssl'] = True # type: ignore[no-untyped-call] + self._async_client_session_request_args['max_allowed_time'] = float( + 'inf' + ) if self._http_options.timeout is None else float( + self._http_options.timeout + ) + self._async_client_session_request_args['total_attempts'] = 1 except ImportError: pass @@ -760,13 +799,53 @@ def __init__( self._retry = tenacity.Retrying(**retry_kwargs) self._async_retry = tenacity.AsyncRetrying(**retry_kwargs) - async def _get_aiohttp_session(self) -> 'aiohttp.ClientSession': + def _use_google_auth_sync(self) -> bool: + if not hasattr(mtls, 'should_use_client_cert'): + return False + return bool( + self.vertexai + and mtls.should_use_client_cert() # type: ignore[no-untyped-call] + and mtls.has_default_client_cert_source() # type: ignore[no-untyped-call] + and not ( + self._http_options.httpx_client or self._http_options.client_args + ) + ) + + def _use_google_auth_async(self) -> bool: + try: + from google.auth.aio.credentials import StaticCredentials + from google.auth.aio.transport.sessions import AsyncAuthorizedSession + except ImportError: + return False + return bool( + has_aiohttp + and self.vertexai + and mtls.should_use_client_cert() # type: ignore[no-untyped-call] + and mtls.has_default_client_cert_source() # type: ignore[no-untyped-call] + and not self._http_options.httpx_async_client + ) + + async def _get_aiohttp_session( + self, + ) -> Union['aiohttp.ClientSession', 'AsyncAuthorizedSession']: """Returns the aiohttp client session.""" - if ( + + if self._aiohttp_session is None and self._use_google_auth_async(): + try: + from google.auth.aio.credentials import StaticCredentials + from google.auth.aio.transport.sessions import AsyncAuthorizedSession + + async_creds = StaticCredentials(token=self._access_token()) # type: ignore[no-untyped-call] + self._aiohttp_session = AsyncAuthorizedSession(async_creds) # type: ignore[no-untyped-call] + return self._aiohttp_session + except ImportError: + pass + + if not self._use_google_auth_async() and ( self._aiohttp_session is None or self._aiohttp_session.closed - or self._aiohttp_session._loop.is_closed() # pylint: disable=protected-access - ): + or self._aiohttp_session._loop.is_closed() + ): # pylint: disable=protected-access # Initialize the aiohttp client session if it's not set up or closed. class AiohttpClientSession(aiohttp.ClientSession): # type: ignore[misc] @@ -803,6 +882,7 @@ def __del__(self, _warnings: Any = warnings) -> None: trust_env=True, read_bufsize=READ_BUFFER_SIZE, ) + return self._aiohttp_session @staticmethod @@ -883,16 +963,11 @@ def _ensure_aiohttp_ssl_ctx(options: HttpOptions) -> _common.StringDict: Returns: An async aiohttp ClientSession._request args. """ - verify = 'ssl' # keep it consistent with httpx. + verify = 'ssl' # keep it consistent with aiohttp. async_args = options.async_client_args ctx = async_args.get(verify) if async_args else None if not ctx: - # Initialize the SSL context for the httpx client. - # Unlike requests, the aiohttp package does not automatically pull in the - # environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be - # enabled explicitly. Instead of 'verify' at client level in httpx, - # aiohttp uses 'ssl' at request level. ctx = ssl.create_default_context( cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), capath=os.environ.get('SSL_CERT_DIR'), @@ -1191,31 +1266,46 @@ def _request_once( else: data = http_request.data - if stream: - httpx_request = self._httpx_client.build_request( - method=http_request.method, - url=http_request.url, - content=data, + if self._use_google_auth_sync(): + url = str(http_request.url) + if self._authorized_session is None: + self._authorized_session = AuthorizedSession( # type: ignore[no-untyped-call] + self._credentials, + max_refresh_attempts=1, + ) + client_cert_source = mtls.default_client_cert_source() # type: ignore[no-untyped-call] + self._authorized_session.configure_mtls_channel( + client_cert_source + ) # type: ignore[no-untyped-call] + if self._authorized_session._is_mtls and 'googleapis.com' in url: + if 'sandbox' in url: + url = url.replace( + 'sandbox.googleapis.com', 'mtls.sandbox.googleapis.com' + ) + else: + url = url.replace('googleapis.com', 'mtls.googleapis.com') + print('request.url: %s' % url) + response = self._authorized_session.request( # type: ignore[no-untyped-call] + method=http_request.method.upper(), + url=url, + data=data, headers=http_request.headers, timeout=http_request.timeout, - ) - response = self._httpx_client.send(httpx_request, stream=stream) - errors.APIError.raise_for_response(response) - return HttpResponse( - response.headers, response if stream else [response.text] + stream=stream, ) else: - response = self._httpx_client.request( + httpx_request = self._httpx_client.build_request( # type: ignore[union-attr] method=http_request.method, url=http_request.url, - headers=http_request.headers, content=data, + headers=http_request.headers, timeout=http_request.timeout, ) - errors.APIError.raise_for_response(response) - return HttpResponse( - response.headers, response if stream else [response.text] - ) + response = self._httpx_client.send(httpx_request, stream=stream) # type: ignore[union-attr] + errors.APIError.raise_for_response(response) + return HttpResponse( + response.headers, response if stream else [response.text] + ) def _request( self, @@ -1240,7 +1330,7 @@ def _request( async def _async_request_once( self, http_request: HttpRequest, stream: bool = False ) -> HttpResponse: - data: Optional[Union[str, bytes]] = None + data: Optional[bytes] = None # If using proj/location, fetch ADC if self.vertexai and (self.project or self.location): @@ -1251,21 +1341,33 @@ async def _async_request_once( http_request.headers['x-goog-user-project'] = ( self._credentials.quota_project_id ) - data = json.dumps(http_request.data) if http_request.data else None - else: - if http_request.data: - if not isinstance(http_request.data, bytes): - data = json.dumps(http_request.data) if http_request.data else None - else: - data = http_request.data + if http_request.data: + if not isinstance(http_request.data, bytes): + data = json.dumps(http_request.data).encode('utf-8') + else: + data = http_request.data if stream: if self._use_aiohttp(): self._aiohttp_session = await self._get_aiohttp_session() + url = http_request.url + if self._use_google_auth_async(): + client_cert_source = mtls.default_client_cert_source() # type: ignore[no-untyped-call] + await self._aiohttp_session.configure_mtls_channel( # type: ignore[union-attr] + client_cert_source + ) + if self._aiohttp_session._is_mtls and 'googleapis.com' in url: # type: ignore[union-attr] + if 'sandbox' in url: + url = url.replace( + 'sandbox.googleapis.com', 'mtls.sandbox.googleapis.com' + ) + else: + url = url.replace('googleapis.com', 'mtls.googleapis.com') try: + print('async request.url: %s' % url) response = await self._aiohttp_session.request( method=http_request.method, - url=http_request.url, + url=url, headers=http_request.headers, data=data, timeout=aiohttp.ClientTimeout(total=http_request.timeout), @@ -1287,7 +1389,7 @@ async def _async_request_once( self._aiohttp_session = await self._get_aiohttp_session() response = await self._aiohttp_session.request( method=http_request.method, - url=http_request.url, + url=url, headers=http_request.headers, data=data, timeout=aiohttp.ClientTimeout(total=http_request.timeout), @@ -1295,17 +1397,21 @@ async def _async_request_once( ) await errors.APIError.raise_for_async_response(response) + if hasattr(response, '_response'): + # Extract the underlying aiohttp.ClientResponse from the + # AsyncAuthorizedSession Response. + response = response._response return HttpResponse(response.headers, response) else: # aiohttp is not available. Fall back to httpx. - httpx_request = self._async_httpx_client.build_request( + httpx_request = self._async_httpx_client.build_request( # type: ignore[union-attr] method=http_request.method, url=http_request.url, content=data, headers=http_request.headers, timeout=http_request.timeout, ) - client_response = await self._async_httpx_client.send( + client_response = await self._async_httpx_client.send( # type: ignore[union-attr] httpx_request, stream=stream, ) @@ -1314,17 +1420,34 @@ async def _async_request_once( else: if self._use_aiohttp(): self._aiohttp_session = await self._get_aiohttp_session() + url = http_request.url + if self._use_google_auth_async(): + await self._aiohttp_session.configure_mtls_channel() # type: ignore[union-attr] + if self._aiohttp_session._is_mtls and 'googleapis.com' in url: # type: ignore[union-attr] + if 'sandbox' in url: + url = url.replace( + 'sandbox.googleapis.com', 'mtls.sandbox.googleapis.com' + ) + else: + url = url.replace('googleapis.com', 'mtls.googleapis.com') try: + print('request.url: %s' % url) response = await self._aiohttp_session.request( method=http_request.method, - url=http_request.url, + url=url, headers=http_request.headers, data=data, timeout=aiohttp.ClientTimeout(total=http_request.timeout), **self._async_client_session_request_args, ) await errors.APIError.raise_for_async_response(response) - return HttpResponse(response.headers, [await response.text()]) + unwrapped_response: Any = response + if hasattr(unwrapped_response, '_response'): + unwrapped_response = unwrapped_response._response + + return HttpResponse( + unwrapped_response.headers, [await unwrapped_response.text()] + ) except ( aiohttp.ClientConnectorError, aiohttp.ClientConnectorDNSError, @@ -1341,17 +1464,24 @@ async def _async_request_once( self._aiohttp_session = await self._get_aiohttp_session() response = await self._aiohttp_session.request( method=http_request.method, - url=http_request.url, + url=url, headers=http_request.headers, data=data, timeout=aiohttp.ClientTimeout(total=http_request.timeout), **self._async_client_session_request_args, ) await errors.APIError.raise_for_async_response(response) - return HttpResponse(response.headers, [await response.text()]) + unwrapped_retry_response: Any = response + if hasattr(unwrapped_retry_response, '_response'): + unwrapped_retry_response = unwrapped_retry_response._response + + return HttpResponse( + unwrapped_retry_response.headers, + [await unwrapped_retry_response.text()], + ) else: # aiohttp is not available. Fall back to httpx. - client_response = await self._async_httpx_client.request( + client_response = await self._async_httpx_client.request( # type: ignore[union-attr] method=http_request.method, url=http_request.url, headers=http_request.headers, @@ -1591,7 +1721,7 @@ def _upload_fd( populate_server_timeout_header(upload_headers, timeout_in_seconds) retry_count = 0 while retry_count < MAX_RETRY_COUNT: - response = self._httpx_client.request( + response = self._httpx_client.request( # type: ignore[union-attr] method='POST', url=upload_url, headers=upload_headers, @@ -1643,7 +1773,7 @@ def download_file( else: data = http_request.data - response = self._httpx_client.request( + response = self._httpx_client.request( # type: ignore[union-attr] method=http_request.method, url=http_request.url, headers=http_request.headers, @@ -1805,7 +1935,7 @@ async def _async_upload_fd( 'Failed to upload file: Upload status is not finalized.' ) return HttpResponse( - response.headers, response_stream=[await response.text()] + response.headers, response_stream=[await response.text()] # type: ignore[union-attr] ) else: # aiohttp is not available. Fall back to httpx. @@ -1850,7 +1980,7 @@ async def _async_upload_fd( retry_count = 0 client_response = None while retry_count < MAX_RETRY_COUNT: - client_response = await self._async_httpx_client.request( + client_response = await self._async_httpx_client.request( # type: ignore[union-attr] method='POST', url=upload_url, content=file_chunk, @@ -1911,10 +2041,10 @@ async def async_download_file( 'get', path=path, request_dict={}, http_options=http_options ) - data: Optional[Union[str, bytes]] = None + data: Optional[bytes] = None if http_request.data: if not isinstance(http_request.data, bytes): - data = json.dumps(http_request.data) + data = json.dumps(http_request.data).encode('utf-8') else: data = http_request.data @@ -1934,7 +2064,7 @@ async def async_download_file( ).byte_stream[0] else: # aiohttp is not available. Fall back to httpx. - client_response = await self._async_httpx_client.request( + client_response = await self._async_httpx_client.request( # type: ignore[union-attr] method=http_request.method, url=http_request.url, headers=http_request.headers, @@ -1957,15 +2087,17 @@ def close(self) -> None: """Closes the API client.""" # Let users close the custom client explicitly by themselves. Otherwise, # close the client when the object is garbage collected. - if not self._http_options.httpx_client: + if not self._http_options.httpx_client and self._httpx_client: self._httpx_client.close() + if self._authorized_session: + self._authorized_session.close() # type: ignore[no-untyped-call] async def aclose(self) -> None: """Closes the API async client.""" # Let users close the custom client explicitly by themselves. Otherwise, # close the client when the object is garbage collected. if not self._http_options.httpx_async_client: - await self._async_httpx_client.aclose() + await self._async_httpx_client.aclose() # type: ignore[union-attr] if self._aiohttp_session and not self._http_options.aiohttp_client: await self._aiohttp_session.close() @@ -1987,6 +2119,7 @@ def __del__(self) -> None: except Exception: # pylint: disable=broad-except pass + def get_token_from_credentials( client: 'BaseApiClient', credentials: google.auth.credentials.Credentials @@ -1999,6 +2132,7 @@ def get_token_from_credentials( raise RuntimeError('Could not resolve API token from the environment') return credentials.token # type: ignore[no-any-return] + async def async_get_token_from_credentials( client: 'BaseApiClient', credentials: google.auth.credentials.Credentials diff --git a/google/genai/client.py b/google/genai/client.py index cec5cef4c..66ef247a7 100644 --- a/google/genai/client.py +++ b/google/genai/client.py @@ -148,7 +148,9 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient: stacklevel=5, ) - http_client: httpx.AsyncClient = self._api_client._async_httpx_client + http_client: Optional[httpx.AsyncClient] = ( + self._api_client._async_httpx_client + ) async_client_args = self._api_client._http_options.async_client_args or {} has_custom_transport = 'transport' in async_client_args @@ -308,7 +310,6 @@ class DebugConfig(pydantic.BaseModel): ) - class Client: """Client for making synchronous requests. diff --git a/google/genai/errors.py b/google/genai/errors.py index 63d9334b9..48bf1b131 100644 --- a/google/genai/errors.py +++ b/google/genai/errors.py @@ -18,19 +18,25 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union import httpx import json -import websockets +import requests from . import _common if TYPE_CHECKING: from .replay_api_client import ReplayResponse import aiohttp + from google.auth.aio.transport.aiohttp import Response as AsyncAuthorizedSessionResponse class APIError(Exception): """General errors raised by the GenAI API.""" code: int - response: Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse'] + response: Union[ + requests.Response, + 'ReplayResponse', + httpx.Response, + 'AsyncAuthorizedSessionResponse', + ] status: Optional[str] = None message: Optional[str] = None @@ -40,7 +46,12 @@ def __init__( code: int, response_json: Any, response: Optional[ - Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse'] + Union[ + requests.Response, + 'ReplayResponse', + httpx.Response, + 'AsyncAuthorizedSessionResponse', + ] ] = None, ): if isinstance(response_json, list) and len(response_json) == 1: @@ -112,7 +123,7 @@ def _to_replay_record(self) -> _common.StringDict: @classmethod def raise_for_response( - cls, response: Union['ReplayResponse', httpx.Response] + cls, response: Union['ReplayResponse', httpx.Response, requests.Response] ) -> None: """Raises an error with detailed error message if the response has an error status.""" if response.status_code == 200: @@ -128,6 +139,16 @@ def raise_for_response( 'message': message, 'status': response.reason_phrase, } + elif isinstance(response, requests.Response): + try: + # do not do any extra muanipulation on the response. + # return the raw response json as is. + response_json = response.json() + except requests.exceptions.JSONDecodeError: + response_json = { + 'message': response.text, + 'status': response.reason, + } else: response_json = response.body_segments[0].get('error', {}) @@ -139,7 +160,11 @@ def raise_error( status_code: int, response_json: Any, response: Optional[ - Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse'] + Union[ + 'ReplayResponse', + httpx.Response, + requests.Response, + ] ], ) -> None: """Raises an appropriate APIError subclass based on the status code. @@ -166,12 +191,13 @@ def raise_error( async def raise_for_async_response( cls, response: Union[ - 'ReplayResponse', httpx.Response, 'aiohttp.ClientResponse' + 'ReplayResponse', + httpx.Response, + 'aiohttp.ClientResponse', + 'AsyncAuthorizedSessionResponse', ], ) -> None: """Raises an error with detailed error message if the response has an error status.""" - status_code = 0 - response_json = None if isinstance(response, httpx.Response): if response.status_code == 200: return @@ -196,18 +222,23 @@ async def raise_for_async_response( try: import aiohttp # pylint: disable=g-import-not-at-top - if isinstance(response, aiohttp.ClientResponse): - if response.status == 200: + # Use a local variable to help Mypy handle the unwrapped response + unwrapped_response: Any = response + if hasattr(unwrapped_response, '_response'): + unwrapped_response = unwrapped_response._response + + if isinstance(unwrapped_response, aiohttp.ClientResponse): + if unwrapped_response.status == 200: return try: - response_json = await response.json() + response_json = await unwrapped_response.json() except aiohttp.client_exceptions.ContentTypeError: - message = await response.text() + message = await unwrapped_response.text() response_json = { 'message': message, - 'status': response.reason, + 'status': unwrapped_response.reason, } - status_code = response.status + status_code = unwrapped_response.status else: raise ValueError(f'Unsupported response type: {type(response)}') except ImportError: diff --git a/google/genai/tests/client/test_client_close.py b/google/genai/tests/client/test_client_close.py index 2beaf7ea8..00e2bc6bc 100644 --- a/google/genai/tests/client/test_client_close.py +++ b/google/genai/tests/client/test_client_close.py @@ -43,6 +43,7 @@ def test_close_httpx_client(): vertexai=True, project='test_project', location='global', + http_options=api_client.HttpOptions(client_args={'max_redirects': 10}), ) client.close() assert client._api_client._httpx_client.is_closed @@ -55,6 +56,7 @@ def test_httpx_client_context_manager(): vertexai=True, project='test_project', location='global', + http_options=api_client.HttpOptions(client_args={'max_redirects': 10}), ) as client: pass assert not client._api_client._httpx_client.is_closed @@ -135,6 +137,9 @@ async def run(): vertexai=True, project='test_project', location='global', + http_options=api_client.HttpOptions( + async_client_args={'trust_env': False} + ), ).aio # aiohttp session is created in the first request instead of client # initialization. @@ -176,6 +181,9 @@ async def run(): vertexai=True, project='test_project', location='global', + http_options=api_client.HttpOptions( + async_client_args={'trust_env': False} + ), ).aio as async_client: # aiohttp session is created in the first request instead of client # initialization. diff --git a/google/genai/tests/client/test_client_initialization.py b/google/genai/tests/client/test_client_initialization.py index 7b0136044..4b9e5c98f 100644 --- a/google/genai/tests/client/test_client_initialization.py +++ b/google/genai/tests/client/test_client_initialization.py @@ -20,6 +20,7 @@ import concurrent.futures import logging import os +import requests import ssl from unittest import mock @@ -1331,18 +1332,32 @@ def refresh_side_effect(request): mock_refresh = mock.Mock(side_effect=refresh_side_effect) mock_creds.refresh = mock_refresh - # Mock the actual request to avoid network calls - mock_httpx_response = httpx.Response( - status_code=200, - headers={}, - text='{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}', - ) - mock_request = mock.Mock(return_value=mock_httpx_response) - monkeypatch.setattr(api_client.SyncHttpxClient, "request", mock_request) - client = Client( vertexai=True, project="fake_project_id", location="fake-location" ) + # Mock the actual request to avoid network calls + if client._api_client._use_google_auth_sync(): + # Cloud environment enables mTLS and uses requests.Response + mock_http_response = requests.Response() + mock_http_response.status_code = 200 + mock_http_response.headers = {} + mock_http_response._content = ( + b'{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}' + ) + mock_request = mock.Mock(return_value=mock_http_response) + monkeypatch.setattr( + google.auth.transport.requests.AuthorizedSession, "request", mock_request + ) + else: + # Non-cloud environment w/o certificates uses httpx.Response + mock_httpx_response = httpx.Response( + status_code=200, + headers={}, + text='{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}', + ) + mock_request = mock.Mock(return_value=mock_httpx_response) + monkeypatch.setattr(api_client.SyncHttpxClient, "send", mock_request) + # Reset credentials to test initialization to ensure the sync lock is tested. client._api_client._credentials = None diff --git a/google/genai/tests/client/test_retries.py b/google/genai/tests/client/test_retries.py index 125354771..49af5264b 100644 --- a/google/genai/tests/client/test_retries.py +++ b/google/genai/tests/client/test_retries.py @@ -20,12 +20,20 @@ import datetime from unittest import mock import pytest + try: - import aiohttp - AIOHTTP_NOT_INSTALLED = False + import aiohttp + from google.auth.aio.transport.aiohttp import Response as AsyncAuthorizedSessionResponse + from google.auth.aio.transport.sessions import AsyncAuthorizedSession + from google.auth.aio.credentials import StaticCredentials + + AIOHTTP_NOT_INSTALLED = False except ImportError: - AIOHTTP_NOT_INSTALLED = True - aiohttp = mock.MagicMock() + AIOHTTP_NOT_INSTALLED = True + aiohttp = mock.MagicMock() + AsyncAuthorizedSessionResponse = mock.MagicMock() + StaticCredentials = mock.MagicMock() + AsyncAuthorizedSession = mock.MagicMock() from google.oauth2 import credentials import httpx @@ -614,7 +622,6 @@ async def run(): # Async aiohttp - async def _aiohttp_async_response(status: int, streamable: bool = False): """Has to return a coroutine hence async.""" response = mock.Mock(spec=aiohttp.ClientResponse) @@ -725,7 +732,10 @@ async def run(): project='test_project', location='global', http_options=_transport_options( - http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + http_options=types.HttpOptions( + retry_options=_RETRY_OPTIONS, + async_client_args={'trust_env': False}, + ), ), ) @@ -740,7 +750,11 @@ async def run(): @requires_aiohttp -@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True) +@mock.patch.object( + aiohttp.ClientSession, + 'request', + autospec=True, +) def test_aiohttp_retries_failed_request_retries_successfully_at_request_level( mock_request, ): @@ -789,7 +803,10 @@ async def run(): project='test_project', location='global', http_options=_transport_options( - http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + http_options=types.HttpOptions( + retry_options=_RETRY_OPTIONS, + async_client_args={'trust_env': False}, + ), ), ) @@ -807,7 +824,11 @@ async def run(): @requires_aiohttp -@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True) +@mock.patch.object( + aiohttp.ClientSession, + 'request', + autospec=True, +) def test_aiohttp_retries_failed_request_retries_unsuccessfully_at_request_level( mock_request, ): @@ -1160,7 +1181,10 @@ async def run(): http_method='GET', path='path', request_dict={}, - http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + http_options=types.HttpOptions( + retry_options=_RETRY_OPTIONS, + async_client_args={'trust_env': False}, + ), ) async for _ in stream: pass @@ -1171,7 +1195,11 @@ async def run(): @requires_aiohttp -@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True) +@mock.patch.object( + aiohttp.ClientSession, + 'request', + autospec=True, +) def test_aiohttp_retries_streamed_failed_request_retries_unsuccessfully( mock_request, ): @@ -1224,6 +1252,9 @@ async def run(): vertexai=True, project='test_project', location='global', + http_options=types.HttpOptions( + async_client_args={'trust_env': False}, + ), ) with _patch_auth_default(): @@ -1245,7 +1276,11 @@ async def run(): @requires_aiohttp -@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True) +@mock.patch.object( + aiohttp.ClientSession, + 'request', + autospec=True, +) def test_aiohttp_retries_client_connector_error_retries_successfully( mock_request, ): @@ -1278,3 +1313,51 @@ async def run(): assert response.headers['status-code'] == '200' asyncio.run(run()) + + +@requires_aiohttp +@mock.patch.object(AsyncAuthorizedSession, 'request', autospec=True) +def test_aiohttp_retries_failed_request_retries_unsuccessfully_mtls( + mock_request, +): + api_client.has_aiohttp = True + + async def run(): + # 1. Setup mocked aiohttp responses + res429 = await _aiohttp_async_response(429) + res504 = await _aiohttp_async_response(504) + + # 2. Wrap them in the AsyncAuthorizedSessionResponse expected by the SDK + mock_auth_res429 = mock.Mock(spec=AsyncAuthorizedSessionResponse) + mock_auth_res429._response = res429 + + mock_auth_res504 = mock.Mock(spec=AsyncAuthorizedSessionResponse) + mock_auth_res504._response = res504 + + # AsyncAuthorizedSession.request is an async method + mock_request.side_effect = [mock_auth_res429, mock_auth_res504] + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + http_options=types.HttpOptions( + retry_options=_RETRY_OPTIONS, + ), + ) + + # Force the mTLS path to be active for this test + with mock.patch( + 'google.auth.transport.mtls.should_use_client_cert', return_value=True + ): + with _patch_auth_default(): + try: + await client.async_request( + http_method='GET', path='path', request_dict={} + ) + assert False, 'Expected APIError to be raised.' + except errors.APIError as e: + assert e.code == 504 + mock_request.assert_called() + + asyncio.run(run()) diff --git a/pyproject.toml b/pyproject.toml index 8ff01c1fa..df423733f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ ] dependencies = [ "anyio>=4.8.0, <5.0.0", - "google-auth[requests]>=2.47.0, <3.0.0", + "google-auth[requests]>=2.48.1, <3.0.0", "httpx>=0.28.1, <1.0.0", "pydantic>=2.9.0, <3.0.0", "requests>=2.28.1, <3.0.0", @@ -40,6 +40,7 @@ dependencies = [ [project.optional-dependencies] aiohttp = ["aiohttp>=3.10.11, <4.0.0"] local-tokenizer = ["sentencepiece>=0.2.0", "protobuf"] +pyopenssl = ["pyopenssl"] [project.urls] Homepage = "https://github.com/googleapis/python-genai" diff --git a/requirements.txt b/requirements.txt index b26403dc2..322d9fe27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ pyasn1==0.6.1 pyasn1_modules==0.4.1 pydantic==2.12.0 pydantic_core==2.41.1 +pyopenssl==24.2.1 pytest==8.3.4 pytest-asyncio==0.25.0 pytest-cov==6.0.0