From fd8b49295d628075cf70acabb2c52eedf62dd5bd Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Tue, 12 May 2026 21:34:13 -0700 Subject: [PATCH 01/28] feat(ci): add Gemini auto review and invoke workflows Merge https://github.com/google/adk-python/pull/5679 Same config as in adk-python-community COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/5679 from DeanChensj:main d33c368d56bb7f7be0e07ff557dc14e557d7d45d PiperOrigin-RevId: 914634367 --- .github/workflows/gemini-dispatch.yml | 184 ++++++++++++++++++++++++++ .github/workflows/gemini-invoke.yml | 122 +++++++++++++++++ .github/workflows/gemini-review.yml | 104 +++++++++++++++ 3 files changed, 410 insertions(+) create mode 100644 .github/workflows/gemini-dispatch.yml create mode 100644 .github/workflows/gemini-invoke.yml create mode 100644 .github/workflows/gemini-review.yml diff --git a/.github/workflows/gemini-dispatch.yml b/.github/workflows/gemini-dispatch.yml new file mode 100644 index 0000000000..1f8c243fda --- /dev/null +++ b/.github/workflows/gemini-dispatch.yml @@ -0,0 +1,184 @@ +name: '🔀 Gemini Dispatch' + +on: + pull_request_review_comment: + types: + - 'created' + pull_request_review: + types: + - 'submitted' + pull_request: + types: + - 'opened' + - 'ready_for_review' + issues: + types: + - 'opened' + - 'reopened' + issue_comment: + types: + - 'created' + +defaults: + run: + shell: 'bash' + +jobs: + debugger: + if: |- + ${{ fromJSON(vars.GEMINI_DEBUG || vars.ACTIONS_STEP_DEBUG || false) }} + runs-on: 'ubuntu-latest' + permissions: + contents: 'read' + steps: + - name: 'Print context for debugging' + env: + DEBUG_event_name: '${{ github.event_name }}' + DEBUG_event__action: '${{ github.event.action }}' + DEBUG_event__comment__author_association: '${{ github.event.comment.author_association }}' + DEBUG_event__issue__author_association: '${{ github.event.issue.author_association }}' + DEBUG_event__pull_request__author_association: '${{ github.event.pull_request.author_association }}' + DEBUG_event__review__author_association: '${{ github.event.review.author_association }}' + DEBUG_event: '${{ toJSON(github.event) }}' + run: |- + env | grep '^DEBUG_' + + dispatch: + # For PRs: only if not from a fork + # For issues: only on open/reopen + # For comments: only if user types @gemini-cli and is OWNER/MEMBER/COLLABORATOR + if: |- + ( + github.event_name == 'pull_request' && + github.event.pull_request.head.repo.fork == false && + github.event.pull_request.draft == false + ) || ( + github.event.sender.type == 'User' && + startsWith(github.event.comment.body || github.event.review.body || github.event.issue.body, '@gemini-cli') && + contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association || github.event.review.author_association || github.event.issue.author_association) + ) + runs-on: 'ubuntu-latest' + permissions: + contents: 'read' + issues: 'write' + pull-requests: 'write' + outputs: + command: '${{ steps.extract_command.outputs.command }}' + request: '${{ steps.extract_command.outputs.request }}' + additional_context: '${{ steps.extract_command.outputs.additional_context }}' + issue_number: '${{ github.event.pull_request.number || github.event.issue.number }}' + steps: + - name: 'Mint identity token' + id: 'mint_identity_token' + if: |- + ${{ vars.APP_ID }} + uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 + with: + app-id: '${{ vars.APP_ID }}' + private-key: '${{ secrets.APP_PRIVATE_KEY }}' + permission-contents: 'read' + permission-issues: 'write' + permission-pull-requests: 'write' + + - name: 'Extract command' + id: 'extract_command' + uses: 'actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea' # ratchet:actions/github-script@v7 + env: + EVENT_TYPE: '${{ github.event_name }}.${{ github.event.action }}' + REQUEST: '${{ github.event.comment.body || github.event.review.body || github.event.issue.body }}' + with: + script: | + const eventType = process.env.EVENT_TYPE; + const request = process.env.REQUEST; + core.setOutput('request', request); + + if (eventType === 'pull_request.opened' || eventType === 'pull_request.ready_for_review') { + core.setOutput('command', 'review'); + } else if (request.startsWith("@gemini-cli /review")) { + core.setOutput('command', 'review'); + const additionalContext = request.replace(/^@gemini-cli \/review/, '').trim(); + core.setOutput('additional_context', additionalContext); + } else if (request.startsWith("@gemini-cli")) { + const additionalContext = request.replace(/^@gemini-cli/, '').trim(); + core.setOutput('command', 'invoke'); + core.setOutput('additional_context', additionalContext); + } else { + core.setOutput('command', 'fallthrough'); + } + + - name: 'Acknowledge request' + env: + GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' + ISSUE_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' + MESSAGE: |- + 🤖 Hi @${{ github.actor }}, I've received your request, and I'm working on it now! You can track my progress [in the logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for more details. + REPOSITORY: '${{ github.repository }}' + run: |- + gh issue comment "${ISSUE_NUMBER}" \ + --body "${MESSAGE}" \ + --repo "${REPOSITORY}" + + review: + needs: 'dispatch' + if: |- + ${{ needs.dispatch.outputs.command == 'review' }} + uses: './.github/workflows/gemini-review.yml' + permissions: + contents: 'read' + id-token: 'write' + issues: 'write' + pull-requests: 'write' + with: + additional_context: '${{ needs.dispatch.outputs.additional_context }}' + secrets: 'inherit' + + invoke: + needs: 'dispatch' + if: |- + ${{ needs.dispatch.outputs.command == 'invoke' }} + uses: './.github/workflows/gemini-invoke.yml' + permissions: + contents: 'read' + id-token: 'write' + issues: 'write' + pull-requests: 'write' + with: + additional_context: '${{ needs.dispatch.outputs.additional_context }}' + secrets: 'inherit' + + fallthrough: + needs: + - 'dispatch' + - 'review' + - 'invoke' + if: |- + ${{ always() && !cancelled() && (failure() || needs.dispatch.outputs.command == 'fallthrough') }} + runs-on: 'ubuntu-latest' + permissions: + contents: 'read' + issues: 'write' + pull-requests: 'write' + steps: + - name: 'Mint identity token' + id: 'mint_identity_token' + if: |- + ${{ vars.APP_ID }} + uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 + with: + app-id: '${{ vars.APP_ID }}' + private-key: '${{ secrets.APP_PRIVATE_KEY }}' + permission-contents: 'read' + permission-issues: 'write' + permission-pull-requests: 'write' + + - name: 'Send failure comment' + env: + GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' + ISSUE_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' + MESSAGE: |- + 🤖 I'm sorry @${{ github.actor }}, but I was unable to process your request. Please [see the logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for more details. + REPOSITORY: '${{ github.repository }}' + run: |- + gh issue comment "${ISSUE_NUMBER}" \ + --body "${MESSAGE}" \ + --repo "${REPOSITORY}" diff --git a/.github/workflows/gemini-invoke.yml b/.github/workflows/gemini-invoke.yml new file mode 100644 index 0000000000..81b669a327 --- /dev/null +++ b/.github/workflows/gemini-invoke.yml @@ -0,0 +1,122 @@ +name: '▶️ Gemini Invoke' + +on: + workflow_call: + inputs: + additional_context: + type: 'string' + description: 'Any additional context from the request' + required: false + +concurrency: + group: '${{ github.workflow }}-invoke-${{ github.event_name }}-${{ github.event.pull_request.number || github.event.issue.number }}' + cancel-in-progress: false + +defaults: + run: + shell: 'bash' + +jobs: + invoke: + runs-on: 'ubuntu-latest' + permissions: + contents: 'read' + id-token: 'write' + issues: 'write' + pull-requests: 'write' + steps: + - name: 'Mint identity token' + id: 'mint_identity_token' + if: |- + ${{ vars.APP_ID }} + uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 + with: + app-id: '${{ vars.APP_ID }}' + private-key: '${{ secrets.APP_PRIVATE_KEY }}' + permission-contents: 'read' + permission-issues: 'write' + permission-pull-requests: 'write' + + - name: 'Checkout Code' + uses: 'actions/checkout@v4' # ratchet:exclude + + - name: 'Run Gemini CLI' + id: 'run_gemini' + uses: 'google-github-actions/run-gemini-cli@v0' # ratchet:exclude + env: + TITLE: '${{ github.event.pull_request.title || github.event.issue.title }}' + DESCRIPTION: '${{ github.event.pull_request.body || github.event.issue.body }}' + EVENT_NAME: '${{ github.event_name }}' + GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' + IS_PULL_REQUEST: '${{ !!github.event.pull_request }}' + ISSUE_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' + REPOSITORY: '${{ github.repository }}' + ADDITIONAL_CONTEXT: '${{ inputs.additional_context }}' + with: + gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}' + gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}' + gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}' + gcp_workload_identity_provider: '${{ vars.GCP_WIF_PROVIDER }}' + gemini_api_key: '${{ secrets.GEMINI_API_KEY }}' + gemini_cli_version: '${{ vars.GEMINI_CLI_VERSION }}' + gemini_debug: '${{ fromJSON(vars.GEMINI_DEBUG || vars.ACTIONS_STEP_DEBUG || false) }}' + gemini_model: '${{ vars.GEMINI_MODEL }}' + google_api_key: '${{ secrets.GOOGLE_API_KEY }}' + use_gemini_code_assist: '${{ vars.GOOGLE_GENAI_USE_GCA }}' + use_vertex_ai: '${{ vars.GOOGLE_GENAI_USE_VERTEXAI }}' + upload_artifacts: '${{ vars.UPLOAD_ARTIFACTS }}' + workflow_name: 'gemini-invoke' + # Assistant workflows can be triggered by comments on either Issues or PRs. + # We explicitly map both fields so the CLI can correctly categorize the interaction. + github_pr_number: '${{ github.event.pull_request.number }}' + github_issue_number: '${{ github.event.issue.number }}' + settings: |- + { + "model": { + "maxSessionTurns": 25 + }, + "telemetry": { + "enabled": true, + "target": "local", + "outfile": ".gemini/telemetry.log" + }, + "mcpServers": { + "github": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "-e", + "GITHUB_PERSONAL_ACCESS_TOKEN", + "ghcr.io/github/github-mcp-server:v0.27.0" + ], + "includeTools": [ + "add_issue_comment", + "issue_read", + "list_issues", + "search_issues", + "pull_request_read", + "list_pull_requests", + "search_pull_requests", + "get_commit", + "get_file_contents", + "list_commits", + "search_code" + ], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}" + } + } + }, + "tools": { + "core": [ + "run_shell_command(cat)", + "run_shell_command(echo)", + "run_shell_command(grep)", + "run_shell_command(head)", + "run_shell_command(tail)" + ] + } + } + prompt: '/gemini-invoke' diff --git a/.github/workflows/gemini-review.yml b/.github/workflows/gemini-review.yml new file mode 100644 index 0000000000..d29b030b5d --- /dev/null +++ b/.github/workflows/gemini-review.yml @@ -0,0 +1,104 @@ +name: '🔎 Gemini Review' + +on: + workflow_call: + inputs: + additional_context: + type: 'string' + description: 'Any additional context from the request' + required: false + +concurrency: + group: '${{ github.workflow }}-review-${{ github.event_name }}-${{ github.event.pull_request.number || github.event.issue.number }}' + cancel-in-progress: true + +defaults: + run: + shell: 'bash' + +jobs: + review: + runs-on: 'ubuntu-latest' + timeout-minutes: 7 + permissions: + contents: 'read' + id-token: 'write' + issues: 'write' + pull-requests: 'write' + steps: + - name: 'Mint identity token' + id: 'mint_identity_token' + if: |- + ${{ vars.APP_ID }} + uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 + with: + app-id: '${{ vars.APP_ID }}' + private-key: '${{ secrets.APP_PRIVATE_KEY }}' + permission-contents: 'read' + permission-issues: 'write' + permission-pull-requests: 'write' + + - name: 'Checkout repository' + uses: 'actions/checkout@v4' # ratchet:exclude + + - name: 'Run Gemini pull request review' + uses: 'google-github-actions/run-gemini-cli@v0' # ratchet:exclude + id: 'gemini_pr_review' + env: + GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' + ISSUE_TITLE: '${{ github.event.pull_request.title || github.event.issue.title }}' + ISSUE_BODY: '${{ github.event.pull_request.body || github.event.issue.body }}' + PULL_REQUEST_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' + REPOSITORY: '${{ github.repository }}' + ADDITIONAL_CONTEXT: '${{ inputs.additional_context }}' + GEMINI_API_KEY: '${{ secrets.GEMINI_API_KEY }}' + GEMINI_CLI_TRUST_WORKSPACE: 'true' + with: + gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}' + gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}' + gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}' + gcp_workload_identity_provider: '${{ vars.GCP_WIF_PROVIDER }}' + gemini_api_key: '${{ secrets.GEMINI_API_KEY }}' + gemini_cli_version: '${{ vars.GEMINI_CLI_VERSION }}' + gemini_debug: '${{ fromJSON(vars.GEMINI_DEBUG || vars.ACTIONS_STEP_DEBUG || false) }}' + gemini_model: '${{ vars.GEMINI_MODEL }}' + google_api_key: '${{ secrets.GOOGLE_API_KEY }}' + use_gemini_code_assist: '${{ vars.GOOGLE_GENAI_USE_GCA }}' + use_vertex_ai: '${{ vars.GOOGLE_GENAI_USE_VERTEXAI }}' + upload_artifacts: '${{ vars.UPLOAD_ARTIFACTS }}' + workflow_name: 'gemini-review' + # Explicitly set the PR number to handle `issue_comment` triggers (which GitHub treats as issues, not PRs) + github_pr_number: '${{ github.event.pull_request.number || github.event.issue.number }}' + settings: |- + { + "model": { + "maxSessionTurns": 25 + }, + "telemetry": { + "enabled": true, + "target": "local", + "outfile": ".gemini/telemetry.log" + }, + "mcpServers": { + "github": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "-e", + "GITHUB_PERSONAL_ACCESS_TOKEN", + "ghcr.io/github/github-mcp-server:v0.27.0" + ], + "includeTools": [ + "pull_request_read", + "add_comment_to_pending_review", + "pull_request_review_write" + ], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}" + } + } + } + } + prompt: 'Please use the pull_request_read tool to read pull request #${{ github.event.pull_request.number || github.event.issue.number }}. Analyze the code for bugs, security issues, and best practices. Then, use the add_comment_to_pending_review and pull_request_review_write tools to post your review directly on pull request #${{ github.event.pull_request.number || github.event.issue.number }}.' From 6e534723dd6be938e6fb1b6f55b06de8ac4d27d8 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 13 May 2026 08:24:47 -0700 Subject: [PATCH 02/28] feat: add support for non-ADK produced input-required events This CL introduces logic to handle scenarios where a non-ADK agent transitions to the `TaskState.input_required` or `TaskState.auth_required` states. It intercepts these events and converts them into a synthetic ADK `FunctionCall` event. PiperOrigin-RevId: 914877123 --- src/google/adk/a2a/converters/to_adk_event.py | 54 +++++- src/google/adk/agents/remote_a2a_agent.py | 54 +++++- tests/unittests/a2a/converters/test_to_adk.py | 94 +++++++++- .../a2a/integration/test_client_server.py | 169 ++++++++++++++++++ .../unittests/agents/test_remote_a2a_agent.py | 6 + 5 files changed, 368 insertions(+), 9 deletions(-) diff --git a/src/google/adk/a2a/converters/to_adk_event.py b/src/google/adk/a2a/converters/to_adk_event.py index 26ae95e1b4..eab89c20f5 100644 --- a/src/google/adk/a2a/converters/to_adk_event.py +++ b/src/google/adk/a2a/converters/to_adk_event.py @@ -43,6 +43,10 @@ # Logger logger = logging.getLogger("google_adk." + __name__) +MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT = ( + "mock_function_call_for_required_user_input" +) + A2AMessageToEventConverter = Callable[ [ Message, @@ -276,6 +280,36 @@ def _merge_event_actions( return EventActions.model_validate(merged_actions_data) +def _create_mock_function_call_for_required_user_input( + state: TaskState, + output_parts: list[genai_types.Part], + long_running_function_ids: set[str], +) -> tuple[list[genai_types.Part], set[str]]: + """Creates a mock function call for input/auth-required if applicable. + + This solution allows to unblock the A2A integration with non-ADK agents from + ADK side by replacing the last text part with a synthetic function call. All + other parts are preserved. + """ + if ( + state == TaskState.input_required or state == TaskState.auth_required + ) and (not long_running_function_ids or len(long_running_function_ids) == 0): + # Find the last text part from the bottom to replace it with a function call. + # In case of input-required events, the LLM should stop the production of other parts. + for i in range(len(output_parts) - 1, -1, -1): + if output_parts[i].text: + function_call = genai_types.FunctionCall( + id=str(uuid.uuid4()), + name=MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT, + args={"input_required": output_parts[i].text}, + ) + long_running_function_ids = set() + long_running_function_ids.add(function_call.id) + output_parts[i] = genai_types.Part(function_call=function_call) + break + return output_parts, long_running_function_ids + + @a2a_experimental def convert_a2a_task_to_event( a2a_task: Task, @@ -317,9 +351,9 @@ def convert_a2a_task_to_event( output_parts, _ = _convert_a2a_parts_to_adk_parts( artifact_parts, part_converter ) - if ( - a2a_task.status.message - and a2a_task.status.state == TaskState.input_required + if a2a_task.status.message and ( + a2a_task.status.state == TaskState.input_required + or a2a_task.status.state == TaskState.auth_required ): event_actions = _merge_event_actions( event_actions, @@ -331,6 +365,12 @@ def convert_a2a_task_to_event( output_parts.extend(parts) long_running_function_ids.update(ids) + output_parts, long_running_function_ids = ( + _create_mock_function_call_for_required_user_input( + a2a_task.status.state, output_parts, long_running_function_ids + ) + ) + return _create_event( output_parts, invocation_context, @@ -422,6 +462,14 @@ def convert_a2a_status_update_to_event( output_parts.extend(parts) long_running_function_ids.update(ids) + output_parts, long_running_function_ids = ( + _create_mock_function_call_for_required_user_input( + a2a_status_update.status.state, + output_parts, + long_running_function_ids, + ) + ) + return _create_event( output_parts, invocation_context, diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 55f57aff05..495a715d76 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -65,6 +65,8 @@ from ..a2a.converters.part_converter import convert_a2a_part_to_genai_part from ..a2a.converters.part_converter import convert_genai_part_to_a2a_part from ..a2a.converters.part_converter import GenAIPartToA2APartConverter +from ..a2a.converters.to_adk_event import _create_mock_function_call_for_required_user_input +from ..a2a.converters.to_adk_event import MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT from ..a2a.converters.utils import _get_adk_metadata_key from ..a2a.experimental import a2a_experimental from ..a2a.logs.log_utils import build_a2a_request_log @@ -105,6 +107,22 @@ class A2AClientError(Exception): pass +def _add_mock_function_call(event: Event, state: TaskState) -> None: + """Generates a mock function call for input-required events if applicable.""" + if event.content is None: + return + + output_parts, long_running_tool_ids = ( + _create_mock_function_call_for_required_user_input( + state, + event.content.parts, + event.long_running_tool_ids, + ) + ) + event.content.parts = output_parts + event.long_running_tool_ids = long_running_tool_ids + + @a2a_experimental class RemoteA2aAgent(BaseAgent): """Agent that communicates with a remote A2A agent via A2A client. @@ -360,8 +378,40 @@ def _create_a2a_request_for_user_function_response( if not function_call_event: return None + event = ctx.session.events[-1] + # If the user function_response replies to a function_call for non-ADK + # input-required events (fc.name = MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT), + # the function_response part is replaced with text extracted from the + # function response. + # The implementation is based on the assumption that the user function_response + # event will contain a function_response with the name + # MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT and the response will + # contain a "result" field with the user input as a string text. + mock_function_call = [ + fc + for fc in function_call_event.get_function_calls() + if fc.name == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT + ] + if mock_function_call: + new_parts = [] + for function_response in event.get_function_responses(): + if ( + function_response.name == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT + and function_response.response + and "result" in function_response.response + ): + text_value = function_response.response.get("result") + new_parts.append( + genai_types.Part( + text=str(text_value), + ) + ) + new_event = event.model_copy(deep=True) + new_event.content.parts = new_parts + event = new_event + a2a_message = convert_event_to_a2a_message( - ctx.session.events[-1], ctx, Role.user, self._genai_part_converter + event, ctx, Role.user, self._genai_part_converter ) if function_call_event.custom_metadata: metadata = function_call_event.custom_metadata @@ -472,6 +522,7 @@ async def _handle_a2a_response( ): for part in event.content.parts: part.thought = True + _add_mock_function_call(event, task.status.state) elif ( isinstance(update, A2ATaskStatusUpdateEvent) and update.status @@ -487,6 +538,7 @@ async def _handle_a2a_response( ): for part in event.content.parts: part.thought = True + _add_mock_function_call(event, update.status.state) elif isinstance(update, A2ATaskArtifactUpdateEvent) and ( not update.append or update.last_chunk ): diff --git a/tests/unittests/a2a/converters/test_to_adk.py b/tests/unittests/a2a/converters/test_to_adk.py index 12eaf2a75a..3ab60f097d 100644 --- a/tests/unittests/a2a/converters/test_to_adk.py +++ b/tests/unittests/a2a/converters/test_to_adk.py @@ -30,6 +30,7 @@ from google.adk.a2a.converters.to_adk_event import convert_a2a_message_to_event from google.adk.a2a.converters.to_adk_event import convert_a2a_status_update_to_event from google.adk.a2a.converters.to_adk_event import convert_a2a_task_to_event +from google.adk.a2a.converters.to_adk_event import MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT from google.adk.a2a.converters.utils import _get_adk_metadata_key from google.adk.agents.invocation_context import InvocationContext from google.genai import types as genai_types @@ -330,12 +331,95 @@ def test_convert_a2a_task_to_event_merges_status_and_artifact_actions(self): assert event.actions.state_delta == {"saved_key": "saved-value"} assert event.actions.transfer_to_agent == "agent-2" assert event.content is not None - assert event.content.parts == [mock_genai_part] + assert ( + event.content.parts[0].function_call.name + == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT + ) + assert ( + event.content.parts[0].function_call.args["input_required"] + == "need input" + ) + + def test_convert_a2a_task_to_event_multiple_parts_replaces_last_text(self): + """Test converting A2A task with multiple text parts, only replacing the last text.""" + part1 = Mock(spec=A2APart) + part1.root = Mock(spec=TextPart) + part1.root.metadata = {} + part2 = Mock(spec=A2APart) + part2.root = Mock(spec=TextPart) + part2.root.metadata = {} + + task = Task( + id="task-1", + context_id="context-1", + kind="task", + status=TaskStatus( + state=TaskState.input_required, + timestamp="now", + message=Message( + message_id="m1", + role="agent", + parts=[part1, part2], + ), + ), + ) + + mock_genai_part_1 = genai_types.Part.from_text(text="Part 1") + mock_genai_part_2 = genai_types.Part.from_text(text="Part 2") - def test_convert_a2a_task_to_event_none(self): - """Test convert_a2a_task_to_event with None.""" - with pytest.raises(ValueError, match="A2A task cannot be None"): - convert_a2a_task_to_event(None) + part_converter_mock = Mock() + part_converter_mock.side_effect = [[mock_genai_part_1], [mock_genai_part_2]] + + event = convert_a2a_task_to_event( + task, + author="test-author", + invocation_context=self.mock_context, + part_converter=part_converter_mock, + ) + + assert event is not None + assert event.content is not None + assert len(event.content.parts) == 2 + assert event.content.parts[0].text == "Part 1" + assert ( + event.content.parts[1].function_call.name + == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT + ) + + def test_convert_a2a_task_to_event_no_text_parts(self): + """Test converting A2A task with no text parts should not inject function call.""" + part1 = Mock(spec=A2APart) + part1.root = Mock() # Not a TextPart + part1.root.metadata = {} + + task = Task( + id="task-1", + context_id="context-1", + kind="task", + status=TaskStatus( + state=TaskState.input_required, + timestamp="now", + message=Message( + message_id="m1", + role="agent", + parts=[part1], + ), + ), + ) + mock_image_part = genai_types.Part( + inline_data=genai_types.Blob(mime_type="image/jpeg", data=b"fake") + ) + + event = convert_a2a_task_to_event( + task, + author="test-author", + invocation_context=self.mock_context, + part_converter=Mock(return_value=[mock_image_part]), + ) + + assert event is not None + assert event.content is not None + assert event.content.parts == [mock_image_part] def test_convert_a2a_status_update_to_event_success(self): """Test successful conversion of A2A status update to Event.""" diff --git a/tests/unittests/a2a/integration/test_client_server.py b/tests/unittests/a2a/integration/test_client_server.py index bd1d72f617..18b13d05d2 100644 --- a/tests/unittests/a2a/integration/test_client_server.py +++ b/tests/unittests/a2a/integration/test_client_server.py @@ -14,11 +14,19 @@ """Integration tests for A2A client-server interaction.""" +import logging +from unittest.mock import AsyncMock + +from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication +from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import Message as A2AMessage from a2a.types import Part as A2APart from a2a.types import Task from a2a.types import TaskState +from a2a.types import TaskStatus from a2a.types import TextPart +from google.adk.a2a.agent.interceptors.new_integration_extension import _NEW_A2A_ADK_INTEGRATION_EXTENSION +from google.adk.a2a.converters.to_adk_event import MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT from google.adk.a2a.executor.config import A2aAgentExecutorConfig from google.adk.a2a.executor.interceptors.include_artifacts_in_a2a_event import include_artifacts_in_a2a_event_interceptor from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX @@ -32,8 +40,11 @@ from .client import create_a2a_client from .client import create_client +from .server import agent_card from .server import create_server_app +logger = logging.getLogger("google_adk." + __name__) + def create_streaming_mock_run_async(received_requests: list): """Creates a mock_run_async that streams multiple chunks.""" @@ -636,3 +647,161 @@ async def mock_run_async(**kwargs): assert task.artifacts[2].artifact_id == "artifact2_1" assert task.artifacts[2].name == "artifact2" assert task.artifacts[2].parts[0].root.text == "artifact content" + + +@pytest.mark.asyncio +async def test_user_follow_up_sends_task_id_with_input_required(): + """Test that client follow-up sends the same task_id.""" + + task_id = "mocked-task-id-123" + context_id = "mocked-context-id-456" + mock_task = Task( + id=task_id, + context_id=context_id, + kind="task", + status=TaskStatus( + state=TaskState.input_required, + message=A2AMessage( + message_id="mocked-message-id-789", + role="user", + parts=[A2APart(root=TextPart(text="Input required"))], + ), + ), + metadata={_NEW_A2A_ADK_INTEGRATION_EXTENSION: True}, + ) + + mock_handler = AsyncMock(spec=RequestHandler) + # First call returns input_required, second call completes + mock_handler.on_message_send.side_effect = [ + mock_task, + Task( + id=task_id, + context_id=context_id, + kind="task", + status=TaskStatus(state=TaskState.completed), + metadata={_NEW_A2A_ADK_INTEGRATION_EXTENSION: True}, + ), + ] + + app = A2AFastAPIApplication( + agent_card=agent_card, http_handler=mock_handler + ).build() + agent = create_client(app, streaming=False) + + session_service = InMemorySessionService() + await session_service.create_session( + app_name="ClientApp", user_id="test_user", session_id="test_session" + ) + client_runner = Runner( + app_name="ClientApp", agent=agent, session_service=session_service + ) + + # First Turn + new_message_1 = types.Content(parts=[types.Part(text="Turn 1")], role="user") + found_call_id = None + async for event in client_runner.run_async( + user_id="test_user", session_id="test_session", new_message=new_message_1 + ): + for call in event.get_function_calls(): + if call.name == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT: + found_call_id = call.id + + assert found_call_id is not None + + # Second Turn (Follow-up) + function_response = types.FunctionResponse( + id=found_call_id, + name=MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT, + response={"result": "Turn 2"}, + ) + new_message_2 = types.Content( + parts=[types.Part(function_response=function_response)], role="user" + ) + async for _ in client_runner.run_async( + user_id="test_user", session_id="test_session", new_message=new_message_2 + ): + pass + + assert mock_handler.on_message_send.call_count == 2 + # Second call args + call_args_2 = mock_handler.on_message_send.call_args_list[1] + params_2 = call_args_2[0][0] + assert params_2.message.task_id == task_id + + +@pytest.mark.asyncio +async def test_user_follow_up_sends_task_id_with_input_required_legacy_impl(): + """Test that client follow-up sends the same task_id.""" + + task_id = "mocked-task-id-123" + context_id = "mocked-context-id-456" + mock_task = Task( + id=task_id, + context_id=context_id, + kind="task", + status=TaskStatus( + state=TaskState.input_required, + message=A2AMessage( + message_id="mocked-message-id-789", + role="user", + parts=[A2APart(root=TextPart(text="Input required"))], + ), + ), + ) + + mock_handler = AsyncMock(spec=RequestHandler) + # First call returns input_required, second call completes + mock_handler.on_message_send.side_effect = [ + mock_task, + Task( + id=task_id, + context_id=context_id, + kind="task", + status=TaskStatus(state=TaskState.completed), + ), + ] + + app = A2AFastAPIApplication( + agent_card=agent_card, http_handler=mock_handler + ).build() + agent = create_client(app, streaming=False) + + session_service = InMemorySessionService() + await session_service.create_session( + app_name="ClientApp", user_id="test_user", session_id="test_session" + ) + client_runner = Runner( + app_name="ClientApp", agent=agent, session_service=session_service + ) + + # First Turn + new_message_1 = types.Content(parts=[types.Part(text="Turn 1")], role="user") + found_call_id = None + async for event in client_runner.run_async( + user_id="test_user", session_id="test_session", new_message=new_message_1 + ): + for call in event.get_function_calls(): + if call.name == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT: + found_call_id = call.id + + assert found_call_id is not None + + # Second Turn (Follow-up) + function_response = types.FunctionResponse( + id=found_call_id, + name=MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT, + response={"result": "Turn 2"}, + ) + new_message_2 = types.Content( + parts=[types.Part(function_response=function_response)], role="user" + ) + async for _ in client_runner.run_async( + user_id="test_user", session_id="test_session", new_message=new_message_2 + ): + pass + + assert mock_handler.on_message_send.call_count == 2 + # Second call args + call_args_2 = mock_handler.on_message_send.call_args_list[1] + params_2 = call_args_2[0][0] + assert params_2.message.task_id == task_id diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index f5e92bc46d..073ad36d9f 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -573,6 +573,9 @@ def test_create_a2a_request_for_user_function_response_success(self): mock_function_event.custom_metadata = { A2A_METADATA_PREFIX + "task_id": "task-123" } + mock_function_event.content = Mock() + mock_function_event.content.parts = [Mock()] + mock_function_event.get_function_calls.return_value = [] # Mock latest event with function response - set proper author mock_latest_event = Mock() @@ -1372,6 +1375,9 @@ def test_create_a2a_request_for_user_function_response_success(self): mock_function_event.custom_metadata = { A2A_METADATA_PREFIX + "task_id": "task-123" } + mock_function_event.content = Mock() + mock_function_event.content.parts = [Mock()] + mock_function_event.get_function_calls.return_value = [] # Mock latest event with function response - set proper author mock_latest_event = Mock() From e377cb5ec057ed4176f2714f368c45e730053eb0 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 13 May 2026 10:25:12 -0700 Subject: [PATCH 03/28] fix: fallback to project id if crendetials don't contain quota project Co-authored-by: Kathy Wu PiperOrigin-RevId: 914934690 --- .../adk/integrations/agent_registry/agent_registry.py | 5 ++++- .../integrations/agent_registry/test_agent_registry.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py index 887c894cd4..a486215151 100644 --- a/src/google/adk/integrations/agent_registry/agent_registry.py +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -208,7 +208,10 @@ def _get_auth_headers(self) -> Dict[str, str]: "Authorization": f"Bearer {self._credentials.token}", "Content-Type": "application/json", } - quota_project_id = getattr(self._credentials, "quota_project_id", None) + quota_project_id = ( + getattr(self._credentials, "quota_project_id", None) + or self.project_id + ) if quota_project_id: headers["x-goog-user-project"] = quota_project_id return headers diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py index fd3f2a8ec4..f4ba47cf25 100644 --- a/tests/unittests/integrations/agent_registry/test_agent_registry.py +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -607,6 +607,15 @@ def test_get_auth_headers(self, registry): assert headers["Authorization"] == "Bearer fake-token" assert headers["x-goog-user-project"] == "quota-project" + def test_get_auth_headers_fallback_to_project_id(self, registry): + registry._credentials.token = "fake-token" + registry._credentials.refresh = MagicMock() + registry._credentials.quota_project_id = None + + headers = registry._get_auth_headers() + assert headers["Authorization"] == "Bearer fake-token" + assert headers["x-goog-user-project"] == "test-project" + @patch("httpx.Client") def test_make_request_raises_http_status_error(self, mock_httpx, registry): mock_response = MagicMock() From f9097cbf7b64b78da894e482480fc22a9603e429 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 13 May 2026 12:31:33 -0700 Subject: [PATCH 04/28] fix: Fix missing dynamically loaded tools in SkillToolset during the same invocation Currently, BaseToolset caches its tools per invocation_id. Because SkillToolset dynamically resolves additional tools from the state when a skill is loaded, the cache prevents new tools from being picked up in the same invocation right after a load_skill call. This change sets `_use_invocation_cache = False` in SkillToolset so that it correctly re-evaluates the state-dependent tools at each step of the LLM generation loop within an invocation, preventing "Tool not found" errors. PiperOrigin-RevId: 914997555 --- src/google/adk/tools/skill_toolset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index 3c60bd5918..2de30c8f12 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -895,6 +895,8 @@ def __init__( script_timeout: Timeout in seconds for shell script execution via subprocess.run. Defaults to 300 seconds. Does not apply to Python scripts executed via exec(). + additional_tools: Optional list of `BaseTool` or `BaseToolset` instances + to be made available to the agent when certain skills are activated. """ super().__init__() @@ -911,6 +913,8 @@ def __init__( self._registry = registry self._code_executor = code_executor self._script_timeout = script_timeout + # Needed for mid-turn reloading of skill tools. + self._use_invocation_cache = False self._invocation_cache: dict[ str, dict[str, models.Skill | asyncio.Future[models.Skill | None] | None], From e04a4683f46a1c956bdd7f607c55ea9017bc5e5d Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Wed, 13 May 2026 15:48:07 -0700 Subject: [PATCH 05/28] chore: Remove deprecated CLI flags and version-based service URI handling This change removes support for the deprecated `--session_db_url`, `--artifact_storage_uri`, and `--verbosity` flags from the ADK CLI. It also simplifies the service URI handling in `cli_deploy.py` by always using the new `--session_service_uri`, `--artifact_service_uri`, and `--memory_service_uri` flags, regardless of the ADK version. The deprecated flags has been more than 1 year Co-authored-by: Shangjie Chen PiperOrigin-RevId: 915100881 --- src/google/adk/cli/cli_deploy.py | 18 ++--- src/google/adk/cli/cli_tools_click.py | 68 +------------------ tests/unittests/cli/utils/test_cli_deploy.py | 12 +++- .../cli/utils/test_cli_tools_click.py | 35 ---------- 4 files changed, 16 insertions(+), 117 deletions(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 19a7dde9ae..4350ae17ce 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -598,18 +598,12 @@ def _get_service_option_by_adk_version( parsed_version = parse(adk_version) options: list[str] = [] - if parsed_version >= parse('1.3.0'): - if session_uri: - options.append(f'--session_service_uri={session_uri}') - if artifact_uri: - options.append(f'--artifact_service_uri={artifact_uri}') - if memory_uri: - options.append(f'--memory_service_uri={memory_uri}') - else: - if session_uri: - options.append(f'--session_db_url={session_uri}') - if parsed_version >= parse('1.2.0') and artifact_uri: - options.append(f'--artifact_storage_uri={artifact_uri}') + if session_uri: + options.append(f'--session_service_uri={session_uri}') + if artifact_uri: + options.append(f'--artifact_service_uri={artifact_uri}') + if memory_uri: + options.append(f'--memory_service_uri={memory_uri}') if use_local_storage is not None and parsed_version >= parse( _LOCAL_STORAGE_FLAG_MIN_VERSION diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 07ccc15892..fda251da10 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -1411,43 +1411,6 @@ def _deprecate_staging_bucket(ctx, param, value): return value -def deprecated_adk_services_options(): - """Deprecated ADK services options.""" - - def warn(alternative_param, ctx, param, value): - if value: - click.echo( - click.style( - f"WARNING: Deprecated option --{param.name} is used. Please use" - f" {alternative_param} instead.", - fg="yellow", - ), - err=True, - ) - return value - - def decorator(func): - @click.option( - "--session_db_url", - help="Deprecated. Use --session_service_uri instead.", - callback=functools.partial(warn, "--session_service_uri"), - ) - @click.option( - "--artifact_storage_uri", - type=str, - help="Deprecated. Use --artifact_service_uri instead.", - callback=functools.partial(warn, "--artifact_service_uri"), - default=None, - ) - @functools.wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - return wrapper - - return decorator - - def fast_api_common_options(): """Decorator to add common fast api options to click commands.""" @@ -1598,7 +1561,6 @@ def wrapper(ctx, *args, **kwargs): @fast_api_common_options() @web_options() @adk_services_options(default_use_local_storage=True) -@deprecated_adk_services_options() @click.argument( "agents_dir", type=click.Path( @@ -1621,8 +1583,6 @@ def cli_web( artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, use_local_storage: bool = True, - session_db_url: Optional[str] = None, # Deprecated - artifact_storage_uri: Optional[str] = None, # Deprecated a2a: bool = False, reload_agents: bool = False, extra_plugins: Optional[list[str]] = None, @@ -1639,8 +1599,6 @@ def cli_web( adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir """ - session_service_uri = session_service_uri or session_db_url - artifact_service_uri = artifact_service_uri or artifact_storage_uri logs.setup_adk_logger(getattr(logging, log_level.upper())) @asynccontextmanager @@ -1711,7 +1669,6 @@ async def _lifespan(app: FastAPI): ) @fast_api_common_options() @adk_services_options(default_use_local_storage=True) -@deprecated_adk_services_options() @click.option( "--auto_create_session", is_flag=True, @@ -1735,8 +1692,6 @@ def cli_api_server( artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, use_local_storage: bool = True, - session_db_url: Optional[str] = None, # Deprecated - artifact_storage_uri: Optional[str] = None, # Deprecated a2a: bool = False, reload_agents: bool = False, extra_plugins: Optional[list[str]] = None, @@ -1752,8 +1707,6 @@ def cli_api_server( adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir """ - session_service_uri = session_service_uri or session_db_url - artifact_service_uri = artifact_service_uri or artifact_storage_uri logs.setup_adk_logger(getattr(logging, log_level.upper())) config = uvicorn.Config( @@ -1882,11 +1835,6 @@ def cli_api_server( default="INFO", help="Optional. Set the logging level", ) -@click.option( - "--verbosity", - type=LOG_LEVELS, - help="Deprecated. Use --log_level instead.", -) @click.argument( "agent", type=click.Path( @@ -1932,7 +1880,6 @@ def cli_api_server( ) # TODO: Add eval_storage_uri option back when evals are supported in Cloud Run. @adk_services_options(default_use_local_storage=False) -@deprecated_adk_services_options() @click.pass_context def cli_deploy_cloud_run( ctx, @@ -1948,14 +1895,11 @@ def cli_deploy_cloud_run( with_ui: bool, adk_version: str, log_level: str, - verbosity: Optional[str], allow_origins: Optional[list[str]] = None, session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, use_local_storage: bool = False, - session_db_url: Optional[str] = None, # Deprecated - artifact_storage_uri: Optional[str] = None, # Deprecated a2a: bool = False, trigger_sources: Optional[str] = None, ): @@ -1972,19 +1916,9 @@ def cli_deploy_cloud_run( adk deploy cloud_run --project=[project] --region=[region] path/to/my_agent -- --no-allow-unauthenticated --min-instances=2 """ - if verbosity: - click.secho( - "WARNING: The --verbosity option is deprecated. Use --log_level" - " instead.", - fg="yellow", - err=True, - ) _warn_if_with_ui(with_ui) - session_service_uri = session_service_uri or session_db_url - artifact_service_uri = artifact_service_uri or artifact_storage_uri - # Parse arguments to separate gcloud args (after --) from regular args gcloud_args = [] if "--" in ctx.args: @@ -2028,7 +1962,7 @@ def cli_deploy_cloud_run( allow_origins=allow_origins, with_ui=with_ui, log_level=log_level, - verbosity=verbosity, + verbosity=log_level, adk_version=adk_version, session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, diff --git a/tests/unittests/cli/utils/test_cli_deploy.py b/tests/unittests/cli/utils/test_cli_deploy.py index a1e66139f8..506e274845 100644 --- a/tests/unittests/cli/utils/test_cli_deploy.py +++ b/tests/unittests/cli/utils/test_cli_deploy.py @@ -147,7 +147,10 @@ def test_resolve_project_from_gcloud_fails( "gs://a", "rag://m", None, - "--session_db_url=sqlite://s --artifact_storage_uri=gs://a", + ( + "--session_service_uri=sqlite://s --artifact_service_uri=gs://a" + " --memory_service_uri=rag://m" + ), ), ( "0.5.0", @@ -155,7 +158,10 @@ def test_resolve_project_from_gcloud_fails( "gs://a", "rag://m", None, - "--session_db_url=sqlite://s", + ( + "--session_service_uri=sqlite://s --artifact_service_uri=gs://a" + " --memory_service_uri=rag://m" + ), ), ( "1.3.0", @@ -179,7 +185,7 @@ def test_resolve_project_from_gcloud_fails( "gs://a", None, None, - "--artifact_storage_uri=gs://a", + "--artifact_service_uri=gs://a", ), ( "1.21.0", diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index 7c642dbbe9..ad47c9ecbe 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -596,41 +596,6 @@ def test_cli_web_passes_service_uris( assert called_kwargs.get("memory_service_uri") == "rag://mycorpus" -@pytest.mark.unmute_click -def test_cli_web_warns_and_maps_deprecated_uris( - tmp_path: Path, - _patch_uvicorn: _Recorder, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """`adk web` should accept deprecated URI flags with warnings.""" - agents_dir = tmp_path / "agents" - agents_dir.mkdir() - - mock_get_app = _Recorder() - monkeypatch.setattr(cli_tools_click, "get_fast_api_app", mock_get_app) - - runner = CliRunner() - result = runner.invoke( - cli_tools_click.main, - [ - "web", - str(agents_dir), - "--session_db_url", - "sqlite:///deprecated.db", - "--artifact_storage_uri", - "gs://deprecated", - ], - ) - - assert result.exit_code == 0 - called_kwargs = mock_get_app.calls[0][1] - assert called_kwargs.get("session_service_uri") == "sqlite:///deprecated.db" - assert called_kwargs.get("artifact_service_uri") == "gs://deprecated" - # Check output for deprecation warnings (CliRunner captures both stdout and stderr) - assert "--session_db_url" in result.output - assert "--artifact_storage_uri" in result.output - - def test_cli_eval_with_eval_set_file_path( mock_load_eval_set_from_file, mock_get_root_agent, From c35a57969d70cb98356297beb36fdf79ab7c00f6 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 13 May 2026 16:49:15 -0700 Subject: [PATCH 06/28] fix(auth): remove unneeded OAuth flows This is related to https://github.com/google/adk-python/issues/5327. `BaseToolset.get_auth_config()` returns `None` by default and removing its overrides in toolsets that don't need OAuth flows to list tools does the job. No regressions in unit tests (`pytest tests/unittests/auth`): ``` ================================================================= 181 passed, 508 warnings in 4.23s ================================================================= ``` PiperOrigin-RevId: 915128706 --- src/google/adk/tools/apihub_tool/apihub_toolset.py | 10 ---------- .../application_integration_toolset.py | 10 ---------- .../openapi_spec_parser/openapi_toolset.py | 11 ----------- 3 files changed, 31 deletions(-) diff --git a/src/google/adk/tools/apihub_tool/apihub_toolset.py b/src/google/adk/tools/apihub_tool/apihub_toolset.py index 5560c7ce52..19c6d709e3 100644 --- a/src/google/adk/tools/apihub_tool/apihub_toolset.py +++ b/src/google/adk/tools/apihub_tool/apihub_toolset.py @@ -199,13 +199,3 @@ def _prepare_toolset(self) -> None: async def close(self): if self._openapi_toolset: await self._openapi_toolset.close() - - @override - def get_auth_config(self) -> Optional[AuthConfig]: - """Returns the auth config for this toolset. - - ADK will populate exchanged_auth_credential on this config before calling - get_tools(). The toolset can then access the ready-to-use credential via - self._auth_config.exchanged_auth_credential. - """ - return self._auth_config diff --git a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py index 5e068fba1d..e4e2c5dde7 100644 --- a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py +++ b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py @@ -289,13 +289,3 @@ async def get_tools( async def close(self) -> None: if self._openapi_toolset: await self._openapi_toolset.close() - - @override - def get_auth_config(self) -> Optional[AuthConfig]: - """Returns the auth config for this toolset. - - ADK will populate exchanged_auth_credential on this config before calling - get_tools(). The toolset can then access the ready-to-use credential via - self._auth_config.exchanged_auth_credential. - """ - return self._auth_config diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py index 235ae7c350..99e649d9c9 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py @@ -245,14 +245,3 @@ def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]: @override async def close(self): pass - - @override - def get_auth_config(self) -> Optional[AuthConfig]: - """Returns the auth config for this toolset. - - Note: This returns a copy so any exchanged credentials populated by the ADK - framework do not persist on the toolset instance across invocations. - """ - return ( - self._auth_config.model_copy(deep=True) if self._auth_config else None - ) From cfe8d2cc2b29e392886f997be4d77d4cced9959e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 14 May 2026 11:54:28 -0700 Subject: [PATCH 07/28] feat: Add mTLS support to Google Cloud Telemetry exporter This change enables the Google Cloud Telemetry exporter to use mTLS endpoints. It checks for the availability of client certificates and respects the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variables to determine whether to use the mTLS-specific endpoint and configure the session accordingly. PiperOrigin-RevId: 915541335 --- src/google/adk/telemetry/google_cloud.py | 94 +++++++++++++- .../unittests/telemetry/test_google_cloud.py | 115 ++++++++++++++++++ 2 files changed, 207 insertions(+), 2 deletions(-) diff --git a/src/google/adk/telemetry/google_cloud.py b/src/google/adk/telemetry/google_cloud.py index dee8f3f554..c34cdba90c 100644 --- a/src/google/adk/telemetry/google_cloud.py +++ b/src/google/adk/telemetry/google_cloud.py @@ -14,13 +14,17 @@ from __future__ import annotations +import enum import logging import os +from typing import Any +from typing import Callable from typing import cast from typing import Optional from typing import TYPE_CHECKING import google.auth +from google.auth.transport import mtls from opentelemetry.sdk._logs import LogRecordProcessor from opentelemetry.sdk._logs.export import BatchLogRecordProcessor from opentelemetry.sdk.metrics.export import MetricReader @@ -40,6 +44,19 @@ _GCP_LOG_NAME_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_DEFAULT_LOG_NAME' _DEFAULT_LOG_NAME = 'adk-otel' +_DEFAULT_TELEMETRY_TRACES_ENPOINT = 'https://telemetry.googleapis.com/v1/traces' +_DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT = ( + 'https://telemetry.mtls.googleapis.com/v1/traces' +) + + +class _MtlsEndpoint(enum.Enum): + """The mTLS endpoint setting.""" + + AUTO = 'auto' + ALWAYS = 'always' + NEVER = 'never' + def get_gcp_exporters( enable_cloud_tracing: bool = False, @@ -100,10 +117,24 @@ def _get_gcp_span_exporter(credentials: Credentials) -> SpanProcessor: from google.auth.transport.requests import AuthorizedSession from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + session = AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_TRACES_ENPOINT + return BatchSpanProcessor( OTLPSpanExporter( - session=AuthorizedSession(credentials=credentials), - endpoint='https://telemetry.googleapis.com/v1/traces', + session=session, + endpoint=endpoint, ) ) @@ -158,3 +189,62 @@ def get_gcp_resource(project_id: Optional[str] = None) -> Resource: ' GCE, GKE or CloudRun related resource attributes may be missing' ) return resource + + +def _get_api_endpoint( + client_cert_source: Callable[[], tuple[bytes, bytes]] | None = None, +) -> str: + """Returns API endpoint based on mTLS configuration and cert availability. + + Args: + client_cert_source: A callable that returns the client certificate and + key, or None. + + Returns: + str: The API endpoint to be used. + """ + use_mtls_endpoint_str = os.getenv( + 'GOOGLE_API_USE_MTLS_ENDPOINT', _MtlsEndpoint.AUTO.value + ).lower() + + try: + use_mtls_endpoint = _MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + logger.warning( + 'Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of ' + '%s. Defaulting to %s.', + [e.value for e in _MtlsEndpoint], + _MtlsEndpoint.AUTO.value, + ) + use_mtls_endpoint = _MtlsEndpoint.AUTO + + if (use_mtls_endpoint is _MtlsEndpoint.ALWAYS) or ( + use_mtls_endpoint is _MtlsEndpoint.AUTO and client_cert_source + ): + return _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT + + return _DEFAULT_TELEMETRY_TRACES_ENPOINT + + +def _use_client_cert_effective() -> bool: + """Returns whether client certificate should be used for mTLS. + + This checks if the google-auth version supports should_use_client_cert + automatic mTLS enablement. Alternatively, it reads from the + GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS. + """ + try: + return bool(mtls.should_use_client_cert()) + except (ImportError, AttributeError): + use_client_cert_str = os.getenv( + 'GOOGLE_API_USE_CLIENT_CERTIFICATE', 'false' + ).lower() + if use_client_cert_str not in ('true', 'false'): + logger.warning( + 'Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be' + ' either `true` or `false`' + ) + return use_client_cert_str == 'true' diff --git a/tests/unittests/telemetry/test_google_cloud.py b/tests/unittests/telemetry/test_google_cloud.py index 0199e7b4b6..284a51539e 100644 --- a/tests/unittests/telemetry/test_google_cloud.py +++ b/tests/unittests/telemetry/test_google_cloud.py @@ -16,8 +16,18 @@ from typing import Optional from unittest import mock +from google.adk.telemetry import google_cloud +from google.adk.telemetry.google_cloud import _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT +from google.adk.telemetry.google_cloud import _DEFAULT_TELEMETRY_TRACES_ENPOINT +from google.adk.telemetry.google_cloud import _get_api_endpoint +from google.adk.telemetry.google_cloud import _get_gcp_span_exporter +from google.adk.telemetry.google_cloud import _use_client_cert_effective from google.adk.telemetry.google_cloud import get_gcp_exporters from google.adk.telemetry.google_cloud import get_gcp_resource +import google.auth.credentials +from google.auth.transport import mtls +from google.auth.transport import requests +from opentelemetry.exporter.otlp.proto.http import trace_exporter import pytest @@ -89,3 +99,108 @@ def test_get_gcp_resource( otel_resource.attributes.get("gcp.project_id", None) == expected_project_id ) + + +@mock.patch.object(mtls, "should_use_client_cert", autospec=True) +def test_use_client_cert_effective_from_mtls(mock_should_use): + mock_should_use.return_value = True + assert _use_client_cert_effective() + + mock_should_use.return_value = False + assert not _use_client_cert_effective() + + +def test_use_client_cert_effective_from_env( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +): + with mock.patch.object( + mtls, + "should_use_client_cert", + autospec=True, + side_effect=AttributeError, + ): + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true") + assert _use_client_cert_effective() + + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + assert not _use_client_cert_effective() + + # Test invalid value defaults to False + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "maybe") + assert not _use_client_cert_effective() + assert ( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + in caplog.text + ) + + +@pytest.mark.parametrize( + "env_val, cert_source, expected", + [ + ("auto", lambda: b"cert", _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT), + ("auto", None, _DEFAULT_TELEMETRY_TRACES_ENPOINT), + ("always", None, _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT), + ("never", lambda: b"cert", _DEFAULT_TELEMETRY_TRACES_ENPOINT), + ("invalid", None, _DEFAULT_TELEMETRY_TRACES_ENPOINT), + ], +) +def test_get_api_endpoint( + env_val, + cert_source, + expected, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +): + monkeypatch.setenv("GOOGLE_API_USE_MTLS_ENDPOINT", env_val) + if env_val == "invalid": + assert _get_api_endpoint(cert_source) == expected + assert ( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of" + in caplog.text + ) + else: + assert _get_api_endpoint(cert_source) == expected + + +@mock.patch.object(requests, "AuthorizedSession", autospec=True) +@mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter", + autospec=True, +) +@mock.patch( + "google.adk.telemetry.google_cloud.BatchSpanProcessor", autospec=True +) +@mock.patch( + "google.adk.telemetry.google_cloud._use_client_cert_effective", + autospec=True, +) +@mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", autospec=True +) +@mock.patch( + "google.auth.transport.mtls.default_client_cert_source", autospec=True +) +def test_get_gcp_span_exporter_mtls( + mock_default_cert: mock.MagicMock, + mock_has_cert: mock.MagicMock, + mock_use_cert: mock.MagicMock, + mock_batch: mock.MagicMock, + mock_exporter: mock.MagicMock, + mock_session: mock.MagicMock, +): + credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + mock_use_cert.return_value = True + mock_has_cert.return_value = True + mock_default_cert.return_value = b"cert" + + _get_gcp_span_exporter(credentials) + + mock_session.assert_called_once_with(credentials=credentials) + mock_session.return_value.configure_mtls_channel.assert_called_once() + mock_exporter.assert_called_once_with( + session=mock_session.return_value, + endpoint=_DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT, + ) From 8de1ae8f8054c5b82a633c4412a8c5685b763851 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Thu, 14 May 2026 13:09:58 -0700 Subject: [PATCH 08/28] chore: change name of skill cache to "fetched_skill_cache" to reduce confusion from BaseToolset's invocation_cache Co-authored-by: Kathy Wu PiperOrigin-RevId: 915578052 --- src/google/adk/tools/skill_toolset.py | 27 +++++++++---- tests/unittests/tools/test_skill_toolset.py | 44 ++++++++++++++++++--- 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index 2de30c8f12..ef579d8256 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -19,6 +19,7 @@ from __future__ import annotations import asyncio +import collections import json import logging import mimetypes @@ -915,10 +916,11 @@ def __init__( self._script_timeout = script_timeout # Needed for mid-turn reloading of skill tools. self._use_invocation_cache = False - self._invocation_cache: dict[ + # Cache fetched remote skill definitions per turn to reduce requests to registry + self._fetched_skill_cache: collections.OrderedDict[ str, dict[str, models.Skill | asyncio.Future[models.Skill | None] | None], - ] = {} + ] = collections.OrderedDict() self._max_cache_turns = 16 self._provided_tools_by_name = {} @@ -1023,14 +1025,13 @@ async def _get_or_fetch_skill( return None if invocation_id: - if invocation_id not in self._invocation_cache: + if invocation_id not in self._fetched_skill_cache: # Enforce bounded cache (FIFO eviction) - if len(self._invocation_cache) >= self._max_cache_turns: - oldest = next(iter(self._invocation_cache)) - self._invocation_cache.pop(oldest) - self._invocation_cache[invocation_id] = {} + if len(self._fetched_skill_cache) >= self._max_cache_turns: + self._fetched_skill_cache.popitem(last=False) + self._fetched_skill_cache[invocation_id] = {} - turn_cache = self._invocation_cache[invocation_id] + turn_cache = self._fetched_skill_cache[invocation_id] if skill_name in turn_cache: cached = turn_cache[skill_name] if isinstance(cached, asyncio.Future): @@ -1080,6 +1081,16 @@ async def process_llm_request( llm_request.append_instructions(instructions) + @override + async def close(self) -> None: + """Performs cleanup and releases resources held by the toolset.""" + for turn_cache in self._fetched_skill_cache.values(): + for cached in turn_cache.values(): + if isinstance(cached, asyncio.Future) and not cached.done(): + cached.cancel() + self._fetched_skill_cache.clear() + await super().close() + def __getattr__(name: str) -> Any: if name == "DEFAULT_SKILL_SYSTEM_INSTRUCTION": diff --git a/tests/unittests/tools/test_skill_toolset.py b/tests/unittests/tools/test_skill_toolset.py index b58c01b91b..f637377513 100644 --- a/tests/unittests/tools/test_skill_toolset.py +++ b/tests/unittests/tools/test_skill_toolset.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import collections import logging import sys from unittest import mock @@ -1859,14 +1860,14 @@ async def test_turn_scoped_skill_cache_eviction(mock_registry, mock_skill1): for i in range(16): await toolset._get_or_fetch_skill("skill1", f"turn-{i}") - assert len(toolset._invocation_cache) == 16 - assert "turn-0" in toolset._invocation_cache + assert len(toolset._fetched_skill_cache) == 16 + assert "turn-0" in toolset._fetched_skill_cache # Next turn should evict oldest (turn-0) await toolset._get_or_fetch_skill("skill1", "turn-16") - assert len(toolset._invocation_cache) == 16 - assert "turn-0" not in toolset._invocation_cache - assert "turn-1" in toolset._invocation_cache + assert len(toolset._fetched_skill_cache) == 16 + assert "turn-0" not in toolset._fetched_skill_cache + assert "turn-1" in toolset._fetched_skill_cache @pytest.mark.asyncio @@ -1891,3 +1892,36 @@ async def delayed_get_skill(name): # Registry should have been called exactly once mock_registry.get_skill.assert_called_once_with(name="skill1") + + +def test_skill_toolset_disables_invocation_cache(): + """Verify SkillToolset disables tool invocation caching to allow dynamic tools.""" + toolset = skill_toolset.SkillToolset() + assert toolset._use_invocation_cache is False + + +@pytest.mark.asyncio +async def test_close_cancels_futures_and_clears_cache(): + # pylint: disable=protected-access + toolset = skill_toolset.SkillToolset() + + # Create mock futures for testing close() behavior + loop = asyncio.get_running_loop() + fut1 = loop.create_future() + fut2 = loop.create_future() + fut2.set_result(None) # Already done future + + toolset._fetched_skill_cache = collections.OrderedDict( + { + "turn1": { + "skill1": fut1, + "skill2": fut2, + } + } + ) + + await toolset.close() + + assert fut1.cancelled() + assert not fut2.cancelled() # Done futures shouldn't/can't be cancelled + assert not toolset._fetched_skill_cache From 9f38973081aacf1999f707dac9778b72b5ce75fd Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 14 May 2026 14:24:15 -0700 Subject: [PATCH 09/28] feat: Make Agent Skill description validation more informative PiperOrigin-RevId: 915616175 --- src/google/adk/skills/models.py | 8 ++++++-- tests/unittests/skills/test_models.py | 5 ++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/google/adk/skills/models.py b/src/google/adk/skills/models.py index 1cd443b2ca..9e9b378a97 100644 --- a/src/google/adk/skills/models.py +++ b/src/google/adk/skills/models.py @@ -108,8 +108,12 @@ def _validate_name(cls, v: str) -> str: def _validate_description(cls, v: str) -> str: if not v: raise ValueError("description must not be empty") - if len(v) > 1024: - raise ValueError("description must be at most 1024 characters") + description_len = len(v) + if description_len > 1024: + raise ValueError( + "description must be at most 1024 characters. Description length:" + f" {description_len}" + ) return v @field_validator("compatibility") diff --git a/tests/unittests/skills/test_models.py b/tests/unittests/skills/test_models.py index 73136be0b9..ffbbb2dd50 100644 --- a/tests/unittests/skills/test_models.py +++ b/tests/unittests/skills/test_models.py @@ -153,7 +153,10 @@ def test_description_empty(): def test_description_too_long(): - with pytest.raises(ValidationError, match="at most 1024 characters"): + with pytest.raises( + ValidationError, + match="at most 1024 characters. Description length: 1025", + ): models.Frontmatter(name="my-skill", description="x" * 1025) From 88ebd426beaec9564bec1fe98ad0096bba519e3d Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Thu, 14 May 2026 14:47:10 -0700 Subject: [PATCH 10/28] feat: Implement GCPSkillRegistry in ADK This CL implements the class in the integrations folder, used specifically for the Skill Registry API. Co-authored-by: Kathy Wu PiperOrigin-RevId: 915627057 --- .../gcp_skill_registry_agent/__init__.py | 15 ++ .../samples/gcp_skill_registry_agent/agent.py | 40 ++++ .../integrations/skill_registry/__init__.py | 19 ++ .../skill_registry/gcp_skill_registry.py | 93 +++++++++ src/google/adk/skills/_utils.py | 93 +++++++++ .../skill_registry/test_gcp_skill_registry.py | 184 ++++++++++++++++++ tests/unittests/skills/test__utils.py | 23 +++ 7 files changed, 467 insertions(+) create mode 100644 contributing/samples/gcp_skill_registry_agent/__init__.py create mode 100644 contributing/samples/gcp_skill_registry_agent/agent.py create mode 100644 src/google/adk/integrations/skill_registry/__init__.py create mode 100644 src/google/adk/integrations/skill_registry/gcp_skill_registry.py create mode 100644 tests/unittests/integrations/skill_registry/test_gcp_skill_registry.py diff --git a/contributing/samples/gcp_skill_registry_agent/__init__.py b/contributing/samples/gcp_skill_registry_agent/__init__.py new file mode 100644 index 0000000000..4015e47d6e --- /dev/null +++ b/contributing/samples/gcp_skill_registry_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/gcp_skill_registry_agent/agent.py b/contributing/samples/gcp_skill_registry_agent/agent.py new file mode 100644 index 0000000000..558b6ecfff --- /dev/null +++ b/contributing/samples/gcp_skill_registry_agent/agent.py @@ -0,0 +1,40 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample agent demonstrating the use of GCPSkillRegistry.""" + +from google.adk import Agent +from google.adk.integrations.skill_registry import GCPSkillRegistry +from google.adk.tools.skill_toolset import SkillToolset + +# Initialize GCP Skill Registry +registry = GCPSkillRegistry( + project_id="your-project-id", location="us-central1" +) + +# Initialize SkillToolset with registry +skill_toolset = SkillToolset(skills=[], registry=registry) + +root_agent = Agent( + model="gemini-2.5-flash", + name="skill_registry_agent", + description=( + "An agent that can discover and load skills from GCP Skill Registry." + ), + instruction=( + "Use search_skills to find skills and load_skill to load them if" + " needed." + ), + tools=[skill_toolset], +) diff --git a/src/google/adk/integrations/skill_registry/__init__.py b/src/google/adk/integrations/skill_registry/__init__.py new file mode 100644 index 0000000000..5cfd76a29d --- /dev/null +++ b/src/google/adk/integrations/skill_registry/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Skill Registry integrations.""" + +from .gcp_skill_registry import GCPSkillRegistry + +__all__ = ["GCPSkillRegistry"] diff --git a/src/google/adk/integrations/skill_registry/gcp_skill_registry.py b/src/google/adk/integrations/skill_registry/gcp_skill_registry.py new file mode 100644 index 0000000000..277913c1b4 --- /dev/null +++ b/src/google/adk/integrations/skill_registry/gcp_skill_registry.py @@ -0,0 +1,93 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GCP Skill Registry implementation.""" + +from __future__ import annotations + +import asyncio +import base64 +import os + +from google.adk.skills import _utils +from google.adk.skills import models +from google.adk.skills.skill_registry import SkillRegistry +import vertexai + + +class GCPSkillRegistry(SkillRegistry): + """GCP implementation of SkillRegistry using GCP Skill Registry API.""" + + def __init__( + self, *, project_id: str | None = None, location: str | None = None + ): + """Initializes the GCP Skill Registry. + + Args: + project_id: Optional GCP project ID. If omitted, loads from environment. + location: Optional GCP location. If omitted, loads from environment. + """ + self.project_id = project_id or os.environ.get("GOOGLE_CLOUD_PROJECT") + self.location = location or os.environ.get("GOOGLE_CLOUD_LOCATION") + self._client = vertexai.Client( + project=self.project_id, + location=self.location, + http_options={ + "api_version": "v1beta1", + }, + ).aio + + async def get_skill(self, *, name: str) -> models.Skill: + """Fetches a skill from the registry. + + Args: + name: The name of the skill. + + Returns: + A Skill object. + """ + full_name = ( + f"projects/{self.project_id}/locations/{self.location}/skills/{name}" + ) + skill_resource = await self._client.skills.get(name=full_name) + + zip_bytes_base64 = skill_resource.zipped_filesystem + if not zip_bytes_base64: + raise ValueError(f"Skill '{name}' does not contain zipped filesystem.") + + zip_bytes = base64.b64decode(zip_bytes_base64) + + return await asyncio.to_thread(_utils._load_skill_from_zip_bytes, zip_bytes) + + async def search_skills(self, *, query: str) -> list[models.Frontmatter]: + """Searches for skills in the registry. + + Args: + query: The search query. + + Returns: + A list of Frontmatter objects for discovery. + """ + response = await self._client.skills.retrieve(query=query) + + results = [] + if response.retrieved_skills: + for s in response.retrieved_skills: + results.append( + models.Frontmatter( + name=s.skill_name.split("/")[-1] if s.skill_name else "", + description=s.description or "", + ) + ) + return results diff --git a/src/google/adk/skills/_utils.py b/src/google/adk/skills/_utils.py index cab70a8d4b..3270025e01 100644 --- a/src/google/adk/skills/_utils.py +++ b/src/google/adk/skills/_utils.py @@ -16,9 +16,12 @@ from __future__ import annotations +import io import logging import pathlib +from typing import Dict from typing import Union +import zipfile from google.auth import credentials as auth from google.cloud import storage @@ -177,6 +180,96 @@ def _load_skill_from_dir(skill_dir: Union[str, pathlib.Path]) -> models.Skill: ) +def _load_skill_from_zip_bytes(zip_bytes: bytes) -> models.Skill: + """Load a complete skill directly from in-memory zip file bytes. + + Args: + zip_bytes: The raw bytes of the zip file containing the skill. + + Returns: + Skill object with all components loaded. + + Raises: + FileNotFoundError: If SKILL.md is not found in the archive. + ValueError: If SKILL.md is invalid or contains dangerous paths. + """ + with zipfile.ZipFile(io.BytesIO(zip_bytes)) as z: + # Security check for zip slip + for member in z.infolist(): + filename = member.filename + if ( + filename.startswith("/") + or filename.startswith("../") + or "/../" in filename + ): + raise ValueError(f"Dangerous zip entry ignored: {filename}") + + # Find SKILL.md or skill.md + skill_md_content = None + for name in ("SKILL.md", "skill.md"): + try: + skill_md_content = z.read(name).decode("utf-8") + break + except KeyError: + continue + + if skill_md_content is None: + raise FileNotFoundError("SKILL.md not found in zipped filesystem.") + + parsed, body = _parse_skill_md_content(skill_md_content) + skill_name = parsed.get("name") + if not skill_name: + raise ValueError("SKILL.md frontmatter must contain 'name'") + if ( + not isinstance(skill_name, str) + or pathlib.Path(skill_name).name != skill_name + ): + raise ValueError(f"Invalid skill name in SKILL.md: {skill_name}") + + frontmatter = models.Frontmatter.model_validate(parsed) + + # Helper to load files under a directory prefix inside the zip + def _load_zip_dir(prefix: str) -> dict[str, str]: + result = {} + if not prefix.endswith("/"): + prefix += "/" + for info in z.infolist(): + if info.is_dir(): + continue + if info.filename.startswith(prefix): + # Avoid cache files or similar + if "__pycache__" in info.filename: + continue + relative_path = info.filename[len(prefix) :] + if not relative_path: + continue + try: + result[relative_path] = z.read(info).decode("utf-8") + except UnicodeDecodeError: + continue + return result + + references = _load_zip_dir("references") + assets = _load_zip_dir("assets") + raw_scripts = _load_zip_dir("scripts") + scripts = { + name: models.Script(src=content) + for name, content in raw_scripts.items() + } + + resources = models.Resources( + references=references, + assets=assets, + scripts=scripts, + ) + + return models.Skill( + frontmatter=frontmatter, + instructions=body, + resources=resources, + ) + + def _validate_skill_dir( skill_dir: Union[str, pathlib.Path], ) -> list[str]: diff --git a/tests/unittests/integrations/skill_registry/test_gcp_skill_registry.py b/tests/unittests/integrations/skill_registry/test_gcp_skill_registry.py new file mode 100644 index 0000000000..bf410456e9 --- /dev/null +++ b/tests/unittests/integrations/skill_registry/test_gcp_skill_registry.py @@ -0,0 +1,184 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for GCP Skill Registry.""" + +import base64 +import os +from unittest import mock +import zipfile + +from google.adk.integrations.skill_registry.gcp_skill_registry import GCPSkillRegistry +import pytest + + +@pytest.fixture(autouse=True) +def mock_env(): + """Fixture to mock environment variables.""" + with mock.patch.dict( + os.environ, + { + "GOOGLE_CLOUD_PROJECT": "test-project", + "GOOGLE_CLOUD_LOCATION": "us-central1", + }, + ): + yield + + +@pytest.fixture +def mock_vertex_client(): + """Fixture to mock vertexai.Client.""" + with mock.patch("vertexai.Client") as mock_client_class: + mock_client = mock_client_class.return_value + yield mock_client + + +def _create_fake_zip_bytes(): + """Creates a fake zip file in memory and returns its bytes.""" + import io + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as z: + z.writestr( + "SKILL.md", "---\nname: my-skill\ndescription: test\n---\n# My Skill\n" + ) + return zip_buffer.getvalue() + + +@pytest.mark.asyncio +async def test_get_skill_success(mock_vertex_client): + """Verifies that get_skill successfully fetches and loads a skill in memory.""" + registry = GCPSkillRegistry() + + fake_zip = _create_fake_zip_bytes() + fake_zip_base64 = base64.b64encode(fake_zip).decode("utf-8") + + mock_skill_resource = mock.MagicMock() + mock_skill_resource.zipped_filesystem = fake_zip_base64 + + mock_vertex_client.aio.skills.get = mock.AsyncMock( + return_value=mock_skill_resource + ) + + skill = await registry.get_skill(name="my-skill") + + assert skill.frontmatter.name == "my-skill" + assert skill.frontmatter.description == "test" + assert skill.instructions == "# My Skill" + mock_vertex_client.aio.skills.get.assert_called_once_with( + name="projects/test-project/locations/us-central1/skills/my-skill" + ) + + +@pytest.mark.asyncio +async def test_search_skills_success(mock_vertex_client): + """Verifies that search_skills successfully returns frontmatter list.""" + registry = GCPSkillRegistry() + + mock_skill1 = mock.MagicMock() + mock_skill1.skill_name = ( + "projects/test-project/locations/us-central1/skills/skill1" + ) + mock_skill1.description = "Description 1" + + mock_skill2 = mock.MagicMock() + mock_skill2.skill_name = ( + "projects/test-project/locations/us-central1/skills/skill2" + ) + mock_skill2.description = "Description 2" + + mock_response = mock.MagicMock() + mock_response.retrieved_skills = [mock_skill1, mock_skill2] + + mock_vertex_client.aio.skills.retrieve = mock.AsyncMock( + return_value=mock_response + ) + + results = await registry.search_skills(query="query") + + assert len(results) == 2 + assert results[0].name == "skill1" + assert results[0].description == "Description 1" + assert results[1].name == "skill2" + assert results[1].description == "Description 2" + mock_vertex_client.aio.skills.retrieve.assert_called_once_with(query="query") + + +@pytest.mark.asyncio +async def test_get_skill_raises_on_missing_zip(mock_vertex_client): + """Verifies that get_skill raises error if zip filesystem is missing.""" + registry = GCPSkillRegistry() + + mock_skill_resource = mock.MagicMock() + mock_skill_resource.zipped_filesystem = None + + mock_vertex_client.aio.skills.get = mock.AsyncMock( + return_value=mock_skill_resource + ) + + with pytest.raises(ValueError, match="does not contain zipped filesystem"): + await registry.get_skill(name="my-skill") + + +@pytest.mark.asyncio +async def test_get_skill_raises_on_zip_slip(mock_vertex_client): + """Verifies that get_skill raises error if zip contains dangerous paths.""" + registry = GCPSkillRegistry() + + import io + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as z: + z.writestr("../evil.txt", "malicious content") + z.writestr( + "SKILL.md", "---\nname: my-skill\ndescription: test\n---\n# My Skill\n" + ) + fake_zip = zip_buffer.getvalue() + fake_zip_base64 = base64.b64encode(fake_zip).decode("utf-8") + + mock_skill_resource = mock.MagicMock() + mock_skill_resource.zipped_filesystem = fake_zip_base64 + + mock_vertex_client.aio.skills.get = mock.AsyncMock( + return_value=mock_skill_resource + ) + + with pytest.raises(ValueError, match="Dangerous zip entry ignored"): + await registry.get_skill(name="my-skill") + + +@pytest.mark.asyncio +async def test_get_skill_raises_on_invalid_skill_name(mock_vertex_client): + """Verifies that get_skill raises error if skill name is invalid.""" + registry = GCPSkillRegistry() + + import io + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as z: + z.writestr( + "SKILL.md", "---\nname: ../evil\ndescription: test\n---\n# My Skill\n" + ) + fake_zip = zip_buffer.getvalue() + fake_zip_base64 = base64.b64encode(fake_zip).decode("utf-8") + + mock_skill_resource = mock.MagicMock() + mock_skill_resource.zipped_filesystem = fake_zip_base64 + + mock_vertex_client.aio.skills.get = mock.AsyncMock( + return_value=mock_skill_resource + ) + + with pytest.raises(ValueError, match="Invalid skill name in SKILL.md"): + await registry.get_skill(name="my-skill") diff --git a/tests/unittests/skills/test__utils.py b/tests/unittests/skills/test__utils.py index ae5df00c1c..1975c0a494 100644 --- a/tests/unittests/skills/test__utils.py +++ b/tests/unittests/skills/test__utils.py @@ -14,12 +14,15 @@ """Unit tests for skill utilities.""" +import io from unittest import mock +import zipfile from google.adk.skills import list_skills_in_dir from google.adk.skills import list_skills_in_gcs_dir as _list_skills_in_gcs_dir from google.adk.skills import load_skill_from_dir as _load_skill_from_dir from google.adk.skills import load_skill_from_gcs_dir as _load_skill_from_gcs_dir +from google.adk.skills._utils import _load_skill_from_zip_bytes from google.adk.skills._utils import _read_skill_properties from google.adk.skills._utils import _validate_skill_dir import pytest @@ -340,3 +343,23 @@ def test_list_skills_in_dir_missing_base_path(tmp_path): skills = list_skills_in_dir(tmp_path / "nonexistent") assert skills == {} + + +def test__load_skill_from_zip_bytes(): + """Tests loading a skill directly from in-memory zip file bytes.""" + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as z: + z.writestr( + "SKILL.md", + "---\nname: my-skill\ndescription: A skill\n---\nBody instructions", + ) + z.writestr("references/ref1.md", "ref1 content") + z.writestr("scripts/script1.sh", "echo hello") + + skill = _load_skill_from_zip_bytes(zip_buffer.getvalue()) + assert skill.frontmatter.name == "my-skill" + assert skill.frontmatter.description == "A skill" + assert skill.instructions == "Body instructions" + assert skill.resources.get_reference("ref1.md") == "ref1 content" + assert skill.resources.get_script("script1.sh").src == "echo hello" From 27e71f3bcff5dd0cf5f49ced6a10a8da76772ba9 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Thu, 14 May 2026 15:42:58 -0700 Subject: [PATCH 11/28] chore: Update Gemini Actions workflows for enhanced security and community alignment - Restrict invoke and review triggers purely to explicit user comments. - Enforce strict author association verification (OWNER, MEMBER, COLLABORATOR). - Enforce strict targeting assertion to ensure pull requests act on the main branch. - Synchronize prompt constraints and GitHub action tools with the community catalog. - Refine action API key options to uniformly target secrets.GOOGLE_API_KEY. Co-authored-by: Shangjie Chen PiperOrigin-RevId: 915654346 --- .github/workflows/gemini-dispatch.yml | 65 ++++++++++++++------------- .github/workflows/gemini-invoke.yml | 30 +++---------- .github/workflows/gemini-review.yml | 10 ++--- 3 files changed, 44 insertions(+), 61 deletions(-) diff --git a/.github/workflows/gemini-dispatch.yml b/.github/workflows/gemini-dispatch.yml index 1f8c243fda..9c2bf8ec9d 100644 --- a/.github/workflows/gemini-dispatch.yml +++ b/.github/workflows/gemini-dispatch.yml @@ -7,14 +7,6 @@ on: pull_request_review: types: - 'submitted' - pull_request: - types: - - 'opened' - - 'ready_for_review' - issues: - types: - - 'opened' - - 'reopened' issue_comment: types: - 'created' @@ -44,19 +36,11 @@ jobs: env | grep '^DEBUG_' dispatch: - # For PRs: only if not from a fork - # For issues: only on open/reopen - # For comments: only if user types @gemini-cli and is OWNER/MEMBER/COLLABORATOR + # Only trigger if user types @gemini-cli and author association is OWNER, MEMBER, or COLLABORATOR if: |- - ( - github.event_name == 'pull_request' && - github.event.pull_request.head.repo.fork == false && - github.event.pull_request.draft == false - ) || ( - github.event.sender.type == 'User' && - startsWith(github.event.comment.body || github.event.review.body || github.event.issue.body, '@gemini-cli') && - contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association || github.event.review.author_association || github.event.issue.author_association) - ) + github.event.sender.type == 'User' && + startsWith(github.event.comment.body || github.event.review.body, '@gemini-cli') && + contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association || github.event.review.author_association) runs-on: 'ubuntu-latest' permissions: contents: 'read' @@ -82,22 +66,43 @@ jobs: - name: 'Extract command' id: 'extract_command' - uses: 'actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea' # ratchet:actions/github-script@v7 + uses: 'actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd' # ratchet:actions/github-script@v8.0.0 env: - EVENT_TYPE: '${{ github.event_name }}.${{ github.event.action }}' - REQUEST: '${{ github.event.comment.body || github.event.review.body || github.event.issue.body }}' + REQUEST: '${{ github.event.comment.body || github.event.review.body }}' + IS_PR: '${{ !!(github.event.pull_request || github.event.issue.pull_request) }}' with: script: | - const eventType = process.env.EVENT_TYPE; const request = process.env.REQUEST; + const isPr = process.env.IS_PR === 'true'; core.setOutput('request', request); - if (eventType === 'pull_request.opened' || eventType === 'pull_request.ready_for_review') { - core.setOutput('command', 'review'); - } else if (request.startsWith("@gemini-cli /review")) { - core.setOutput('command', 'review'); - const additionalContext = request.replace(/^@gemini-cli \/review/, '').trim(); - core.setOutput('additional_context', additionalContext); + // Ensure request is on a PR targeting the main branch + let baseRef = ''; + if (context.eventName === 'pull_request_review' || context.eventName === 'pull_request_review_comment') { + baseRef = context.payload.pull_request.base.ref; + } else if (context.eventName === 'issue_comment' && context.payload.issue.pull_request) { + const pr = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.payload.issue.number + }); + baseRef = pr.data.base.ref; + } + + if (isPr && baseRef !== 'main') { + console.log(`Skipping: PR targets '${baseRef}', but only 'main' is allowed.`); + core.setOutput('command', 'fallthrough'); + return; + } + + if (request.startsWith("@gemini-cli /review")) { + if (isPr) { + core.setOutput('command', 'review'); + const additionalContext = request.replace(/^@gemini-cli \/review/, '').trim(); + core.setOutput('additional_context', additionalContext); + } else { + core.setOutput('command', 'fallthrough'); + } } else if (request.startsWith("@gemini-cli")) { const additionalContext = request.replace(/^@gemini-cli/, '').trim(); core.setOutput('command', 'invoke'); diff --git a/.github/workflows/gemini-invoke.yml b/.github/workflows/gemini-invoke.yml index 81b669a327..5138d6f729 100644 --- a/.github/workflows/gemini-invoke.yml +++ b/.github/workflows/gemini-invoke.yml @@ -52,12 +52,14 @@ jobs: ISSUE_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' REPOSITORY: '${{ github.repository }}' ADDITIONAL_CONTEXT: '${{ inputs.additional_context }}' + # Required to allow the Gemini CLI to process files in the ephemeral GitHub Actions runner + GEMINI_CLI_TRUST_WORKSPACE: 'true' with: gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}' gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}' gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}' gcp_workload_identity_provider: '${{ vars.GCP_WIF_PROVIDER }}' - gemini_api_key: '${{ secrets.GEMINI_API_KEY }}' + gemini_api_key: '${{ secrets.GOOGLE_API_KEY }}' gemini_cli_version: '${{ vars.GEMINI_CLI_VERSION }}' gemini_debug: '${{ fromJSON(vars.GEMINI_DEBUG || vars.ACTIONS_STEP_DEBUG || false) }}' gemini_model: '${{ vars.GEMINI_MODEL }}' @@ -91,32 +93,12 @@ jobs: "GITHUB_PERSONAL_ACCESS_TOKEN", "ghcr.io/github/github-mcp-server:v0.27.0" ], - "includeTools": [ - "add_issue_comment", - "issue_read", - "list_issues", - "search_issues", - "pull_request_read", - "list_pull_requests", - "search_pull_requests", - "get_commit", - "get_file_contents", - "list_commits", - "search_code" - ], "env": { "GITHUB_PERSONAL_ACCESS_TOKEN": "${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}" } } - }, - "tools": { - "core": [ - "run_shell_command(cat)", - "run_shell_command(echo)", - "run_shell_command(grep)", - "run_shell_command(head)", - "run_shell_command(tail)" - ] } } - prompt: '/gemini-invoke' + prompt: |- + /gemini-invoke + [IMPORTANT] Do not generate execution plans and do not ask for approval (such as suggesting `@gemini-cli /approve`). Perform the requested task or answer the question directly and immediately. diff --git a/.github/workflows/gemini-review.yml b/.github/workflows/gemini-review.yml index d29b030b5d..9c1b1bf442 100644 --- a/.github/workflows/gemini-review.yml +++ b/.github/workflows/gemini-review.yml @@ -51,14 +51,15 @@ jobs: PULL_REQUEST_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' REPOSITORY: '${{ github.repository }}' ADDITIONAL_CONTEXT: '${{ inputs.additional_context }}' - GEMINI_API_KEY: '${{ secrets.GEMINI_API_KEY }}' + GEMINI_API_KEY: '${{ secrets.GOOGLE_API_KEY }}' + # Required to allow the Gemini CLI to process files in the ephemeral GitHub Actions runner GEMINI_CLI_TRUST_WORKSPACE: 'true' with: gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}' gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}' gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}' gcp_workload_identity_provider: '${{ vars.GCP_WIF_PROVIDER }}' - gemini_api_key: '${{ secrets.GEMINI_API_KEY }}' + gemini_api_key: '${{ secrets.GOOGLE_API_KEY }}' gemini_cli_version: '${{ vars.GEMINI_CLI_VERSION }}' gemini_debug: '${{ fromJSON(vars.GEMINI_DEBUG || vars.ACTIONS_STEP_DEBUG || false) }}' gemini_model: '${{ vars.GEMINI_MODEL }}' @@ -90,11 +91,6 @@ jobs: "GITHUB_PERSONAL_ACCESS_TOKEN", "ghcr.io/github/github-mcp-server:v0.27.0" ], - "includeTools": [ - "pull_request_read", - "add_comment_to_pending_review", - "pull_request_review_write" - ], "env": { "GITHUB_PERSONAL_ACCESS_TOKEN": "${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}" } From 115124cdf413859c7f634ce995113e4de6cf5ff7 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 15 May 2026 08:22:03 -0700 Subject: [PATCH 12/28] feat: add support for A2aAgentExecutor factory in to_a2a() function PiperOrigin-RevId: 916015512 --- src/google/adk/a2a/utils/agent_to_a2a.py | 31 +++++++++++-------- .../unittests/a2a/utils/test_agent_to_a2a.py | 8 ++--- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 2744477e3c..48454a632a 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -35,6 +35,7 @@ from ...runners import Runner from ...sessions.in_memory_session_service import InMemorySessionService from ..executor.a2a_agent_executor import A2aAgentExecutor +from ..executor.config import A2aAgentExecutorConfig from ..experimental import a2a_experimental from .agent_card_builder import AgentCardBuilder @@ -86,6 +87,7 @@ def to_a2a( task_store: TaskStore | None = None, runner: Runner | None = None, lifespan: Callable[[Starlette], AsyncIterator[None]] | None = None, + agent_executor_factory: Callable[[Runner], A2aAgentExecutor] | None = None, ) -> Starlette: """Convert an ADK agent to a A2A Starlette application. @@ -95,20 +97,21 @@ def to_a2a( port: The port for the A2A RPC URL (default: 8000) protocol: The protocol for the A2A RPC URL (default: "http") agent_card: Optional pre-built AgentCard object or path to agent card - JSON. If not provided, will be built automatically from the - agent. + JSON. If not provided, will be built automatically from the agent. push_config_store: Optional A2A push notification config store. If not - provided, an in-memory store will be created so push-notification - config RPC methods are supported. + provided, an in-memory store will be created so push-notification config + RPC methods are supported. task_store: Optional A2A task store for persisting task state. If not provided, an in-memory store will be created. runner: Optional pre-built Runner object. If not provided, a default - runner will be created using in-memory services. - lifespan: Optional async context manager for Starlette lifespan - events. Use this to run startup/shutdown logic (e.g. initializing - database connections or loading resources). The context manager - receives the Starlette app instance and can set state on - ``app.state``. + runner will be created using in-memory services. + lifespan: Optional async context manager for Starlette lifespan events. + Use this to run startup/shutdown logic (e.g. initializing database + connections or loading resources). The context manager receives the + Starlette app instance and can set state on ``app.state``. + agent_executor_factory: Optional factory function that creates an instance + of A2aAgentExecutor. If not provided, a default A2aAgentExecutor will be + created. Returns: A Starlette application that can be run with uvicorn @@ -148,7 +151,7 @@ async def lifespan(app): adk_logger = logging.getLogger("google_adk") adk_logger.setLevel(logging.INFO) - async def create_runner() -> Runner: + def create_runner() -> Runner: """Create a runner for the agent.""" return Runner( app_name=agent.name or "adk_agent", @@ -164,8 +167,10 @@ async def create_runner() -> Runner: if task_store is None: task_store = InMemoryTaskStore() - agent_executor = A2aAgentExecutor( - runner=runner or create_runner, + agent_executor = ( + agent_executor_factory(runner or create_runner()) + if agent_executor_factory is not None + else A2aAgentExecutor(runner=runner or create_runner) ) if push_config_store is None: diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index b5012a6535..752df0ded3 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -357,7 +357,7 @@ def test_to_a2a_creates_runner_with_correct_services( @patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder") @patch("google.adk.a2a.utils.agent_to_a2a.Starlette") @patch("google.adk.a2a.utils.agent_to_a2a.Runner") - async def test_create_runner_function_creates_runner_correctly( + def test_create_runner_function_creates_runner_correctly( self, mock_runner_class, mock_starlette_class, @@ -391,7 +391,7 @@ async def test_create_runner_function_creates_runner_correctly( runner_func = call_args[1]["runner"] # Call the runner function to verify it creates Runner correctly - runner_result = await runner_func() + runner_result = runner_func() # Verify Runner was created with correct parameters mock_runner_class.assert_called_once_with( @@ -420,7 +420,7 @@ async def test_create_runner_function_creates_runner_correctly( @patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder") @patch("google.adk.a2a.utils.agent_to_a2a.Starlette") @patch("google.adk.a2a.utils.agent_to_a2a.Runner") - async def test_create_runner_function_with_agent_without_name( + def test_create_runner_function_with_agent_without_name( self, mock_runner_class, mock_starlette_class, @@ -455,7 +455,7 @@ async def test_create_runner_function_with_agent_without_name( runner_func = call_args[1]["runner"] # Call the runner function to verify it creates Runner correctly - await runner_func() + runner_func() # Verify Runner was created with default app_name when agent has no name mock_runner_class.assert_called_once_with( From 85f397d20f8b32cdfd074463ff505a06c8535ddf Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 15 May 2026 09:16:07 -0700 Subject: [PATCH 13/28] fix: avoid pre-serializing dict values in Interactions API to prevent double-escaping PiperOrigin-RevId: 916037238 --- src/google/adk/models/interactions_utils.py | 10 ++--- .../models/test_interactions_utils.py | 41 +++++++++++++++---- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/google/adk/models/interactions_utils.py b/src/google/adk/models/interactions_utils.py index be1c2a6ab2..89ffe6be71 100644 --- a/src/google/adk/models/interactions_utils.py +++ b/src/google/adk/models/interactions_utils.py @@ -111,12 +111,12 @@ def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]: ).decode('utf-8') return result elif part.function_response is not None: - # Convert the function response to a string for the interactions API - # The interactions API expects result to be either a string or items list + # Pass the function response through to the interactions API. + # Dict and list values are passed directly — the Interactions API handles + # JSON serialization internally. Pre-serializing with json.dumps() would + # cause double-escaping. result = part.function_response.response - if isinstance(result, dict): - result = json.dumps(result) - elif not isinstance(result, str): + if not isinstance(result, (dict, str, list)): result = str(result) logger.debug( 'Converting function_response: name=%s, call_id=%s', diff --git a/tests/unittests/models/test_interactions_utils.py b/tests/unittests/models/test_interactions_utils.py index 3a6b964ac1..118a925ab6 100644 --- a/tests/unittests/models/test_interactions_utils.py +++ b/tests/unittests/models/test_interactions_utils.py @@ -280,11 +280,9 @@ def test_function_response_dict(self): assert result['type'] == 'function_result' assert result['call_id'] == 'call_123' assert result['name'] == 'get_weather' - # Dict should be JSON serialized - assert json.loads(result['result']) == { - 'temperature': 20, - 'condition': 'sunny', - } + # Dict should be passed through directly (not JSON-serialized). + assert result['result'] == {'temperature': 20, 'condition': 'sunny'} + assert isinstance(result['result'], dict) def test_function_response_simple(self): """Test converting a function response Part with simple response.""" @@ -299,8 +297,37 @@ def test_function_response_simple(self): assert result['type'] == 'function_result' assert result['call_id'] == 'call_123' assert result['name'] == 'check_weather' - # Dict should be JSON serialized - assert json.loads(result['result']) == {'message': 'Weather is sunny'} + # Dict should be passed through directly (not JSON-serialized). + assert result['result'] == {'message': 'Weather is sunny'} + + def test_function_response_dict_not_double_serialized(self): + """Regression test: avoid double-serializing bash tool outputs. + + Bash tool responses contain JSON structures (stdout/stderr). When these + dict responses were json.dumps()'d before being sent to the Interactions + API, the API's own serialization would escape the already-escaped content, + producing unreadable output like: + {"result":"\\\"{\\\\\\\"error\\\\\\\":\\\\\\\"...\\\\\\\"}\\\"" + """ + bash_response = { + 'stdout': '{"name": "test", "version": "1.0"}\n', + 'stderr': '', + } + part = types.Part( + function_response=types.FunctionResponse( + id='call_bash', + name='bash', + response=bash_response, + ) + ) + result = interactions_utils.convert_part_to_interaction_content(part) + # The result value must be the dict itself, NOT a JSON string. + assert isinstance(result['result'], dict) + assert result['result'] == bash_response + # Verify there's no double-escaping: if result were a JSON string, + # serializing it again would add backslashes before the internal quotes. + wire_json = json.dumps(result) + assert '\\\\' not in wire_json def test_inline_data_image(self): """Test converting an inline image Part.""" From f5b765d608fa22b492a9ccf5a8ab9b99e27fbe68 Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 15 May 2026 10:36:16 -0700 Subject: [PATCH 14/28] chore(build): exclude nested README.md files from sdist packaging This ensures that adding README.md files to subdirectories (as discussed for new folders and integrations) won't result in them being included in the published package. Co-authored-by: George Weale PiperOrigin-RevId: 916075206 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d7bacd8a10..10975105ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -166,7 +166,7 @@ scripts.adk = "google.adk.cli:main" [tool.flit.sdist] include = [ 'src/**/*', 'README.md', 'pyproject.toml', 'LICENSE' ] -exclude = [ 'src/**/*.sh' ] +exclude = [ 'src/**/*.sh', 'src/**/README.md' ] [tool.flit.module] name = "google.adk" From eed9bd319ffc398fae14c2362c93f986ffe25f67 Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 15 May 2026 11:00:24 -0700 Subject: [PATCH 15/28] fix(evaluation): handle none config in per_turn_user_simulator_quality When judge_model_config is None, LlmRequest raises a ValidationError because it requires a config. We now construct a default GenerateContentConfig if one isn't provided. Close #5677 Co-authored-by: George Weale PiperOrigin-RevId: 916087055 --- .../per_turn_user_simulator_quality_v1.py | 6 ++- ...est_per_turn_user_simulation_quality_v1.py | 44 ++++++++++++++++++- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py b/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py index a95eb87d88..239fa31a71 100644 --- a/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py +++ b/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py @@ -325,6 +325,10 @@ async def _evaluate_intermediate_turn( previous_invocations=invocation_history, ) + config = ( + self._llm_options.judge_model_config + or genai_types.GenerateContentConfig() + ) llm_request = LlmRequest( model=self._llm_options.judge_model, contents=[ @@ -333,7 +337,7 @@ async def _evaluate_intermediate_turn( role="user", ) ], - config=self._llm_options.judge_model_config, + config=config, ) add_default_retry_options_if_not_present(llm_request) num_samples = self._llm_options.num_samples diff --git a/tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py b/tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py index 6798143df3..be0a0394c5 100644 --- a/tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py +++ b/tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py @@ -634,10 +634,10 @@ def test_aggregate_conversation_percentage_below_threshold_produces_failure(): async def test_evaluate_invocations_all_pass(): evaluator = _create_test_evaluator() - async def sample_llm_valid(*args, **kwargs): + async def sample_llm_valid(*args, **kwargs): # pylint: disable=unused-argument return AutoRaterScore(score=1.0) - evaluator._sample_llm = sample_llm_valid + evaluator._sample_llm = sample_llm_valid # pylint: disable=protected-access starting_prompt = "first user prompt." conversation_scenario = _create_test_conversation_scenario( starting_prompt=starting_prompt @@ -656,3 +656,43 @@ async def sample_llm_valid(*args, **kwargs): assert len(result.per_invocation_results) == 2 assert result.per_invocation_results[0].score == 1.0 assert result.per_invocation_results[1].score == 1.0 + + +@pytest.mark.asyncio +async def test_evaluate_invocations_none_judge_model_config(): + """Tests evaluation when judge_model_config is None.""" + evaluator = PerTurnUserSimulatorQualityV1( + EvalMetric( + metric_name="test_per_turn_user_simulator_quality_v1", + threshold=1.0, + criterion=LlmBackedUserSimulatorCriterion( + threshold=1.0, + stop_signal="test stop signal", + judge_model_options=JudgeModelOptions( + judge_model="gemini-2.5-flash", + judge_model_config=None, + num_samples=1, + ), + ), + ), + ) + + async def sample_llm_valid(*args, **kwargs): # pylint: disable=unused-argument + return AutoRaterScore(score=1.0) + + evaluator._sample_llm = sample_llm_valid # pylint: disable=protected-access + starting_prompt = "first user prompt." + conversation_scenario = _create_test_conversation_scenario( + starting_prompt=starting_prompt + ) + invocations = _create_test_invocations( + [starting_prompt, "model 1.", "user 2.", "model 2."] + ) + result = await evaluator.evaluate_invocations( + actual_invocations=invocations, + expected_invocations=None, + conversation_scenario=conversation_scenario, + ) + + assert result.overall_score == 1.0 + assert result.overall_eval_status == EvalStatus.PASSED From bb2efb6bd234e3235c47b3245676581f6022b458 Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 15 May 2026 11:43:14 -0700 Subject: [PATCH 16/28] fix: Prevent compaction of events involved in Human-in-the-Loop interactions This change introduces logic to identify events containing requests for tool confirmation or auth credentials. The compaction process will now stop before any such "Human-in-the-Loop" (HITL) events, ensuring that the full context of the interaction is preserved and not summarized away. This applies to both sliding window and token threshold compaction strategies. Co-authored-by: George Weale PiperOrigin-RevId: 916108771 --- src/google/adk/apps/compaction.py | 45 +++ tests/unittests/apps/test_compaction.py | 455 ++++++++++++++++++++++++ 2 files changed, 500 insertions(+) diff --git a/src/google/adk/apps/compaction.py b/src/google/adk/apps/compaction.py index b3b49bbd0b..1e33003a95 100644 --- a/src/google/adk/apps/compaction.py +++ b/src/google/adk/apps/compaction.py @@ -270,6 +270,9 @@ def _events_to_compact_for_token_threshold( events_to_compact = _truncate_events_before_pending_function_call( events_to_compact, pending_ids ) + events_to_compact = _truncate_events_before_hitl_signal( + events_to_compact, _resolved_hitl_call_ids(events) + ) if not events_to_compact: return [] @@ -344,6 +347,45 @@ def _truncate_events_before_pending_function_call( return events +def _resolved_hitl_call_ids(events: list[Event]) -> set[str]: + """Returns HITL call ids resolved by a later function_response in `events`.""" + hitl_position: dict[str, int] = {} + resolved: set[str] = set() + for index, event in enumerate(events): + if event.actions: + for call_id in event.actions.requested_tool_confirmations: + hitl_position.setdefault(call_id, index) + for call_id in event.actions.requested_auth_configs: + hitl_position.setdefault(call_id, index) + for resp_id in _event_function_response_ids(event): + hitl_pos = hitl_position.get(resp_id) + if hitl_pos is not None and index > hitl_pos: + resolved.add(resp_id) + return resolved + + +def _is_pending_hitl(event: Event, resolved_call_ids: set[str]) -> bool: + """Returns True if the event has an HITL request not in `resolved_call_ids`.""" + if not event.actions: + return False + requested = set(event.actions.requested_tool_confirmations) | set( + event.actions.requested_auth_configs + ) + if not requested: + return False + return bool(requested - resolved_call_ids) + + +def _truncate_events_before_hitl_signal( + events: list[Event], resolved_call_ids: set[str] +) -> list[Event]: + """Returns the leading contiguous events before any pending HITL request.""" + for index, event in enumerate(events): + if _is_pending_hitl(event, resolved_call_ids): + return events[:index] + return events + + def _safe_token_compaction_split_index( *, candidate_events: list[Event], @@ -631,6 +673,9 @@ async def _run_compaction_for_sliding_window( events_to_compact = _truncate_events_before_pending_function_call( events_to_compact, pending_ids ) + events_to_compact = _truncate_events_before_hitl_signal( + events_to_compact, _resolved_hitl_call_ids(events) + ) if not events_to_compact: return None diff --git a/tests/unittests/apps/test_compaction.py b/tests/unittests/apps/test_compaction.py index ba6d99ef38..dde8a51b4b 100644 --- a/tests/unittests/apps/test_compaction.py +++ b/tests/unittests/apps/test_compaction.py @@ -23,12 +23,15 @@ from google.adk.apps.compaction import _run_compaction_for_sliding_window import google.adk.apps.compaction as compaction_module from google.adk.apps.llm_event_summarizer import LlmEventSummarizer +from google.adk.auth.auth_schemes import CustomAuthScheme +from google.adk.auth.auth_tool import AuthConfig from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.events.event_actions import EventCompaction from google.adk.flows.llm_flows import contents from google.adk.sessions.base_session_service import BaseSessionService from google.adk.sessions.session import Session +from google.adk.tools.tool_confirmation import ToolConfirmation from google.genai import types from google.genai.types import Content from google.genai.types import Part @@ -1319,3 +1322,455 @@ async def test_completed_function_call_pair_is_still_compacted(self): self.assertIn('inv1', compacted_inv_ids) self.assertEqual(compacted_inv_ids.count('inv2'), 2) self.assertIn('inv3', compacted_inv_ids) + + def _create_hitl_confirmation_event( + self, + timestamp: float, + invocation_id: str, + function_call_id: str, + ) -> Event: + """Creates a function response event with a tool confirmation request.""" + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='agent', + content=Content( + role='user', + parts=[ + Part( + function_response=types.FunctionResponse( + id=function_call_id, + name='tool', + response={ + 'error': 'This tool call requires confirmation.' + }, + ) + ) + ], + ), + actions=EventActions( + requested_tool_confirmations={ + function_call_id: ToolConfirmation( + hint='Please confirm this action.' + ) + }, + ), + ) + + def _create_hitl_auth_event( + self, + timestamp: float, + invocation_id: str, + function_call_id: str, + ) -> Event: + """Creates a function response event with an auth credential request.""" + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='agent', + content=Content( + role='user', + parts=[ + Part( + function_response=types.FunctionResponse( + id=function_call_id, + name='tool', + response={'error': 'Auth required.'}, + ) + ) + ], + ), + actions=EventActions( + requested_auth_configs={ + function_call_id: AuthConfig( + auth_scheme=CustomAuthScheme(type='custom'), + ) + }, + ), + ) + + async def test_sliding_window_excludes_hitl_confirmation_events(self): + """Sliding-window compaction stops before tool confirmation events.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + # inv1: text, inv2: call + HITL confirmation response, inv3: text + # The HITL event (confirmation response) blocks compaction at that point. + # The preceding function call event is not HITL itself and gets compacted. + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_confirmation_event(3.0, 'inv2', 'call-1'), + self._create_event(4.0, 'inv3', 'e3'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 2.0, 'Summary before hitl' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + # inv1 text + inv2 function call are compacted; HITL response is protected. + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2']) + + async def test_sliding_window_excludes_hitl_auth_events(self): + """Sliding-window compaction stops before auth credential events.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_auth_event(3.0, 'inv2', 'call-1'), + self._create_event(4.0, 'inv3', 'e3'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 2.0, 'Summary before auth' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2']) + + async def test_token_threshold_excludes_hitl_confirmation_events(self): + """Token-threshold compaction stops before tool confirmation events.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=0, + ), + ) + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_confirmation_event(3.0, 'inv2', 'call-1'), + self._create_event(4.0, 'inv3', 'e3', prompt_token_count=100), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 2.0, 'Summary inv1-inv2' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2']) + + async def test_token_threshold_excludes_hitl_auth_events(self): + """Token-threshold compaction stops before auth credential events.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=0, + ), + ) + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_auth_event(3.0, 'inv2', 'call-1'), + self._create_event(4.0, 'inv3', 'e3', prompt_token_count=100), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 2.0, 'Summary inv1-inv2' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2']) + + async def test_hitl_event_at_start_blocks_all_compaction(self): + """If the first candidate event has HITL, nothing is compacted.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + # The very first event is an HITL confirmation (no preceding function call). + events = [ + self._create_hitl_confirmation_event(1.0, 'inv1', 'call-1'), + self._create_event(2.0, 'inv2', 'e2'), + self._create_event(3.0, 'inv3', 'e3'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + self.mock_compactor.maybe_summarize_events.assert_not_called() + self.mock_session_service.append_event.assert_not_called() + + async def test_events_before_hitl_are_still_compacted(self): + """Events before the HITL event are compacted normally.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + # inv1, inv2: text events, inv3: call + HITL confirmation, inv4: text + # The HITL event at index 3 blocks compaction; events before it are safe. + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_event(2.0, 'inv2', 'e2'), + self._create_function_call_event(3.0, 'inv3', 'call-1'), + self._create_hitl_confirmation_event(4.0, 'inv3', 'call-1'), + self._create_event(5.0, 'inv4', 'e4'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 3.0, 'Summary inv1-inv3' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + # inv1, inv2 (text) + inv3 function call compact; HITL response is not. + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2', 'inv3']) + + async def test_resolved_hitl_confirmation_is_compactable(self): + """A HITL confirmation followed by a resolved tool response is compactable.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + # inv1: text, inv2: call + HITL request + resolved response (same call-1 + # id), inv3: text. The resolved HITL is safe to compact. + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_confirmation_event(3.0, 'inv2', 'call-1'), + self._create_function_response_event(4.0, 'inv2', 'call-1'), + self._create_event(5.0, 'inv3', 'e3'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 5.0, 'Summary including resolved hitl' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + # Resolved HITL doesn't block; all events through inv3 compact together. + self.assertEqual( + compacted_inv_ids, ['inv1', 'inv2', 'inv2', 'inv2', 'inv3'] + ) + + async def test_resolved_hitl_auth_is_compactable(self): + """A HITL auth request followed by a resolved tool response is compactable.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_auth_event(3.0, 'inv2', 'call-1'), + self._create_function_response_event(4.0, 'inv2', 'call-1'), + self._create_event(5.0, 'inv3', 'e3'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 5.0, 'Summary including resolved auth' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual( + compacted_inv_ids, ['inv1', 'inv2', 'inv2', 'inv2', 'inv3'] + ) + + async def test_sliding_window_resolved_hitl_outside_window_is_compactable( + self, + ): + """A HITL whose resolver lives past the truncation point is compactable.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + # inv1 text, inv2 call_a, inv3 HITL_a, inv4 call_b (unanswered), + # inv5 resolver_a. _truncate_events_before_pending_function_call prunes + # at inv4 because call_b has no response in the session, leaving + # resolver_a outside events_to_compact. The HITL still has to be + # recognized as resolved via the full-session lookup. + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-a'), + self._create_hitl_confirmation_event(3.0, 'inv3', 'call-a'), + self._create_function_call_event(4.0, 'inv4', 'call-b'), + self._create_function_response_event(5.0, 'inv5', 'call-a'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 3.0, 'Summary including resolved hitl' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2', 'inv3']) + + async def test_token_threshold_resolved_hitl_outside_window_is_compactable( + self, + ): + """Token-threshold: HITL with resolver past the truncation point compacts.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=0, + ), + ) + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-a'), + self._create_hitl_confirmation_event(3.0, 'inv3', 'call-a'), + self._create_function_call_event(4.0, 'inv4', 'call-b'), + self._create_function_response_event( + 5.0, 'inv5', 'call-a', prompt_token_count=100 + ), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 3.0, 'Summary including resolved hitl' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2', 'inv3']) From 9a1e75f24256cfe54766c69691247df90dc5558f Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 15 May 2026 11:44:08 -0700 Subject: [PATCH 17/28] fix(models): preserve string content in Anthropic tool_result blocks `part_to_message_block` iterated `content` char-by-char when a tool returned it as a plain string (e.g. `LoadSkillResourceTool`'s `{"content": }`), producing `"H\ne\nl\nl\no"` instead of `"Hello"`. Guard the list branch with `isinstance(..., list)` and add a sibling branch that passes a scalar string through directly, matching Anthropic's `content: str | list[ContentBlockParam]` shape. Close #5358 Co-authored-by: George Weale PiperOrigin-RevId: 916109239 --- src/google/adk/models/anthropic_llm.py | 15 +++-- tests/unittests/models/test_anthropic_llm.py | 59 ++++++++++++++++++++ 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index a175ce5f1a..ec1f38a895 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -219,20 +219,27 @@ def _part_to_message_block( content = "" response_data = part.function_response.response - # Handle response with content array - if "content" in response_data and response_data["content"]: + if ( + "content" in response_data + and isinstance(response_data["content"], list) + and response_data["content"] + ): content_items = [] for item in response_data["content"]: if isinstance(item, dict): - # Handle text content blocks if item.get("type") == "text" and "text" in item: content_items.append(item["text"]) else: - # Handle other structured content content_items.append(str(item)) else: content_items.append(str(item)) content = "\n".join(content_items) if content_items else "" + elif ( + "content" in response_data + and isinstance(response_data["content"], str) + and response_data["content"] + ): + content = response_data["content"] # We serialize to str here # SDK ref: anthropic.types.tool_result_block_param # https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/tool_result_block_param.py diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index bf173c5cc3..d6daf39dc6 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -1096,6 +1096,65 @@ def test_part_to_message_block_empty_response_stays_empty(): assert result["content"] == "" +def test_part_to_message_block_string_content_passes_through(): + """A scalar string `content` value must not be iterated char-by-char.""" + response_part = types.Part.from_function_response( + name="some_tool", + response={"content": "Hello"}, + ) + response_part.function_response.id = "test_id_str_content" + + result = part_to_message_block(response_part) + + assert result["content"] == "Hello" + + +def test_part_to_message_block_load_skill_resource_response(): + """LoadSkillResourceTool returns {content: } as a string.""" + file_text = "Line one\nLine two\nLine three" + response_part = types.Part.from_function_response( + name="load_skill_resource", + response={ + "skill_name": "my-skill", + "file_path": "references/doc.md", + "content": file_text, + }, + ) + response_part.function_response.id = "test_id_load_skill" + + result = part_to_message_block(response_part) + + assert result["content"] == file_text + + +def test_part_to_message_block_empty_string_content_falls_through(): + """`{"content": ""}` falls through to the JSON-dump fallback, not a crash.""" + response_part = types.Part.from_function_response( + name="some_tool", + response={"content": ""}, + ) + response_part.function_response.id = "test_id_empty_content_only" + + result = part_to_message_block(response_part) + + assert json.loads(result["content"]) == {"content": ""} + + +def test_part_to_message_block_empty_content_with_metadata_keeps_metadata(): + """`content: ""` is falsy; sibling keys still reach the model via JSON dump.""" + response_part = types.Part.from_function_response( + name="some_tool", + response={"content": "", "extra": "keep me"}, + ) + response_part.function_response.id = "test_id_empty_content_with_meta" + + result = part_to_message_block(response_part) + + parsed = json.loads(result["content"]) + assert parsed["content"] == "" + assert parsed["extra"] == "keep me" + + # --- Tests for Bug #1: Streaming support --- From 0524797ac75ddd13b1c01cac91e507ba2c42cef0 Mon Sep 17 00:00:00 2001 From: Amaad Martin Date: Fri, 15 May 2026 11:50:59 -0700 Subject: [PATCH 18/28] fix(agents): fix visibility of output_key state delta in callbacks Co-authored-by: Amaad Martin PiperOrigin-RevId: 916112779 --- src/google/adk/runners.py | 44 +++-- .../agents/test_output_key_visibility.py | 180 ++++++++++++++++++ 2 files changed, 206 insertions(+), 18 deletions(-) create mode 100644 tests/unittests/agents/test_output_key_visibility.py diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index f52f07abb3..f352e6eb5f 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -608,7 +608,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing( self._exec_with_plugin( invocation_context=invocation_context, - session=session, + session=invocation_context.session, execute_fn=execute, is_live_call=False, ) @@ -622,7 +622,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: logger.debug('Running event compactor.') await _run_compaction_for_sliding_window( self.app, - session, + invocation_context.session, self.session_service, skip_token_compaction=invocation_context.token_compaction_checked, ) @@ -841,7 +841,7 @@ async def _exec_with_plugin( Args: invocation_context: The invocation context - session: The current session + session: The current session (ignored, kept for backward compatibility) execute_fn: A callable that returns an AsyncGenerator of Events is_live_call: Whether this is a live call @@ -866,7 +866,7 @@ async def _exec_with_plugin( ) if self._should_append_event(early_exit_event, is_live_call): await self.session_service.append_event( - session=session, + session=invocation_context.session, event=early_exit_event, ) yield early_exit_event @@ -931,13 +931,13 @@ async def _exec_with_plugin( ) if self._should_append_event(event, is_live_call): await self.session_service.append_event( - session=session, event=output_event + session=invocation_context.session, event=output_event ) for buffered_event in buffered_events: logger.debug('Appending buffered event: %s', buffered_event) await self.session_service.append_event( - session=session, event=buffered_event + session=invocation_context.session, event=buffered_event ) yield buffered_event # yield buffered events to caller buffered_events = [] @@ -947,12 +947,12 @@ async def _exec_with_plugin( if self._should_append_event(event, is_live_call): logger.debug('Appending non-buffered event: %s', event) await self.session_service.append_event( - session=session, event=output_event + session=invocation_context.session, event=output_event ) else: if event.partial is not True: await self.session_service.append_event( - session=session, event=output_event + session=invocation_context.session, event=output_event ) yield output_event @@ -1004,8 +1004,8 @@ async def _append_new_message_to_session( file_name = f'artifact_{invocation_context.invocation_id}_{i}' await self.artifact_service.save_artifact( app_name=self.app_name, - user_id=session.user_id, - session_id=session.id, + user_id=invocation_context.session.user_id, + session_id=invocation_context.session.id, filename=file_name, artifact=part, ) @@ -1032,7 +1032,9 @@ async def _append_new_message_to_session( if function_call := invocation_context._find_matching_function_call(event): event.branch = function_call.branch - await self.session_service.append_event(session=session, event=event) + await self.session_service.append_event( + session=invocation_context.session, event=event + ) async def run_live( self, @@ -1127,7 +1129,9 @@ async def run_live( ) root_agent = self.agent - invocation_context.agent = self._find_agent_to_run(session, root_agent) + invocation_context.agent = self._find_agent_to_run( + invocation_context.session, root_agent + ) async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing(ctx.agent.run_live(ctx)) as agen: @@ -1137,7 +1141,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing( self._exec_with_plugin( invocation_context=invocation_context, - session=session, + session=invocation_context.session, execute_fn=execute, is_live_call=True, ) @@ -1355,14 +1359,16 @@ async def _setup_context_for_new_invocation( # Step 2: Handle new message, by running callbacks and appending to # session. await self._handle_new_message( - session=session, + session=invocation_context.session, new_message=new_message, invocation_context=invocation_context, run_config=run_config, state_delta=state_delta, ) # Step 3: Set agent to run for the invocation. - invocation_context.agent = self._find_agent_to_run(session, self.agent) + invocation_context.agent = self._find_agent_to_run( + invocation_context.session, self.agent + ) return invocation_context async def _setup_context_for_resumed_invocation( @@ -1411,7 +1417,7 @@ async def _setup_context_for_resumed_invocation( # Step 3: Maybe handle new message. if new_message: await self._handle_new_message( - session=session, + session=invocation_context.session, new_message=user_message, invocation_context=invocation_context, run_config=run_config, @@ -1425,7 +1431,9 @@ async def _setup_context_for_resumed_invocation( # started from a sub-agent and paused on a sub-agent. # We should find the appropriate agent to run to continue the invocation. if self.agent.name not in invocation_context.end_of_agents: - invocation_context.agent = self._find_agent_to_run(session, self.agent) + invocation_context.agent = self._find_agent_to_run( + invocation_context.session, self.agent + ) return invocation_context def _find_user_message_for_invocation( @@ -1559,7 +1567,7 @@ async def _handle_new_message( if 'save_input_blobs_as_artifacts' in run_config.model_fields_set: deprecated_save_blobs = run_config.save_input_blobs_as_artifacts await self._append_new_message_to_session( - session=session, + session=invocation_context.session, new_message=new_message, invocation_context=invocation_context, save_input_blobs_as_artifacts=deprecated_save_blobs, diff --git a/tests/unittests/agents/test_output_key_visibility.py b/tests/unittests/agents/test_output_key_visibility.py new file mode 100644 index 0000000000..e20a676241 --- /dev/null +++ b/tests/unittests/agents/test_output_key_visibility.py @@ -0,0 +1,180 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LlmAgent output_key visibility in callbacks.""" + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.live_request_queue import LiveRequestQueue +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.events.event import Event +from google.adk.flows.llm_flows.auto_flow import AutoFlow +from google.genai import types +import pytest +from pytest_mock import MockerFixture + +from .. import testing_utils + +# Standard MockModel will be used instead of SafeMockModel + + +@pytest.mark.asyncio +async def test_output_key_visibility_in_after_agent_callback(): + """Test that output_key state delta is visible in after_agent_callback.""" + mock_response = "Hello! How can I help you?" + mock_model = testing_utils.MockModel.create(responses=[mock_response]) + + callback_called = False + captured_state_value = None + captured_session_state_value = None + + async def check_output_key(callback_context: CallbackContext): + nonlocal callback_called, captured_state_value, captured_session_state_value + callback_called = True + captured_state_value = callback_context.state.get("result", "NOT_FOUND") + captured_session_state_value = callback_context.session.state.get( + "result", "NOT_IN_RAW" + ) + + agent = LlmAgent( + name="my_agent", + model=mock_model, + instruction="Reply with a short greeting.", + output_key="result", + after_agent_callback=check_output_key, + ) + + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async(new_message="hello") + + assert callback_called, "Callback was not called" + + assert ( + captured_state_value == mock_response + ), f"Expected {mock_response}, got {captured_state_value}" + assert ( + captured_session_state_value == mock_response + ), f"Expected {mock_response}, got {captured_session_state_value}" + + +@pytest.mark.asyncio +async def test_output_key_visibility_in_run_live(mocker: MockerFixture): + """Test that output_key state delta is visible in after_agent_callback in run_live.""" + mock_response = "Hello! How can I help you?" + mock_model = testing_utils.MockModel.create(responses=[mock_response]) + + callback_called = False + captured_state_value = None + captured_session_state_value = None + + async def check_output_key(callback_context: CallbackContext): + nonlocal callback_called, captured_state_value, captured_session_state_value + callback_called = True + captured_state_value = callback_context.state.get("result", "NOT_FOUND") + captured_session_state_value = callback_context.session.state.get( + "result", "NOT_IN_RAW" + ) + + agent = LlmAgent( + name="my_agent", + model=mock_model, + instruction="Reply with a short greeting.", + output_key="result", + after_agent_callback=check_output_key, + ) + + async def mock_auto_flow_run_live(self, ctx): + yield Event( + id=Event.new_id(), + invocation_id=ctx.invocation_id, + author=ctx.agent.name, + content=types.Content(parts=[types.Part(text=mock_response)]), + ) + + mocker.patch.object(AutoFlow, "run_live", mock_auto_flow_run_live) + + runner = testing_utils.InMemoryRunner(root_agent=agent) + live_queue = LiveRequestQueue() + + agen = runner.runner.run_live( + user_id="test_user", + session_id=runner.session.id, + live_request_queue=live_queue, + ) + + # Send a message to trigger the agent + live_queue.send_content( + types.Content(role="user", parts=[types.Part(text="hello")]) + ) + + live_queue.close() + + async for event in agen: + pass + + assert callback_called, "Callback was not called" + assert ( + captured_state_value == mock_response + ), f"Expected {mock_response}, got {captured_state_value}" + assert ( + captured_session_state_value == mock_response + ), f"Expected {mock_response}, got {captured_session_state_value}" + + +@pytest.mark.asyncio +async def test_output_key_visibility_in_sequential_agent(): + """Test that output_key state delta is visible in next agent's before_agent_callback.""" + mock_response = "Hello from agent 1!" + mock_model = testing_utils.MockModel.create( + responses=[mock_response, "Hello from agent 2!"] + ) + + callback_called = False + captured_session_state_value = None + + async def check_before_agent(callback_context: CallbackContext): + nonlocal callback_called, captured_session_state_value + callback_called = True + captured_session_state_value = callback_context.session.state.get( + "result", "NOT_FOUND" + ) + + agent_1 = LlmAgent( + name="agent_1", + model=mock_model, + instruction="Reply with a short greeting.", + output_key="result", + ) + + agent_2 = LlmAgent( + name="agent_2", + model=mock_model, + instruction="Reply with a short greeting.", + before_agent_callback=check_before_agent, + ) + + sequential_agent = SequentialAgent( + name="seq_agent", + sub_agents=[agent_1, agent_2], + ) + + runner = testing_utils.InMemoryRunner(root_agent=sequential_agent) + + events = await runner.run_async(new_message="hello") + + assert callback_called, "Callback was not called" + assert ( + captured_session_state_value == mock_response + ), f"Expected {mock_response}, got {captured_session_state_value}" From 0cb9ae94b30ac2cff120b2c4ccab77e6b85cbf45 Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 15 May 2026 11:55:21 -0700 Subject: [PATCH 19/28] fix(models): treat empty GenerateContentResponse without prompt feedback as successful Previously, an empty `candidates` list without `prompt_feedback` resulted in an `UNKNOWN_ERROR`. This change updates the logic to handle such cases as a successful completion with no generated content, which is valid for certain model interactions like tool-driven turns. Co-authored-by: George Weale PiperOrigin-RevId: 916115022 --- src/google/adk/models/llm_response.py | 11 +++++++-- tests/unittests/models/test_llm_response.py | 25 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/google/adk/models/llm_response.py b/src/google/adk/models/llm_response.py index 10ab946455..c921f197c3 100644 --- a/src/google/adk/models/llm_response.py +++ b/src/google/adk/models/llm_response.py @@ -14,6 +14,7 @@ from __future__ import annotations +import logging from typing import Any from typing import Optional @@ -216,9 +217,15 @@ def create( model_version=generate_content_response.model_version, ) else: + # Some model backends can legitimately complete a turn without + # candidates (for example, tool-driven UI turns with no text). Treat + # this as an empty successful response rather than an unknown error. + logging.warning( + 'Received empty candidates and no prompt feedback in model ' + 'response. Treating as a successful empty response.' + ) return LlmResponse( - error_code='UNKNOWN_ERROR', - error_message='Unknown error.', + content=types.Content(role='model', parts=[]), usage_metadata=usage_metadata, model_version=generate_content_response.model_version, ) diff --git a/tests/unittests/models/test_llm_response.py b/tests/unittests/models/test_llm_response.py index e2cbe4c286..02b7126ab5 100644 --- a/tests/unittests/models/test_llm_response.py +++ b/tests/unittests/models/test_llm_response.py @@ -107,6 +107,31 @@ def test_llm_response_create_no_candidates(): assert response.error_message == 'Prompt blocked for safety' +def test_llm_response_create_no_candidates_without_prompt_feedback(): + """Test LlmResponse.create() for empty successful model responses.""" + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=10, + candidates_token_count=0, + total_token_count=10, + ) + generate_content_response = types.GenerateContentResponse( + candidates=[], + usage_metadata=usage_metadata, + model_version='gemini-2.5-flash', + ) + + response = LlmResponse.create(generate_content_response) + + assert response.error_code is None + assert response.error_message is None + assert response.finish_reason is None + assert response.content is not None + assert response.content.role == 'model' + assert not response.content.parts + assert response.usage_metadata == usage_metadata + assert response.model_version == 'gemini-2.5-flash' + + def test_llm_response_create_with_concrete_logprobs_result(): """Test LlmResponse.create() with detailed logprobs_result containing actual token data.""" # Create realistic logprobs data From 6ca6a14187f9c65982ae2f4d506a659171ee58ce Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 15 May 2026 14:53:45 -0700 Subject: [PATCH 20/28] perf: lazy-load service registries and split apps.app to cut cold start ~8% Co-authored-by: George Weale PiperOrigin-RevId: 916195791 --- src/google/adk/agents/invocation_context.py | 4 +- src/google/adk/apps/__init__.py | 22 ++++- src/google/adk/apps/_configs.py | 95 +++++++++++++++++++ src/google/adk/apps/app.py | 81 ++-------------- src/google/adk/artifacts/__init__.py | 26 ++++- src/google/adk/flows/llm_flows/single_flow.py | 2 +- src/google/adk/memory/__init__.py | 34 ++++--- src/google/adk/plugins/__init__.py | 29 +++++- src/google/adk/runners.py | 17 ++-- src/google/adk/sessions/__init__.py | 30 ++++-- 10 files changed, 229 insertions(+), 111 deletions(-) create mode 100644 src/google/adk/apps/_configs.py diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 7fdbaee89b..8dc1a01af2 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -25,8 +25,8 @@ from pydantic import Field from pydantic import PrivateAttr -from ..apps.app import EventsCompactionConfig -from ..apps.app import ResumabilityConfig +from ..apps._configs import EventsCompactionConfig +from ..apps._configs import ResumabilityConfig from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.auth_credential import AuthCredential from ..auth.credential_service.base_credential_service import BaseCredentialService diff --git a/src/google/adk/apps/__init__.py b/src/google/adk/apps/__init__.py index 3a5d0b0643..319293967b 100644 --- a/src/google/adk/apps/__init__.py +++ b/src/google/adk/apps/__init__.py @@ -12,10 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .app import App -from .app import ResumabilityConfig +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._configs import ResumabilityConfig + from .app import App __all__ = [ 'App', 'ResumabilityConfig', ] + +_LAZY_MEMBERS: dict[str, str] = { + 'App': 'app', + 'ResumabilityConfig': '_configs', +} + + +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/apps/_configs.py b/src/google/adk/apps/_configs.py new file mode 100644 index 0000000000..87f3666ebd --- /dev/null +++ b/src/google/adk/apps/_configs.py @@ -0,0 +1,95 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field +from pydantic import model_validator + +from ..utils.feature_decorator import experimental +from .base_events_summarizer import BaseEventsSummarizer + + +@experimental +class ResumabilityConfig(BaseModel): + """The config of the resumability for an application. + + The "resumability" in ADK refers to the ability to: + 1. pause an invocation upon a long-running function call. + 2. resume an invocation from the last event, if it's paused or failed midway + through. + + Note: ADK resumes the invocation in a best-effort manner: + 1. Tool call to resume needs to be idempotent because we only guarantee + an at-least-once behavior once resumed. + 2. Any temporary / in-memory state will be lost upon resumption. + """ + + is_resumable: bool = False + """Whether the app supports agent resumption. + If enabled, the feature will be enabled for all agents in the app. + """ + + +@experimental +class EventsCompactionConfig(BaseModel): + """The config of event compaction for an application.""" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + + summarizer: Optional[BaseEventsSummarizer] = None + """The event summarizer to use for compaction.""" + + compaction_interval: int + """The number of *new* user-initiated invocations that, once + fully represented in the session's events, will trigger a compaction.""" + + overlap_size: int + """The number of preceding invocations to include from the + end of the last compacted range. This creates an overlap between consecutive + compacted summaries, maintaining context.""" + + token_threshold: Optional[int] = Field( + default=None, + gt=0, + ) + """Post-invocation token threshold trigger. + + If set, ADK will attempt a post-invocation compaction when the most recently + observed prompt token count meets or exceeds this threshold. + """ + + event_retention_size: Optional[int] = Field(default=None, ge=0) + """Post-invocation raw event retention size. + + If token-based post-invocation compaction is triggered, this keeps the last N + raw events un-compacted. + """ + + @model_validator(mode="after") + def _validate_token_params(self) -> EventsCompactionConfig: + token_threshold_set = self.token_threshold is not None + retention_size_set = self.event_retention_size is not None + if token_threshold_set != retention_size_set: + raise ValueError( + "token_threshold and event_retention_size must be set together." + ) + return self diff --git a/src/google/adk/apps/app.py b/src/google/adk/apps/app.py index c20d581d9b..9bde128b7a 100644 --- a/src/google/adk/apps/app.py +++ b/src/google/adk/apps/app.py @@ -22,9 +22,16 @@ from ..agents.base_agent import BaseAgent from ..agents.context_cache_config import ContextCacheConfig -from ..apps.base_events_summarizer import BaseEventsSummarizer from ..plugins.base_plugin import BasePlugin -from ..utils.feature_decorator import experimental +from ._configs import EventsCompactionConfig +from ._configs import ResumabilityConfig + +__all__ = [ + "App", + "EventsCompactionConfig", + "ResumabilityConfig", + "validate_app_name", +] def validate_app_name(name: str) -> None: @@ -38,76 +45,6 @@ def validate_app_name(name: str) -> None: raise ValueError("App name cannot be 'user'; reserved for end-user input.") -@experimental -class ResumabilityConfig(BaseModel): - """The config of the resumability for an application. - - The "resumability" in ADK refers to the ability to: - 1. pause an invocation upon a long-running function call. - 2. resume an invocation from the last event, if it's paused or failed midway - through. - - Note: ADK resumes the invocation in a best-effort manner: - 1. Tool call to resume needs to be idempotent because we only guarantee - an at-least-once behavior once resumed. - 2. Any temporary / in-memory state will be lost upon resumption. - """ - - is_resumable: bool = False - """Whether the app supports agent resumption. - If enabled, the feature will be enabled for all agents in the app. - """ - - -@experimental -class EventsCompactionConfig(BaseModel): - """The config of event compaction for an application.""" - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - ) - - summarizer: Optional[BaseEventsSummarizer] = None - """The event summarizer to use for compaction.""" - - compaction_interval: int - """The number of *new* user-initiated invocations that, once - fully represented in the session's events, will trigger a compaction.""" - - overlap_size: int - """The number of preceding invocations to include from the - end of the last compacted range. This creates an overlap between consecutive - compacted summaries, maintaining context.""" - - token_threshold: Optional[int] = Field( - default=None, - gt=0, - ) - """Post-invocation token threshold trigger. - - If set, ADK will attempt a post-invocation compaction when the most recently - observed prompt token count meets or exceeds this threshold. - """ - - event_retention_size: Optional[int] = Field(default=None, ge=0) - """Post-invocation raw event retention size. - - If token-based post-invocation compaction is triggered, this keeps the last N - raw events un-compacted. - """ - - @model_validator(mode="after") - def _validate_token_params(self) -> EventsCompactionConfig: - token_threshold_set = self.token_threshold is not None - retention_size_set = self.event_retention_size is not None - if token_threshold_set != retention_size_set: - raise ValueError( - "token_threshold and event_retention_size must be set together." - ) - return self - - class App(BaseModel): """Represents an LLM-backed agentic application. diff --git a/src/google/adk/artifacts/__init__.py b/src/google/adk/artifacts/__init__.py index 5e56ffc737..af7912e617 100644 --- a/src/google/adk/artifacts/__init__.py +++ b/src/google/adk/artifacts/__init__.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + from .base_artifact_service import BaseArtifactService -from .file_artifact_service import FileArtifactService -from .gcs_artifact_service import GcsArtifactService -from .in_memory_artifact_service import InMemoryArtifactService + +if TYPE_CHECKING: + from .file_artifact_service import FileArtifactService + from .gcs_artifact_service import GcsArtifactService + from .in_memory_artifact_service import InMemoryArtifactService __all__ = [ 'BaseArtifactService', @@ -23,3 +30,16 @@ 'GcsArtifactService', 'InMemoryArtifactService', ] + +_LAZY_MEMBERS: dict[str, str] = { + 'FileArtifactService': 'file_artifact_service', + 'GcsArtifactService': 'gcs_artifact_service', + 'InMemoryArtifactService': 'in_memory_artifact_service', +} + + +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py index 932a265ed1..cc3fc9e6fa 100644 --- a/src/google/adk/flows/llm_flows/single_flow.py +++ b/src/google/adk/flows/llm_flows/single_flow.py @@ -22,7 +22,6 @@ from . import _nl_planning from . import _output_schema_processor from . import basic -from . import compaction from . import contents from . import context_cache_processor from . import identity @@ -36,6 +35,7 @@ def _create_request_processors(): """Create the standard request processor list for a single-agent flow.""" + from . import compaction from ...auth import auth_preprocessor return [ diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index c47fb8ec40..d40f3bf7d9 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -11,27 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING from .base_memory_service import BaseMemoryService -from .in_memory_memory_service import InMemoryMemoryService -from .vertex_ai_memory_bank_service import VertexAiMemoryBankService -logger = logging.getLogger('google_adk.' + __name__) +if TYPE_CHECKING: + from .in_memory_memory_service import InMemoryMemoryService + from .vertex_ai_memory_bank_service import VertexAiMemoryBankService + from .vertex_ai_rag_memory_service import VertexAiRagMemoryService __all__ = [ 'BaseMemoryService', 'InMemoryMemoryService', 'VertexAiMemoryBankService', + 'VertexAiRagMemoryService', ] -try: - from .vertex_ai_rag_memory_service import VertexAiRagMemoryService +_LAZY_MEMBERS: dict[str, str] = { + 'InMemoryMemoryService': 'in_memory_memory_service', + 'VertexAiMemoryBankService': 'vertex_ai_memory_bank_service', + 'VertexAiRagMemoryService': 'vertex_ai_rag_memory_service', +} + - __all__.append('VertexAiRagMemoryService') -except ImportError: - logger.debug( - 'The Vertex SDK is not installed. If you want to use the' - ' VertexAiRagMemoryService please install it. If not, you can ignore this' - ' warning.' - ) +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/plugins/__init__.py b/src/google/adk/plugins/__init__.py index 45caf16038..70347fd25e 100644 --- a/src/google/adk/plugins/__init__.py +++ b/src/google/adk/plugins/__init__.py @@ -1,8 +1,7 @@ # Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may in obtain a copy of the License at +# you may in obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # @@ -12,11 +11,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + from .base_plugin import BasePlugin -from .debug_logging_plugin import DebugLoggingPlugin -from .logging_plugin import LoggingPlugin from .plugin_manager import PluginManager -from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin + +if TYPE_CHECKING: + from .debug_logging_plugin import DebugLoggingPlugin + from .logging_plugin import LoggingPlugin + from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin __all__ = [ 'BasePlugin', @@ -25,3 +31,16 @@ 'PluginManager', 'ReflectAndRetryToolPlugin', ] + +_LAZY_MEMBERS: dict[str, str] = { + 'DebugLoggingPlugin': 'debug_logging_plugin', + 'LoggingPlugin': 'logging_plugin', + 'ReflectAndRetryToolPlugin': 'reflect_retry_tool_plugin', +} + + +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index f352e6eb5f..850c26bbba 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -26,9 +26,9 @@ from typing import Generator from typing import List from typing import Optional +from typing import TYPE_CHECKING import warnings -from google.adk.apps.compaction import _run_compaction_for_sliding_window from google.genai import types from .agents.base_agent import BaseAgent @@ -38,10 +38,7 @@ from .agents.invocation_context import new_invocation_context_id from .agents.live_request_queue import LiveRequestQueue from .agents.run_config import RunConfig -from .apps.app import App -from .apps.app import ResumabilityConfig from .artifacts.base_artifact_service import BaseArtifactService -from .artifacts.in_memory_artifact_service import InMemoryArtifactService from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .errors.session_not_found_error import SessionNotFoundError @@ -51,19 +48,21 @@ from .flows.llm_flows.functions import find_event_by_function_call_id from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService -from .memory.in_memory_memory_service import InMemoryMemoryService from .platform.thread import create_thread from .plugins.base_plugin import BasePlugin from .plugins.plugin_manager import PluginManager from .sessions.base_session_service import BaseSessionService from .sessions.base_session_service import GetSessionConfig -from .sessions.in_memory_session_service import InMemorySessionService from .sessions.session import Session from .telemetry.tracing import tracer from .tools.base_toolset import BaseToolset from .utils._debug_output import print_event from .utils.context_utils import Aclosing +if TYPE_CHECKING: + from .apps.app import App + from .apps.app import ResumabilityConfig + logger = logging.getLogger('google_adk.' + __name__) @@ -620,6 +619,8 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: # the end of an invocation.) if self.app and self.app.events_compaction_config: logger.debug('Running event compactor.') + from google.adk.apps.compaction import _run_compaction_for_sliding_window + await _run_compaction_for_sliding_window( self.app, invocation_context.session, @@ -1677,6 +1678,10 @@ def __init__( app: Optional App instance. plugin_close_timeout: The timeout in seconds for plugin close methods. """ + from .artifacts.in_memory_artifact_service import InMemoryArtifactService + from .memory.in_memory_memory_service import InMemoryMemoryService + from .sessions.in_memory_session_service import InMemorySessionService + if app is None and app_name is None: app_name = 'InMemoryRunner' super().__init__( diff --git a/src/google/adk/sessions/__init__.py b/src/google/adk/sessions/__init__.py index 7505eda346..db983f96f1 100644 --- a/src/google/adk/sessions/__init__.py +++ b/src/google/adk/sessions/__init__.py @@ -11,11 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + from .base_session_service import BaseSessionService -from .in_memory_session_service import InMemorySessionService from .session import Session from .state import State -from .vertex_ai_session_service import VertexAiSessionService + +if TYPE_CHECKING: + from .database_session_service import DatabaseSessionService + from .in_memory_session_service import InMemorySessionService + from .vertex_ai_session_service import VertexAiSessionService __all__ = [ 'BaseSessionService', @@ -26,16 +35,23 @@ 'VertexAiSessionService', ] +_LAZY_MEMBERS: dict[str, str] = { + 'InMemorySessionService': 'in_memory_session_service', + 'VertexAiSessionService': 'vertex_ai_session_service', +} + def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] if name == 'DatabaseSessionService': try: - from .database_session_service import DatabaseSessionService - - return DatabaseSessionService + module = importlib.import_module(f'{__name__}.database_session_service') except ImportError as e: raise ImportError( - 'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it is' - ' installed correctly.' + 'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it' + ' is installed correctly.' ) from e + return vars(module)['DatabaseSessionService'] raise AttributeError(f'module {__name__!r} has no attribute {name!r}') From 7e61b517027a23c640b7b636a87e04a0a02c392c Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 15 May 2026 14:55:41 -0700 Subject: [PATCH 21/28] fix(tools): preserve code_execution_result and executable_code in AgentTool AgentTool.run_async only extracted text parts from the inner agent's response, silently dropping code_execution_result.output and executable_code.code. Outer agents using an inner agent with a code executor saw nothing. Close #5481 Co-authored-by: George Weale PiperOrigin-RevId: 916196604 --- src/google/adk/tools/agent_tool.py | 16 +++- tests/unittests/tools/test_agent_tool.py | 105 +++++++++++++++++++++++ 2 files changed, 118 insertions(+), 3 deletions(-) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 8a43eb6311..2b8c083f23 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -41,6 +41,17 @@ from ..agents.base_agent import BaseAgent +def _part_to_text(part: types.Part) -> str: + """Returns user-visible text from a Part, including code execution output.""" + if part.text: + return part.text + if part.code_execution_result and part.code_execution_result.output: + return part.code_execution_result.output.rstrip('\n') + if part.executable_code and part.executable_code.code: + return part.executable_code.code + return '' + + def _get_input_schema(agent: BaseAgent) -> Optional[type[BaseModel]]: """Extracts the input_schema from an agent. @@ -269,9 +280,8 @@ async def run_async( if last_content is None or last_content.parts is None: return '' - merged_text = '\n'.join( - p.text for p in last_content.parts if p.text and not p.thought - ) + parts_text = (_part_to_text(p) for p in last_content.parts if not p.thought) + merged_text = '\n'.join(t for t in parts_text if t) output_schema = _get_output_schema(self.agent) if output_schema: tool_result = validate_schema(output_schema, merged_text) diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index cc6c0b2134..67a2674ff7 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -15,12 +15,14 @@ from typing import Any from typing import Optional +from google.adk.agents.base_agent import BaseAgent from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import Agent from google.adk.agents.run_config import RunConfig from google.adk.agents.sequential_agent import SequentialAgent from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.events.event import Event from google.adk.features import FeatureName from google.adk.features._feature_registry import temporary_feature_override from google.adk.memory.in_memory_memory_service import InMemoryMemoryService @@ -985,6 +987,109 @@ async def test_run_async_handles_none_parts_in_response(): assert tool_result == '' +async def _run_agent_tool_with_parts(parts: list[types.Part]) -> Any: + """Drives AgentTool with an inner agent whose final event content is `parts`.""" + + class _StaticAgent(BaseAgent): + + async def _run_async_impl(self, ctx): + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + content=types.Content(role='model', parts=parts), + ) + + inner = _StaticAgent(name='inner_agent', description='static') + agent_tool = AgentTool(agent=inner) + + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + invocation_context = InvocationContext( + invocation_id='invocation_id', + agent=inner, + session=session, + session_service=session_service, + ) + tool_context = ToolContext(invocation_context=invocation_context) + + return await agent_tool.run_async( + args={'request': 'test request'}, tool_context=tool_context + ) + + +@mark.asyncio +async def test_run_async_extracts_text_only(): + """Plain text parts pass through unchanged.""" + result = await _run_agent_tool_with_parts([types.Part(text='hello world')]) + assert result == 'hello world' + + +@mark.asyncio +async def test_run_async_extracts_code_execution_result_only(): + """code_execution_result.output and executable_code.code are returned.""" + result = await _run_agent_tool_with_parts([ + types.Part( + executable_code=types.ExecutableCode( + language=types.Language.PYTHON, code='print(2 ** 10)' + ) + ), + types.Part( + code_execution_result=types.CodeExecutionResult( + outcome=types.Outcome.OUTCOME_OK, output='1024\n' + ) + ), + ]) + assert result == 'print(2 ** 10)\n1024' + + +@mark.asyncio +async def test_run_async_extracts_text_and_code_execution_result(): + """Mixed text + code parts are concatenated in order.""" + result = await _run_agent_tool_with_parts([ + types.Part(text='Here is the answer:'), + types.Part( + executable_code=types.ExecutableCode( + language=types.Language.PYTHON, code='print(2 ** 10)' + ) + ), + types.Part( + code_execution_result=types.CodeExecutionResult( + outcome=types.Outcome.OUTCOME_OK, output='1024\n' + ) + ), + ]) + assert result == 'Here is the answer:\nprint(2 ** 10)\n1024' + + +@mark.asyncio +async def test_run_async_extracts_executable_code_only(): + """executable_code.code alone is returned when no result part follows.""" + result = await _run_agent_tool_with_parts([ + types.Part( + executable_code=types.ExecutableCode( + language=types.Language.PYTHON, code='print("hi")' + ) + ), + ]) + assert result == 'print("hi")' + + +@mark.asyncio +async def test_run_async_skips_thought_parts(): + """Parts marked thought=True are dropped regardless of kind.""" + result = await _run_agent_tool_with_parts([ + types.Part(text='thinking out loud', thought=True), + types.Part( + code_execution_result=types.CodeExecutionResult( + outcome=types.Outcome.OUTCOME_OK, output='42\n' + ) + ), + ]) + assert result == '42' + + class TestAgentToolWithCompositeAgents: """Tests for AgentTool wrapping composite agents (SequentialAgent, etc.).""" From 2388090cd26d09180d011e6138dd17f43ba6c7a2 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Fri, 15 May 2026 14:59:49 -0700 Subject: [PATCH 22/28] chore: Remove experimental tag from SkillToolset Co-authored-by: Kathy Wu PiperOrigin-RevId: 916198410 --- src/google/adk/features/_feature_registry.py | 2 +- src/google/adk/tools/bash_tool.py | 2 -- src/google/adk/tools/skill_toolset.py | 26 ++------------------ tests/unittests/tools/test_skill_toolset.py | 8 ------ 4 files changed, 3 insertions(+), 35 deletions(-) diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index b5f51f2825..a99a51b430 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -150,7 +150,7 @@ class FeatureConfig: FeatureStage.EXPERIMENTAL, default_on=True ), FeatureName.SKILL_TOOLSET: FeatureConfig( - FeatureStage.EXPERIMENTAL, default_on=True + FeatureStage.STABLE, default_on=True ), FeatureName.SPANNER_ADMIN_TOOLSET: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True diff --git a/src/google/adk/tools/bash_tool.py b/src/google/adk/tools/bash_tool.py index 89de3bf34f..85575656af 100644 --- a/src/google/adk/tools/bash_tool.py +++ b/src/google/adk/tools/bash_tool.py @@ -29,7 +29,6 @@ from google.genai import types -from .. import features from .base_tool import BaseTool from .tool_context import ToolContext @@ -99,7 +98,6 @@ def _set_resource_limits(policy: BashToolPolicy) -> None: logger.warning("Failed to set resource limits: %s", e) -@features.experimental(features.FeatureName.SKILL_TOOLSET) class ExecuteBashTool(BaseTool): """Tool to execute a validated bash command within a workspace directory.""" diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index ef579d8256..18a894cdc2 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -25,7 +25,6 @@ import mimetypes from typing import Any from typing import TYPE_CHECKING -import warnings from google.genai import types from typing_extensions import override @@ -33,8 +32,6 @@ from ..agents.readonly_context import ReadonlyContext from ..code_executors.base_code_executor import BaseCodeExecutor from ..code_executors.code_execution_utils import CodeExecutionInput -from ..features import experimental -from ..features import FeatureName from ..skills import models from ..skills import prompt from ..skills import SkillRegistry @@ -58,7 +55,7 @@ " conversation history for you to analyze." ) -_DEFAULT_SKILL_SYSTEM_INSTRUCTION = ( +DEFAULT_SKILL_SYSTEM_INSTRUCTION = ( "You can use specialized 'skills' to help you with complex tasks. " "You MUST use the skill tools to interact with these skills.\n\n" "Skills are folders of instructions and resources that extend your " @@ -88,7 +85,6 @@ ) -@experimental(FeatureName.SKILL_TOOLSET) class ListSkillsTool(BaseTool): """Tool to list all available skills.""" @@ -118,7 +114,6 @@ async def run_async( return prompt.format_skills_as_xml(skills) -@experimental(FeatureName.SKILL_TOOLSET) class SearchSkillsTool(BaseTool): """Tool to search for relevant skills in the registry.""" @@ -181,7 +176,6 @@ async def run_async( } -@experimental(FeatureName.SKILL_TOOLSET) class LoadSkillTool(BaseTool): """Tool to load a skill's instructions.""" @@ -250,7 +244,6 @@ async def run_async( } -@experimental(FeatureName.SKILL_TOOLSET) class LoadSkillResourceTool(BaseTool): """Tool to load resources (references, assets, or scripts) from a skill.""" @@ -710,7 +703,6 @@ def _build_wrapper_code( return "\n".join(code_lines) -@experimental(FeatureName.SKILL_TOOLSET) class RunSkillScriptTool(BaseTool): """Tool to execute scripts from a skill's scripts/ directory.""" @@ -874,7 +866,6 @@ async def run_async( ) -@experimental(FeatureName.SKILL_TOOLSET) class SkillToolset(BaseToolset): """A toolset for managing and interacting with agent skills.""" @@ -1063,7 +1054,7 @@ async def process_llm_request( self, *, tool_context: ToolContext, llm_request: LlmRequest ) -> None: """Processes the outgoing LLM request to include available skills.""" - instructions = [_DEFAULT_SKILL_SYSTEM_INSTRUCTION] + instructions = [DEFAULT_SKILL_SYSTEM_INSTRUCTION] has_list_skills = any(isinstance(t, ListSkillsTool) for t in self._tools) @@ -1090,16 +1081,3 @@ async def close(self) -> None: cached.cancel() self._fetched_skill_cache.clear() await super().close() - - -def __getattr__(name: str) -> Any: - if name == "DEFAULT_SKILL_SYSTEM_INSTRUCTION": - warnings.warn( - "DEFAULT_SKILL_SYSTEM_INSTRUCTION is experimental. Its content " - "is internal implementation and will change in minor/patch releases " - "to tune agent performance.", - UserWarning, - stacklevel=2, - ) - return _DEFAULT_SKILL_SYSTEM_INSTRUCTION - raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/tests/unittests/tools/test_skill_toolset.py b/tests/unittests/tools/test_skill_toolset.py index f637377513..aca34c2523 100644 --- a/tests/unittests/tools/test_skill_toolset.py +++ b/tests/unittests/tools/test_skill_toolset.py @@ -471,14 +471,6 @@ async def test_process_llm_request_without_list_skills_tool( assert "skill2" in instructions[1] -def test_default_skill_system_instruction_warning(): - with pytest.warns( - UserWarning, match="DEFAULT_SKILL_SYSTEM_INSTRUCTION is experimental" - ): - instruction = skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION - assert "specialized 'skills'" in instruction - - def test_duplicate_skill_name_raises(mock_skill1): skill_dup = mock.create_autospec(models.Skill, instance=True) skill_dup.name = "skill1" From ec54bd439e31c99a32d773ace04b73cb3a275675 Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 15 May 2026 15:14:11 -0700 Subject: [PATCH 23/28] perf(utils): cache find_context_parameter introspection Adds @functools.lru_cache to find_context_parameter so the inspect.signature + typing.get_type_hints lookup runs once per function, not on every MCP confirmation callback or declaration build. No public surface change. Co-authored-by: George Weale PiperOrigin-RevId: 916204929 --- src/google/adk/utils/context_utils.py | 2 ++ tests/unittests/utils/test_context_utils.py | 24 +++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/src/google/adk/utils/context_utils.py b/src/google/adk/utils/context_utils.py index c0fde29ae4..bd80fa2ff3 100644 --- a/src/google/adk/utils/context_utils.py +++ b/src/google/adk/utils/context_utils.py @@ -21,6 +21,7 @@ from __future__ import annotations from contextlib import aclosing +import functools import inspect import typing from typing import Any @@ -62,6 +63,7 @@ def _is_context_type(annotation: Any) -> bool: return annotation is Context +@functools.lru_cache(maxsize=1024) def find_context_parameter(func: Callable[..., Any]) -> str | None: """Find the parameter name that has a Context type annotation. diff --git a/tests/unittests/utils/test_context_utils.py b/tests/unittests/utils/test_context_utils.py index 56583b9c6d..b8173be4b0 100644 --- a/tests/unittests/utils/test_context_utils.py +++ b/tests/unittests/utils/test_context_utils.py @@ -15,10 +15,12 @@ """Tests for context_utils module.""" from typing import Optional +from unittest import mock from google.adk.agents.callback_context import CallbackContext from google.adk.agents.context import Context from google.adk.tools.tool_context import ToolContext +from google.adk.utils import context_utils from google.adk.utils.context_utils import find_context_parameter @@ -129,3 +131,25 @@ def my_tool( return query assert find_context_parameter(my_tool) == 'ctx' + + +class TestFindContextParameterCaching: + """Tests for find_context_parameter caching behavior.""" + + def test_repeated_calls_inspect_signature_once(self): + """Repeated calls with the same function reuse the cached result.""" + + def my_tool(ctx: Context) -> str: + return 'ok' + + find_context_parameter.cache_clear() + + with mock.patch.object( + context_utils.inspect, + 'signature', + wraps=context_utils.inspect.signature, + ) as spy: + for _ in range(10): + assert find_context_parameter(my_tool) == 'ctx' + + assert spy.call_count == 1 From 430915970062a4ff926a65e5884cc5bc2912c48c Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Fri, 15 May 2026 15:49:42 -0700 Subject: [PATCH 24/28] fix(tools): Prevent AnyIO CancelScope task boundary violations during MCP session creation failure Co-authored-by: Sasha Sobran PiperOrigin-RevId: 916220631 --- .../adk/tools/mcp_tool/mcp_session_manager.py | 11 ++++--- .../mcp_tool/test_mcp_session_manager.py | 32 +++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 1a14fe1972..a4b45cf16b 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -542,10 +542,13 @@ async def create_session( sampling_capabilities=self._sampling_capabilities, ) - session = await asyncio.wait_for( - exit_stack.enter_async_context(session_context), - timeout=timeout_in_seconds, - ) + if is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING): # pylint: disable=protected-access + session = await exit_stack.enter_async_context(session_context) + else: + session = await asyncio.wait_for( + exit_stack.enter_async_context(session_context), + timeout=timeout_in_seconds, + ) # Store session, exit stack, and loop in the pool. The pool storage # remains a tuple for backward-compatibility with downstream tests diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 27df785277..a94b2eb885 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -1043,3 +1043,35 @@ def test_env_var_disable_acts_as_kill_switch(self): os.environ[disable] = saved_disable if saved_enable is not None: os.environ[enable] = saved_enable + + @pytest.mark.asyncio + @patch("google.adk.tools.mcp_tool.mcp_session_manager.asyncio.wait_for") + async def test_create_session_does_not_use_wait_for_when_ge_is_enabled( + self, mock_wait_for + ): + """create_session must not wrap enter_async_context in asyncio.wait_for when GE is enabled.""" + from google.adk.features import FeatureName + from google.adk.features._feature_registry import temporary_feature_override + + manager = MCPSessionManager( + StdioConnectionParams( + server_params=StdioServerParameters(command="dummy", args=[]), + timeout=5.0, + ) + ) + with temporary_feature_override( + FeatureName._MCP_GRACEFUL_ERROR_HANDLING, True + ): + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack" + ) as mock_stack: + mock_stack.return_value.enter_async_context = AsyncMock() + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.SessionContext" + ): + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.stdio_client" + ): + await manager.create_session() + + mock_wait_for.assert_not_called() From a5cddb8e86f51dbf84c732dc3257bf4515db33e6 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 15 May 2026 16:05:29 -0700 Subject: [PATCH 25/28] chore: Remove experimental tag from SkillToolset PiperOrigin-RevId: 916227202 --- src/google/adk/features/_feature_registry.py | 2 +- src/google/adk/tools/bash_tool.py | 2 ++ src/google/adk/tools/skill_toolset.py | 26 ++++++++++++++++++-- tests/unittests/tools/test_skill_toolset.py | 8 ++++++ 4 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index a99a51b430..b5f51f2825 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -150,7 +150,7 @@ class FeatureConfig: FeatureStage.EXPERIMENTAL, default_on=True ), FeatureName.SKILL_TOOLSET: FeatureConfig( - FeatureStage.STABLE, default_on=True + FeatureStage.EXPERIMENTAL, default_on=True ), FeatureName.SPANNER_ADMIN_TOOLSET: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True diff --git a/src/google/adk/tools/bash_tool.py b/src/google/adk/tools/bash_tool.py index 85575656af..89de3bf34f 100644 --- a/src/google/adk/tools/bash_tool.py +++ b/src/google/adk/tools/bash_tool.py @@ -29,6 +29,7 @@ from google.genai import types +from .. import features from .base_tool import BaseTool from .tool_context import ToolContext @@ -98,6 +99,7 @@ def _set_resource_limits(policy: BashToolPolicy) -> None: logger.warning("Failed to set resource limits: %s", e) +@features.experimental(features.FeatureName.SKILL_TOOLSET) class ExecuteBashTool(BaseTool): """Tool to execute a validated bash command within a workspace directory.""" diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index 18a894cdc2..ef579d8256 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -25,6 +25,7 @@ import mimetypes from typing import Any from typing import TYPE_CHECKING +import warnings from google.genai import types from typing_extensions import override @@ -32,6 +33,8 @@ from ..agents.readonly_context import ReadonlyContext from ..code_executors.base_code_executor import BaseCodeExecutor from ..code_executors.code_execution_utils import CodeExecutionInput +from ..features import experimental +from ..features import FeatureName from ..skills import models from ..skills import prompt from ..skills import SkillRegistry @@ -55,7 +58,7 @@ " conversation history for you to analyze." ) -DEFAULT_SKILL_SYSTEM_INSTRUCTION = ( +_DEFAULT_SKILL_SYSTEM_INSTRUCTION = ( "You can use specialized 'skills' to help you with complex tasks. " "You MUST use the skill tools to interact with these skills.\n\n" "Skills are folders of instructions and resources that extend your " @@ -85,6 +88,7 @@ ) +@experimental(FeatureName.SKILL_TOOLSET) class ListSkillsTool(BaseTool): """Tool to list all available skills.""" @@ -114,6 +118,7 @@ async def run_async( return prompt.format_skills_as_xml(skills) +@experimental(FeatureName.SKILL_TOOLSET) class SearchSkillsTool(BaseTool): """Tool to search for relevant skills in the registry.""" @@ -176,6 +181,7 @@ async def run_async( } +@experimental(FeatureName.SKILL_TOOLSET) class LoadSkillTool(BaseTool): """Tool to load a skill's instructions.""" @@ -244,6 +250,7 @@ async def run_async( } +@experimental(FeatureName.SKILL_TOOLSET) class LoadSkillResourceTool(BaseTool): """Tool to load resources (references, assets, or scripts) from a skill.""" @@ -703,6 +710,7 @@ def _build_wrapper_code( return "\n".join(code_lines) +@experimental(FeatureName.SKILL_TOOLSET) class RunSkillScriptTool(BaseTool): """Tool to execute scripts from a skill's scripts/ directory.""" @@ -866,6 +874,7 @@ async def run_async( ) +@experimental(FeatureName.SKILL_TOOLSET) class SkillToolset(BaseToolset): """A toolset for managing and interacting with agent skills.""" @@ -1054,7 +1063,7 @@ async def process_llm_request( self, *, tool_context: ToolContext, llm_request: LlmRequest ) -> None: """Processes the outgoing LLM request to include available skills.""" - instructions = [DEFAULT_SKILL_SYSTEM_INSTRUCTION] + instructions = [_DEFAULT_SKILL_SYSTEM_INSTRUCTION] has_list_skills = any(isinstance(t, ListSkillsTool) for t in self._tools) @@ -1081,3 +1090,16 @@ async def close(self) -> None: cached.cancel() self._fetched_skill_cache.clear() await super().close() + + +def __getattr__(name: str) -> Any: + if name == "DEFAULT_SKILL_SYSTEM_INSTRUCTION": + warnings.warn( + "DEFAULT_SKILL_SYSTEM_INSTRUCTION is experimental. Its content " + "is internal implementation and will change in minor/patch releases " + "to tune agent performance.", + UserWarning, + stacklevel=2, + ) + return _DEFAULT_SKILL_SYSTEM_INSTRUCTION + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/tests/unittests/tools/test_skill_toolset.py b/tests/unittests/tools/test_skill_toolset.py index aca34c2523..f637377513 100644 --- a/tests/unittests/tools/test_skill_toolset.py +++ b/tests/unittests/tools/test_skill_toolset.py @@ -471,6 +471,14 @@ async def test_process_llm_request_without_list_skills_tool( assert "skill2" in instructions[1] +def test_default_skill_system_instruction_warning(): + with pytest.warns( + UserWarning, match="DEFAULT_SKILL_SYSTEM_INSTRUCTION is experimental" + ): + instruction = skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + assert "specialized 'skills'" in instruction + + def test_duplicate_skill_name_raises(mock_skill1): skill_dup = mock.create_autospec(models.Skill, instance=True) skill_dup.name = "skill1" From 790c9bef9a336ea000d0cf68e63b025dfead5227 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 15 May 2026 16:14:00 -0700 Subject: [PATCH 26/28] feat: add general support for Gemini Live API in ADK evaluate * Implements live inference in evaluation_generator.py using Runner.run_live(). * Updates base_eval_service.py, local_eval_service.py to support live mode data structures and connection handling. Testing: Added unit tests: * test_generates_inferences_with_user_simulator_live * test_live_session_manually_triggers_callbacks * test_live_session_manually_triggers_callbacks_with_tools * test_perform_inference_with_use_live * test_perform_inference_single_eval_item_live * test_perform_inference_single_eval_item_non_live PiperOrigin-RevId: 916231029 --- .../adk/evaluation/base_eval_service.py | 13 + src/google/adk/evaluation/constants.py | 2 + .../adk/evaluation/evaluation_generator.py | 357 +++++++++++++++++- .../adk/evaluation/local_eval_service.py | 40 +- .../evaluation/test_evaluation_generator.py | 325 ++++++++++++++++ .../evaluation/test_local_eval_service.py | 124 ++++++ 6 files changed, 849 insertions(+), 12 deletions(-) diff --git a/src/google/adk/evaluation/base_eval_service.py b/src/google/adk/evaluation/base_eval_service.py index bafe5846c2..927dd8cd04 100644 --- a/src/google/adk/evaluation/base_eval_service.py +++ b/src/google/adk/evaluation/base_eval_service.py @@ -25,6 +25,7 @@ from pydantic import ConfigDict from pydantic import Field +from .constants import DEFAULT_LIVE_TIMEOUT_SECONDS from .eval_case import Invocation from .eval_metrics import EvalMetric from .eval_result import EvalCaseResult @@ -81,6 +82,18 @@ class InferenceConfig(BaseModel): could also overwhelm those tools.""", ) + use_live: bool = Field( + default=False, + description="""Whether to use live (bidirectional streaming) mode for +inference. This is required for Live API models (e.g., gemini-*-live-*).""", + ) + + live_timeout_seconds: int = Field( + default=DEFAULT_LIVE_TIMEOUT_SECONDS, + description="""Timeout in seconds for waiting for model turn completion in +live mode.""", + ) + class InferenceRequest(BaseModel): """Represent a request to perform inferences for the eval cases in an eval set.""" diff --git a/src/google/adk/evaluation/constants.py b/src/google/adk/evaluation/constants.py index 5aed2b101e..e7ee1f24d2 100644 --- a/src/google/adk/evaluation/constants.py +++ b/src/google/adk/evaluation/constants.py @@ -18,3 +18,5 @@ 'Eval module is not installed, please install via `pip install' ' "google-adk[eval]"`.' ) + +DEFAULT_LIVE_TIMEOUT_SECONDS = 300 diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index f8fb6795aa..d5a6629366 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import copy import importlib import logging @@ -22,15 +23,26 @@ from typing import Optional import uuid +from google.genai import errors +from google.genai import types from google.genai.types import Content from pydantic import BaseModel +from websockets.exceptions import ConnectionClosed +from websockets.exceptions import ConnectionClosedOK +from ..agents.callback_context import CallbackContext +from ..agents.invocation_context import InvocationContext +from ..agents.live_request_queue import LiveRequestQueue from ..agents.llm_agent import Agent +from ..agents.run_config import RunConfig +from ..agents.run_config import StreamingMode from ..artifacts.base_artifact_service import BaseArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..events.event import Event +from ..flows.llm_flows.functions import handle_function_calls_live from ..memory.base_memory_service import BaseMemoryService from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..models.llm_request import LlmRequest from ..runners import Runner from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService @@ -39,6 +51,7 @@ from ._retry_options_utils import EnsureRetryOptionsPlugin from .app_details import AgentDetails from .app_details import AppDetails +from .constants import DEFAULT_LIVE_TIMEOUT_SECONDS from .eval_case import EvalCase from .eval_case import Invocation from .eval_case import InvocationEvent @@ -66,6 +79,181 @@ class EvalCaseResponses(BaseModel): responses: list[list[Invocation]] +class _LiveSession: + """Manages the background task and state for a live session.""" + + def __init__( + self, + runner: Runner, + session: Session, + user_id: str, + session_id: str, + ): + self.runner = runner + self.session = session + self.user_id = user_id + self.session_id = session_id + self.live_request_queue = LiveRequestQueue() + self.event_queue = asyncio.Queue() + self.turn_complete_event = asyncio.Event() + self.live_finished = asyncio.Event() + self.current_invocation_id = Event.new_id() + self.consume_task = None + + async def __aenter__(self) -> _LiveSession: + """Starts the background task.""" + self.consume_task = asyncio.create_task(self._consume_events()) + return self + + async def _consume_events(self) -> None: + """Background task: consume events from run_live.""" + try: + run_config = RunConfig( + streaming_mode=StreamingMode.BIDI, + response_modalities=["AUDIO"], + output_audio_transcription=types.AudioTranscriptionConfig(), + input_audio_transcription=types.AudioTranscriptionConfig(), + ) + + invocation_context = self.runner._new_invocation_context_for_live( + self.session, + live_request_queue=self.live_request_queue, + run_config=run_config, + ) + invocation_context.agent = self.runner._find_agent_to_run( + self.session, self.runner.agent + ) + + callback_context = None + llm_request = LlmRequest() + + async with Aclosing( + invocation_context.agent._llm_flow._preprocess_async( + invocation_context, llm_request + ) + ) as agen: + async for _ in agen: + pass + + callback_context = CallbackContext(invocation_context) + # By default, live API calls do not include before_model_callback and + # after_model_callback. These callbacks are needed by the plugins to + # include the agent instructions and tool declarations in the eval + # invocations for autorater evaluation. + await invocation_context.plugin_manager.run_before_model_callback( + callback_context=callback_context, + llm_request=llm_request, + ) + + in_function_call_loop = False + async with Aclosing( + invocation_context.agent.run_live(invocation_context) + ) as agen: + async for event in agen: + assert event is not None + event.invocation_id = self.current_invocation_id + if callback_context: + await invocation_context.plugin_manager.run_after_model_callback( + callback_context=callback_context, + llm_response=event, + ) + await self.event_queue.put(event) + if not event.partial: + await self.runner.session_service.append_event( + session=self.session, event=event + ) + function_calls = event.get_function_calls() + if function_calls: + in_function_call_loop = True + inv_context = InvocationContext( + session_service=self.runner.session_service, + invocation_id=event.invocation_id, + agent=self.runner.agent, + session=self.session, + run_config=run_config, + ) + + if isinstance(self.runner.agent, Agent): + resolved_tools = await self.runner.agent.canonical_tools( + inv_context + ) + tools_dict = {t.name: t for t in resolved_tools} + else: + tools_dict = {} + + try: + response_event = await handle_function_calls_live( + invocation_context=inv_context, + function_call_event=event, + tools_dict=tools_dict, + ) + + if ( + response_event + and response_event.content + and response_event.content.parts + ): + for part in response_event.content.parts: + if part.function_response: + tool_content = types.Content( + role="tool", + parts=[part], + ) + self.live_request_queue.send_content(tool_content) + except (ValueError, RuntimeError, KeyError, TypeError) as e: + logger.error( + "Failed to handle function calls: %s", + e, + exc_info=True, + ) + for fc in function_calls: + response_content = types.FunctionResponse( + name=fc.name, + id=fc.id, + response={"error": str(e)}, + ) + tool_content = types.Content( + role="tool", + parts=[types.Part(function_response=response_content)], + ) + self.live_request_queue.send_content(tool_content) + if event.turn_complete and event.author != _USER_AUTHOR: + if not in_function_call_loop: + self.turn_complete_event.set() + else: + in_function_call_loop = False + finally: + self.live_finished.set() + self.turn_complete_event.set() # Unblock any waiters + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Closes the queue and waits for the background task to finish.""" + self.live_request_queue.close() + try: + await asyncio.wait_for(self.consume_task, timeout=30) + except asyncio.TimeoutError: + logger.warning("Timed out waiting for run_live to finish.") + assert self.consume_task is not None + self.consume_task.cancel() + try: + await self.consume_task + except asyncio.CancelledError: + pass + except (ConnectionClosed, errors.APIError) as e: + # The Gemini Live API uses WebSockets. When the session ends normally, the + # connection is closed with code 1000. Some client libraries may raise an + # exception rather than handling it silently. We log this as INFO to + # avoid false-positive error reports for expected behavior. + is_normal_closure = isinstance(e, ConnectionClosedOK) or ( + isinstance(e, errors.APIError) and e.code == 1000 + ) + + if is_normal_closure: + logger.info("Ignored WebSocket normal closure exception: %s", e) + else: + raise + + class EvaluationGenerator: """Generates evaluation responses for agents.""" @@ -187,6 +375,168 @@ async def _generate_inferences_for_single_user_invocation( yield event + @staticmethod + async def _generate_inferences_for_single_user_invocation_live( + live_request_queue: LiveRequestQueue, + event_queue: asyncio.Queue[Event], + user_message: Content, + current_invocation_id: str, + turn_complete_event: asyncio.Event, + live_timeout_seconds: int, + ) -> AsyncGenerator[Event, None]: + """Generates inferences for a single user invocation in live mode.""" + yield Event( + content=user_message, + author=_USER_AUTHOR, + invocation_id=current_invocation_id, + ) + + live_request_queue.send_content(user_message) + + try: + await asyncio.wait_for( + turn_complete_event.wait(), + timeout=live_timeout_seconds, + ) + except asyncio.TimeoutError: + logger.warning( + "Timed out waiting for model turn completion in live mode." + ) + raise + + while not event_queue.empty(): + event = await event_queue.get() + if event.invocation_id == current_invocation_id: + yield event + + @staticmethod + async def _generate_inferences_from_root_agent_live( + root_agent: Agent, + user_simulator: UserSimulator, + reset_func: Optional[Any] = None, + initial_session: Optional[SessionInput] = None, + session_id: Optional[str] = None, + session_service: Optional[BaseSessionService] = None, + artifact_service: Optional[BaseArtifactService] = None, + memory_service: Optional[BaseMemoryService] = None, + live_timeout_seconds: int = DEFAULT_LIVE_TIMEOUT_SECONDS, + ) -> list[Invocation]: + """Scrapes the root agent in coordination with the user simulator in live mode.""" + if not session_service: + session_service = InMemorySessionService() + + if not memory_service: + memory_service = InMemoryMemoryService() + + app_name = ( + initial_session.app_name if initial_session else "EvaluationGenerator" + ) + user_id = initial_session.user_id if initial_session else "test_user_id" + session_id = session_id if session_id else str(uuid.uuid4()) + + session = await session_service.create_session( + app_name=app_name, + user_id=user_id, + state=initial_session.state if initial_session else {}, + session_id=session_id, + ) + + if not artifact_service: + artifact_service = InMemoryArtifactService() + + # Reset agent state for each query + if callable(reset_func): + reset_func() + + # We ensure that there is some kind of retries on the llm_requests that are + # generated from the Agent. This is done to make inferencing step of evals + # more resilient to temporary model failures. + ensure_retry_options_plugin = EnsureRetryOptionsPlugin( + name="ensure_retry_options" + ) + request_intercepter_plugin = _RequestIntercepterPlugin( + name="request_intercepter_plugin" + ) + async with Runner( + app_name=app_name, + agent=root_agent, + artifact_service=artifact_service, + session_service=session_service, + memory_service=memory_service, + plugins=[request_intercepter_plugin, ensure_retry_options_plugin], + ) as runner: + events = [] + + # `_LiveSession` is a runtime connection manager wrapping the `Session` + # data model (which stores conversation history/state). It manages the + # active bidirectional WebSocket stream and background consumer tasks. + live_session = _LiveSession(runner, session, user_id, session_id) + await live_session.__aenter__() + + try: + turn_idx = 0 + while True: + turn_idx += 1 + next_user_message = await user_simulator.get_next_user_message( + copy.deepcopy(events) + ) + if next_user_message.status == UserSimulatorStatus.SUCCESS: + live_session.current_invocation_id = Event.new_id() + live_session.turn_complete_event.clear() + + logger.info("Waiting for model to complete turn %d...", turn_idx) + + async for ( + event + ) in EvaluationGenerator._generate_inferences_for_single_user_invocation_live( + live_request_queue=live_session.live_request_queue, + event_queue=live_session.event_queue, + user_message=next_user_message.user_message, + current_invocation_id=live_session.current_invocation_id, + turn_complete_event=live_session.turn_complete_event, + live_timeout_seconds=live_timeout_seconds, + ): + events.append(event) + + turn_transcription = "" + for evt in events: + if ( + evt.invocation_id == live_session.current_invocation_id + and evt.author != _USER_AUTHOR + and evt.output_transcription + ): + if not evt.partial and evt.output_transcription.text: + turn_transcription = evt.output_transcription.text + else: + turn_transcription += evt.output_transcription.text + if turn_transcription: + synthetic_event = Event( + content=Content( + role="model", + parts=[types.Part(text=turn_transcription)], + ), + author=runner.agent.name, + invocation_id=live_session.current_invocation_id, + ) + events.append(synthetic_event) + + if live_session.live_finished.is_set(): + logger.info("Live session finished signal detected.") + break + else: # no message generated + break + finally: + await live_session.__aexit__(None, None, None) + + app_details_by_invocation_id = ( + EvaluationGenerator._get_app_details_by_invocation_id( + events, request_intercepter_plugin + ) + ) + return EvaluationGenerator.convert_events_to_eval_invocations( + events, app_details_by_invocation_id + ) + @staticmethod async def _generate_inferences_from_root_agent( root_agent: Agent, @@ -308,7 +658,12 @@ def convert_events_to_eval_invocations( final_event = event for p in event.content.parts: - if p.function_call or p.function_response or p.text: + if ( + p.function_call + or p.function_response + or p.text + or p.inline_data + ): events_to_add.append(event) break diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index 2426204ca0..b749b7e8f1 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -182,6 +182,8 @@ async def run_inference(eval_case): eval_set_id=inference_request.eval_set_id, eval_case=eval_case, root_agent=self._root_agent, + use_live=inference_request.inference_config.use_live, + live_timeout_seconds=inference_request.inference_config.live_timeout_seconds, ) inference_results = [run_inference(eval_case) for eval_case in eval_cases] @@ -470,6 +472,8 @@ async def _perform_inference_single_eval_item( eval_set_id: str, eval_case: EvalCase, root_agent: BaseAgent, + use_live: bool, + live_timeout_seconds: int, ) -> InferenceResult: initial_session = eval_case.session_input session_id = self._session_id_supplier() @@ -482,17 +486,31 @@ async def _perform_inference_single_eval_item( try: with client_label_context(EVAL_CLIENT_LABEL): - inferences = ( - await EvaluationGenerator._generate_inferences_from_root_agent( - root_agent=root_agent, - user_simulator=self._user_simulator_provider.provide(eval_case), - initial_session=initial_session, - session_id=session_id, - session_service=self._session_service, - artifact_service=self._artifact_service, - memory_service=self._memory_service, - ) - ) + if use_live: + inferences = await EvaluationGenerator._generate_inferences_from_root_agent_live( + root_agent=root_agent, + user_simulator=self._user_simulator_provider.provide(eval_case), + initial_session=initial_session, + session_id=session_id, + session_service=self._session_service, + artifact_service=self._artifact_service, + memory_service=self._memory_service, + live_timeout_seconds=live_timeout_seconds, + ) + else: + inferences = ( + await EvaluationGenerator._generate_inferences_from_root_agent( + root_agent=root_agent, + user_simulator=self._user_simulator_provider.provide( + eval_case + ), + initial_session=initial_session, + session_id=session_id, + session_service=self._session_service, + artifact_service=self._artifact_service, + memory_service=self._memory_service, + ) + ) inference_result.inferences = inferences inference_result.status = InferenceStatus.SUCCESS diff --git a/tests/unittests/evaluation/test_evaluation_generator.py b/tests/unittests/evaluation/test_evaluation_generator.py index a4aa8691fd..508b6f5c9c 100644 --- a/tests/unittests/evaluation/test_evaluation_generator.py +++ b/tests/unittests/evaluation/test_evaluation_generator.py @@ -14,8 +14,11 @@ from __future__ import annotations +import asyncio + from google.adk.evaluation.app_details import AgentDetails from google.adk.evaluation.app_details import AppDetails +from google.adk.evaluation.evaluation_generator import _LiveSession from google.adk.evaluation.evaluation_generator import EvaluationGenerator from google.adk.evaluation.request_intercepter_plugin import _RequestIntercepterPlugin from google.adk.evaluation.simulation.user_simulator import NextUserMessage @@ -396,6 +399,61 @@ async def mock_run_async(*args, **kwargs): ) +class TestGenerateInferencesForSingleUserInvocationLive: + """Test cases for EvaluationGenerator._generate_inferences_for_single_user_invocation_live method.""" + + @pytest.mark.asyncio + async def test_generate_inferences_live(self, mocker): + """Tests live inference generation.""" + mock_live_request_queue = mocker.MagicMock() + event_queue = asyncio.Queue() + turn_complete_event = asyncio.Event() + + user_content = types.Content(parts=[types.Part(text="User query")]) + invocation_id = "inv1" + + agent_event = _build_event( + "agent", [types.Part(text="Agent response")], invocation_id + ) + other_event = _build_event( + "agent", [types.Part(text="Other response")], "inv2" + ) + + gen = EvaluationGenerator._generate_inferences_for_single_user_invocation_live( + live_request_queue=mock_live_request_queue, + event_queue=event_queue, + user_message=user_content, + current_invocation_id=invocation_id, + turn_complete_event=turn_complete_event, + live_timeout_seconds=300, + ) + + # First yield should be the user message + first_event = await gen.__anext__() + assert first_event.author == "user" + assert first_event.content == user_content + assert first_event.invocation_id == invocation_id + + # Mock turn_complete_event.wait to avoid blocking + turn_complete_event.wait = mocker.AsyncMock() + + # Put events in queue BEFORE advancing + await event_queue.put(agent_event) + await event_queue.put(other_event) + + # Now advance to get the next event + second_event = await gen.__anext__() + + assert mock_live_request_queue.send_content.called + mock_live_request_queue.send_content.assert_called_once_with(user_content) + + assert second_event == agent_event + + # The generator should be exhausted now because other_event doesn't match invocation_id + with pytest.raises(StopAsyncIteration): + await gen.__anext__() + + @pytest.fixture def mock_runner(mocker): """Provides a mock Runner for testing.""" @@ -479,3 +537,270 @@ async def mock_generate_inferences_side_effect( mock_generate_inferences.assert_called_once() called_with_content = mock_generate_inferences.call_args.args[3] assert called_with_content.parts[0].text == "message 1" + + @pytest.mark.asyncio + async def test_generates_inferences_with_user_simulator_live( + self, mocker, mock_runner, mock_session_service + ): + """Tests that inferences are generated by interacting with a user simulator in live mode.""" + mock_agent = mocker.MagicMock() + mock_user_sim = mocker.MagicMock(spec=UserSimulator) + + # Mock user simulator will produce one message, then stop. + async def get_next_user_message_side_effect(*args, **kwargs): + if mock_user_sim.get_next_user_message.call_count == 1: + return NextUserMessage( + status=UserSimulatorStatus.SUCCESS, + user_message=types.Content(parts=[types.Part(text="message 1")]), + ) + return NextUserMessage(status=UserSimulatorStatus.STOP_SIGNAL_DETECTED) + + mock_user_sim.get_next_user_message = mocker.AsyncMock( + side_effect=get_next_user_message_side_effect + ) + + mock_generate_inferences_live = mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._generate_inferences_for_single_user_invocation_live" + ) + mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._get_app_details_by_invocation_id" + ) + mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator.convert_events_to_eval_invocations" + ) + + # Mock _LiveSession context manager + mock_live_session = mocker.MagicMock() + mock_live_session.__aenter__ = mocker.AsyncMock( + return_value=mock_live_session + ) + mock_live_session.__aexit__ = mocker.AsyncMock(return_value=None) + mock_live_session.live_request_queue = mocker.MagicMock() + mock_live_session.event_queue = asyncio.Queue() + mock_live_session.turn_complete_event = asyncio.Event() + mock_live_session.live_finished = asyncio.Event() + + mock_live_session_cls = mocker.patch( + "google.adk.evaluation.evaluation_generator._LiveSession", + return_value=mock_live_session, + ) + + # Each call to _generate_inferences_for_single_user_invocation_live will + # yield one user and one agent event. + async def mock_generate_inferences_live_side_effect(*args, **kwargs): + yield _build_event("user", [types.Part(text="message 1")], "inv1") + yield _build_event("agent", [types.Part(text="agent_response")], "inv1") + + mock_generate_inferences_live.side_effect = ( + mock_generate_inferences_live_side_effect + ) + + await EvaluationGenerator._generate_inferences_from_root_agent_live( + root_agent=mock_agent, + user_simulator=mock_user_sim, + live_timeout_seconds=600, + ) + + # Check that user simulator was called until it stopped. + assert mock_user_sim.get_next_user_message.call_count == 2 + + # Check that we generated inferences for each user message. + mock_generate_inferences_live.assert_called_once() + called_with_content = mock_generate_inferences_live.call_args.kwargs[ + "user_message" + ] + assert called_with_content.parts[0].text == "message 1" + assert ( + mock_generate_inferences_live.call_args.kwargs["live_timeout_seconds"] + == 600 + ) + + # Verify that the agent response was collected + mock_convert = EvaluationGenerator.convert_events_to_eval_invocations + mock_convert.assert_called_once() + events_passed = mock_convert.call_args.args[0] + + agent_events = [e for e in events_passed if e.author == "agent"] + assert len(agent_events) == 1 + assert agent_events[0].content.parts[0].text == "agent_response" + + # Verify that the _LiveSession constructor was called + mock_live_session_cls.assert_called_once() + + +class TestLiveSessionCallbacks: + """Unit tests verifying that _LiveSession manually triggers callbacks.""" + + @pytest.mark.asyncio + async def test_live_session_manually_triggers_callbacks(self, mocker): + from google.adk.agents.callback_context import CallbackContext + from google.adk.agents.llm_agent import Agent + from google.adk.models.llm_request import LlmRequest + + # 1. Setup mock runner, agent, and session + mock_runner = mocker.MagicMock() + mock_runner.session_service.append_event = mocker.AsyncMock() + mock_session = mocker.MagicMock() + mock_agent = mocker.MagicMock(spec=Agent) + mock_runner.agent = mock_agent + mock_runner._find_agent_to_run.return_value = mock_agent + + mock_agent.name = "test_agent" + + # Mock _llm_flow._preprocess_async to set dummy instruction + async def mock_preprocess_async(invocation_context, llm_request): + llm_request.config.system_instruction = "mock instruction" + return + yield # make it an async generator + + mock_flow = mocker.MagicMock() + mock_flow._preprocess_async = mock_preprocess_async + mock_agent._llm_flow = mock_flow + + # Mock run_live stream yielding one event + mock_event = Event( + author="agent", + content=types.Content(parts=[types.Part(text="Hello")]), + invocation_id="test_invocation_id", + ) + + async def mock_run_live(*args, **kwargs): + yield mock_event + + mock_agent.run_live.return_value = mock_run_live() + + # Mock plugin_manager on invocation context + mock_plugin_manager = mocker.MagicMock() + mock_plugin_manager.run_before_model_callback = mocker.AsyncMock() + mock_plugin_manager.run_after_model_callback = mocker.AsyncMock() + mock_runner._new_invocation_context_for_live.return_value.plugin_manager = ( + mock_plugin_manager + ) + mock_runner._new_invocation_context_for_live.return_value.agent = mock_agent + + # 2. Instantiate and enter _LiveSession + live_session = _LiveSession( + runner=mock_runner, + session=mock_session, + user_id="test_user", + session_id="test_session", + ) + + # Directly run _consume_events as a coroutine for synchronous-style testing + await live_session._consume_events() + + # 3. Assertions + mock_plugin_manager.run_before_model_callback.assert_called_once() + called_before_args = mock_plugin_manager.run_before_model_callback.call_args + assert isinstance( + called_before_args.kwargs["callback_context"], CallbackContext + ) + assert isinstance(called_before_args.kwargs["llm_request"], LlmRequest) + assert ( + called_before_args.kwargs["llm_request"].config.system_instruction + == "mock instruction" + ) + + mock_plugin_manager.run_after_model_callback.assert_called_once() + called_after_args = mock_plugin_manager.run_after_model_callback.call_args + assert isinstance( + called_after_args.kwargs["callback_context"], CallbackContext + ) + assert isinstance(called_after_args.kwargs["llm_response"], Event) + assert called_after_args.kwargs["llm_response"] == mock_event + + @pytest.mark.asyncio + async def test_live_session_manually_triggers_callbacks_with_tools( + self, mocker + ): + from google.adk.agents.callback_context import CallbackContext + from google.adk.agents.llm_agent import Agent + from google.adk.models.llm_request import LlmRequest + + # 1. Setup mock runner, agent, and session + mock_runner = mocker.MagicMock() + mock_runner.session_service.append_event = mocker.AsyncMock() + mock_session = mocker.MagicMock() + mock_agent = mocker.MagicMock(spec=Agent) + mock_runner.agent = mock_agent + mock_runner._find_agent_to_run.return_value = mock_agent + + mock_agent.name = "test_agent" + + # Set up a mock tool + mock_tool = mocker.MagicMock() + mock_tool.name = "get_weather" + mock_decl = types.FunctionDeclaration( + name="get_weather", + description="Get weather details", + ) + mock_tool._get_declaration.return_value = mock_decl + + # Mock _llm_flow._preprocess_async to set instruction and append tool + async def mock_preprocess_async(invocation_context, llm_request): + llm_request.config.system_instruction = "mock instruction" + llm_request.append_tools([mock_tool]) + return + yield # make it an async generator + + mock_flow = mocker.MagicMock() + mock_flow._preprocess_async = mock_preprocess_async + mock_agent._llm_flow = mock_flow + + # Mock run_live stream yielding one event + mock_event = Event( + author="agent", + content=types.Content(parts=[types.Part(text="Hello")]), + invocation_id="test_invocation_id", + ) + + async def mock_run_live(*args, **kwargs): + yield mock_event + + mock_agent.run_live.return_value = mock_run_live() + + # Mock plugin_manager on invocation context + mock_plugin_manager = mocker.MagicMock() + mock_plugin_manager.run_before_model_callback = mocker.AsyncMock() + mock_plugin_manager.run_after_model_callback = mocker.AsyncMock() + mock_runner._new_invocation_context_for_live.return_value.plugin_manager = ( + mock_plugin_manager + ) + mock_runner._new_invocation_context_for_live.return_value.agent = mock_agent + + # 2. Instantiate and enter _LiveSession + live_session = _LiveSession( + runner=mock_runner, + session=mock_session, + user_id="test_user", + session_id="test_session", + ) + + # Directly run _consume_events as a coroutine + await live_session._consume_events() + + # 3. Assertions + mock_plugin_manager.run_before_model_callback.assert_called_once() + called_before_args = mock_plugin_manager.run_before_model_callback.call_args + assert isinstance( + called_before_args.kwargs["callback_context"], CallbackContext + ) + + llm_request = called_before_args.kwargs["llm_request"] + assert isinstance(llm_request, LlmRequest) + assert llm_request.config.system_instruction == "mock instruction" + + # Assert that tool was correctly wrapped under types.Tool format + assert len(llm_request.config.tools) == 1 + wrapped_tool = llm_request.config.tools[0] + assert isinstance(wrapped_tool, types.Tool) + assert len(wrapped_tool.function_declarations) == 1 + assert wrapped_tool.function_declarations[0].name == "get_weather" + + mock_plugin_manager.run_after_model_callback.assert_called_once() + called_after_args = mock_plugin_manager.run_after_model_callback.call_args + assert isinstance( + called_after_args.kwargs["callback_context"], CallbackContext + ) + assert isinstance(called_after_args.kwargs["llm_response"], Event) + assert called_after_args.kwargs["llm_response"] == mock_event diff --git a/tests/unittests/evaluation/test_local_eval_service.py b/tests/unittests/evaluation/test_local_eval_service.py index 386c1fd07a..894fc27c07 100644 --- a/tests/unittests/evaluation/test_local_eval_service.py +++ b/tests/unittests/evaluation/test_local_eval_service.py @@ -247,12 +247,59 @@ async def test_perform_inference_with_case_ids( eval_set_id="test_eval_set", eval_case=eval_set.eval_cases[0], root_agent=dummy_agent, + use_live=False, + live_timeout_seconds=300, ) eval_service._perform_inference_single_eval_item.assert_any_call( app_name="test_app", eval_set_id="test_eval_set", eval_case=eval_set.eval_cases[2], root_agent=dummy_agent, + use_live=False, + live_timeout_seconds=300, + ) + + +@pytest.mark.asyncio +async def test_perform_inference_with_use_live( + eval_service, + dummy_agent, + mock_eval_sets_manager, + mocker, +): + eval_set = EvalSet( + eval_set_id="test_eval_set", + eval_cases=[ + EvalCase(eval_id="case1", conversation=[], session_input=None), + ], + ) + mock_eval_sets_manager.get_eval_set.return_value = eval_set + + mock_inference_result = mocker.MagicMock() + eval_service._perform_inference_single_eval_item = mocker.AsyncMock( + return_value=mock_inference_result + ) + + inference_request = InferenceRequest( + app_name="test_app", + eval_set_id="test_eval_set", + inference_config=InferenceConfig( + parallelism=1, use_live=True, live_timeout_seconds=600 + ), + ) + + results = [] + async for result in eval_service.perform_inference(inference_request): + results.append(result) + + assert len(results) == 1 + eval_service._perform_inference_single_eval_item.assert_called_once_with( + app_name="test_app", + eval_set_id="test_eval_set", + eval_case=eval_set.eval_cases[0], + root_agent=dummy_agent, + use_live=True, + live_timeout_seconds=600, ) @@ -791,3 +838,80 @@ def test_copy_invocation_rubrics_to_actual_invocations(): _copy_invocation_rubrics_to_actual_invocations(expected, actual) assert actual[0].rubrics == [rubric1] assert actual[1].rubrics == [rubric2] + + +@pytest.mark.asyncio +async def test_perform_inference_single_eval_item_live( + eval_service, dummy_agent, mocker +): + eval_case = EvalCase(eval_id="case1", conversation=[], session_input=None) + mock_generate_live = mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._generate_inferences_from_root_agent_live" + ) + mock_generate_live.return_value = [] + + eval_service._session_id_supplier = mocker.MagicMock( + return_value="test_session_id" + ) + mock_user_sim = mocker.MagicMock() + eval_service._user_simulator_provider.provide = mocker.MagicMock( + return_value=mock_user_sim + ) + + await eval_service._perform_inference_single_eval_item( + app_name="test_app", + eval_set_id="test_eval_set", + eval_case=eval_case, + root_agent=dummy_agent, + use_live=True, + live_timeout_seconds=600, + ) + + mock_generate_live.assert_called_once_with( + root_agent=dummy_agent, + user_simulator=mock_user_sim, + initial_session=None, + session_id="test_session_id", + session_service=eval_service._session_service, + artifact_service=eval_service._artifact_service, + memory_service=eval_service._memory_service, + live_timeout_seconds=600, + ) + + +@pytest.mark.asyncio +async def test_perform_inference_single_eval_item_non_live( + eval_service, dummy_agent, mocker +): + eval_case = EvalCase(eval_id="case1", conversation=[], session_input=None) + mock_generate = mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._generate_inferences_from_root_agent" + ) + mock_generate.return_value = [] + + eval_service._session_id_supplier = mocker.MagicMock( + return_value="test_session_id" + ) + mock_user_sim = mocker.MagicMock() + eval_service._user_simulator_provider.provide = mocker.MagicMock( + return_value=mock_user_sim + ) + + await eval_service._perform_inference_single_eval_item( + app_name="test_app", + eval_set_id="test_eval_set", + eval_case=eval_case, + root_agent=dummy_agent, + use_live=False, + live_timeout_seconds=300, + ) + + mock_generate.assert_called_once_with( + root_agent=dummy_agent, + user_simulator=mock_user_sim, + initial_session=None, + session_id="test_session_id", + session_service=eval_service._session_service, + artifact_service=eval_service._artifact_service, + memory_service=eval_service._memory_service, + ) From dc3cc2be7d1f3b6e438e8651327b01a840de85bd Mon Sep 17 00:00:00 2001 From: Raman369AI Date: Fri, 27 Mar 2026 00:36:28 -0500 Subject: [PATCH 27/28] fix: raise ValueError for unsupported MIME types in file_data URI path The file_data.file_uri path silently fell back to application/octet-stream when no MIME type could be determined, then passed it to LiteLLM which raised a cryptic internal ValueError. The inline_data path already had fail-fast behavior for unsupported types but the file_data path did not. This change removes the _DEFAULT_MIME_TYPE fallback and raises ValueError early with an actionable message for two cases: when no MIME type can be determined from the URI, display_name, or explicit field, and when the resolved type is application/octet-stream regardless of whether it was set by the caller or arrived via a library default. Both cases cause the same downstream failure. The logic order is also corrected so that providers which always produce a text fallback (anthropic, non-Gemini Vertex AI) and OpenAI/Azure HTTP media URLs are handled before the MIME type guard, keeping those paths unaffected. Tests are updated to assert the new ValueError and a new test covers the explicit application/octet-stream case. --- src/google/adk/models/lite_llm.py | 51 ++++++++++++-------- tests/unittests/models/test_litellm.py | 67 +++++++++++++------------- 2 files changed, 64 insertions(+), 54 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 3a6c36624d..5fa26261d7 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -234,10 +234,6 @@ def _get_provider_from_model(model: str) -> str: return "" -# Default MIME type when none can be inferred -_DEFAULT_MIME_TYPE = "application/octet-stream" - - def _infer_mime_type_from_uri(uri: str) -> Optional[str]: """Attempts to infer MIME type from a URI's path extension. @@ -1103,33 +1099,33 @@ async def _get_content( }) continue - # Determine MIME type: use explicit value, infer from URI, or use default. + # Resolve MIME type early: needed before the media-URL shortcut below, + # which must run before the generic text-fallback check. The raise is + # deferred until after all early-continue paths so that providers which + # always fall back to text (anthropic, non-Gemini Vertex AI) are never + # asked for a MIME type they cannot supply. mime_type = part.file_data.mime_type if not mime_type: mime_type = _infer_mime_type_from_uri(part.file_data.file_uri) if not mime_type and part.file_data.display_name: guessed_mime_type, _ = mimetypes.guess_type(part.file_data.display_name) mime_type = guessed_mime_type - if not mime_type: - # LiteLLM's Vertex AI backend requires format for GCS URIs. - mime_type = _DEFAULT_MIME_TYPE - logger.debug( - "Could not determine MIME type for file_uri %s, using default: %s", - part.file_data.file_uri, - mime_type, - ) - mime_type = _normalize_mime_type(mime_type) + if mime_type: + mime_type = _normalize_mime_type(mime_type) + # For OpenAI/Azure: HTTP media URLs (image, video, audio) are sent as + # typed URL blocks and must be handled before the generic text fallback. if provider in _FILE_ID_REQUIRED_PROVIDERS and _is_http_url( part.file_data.file_uri ): - url_content_type = _media_url_content_type(mime_type) - if url_content_type: - content_objects.append({ - "type": url_content_type, - url_content_type: {"url": part.file_data.file_uri}, - }) - continue + if mime_type: + url_content_type = _media_url_content_type(mime_type) + if url_content_type: + content_objects.append({ + "type": url_content_type, + url_content_type: {"url": part.file_data.file_uri}, + }) + continue if _requires_file_uri_fallback(provider, model, part.file_data.file_uri): logger.debug( @@ -1147,6 +1143,19 @@ async def _get_content( }) continue + # All remaining providers (e.g. Vertex AI + Gemini) require a specific + # MIME type in the file object. Both a missing type and + # 'application/octet-stream' cause a downstream ValueError from LiteLLM + # regardless of whether the value was set explicitly by the caller or + # arrived via a default fallback; raise early with an actionable message. + if not mime_type or mime_type == "application/octet-stream": + type_label = mime_type or "(unknown)" + raise ValueError( + f"Cannot process file_uri {part.file_data.file_uri!r}: MIME type" + f" {type_label!r} is not supported. Please set a specific MIME" + " type on `file_data.mime_type`." + ) + file_object: ChatCompletionFileUrlObject = { "file_id": part.file_data.file_uri, } diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index c195076349..3ebee8a2f4 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1908,12 +1908,12 @@ async def test_content_to_message_param_user_message_file_uri_only( @pytest.mark.asyncio async def test_content_to_message_param_user_message_file_uri_without_mime_type(): - """Test handling of file_data without mime_type (GcsArtifactService scenario). + """Test that file_data without an inferable mime_type raises ValueError. When using GcsArtifactService, artifacts may have file_uri (gs://...) but - without mime_type set. LiteLLM's Vertex AI backend requires the format - field to be present, so we infer MIME type from the URI extension or use - a default fallback to ensure compatibility. + without mime_type set. When the MIME type cannot be determined from the URI + extension or display_name, ADK raises a clear ValueError rather than + forwarding an unsupported 'application/octet-stream' to LiteLLM. See: https://github.com/google/adk-python/issues/3787 """ @@ -1930,22 +1930,8 @@ async def test_content_to_message_param_user_message_file_uri_without_mime_type( ], ) - message = await _content_to_message_param(content) - assert message == { - "role": "user", - "content": [ - {"type": "text", "text": "Analyze this file."}, - { - "type": "file", - "file": { - "file_id": ( - "gs://agent-artifact-bucket/app/user/session/artifact/0" - ), - "format": "application/octet-stream", - }, - }, - ], - } + with pytest.raises(ValueError, match="Cannot process file_uri"): + await _content_to_message_param(content) @pytest.mark.asyncio @@ -3117,27 +3103,42 @@ async def test_get_content_file_uri_infers_from_display_name(): @pytest.mark.asyncio async def test_get_content_file_uri_default_mime_type(): - """Test that file_uri without extension uses default MIME type. + """Test that file_uri without an inferable extension raises ValueError. When file_data has a file_uri without a recognizable extension and no explicit - mime_type, a default MIME type should be used to ensure compatibility with - LiteLLM backends. + mime_type, ADK raises a clear ValueError instead of forwarding the unsupported + 'application/octet-stream' MIME type to LiteLLM. See: https://github.com/google/adk-python/issues/3787 """ - # Use Part constructor directly to create file_data without mime_type - # (types.Part.from_uri requires a valid mime_type when it can't infer) parts = [ types.Part(file_data=types.FileData(file_uri="gs://bucket/artifact/0")) ] - content = await _get_content(parts) - assert content[0] == { - "type": "file", - "file": { - "file_id": "gs://bucket/artifact/0", - "format": "application/octet-stream", - }, - } + with pytest.raises(ValueError, match="Cannot process file_uri"): + await _get_content(parts) + + +@pytest.mark.asyncio +async def test_get_content_file_uri_explicit_octet_stream_raises(): + """Test that an explicit application/octet-stream MIME type raises ValueError. + + 'application/octet-stream' is semantically equivalent to an unknown type and + causes the same downstream ValueError from LiteLLM whether it arrives as a + default fallback or is set explicitly by the caller. ADK raises early with + an actionable message in both cases. + + See: https://github.com/google/adk-python/issues/3787 + """ + parts = [ + types.Part( + file_data=types.FileData( + file_uri="gs://bucket/artifact/0", + mime_type="application/octet-stream", + ) + ) + ] + with pytest.raises(ValueError, match="application/octet-stream"): + await _get_content(parts) @pytest.mark.asyncio From c7eb917674a9bed14970215e6f2514afdd11a3b9 Mon Sep 17 00:00:00 2001 From: Raman369AI Date: Fri, 24 Apr 2026 23:57:09 -0500 Subject: [PATCH 28/28] test: cover both ValueError branches for file_uri MIME type guard Adds test_content_to_message_param_user_message_file_uri_explicit_octet_stream to confirm that an upstream caller passing mime_type='application/octet-stream' raises a clear ValueError, covering both branches of the combined guard. Fixes: https://github.com/google/adk-python/issues/5022 --- tests/unittests/models/test_litellm.py | 38 ++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 3ebee8a2f4..c8bf5be010 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1915,7 +1915,7 @@ async def test_content_to_message_param_user_message_file_uri_without_mime_type( extension or display_name, ADK raises a clear ValueError rather than forwarding an unsupported 'application/octet-stream' to LiteLLM. - See: https://github.com/google/adk-python/issues/3787 + See: https://github.com/google/adk-python/issues/5022 """ file_part = types.Part( file_data=types.FileData( @@ -1934,6 +1934,34 @@ async def test_content_to_message_param_user_message_file_uri_without_mime_type( await _content_to_message_param(content) +@pytest.mark.asyncio +async def test_content_to_message_param_user_message_file_uri_explicit_octet_stream(): + """Test that an explicit application/octet-stream MIME type raises ValueError. + + Upstream callers may explicitly set mime_type to 'application/octet-stream' + when the true type is unknown. ADK treats this identically to a missing MIME + type and raises early rather than forwarding the unsupported type to LiteLLM. + + See: https://github.com/google/adk-python/issues/5022 + """ + file_part = types.Part( + file_data=types.FileData( + file_uri="gs://agent-artifact-bucket/app/user/session/artifact/0", + mime_type="application/octet-stream", + ) + ) + content = types.Content( + role="user", + parts=[ + types.Part.from_text(text="Analyze this file."), + file_part, + ], + ) + + with pytest.raises(ValueError, match="application/octet-stream"): + await _content_to_message_param(content) + + @pytest.mark.asyncio async def test_content_to_message_param_user_message_file_uri_infer_mime_type(): """Test MIME type inference from file_uri extension. @@ -1941,7 +1969,7 @@ async def test_content_to_message_param_user_message_file_uri_infer_mime_type(): When file_data has a file_uri with a recognizable extension but no explicit mime_type, the MIME type should be inferred from the extension. - See: https://github.com/google/adk-python/issues/3787 + See: https://github.com/google/adk-python/issues/5022 """ file_part = types.Part( file_data=types.FileData( @@ -3053,7 +3081,7 @@ async def test_get_content_file_uri_infer_mime_type(): When file_data has a file_uri with a recognizable extension but no explicit mime_type, the MIME type should be inferred from the extension. - See: https://github.com/google/adk-python/issues/3787 + See: https://github.com/google/adk-python/issues/5022 """ # Use Part constructor directly to test MIME type inference in _get_content # (types.Part.from_uri does its own inference, so we bypass it) @@ -3109,7 +3137,7 @@ async def test_get_content_file_uri_default_mime_type(): mime_type, ADK raises a clear ValueError instead of forwarding the unsupported 'application/octet-stream' MIME type to LiteLLM. - See: https://github.com/google/adk-python/issues/3787 + See: https://github.com/google/adk-python/issues/5022 """ parts = [ types.Part(file_data=types.FileData(file_uri="gs://bucket/artifact/0")) @@ -3127,7 +3155,7 @@ async def test_get_content_file_uri_explicit_octet_stream_raises(): default fallback or is set explicitly by the caller. ADK raises early with an actionable message in both cases. - See: https://github.com/google/adk-python/issues/3787 + See: https://github.com/google/adk-python/issues/5022 """ parts = [ types.Part(