diff --git a/.github/.release-please-manifest.json b/.github/.release-please-manifest.json index f16e9b1aea..9a3ece4da7 100644 --- a/.github/.release-please-manifest.json +++ b/.github/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.32.0" + ".": "1.33.0" } diff --git a/.github/release-please-config.json b/.github/release-please-config.json index 8122ea8f75..b25f273fed 100644 --- a/.github/release-please-config.json +++ b/.github/release-please-config.json @@ -1,6 +1,6 @@ { "$schema": "https://raw.githubusercontent.com/googleapis/release-please/main/schemas/config.json", - "last-release-sha": "5e49cfa6567a09e06409b0f380434f12f85a17c9", + "last-release-sha": "88421f80a0b008e90f18401abca4ceec3548f6cd", "packages": { ".": { "release-type": "python", diff --git a/.github/scripts/constant.js b/.github/scripts/constant.js index 3e90cb8842..142efb7838 100644 --- a/.github/scripts/constant.js +++ b/.github/scripts/constant.js @@ -17,7 +17,6 @@ limitations under the License. let CONSTANT_VALUES = { GLOBALS: { LABELS: { - STALE: 'stale', BUG: 'bug', CORE: 'core', TOOLS: 'tools', @@ -33,9 +32,7 @@ let CONSTANT_VALUES = { EVAL: 'eval', TRACING: 'tracing', WEB: 'web', - WORKFLOW: 'workflow', - REQUEST_CLARIFICATION: 'request clarification', - NEEDS_REVIEW: 'needs review' + WORKFLOW: 'workflow' }, STATE: { CLOSED: 'closed' } }, @@ -52,4 +49,4 @@ let CONSTANT_VALUES = { } }; -module.exports = CONSTANT_VALUES; \ No newline at end of file +module.exports = CONSTANT_VALUES; diff --git a/.github/scripts/csat.js b/.github/scripts/csat.js index 54356fd69d..0c0168f861 100644 --- a/.github/scripts/csat.js +++ b/.github/scripts/csat.js @@ -13,59 +13,47 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + const CONSTANT_VALUES = require('./constant'); /** - * Invoked from stale_csat.js and csat.yaml file to post survey link - * in closed issue. + * Invoked from csat.yml workflow file to post survey link + * in closed issues. * @param {!Object.} github contains pre defined functions. * context Information about the workflow run. * @return {null} */ module.exports = async ({ github, context }) => { const issue = context.payload.issue.html_url; - let baseUrl = ''; - // Loop over all ths label present in issue and check if specific label is - // present for survey link. - for (const label of context.payload.issue.labels) { - if (label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.BUG) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.CORE) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.TOOLS) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.SERVICES) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.MODELS) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.MCP) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.AUTH) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.LIVE) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.DOCUMENTATION) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.GOOD_FIRST_ISSUE) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.AGENT_ENGINE) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.BQ) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.EVAL) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.TRACING) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.WEB) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.WORKFLOW) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.REQUEST_CLARIFICATION) || - label.name.includes(CONSTANT_VALUES.GLOBALS.LABELS.NEEDS_REVIEW)) { - console.log( - `label-${label.name}, posting CSAT survey for issue =${issue}`); - baseUrl = CONSTANT_VALUES.MODULE.CSAT.BASE_URL; - const yesCsat = ` ${CONSTANT_VALUES.MODULE.CSAT.YES}`; + // Check if any label matches (case-insensitive) the supported CSAT labels. + const supportedLabels = Object.values(CONSTANT_VALUES.GLOBALS.LABELS); + const hasMatchingLabel = context.payload.issue.labels.some(label => { + const name = label.name.toLowerCase(); + return supportedLabels.some(supportedLabel => name.includes(supportedLabel)); + }); + + if (hasMatchingLabel) { + console.log(`Posting CSAT survey for issue =${issue}`); + const baseUrl = CONSTANT_VALUES.MODULE.CSAT.BASE_URL; + + const yesCsat = ` ${CONSTANT_VALUES.MODULE.CSAT.YES}`; + + const noCsat = ` ${CONSTANT_VALUES.MODULE.CSAT.NO}`; + + const comment = CONSTANT_VALUES.MODULE.CSAT.MSG + '\n' + yesCsat + '\n' + + noCsat + '\n'; + const issueNumber = context.issue.number ?? context.payload.issue.number; - const noCsat = ` ${CONSTANT_VALUES.MODULE.CSAT.NO}`; - const comment = CONSTANT_VALUES.MODULE.CSAT.MSG + '\n' + yesCsat + '\n' + - noCsat + '\n'; - let issueNumber = context.issue.number ?? context.payload.issue.number; - await github.rest.issues.createComment({ - issue_number: issueNumber, - owner: context.repo.owner, - repo: context.repo.repo, - body: comment - }); - } + await github.rest.issues.createComment({ + issue_number: issueNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body: comment + }); } }; diff --git a/.github/workflows/csat.yml b/.github/workflows/csat.yml index 5da6ff909f..24569d6646 100644 --- a/.github/workflows/csat.yml +++ b/.github/workflows/csat.yml @@ -7,7 +7,6 @@ on: permissions: contents: read issues: write - pull-requests: write jobs: welcome: @@ -18,4 +17,4 @@ jobs: with: script: | const script = require('./.github/scripts/csat.js') - script({github, context}) \ No newline at end of file + script({github, context}) diff --git a/CHANGELOG.md b/CHANGELOG.md index a57085d100..baa2a92d68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,38 @@ # Changelog +## [1.33.0](https://github.com/google/adk-python/compare/v1.32.0...v1.33.0) (2026-05-08) + + +### Features + +* add BufferableSessionService ([0bc767e](https://github.com/google/adk-python/commit/0bc767e6892742d6290d3445d028f95925187aed)) +* **apigee:** allow injecting credentials into ApigeeLlm ([ce578ff](https://github.com/google/adk-python/commit/ce578fffa0dc02b0033f7f5e705b9422cbd6c252)) +* Make ADK environment tools truncation limit configurable ([83ae405](https://github.com/google/adk-python/commit/83ae40525aa734f4a3b365614cce43831612a1ec)) +* **models:** add get_function_calls and get_function_responses to LlmResponse ([22fae7e](https://github.com/google/adk-python/commit/22fae7e9a09c581f433f3c51ea9a0ab26e689b92)) + + +### Bug Fixes + +* catch genai.ClientError when sandbox is missing ([69fa777](https://github.com/google/adk-python/commit/69fa777881b3cb161e5b3dcb005def9a2ad86904)), closes [#5480](https://github.com/google/adk-python/issues/5480) +* double append bug ([f8b4c59](https://github.com/google/adk-python/commit/f8b4c59350fea3319c9e53e29968c56c93c57c99)) +* Filter out video events with inline data from being stored in session ([88421f8](https://github.com/google/adk-python/commit/88421f80a0b008e90f18401abca4ceec3548f6cd)) +* fix fork detection, correct offload limits, and add response logging in BigQuery plugin ([9d1bb4b](https://github.com/google/adk-python/commit/9d1bb4b4870233e574f5c06ddd2b62a48272398f)) +* hot reload agents for adk web ([740557c](https://github.com/google/adk-python/commit/740557c8965305abc75752082bc3ee63d924742f)) +* Only append skills to system instruction if ListSkillsTool isn't available ([01f1fc9](https://github.com/google/adk-python/commit/01f1fc9c912a97ff27bb1332a28324f991eae77d)) +* prevent state_delta overwrite on function_response-only events ([fc27203](https://github.com/google/adk-python/commit/fc2720378e8997269d30f5439051f5e43d5fa028), [211e2ce](https://github.com/google/adk-python/commit/211e2ceb70ac6b61400559761d1d6548d906a79b)), closes [#3178](https://github.com/google/adk-python/issues/3178) +* Raise a clear actionable error when CustomAuthScheme lacks a registered AuthProvider ([83f9817](https://github.com/google/adk-python/commit/83f981761b963ca51a286cbd004c043567517a3c)) +* should use app_name instead of req.app_name ([8286066](https://github.com/google/adk-python/commit/8286066e71e5c07b5b28979b8327d4b330187ddd)) +* **simulation:** Add error message when LlmBackedUserSimulator returns empty response ([fb92aad](https://github.com/google/adk-python/commit/fb92aad9c53bb9f6706fb27751d71fcda2419500)) +* Update expressmode api call to include default api key param ([e833995](https://github.com/google/adk-python/commit/e8339953911a8b580cfc2d88c7008234a43beece)) +* use asyncio.sleep to avoid blocking event loop ([3a1eadc](https://github.com/google/adk-python/commit/3a1eadce66804db08f6520cc11f9c60e81bb9e30)) +* Use project and location instead of API key when deploying to agent engine ([398f28f](https://github.com/google/adk-python/commit/398f28feb47d87ec9c4c03dd3e0e7b87a1699e6e)) + + +### Code Refactoring + +* adjust computation of workflow.steps metric and add new unit tests ([03d6208](https://github.com/google/adk-python/commit/03d6208aacac8c19adec45ce0dd837f9e3a7f66f)) +* remove input.type and output.type attributes from adk metrics ([9559968](https://github.com/google/adk-python/commit/95599683230dd13e5792133f30ade3fe19358d52)) + ## [1.32.0](https://github.com/google/adk-python/compare/v1.31.0...v1.32.0) (2026-04-30) diff --git a/contributing/samples/hello_world/agent.py b/contributing/samples/hello_world/agent.py index 01def21ad1..d5cfc3ef5f 100755 --- a/contributing/samples/hello_world/agent.py +++ b/contributing/samples/hello_world/agent.py @@ -65,7 +65,7 @@ async def check_prime(nums: list[int]) -> str: root_agent = Agent( - model='projects/adk-cat/locations/us-central1/publishers/google/models/gemini-2.5-flash', + model='gemini-3-flash-preview', name='hello_world_agent', description=( 'hello world agent that can roll a dice of 8 sides and check prime' diff --git a/contributing/samples/session_state_agent/agent.py b/contributing/samples/session_state_agent/agent.py index 6c03de8e90..7b29a90c09 100644 --- a/contributing/samples/session_state_agent/agent.py +++ b/contributing/samples/session_state_agent/agent.py @@ -171,7 +171,7 @@ async def after_agent_callback(callback_context: CallbackContext): 'Log all users query with `log_query` tool. Must always remind user you' ' cannot answer second query because your setup.' ), - model='gemini-2.5-flash', + model='gemini-3-flash-preview', before_agent_callback=before_agent_callback, before_model_callback=before_model_callback, after_model_callback=after_model_callback, diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py index 17989374d6..881487bd1d 100644 --- a/src/google/adk/a2a/converters/request_converter.py +++ b/src/google/adk/a2a/converters/request_converter.py @@ -27,6 +27,8 @@ from .part_converter import A2APartToGenAIPartConverter from .part_converter import convert_a2a_part_to_genai_part +A2A_METADATA_KEY = 'a2a_metadata' + @a2a_experimental class AgentRunRequest(BaseModel): @@ -97,7 +99,7 @@ def convert_a2a_request_to_agent_run_request( custom_metadata = {} if request.metadata: - custom_metadata['a2a_metadata'] = request.metadata + custom_metadata[A2A_METADATA_KEY] = request.metadata output_parts = [] for a2a_part in request.message.parts: diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index a71ae317f5..4a2add823c 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -19,7 +19,6 @@ from typing import Dict from typing import List from typing import Literal -from typing import Optional from pydantic import alias_generators from pydantic import BaseModel @@ -40,9 +39,9 @@ class BaseModelWithConfig(BaseModel): class HttpCredentials(BaseModelWithConfig): """Represents the secret token value for HTTP authentication, like user name, password, oauth token, etc.""" - username: Optional[str] = None - password: Optional[str] = None - token: Optional[str] = None + username: str | None = None + password: str | None = None + token: str | None = None @classmethod def model_validate(cls, data: Dict[str, Any]) -> "HttpCredentials": @@ -62,40 +61,43 @@ class HttpAuth(BaseModelWithConfig): # Examples: 'basic', 'bearer' scheme: str credentials: HttpCredentials - additional_headers: Optional[Dict[str, str]] = None + additional_headers: Dict[str, str] | None = None class OAuth2Auth(BaseModelWithConfig): """Represents credential value and its metadata for a OAuth2 credential.""" - client_id: Optional[str] = None - client_secret: Optional[str] = None + client_id: str | None = None + client_secret: str | None = None # tool or adk can generate the auth_uri with the state info thus client # can verify the state - auth_uri: Optional[str] = None + auth_uri: str | None = None # A unique value generated at the start of the OAuth flow to bind the user's # session to the authorization request. This value is typically stored with # user session and passed to backend for validation. - nonce: Optional[str] = None - state: Optional[str] = None + nonce: str | None = None + state: str | None = None # tool or adk can decide the redirect_uri if they don't want client to decide - redirect_uri: Optional[str] = None - auth_response_uri: Optional[str] = None - auth_code: Optional[str] = None - access_token: Optional[str] = None - refresh_token: Optional[str] = None - id_token: Optional[str] = None - expires_at: Optional[int] = None - expires_in: Optional[int] = None - audience: Optional[str] = None - token_endpoint_auth_method: Optional[ + redirect_uri: str | None = None + auth_response_uri: str | None = None + auth_code: str | None = None + access_token: str | None = None + refresh_token: str | None = None + id_token: str | None = None + expires_at: int | None = None + expires_in: int | None = None + audience: str | None = None + code_verifier: str | None = None + code_challenge_method: str | None = None + token_endpoint_auth_method: ( Literal[ "client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt", ] - ] = "client_secret_basic" + | None + ) = "client_secret_basic" class ServiceAccountCredential(BaseModelWithConfig): @@ -166,11 +168,11 @@ class ServiceAccount(BaseModelWithConfig): when ``use_id_token`` is True. """ - service_account_credential: Optional[ServiceAccountCredential] = None - scopes: Optional[List[str]] = None - use_default_credential: Optional[bool] = False - use_id_token: Optional[bool] = False - audience: Optional[str] = None + service_account_credential: ServiceAccountCredential | None = None + scopes: List[str] | None = None + use_default_credential: bool | None = False + use_id_token: bool | None = False + audience: str | None = None @model_validator(mode="after") def _validate_config(self) -> ServiceAccount: @@ -275,9 +277,9 @@ class AuthCredential(BaseModelWithConfig): auth_type: AuthCredentialTypes # Resource reference for the credential. # This will be supported in the future. - resource_ref: Optional[str] = None + resource_ref: str | None = None - api_key: Optional[str] = None - http: Optional[HttpAuth] = None - service_account: Optional[ServiceAccount] = None - oauth2: Optional[OAuth2Auth] = None + api_key: str | None = None + http: HttpAuth | None = None + service_account: ServiceAccount | None = None + oauth2: OAuth2Auth | None = None diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index ec7c75716c..8e8f5d340b 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -28,6 +28,7 @@ from ..sessions.state import State try: + from authlib.common.security import generate_token from authlib.integrations.requests_client import OAuth2Session AUTHLIB_AVAILABLE = True @@ -158,6 +159,8 @@ def generate_auth_uri( auth_scheme = self.auth_config.auth_scheme auth_credential = self.auth_config.raw_auth_credential + if not auth_credential or not auth_credential.oauth2: + raise ValueError("raw_auth_credential or oauth2 is empty") if isinstance(auth_scheme, OpenIdConnectWithConfig): authorization_endpoint = auth_scheme.authorization_endpoint @@ -190,6 +193,7 @@ def generate_auth_uri( auth_credential.oauth2.client_secret, scope=" ".join(scopes), redirect_uri=auth_credential.oauth2.redirect_uri, + code_challenge_method=auth_credential.oauth2.code_challenge_method, ) params = { "access_type": "offline", @@ -197,12 +201,30 @@ def generate_auth_uri( } if auth_credential.oauth2.audience: params["audience"] = auth_credential.oauth2.audience + + # If using PKCE with S256, ensure a code_verifier exists. + # If not provided in the credential, generate a cryptographically secure + # random token of 48 characters (OAuth2 recommends 43-128 characters). + code_verifier = auth_credential.oauth2.code_verifier + method = auth_credential.oauth2.code_challenge_method + + if method: + if method != "S256": + raise ValueError( + f"Unsupported code_challenge_method: {method}. Only 'S256' is" + " supported." + ) + if not code_verifier: + code_verifier = generate_token(48) + uri, state = client.create_authorization_url( - url=authorization_endpoint, **params + url=authorization_endpoint, code_verifier=code_verifier, **params ) exchanged_auth_credential = auth_credential.model_copy(deep=True) exchanged_auth_credential.oauth2.auth_uri = uri exchanged_auth_credential.oauth2.state = state + if code_verifier: + exchanged_auth_credential.oauth2.code_verifier = code_verifier return exchanged_auth_credential diff --git a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py index 76f0c67899..d3504bfff6 100644 --- a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py @@ -193,6 +193,12 @@ async def _exchange_authorization_code( return ExchangeResult(auth_credential, False) try: + kwargs = {} + # If a code_verifier is available (e.g. from PKCE), include it in the + # token exchange request. + if auth_credential.oauth2 and auth_credential.oauth2.code_verifier: + kwargs["code_verifier"] = auth_credential.oauth2.code_verifier + # Authlib already injects client_id for body-based client auth flows such # as client_secret_post, so passing it here would duplicate the field. tokens = client.fetch_token( @@ -202,6 +208,7 @@ async def _exchange_authorization_code( ), code=auth_credential.oauth2.auth_code, grant_type=OAuthGrantType.AUTHORIZATION_CODE, + **kwargs, ) update_credential_with_tokens(auth_credential, tokens) logger.debug("Successfully exchanged authorization code for access token") diff --git a/src/google/adk/auth/oauth2_credential_util.py b/src/google/adk/auth/oauth2_credential_util.py index df2f26c002..d0d1255fbe 100644 --- a/src/google/adk/auth/oauth2_credential_util.py +++ b/src/google/adk/auth/oauth2_credential_util.py @@ -92,6 +92,7 @@ def create_oauth2_session( redirect_uri=auth_credential.oauth2.redirect_uri, state=auth_credential.oauth2.state, token_endpoint_auth_method=auth_credential.oauth2.token_endpoint_auth_method, + code_challenge_method=auth_credential.oauth2.code_challenge_method, ), token_endpoint, ) diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index 027a79bdf4..f5d706b94c 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -57,7 +57,7 @@ async def run_async( # ids must survive replay. try: from ...models.anthropic_llm import AnthropicLlm - except ImportError: + except (ImportError, OSError): AnthropicLlm = None if AnthropicLlm is not None and isinstance( canonical_model, AnthropicLlm diff --git a/src/google/adk/integrations/agent_registry/__init__.py b/src/google/adk/integrations/agent_registry/__init__.py index 18a30b3211..3c3bd9b2f5 100644 --- a/src/google/adk/integrations/agent_registry/__init__.py +++ b/src/google/adk/integrations/agent_registry/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .agent_registry import _ProtocolType from .agent_registry import AgentRegistry __all__ = [ diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py index 1726b1179e..887c894cd4 100644 --- a/src/google/adk/integrations/agent_registry/agent_registry.py +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -28,12 +28,7 @@ from typing import TypedDict from urllib.parse import urlparse -from a2a.types import AgentCapabilities -from a2a.types import AgentCard -from a2a.types import AgentSkill -from a2a.types import TransportProtocol as A2ATransport from google.adk.agents.readonly_context import ReadonlyContext -from google.adk.agents.remote_a2a_agent import RemoteA2aAgent from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_schemes import AuthScheme from google.adk.integrations.agent_identity.gcp_auth_provider_scheme import GcpAuthProviderScheme @@ -49,6 +44,20 @@ from mcp import StdioServerParameters from typing_extensions import override +# pylint: disable=g-import-not-at-top +try: + from a2a.types import AgentCapabilities + from a2a.types import AgentCard + from a2a.types import AgentSkill + from a2a.types import TransportProtocol as A2ATransport + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent +except ImportError as e: + raise ImportError( + "AgentRegistry requires the 'a2a-sdk' package. " + "Please install it using 'pip install google-adk[a2a]'." + ) from e +# pylint: enable=g-import-not-at-top + logger = logging.getLogger("google_adk." + __name__) AGENT_REGISTRY_BASE_URL = "https://agentregistry.googleapis.com/v1alpha" diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py index 21337cfa51..0d78a5759b 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py @@ -242,6 +242,12 @@ async def _get_existing_credential( existing_credential = await refresher.refresh( existing_credential, self.auth_scheme ) + # Persist the refreshed credential so the next invocation + # reads the new tokens instead of the stale pre-refresh ones. + # Without this, providers that rotate refresh_tokens on each + # refresh (e.g. Salesforce, many OIDC providers) will fail + # because the old refresh_token has already been invalidated. + self._store_credential(existing_credential) return existing_credential return None diff --git a/src/google/adk/utils/cache_performance_analyzer.py b/src/google/adk/utils/cache_performance_analyzer.py index 5bdf8653d0..5af3a07660 100644 --- a/src/google/adk/utils/cache_performance_analyzer.py +++ b/src/google/adk/utils/cache_performance_analyzer.py @@ -144,7 +144,11 @@ async def analyze_agent_cache_performance( total_cached_tokens / total_requests if total_requests > 0 else 0.0 ) - invocations_used = [c.invocations_used for c in cache_history] + invocations_used = [ + c.invocations_used + for c in cache_history + if c.invocations_used is not None + ] total_invocations = sum(invocations_used) return { @@ -156,7 +160,9 @@ async def analyze_agent_cache_performance( else 0 ), "latest_cache": cache_history[-1].cache_name, - "cache_refreshes": len(set(c.cache_name for c in cache_history)), + "cache_refreshes": len( + {c.cache_name for c in cache_history if c.cache_name is not None} + ), "total_invocations": total_invocations, "total_prompt_tokens": total_prompt_tokens, "total_cached_tokens": total_cached_tokens, diff --git a/src/google/adk/version.py b/src/google/adk/version.py index 3a7e8f81b4..91b8650e52 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.32.0" +__version__ = "1.33.0" diff --git a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py index 3a0a5647cb..25f9267452 100644 --- a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py +++ b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py @@ -135,6 +135,57 @@ async def test_exchange_success(self, mock_oauth2_session): assert exchange_result.was_exchanged mock_client.fetch_token.assert_called_once() + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + async def test_exchange_success_pkce(self, mock_oauth2_session): + """Test successful token exchange with PKCE.""" + # Setup mock + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + code_verifier="mock_code_verifier", + ), + ) + + exchanger = OAuth2CredentialExchanger() + exchange_result = await exchanger.exchange(credential, scheme) + + # Verify token exchange was successful + assert exchange_result.credential.oauth2.access_token == "new_access_token" + assert ( + exchange_result.credential.oauth2.refresh_token == "new_refresh_token" + ) + assert exchange_result.was_exchanged + mock_client.fetch_token.assert_called_once_with( + "https://example.com/token", + authorization_response="https://example.com/callback?code=auth_code", + code="auth_code", + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code_verifier="mock_code_verifier", + ) + async def test_exchange_missing_auth_scheme(self): """Test exchange with missing auth_scheme raises ValueError.""" credential = AuthCredential( diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index 2faeeb158e..c19a5d93fd 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -53,12 +53,14 @@ def __init__( scope=None, redirect_uri=None, state=None, + **kwargs, ): self.client_id = client_id self.client_secret = client_secret self.scope = scope self.redirect_uri = redirect_uri self.state = state + self.extra_kwargs = kwargs def create_authorization_url(self, url, **kwargs): params = f"client_id={self.client_id}&scope={self.scope}" @@ -271,6 +273,54 @@ def test_generate_auth_uri_openid( assert "client_id=mock_client_id" in result.oauth2.auth_uri assert result.oauth2.state == "mock_state" + @patch("google.adk.auth.auth_handler.OAuth2Session") + def test_generate_auth_uri_pkce( + self, mock_oauth2_session, oauth2_auth_scheme, oauth2_credentials + ): + """Test generating an auth URI with PKCE.""" + oauth2_credentials.oauth2.code_challenge_method = "S256" + exchanged = oauth2_credentials.model_copy(deep=True) + + config = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged, + ) + + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_client.create_authorization_url.return_value = ( + "https://example.com/oauth2/authorize?code_challenge=...&code_challenge_method=S256", + "mock_state", + ) + + handler = AuthHandler(config) + result = handler.generate_auth_uri() + + assert result.oauth2.code_verifier is not None + assert len(result.oauth2.code_verifier) == 48 + mock_client.create_authorization_url.assert_called_once() + _, kwargs = mock_client.create_authorization_url.call_args + assert "code_verifier" in kwargs + assert kwargs["code_verifier"] == result.oauth2.code_verifier + + def test_generate_auth_uri_unsupported_pkce_method( + self, oauth2_auth_scheme, oauth2_credentials + ): + """Test generating an auth URI with unsupported PKCE method.""" + oauth2_credentials.oauth2.code_challenge_method = "plain" + exchanged = oauth2_credentials.model_copy(deep=True) + + config = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged, + ) + + handler = AuthHandler(config) + with pytest.raises(ValueError, match="Unsupported code_challenge_method"): + handler.generate_auth_uri() + class TestGenerateAuthRequest: """Tests for the generate_auth_request method.""" diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py index a631e4a656..fd3f2a8ec4 100644 --- a/tests/unittests/integrations/agent_registry/test_agent_registry.py +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -22,8 +22,8 @@ from google.adk.agents.remote_a2a_agent import RemoteA2aAgent from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import OAuth2Auth -from google.adk.integrations.agent_registry import _ProtocolType from google.adk.integrations.agent_registry import AgentRegistry +from google.adk.integrations.agent_registry.agent_registry import _ProtocolType from google.adk.telemetry.tracing import GCP_MCP_SERVER_DESTINATION_ID from google.adk.tools.mcp_tool.mcp_toolset import McpToolset import httpx diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py index a6babce651..d32fc132da 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional +from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch @@ -29,6 +30,7 @@ from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import OAuth2CredentialExchanger +from google.adk.tools.openapi_tool.openapi_spec_parser import tool_auth_handler from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolContextCredentialStore from google.adk.tools.tool_context import ToolContext @@ -223,9 +225,7 @@ async def test_openid_connect_existing_token( assert result.auth_credential == existing_credential -@patch( - 'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialRefresher' -) +@patch.object(tool_auth_handler, 'OAuth2CredentialRefresher') @pytest.mark.asyncio async def test_openid_connect_existing_oauth2_token_refresh( mock_oauth2_refresher, openid_connect_scheme, openid_connect_credential @@ -292,3 +292,64 @@ async def test_openid_connect_existing_oauth2_token_refresh( assert result.state == 'done' # The result should contain the refreshed credential after exchange assert result.auth_credential is not None + + +@patch.object(tool_auth_handler, 'OAuth2CredentialRefresher') +@pytest.mark.asyncio +async def test_refreshed_credential_is_persisted_to_store( + mock_oauth2_refresher, openid_connect_scheme, openid_connect_credential +): + """Test that refreshed OAuth2 credentials are persisted back to the store.""" + # Create existing OAuth2 credential with an "old" refresh token. + existing_credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id='test_client_id', + client_secret='test_client_secret', + access_token='old_access_token', + refresh_token='old_refresh_token', + ), + ) + + # The refresher will return a credential with rotated tokens. + refreshed_credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id='test_client_id', + client_secret='test_client_secret', + access_token='new_access_token', + refresh_token='new_refresh_token', + ), + ) + + mock_refresher_instance = MagicMock() + mock_refresher_instance.is_refresh_needed = AsyncMock(return_value=True) + mock_refresher_instance.refresh = AsyncMock(return_value=refreshed_credential) + mock_oauth2_refresher.return_value = mock_refresher_instance + + tool_context = create_mock_tool_context() + credential_store = ToolContextCredentialStore(tool_context=tool_context) + + # Store the existing (stale) credential. + key = credential_store.get_credential_key( + openid_connect_scheme, openid_connect_credential + ) + credential_store.store_credential(key, existing_credential) + + handler = ToolAuthHandler( + tool_context, + openid_connect_scheme, + openid_connect_credential, + credential_store=credential_store, + ) + + await handler.prepare_auth_credentials() + + # The critical assertion: the *refreshed* credential must now be in the + # store so that the next invocation reads the new tokens, not the old ones. + persisted = credential_store.get_credential( + openid_connect_scheme, openid_connect_credential + ) + assert persisted is not None + assert persisted.oauth2.access_token == 'new_access_token' + assert persisted.oauth2.refresh_token == 'new_refresh_token' diff --git a/tests/unittests/utils/test_cache_performance_analyzer.py b/tests/unittests/utils/test_cache_performance_analyzer.py index cdeb7ebdb1..436c341b64 100644 --- a/tests/unittests/utils/test_cache_performance_analyzer.py +++ b/tests/unittests/utils/test_cache_performance_analyzer.py @@ -401,6 +401,50 @@ async def test_session_service_integration(self): assert result["status"] == "active" assert result["requests_with_cache"] == 1 + async def test_analyze_agent_cache_performance_with_fingerprint_only(self): + """Fingerprint-only entries (cache_name=None, invocations_used=None) don't crash.""" + fp_only = CacheMetadata(fingerprint="fp", contents_count=3) + active = self.create_cache_metadata(invocations_used=4, cache_name="active") + fp_usage = self.create_mock_usage_metadata( + prompt_tokens=1000, cached_tokens=0 + ) + active_usage = self.create_mock_usage_metadata( + prompt_tokens=1000, cached_tokens=800 + ) + + events = [ + self.create_mock_event( + author="test_agent", + cache_metadata=fp_only, + usage_metadata=fp_usage, + ), + self.create_mock_event( + author="test_agent", + cache_metadata=active, + usage_metadata=active_usage, + ), + ] + mock_session = Session( + id="test_session", + app_name="test_app", + user_id="test_user", + events=events, + ) + self.mock_session_service.get_session = AsyncMock(return_value=mock_session) + + result = await self.analyzer.analyze_agent_cache_performance( + "test_session", "test_user", "test_app", "test_agent" + ) + + assert result["status"] == "active" + assert result["total_requests"] == 2 + assert result["total_prompt_tokens"] == 2000 + assert result["total_cached_tokens"] == 800 + assert result["total_invocations"] == 4 + assert result["avg_invocations_used"] == 4.0 + assert result["cache_refreshes"] == 1 + assert result["requests_with_cache"] == 2 + async def test_mixed_agents_filtering(self): """Test that analysis correctly filters by agent name.""" target_cache = self.create_cache_metadata(