diff --git a/.github/workflows/gemini-dispatch.yml b/.github/workflows/gemini-dispatch.yml new file mode 100644 index 0000000000..9c2bf8ec9d --- /dev/null +++ b/.github/workflows/gemini-dispatch.yml @@ -0,0 +1,189 @@ +name: '🔀 Gemini Dispatch' + +on: + pull_request_review_comment: + types: + - 'created' + pull_request_review: + types: + - 'submitted' + issue_comment: + types: + - 'created' + +defaults: + run: + shell: 'bash' + +jobs: + debugger: + if: |- + ${{ fromJSON(vars.GEMINI_DEBUG || vars.ACTIONS_STEP_DEBUG || false) }} + runs-on: 'ubuntu-latest' + permissions: + contents: 'read' + steps: + - name: 'Print context for debugging' + env: + DEBUG_event_name: '${{ github.event_name }}' + DEBUG_event__action: '${{ github.event.action }}' + DEBUG_event__comment__author_association: '${{ github.event.comment.author_association }}' + DEBUG_event__issue__author_association: '${{ github.event.issue.author_association }}' + DEBUG_event__pull_request__author_association: '${{ github.event.pull_request.author_association }}' + DEBUG_event__review__author_association: '${{ github.event.review.author_association }}' + DEBUG_event: '${{ toJSON(github.event) }}' + run: |- + env | grep '^DEBUG_' + + dispatch: + # Only trigger if user types @gemini-cli and author association is OWNER, MEMBER, or COLLABORATOR + if: |- + github.event.sender.type == 'User' && + startsWith(github.event.comment.body || github.event.review.body, '@gemini-cli') && + contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association || github.event.review.author_association) + runs-on: 'ubuntu-latest' + permissions: + contents: 'read' + issues: 'write' + pull-requests: 'write' + outputs: + command: '${{ steps.extract_command.outputs.command }}' + request: '${{ steps.extract_command.outputs.request }}' + additional_context: '${{ steps.extract_command.outputs.additional_context }}' + issue_number: '${{ github.event.pull_request.number || github.event.issue.number }}' + steps: + - name: 'Mint identity token' + id: 'mint_identity_token' + if: |- + ${{ vars.APP_ID }} + uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 + with: + app-id: '${{ vars.APP_ID }}' + private-key: '${{ secrets.APP_PRIVATE_KEY }}' + permission-contents: 'read' + permission-issues: 'write' + permission-pull-requests: 'write' + + - name: 'Extract command' + id: 'extract_command' + uses: 'actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd' # ratchet:actions/github-script@v8.0.0 + env: + REQUEST: '${{ github.event.comment.body || github.event.review.body }}' + IS_PR: '${{ !!(github.event.pull_request || github.event.issue.pull_request) }}' + with: + script: | + const request = process.env.REQUEST; + const isPr = process.env.IS_PR === 'true'; + core.setOutput('request', request); + + // Ensure request is on a PR targeting the main branch + let baseRef = ''; + if (context.eventName === 'pull_request_review' || context.eventName === 'pull_request_review_comment') { + baseRef = context.payload.pull_request.base.ref; + } else if (context.eventName === 'issue_comment' && context.payload.issue.pull_request) { + const pr = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.payload.issue.number + }); + baseRef = pr.data.base.ref; + } + + if (isPr && baseRef !== 'main') { + console.log(`Skipping: PR targets '${baseRef}', but only 'main' is allowed.`); + core.setOutput('command', 'fallthrough'); + return; + } + + if (request.startsWith("@gemini-cli /review")) { + if (isPr) { + core.setOutput('command', 'review'); + const additionalContext = request.replace(/^@gemini-cli \/review/, '').trim(); + core.setOutput('additional_context', additionalContext); + } else { + core.setOutput('command', 'fallthrough'); + } + } else if (request.startsWith("@gemini-cli")) { + const additionalContext = request.replace(/^@gemini-cli/, '').trim(); + core.setOutput('command', 'invoke'); + core.setOutput('additional_context', additionalContext); + } else { + core.setOutput('command', 'fallthrough'); + } + + - name: 'Acknowledge request' + env: + GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' + ISSUE_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' + MESSAGE: |- + 🤖 Hi @${{ github.actor }}, I've received your request, and I'm working on it now! You can track my progress [in the logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for more details. + REPOSITORY: '${{ github.repository }}' + run: |- + gh issue comment "${ISSUE_NUMBER}" \ + --body "${MESSAGE}" \ + --repo "${REPOSITORY}" + + review: + needs: 'dispatch' + if: |- + ${{ needs.dispatch.outputs.command == 'review' }} + uses: './.github/workflows/gemini-review.yml' + permissions: + contents: 'read' + id-token: 'write' + issues: 'write' + pull-requests: 'write' + with: + additional_context: '${{ needs.dispatch.outputs.additional_context }}' + secrets: 'inherit' + + invoke: + needs: 'dispatch' + if: |- + ${{ needs.dispatch.outputs.command == 'invoke' }} + uses: './.github/workflows/gemini-invoke.yml' + permissions: + contents: 'read' + id-token: 'write' + issues: 'write' + pull-requests: 'write' + with: + additional_context: '${{ needs.dispatch.outputs.additional_context }}' + secrets: 'inherit' + + fallthrough: + needs: + - 'dispatch' + - 'review' + - 'invoke' + if: |- + ${{ always() && !cancelled() && (failure() || needs.dispatch.outputs.command == 'fallthrough') }} + runs-on: 'ubuntu-latest' + permissions: + contents: 'read' + issues: 'write' + pull-requests: 'write' + steps: + - name: 'Mint identity token' + id: 'mint_identity_token' + if: |- + ${{ vars.APP_ID }} + uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 + with: + app-id: '${{ vars.APP_ID }}' + private-key: '${{ secrets.APP_PRIVATE_KEY }}' + permission-contents: 'read' + permission-issues: 'write' + permission-pull-requests: 'write' + + - name: 'Send failure comment' + env: + GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' + ISSUE_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' + MESSAGE: |- + 🤖 I'm sorry @${{ github.actor }}, but I was unable to process your request. Please [see the logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for more details. + REPOSITORY: '${{ github.repository }}' + run: |- + gh issue comment "${ISSUE_NUMBER}" \ + --body "${MESSAGE}" \ + --repo "${REPOSITORY}" diff --git a/.github/workflows/gemini-invoke.yml b/.github/workflows/gemini-invoke.yml new file mode 100644 index 0000000000..5138d6f729 --- /dev/null +++ b/.github/workflows/gemini-invoke.yml @@ -0,0 +1,104 @@ +name: '▶️ Gemini Invoke' + +on: + workflow_call: + inputs: + additional_context: + type: 'string' + description: 'Any additional context from the request' + required: false + +concurrency: + group: '${{ github.workflow }}-invoke-${{ github.event_name }}-${{ github.event.pull_request.number || github.event.issue.number }}' + cancel-in-progress: false + +defaults: + run: + shell: 'bash' + +jobs: + invoke: + runs-on: 'ubuntu-latest' + permissions: + contents: 'read' + id-token: 'write' + issues: 'write' + pull-requests: 'write' + steps: + - name: 'Mint identity token' + id: 'mint_identity_token' + if: |- + ${{ vars.APP_ID }} + uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 + with: + app-id: '${{ vars.APP_ID }}' + private-key: '${{ secrets.APP_PRIVATE_KEY }}' + permission-contents: 'read' + permission-issues: 'write' + permission-pull-requests: 'write' + + - name: 'Checkout Code' + uses: 'actions/checkout@v4' # ratchet:exclude + + - name: 'Run Gemini CLI' + id: 'run_gemini' + uses: 'google-github-actions/run-gemini-cli@v0' # ratchet:exclude + env: + TITLE: '${{ github.event.pull_request.title || github.event.issue.title }}' + DESCRIPTION: '${{ github.event.pull_request.body || github.event.issue.body }}' + EVENT_NAME: '${{ github.event_name }}' + GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' + IS_PULL_REQUEST: '${{ !!github.event.pull_request }}' + ISSUE_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' + REPOSITORY: '${{ github.repository }}' + ADDITIONAL_CONTEXT: '${{ inputs.additional_context }}' + # Required to allow the Gemini CLI to process files in the ephemeral GitHub Actions runner + GEMINI_CLI_TRUST_WORKSPACE: 'true' + with: + gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}' + gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}' + gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}' + gcp_workload_identity_provider: '${{ vars.GCP_WIF_PROVIDER }}' + gemini_api_key: '${{ secrets.GOOGLE_API_KEY }}' + gemini_cli_version: '${{ vars.GEMINI_CLI_VERSION }}' + gemini_debug: '${{ fromJSON(vars.GEMINI_DEBUG || vars.ACTIONS_STEP_DEBUG || false) }}' + gemini_model: '${{ vars.GEMINI_MODEL }}' + google_api_key: '${{ secrets.GOOGLE_API_KEY }}' + use_gemini_code_assist: '${{ vars.GOOGLE_GENAI_USE_GCA }}' + use_vertex_ai: '${{ vars.GOOGLE_GENAI_USE_VERTEXAI }}' + upload_artifacts: '${{ vars.UPLOAD_ARTIFACTS }}' + workflow_name: 'gemini-invoke' + # Assistant workflows can be triggered by comments on either Issues or PRs. + # We explicitly map both fields so the CLI can correctly categorize the interaction. + github_pr_number: '${{ github.event.pull_request.number }}' + github_issue_number: '${{ github.event.issue.number }}' + settings: |- + { + "model": { + "maxSessionTurns": 25 + }, + "telemetry": { + "enabled": true, + "target": "local", + "outfile": ".gemini/telemetry.log" + }, + "mcpServers": { + "github": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "-e", + "GITHUB_PERSONAL_ACCESS_TOKEN", + "ghcr.io/github/github-mcp-server:v0.27.0" + ], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}" + } + } + } + } + prompt: |- + /gemini-invoke + [IMPORTANT] Do not generate execution plans and do not ask for approval (such as suggesting `@gemini-cli /approve`). Perform the requested task or answer the question directly and immediately. diff --git a/.github/workflows/gemini-review.yml b/.github/workflows/gemini-review.yml new file mode 100644 index 0000000000..9c1b1bf442 --- /dev/null +++ b/.github/workflows/gemini-review.yml @@ -0,0 +1,100 @@ +name: '🔎 Gemini Review' + +on: + workflow_call: + inputs: + additional_context: + type: 'string' + description: 'Any additional context from the request' + required: false + +concurrency: + group: '${{ github.workflow }}-review-${{ github.event_name }}-${{ github.event.pull_request.number || github.event.issue.number }}' + cancel-in-progress: true + +defaults: + run: + shell: 'bash' + +jobs: + review: + runs-on: 'ubuntu-latest' + timeout-minutes: 7 + permissions: + contents: 'read' + id-token: 'write' + issues: 'write' + pull-requests: 'write' + steps: + - name: 'Mint identity token' + id: 'mint_identity_token' + if: |- + ${{ vars.APP_ID }} + uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 + with: + app-id: '${{ vars.APP_ID }}' + private-key: '${{ secrets.APP_PRIVATE_KEY }}' + permission-contents: 'read' + permission-issues: 'write' + permission-pull-requests: 'write' + + - name: 'Checkout repository' + uses: 'actions/checkout@v4' # ratchet:exclude + + - name: 'Run Gemini pull request review' + uses: 'google-github-actions/run-gemini-cli@v0' # ratchet:exclude + id: 'gemini_pr_review' + env: + GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' + ISSUE_TITLE: '${{ github.event.pull_request.title || github.event.issue.title }}' + ISSUE_BODY: '${{ github.event.pull_request.body || github.event.issue.body }}' + PULL_REQUEST_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' + REPOSITORY: '${{ github.repository }}' + ADDITIONAL_CONTEXT: '${{ inputs.additional_context }}' + GEMINI_API_KEY: '${{ secrets.GOOGLE_API_KEY }}' + # Required to allow the Gemini CLI to process files in the ephemeral GitHub Actions runner + GEMINI_CLI_TRUST_WORKSPACE: 'true' + with: + gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}' + gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}' + gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}' + gcp_workload_identity_provider: '${{ vars.GCP_WIF_PROVIDER }}' + gemini_api_key: '${{ secrets.GOOGLE_API_KEY }}' + gemini_cli_version: '${{ vars.GEMINI_CLI_VERSION }}' + gemini_debug: '${{ fromJSON(vars.GEMINI_DEBUG || vars.ACTIONS_STEP_DEBUG || false) }}' + gemini_model: '${{ vars.GEMINI_MODEL }}' + google_api_key: '${{ secrets.GOOGLE_API_KEY }}' + use_gemini_code_assist: '${{ vars.GOOGLE_GENAI_USE_GCA }}' + use_vertex_ai: '${{ vars.GOOGLE_GENAI_USE_VERTEXAI }}' + upload_artifacts: '${{ vars.UPLOAD_ARTIFACTS }}' + workflow_name: 'gemini-review' + # Explicitly set the PR number to handle `issue_comment` triggers (which GitHub treats as issues, not PRs) + github_pr_number: '${{ github.event.pull_request.number || github.event.issue.number }}' + settings: |- + { + "model": { + "maxSessionTurns": 25 + }, + "telemetry": { + "enabled": true, + "target": "local", + "outfile": ".gemini/telemetry.log" + }, + "mcpServers": { + "github": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "-e", + "GITHUB_PERSONAL_ACCESS_TOKEN", + "ghcr.io/github/github-mcp-server:v0.27.0" + ], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}" + } + } + } + } + prompt: 'Please use the pull_request_read tool to read pull request #${{ github.event.pull_request.number || github.event.issue.number }}. Analyze the code for bugs, security issues, and best practices. Then, use the add_comment_to_pending_review and pull_request_review_write tools to post your review directly on pull request #${{ github.event.pull_request.number || github.event.issue.number }}.' diff --git a/contributing/samples/gcp_skill_registry_agent/__init__.py b/contributing/samples/gcp_skill_registry_agent/__init__.py new file mode 100644 index 0000000000..4015e47d6e --- /dev/null +++ b/contributing/samples/gcp_skill_registry_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/gcp_skill_registry_agent/agent.py b/contributing/samples/gcp_skill_registry_agent/agent.py new file mode 100644 index 0000000000..558b6ecfff --- /dev/null +++ b/contributing/samples/gcp_skill_registry_agent/agent.py @@ -0,0 +1,40 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample agent demonstrating the use of GCPSkillRegistry.""" + +from google.adk import Agent +from google.adk.integrations.skill_registry import GCPSkillRegistry +from google.adk.tools.skill_toolset import SkillToolset + +# Initialize GCP Skill Registry +registry = GCPSkillRegistry( + project_id="your-project-id", location="us-central1" +) + +# Initialize SkillToolset with registry +skill_toolset = SkillToolset(skills=[], registry=registry) + +root_agent = Agent( + model="gemini-2.5-flash", + name="skill_registry_agent", + description=( + "An agent that can discover and load skills from GCP Skill Registry." + ), + instruction=( + "Use search_skills to find skills and load_skill to load them if" + " needed." + ), + tools=[skill_toolset], +) diff --git a/pyproject.toml b/pyproject.toml index d7bacd8a10..10975105ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -166,7 +166,7 @@ scripts.adk = "google.adk.cli:main" [tool.flit.sdist] include = [ 'src/**/*', 'README.md', 'pyproject.toml', 'LICENSE' ] -exclude = [ 'src/**/*.sh' ] +exclude = [ 'src/**/*.sh', 'src/**/README.md' ] [tool.flit.module] name = "google.adk" diff --git a/src/google/adk/a2a/converters/to_adk_event.py b/src/google/adk/a2a/converters/to_adk_event.py index 26ae95e1b4..eab89c20f5 100644 --- a/src/google/adk/a2a/converters/to_adk_event.py +++ b/src/google/adk/a2a/converters/to_adk_event.py @@ -43,6 +43,10 @@ # Logger logger = logging.getLogger("google_adk." + __name__) +MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT = ( + "mock_function_call_for_required_user_input" +) + A2AMessageToEventConverter = Callable[ [ Message, @@ -276,6 +280,36 @@ def _merge_event_actions( return EventActions.model_validate(merged_actions_data) +def _create_mock_function_call_for_required_user_input( + state: TaskState, + output_parts: list[genai_types.Part], + long_running_function_ids: set[str], +) -> tuple[list[genai_types.Part], set[str]]: + """Creates a mock function call for input/auth-required if applicable. + + This solution allows to unblock the A2A integration with non-ADK agents from + ADK side by replacing the last text part with a synthetic function call. All + other parts are preserved. + """ + if ( + state == TaskState.input_required or state == TaskState.auth_required + ) and (not long_running_function_ids or len(long_running_function_ids) == 0): + # Find the last text part from the bottom to replace it with a function call. + # In case of input-required events, the LLM should stop the production of other parts. + for i in range(len(output_parts) - 1, -1, -1): + if output_parts[i].text: + function_call = genai_types.FunctionCall( + id=str(uuid.uuid4()), + name=MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT, + args={"input_required": output_parts[i].text}, + ) + long_running_function_ids = set() + long_running_function_ids.add(function_call.id) + output_parts[i] = genai_types.Part(function_call=function_call) + break + return output_parts, long_running_function_ids + + @a2a_experimental def convert_a2a_task_to_event( a2a_task: Task, @@ -317,9 +351,9 @@ def convert_a2a_task_to_event( output_parts, _ = _convert_a2a_parts_to_adk_parts( artifact_parts, part_converter ) - if ( - a2a_task.status.message - and a2a_task.status.state == TaskState.input_required + if a2a_task.status.message and ( + a2a_task.status.state == TaskState.input_required + or a2a_task.status.state == TaskState.auth_required ): event_actions = _merge_event_actions( event_actions, @@ -331,6 +365,12 @@ def convert_a2a_task_to_event( output_parts.extend(parts) long_running_function_ids.update(ids) + output_parts, long_running_function_ids = ( + _create_mock_function_call_for_required_user_input( + a2a_task.status.state, output_parts, long_running_function_ids + ) + ) + return _create_event( output_parts, invocation_context, @@ -422,6 +462,14 @@ def convert_a2a_status_update_to_event( output_parts.extend(parts) long_running_function_ids.update(ids) + output_parts, long_running_function_ids = ( + _create_mock_function_call_for_required_user_input( + a2a_status_update.status.state, + output_parts, + long_running_function_ids, + ) + ) + return _create_event( output_parts, invocation_context, diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 2744477e3c..48454a632a 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -35,6 +35,7 @@ from ...runners import Runner from ...sessions.in_memory_session_service import InMemorySessionService from ..executor.a2a_agent_executor import A2aAgentExecutor +from ..executor.config import A2aAgentExecutorConfig from ..experimental import a2a_experimental from .agent_card_builder import AgentCardBuilder @@ -86,6 +87,7 @@ def to_a2a( task_store: TaskStore | None = None, runner: Runner | None = None, lifespan: Callable[[Starlette], AsyncIterator[None]] | None = None, + agent_executor_factory: Callable[[Runner], A2aAgentExecutor] | None = None, ) -> Starlette: """Convert an ADK agent to a A2A Starlette application. @@ -95,20 +97,21 @@ def to_a2a( port: The port for the A2A RPC URL (default: 8000) protocol: The protocol for the A2A RPC URL (default: "http") agent_card: Optional pre-built AgentCard object or path to agent card - JSON. If not provided, will be built automatically from the - agent. + JSON. If not provided, will be built automatically from the agent. push_config_store: Optional A2A push notification config store. If not - provided, an in-memory store will be created so push-notification - config RPC methods are supported. + provided, an in-memory store will be created so push-notification config + RPC methods are supported. task_store: Optional A2A task store for persisting task state. If not provided, an in-memory store will be created. runner: Optional pre-built Runner object. If not provided, a default - runner will be created using in-memory services. - lifespan: Optional async context manager for Starlette lifespan - events. Use this to run startup/shutdown logic (e.g. initializing - database connections or loading resources). The context manager - receives the Starlette app instance and can set state on - ``app.state``. + runner will be created using in-memory services. + lifespan: Optional async context manager for Starlette lifespan events. + Use this to run startup/shutdown logic (e.g. initializing database + connections or loading resources). The context manager receives the + Starlette app instance and can set state on ``app.state``. + agent_executor_factory: Optional factory function that creates an instance + of A2aAgentExecutor. If not provided, a default A2aAgentExecutor will be + created. Returns: A Starlette application that can be run with uvicorn @@ -148,7 +151,7 @@ async def lifespan(app): adk_logger = logging.getLogger("google_adk") adk_logger.setLevel(logging.INFO) - async def create_runner() -> Runner: + def create_runner() -> Runner: """Create a runner for the agent.""" return Runner( app_name=agent.name or "adk_agent", @@ -164,8 +167,10 @@ async def create_runner() -> Runner: if task_store is None: task_store = InMemoryTaskStore() - agent_executor = A2aAgentExecutor( - runner=runner or create_runner, + agent_executor = ( + agent_executor_factory(runner or create_runner()) + if agent_executor_factory is not None + else A2aAgentExecutor(runner=runner or create_runner) ) if push_config_store is None: diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 7fdbaee89b..8dc1a01af2 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -25,8 +25,8 @@ from pydantic import Field from pydantic import PrivateAttr -from ..apps.app import EventsCompactionConfig -from ..apps.app import ResumabilityConfig +from ..apps._configs import EventsCompactionConfig +from ..apps._configs import ResumabilityConfig from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.auth_credential import AuthCredential from ..auth.credential_service.base_credential_service import BaseCredentialService diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 55f57aff05..495a715d76 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -65,6 +65,8 @@ from ..a2a.converters.part_converter import convert_a2a_part_to_genai_part from ..a2a.converters.part_converter import convert_genai_part_to_a2a_part from ..a2a.converters.part_converter import GenAIPartToA2APartConverter +from ..a2a.converters.to_adk_event import _create_mock_function_call_for_required_user_input +from ..a2a.converters.to_adk_event import MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT from ..a2a.converters.utils import _get_adk_metadata_key from ..a2a.experimental import a2a_experimental from ..a2a.logs.log_utils import build_a2a_request_log @@ -105,6 +107,22 @@ class A2AClientError(Exception): pass +def _add_mock_function_call(event: Event, state: TaskState) -> None: + """Generates a mock function call for input-required events if applicable.""" + if event.content is None: + return + + output_parts, long_running_tool_ids = ( + _create_mock_function_call_for_required_user_input( + state, + event.content.parts, + event.long_running_tool_ids, + ) + ) + event.content.parts = output_parts + event.long_running_tool_ids = long_running_tool_ids + + @a2a_experimental class RemoteA2aAgent(BaseAgent): """Agent that communicates with a remote A2A agent via A2A client. @@ -360,8 +378,40 @@ def _create_a2a_request_for_user_function_response( if not function_call_event: return None + event = ctx.session.events[-1] + # If the user function_response replies to a function_call for non-ADK + # input-required events (fc.name = MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT), + # the function_response part is replaced with text extracted from the + # function response. + # The implementation is based on the assumption that the user function_response + # event will contain a function_response with the name + # MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT and the response will + # contain a "result" field with the user input as a string text. + mock_function_call = [ + fc + for fc in function_call_event.get_function_calls() + if fc.name == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT + ] + if mock_function_call: + new_parts = [] + for function_response in event.get_function_responses(): + if ( + function_response.name == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT + and function_response.response + and "result" in function_response.response + ): + text_value = function_response.response.get("result") + new_parts.append( + genai_types.Part( + text=str(text_value), + ) + ) + new_event = event.model_copy(deep=True) + new_event.content.parts = new_parts + event = new_event + a2a_message = convert_event_to_a2a_message( - ctx.session.events[-1], ctx, Role.user, self._genai_part_converter + event, ctx, Role.user, self._genai_part_converter ) if function_call_event.custom_metadata: metadata = function_call_event.custom_metadata @@ -472,6 +522,7 @@ async def _handle_a2a_response( ): for part in event.content.parts: part.thought = True + _add_mock_function_call(event, task.status.state) elif ( isinstance(update, A2ATaskStatusUpdateEvent) and update.status @@ -487,6 +538,7 @@ async def _handle_a2a_response( ): for part in event.content.parts: part.thought = True + _add_mock_function_call(event, update.status.state) elif isinstance(update, A2ATaskArtifactUpdateEvent) and ( not update.append or update.last_chunk ): diff --git a/src/google/adk/apps/__init__.py b/src/google/adk/apps/__init__.py index 3a5d0b0643..319293967b 100644 --- a/src/google/adk/apps/__init__.py +++ b/src/google/adk/apps/__init__.py @@ -12,10 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .app import App -from .app import ResumabilityConfig +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._configs import ResumabilityConfig + from .app import App __all__ = [ 'App', 'ResumabilityConfig', ] + +_LAZY_MEMBERS: dict[str, str] = { + 'App': 'app', + 'ResumabilityConfig': '_configs', +} + + +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/apps/_configs.py b/src/google/adk/apps/_configs.py new file mode 100644 index 0000000000..87f3666ebd --- /dev/null +++ b/src/google/adk/apps/_configs.py @@ -0,0 +1,95 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field +from pydantic import model_validator + +from ..utils.feature_decorator import experimental +from .base_events_summarizer import BaseEventsSummarizer + + +@experimental +class ResumabilityConfig(BaseModel): + """The config of the resumability for an application. + + The "resumability" in ADK refers to the ability to: + 1. pause an invocation upon a long-running function call. + 2. resume an invocation from the last event, if it's paused or failed midway + through. + + Note: ADK resumes the invocation in a best-effort manner: + 1. Tool call to resume needs to be idempotent because we only guarantee + an at-least-once behavior once resumed. + 2. Any temporary / in-memory state will be lost upon resumption. + """ + + is_resumable: bool = False + """Whether the app supports agent resumption. + If enabled, the feature will be enabled for all agents in the app. + """ + + +@experimental +class EventsCompactionConfig(BaseModel): + """The config of event compaction for an application.""" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + + summarizer: Optional[BaseEventsSummarizer] = None + """The event summarizer to use for compaction.""" + + compaction_interval: int + """The number of *new* user-initiated invocations that, once + fully represented in the session's events, will trigger a compaction.""" + + overlap_size: int + """The number of preceding invocations to include from the + end of the last compacted range. This creates an overlap between consecutive + compacted summaries, maintaining context.""" + + token_threshold: Optional[int] = Field( + default=None, + gt=0, + ) + """Post-invocation token threshold trigger. + + If set, ADK will attempt a post-invocation compaction when the most recently + observed prompt token count meets or exceeds this threshold. + """ + + event_retention_size: Optional[int] = Field(default=None, ge=0) + """Post-invocation raw event retention size. + + If token-based post-invocation compaction is triggered, this keeps the last N + raw events un-compacted. + """ + + @model_validator(mode="after") + def _validate_token_params(self) -> EventsCompactionConfig: + token_threshold_set = self.token_threshold is not None + retention_size_set = self.event_retention_size is not None + if token_threshold_set != retention_size_set: + raise ValueError( + "token_threshold and event_retention_size must be set together." + ) + return self diff --git a/src/google/adk/apps/app.py b/src/google/adk/apps/app.py index c20d581d9b..9bde128b7a 100644 --- a/src/google/adk/apps/app.py +++ b/src/google/adk/apps/app.py @@ -22,9 +22,16 @@ from ..agents.base_agent import BaseAgent from ..agents.context_cache_config import ContextCacheConfig -from ..apps.base_events_summarizer import BaseEventsSummarizer from ..plugins.base_plugin import BasePlugin -from ..utils.feature_decorator import experimental +from ._configs import EventsCompactionConfig +from ._configs import ResumabilityConfig + +__all__ = [ + "App", + "EventsCompactionConfig", + "ResumabilityConfig", + "validate_app_name", +] def validate_app_name(name: str) -> None: @@ -38,76 +45,6 @@ def validate_app_name(name: str) -> None: raise ValueError("App name cannot be 'user'; reserved for end-user input.") -@experimental -class ResumabilityConfig(BaseModel): - """The config of the resumability for an application. - - The "resumability" in ADK refers to the ability to: - 1. pause an invocation upon a long-running function call. - 2. resume an invocation from the last event, if it's paused or failed midway - through. - - Note: ADK resumes the invocation in a best-effort manner: - 1. Tool call to resume needs to be idempotent because we only guarantee - an at-least-once behavior once resumed. - 2. Any temporary / in-memory state will be lost upon resumption. - """ - - is_resumable: bool = False - """Whether the app supports agent resumption. - If enabled, the feature will be enabled for all agents in the app. - """ - - -@experimental -class EventsCompactionConfig(BaseModel): - """The config of event compaction for an application.""" - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - ) - - summarizer: Optional[BaseEventsSummarizer] = None - """The event summarizer to use for compaction.""" - - compaction_interval: int - """The number of *new* user-initiated invocations that, once - fully represented in the session's events, will trigger a compaction.""" - - overlap_size: int - """The number of preceding invocations to include from the - end of the last compacted range. This creates an overlap between consecutive - compacted summaries, maintaining context.""" - - token_threshold: Optional[int] = Field( - default=None, - gt=0, - ) - """Post-invocation token threshold trigger. - - If set, ADK will attempt a post-invocation compaction when the most recently - observed prompt token count meets or exceeds this threshold. - """ - - event_retention_size: Optional[int] = Field(default=None, ge=0) - """Post-invocation raw event retention size. - - If token-based post-invocation compaction is triggered, this keeps the last N - raw events un-compacted. - """ - - @model_validator(mode="after") - def _validate_token_params(self) -> EventsCompactionConfig: - token_threshold_set = self.token_threshold is not None - retention_size_set = self.event_retention_size is not None - if token_threshold_set != retention_size_set: - raise ValueError( - "token_threshold and event_retention_size must be set together." - ) - return self - - class App(BaseModel): """Represents an LLM-backed agentic application. diff --git a/src/google/adk/apps/compaction.py b/src/google/adk/apps/compaction.py index b3b49bbd0b..1e33003a95 100644 --- a/src/google/adk/apps/compaction.py +++ b/src/google/adk/apps/compaction.py @@ -270,6 +270,9 @@ def _events_to_compact_for_token_threshold( events_to_compact = _truncate_events_before_pending_function_call( events_to_compact, pending_ids ) + events_to_compact = _truncate_events_before_hitl_signal( + events_to_compact, _resolved_hitl_call_ids(events) + ) if not events_to_compact: return [] @@ -344,6 +347,45 @@ def _truncate_events_before_pending_function_call( return events +def _resolved_hitl_call_ids(events: list[Event]) -> set[str]: + """Returns HITL call ids resolved by a later function_response in `events`.""" + hitl_position: dict[str, int] = {} + resolved: set[str] = set() + for index, event in enumerate(events): + if event.actions: + for call_id in event.actions.requested_tool_confirmations: + hitl_position.setdefault(call_id, index) + for call_id in event.actions.requested_auth_configs: + hitl_position.setdefault(call_id, index) + for resp_id in _event_function_response_ids(event): + hitl_pos = hitl_position.get(resp_id) + if hitl_pos is not None and index > hitl_pos: + resolved.add(resp_id) + return resolved + + +def _is_pending_hitl(event: Event, resolved_call_ids: set[str]) -> bool: + """Returns True if the event has an HITL request not in `resolved_call_ids`.""" + if not event.actions: + return False + requested = set(event.actions.requested_tool_confirmations) | set( + event.actions.requested_auth_configs + ) + if not requested: + return False + return bool(requested - resolved_call_ids) + + +def _truncate_events_before_hitl_signal( + events: list[Event], resolved_call_ids: set[str] +) -> list[Event]: + """Returns the leading contiguous events before any pending HITL request.""" + for index, event in enumerate(events): + if _is_pending_hitl(event, resolved_call_ids): + return events[:index] + return events + + def _safe_token_compaction_split_index( *, candidate_events: list[Event], @@ -631,6 +673,9 @@ async def _run_compaction_for_sliding_window( events_to_compact = _truncate_events_before_pending_function_call( events_to_compact, pending_ids ) + events_to_compact = _truncate_events_before_hitl_signal( + events_to_compact, _resolved_hitl_call_ids(events) + ) if not events_to_compact: return None diff --git a/src/google/adk/artifacts/__init__.py b/src/google/adk/artifacts/__init__.py index 5e56ffc737..af7912e617 100644 --- a/src/google/adk/artifacts/__init__.py +++ b/src/google/adk/artifacts/__init__.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + from .base_artifact_service import BaseArtifactService -from .file_artifact_service import FileArtifactService -from .gcs_artifact_service import GcsArtifactService -from .in_memory_artifact_service import InMemoryArtifactService + +if TYPE_CHECKING: + from .file_artifact_service import FileArtifactService + from .gcs_artifact_service import GcsArtifactService + from .in_memory_artifact_service import InMemoryArtifactService __all__ = [ 'BaseArtifactService', @@ -23,3 +30,16 @@ 'GcsArtifactService', 'InMemoryArtifactService', ] + +_LAZY_MEMBERS: dict[str, str] = { + 'FileArtifactService': 'file_artifact_service', + 'GcsArtifactService': 'gcs_artifact_service', + 'InMemoryArtifactService': 'in_memory_artifact_service', +} + + +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 19a7dde9ae..4350ae17ce 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -598,18 +598,12 @@ def _get_service_option_by_adk_version( parsed_version = parse(adk_version) options: list[str] = [] - if parsed_version >= parse('1.3.0'): - if session_uri: - options.append(f'--session_service_uri={session_uri}') - if artifact_uri: - options.append(f'--artifact_service_uri={artifact_uri}') - if memory_uri: - options.append(f'--memory_service_uri={memory_uri}') - else: - if session_uri: - options.append(f'--session_db_url={session_uri}') - if parsed_version >= parse('1.2.0') and artifact_uri: - options.append(f'--artifact_storage_uri={artifact_uri}') + if session_uri: + options.append(f'--session_service_uri={session_uri}') + if artifact_uri: + options.append(f'--artifact_service_uri={artifact_uri}') + if memory_uri: + options.append(f'--memory_service_uri={memory_uri}') if use_local_storage is not None and parsed_version >= parse( _LOCAL_STORAGE_FLAG_MIN_VERSION diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 07ccc15892..fda251da10 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -1411,43 +1411,6 @@ def _deprecate_staging_bucket(ctx, param, value): return value -def deprecated_adk_services_options(): - """Deprecated ADK services options.""" - - def warn(alternative_param, ctx, param, value): - if value: - click.echo( - click.style( - f"WARNING: Deprecated option --{param.name} is used. Please use" - f" {alternative_param} instead.", - fg="yellow", - ), - err=True, - ) - return value - - def decorator(func): - @click.option( - "--session_db_url", - help="Deprecated. Use --session_service_uri instead.", - callback=functools.partial(warn, "--session_service_uri"), - ) - @click.option( - "--artifact_storage_uri", - type=str, - help="Deprecated. Use --artifact_service_uri instead.", - callback=functools.partial(warn, "--artifact_service_uri"), - default=None, - ) - @functools.wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - return wrapper - - return decorator - - def fast_api_common_options(): """Decorator to add common fast api options to click commands.""" @@ -1598,7 +1561,6 @@ def wrapper(ctx, *args, **kwargs): @fast_api_common_options() @web_options() @adk_services_options(default_use_local_storage=True) -@deprecated_adk_services_options() @click.argument( "agents_dir", type=click.Path( @@ -1621,8 +1583,6 @@ def cli_web( artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, use_local_storage: bool = True, - session_db_url: Optional[str] = None, # Deprecated - artifact_storage_uri: Optional[str] = None, # Deprecated a2a: bool = False, reload_agents: bool = False, extra_plugins: Optional[list[str]] = None, @@ -1639,8 +1599,6 @@ def cli_web( adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir """ - session_service_uri = session_service_uri or session_db_url - artifact_service_uri = artifact_service_uri or artifact_storage_uri logs.setup_adk_logger(getattr(logging, log_level.upper())) @asynccontextmanager @@ -1711,7 +1669,6 @@ async def _lifespan(app: FastAPI): ) @fast_api_common_options() @adk_services_options(default_use_local_storage=True) -@deprecated_adk_services_options() @click.option( "--auto_create_session", is_flag=True, @@ -1735,8 +1692,6 @@ def cli_api_server( artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, use_local_storage: bool = True, - session_db_url: Optional[str] = None, # Deprecated - artifact_storage_uri: Optional[str] = None, # Deprecated a2a: bool = False, reload_agents: bool = False, extra_plugins: Optional[list[str]] = None, @@ -1752,8 +1707,6 @@ def cli_api_server( adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir """ - session_service_uri = session_service_uri or session_db_url - artifact_service_uri = artifact_service_uri or artifact_storage_uri logs.setup_adk_logger(getattr(logging, log_level.upper())) config = uvicorn.Config( @@ -1882,11 +1835,6 @@ def cli_api_server( default="INFO", help="Optional. Set the logging level", ) -@click.option( - "--verbosity", - type=LOG_LEVELS, - help="Deprecated. Use --log_level instead.", -) @click.argument( "agent", type=click.Path( @@ -1932,7 +1880,6 @@ def cli_api_server( ) # TODO: Add eval_storage_uri option back when evals are supported in Cloud Run. @adk_services_options(default_use_local_storage=False) -@deprecated_adk_services_options() @click.pass_context def cli_deploy_cloud_run( ctx, @@ -1948,14 +1895,11 @@ def cli_deploy_cloud_run( with_ui: bool, adk_version: str, log_level: str, - verbosity: Optional[str], allow_origins: Optional[list[str]] = None, session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, use_local_storage: bool = False, - session_db_url: Optional[str] = None, # Deprecated - artifact_storage_uri: Optional[str] = None, # Deprecated a2a: bool = False, trigger_sources: Optional[str] = None, ): @@ -1972,19 +1916,9 @@ def cli_deploy_cloud_run( adk deploy cloud_run --project=[project] --region=[region] path/to/my_agent -- --no-allow-unauthenticated --min-instances=2 """ - if verbosity: - click.secho( - "WARNING: The --verbosity option is deprecated. Use --log_level" - " instead.", - fg="yellow", - err=True, - ) _warn_if_with_ui(with_ui) - session_service_uri = session_service_uri or session_db_url - artifact_service_uri = artifact_service_uri or artifact_storage_uri - # Parse arguments to separate gcloud args (after --) from regular args gcloud_args = [] if "--" in ctx.args: @@ -2028,7 +1962,7 @@ def cli_deploy_cloud_run( allow_origins=allow_origins, with_ui=with_ui, log_level=log_level, - verbosity=verbosity, + verbosity=log_level, adk_version=adk_version, session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, diff --git a/src/google/adk/evaluation/base_eval_service.py b/src/google/adk/evaluation/base_eval_service.py index bafe5846c2..927dd8cd04 100644 --- a/src/google/adk/evaluation/base_eval_service.py +++ b/src/google/adk/evaluation/base_eval_service.py @@ -25,6 +25,7 @@ from pydantic import ConfigDict from pydantic import Field +from .constants import DEFAULT_LIVE_TIMEOUT_SECONDS from .eval_case import Invocation from .eval_metrics import EvalMetric from .eval_result import EvalCaseResult @@ -81,6 +82,18 @@ class InferenceConfig(BaseModel): could also overwhelm those tools.""", ) + use_live: bool = Field( + default=False, + description="""Whether to use live (bidirectional streaming) mode for +inference. This is required for Live API models (e.g., gemini-*-live-*).""", + ) + + live_timeout_seconds: int = Field( + default=DEFAULT_LIVE_TIMEOUT_SECONDS, + description="""Timeout in seconds for waiting for model turn completion in +live mode.""", + ) + class InferenceRequest(BaseModel): """Represent a request to perform inferences for the eval cases in an eval set.""" diff --git a/src/google/adk/evaluation/constants.py b/src/google/adk/evaluation/constants.py index 5aed2b101e..e7ee1f24d2 100644 --- a/src/google/adk/evaluation/constants.py +++ b/src/google/adk/evaluation/constants.py @@ -18,3 +18,5 @@ 'Eval module is not installed, please install via `pip install' ' "google-adk[eval]"`.' ) + +DEFAULT_LIVE_TIMEOUT_SECONDS = 300 diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index f8fb6795aa..d5a6629366 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import copy import importlib import logging @@ -22,15 +23,26 @@ from typing import Optional import uuid +from google.genai import errors +from google.genai import types from google.genai.types import Content from pydantic import BaseModel +from websockets.exceptions import ConnectionClosed +from websockets.exceptions import ConnectionClosedOK +from ..agents.callback_context import CallbackContext +from ..agents.invocation_context import InvocationContext +from ..agents.live_request_queue import LiveRequestQueue from ..agents.llm_agent import Agent +from ..agents.run_config import RunConfig +from ..agents.run_config import StreamingMode from ..artifacts.base_artifact_service import BaseArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..events.event import Event +from ..flows.llm_flows.functions import handle_function_calls_live from ..memory.base_memory_service import BaseMemoryService from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..models.llm_request import LlmRequest from ..runners import Runner from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService @@ -39,6 +51,7 @@ from ._retry_options_utils import EnsureRetryOptionsPlugin from .app_details import AgentDetails from .app_details import AppDetails +from .constants import DEFAULT_LIVE_TIMEOUT_SECONDS from .eval_case import EvalCase from .eval_case import Invocation from .eval_case import InvocationEvent @@ -66,6 +79,181 @@ class EvalCaseResponses(BaseModel): responses: list[list[Invocation]] +class _LiveSession: + """Manages the background task and state for a live session.""" + + def __init__( + self, + runner: Runner, + session: Session, + user_id: str, + session_id: str, + ): + self.runner = runner + self.session = session + self.user_id = user_id + self.session_id = session_id + self.live_request_queue = LiveRequestQueue() + self.event_queue = asyncio.Queue() + self.turn_complete_event = asyncio.Event() + self.live_finished = asyncio.Event() + self.current_invocation_id = Event.new_id() + self.consume_task = None + + async def __aenter__(self) -> _LiveSession: + """Starts the background task.""" + self.consume_task = asyncio.create_task(self._consume_events()) + return self + + async def _consume_events(self) -> None: + """Background task: consume events from run_live.""" + try: + run_config = RunConfig( + streaming_mode=StreamingMode.BIDI, + response_modalities=["AUDIO"], + output_audio_transcription=types.AudioTranscriptionConfig(), + input_audio_transcription=types.AudioTranscriptionConfig(), + ) + + invocation_context = self.runner._new_invocation_context_for_live( + self.session, + live_request_queue=self.live_request_queue, + run_config=run_config, + ) + invocation_context.agent = self.runner._find_agent_to_run( + self.session, self.runner.agent + ) + + callback_context = None + llm_request = LlmRequest() + + async with Aclosing( + invocation_context.agent._llm_flow._preprocess_async( + invocation_context, llm_request + ) + ) as agen: + async for _ in agen: + pass + + callback_context = CallbackContext(invocation_context) + # By default, live API calls do not include before_model_callback and + # after_model_callback. These callbacks are needed by the plugins to + # include the agent instructions and tool declarations in the eval + # invocations for autorater evaluation. + await invocation_context.plugin_manager.run_before_model_callback( + callback_context=callback_context, + llm_request=llm_request, + ) + + in_function_call_loop = False + async with Aclosing( + invocation_context.agent.run_live(invocation_context) + ) as agen: + async for event in agen: + assert event is not None + event.invocation_id = self.current_invocation_id + if callback_context: + await invocation_context.plugin_manager.run_after_model_callback( + callback_context=callback_context, + llm_response=event, + ) + await self.event_queue.put(event) + if not event.partial: + await self.runner.session_service.append_event( + session=self.session, event=event + ) + function_calls = event.get_function_calls() + if function_calls: + in_function_call_loop = True + inv_context = InvocationContext( + session_service=self.runner.session_service, + invocation_id=event.invocation_id, + agent=self.runner.agent, + session=self.session, + run_config=run_config, + ) + + if isinstance(self.runner.agent, Agent): + resolved_tools = await self.runner.agent.canonical_tools( + inv_context + ) + tools_dict = {t.name: t for t in resolved_tools} + else: + tools_dict = {} + + try: + response_event = await handle_function_calls_live( + invocation_context=inv_context, + function_call_event=event, + tools_dict=tools_dict, + ) + + if ( + response_event + and response_event.content + and response_event.content.parts + ): + for part in response_event.content.parts: + if part.function_response: + tool_content = types.Content( + role="tool", + parts=[part], + ) + self.live_request_queue.send_content(tool_content) + except (ValueError, RuntimeError, KeyError, TypeError) as e: + logger.error( + "Failed to handle function calls: %s", + e, + exc_info=True, + ) + for fc in function_calls: + response_content = types.FunctionResponse( + name=fc.name, + id=fc.id, + response={"error": str(e)}, + ) + tool_content = types.Content( + role="tool", + parts=[types.Part(function_response=response_content)], + ) + self.live_request_queue.send_content(tool_content) + if event.turn_complete and event.author != _USER_AUTHOR: + if not in_function_call_loop: + self.turn_complete_event.set() + else: + in_function_call_loop = False + finally: + self.live_finished.set() + self.turn_complete_event.set() # Unblock any waiters + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Closes the queue and waits for the background task to finish.""" + self.live_request_queue.close() + try: + await asyncio.wait_for(self.consume_task, timeout=30) + except asyncio.TimeoutError: + logger.warning("Timed out waiting for run_live to finish.") + assert self.consume_task is not None + self.consume_task.cancel() + try: + await self.consume_task + except asyncio.CancelledError: + pass + except (ConnectionClosed, errors.APIError) as e: + # The Gemini Live API uses WebSockets. When the session ends normally, the + # connection is closed with code 1000. Some client libraries may raise an + # exception rather than handling it silently. We log this as INFO to + # avoid false-positive error reports for expected behavior. + is_normal_closure = isinstance(e, ConnectionClosedOK) or ( + isinstance(e, errors.APIError) and e.code == 1000 + ) + + if is_normal_closure: + logger.info("Ignored WebSocket normal closure exception: %s", e) + else: + raise + + class EvaluationGenerator: """Generates evaluation responses for agents.""" @@ -187,6 +375,168 @@ async def _generate_inferences_for_single_user_invocation( yield event + @staticmethod + async def _generate_inferences_for_single_user_invocation_live( + live_request_queue: LiveRequestQueue, + event_queue: asyncio.Queue[Event], + user_message: Content, + current_invocation_id: str, + turn_complete_event: asyncio.Event, + live_timeout_seconds: int, + ) -> AsyncGenerator[Event, None]: + """Generates inferences for a single user invocation in live mode.""" + yield Event( + content=user_message, + author=_USER_AUTHOR, + invocation_id=current_invocation_id, + ) + + live_request_queue.send_content(user_message) + + try: + await asyncio.wait_for( + turn_complete_event.wait(), + timeout=live_timeout_seconds, + ) + except asyncio.TimeoutError: + logger.warning( + "Timed out waiting for model turn completion in live mode." + ) + raise + + while not event_queue.empty(): + event = await event_queue.get() + if event.invocation_id == current_invocation_id: + yield event + + @staticmethod + async def _generate_inferences_from_root_agent_live( + root_agent: Agent, + user_simulator: UserSimulator, + reset_func: Optional[Any] = None, + initial_session: Optional[SessionInput] = None, + session_id: Optional[str] = None, + session_service: Optional[BaseSessionService] = None, + artifact_service: Optional[BaseArtifactService] = None, + memory_service: Optional[BaseMemoryService] = None, + live_timeout_seconds: int = DEFAULT_LIVE_TIMEOUT_SECONDS, + ) -> list[Invocation]: + """Scrapes the root agent in coordination with the user simulator in live mode.""" + if not session_service: + session_service = InMemorySessionService() + + if not memory_service: + memory_service = InMemoryMemoryService() + + app_name = ( + initial_session.app_name if initial_session else "EvaluationGenerator" + ) + user_id = initial_session.user_id if initial_session else "test_user_id" + session_id = session_id if session_id else str(uuid.uuid4()) + + session = await session_service.create_session( + app_name=app_name, + user_id=user_id, + state=initial_session.state if initial_session else {}, + session_id=session_id, + ) + + if not artifact_service: + artifact_service = InMemoryArtifactService() + + # Reset agent state for each query + if callable(reset_func): + reset_func() + + # We ensure that there is some kind of retries on the llm_requests that are + # generated from the Agent. This is done to make inferencing step of evals + # more resilient to temporary model failures. + ensure_retry_options_plugin = EnsureRetryOptionsPlugin( + name="ensure_retry_options" + ) + request_intercepter_plugin = _RequestIntercepterPlugin( + name="request_intercepter_plugin" + ) + async with Runner( + app_name=app_name, + agent=root_agent, + artifact_service=artifact_service, + session_service=session_service, + memory_service=memory_service, + plugins=[request_intercepter_plugin, ensure_retry_options_plugin], + ) as runner: + events = [] + + # `_LiveSession` is a runtime connection manager wrapping the `Session` + # data model (which stores conversation history/state). It manages the + # active bidirectional WebSocket stream and background consumer tasks. + live_session = _LiveSession(runner, session, user_id, session_id) + await live_session.__aenter__() + + try: + turn_idx = 0 + while True: + turn_idx += 1 + next_user_message = await user_simulator.get_next_user_message( + copy.deepcopy(events) + ) + if next_user_message.status == UserSimulatorStatus.SUCCESS: + live_session.current_invocation_id = Event.new_id() + live_session.turn_complete_event.clear() + + logger.info("Waiting for model to complete turn %d...", turn_idx) + + async for ( + event + ) in EvaluationGenerator._generate_inferences_for_single_user_invocation_live( + live_request_queue=live_session.live_request_queue, + event_queue=live_session.event_queue, + user_message=next_user_message.user_message, + current_invocation_id=live_session.current_invocation_id, + turn_complete_event=live_session.turn_complete_event, + live_timeout_seconds=live_timeout_seconds, + ): + events.append(event) + + turn_transcription = "" + for evt in events: + if ( + evt.invocation_id == live_session.current_invocation_id + and evt.author != _USER_AUTHOR + and evt.output_transcription + ): + if not evt.partial and evt.output_transcription.text: + turn_transcription = evt.output_transcription.text + else: + turn_transcription += evt.output_transcription.text + if turn_transcription: + synthetic_event = Event( + content=Content( + role="model", + parts=[types.Part(text=turn_transcription)], + ), + author=runner.agent.name, + invocation_id=live_session.current_invocation_id, + ) + events.append(synthetic_event) + + if live_session.live_finished.is_set(): + logger.info("Live session finished signal detected.") + break + else: # no message generated + break + finally: + await live_session.__aexit__(None, None, None) + + app_details_by_invocation_id = ( + EvaluationGenerator._get_app_details_by_invocation_id( + events, request_intercepter_plugin + ) + ) + return EvaluationGenerator.convert_events_to_eval_invocations( + events, app_details_by_invocation_id + ) + @staticmethod async def _generate_inferences_from_root_agent( root_agent: Agent, @@ -308,7 +658,12 @@ def convert_events_to_eval_invocations( final_event = event for p in event.content.parts: - if p.function_call or p.function_response or p.text: + if ( + p.function_call + or p.function_response + or p.text + or p.inline_data + ): events_to_add.append(event) break diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index 2426204ca0..b749b7e8f1 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -182,6 +182,8 @@ async def run_inference(eval_case): eval_set_id=inference_request.eval_set_id, eval_case=eval_case, root_agent=self._root_agent, + use_live=inference_request.inference_config.use_live, + live_timeout_seconds=inference_request.inference_config.live_timeout_seconds, ) inference_results = [run_inference(eval_case) for eval_case in eval_cases] @@ -470,6 +472,8 @@ async def _perform_inference_single_eval_item( eval_set_id: str, eval_case: EvalCase, root_agent: BaseAgent, + use_live: bool, + live_timeout_seconds: int, ) -> InferenceResult: initial_session = eval_case.session_input session_id = self._session_id_supplier() @@ -482,17 +486,31 @@ async def _perform_inference_single_eval_item( try: with client_label_context(EVAL_CLIENT_LABEL): - inferences = ( - await EvaluationGenerator._generate_inferences_from_root_agent( - root_agent=root_agent, - user_simulator=self._user_simulator_provider.provide(eval_case), - initial_session=initial_session, - session_id=session_id, - session_service=self._session_service, - artifact_service=self._artifact_service, - memory_service=self._memory_service, - ) - ) + if use_live: + inferences = await EvaluationGenerator._generate_inferences_from_root_agent_live( + root_agent=root_agent, + user_simulator=self._user_simulator_provider.provide(eval_case), + initial_session=initial_session, + session_id=session_id, + session_service=self._session_service, + artifact_service=self._artifact_service, + memory_service=self._memory_service, + live_timeout_seconds=live_timeout_seconds, + ) + else: + inferences = ( + await EvaluationGenerator._generate_inferences_from_root_agent( + root_agent=root_agent, + user_simulator=self._user_simulator_provider.provide( + eval_case + ), + initial_session=initial_session, + session_id=session_id, + session_service=self._session_service, + artifact_service=self._artifact_service, + memory_service=self._memory_service, + ) + ) inference_result.inferences = inferences inference_result.status = InferenceStatus.SUCCESS diff --git a/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py b/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py index a95eb87d88..239fa31a71 100644 --- a/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py +++ b/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py @@ -325,6 +325,10 @@ async def _evaluate_intermediate_turn( previous_invocations=invocation_history, ) + config = ( + self._llm_options.judge_model_config + or genai_types.GenerateContentConfig() + ) llm_request = LlmRequest( model=self._llm_options.judge_model, contents=[ @@ -333,7 +337,7 @@ async def _evaluate_intermediate_turn( role="user", ) ], - config=self._llm_options.judge_model_config, + config=config, ) add_default_retry_options_if_not_present(llm_request) num_samples = self._llm_options.num_samples diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py index 932a265ed1..cc3fc9e6fa 100644 --- a/src/google/adk/flows/llm_flows/single_flow.py +++ b/src/google/adk/flows/llm_flows/single_flow.py @@ -22,7 +22,6 @@ from . import _nl_planning from . import _output_schema_processor from . import basic -from . import compaction from . import contents from . import context_cache_processor from . import identity @@ -36,6 +35,7 @@ def _create_request_processors(): """Create the standard request processor list for a single-agent flow.""" + from . import compaction from ...auth import auth_preprocessor return [ diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py index 887c894cd4..a486215151 100644 --- a/src/google/adk/integrations/agent_registry/agent_registry.py +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -208,7 +208,10 @@ def _get_auth_headers(self) -> Dict[str, str]: "Authorization": f"Bearer {self._credentials.token}", "Content-Type": "application/json", } - quota_project_id = getattr(self._credentials, "quota_project_id", None) + quota_project_id = ( + getattr(self._credentials, "quota_project_id", None) + or self.project_id + ) if quota_project_id: headers["x-goog-user-project"] = quota_project_id return headers diff --git a/src/google/adk/integrations/skill_registry/__init__.py b/src/google/adk/integrations/skill_registry/__init__.py new file mode 100644 index 0000000000..5cfd76a29d --- /dev/null +++ b/src/google/adk/integrations/skill_registry/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Skill Registry integrations.""" + +from .gcp_skill_registry import GCPSkillRegistry + +__all__ = ["GCPSkillRegistry"] diff --git a/src/google/adk/integrations/skill_registry/gcp_skill_registry.py b/src/google/adk/integrations/skill_registry/gcp_skill_registry.py new file mode 100644 index 0000000000..277913c1b4 --- /dev/null +++ b/src/google/adk/integrations/skill_registry/gcp_skill_registry.py @@ -0,0 +1,93 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GCP Skill Registry implementation.""" + +from __future__ import annotations + +import asyncio +import base64 +import os + +from google.adk.skills import _utils +from google.adk.skills import models +from google.adk.skills.skill_registry import SkillRegistry +import vertexai + + +class GCPSkillRegistry(SkillRegistry): + """GCP implementation of SkillRegistry using GCP Skill Registry API.""" + + def __init__( + self, *, project_id: str | None = None, location: str | None = None + ): + """Initializes the GCP Skill Registry. + + Args: + project_id: Optional GCP project ID. If omitted, loads from environment. + location: Optional GCP location. If omitted, loads from environment. + """ + self.project_id = project_id or os.environ.get("GOOGLE_CLOUD_PROJECT") + self.location = location or os.environ.get("GOOGLE_CLOUD_LOCATION") + self._client = vertexai.Client( + project=self.project_id, + location=self.location, + http_options={ + "api_version": "v1beta1", + }, + ).aio + + async def get_skill(self, *, name: str) -> models.Skill: + """Fetches a skill from the registry. + + Args: + name: The name of the skill. + + Returns: + A Skill object. + """ + full_name = ( + f"projects/{self.project_id}/locations/{self.location}/skills/{name}" + ) + skill_resource = await self._client.skills.get(name=full_name) + + zip_bytes_base64 = skill_resource.zipped_filesystem + if not zip_bytes_base64: + raise ValueError(f"Skill '{name}' does not contain zipped filesystem.") + + zip_bytes = base64.b64decode(zip_bytes_base64) + + return await asyncio.to_thread(_utils._load_skill_from_zip_bytes, zip_bytes) + + async def search_skills(self, *, query: str) -> list[models.Frontmatter]: + """Searches for skills in the registry. + + Args: + query: The search query. + + Returns: + A list of Frontmatter objects for discovery. + """ + response = await self._client.skills.retrieve(query=query) + + results = [] + if response.retrieved_skills: + for s in response.retrieved_skills: + results.append( + models.Frontmatter( + name=s.skill_name.split("/")[-1] if s.skill_name else "", + description=s.description or "", + ) + ) + return results diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index c47fb8ec40..d40f3bf7d9 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -11,27 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING from .base_memory_service import BaseMemoryService -from .in_memory_memory_service import InMemoryMemoryService -from .vertex_ai_memory_bank_service import VertexAiMemoryBankService -logger = logging.getLogger('google_adk.' + __name__) +if TYPE_CHECKING: + from .in_memory_memory_service import InMemoryMemoryService + from .vertex_ai_memory_bank_service import VertexAiMemoryBankService + from .vertex_ai_rag_memory_service import VertexAiRagMemoryService __all__ = [ 'BaseMemoryService', 'InMemoryMemoryService', 'VertexAiMemoryBankService', + 'VertexAiRagMemoryService', ] -try: - from .vertex_ai_rag_memory_service import VertexAiRagMemoryService +_LAZY_MEMBERS: dict[str, str] = { + 'InMemoryMemoryService': 'in_memory_memory_service', + 'VertexAiMemoryBankService': 'vertex_ai_memory_bank_service', + 'VertexAiRagMemoryService': 'vertex_ai_rag_memory_service', +} + - __all__.append('VertexAiRagMemoryService') -except ImportError: - logger.debug( - 'The Vertex SDK is not installed. If you want to use the' - ' VertexAiRagMemoryService please install it. If not, you can ignore this' - ' warning.' - ) +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index a175ce5f1a..ec1f38a895 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -219,20 +219,27 @@ def _part_to_message_block( content = "" response_data = part.function_response.response - # Handle response with content array - if "content" in response_data and response_data["content"]: + if ( + "content" in response_data + and isinstance(response_data["content"], list) + and response_data["content"] + ): content_items = [] for item in response_data["content"]: if isinstance(item, dict): - # Handle text content blocks if item.get("type") == "text" and "text" in item: content_items.append(item["text"]) else: - # Handle other structured content content_items.append(str(item)) else: content_items.append(str(item)) content = "\n".join(content_items) if content_items else "" + elif ( + "content" in response_data + and isinstance(response_data["content"], str) + and response_data["content"] + ): + content = response_data["content"] # We serialize to str here # SDK ref: anthropic.types.tool_result_block_param # https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/tool_result_block_param.py diff --git a/src/google/adk/models/interactions_utils.py b/src/google/adk/models/interactions_utils.py index be1c2a6ab2..89ffe6be71 100644 --- a/src/google/adk/models/interactions_utils.py +++ b/src/google/adk/models/interactions_utils.py @@ -111,12 +111,12 @@ def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]: ).decode('utf-8') return result elif part.function_response is not None: - # Convert the function response to a string for the interactions API - # The interactions API expects result to be either a string or items list + # Pass the function response through to the interactions API. + # Dict and list values are passed directly — the Interactions API handles + # JSON serialization internally. Pre-serializing with json.dumps() would + # cause double-escaping. result = part.function_response.response - if isinstance(result, dict): - result = json.dumps(result) - elif not isinstance(result, str): + if not isinstance(result, (dict, str, list)): result = str(result) logger.debug( 'Converting function_response: name=%s, call_id=%s', diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 3a6c36624d..5fa26261d7 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -234,10 +234,6 @@ def _get_provider_from_model(model: str) -> str: return "" -# Default MIME type when none can be inferred -_DEFAULT_MIME_TYPE = "application/octet-stream" - - def _infer_mime_type_from_uri(uri: str) -> Optional[str]: """Attempts to infer MIME type from a URI's path extension. @@ -1103,33 +1099,33 @@ async def _get_content( }) continue - # Determine MIME type: use explicit value, infer from URI, or use default. + # Resolve MIME type early: needed before the media-URL shortcut below, + # which must run before the generic text-fallback check. The raise is + # deferred until after all early-continue paths so that providers which + # always fall back to text (anthropic, non-Gemini Vertex AI) are never + # asked for a MIME type they cannot supply. mime_type = part.file_data.mime_type if not mime_type: mime_type = _infer_mime_type_from_uri(part.file_data.file_uri) if not mime_type and part.file_data.display_name: guessed_mime_type, _ = mimetypes.guess_type(part.file_data.display_name) mime_type = guessed_mime_type - if not mime_type: - # LiteLLM's Vertex AI backend requires format for GCS URIs. - mime_type = _DEFAULT_MIME_TYPE - logger.debug( - "Could not determine MIME type for file_uri %s, using default: %s", - part.file_data.file_uri, - mime_type, - ) - mime_type = _normalize_mime_type(mime_type) + if mime_type: + mime_type = _normalize_mime_type(mime_type) + # For OpenAI/Azure: HTTP media URLs (image, video, audio) are sent as + # typed URL blocks and must be handled before the generic text fallback. if provider in _FILE_ID_REQUIRED_PROVIDERS and _is_http_url( part.file_data.file_uri ): - url_content_type = _media_url_content_type(mime_type) - if url_content_type: - content_objects.append({ - "type": url_content_type, - url_content_type: {"url": part.file_data.file_uri}, - }) - continue + if mime_type: + url_content_type = _media_url_content_type(mime_type) + if url_content_type: + content_objects.append({ + "type": url_content_type, + url_content_type: {"url": part.file_data.file_uri}, + }) + continue if _requires_file_uri_fallback(provider, model, part.file_data.file_uri): logger.debug( @@ -1147,6 +1143,19 @@ async def _get_content( }) continue + # All remaining providers (e.g. Vertex AI + Gemini) require a specific + # MIME type in the file object. Both a missing type and + # 'application/octet-stream' cause a downstream ValueError from LiteLLM + # regardless of whether the value was set explicitly by the caller or + # arrived via a default fallback; raise early with an actionable message. + if not mime_type or mime_type == "application/octet-stream": + type_label = mime_type or "(unknown)" + raise ValueError( + f"Cannot process file_uri {part.file_data.file_uri!r}: MIME type" + f" {type_label!r} is not supported. Please set a specific MIME" + " type on `file_data.mime_type`." + ) + file_object: ChatCompletionFileUrlObject = { "file_id": part.file_data.file_uri, } diff --git a/src/google/adk/models/llm_response.py b/src/google/adk/models/llm_response.py index 10ab946455..c921f197c3 100644 --- a/src/google/adk/models/llm_response.py +++ b/src/google/adk/models/llm_response.py @@ -14,6 +14,7 @@ from __future__ import annotations +import logging from typing import Any from typing import Optional @@ -216,9 +217,15 @@ def create( model_version=generate_content_response.model_version, ) else: + # Some model backends can legitimately complete a turn without + # candidates (for example, tool-driven UI turns with no text). Treat + # this as an empty successful response rather than an unknown error. + logging.warning( + 'Received empty candidates and no prompt feedback in model ' + 'response. Treating as a successful empty response.' + ) return LlmResponse( - error_code='UNKNOWN_ERROR', - error_message='Unknown error.', + content=types.Content(role='model', parts=[]), usage_metadata=usage_metadata, model_version=generate_content_response.model_version, ) diff --git a/src/google/adk/plugins/__init__.py b/src/google/adk/plugins/__init__.py index 45caf16038..70347fd25e 100644 --- a/src/google/adk/plugins/__init__.py +++ b/src/google/adk/plugins/__init__.py @@ -1,8 +1,7 @@ # Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may in obtain a copy of the License at +# you may in obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # @@ -12,11 +11,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + from .base_plugin import BasePlugin -from .debug_logging_plugin import DebugLoggingPlugin -from .logging_plugin import LoggingPlugin from .plugin_manager import PluginManager -from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin + +if TYPE_CHECKING: + from .debug_logging_plugin import DebugLoggingPlugin + from .logging_plugin import LoggingPlugin + from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin __all__ = [ 'BasePlugin', @@ -25,3 +31,16 @@ 'PluginManager', 'ReflectAndRetryToolPlugin', ] + +_LAZY_MEMBERS: dict[str, str] = { + 'DebugLoggingPlugin': 'debug_logging_plugin', + 'LoggingPlugin': 'logging_plugin', + 'ReflectAndRetryToolPlugin': 'reflect_retry_tool_plugin', +} + + +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index f52f07abb3..850c26bbba 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -26,9 +26,9 @@ from typing import Generator from typing import List from typing import Optional +from typing import TYPE_CHECKING import warnings -from google.adk.apps.compaction import _run_compaction_for_sliding_window from google.genai import types from .agents.base_agent import BaseAgent @@ -38,10 +38,7 @@ from .agents.invocation_context import new_invocation_context_id from .agents.live_request_queue import LiveRequestQueue from .agents.run_config import RunConfig -from .apps.app import App -from .apps.app import ResumabilityConfig from .artifacts.base_artifact_service import BaseArtifactService -from .artifacts.in_memory_artifact_service import InMemoryArtifactService from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .errors.session_not_found_error import SessionNotFoundError @@ -51,19 +48,21 @@ from .flows.llm_flows.functions import find_event_by_function_call_id from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService -from .memory.in_memory_memory_service import InMemoryMemoryService from .platform.thread import create_thread from .plugins.base_plugin import BasePlugin from .plugins.plugin_manager import PluginManager from .sessions.base_session_service import BaseSessionService from .sessions.base_session_service import GetSessionConfig -from .sessions.in_memory_session_service import InMemorySessionService from .sessions.session import Session from .telemetry.tracing import tracer from .tools.base_toolset import BaseToolset from .utils._debug_output import print_event from .utils.context_utils import Aclosing +if TYPE_CHECKING: + from .apps.app import App + from .apps.app import ResumabilityConfig + logger = logging.getLogger('google_adk.' + __name__) @@ -608,7 +607,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing( self._exec_with_plugin( invocation_context=invocation_context, - session=session, + session=invocation_context.session, execute_fn=execute, is_live_call=False, ) @@ -620,9 +619,11 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: # the end of an invocation.) if self.app and self.app.events_compaction_config: logger.debug('Running event compactor.') + from google.adk.apps.compaction import _run_compaction_for_sliding_window + await _run_compaction_for_sliding_window( self.app, - session, + invocation_context.session, self.session_service, skip_token_compaction=invocation_context.token_compaction_checked, ) @@ -841,7 +842,7 @@ async def _exec_with_plugin( Args: invocation_context: The invocation context - session: The current session + session: The current session (ignored, kept for backward compatibility) execute_fn: A callable that returns an AsyncGenerator of Events is_live_call: Whether this is a live call @@ -866,7 +867,7 @@ async def _exec_with_plugin( ) if self._should_append_event(early_exit_event, is_live_call): await self.session_service.append_event( - session=session, + session=invocation_context.session, event=early_exit_event, ) yield early_exit_event @@ -931,13 +932,13 @@ async def _exec_with_plugin( ) if self._should_append_event(event, is_live_call): await self.session_service.append_event( - session=session, event=output_event + session=invocation_context.session, event=output_event ) for buffered_event in buffered_events: logger.debug('Appending buffered event: %s', buffered_event) await self.session_service.append_event( - session=session, event=buffered_event + session=invocation_context.session, event=buffered_event ) yield buffered_event # yield buffered events to caller buffered_events = [] @@ -947,12 +948,12 @@ async def _exec_with_plugin( if self._should_append_event(event, is_live_call): logger.debug('Appending non-buffered event: %s', event) await self.session_service.append_event( - session=session, event=output_event + session=invocation_context.session, event=output_event ) else: if event.partial is not True: await self.session_service.append_event( - session=session, event=output_event + session=invocation_context.session, event=output_event ) yield output_event @@ -1004,8 +1005,8 @@ async def _append_new_message_to_session( file_name = f'artifact_{invocation_context.invocation_id}_{i}' await self.artifact_service.save_artifact( app_name=self.app_name, - user_id=session.user_id, - session_id=session.id, + user_id=invocation_context.session.user_id, + session_id=invocation_context.session.id, filename=file_name, artifact=part, ) @@ -1032,7 +1033,9 @@ async def _append_new_message_to_session( if function_call := invocation_context._find_matching_function_call(event): event.branch = function_call.branch - await self.session_service.append_event(session=session, event=event) + await self.session_service.append_event( + session=invocation_context.session, event=event + ) async def run_live( self, @@ -1127,7 +1130,9 @@ async def run_live( ) root_agent = self.agent - invocation_context.agent = self._find_agent_to_run(session, root_agent) + invocation_context.agent = self._find_agent_to_run( + invocation_context.session, root_agent + ) async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing(ctx.agent.run_live(ctx)) as agen: @@ -1137,7 +1142,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing( self._exec_with_plugin( invocation_context=invocation_context, - session=session, + session=invocation_context.session, execute_fn=execute, is_live_call=True, ) @@ -1355,14 +1360,16 @@ async def _setup_context_for_new_invocation( # Step 2: Handle new message, by running callbacks and appending to # session. await self._handle_new_message( - session=session, + session=invocation_context.session, new_message=new_message, invocation_context=invocation_context, run_config=run_config, state_delta=state_delta, ) # Step 3: Set agent to run for the invocation. - invocation_context.agent = self._find_agent_to_run(session, self.agent) + invocation_context.agent = self._find_agent_to_run( + invocation_context.session, self.agent + ) return invocation_context async def _setup_context_for_resumed_invocation( @@ -1411,7 +1418,7 @@ async def _setup_context_for_resumed_invocation( # Step 3: Maybe handle new message. if new_message: await self._handle_new_message( - session=session, + session=invocation_context.session, new_message=user_message, invocation_context=invocation_context, run_config=run_config, @@ -1425,7 +1432,9 @@ async def _setup_context_for_resumed_invocation( # started from a sub-agent and paused on a sub-agent. # We should find the appropriate agent to run to continue the invocation. if self.agent.name not in invocation_context.end_of_agents: - invocation_context.agent = self._find_agent_to_run(session, self.agent) + invocation_context.agent = self._find_agent_to_run( + invocation_context.session, self.agent + ) return invocation_context def _find_user_message_for_invocation( @@ -1559,7 +1568,7 @@ async def _handle_new_message( if 'save_input_blobs_as_artifacts' in run_config.model_fields_set: deprecated_save_blobs = run_config.save_input_blobs_as_artifacts await self._append_new_message_to_session( - session=session, + session=invocation_context.session, new_message=new_message, invocation_context=invocation_context, save_input_blobs_as_artifacts=deprecated_save_blobs, @@ -1669,6 +1678,10 @@ def __init__( app: Optional App instance. plugin_close_timeout: The timeout in seconds for plugin close methods. """ + from .artifacts.in_memory_artifact_service import InMemoryArtifactService + from .memory.in_memory_memory_service import InMemoryMemoryService + from .sessions.in_memory_session_service import InMemorySessionService + if app is None and app_name is None: app_name = 'InMemoryRunner' super().__init__( diff --git a/src/google/adk/sessions/__init__.py b/src/google/adk/sessions/__init__.py index 7505eda346..db983f96f1 100644 --- a/src/google/adk/sessions/__init__.py +++ b/src/google/adk/sessions/__init__.py @@ -11,11 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + from .base_session_service import BaseSessionService -from .in_memory_session_service import InMemorySessionService from .session import Session from .state import State -from .vertex_ai_session_service import VertexAiSessionService + +if TYPE_CHECKING: + from .database_session_service import DatabaseSessionService + from .in_memory_session_service import InMemorySessionService + from .vertex_ai_session_service import VertexAiSessionService __all__ = [ 'BaseSessionService', @@ -26,16 +35,23 @@ 'VertexAiSessionService', ] +_LAZY_MEMBERS: dict[str, str] = { + 'InMemorySessionService': 'in_memory_session_service', + 'VertexAiSessionService': 'vertex_ai_session_service', +} + def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] if name == 'DatabaseSessionService': try: - from .database_session_service import DatabaseSessionService - - return DatabaseSessionService + module = importlib.import_module(f'{__name__}.database_session_service') except ImportError as e: raise ImportError( - 'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it is' - ' installed correctly.' + 'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it' + ' is installed correctly.' ) from e + return vars(module)['DatabaseSessionService'] raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/skills/_utils.py b/src/google/adk/skills/_utils.py index cab70a8d4b..3270025e01 100644 --- a/src/google/adk/skills/_utils.py +++ b/src/google/adk/skills/_utils.py @@ -16,9 +16,12 @@ from __future__ import annotations +import io import logging import pathlib +from typing import Dict from typing import Union +import zipfile from google.auth import credentials as auth from google.cloud import storage @@ -177,6 +180,96 @@ def _load_skill_from_dir(skill_dir: Union[str, pathlib.Path]) -> models.Skill: ) +def _load_skill_from_zip_bytes(zip_bytes: bytes) -> models.Skill: + """Load a complete skill directly from in-memory zip file bytes. + + Args: + zip_bytes: The raw bytes of the zip file containing the skill. + + Returns: + Skill object with all components loaded. + + Raises: + FileNotFoundError: If SKILL.md is not found in the archive. + ValueError: If SKILL.md is invalid or contains dangerous paths. + """ + with zipfile.ZipFile(io.BytesIO(zip_bytes)) as z: + # Security check for zip slip + for member in z.infolist(): + filename = member.filename + if ( + filename.startswith("/") + or filename.startswith("../") + or "/../" in filename + ): + raise ValueError(f"Dangerous zip entry ignored: {filename}") + + # Find SKILL.md or skill.md + skill_md_content = None + for name in ("SKILL.md", "skill.md"): + try: + skill_md_content = z.read(name).decode("utf-8") + break + except KeyError: + continue + + if skill_md_content is None: + raise FileNotFoundError("SKILL.md not found in zipped filesystem.") + + parsed, body = _parse_skill_md_content(skill_md_content) + skill_name = parsed.get("name") + if not skill_name: + raise ValueError("SKILL.md frontmatter must contain 'name'") + if ( + not isinstance(skill_name, str) + or pathlib.Path(skill_name).name != skill_name + ): + raise ValueError(f"Invalid skill name in SKILL.md: {skill_name}") + + frontmatter = models.Frontmatter.model_validate(parsed) + + # Helper to load files under a directory prefix inside the zip + def _load_zip_dir(prefix: str) -> dict[str, str]: + result = {} + if not prefix.endswith("/"): + prefix += "/" + for info in z.infolist(): + if info.is_dir(): + continue + if info.filename.startswith(prefix): + # Avoid cache files or similar + if "__pycache__" in info.filename: + continue + relative_path = info.filename[len(prefix) :] + if not relative_path: + continue + try: + result[relative_path] = z.read(info).decode("utf-8") + except UnicodeDecodeError: + continue + return result + + references = _load_zip_dir("references") + assets = _load_zip_dir("assets") + raw_scripts = _load_zip_dir("scripts") + scripts = { + name: models.Script(src=content) + for name, content in raw_scripts.items() + } + + resources = models.Resources( + references=references, + assets=assets, + scripts=scripts, + ) + + return models.Skill( + frontmatter=frontmatter, + instructions=body, + resources=resources, + ) + + def _validate_skill_dir( skill_dir: Union[str, pathlib.Path], ) -> list[str]: diff --git a/src/google/adk/skills/models.py b/src/google/adk/skills/models.py index 1cd443b2ca..9e9b378a97 100644 --- a/src/google/adk/skills/models.py +++ b/src/google/adk/skills/models.py @@ -108,8 +108,12 @@ def _validate_name(cls, v: str) -> str: def _validate_description(cls, v: str) -> str: if not v: raise ValueError("description must not be empty") - if len(v) > 1024: - raise ValueError("description must be at most 1024 characters") + description_len = len(v) + if description_len > 1024: + raise ValueError( + "description must be at most 1024 characters. Description length:" + f" {description_len}" + ) return v @field_validator("compatibility") diff --git a/src/google/adk/telemetry/google_cloud.py b/src/google/adk/telemetry/google_cloud.py index dee8f3f554..c34cdba90c 100644 --- a/src/google/adk/telemetry/google_cloud.py +++ b/src/google/adk/telemetry/google_cloud.py @@ -14,13 +14,17 @@ from __future__ import annotations +import enum import logging import os +from typing import Any +from typing import Callable from typing import cast from typing import Optional from typing import TYPE_CHECKING import google.auth +from google.auth.transport import mtls from opentelemetry.sdk._logs import LogRecordProcessor from opentelemetry.sdk._logs.export import BatchLogRecordProcessor from opentelemetry.sdk.metrics.export import MetricReader @@ -40,6 +44,19 @@ _GCP_LOG_NAME_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_DEFAULT_LOG_NAME' _DEFAULT_LOG_NAME = 'adk-otel' +_DEFAULT_TELEMETRY_TRACES_ENPOINT = 'https://telemetry.googleapis.com/v1/traces' +_DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT = ( + 'https://telemetry.mtls.googleapis.com/v1/traces' +) + + +class _MtlsEndpoint(enum.Enum): + """The mTLS endpoint setting.""" + + AUTO = 'auto' + ALWAYS = 'always' + NEVER = 'never' + def get_gcp_exporters( enable_cloud_tracing: bool = False, @@ -100,10 +117,24 @@ def _get_gcp_span_exporter(credentials: Credentials) -> SpanProcessor: from google.auth.transport.requests import AuthorizedSession from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + session = AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_TRACES_ENPOINT + return BatchSpanProcessor( OTLPSpanExporter( - session=AuthorizedSession(credentials=credentials), - endpoint='https://telemetry.googleapis.com/v1/traces', + session=session, + endpoint=endpoint, ) ) @@ -158,3 +189,62 @@ def get_gcp_resource(project_id: Optional[str] = None) -> Resource: ' GCE, GKE or CloudRun related resource attributes may be missing' ) return resource + + +def _get_api_endpoint( + client_cert_source: Callable[[], tuple[bytes, bytes]] | None = None, +) -> str: + """Returns API endpoint based on mTLS configuration and cert availability. + + Args: + client_cert_source: A callable that returns the client certificate and + key, or None. + + Returns: + str: The API endpoint to be used. + """ + use_mtls_endpoint_str = os.getenv( + 'GOOGLE_API_USE_MTLS_ENDPOINT', _MtlsEndpoint.AUTO.value + ).lower() + + try: + use_mtls_endpoint = _MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + logger.warning( + 'Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of ' + '%s. Defaulting to %s.', + [e.value for e in _MtlsEndpoint], + _MtlsEndpoint.AUTO.value, + ) + use_mtls_endpoint = _MtlsEndpoint.AUTO + + if (use_mtls_endpoint is _MtlsEndpoint.ALWAYS) or ( + use_mtls_endpoint is _MtlsEndpoint.AUTO and client_cert_source + ): + return _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT + + return _DEFAULT_TELEMETRY_TRACES_ENPOINT + + +def _use_client_cert_effective() -> bool: + """Returns whether client certificate should be used for mTLS. + + This checks if the google-auth version supports should_use_client_cert + automatic mTLS enablement. Alternatively, it reads from the + GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS. + """ + try: + return bool(mtls.should_use_client_cert()) + except (ImportError, AttributeError): + use_client_cert_str = os.getenv( + 'GOOGLE_API_USE_CLIENT_CERTIFICATE', 'false' + ).lower() + if use_client_cert_str not in ('true', 'false'): + logger.warning( + 'Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be' + ' either `true` or `false`' + ) + return use_client_cert_str == 'true' diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 8a43eb6311..2b8c083f23 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -41,6 +41,17 @@ from ..agents.base_agent import BaseAgent +def _part_to_text(part: types.Part) -> str: + """Returns user-visible text from a Part, including code execution output.""" + if part.text: + return part.text + if part.code_execution_result and part.code_execution_result.output: + return part.code_execution_result.output.rstrip('\n') + if part.executable_code and part.executable_code.code: + return part.executable_code.code + return '' + + def _get_input_schema(agent: BaseAgent) -> Optional[type[BaseModel]]: """Extracts the input_schema from an agent. @@ -269,9 +280,8 @@ async def run_async( if last_content is None or last_content.parts is None: return '' - merged_text = '\n'.join( - p.text for p in last_content.parts if p.text and not p.thought - ) + parts_text = (_part_to_text(p) for p in last_content.parts if not p.thought) + merged_text = '\n'.join(t for t in parts_text if t) output_schema = _get_output_schema(self.agent) if output_schema: tool_result = validate_schema(output_schema, merged_text) diff --git a/src/google/adk/tools/apihub_tool/apihub_toolset.py b/src/google/adk/tools/apihub_tool/apihub_toolset.py index 5560c7ce52..19c6d709e3 100644 --- a/src/google/adk/tools/apihub_tool/apihub_toolset.py +++ b/src/google/adk/tools/apihub_tool/apihub_toolset.py @@ -199,13 +199,3 @@ def _prepare_toolset(self) -> None: async def close(self): if self._openapi_toolset: await self._openapi_toolset.close() - - @override - def get_auth_config(self) -> Optional[AuthConfig]: - """Returns the auth config for this toolset. - - ADK will populate exchanged_auth_credential on this config before calling - get_tools(). The toolset can then access the ready-to-use credential via - self._auth_config.exchanged_auth_credential. - """ - return self._auth_config diff --git a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py index 5e068fba1d..e4e2c5dde7 100644 --- a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py +++ b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py @@ -289,13 +289,3 @@ async def get_tools( async def close(self) -> None: if self._openapi_toolset: await self._openapi_toolset.close() - - @override - def get_auth_config(self) -> Optional[AuthConfig]: - """Returns the auth config for this toolset. - - ADK will populate exchanged_auth_credential on this config before calling - get_tools(). The toolset can then access the ready-to-use credential via - self._auth_config.exchanged_auth_credential. - """ - return self._auth_config diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 1a14fe1972..a4b45cf16b 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -542,10 +542,13 @@ async def create_session( sampling_capabilities=self._sampling_capabilities, ) - session = await asyncio.wait_for( - exit_stack.enter_async_context(session_context), - timeout=timeout_in_seconds, - ) + if is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING): # pylint: disable=protected-access + session = await exit_stack.enter_async_context(session_context) + else: + session = await asyncio.wait_for( + exit_stack.enter_async_context(session_context), + timeout=timeout_in_seconds, + ) # Store session, exit stack, and loop in the pool. The pool storage # remains a tuple for backward-compatibility with downstream tests diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py index 235ae7c350..99e649d9c9 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py @@ -245,14 +245,3 @@ def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]: @override async def close(self): pass - - @override - def get_auth_config(self) -> Optional[AuthConfig]: - """Returns the auth config for this toolset. - - Note: This returns a copy so any exchanged credentials populated by the ADK - framework do not persist on the toolset instance across invocations. - """ - return ( - self._auth_config.model_copy(deep=True) if self._auth_config else None - ) diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index 3c60bd5918..ef579d8256 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -19,6 +19,7 @@ from __future__ import annotations import asyncio +import collections import json import logging import mimetypes @@ -895,6 +896,8 @@ def __init__( script_timeout: Timeout in seconds for shell script execution via subprocess.run. Defaults to 300 seconds. Does not apply to Python scripts executed via exec(). + additional_tools: Optional list of `BaseTool` or `BaseToolset` instances + to be made available to the agent when certain skills are activated. """ super().__init__() @@ -911,10 +914,13 @@ def __init__( self._registry = registry self._code_executor = code_executor self._script_timeout = script_timeout - self._invocation_cache: dict[ + # Needed for mid-turn reloading of skill tools. + self._use_invocation_cache = False + # Cache fetched remote skill definitions per turn to reduce requests to registry + self._fetched_skill_cache: collections.OrderedDict[ str, dict[str, models.Skill | asyncio.Future[models.Skill | None] | None], - ] = {} + ] = collections.OrderedDict() self._max_cache_turns = 16 self._provided_tools_by_name = {} @@ -1019,14 +1025,13 @@ async def _get_or_fetch_skill( return None if invocation_id: - if invocation_id not in self._invocation_cache: + if invocation_id not in self._fetched_skill_cache: # Enforce bounded cache (FIFO eviction) - if len(self._invocation_cache) >= self._max_cache_turns: - oldest = next(iter(self._invocation_cache)) - self._invocation_cache.pop(oldest) - self._invocation_cache[invocation_id] = {} + if len(self._fetched_skill_cache) >= self._max_cache_turns: + self._fetched_skill_cache.popitem(last=False) + self._fetched_skill_cache[invocation_id] = {} - turn_cache = self._invocation_cache[invocation_id] + turn_cache = self._fetched_skill_cache[invocation_id] if skill_name in turn_cache: cached = turn_cache[skill_name] if isinstance(cached, asyncio.Future): @@ -1076,6 +1081,16 @@ async def process_llm_request( llm_request.append_instructions(instructions) + @override + async def close(self) -> None: + """Performs cleanup and releases resources held by the toolset.""" + for turn_cache in self._fetched_skill_cache.values(): + for cached in turn_cache.values(): + if isinstance(cached, asyncio.Future) and not cached.done(): + cached.cancel() + self._fetched_skill_cache.clear() + await super().close() + def __getattr__(name: str) -> Any: if name == "DEFAULT_SKILL_SYSTEM_INSTRUCTION": diff --git a/src/google/adk/utils/context_utils.py b/src/google/adk/utils/context_utils.py index c0fde29ae4..bd80fa2ff3 100644 --- a/src/google/adk/utils/context_utils.py +++ b/src/google/adk/utils/context_utils.py @@ -21,6 +21,7 @@ from __future__ import annotations from contextlib import aclosing +import functools import inspect import typing from typing import Any @@ -62,6 +63,7 @@ def _is_context_type(annotation: Any) -> bool: return annotation is Context +@functools.lru_cache(maxsize=1024) def find_context_parameter(func: Callable[..., Any]) -> str | None: """Find the parameter name that has a Context type annotation. diff --git a/tests/unittests/a2a/converters/test_to_adk.py b/tests/unittests/a2a/converters/test_to_adk.py index 12eaf2a75a..3ab60f097d 100644 --- a/tests/unittests/a2a/converters/test_to_adk.py +++ b/tests/unittests/a2a/converters/test_to_adk.py @@ -30,6 +30,7 @@ from google.adk.a2a.converters.to_adk_event import convert_a2a_message_to_event from google.adk.a2a.converters.to_adk_event import convert_a2a_status_update_to_event from google.adk.a2a.converters.to_adk_event import convert_a2a_task_to_event +from google.adk.a2a.converters.to_adk_event import MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT from google.adk.a2a.converters.utils import _get_adk_metadata_key from google.adk.agents.invocation_context import InvocationContext from google.genai import types as genai_types @@ -330,12 +331,95 @@ def test_convert_a2a_task_to_event_merges_status_and_artifact_actions(self): assert event.actions.state_delta == {"saved_key": "saved-value"} assert event.actions.transfer_to_agent == "agent-2" assert event.content is not None - assert event.content.parts == [mock_genai_part] + assert ( + event.content.parts[0].function_call.name + == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT + ) + assert ( + event.content.parts[0].function_call.args["input_required"] + == "need input" + ) + + def test_convert_a2a_task_to_event_multiple_parts_replaces_last_text(self): + """Test converting A2A task with multiple text parts, only replacing the last text.""" + part1 = Mock(spec=A2APart) + part1.root = Mock(spec=TextPart) + part1.root.metadata = {} + part2 = Mock(spec=A2APart) + part2.root = Mock(spec=TextPart) + part2.root.metadata = {} + + task = Task( + id="task-1", + context_id="context-1", + kind="task", + status=TaskStatus( + state=TaskState.input_required, + timestamp="now", + message=Message( + message_id="m1", + role="agent", + parts=[part1, part2], + ), + ), + ) + + mock_genai_part_1 = genai_types.Part.from_text(text="Part 1") + mock_genai_part_2 = genai_types.Part.from_text(text="Part 2") - def test_convert_a2a_task_to_event_none(self): - """Test convert_a2a_task_to_event with None.""" - with pytest.raises(ValueError, match="A2A task cannot be None"): - convert_a2a_task_to_event(None) + part_converter_mock = Mock() + part_converter_mock.side_effect = [[mock_genai_part_1], [mock_genai_part_2]] + + event = convert_a2a_task_to_event( + task, + author="test-author", + invocation_context=self.mock_context, + part_converter=part_converter_mock, + ) + + assert event is not None + assert event.content is not None + assert len(event.content.parts) == 2 + assert event.content.parts[0].text == "Part 1" + assert ( + event.content.parts[1].function_call.name + == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT + ) + + def test_convert_a2a_task_to_event_no_text_parts(self): + """Test converting A2A task with no text parts should not inject function call.""" + part1 = Mock(spec=A2APart) + part1.root = Mock() # Not a TextPart + part1.root.metadata = {} + + task = Task( + id="task-1", + context_id="context-1", + kind="task", + status=TaskStatus( + state=TaskState.input_required, + timestamp="now", + message=Message( + message_id="m1", + role="agent", + parts=[part1], + ), + ), + ) + mock_image_part = genai_types.Part( + inline_data=genai_types.Blob(mime_type="image/jpeg", data=b"fake") + ) + + event = convert_a2a_task_to_event( + task, + author="test-author", + invocation_context=self.mock_context, + part_converter=Mock(return_value=[mock_image_part]), + ) + + assert event is not None + assert event.content is not None + assert event.content.parts == [mock_image_part] def test_convert_a2a_status_update_to_event_success(self): """Test successful conversion of A2A status update to Event.""" diff --git a/tests/unittests/a2a/integration/test_client_server.py b/tests/unittests/a2a/integration/test_client_server.py index bd1d72f617..18b13d05d2 100644 --- a/tests/unittests/a2a/integration/test_client_server.py +++ b/tests/unittests/a2a/integration/test_client_server.py @@ -14,11 +14,19 @@ """Integration tests for A2A client-server interaction.""" +import logging +from unittest.mock import AsyncMock + +from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication +from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import Message as A2AMessage from a2a.types import Part as A2APart from a2a.types import Task from a2a.types import TaskState +from a2a.types import TaskStatus from a2a.types import TextPart +from google.adk.a2a.agent.interceptors.new_integration_extension import _NEW_A2A_ADK_INTEGRATION_EXTENSION +from google.adk.a2a.converters.to_adk_event import MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT from google.adk.a2a.executor.config import A2aAgentExecutorConfig from google.adk.a2a.executor.interceptors.include_artifacts_in_a2a_event import include_artifacts_in_a2a_event_interceptor from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX @@ -32,8 +40,11 @@ from .client import create_a2a_client from .client import create_client +from .server import agent_card from .server import create_server_app +logger = logging.getLogger("google_adk." + __name__) + def create_streaming_mock_run_async(received_requests: list): """Creates a mock_run_async that streams multiple chunks.""" @@ -636,3 +647,161 @@ async def mock_run_async(**kwargs): assert task.artifacts[2].artifact_id == "artifact2_1" assert task.artifacts[2].name == "artifact2" assert task.artifacts[2].parts[0].root.text == "artifact content" + + +@pytest.mark.asyncio +async def test_user_follow_up_sends_task_id_with_input_required(): + """Test that client follow-up sends the same task_id.""" + + task_id = "mocked-task-id-123" + context_id = "mocked-context-id-456" + mock_task = Task( + id=task_id, + context_id=context_id, + kind="task", + status=TaskStatus( + state=TaskState.input_required, + message=A2AMessage( + message_id="mocked-message-id-789", + role="user", + parts=[A2APart(root=TextPart(text="Input required"))], + ), + ), + metadata={_NEW_A2A_ADK_INTEGRATION_EXTENSION: True}, + ) + + mock_handler = AsyncMock(spec=RequestHandler) + # First call returns input_required, second call completes + mock_handler.on_message_send.side_effect = [ + mock_task, + Task( + id=task_id, + context_id=context_id, + kind="task", + status=TaskStatus(state=TaskState.completed), + metadata={_NEW_A2A_ADK_INTEGRATION_EXTENSION: True}, + ), + ] + + app = A2AFastAPIApplication( + agent_card=agent_card, http_handler=mock_handler + ).build() + agent = create_client(app, streaming=False) + + session_service = InMemorySessionService() + await session_service.create_session( + app_name="ClientApp", user_id="test_user", session_id="test_session" + ) + client_runner = Runner( + app_name="ClientApp", agent=agent, session_service=session_service + ) + + # First Turn + new_message_1 = types.Content(parts=[types.Part(text="Turn 1")], role="user") + found_call_id = None + async for event in client_runner.run_async( + user_id="test_user", session_id="test_session", new_message=new_message_1 + ): + for call in event.get_function_calls(): + if call.name == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT: + found_call_id = call.id + + assert found_call_id is not None + + # Second Turn (Follow-up) + function_response = types.FunctionResponse( + id=found_call_id, + name=MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT, + response={"result": "Turn 2"}, + ) + new_message_2 = types.Content( + parts=[types.Part(function_response=function_response)], role="user" + ) + async for _ in client_runner.run_async( + user_id="test_user", session_id="test_session", new_message=new_message_2 + ): + pass + + assert mock_handler.on_message_send.call_count == 2 + # Second call args + call_args_2 = mock_handler.on_message_send.call_args_list[1] + params_2 = call_args_2[0][0] + assert params_2.message.task_id == task_id + + +@pytest.mark.asyncio +async def test_user_follow_up_sends_task_id_with_input_required_legacy_impl(): + """Test that client follow-up sends the same task_id.""" + + task_id = "mocked-task-id-123" + context_id = "mocked-context-id-456" + mock_task = Task( + id=task_id, + context_id=context_id, + kind="task", + status=TaskStatus( + state=TaskState.input_required, + message=A2AMessage( + message_id="mocked-message-id-789", + role="user", + parts=[A2APart(root=TextPart(text="Input required"))], + ), + ), + ) + + mock_handler = AsyncMock(spec=RequestHandler) + # First call returns input_required, second call completes + mock_handler.on_message_send.side_effect = [ + mock_task, + Task( + id=task_id, + context_id=context_id, + kind="task", + status=TaskStatus(state=TaskState.completed), + ), + ] + + app = A2AFastAPIApplication( + agent_card=agent_card, http_handler=mock_handler + ).build() + agent = create_client(app, streaming=False) + + session_service = InMemorySessionService() + await session_service.create_session( + app_name="ClientApp", user_id="test_user", session_id="test_session" + ) + client_runner = Runner( + app_name="ClientApp", agent=agent, session_service=session_service + ) + + # First Turn + new_message_1 = types.Content(parts=[types.Part(text="Turn 1")], role="user") + found_call_id = None + async for event in client_runner.run_async( + user_id="test_user", session_id="test_session", new_message=new_message_1 + ): + for call in event.get_function_calls(): + if call.name == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT: + found_call_id = call.id + + assert found_call_id is not None + + # Second Turn (Follow-up) + function_response = types.FunctionResponse( + id=found_call_id, + name=MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT, + response={"result": "Turn 2"}, + ) + new_message_2 = types.Content( + parts=[types.Part(function_response=function_response)], role="user" + ) + async for _ in client_runner.run_async( + user_id="test_user", session_id="test_session", new_message=new_message_2 + ): + pass + + assert mock_handler.on_message_send.call_count == 2 + # Second call args + call_args_2 = mock_handler.on_message_send.call_args_list[1] + params_2 = call_args_2[0][0] + assert params_2.message.task_id == task_id diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index b5012a6535..752df0ded3 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -357,7 +357,7 @@ def test_to_a2a_creates_runner_with_correct_services( @patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder") @patch("google.adk.a2a.utils.agent_to_a2a.Starlette") @patch("google.adk.a2a.utils.agent_to_a2a.Runner") - async def test_create_runner_function_creates_runner_correctly( + def test_create_runner_function_creates_runner_correctly( self, mock_runner_class, mock_starlette_class, @@ -391,7 +391,7 @@ async def test_create_runner_function_creates_runner_correctly( runner_func = call_args[1]["runner"] # Call the runner function to verify it creates Runner correctly - runner_result = await runner_func() + runner_result = runner_func() # Verify Runner was created with correct parameters mock_runner_class.assert_called_once_with( @@ -420,7 +420,7 @@ async def test_create_runner_function_creates_runner_correctly( @patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder") @patch("google.adk.a2a.utils.agent_to_a2a.Starlette") @patch("google.adk.a2a.utils.agent_to_a2a.Runner") - async def test_create_runner_function_with_agent_without_name( + def test_create_runner_function_with_agent_without_name( self, mock_runner_class, mock_starlette_class, @@ -455,7 +455,7 @@ async def test_create_runner_function_with_agent_without_name( runner_func = call_args[1]["runner"] # Call the runner function to verify it creates Runner correctly - await runner_func() + runner_func() # Verify Runner was created with default app_name when agent has no name mock_runner_class.assert_called_once_with( diff --git a/tests/unittests/agents/test_output_key_visibility.py b/tests/unittests/agents/test_output_key_visibility.py new file mode 100644 index 0000000000..e20a676241 --- /dev/null +++ b/tests/unittests/agents/test_output_key_visibility.py @@ -0,0 +1,180 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LlmAgent output_key visibility in callbacks.""" + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.live_request_queue import LiveRequestQueue +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.events.event import Event +from google.adk.flows.llm_flows.auto_flow import AutoFlow +from google.genai import types +import pytest +from pytest_mock import MockerFixture + +from .. import testing_utils + +# Standard MockModel will be used instead of SafeMockModel + + +@pytest.mark.asyncio +async def test_output_key_visibility_in_after_agent_callback(): + """Test that output_key state delta is visible in after_agent_callback.""" + mock_response = "Hello! How can I help you?" + mock_model = testing_utils.MockModel.create(responses=[mock_response]) + + callback_called = False + captured_state_value = None + captured_session_state_value = None + + async def check_output_key(callback_context: CallbackContext): + nonlocal callback_called, captured_state_value, captured_session_state_value + callback_called = True + captured_state_value = callback_context.state.get("result", "NOT_FOUND") + captured_session_state_value = callback_context.session.state.get( + "result", "NOT_IN_RAW" + ) + + agent = LlmAgent( + name="my_agent", + model=mock_model, + instruction="Reply with a short greeting.", + output_key="result", + after_agent_callback=check_output_key, + ) + + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async(new_message="hello") + + assert callback_called, "Callback was not called" + + assert ( + captured_state_value == mock_response + ), f"Expected {mock_response}, got {captured_state_value}" + assert ( + captured_session_state_value == mock_response + ), f"Expected {mock_response}, got {captured_session_state_value}" + + +@pytest.mark.asyncio +async def test_output_key_visibility_in_run_live(mocker: MockerFixture): + """Test that output_key state delta is visible in after_agent_callback in run_live.""" + mock_response = "Hello! How can I help you?" + mock_model = testing_utils.MockModel.create(responses=[mock_response]) + + callback_called = False + captured_state_value = None + captured_session_state_value = None + + async def check_output_key(callback_context: CallbackContext): + nonlocal callback_called, captured_state_value, captured_session_state_value + callback_called = True + captured_state_value = callback_context.state.get("result", "NOT_FOUND") + captured_session_state_value = callback_context.session.state.get( + "result", "NOT_IN_RAW" + ) + + agent = LlmAgent( + name="my_agent", + model=mock_model, + instruction="Reply with a short greeting.", + output_key="result", + after_agent_callback=check_output_key, + ) + + async def mock_auto_flow_run_live(self, ctx): + yield Event( + id=Event.new_id(), + invocation_id=ctx.invocation_id, + author=ctx.agent.name, + content=types.Content(parts=[types.Part(text=mock_response)]), + ) + + mocker.patch.object(AutoFlow, "run_live", mock_auto_flow_run_live) + + runner = testing_utils.InMemoryRunner(root_agent=agent) + live_queue = LiveRequestQueue() + + agen = runner.runner.run_live( + user_id="test_user", + session_id=runner.session.id, + live_request_queue=live_queue, + ) + + # Send a message to trigger the agent + live_queue.send_content( + types.Content(role="user", parts=[types.Part(text="hello")]) + ) + + live_queue.close() + + async for event in agen: + pass + + assert callback_called, "Callback was not called" + assert ( + captured_state_value == mock_response + ), f"Expected {mock_response}, got {captured_state_value}" + assert ( + captured_session_state_value == mock_response + ), f"Expected {mock_response}, got {captured_session_state_value}" + + +@pytest.mark.asyncio +async def test_output_key_visibility_in_sequential_agent(): + """Test that output_key state delta is visible in next agent's before_agent_callback.""" + mock_response = "Hello from agent 1!" + mock_model = testing_utils.MockModel.create( + responses=[mock_response, "Hello from agent 2!"] + ) + + callback_called = False + captured_session_state_value = None + + async def check_before_agent(callback_context: CallbackContext): + nonlocal callback_called, captured_session_state_value + callback_called = True + captured_session_state_value = callback_context.session.state.get( + "result", "NOT_FOUND" + ) + + agent_1 = LlmAgent( + name="agent_1", + model=mock_model, + instruction="Reply with a short greeting.", + output_key="result", + ) + + agent_2 = LlmAgent( + name="agent_2", + model=mock_model, + instruction="Reply with a short greeting.", + before_agent_callback=check_before_agent, + ) + + sequential_agent = SequentialAgent( + name="seq_agent", + sub_agents=[agent_1, agent_2], + ) + + runner = testing_utils.InMemoryRunner(root_agent=sequential_agent) + + events = await runner.run_async(new_message="hello") + + assert callback_called, "Callback was not called" + assert ( + captured_session_state_value == mock_response + ), f"Expected {mock_response}, got {captured_session_state_value}" diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index f5e92bc46d..073ad36d9f 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -573,6 +573,9 @@ def test_create_a2a_request_for_user_function_response_success(self): mock_function_event.custom_metadata = { A2A_METADATA_PREFIX + "task_id": "task-123" } + mock_function_event.content = Mock() + mock_function_event.content.parts = [Mock()] + mock_function_event.get_function_calls.return_value = [] # Mock latest event with function response - set proper author mock_latest_event = Mock() @@ -1372,6 +1375,9 @@ def test_create_a2a_request_for_user_function_response_success(self): mock_function_event.custom_metadata = { A2A_METADATA_PREFIX + "task_id": "task-123" } + mock_function_event.content = Mock() + mock_function_event.content.parts = [Mock()] + mock_function_event.get_function_calls.return_value = [] # Mock latest event with function response - set proper author mock_latest_event = Mock() diff --git a/tests/unittests/apps/test_compaction.py b/tests/unittests/apps/test_compaction.py index ba6d99ef38..dde8a51b4b 100644 --- a/tests/unittests/apps/test_compaction.py +++ b/tests/unittests/apps/test_compaction.py @@ -23,12 +23,15 @@ from google.adk.apps.compaction import _run_compaction_for_sliding_window import google.adk.apps.compaction as compaction_module from google.adk.apps.llm_event_summarizer import LlmEventSummarizer +from google.adk.auth.auth_schemes import CustomAuthScheme +from google.adk.auth.auth_tool import AuthConfig from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.events.event_actions import EventCompaction from google.adk.flows.llm_flows import contents from google.adk.sessions.base_session_service import BaseSessionService from google.adk.sessions.session import Session +from google.adk.tools.tool_confirmation import ToolConfirmation from google.genai import types from google.genai.types import Content from google.genai.types import Part @@ -1319,3 +1322,455 @@ async def test_completed_function_call_pair_is_still_compacted(self): self.assertIn('inv1', compacted_inv_ids) self.assertEqual(compacted_inv_ids.count('inv2'), 2) self.assertIn('inv3', compacted_inv_ids) + + def _create_hitl_confirmation_event( + self, + timestamp: float, + invocation_id: str, + function_call_id: str, + ) -> Event: + """Creates a function response event with a tool confirmation request.""" + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='agent', + content=Content( + role='user', + parts=[ + Part( + function_response=types.FunctionResponse( + id=function_call_id, + name='tool', + response={ + 'error': 'This tool call requires confirmation.' + }, + ) + ) + ], + ), + actions=EventActions( + requested_tool_confirmations={ + function_call_id: ToolConfirmation( + hint='Please confirm this action.' + ) + }, + ), + ) + + def _create_hitl_auth_event( + self, + timestamp: float, + invocation_id: str, + function_call_id: str, + ) -> Event: + """Creates a function response event with an auth credential request.""" + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='agent', + content=Content( + role='user', + parts=[ + Part( + function_response=types.FunctionResponse( + id=function_call_id, + name='tool', + response={'error': 'Auth required.'}, + ) + ) + ], + ), + actions=EventActions( + requested_auth_configs={ + function_call_id: AuthConfig( + auth_scheme=CustomAuthScheme(type='custom'), + ) + }, + ), + ) + + async def test_sliding_window_excludes_hitl_confirmation_events(self): + """Sliding-window compaction stops before tool confirmation events.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + # inv1: text, inv2: call + HITL confirmation response, inv3: text + # The HITL event (confirmation response) blocks compaction at that point. + # The preceding function call event is not HITL itself and gets compacted. + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_confirmation_event(3.0, 'inv2', 'call-1'), + self._create_event(4.0, 'inv3', 'e3'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 2.0, 'Summary before hitl' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + # inv1 text + inv2 function call are compacted; HITL response is protected. + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2']) + + async def test_sliding_window_excludes_hitl_auth_events(self): + """Sliding-window compaction stops before auth credential events.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_auth_event(3.0, 'inv2', 'call-1'), + self._create_event(4.0, 'inv3', 'e3'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 2.0, 'Summary before auth' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2']) + + async def test_token_threshold_excludes_hitl_confirmation_events(self): + """Token-threshold compaction stops before tool confirmation events.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=0, + ), + ) + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_confirmation_event(3.0, 'inv2', 'call-1'), + self._create_event(4.0, 'inv3', 'e3', prompt_token_count=100), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 2.0, 'Summary inv1-inv2' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2']) + + async def test_token_threshold_excludes_hitl_auth_events(self): + """Token-threshold compaction stops before auth credential events.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=0, + ), + ) + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_auth_event(3.0, 'inv2', 'call-1'), + self._create_event(4.0, 'inv3', 'e3', prompt_token_count=100), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 2.0, 'Summary inv1-inv2' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2']) + + async def test_hitl_event_at_start_blocks_all_compaction(self): + """If the first candidate event has HITL, nothing is compacted.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + # The very first event is an HITL confirmation (no preceding function call). + events = [ + self._create_hitl_confirmation_event(1.0, 'inv1', 'call-1'), + self._create_event(2.0, 'inv2', 'e2'), + self._create_event(3.0, 'inv3', 'e3'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + self.mock_compactor.maybe_summarize_events.assert_not_called() + self.mock_session_service.append_event.assert_not_called() + + async def test_events_before_hitl_are_still_compacted(self): + """Events before the HITL event are compacted normally.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + # inv1, inv2: text events, inv3: call + HITL confirmation, inv4: text + # The HITL event at index 3 blocks compaction; events before it are safe. + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_event(2.0, 'inv2', 'e2'), + self._create_function_call_event(3.0, 'inv3', 'call-1'), + self._create_hitl_confirmation_event(4.0, 'inv3', 'call-1'), + self._create_event(5.0, 'inv4', 'e4'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 3.0, 'Summary inv1-inv3' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + # inv1, inv2 (text) + inv3 function call compact; HITL response is not. + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2', 'inv3']) + + async def test_resolved_hitl_confirmation_is_compactable(self): + """A HITL confirmation followed by a resolved tool response is compactable.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + # inv1: text, inv2: call + HITL request + resolved response (same call-1 + # id), inv3: text. The resolved HITL is safe to compact. + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_confirmation_event(3.0, 'inv2', 'call-1'), + self._create_function_response_event(4.0, 'inv2', 'call-1'), + self._create_event(5.0, 'inv3', 'e3'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 5.0, 'Summary including resolved hitl' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + # Resolved HITL doesn't block; all events through inv3 compact together. + self.assertEqual( + compacted_inv_ids, ['inv1', 'inv2', 'inv2', 'inv2', 'inv3'] + ) + + async def test_resolved_hitl_auth_is_compactable(self): + """A HITL auth request followed by a resolved tool response is compactable.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-1'), + self._create_hitl_auth_event(3.0, 'inv2', 'call-1'), + self._create_function_response_event(4.0, 'inv2', 'call-1'), + self._create_event(5.0, 'inv3', 'e3'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 5.0, 'Summary including resolved auth' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual( + compacted_inv_ids, ['inv1', 'inv2', 'inv2', 'inv2', 'inv3'] + ) + + async def test_sliding_window_resolved_hitl_outside_window_is_compactable( + self, + ): + """A HITL whose resolver lives past the truncation point is compactable.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=2, + overlap_size=0, + ), + ) + # inv1 text, inv2 call_a, inv3 HITL_a, inv4 call_b (unanswered), + # inv5 resolver_a. _truncate_events_before_pending_function_call prunes + # at inv4 because call_b has no response in the session, leaving + # resolver_a outside events_to_compact. The HITL still has to be + # recognized as resolved via the full-session lookup. + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-a'), + self._create_hitl_confirmation_event(3.0, 'inv3', 'call-a'), + self._create_function_call_event(4.0, 'inv4', 'call-b'), + self._create_function_response_event(5.0, 'inv5', 'call-a'), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 3.0, 'Summary including resolved hitl' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2', 'inv3']) + + async def test_token_threshold_resolved_hitl_outside_window_is_compactable( + self, + ): + """Token-threshold: HITL with resolver past the truncation point compacts.""" + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=0, + ), + ) + events = [ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'call-a'), + self._create_hitl_confirmation_event(3.0, 'inv3', 'call-a'), + self._create_function_call_event(4.0, 'inv4', 'call-b'), + self._create_function_response_event( + 5.0, 'inv5', 'call-a', prompt_token_count=100 + ), + ] + session = Session(app_name='test', user_id='u1', id='s1', events=events) + + mock_compacted_event = self._create_compacted_event( + 1.0, 3.0, 'Summary including resolved hitl' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + compacted_inv_ids = [e.invocation_id for e in compacted_events_arg] + self.assertEqual(compacted_inv_ids, ['inv1', 'inv2', 'inv3']) diff --git a/tests/unittests/cli/utils/test_cli_deploy.py b/tests/unittests/cli/utils/test_cli_deploy.py index a1e66139f8..506e274845 100644 --- a/tests/unittests/cli/utils/test_cli_deploy.py +++ b/tests/unittests/cli/utils/test_cli_deploy.py @@ -147,7 +147,10 @@ def test_resolve_project_from_gcloud_fails( "gs://a", "rag://m", None, - "--session_db_url=sqlite://s --artifact_storage_uri=gs://a", + ( + "--session_service_uri=sqlite://s --artifact_service_uri=gs://a" + " --memory_service_uri=rag://m" + ), ), ( "0.5.0", @@ -155,7 +158,10 @@ def test_resolve_project_from_gcloud_fails( "gs://a", "rag://m", None, - "--session_db_url=sqlite://s", + ( + "--session_service_uri=sqlite://s --artifact_service_uri=gs://a" + " --memory_service_uri=rag://m" + ), ), ( "1.3.0", @@ -179,7 +185,7 @@ def test_resolve_project_from_gcloud_fails( "gs://a", None, None, - "--artifact_storage_uri=gs://a", + "--artifact_service_uri=gs://a", ), ( "1.21.0", diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index 7c642dbbe9..ad47c9ecbe 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -596,41 +596,6 @@ def test_cli_web_passes_service_uris( assert called_kwargs.get("memory_service_uri") == "rag://mycorpus" -@pytest.mark.unmute_click -def test_cli_web_warns_and_maps_deprecated_uris( - tmp_path: Path, - _patch_uvicorn: _Recorder, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """`adk web` should accept deprecated URI flags with warnings.""" - agents_dir = tmp_path / "agents" - agents_dir.mkdir() - - mock_get_app = _Recorder() - monkeypatch.setattr(cli_tools_click, "get_fast_api_app", mock_get_app) - - runner = CliRunner() - result = runner.invoke( - cli_tools_click.main, - [ - "web", - str(agents_dir), - "--session_db_url", - "sqlite:///deprecated.db", - "--artifact_storage_uri", - "gs://deprecated", - ], - ) - - assert result.exit_code == 0 - called_kwargs = mock_get_app.calls[0][1] - assert called_kwargs.get("session_service_uri") == "sqlite:///deprecated.db" - assert called_kwargs.get("artifact_service_uri") == "gs://deprecated" - # Check output for deprecation warnings (CliRunner captures both stdout and stderr) - assert "--session_db_url" in result.output - assert "--artifact_storage_uri" in result.output - - def test_cli_eval_with_eval_set_file_path( mock_load_eval_set_from_file, mock_get_root_agent, diff --git a/tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py b/tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py index 6798143df3..be0a0394c5 100644 --- a/tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py +++ b/tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py @@ -634,10 +634,10 @@ def test_aggregate_conversation_percentage_below_threshold_produces_failure(): async def test_evaluate_invocations_all_pass(): evaluator = _create_test_evaluator() - async def sample_llm_valid(*args, **kwargs): + async def sample_llm_valid(*args, **kwargs): # pylint: disable=unused-argument return AutoRaterScore(score=1.0) - evaluator._sample_llm = sample_llm_valid + evaluator._sample_llm = sample_llm_valid # pylint: disable=protected-access starting_prompt = "first user prompt." conversation_scenario = _create_test_conversation_scenario( starting_prompt=starting_prompt @@ -656,3 +656,43 @@ async def sample_llm_valid(*args, **kwargs): assert len(result.per_invocation_results) == 2 assert result.per_invocation_results[0].score == 1.0 assert result.per_invocation_results[1].score == 1.0 + + +@pytest.mark.asyncio +async def test_evaluate_invocations_none_judge_model_config(): + """Tests evaluation when judge_model_config is None.""" + evaluator = PerTurnUserSimulatorQualityV1( + EvalMetric( + metric_name="test_per_turn_user_simulator_quality_v1", + threshold=1.0, + criterion=LlmBackedUserSimulatorCriterion( + threshold=1.0, + stop_signal="test stop signal", + judge_model_options=JudgeModelOptions( + judge_model="gemini-2.5-flash", + judge_model_config=None, + num_samples=1, + ), + ), + ), + ) + + async def sample_llm_valid(*args, **kwargs): # pylint: disable=unused-argument + return AutoRaterScore(score=1.0) + + evaluator._sample_llm = sample_llm_valid # pylint: disable=protected-access + starting_prompt = "first user prompt." + conversation_scenario = _create_test_conversation_scenario( + starting_prompt=starting_prompt + ) + invocations = _create_test_invocations( + [starting_prompt, "model 1.", "user 2.", "model 2."] + ) + result = await evaluator.evaluate_invocations( + actual_invocations=invocations, + expected_invocations=None, + conversation_scenario=conversation_scenario, + ) + + assert result.overall_score == 1.0 + assert result.overall_eval_status == EvalStatus.PASSED diff --git a/tests/unittests/evaluation/test_evaluation_generator.py b/tests/unittests/evaluation/test_evaluation_generator.py index a4aa8691fd..508b6f5c9c 100644 --- a/tests/unittests/evaluation/test_evaluation_generator.py +++ b/tests/unittests/evaluation/test_evaluation_generator.py @@ -14,8 +14,11 @@ from __future__ import annotations +import asyncio + from google.adk.evaluation.app_details import AgentDetails from google.adk.evaluation.app_details import AppDetails +from google.adk.evaluation.evaluation_generator import _LiveSession from google.adk.evaluation.evaluation_generator import EvaluationGenerator from google.adk.evaluation.request_intercepter_plugin import _RequestIntercepterPlugin from google.adk.evaluation.simulation.user_simulator import NextUserMessage @@ -396,6 +399,61 @@ async def mock_run_async(*args, **kwargs): ) +class TestGenerateInferencesForSingleUserInvocationLive: + """Test cases for EvaluationGenerator._generate_inferences_for_single_user_invocation_live method.""" + + @pytest.mark.asyncio + async def test_generate_inferences_live(self, mocker): + """Tests live inference generation.""" + mock_live_request_queue = mocker.MagicMock() + event_queue = asyncio.Queue() + turn_complete_event = asyncio.Event() + + user_content = types.Content(parts=[types.Part(text="User query")]) + invocation_id = "inv1" + + agent_event = _build_event( + "agent", [types.Part(text="Agent response")], invocation_id + ) + other_event = _build_event( + "agent", [types.Part(text="Other response")], "inv2" + ) + + gen = EvaluationGenerator._generate_inferences_for_single_user_invocation_live( + live_request_queue=mock_live_request_queue, + event_queue=event_queue, + user_message=user_content, + current_invocation_id=invocation_id, + turn_complete_event=turn_complete_event, + live_timeout_seconds=300, + ) + + # First yield should be the user message + first_event = await gen.__anext__() + assert first_event.author == "user" + assert first_event.content == user_content + assert first_event.invocation_id == invocation_id + + # Mock turn_complete_event.wait to avoid blocking + turn_complete_event.wait = mocker.AsyncMock() + + # Put events in queue BEFORE advancing + await event_queue.put(agent_event) + await event_queue.put(other_event) + + # Now advance to get the next event + second_event = await gen.__anext__() + + assert mock_live_request_queue.send_content.called + mock_live_request_queue.send_content.assert_called_once_with(user_content) + + assert second_event == agent_event + + # The generator should be exhausted now because other_event doesn't match invocation_id + with pytest.raises(StopAsyncIteration): + await gen.__anext__() + + @pytest.fixture def mock_runner(mocker): """Provides a mock Runner for testing.""" @@ -479,3 +537,270 @@ async def mock_generate_inferences_side_effect( mock_generate_inferences.assert_called_once() called_with_content = mock_generate_inferences.call_args.args[3] assert called_with_content.parts[0].text == "message 1" + + @pytest.mark.asyncio + async def test_generates_inferences_with_user_simulator_live( + self, mocker, mock_runner, mock_session_service + ): + """Tests that inferences are generated by interacting with a user simulator in live mode.""" + mock_agent = mocker.MagicMock() + mock_user_sim = mocker.MagicMock(spec=UserSimulator) + + # Mock user simulator will produce one message, then stop. + async def get_next_user_message_side_effect(*args, **kwargs): + if mock_user_sim.get_next_user_message.call_count == 1: + return NextUserMessage( + status=UserSimulatorStatus.SUCCESS, + user_message=types.Content(parts=[types.Part(text="message 1")]), + ) + return NextUserMessage(status=UserSimulatorStatus.STOP_SIGNAL_DETECTED) + + mock_user_sim.get_next_user_message = mocker.AsyncMock( + side_effect=get_next_user_message_side_effect + ) + + mock_generate_inferences_live = mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._generate_inferences_for_single_user_invocation_live" + ) + mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._get_app_details_by_invocation_id" + ) + mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator.convert_events_to_eval_invocations" + ) + + # Mock _LiveSession context manager + mock_live_session = mocker.MagicMock() + mock_live_session.__aenter__ = mocker.AsyncMock( + return_value=mock_live_session + ) + mock_live_session.__aexit__ = mocker.AsyncMock(return_value=None) + mock_live_session.live_request_queue = mocker.MagicMock() + mock_live_session.event_queue = asyncio.Queue() + mock_live_session.turn_complete_event = asyncio.Event() + mock_live_session.live_finished = asyncio.Event() + + mock_live_session_cls = mocker.patch( + "google.adk.evaluation.evaluation_generator._LiveSession", + return_value=mock_live_session, + ) + + # Each call to _generate_inferences_for_single_user_invocation_live will + # yield one user and one agent event. + async def mock_generate_inferences_live_side_effect(*args, **kwargs): + yield _build_event("user", [types.Part(text="message 1")], "inv1") + yield _build_event("agent", [types.Part(text="agent_response")], "inv1") + + mock_generate_inferences_live.side_effect = ( + mock_generate_inferences_live_side_effect + ) + + await EvaluationGenerator._generate_inferences_from_root_agent_live( + root_agent=mock_agent, + user_simulator=mock_user_sim, + live_timeout_seconds=600, + ) + + # Check that user simulator was called until it stopped. + assert mock_user_sim.get_next_user_message.call_count == 2 + + # Check that we generated inferences for each user message. + mock_generate_inferences_live.assert_called_once() + called_with_content = mock_generate_inferences_live.call_args.kwargs[ + "user_message" + ] + assert called_with_content.parts[0].text == "message 1" + assert ( + mock_generate_inferences_live.call_args.kwargs["live_timeout_seconds"] + == 600 + ) + + # Verify that the agent response was collected + mock_convert = EvaluationGenerator.convert_events_to_eval_invocations + mock_convert.assert_called_once() + events_passed = mock_convert.call_args.args[0] + + agent_events = [e for e in events_passed if e.author == "agent"] + assert len(agent_events) == 1 + assert agent_events[0].content.parts[0].text == "agent_response" + + # Verify that the _LiveSession constructor was called + mock_live_session_cls.assert_called_once() + + +class TestLiveSessionCallbacks: + """Unit tests verifying that _LiveSession manually triggers callbacks.""" + + @pytest.mark.asyncio + async def test_live_session_manually_triggers_callbacks(self, mocker): + from google.adk.agents.callback_context import CallbackContext + from google.adk.agents.llm_agent import Agent + from google.adk.models.llm_request import LlmRequest + + # 1. Setup mock runner, agent, and session + mock_runner = mocker.MagicMock() + mock_runner.session_service.append_event = mocker.AsyncMock() + mock_session = mocker.MagicMock() + mock_agent = mocker.MagicMock(spec=Agent) + mock_runner.agent = mock_agent + mock_runner._find_agent_to_run.return_value = mock_agent + + mock_agent.name = "test_agent" + + # Mock _llm_flow._preprocess_async to set dummy instruction + async def mock_preprocess_async(invocation_context, llm_request): + llm_request.config.system_instruction = "mock instruction" + return + yield # make it an async generator + + mock_flow = mocker.MagicMock() + mock_flow._preprocess_async = mock_preprocess_async + mock_agent._llm_flow = mock_flow + + # Mock run_live stream yielding one event + mock_event = Event( + author="agent", + content=types.Content(parts=[types.Part(text="Hello")]), + invocation_id="test_invocation_id", + ) + + async def mock_run_live(*args, **kwargs): + yield mock_event + + mock_agent.run_live.return_value = mock_run_live() + + # Mock plugin_manager on invocation context + mock_plugin_manager = mocker.MagicMock() + mock_plugin_manager.run_before_model_callback = mocker.AsyncMock() + mock_plugin_manager.run_after_model_callback = mocker.AsyncMock() + mock_runner._new_invocation_context_for_live.return_value.plugin_manager = ( + mock_plugin_manager + ) + mock_runner._new_invocation_context_for_live.return_value.agent = mock_agent + + # 2. Instantiate and enter _LiveSession + live_session = _LiveSession( + runner=mock_runner, + session=mock_session, + user_id="test_user", + session_id="test_session", + ) + + # Directly run _consume_events as a coroutine for synchronous-style testing + await live_session._consume_events() + + # 3. Assertions + mock_plugin_manager.run_before_model_callback.assert_called_once() + called_before_args = mock_plugin_manager.run_before_model_callback.call_args + assert isinstance( + called_before_args.kwargs["callback_context"], CallbackContext + ) + assert isinstance(called_before_args.kwargs["llm_request"], LlmRequest) + assert ( + called_before_args.kwargs["llm_request"].config.system_instruction + == "mock instruction" + ) + + mock_plugin_manager.run_after_model_callback.assert_called_once() + called_after_args = mock_plugin_manager.run_after_model_callback.call_args + assert isinstance( + called_after_args.kwargs["callback_context"], CallbackContext + ) + assert isinstance(called_after_args.kwargs["llm_response"], Event) + assert called_after_args.kwargs["llm_response"] == mock_event + + @pytest.mark.asyncio + async def test_live_session_manually_triggers_callbacks_with_tools( + self, mocker + ): + from google.adk.agents.callback_context import CallbackContext + from google.adk.agents.llm_agent import Agent + from google.adk.models.llm_request import LlmRequest + + # 1. Setup mock runner, agent, and session + mock_runner = mocker.MagicMock() + mock_runner.session_service.append_event = mocker.AsyncMock() + mock_session = mocker.MagicMock() + mock_agent = mocker.MagicMock(spec=Agent) + mock_runner.agent = mock_agent + mock_runner._find_agent_to_run.return_value = mock_agent + + mock_agent.name = "test_agent" + + # Set up a mock tool + mock_tool = mocker.MagicMock() + mock_tool.name = "get_weather" + mock_decl = types.FunctionDeclaration( + name="get_weather", + description="Get weather details", + ) + mock_tool._get_declaration.return_value = mock_decl + + # Mock _llm_flow._preprocess_async to set instruction and append tool + async def mock_preprocess_async(invocation_context, llm_request): + llm_request.config.system_instruction = "mock instruction" + llm_request.append_tools([mock_tool]) + return + yield # make it an async generator + + mock_flow = mocker.MagicMock() + mock_flow._preprocess_async = mock_preprocess_async + mock_agent._llm_flow = mock_flow + + # Mock run_live stream yielding one event + mock_event = Event( + author="agent", + content=types.Content(parts=[types.Part(text="Hello")]), + invocation_id="test_invocation_id", + ) + + async def mock_run_live(*args, **kwargs): + yield mock_event + + mock_agent.run_live.return_value = mock_run_live() + + # Mock plugin_manager on invocation context + mock_plugin_manager = mocker.MagicMock() + mock_plugin_manager.run_before_model_callback = mocker.AsyncMock() + mock_plugin_manager.run_after_model_callback = mocker.AsyncMock() + mock_runner._new_invocation_context_for_live.return_value.plugin_manager = ( + mock_plugin_manager + ) + mock_runner._new_invocation_context_for_live.return_value.agent = mock_agent + + # 2. Instantiate and enter _LiveSession + live_session = _LiveSession( + runner=mock_runner, + session=mock_session, + user_id="test_user", + session_id="test_session", + ) + + # Directly run _consume_events as a coroutine + await live_session._consume_events() + + # 3. Assertions + mock_plugin_manager.run_before_model_callback.assert_called_once() + called_before_args = mock_plugin_manager.run_before_model_callback.call_args + assert isinstance( + called_before_args.kwargs["callback_context"], CallbackContext + ) + + llm_request = called_before_args.kwargs["llm_request"] + assert isinstance(llm_request, LlmRequest) + assert llm_request.config.system_instruction == "mock instruction" + + # Assert that tool was correctly wrapped under types.Tool format + assert len(llm_request.config.tools) == 1 + wrapped_tool = llm_request.config.tools[0] + assert isinstance(wrapped_tool, types.Tool) + assert len(wrapped_tool.function_declarations) == 1 + assert wrapped_tool.function_declarations[0].name == "get_weather" + + mock_plugin_manager.run_after_model_callback.assert_called_once() + called_after_args = mock_plugin_manager.run_after_model_callback.call_args + assert isinstance( + called_after_args.kwargs["callback_context"], CallbackContext + ) + assert isinstance(called_after_args.kwargs["llm_response"], Event) + assert called_after_args.kwargs["llm_response"] == mock_event diff --git a/tests/unittests/evaluation/test_local_eval_service.py b/tests/unittests/evaluation/test_local_eval_service.py index 386c1fd07a..894fc27c07 100644 --- a/tests/unittests/evaluation/test_local_eval_service.py +++ b/tests/unittests/evaluation/test_local_eval_service.py @@ -247,12 +247,59 @@ async def test_perform_inference_with_case_ids( eval_set_id="test_eval_set", eval_case=eval_set.eval_cases[0], root_agent=dummy_agent, + use_live=False, + live_timeout_seconds=300, ) eval_service._perform_inference_single_eval_item.assert_any_call( app_name="test_app", eval_set_id="test_eval_set", eval_case=eval_set.eval_cases[2], root_agent=dummy_agent, + use_live=False, + live_timeout_seconds=300, + ) + + +@pytest.mark.asyncio +async def test_perform_inference_with_use_live( + eval_service, + dummy_agent, + mock_eval_sets_manager, + mocker, +): + eval_set = EvalSet( + eval_set_id="test_eval_set", + eval_cases=[ + EvalCase(eval_id="case1", conversation=[], session_input=None), + ], + ) + mock_eval_sets_manager.get_eval_set.return_value = eval_set + + mock_inference_result = mocker.MagicMock() + eval_service._perform_inference_single_eval_item = mocker.AsyncMock( + return_value=mock_inference_result + ) + + inference_request = InferenceRequest( + app_name="test_app", + eval_set_id="test_eval_set", + inference_config=InferenceConfig( + parallelism=1, use_live=True, live_timeout_seconds=600 + ), + ) + + results = [] + async for result in eval_service.perform_inference(inference_request): + results.append(result) + + assert len(results) == 1 + eval_service._perform_inference_single_eval_item.assert_called_once_with( + app_name="test_app", + eval_set_id="test_eval_set", + eval_case=eval_set.eval_cases[0], + root_agent=dummy_agent, + use_live=True, + live_timeout_seconds=600, ) @@ -791,3 +838,80 @@ def test_copy_invocation_rubrics_to_actual_invocations(): _copy_invocation_rubrics_to_actual_invocations(expected, actual) assert actual[0].rubrics == [rubric1] assert actual[1].rubrics == [rubric2] + + +@pytest.mark.asyncio +async def test_perform_inference_single_eval_item_live( + eval_service, dummy_agent, mocker +): + eval_case = EvalCase(eval_id="case1", conversation=[], session_input=None) + mock_generate_live = mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._generate_inferences_from_root_agent_live" + ) + mock_generate_live.return_value = [] + + eval_service._session_id_supplier = mocker.MagicMock( + return_value="test_session_id" + ) + mock_user_sim = mocker.MagicMock() + eval_service._user_simulator_provider.provide = mocker.MagicMock( + return_value=mock_user_sim + ) + + await eval_service._perform_inference_single_eval_item( + app_name="test_app", + eval_set_id="test_eval_set", + eval_case=eval_case, + root_agent=dummy_agent, + use_live=True, + live_timeout_seconds=600, + ) + + mock_generate_live.assert_called_once_with( + root_agent=dummy_agent, + user_simulator=mock_user_sim, + initial_session=None, + session_id="test_session_id", + session_service=eval_service._session_service, + artifact_service=eval_service._artifact_service, + memory_service=eval_service._memory_service, + live_timeout_seconds=600, + ) + + +@pytest.mark.asyncio +async def test_perform_inference_single_eval_item_non_live( + eval_service, dummy_agent, mocker +): + eval_case = EvalCase(eval_id="case1", conversation=[], session_input=None) + mock_generate = mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._generate_inferences_from_root_agent" + ) + mock_generate.return_value = [] + + eval_service._session_id_supplier = mocker.MagicMock( + return_value="test_session_id" + ) + mock_user_sim = mocker.MagicMock() + eval_service._user_simulator_provider.provide = mocker.MagicMock( + return_value=mock_user_sim + ) + + await eval_service._perform_inference_single_eval_item( + app_name="test_app", + eval_set_id="test_eval_set", + eval_case=eval_case, + root_agent=dummy_agent, + use_live=False, + live_timeout_seconds=300, + ) + + mock_generate.assert_called_once_with( + root_agent=dummy_agent, + user_simulator=mock_user_sim, + initial_session=None, + session_id="test_session_id", + session_service=eval_service._session_service, + artifact_service=eval_service._artifact_service, + memory_service=eval_service._memory_service, + ) diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py index fd3f2a8ec4..f4ba47cf25 100644 --- a/tests/unittests/integrations/agent_registry/test_agent_registry.py +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -607,6 +607,15 @@ def test_get_auth_headers(self, registry): assert headers["Authorization"] == "Bearer fake-token" assert headers["x-goog-user-project"] == "quota-project" + def test_get_auth_headers_fallback_to_project_id(self, registry): + registry._credentials.token = "fake-token" + registry._credentials.refresh = MagicMock() + registry._credentials.quota_project_id = None + + headers = registry._get_auth_headers() + assert headers["Authorization"] == "Bearer fake-token" + assert headers["x-goog-user-project"] == "test-project" + @patch("httpx.Client") def test_make_request_raises_http_status_error(self, mock_httpx, registry): mock_response = MagicMock() diff --git a/tests/unittests/integrations/skill_registry/test_gcp_skill_registry.py b/tests/unittests/integrations/skill_registry/test_gcp_skill_registry.py new file mode 100644 index 0000000000..bf410456e9 --- /dev/null +++ b/tests/unittests/integrations/skill_registry/test_gcp_skill_registry.py @@ -0,0 +1,184 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for GCP Skill Registry.""" + +import base64 +import os +from unittest import mock +import zipfile + +from google.adk.integrations.skill_registry.gcp_skill_registry import GCPSkillRegistry +import pytest + + +@pytest.fixture(autouse=True) +def mock_env(): + """Fixture to mock environment variables.""" + with mock.patch.dict( + os.environ, + { + "GOOGLE_CLOUD_PROJECT": "test-project", + "GOOGLE_CLOUD_LOCATION": "us-central1", + }, + ): + yield + + +@pytest.fixture +def mock_vertex_client(): + """Fixture to mock vertexai.Client.""" + with mock.patch("vertexai.Client") as mock_client_class: + mock_client = mock_client_class.return_value + yield mock_client + + +def _create_fake_zip_bytes(): + """Creates a fake zip file in memory and returns its bytes.""" + import io + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as z: + z.writestr( + "SKILL.md", "---\nname: my-skill\ndescription: test\n---\n# My Skill\n" + ) + return zip_buffer.getvalue() + + +@pytest.mark.asyncio +async def test_get_skill_success(mock_vertex_client): + """Verifies that get_skill successfully fetches and loads a skill in memory.""" + registry = GCPSkillRegistry() + + fake_zip = _create_fake_zip_bytes() + fake_zip_base64 = base64.b64encode(fake_zip).decode("utf-8") + + mock_skill_resource = mock.MagicMock() + mock_skill_resource.zipped_filesystem = fake_zip_base64 + + mock_vertex_client.aio.skills.get = mock.AsyncMock( + return_value=mock_skill_resource + ) + + skill = await registry.get_skill(name="my-skill") + + assert skill.frontmatter.name == "my-skill" + assert skill.frontmatter.description == "test" + assert skill.instructions == "# My Skill" + mock_vertex_client.aio.skills.get.assert_called_once_with( + name="projects/test-project/locations/us-central1/skills/my-skill" + ) + + +@pytest.mark.asyncio +async def test_search_skills_success(mock_vertex_client): + """Verifies that search_skills successfully returns frontmatter list.""" + registry = GCPSkillRegistry() + + mock_skill1 = mock.MagicMock() + mock_skill1.skill_name = ( + "projects/test-project/locations/us-central1/skills/skill1" + ) + mock_skill1.description = "Description 1" + + mock_skill2 = mock.MagicMock() + mock_skill2.skill_name = ( + "projects/test-project/locations/us-central1/skills/skill2" + ) + mock_skill2.description = "Description 2" + + mock_response = mock.MagicMock() + mock_response.retrieved_skills = [mock_skill1, mock_skill2] + + mock_vertex_client.aio.skills.retrieve = mock.AsyncMock( + return_value=mock_response + ) + + results = await registry.search_skills(query="query") + + assert len(results) == 2 + assert results[0].name == "skill1" + assert results[0].description == "Description 1" + assert results[1].name == "skill2" + assert results[1].description == "Description 2" + mock_vertex_client.aio.skills.retrieve.assert_called_once_with(query="query") + + +@pytest.mark.asyncio +async def test_get_skill_raises_on_missing_zip(mock_vertex_client): + """Verifies that get_skill raises error if zip filesystem is missing.""" + registry = GCPSkillRegistry() + + mock_skill_resource = mock.MagicMock() + mock_skill_resource.zipped_filesystem = None + + mock_vertex_client.aio.skills.get = mock.AsyncMock( + return_value=mock_skill_resource + ) + + with pytest.raises(ValueError, match="does not contain zipped filesystem"): + await registry.get_skill(name="my-skill") + + +@pytest.mark.asyncio +async def test_get_skill_raises_on_zip_slip(mock_vertex_client): + """Verifies that get_skill raises error if zip contains dangerous paths.""" + registry = GCPSkillRegistry() + + import io + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as z: + z.writestr("../evil.txt", "malicious content") + z.writestr( + "SKILL.md", "---\nname: my-skill\ndescription: test\n---\n# My Skill\n" + ) + fake_zip = zip_buffer.getvalue() + fake_zip_base64 = base64.b64encode(fake_zip).decode("utf-8") + + mock_skill_resource = mock.MagicMock() + mock_skill_resource.zipped_filesystem = fake_zip_base64 + + mock_vertex_client.aio.skills.get = mock.AsyncMock( + return_value=mock_skill_resource + ) + + with pytest.raises(ValueError, match="Dangerous zip entry ignored"): + await registry.get_skill(name="my-skill") + + +@pytest.mark.asyncio +async def test_get_skill_raises_on_invalid_skill_name(mock_vertex_client): + """Verifies that get_skill raises error if skill name is invalid.""" + registry = GCPSkillRegistry() + + import io + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as z: + z.writestr( + "SKILL.md", "---\nname: ../evil\ndescription: test\n---\n# My Skill\n" + ) + fake_zip = zip_buffer.getvalue() + fake_zip_base64 = base64.b64encode(fake_zip).decode("utf-8") + + mock_skill_resource = mock.MagicMock() + mock_skill_resource.zipped_filesystem = fake_zip_base64 + + mock_vertex_client.aio.skills.get = mock.AsyncMock( + return_value=mock_skill_resource + ) + + with pytest.raises(ValueError, match="Invalid skill name in SKILL.md"): + await registry.get_skill(name="my-skill") diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index bf173c5cc3..d6daf39dc6 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -1096,6 +1096,65 @@ def test_part_to_message_block_empty_response_stays_empty(): assert result["content"] == "" +def test_part_to_message_block_string_content_passes_through(): + """A scalar string `content` value must not be iterated char-by-char.""" + response_part = types.Part.from_function_response( + name="some_tool", + response={"content": "Hello"}, + ) + response_part.function_response.id = "test_id_str_content" + + result = part_to_message_block(response_part) + + assert result["content"] == "Hello" + + +def test_part_to_message_block_load_skill_resource_response(): + """LoadSkillResourceTool returns {content: } as a string.""" + file_text = "Line one\nLine two\nLine three" + response_part = types.Part.from_function_response( + name="load_skill_resource", + response={ + "skill_name": "my-skill", + "file_path": "references/doc.md", + "content": file_text, + }, + ) + response_part.function_response.id = "test_id_load_skill" + + result = part_to_message_block(response_part) + + assert result["content"] == file_text + + +def test_part_to_message_block_empty_string_content_falls_through(): + """`{"content": ""}` falls through to the JSON-dump fallback, not a crash.""" + response_part = types.Part.from_function_response( + name="some_tool", + response={"content": ""}, + ) + response_part.function_response.id = "test_id_empty_content_only" + + result = part_to_message_block(response_part) + + assert json.loads(result["content"]) == {"content": ""} + + +def test_part_to_message_block_empty_content_with_metadata_keeps_metadata(): + """`content: ""` is falsy; sibling keys still reach the model via JSON dump.""" + response_part = types.Part.from_function_response( + name="some_tool", + response={"content": "", "extra": "keep me"}, + ) + response_part.function_response.id = "test_id_empty_content_with_meta" + + result = part_to_message_block(response_part) + + parsed = json.loads(result["content"]) + assert parsed["content"] == "" + assert parsed["extra"] == "keep me" + + # --- Tests for Bug #1: Streaming support --- diff --git a/tests/unittests/models/test_interactions_utils.py b/tests/unittests/models/test_interactions_utils.py index 3a6b964ac1..118a925ab6 100644 --- a/tests/unittests/models/test_interactions_utils.py +++ b/tests/unittests/models/test_interactions_utils.py @@ -280,11 +280,9 @@ def test_function_response_dict(self): assert result['type'] == 'function_result' assert result['call_id'] == 'call_123' assert result['name'] == 'get_weather' - # Dict should be JSON serialized - assert json.loads(result['result']) == { - 'temperature': 20, - 'condition': 'sunny', - } + # Dict should be passed through directly (not JSON-serialized). + assert result['result'] == {'temperature': 20, 'condition': 'sunny'} + assert isinstance(result['result'], dict) def test_function_response_simple(self): """Test converting a function response Part with simple response.""" @@ -299,8 +297,37 @@ def test_function_response_simple(self): assert result['type'] == 'function_result' assert result['call_id'] == 'call_123' assert result['name'] == 'check_weather' - # Dict should be JSON serialized - assert json.loads(result['result']) == {'message': 'Weather is sunny'} + # Dict should be passed through directly (not JSON-serialized). + assert result['result'] == {'message': 'Weather is sunny'} + + def test_function_response_dict_not_double_serialized(self): + """Regression test: avoid double-serializing bash tool outputs. + + Bash tool responses contain JSON structures (stdout/stderr). When these + dict responses were json.dumps()'d before being sent to the Interactions + API, the API's own serialization would escape the already-escaped content, + producing unreadable output like: + {"result":"\\\"{\\\\\\\"error\\\\\\\":\\\\\\\"...\\\\\\\"}\\\"" + """ + bash_response = { + 'stdout': '{"name": "test", "version": "1.0"}\n', + 'stderr': '', + } + part = types.Part( + function_response=types.FunctionResponse( + id='call_bash', + name='bash', + response=bash_response, + ) + ) + result = interactions_utils.convert_part_to_interaction_content(part) + # The result value must be the dict itself, NOT a JSON string. + assert isinstance(result['result'], dict) + assert result['result'] == bash_response + # Verify there's no double-escaping: if result were a JSON string, + # serializing it again would add backslashes before the internal quotes. + wire_json = json.dumps(result) + assert '\\\\' not in wire_json def test_inline_data_image(self): """Test converting an inline image Part.""" diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index c195076349..c8bf5be010 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1908,14 +1908,14 @@ async def test_content_to_message_param_user_message_file_uri_only( @pytest.mark.asyncio async def test_content_to_message_param_user_message_file_uri_without_mime_type(): - """Test handling of file_data without mime_type (GcsArtifactService scenario). + """Test that file_data without an inferable mime_type raises ValueError. When using GcsArtifactService, artifacts may have file_uri (gs://...) but - without mime_type set. LiteLLM's Vertex AI backend requires the format - field to be present, so we infer MIME type from the URI extension or use - a default fallback to ensure compatibility. + without mime_type set. When the MIME type cannot be determined from the URI + extension or display_name, ADK raises a clear ValueError rather than + forwarding an unsupported 'application/octet-stream' to LiteLLM. - See: https://github.com/google/adk-python/issues/3787 + See: https://github.com/google/adk-python/issues/5022 """ file_part = types.Part( file_data=types.FileData( @@ -1930,22 +1930,36 @@ async def test_content_to_message_param_user_message_file_uri_without_mime_type( ], ) - message = await _content_to_message_param(content) - assert message == { - "role": "user", - "content": [ - {"type": "text", "text": "Analyze this file."}, - { - "type": "file", - "file": { - "file_id": ( - "gs://agent-artifact-bucket/app/user/session/artifact/0" - ), - "format": "application/octet-stream", - }, - }, + with pytest.raises(ValueError, match="Cannot process file_uri"): + await _content_to_message_param(content) + + +@pytest.mark.asyncio +async def test_content_to_message_param_user_message_file_uri_explicit_octet_stream(): + """Test that an explicit application/octet-stream MIME type raises ValueError. + + Upstream callers may explicitly set mime_type to 'application/octet-stream' + when the true type is unknown. ADK treats this identically to a missing MIME + type and raises early rather than forwarding the unsupported type to LiteLLM. + + See: https://github.com/google/adk-python/issues/5022 + """ + file_part = types.Part( + file_data=types.FileData( + file_uri="gs://agent-artifact-bucket/app/user/session/artifact/0", + mime_type="application/octet-stream", + ) + ) + content = types.Content( + role="user", + parts=[ + types.Part.from_text(text="Analyze this file."), + file_part, ], - } + ) + + with pytest.raises(ValueError, match="application/octet-stream"): + await _content_to_message_param(content) @pytest.mark.asyncio @@ -1955,7 +1969,7 @@ async def test_content_to_message_param_user_message_file_uri_infer_mime_type(): When file_data has a file_uri with a recognizable extension but no explicit mime_type, the MIME type should be inferred from the extension. - See: https://github.com/google/adk-python/issues/3787 + See: https://github.com/google/adk-python/issues/5022 """ file_part = types.Part( file_data=types.FileData( @@ -3067,7 +3081,7 @@ async def test_get_content_file_uri_infer_mime_type(): When file_data has a file_uri with a recognizable extension but no explicit mime_type, the MIME type should be inferred from the extension. - See: https://github.com/google/adk-python/issues/3787 + See: https://github.com/google/adk-python/issues/5022 """ # Use Part constructor directly to test MIME type inference in _get_content # (types.Part.from_uri does its own inference, so we bypass it) @@ -3117,27 +3131,42 @@ async def test_get_content_file_uri_infers_from_display_name(): @pytest.mark.asyncio async def test_get_content_file_uri_default_mime_type(): - """Test that file_uri without extension uses default MIME type. + """Test that file_uri without an inferable extension raises ValueError. When file_data has a file_uri without a recognizable extension and no explicit - mime_type, a default MIME type should be used to ensure compatibility with - LiteLLM backends. + mime_type, ADK raises a clear ValueError instead of forwarding the unsupported + 'application/octet-stream' MIME type to LiteLLM. - See: https://github.com/google/adk-python/issues/3787 + See: https://github.com/google/adk-python/issues/5022 """ - # Use Part constructor directly to create file_data without mime_type - # (types.Part.from_uri requires a valid mime_type when it can't infer) parts = [ types.Part(file_data=types.FileData(file_uri="gs://bucket/artifact/0")) ] - content = await _get_content(parts) - assert content[0] == { - "type": "file", - "file": { - "file_id": "gs://bucket/artifact/0", - "format": "application/octet-stream", - }, - } + with pytest.raises(ValueError, match="Cannot process file_uri"): + await _get_content(parts) + + +@pytest.mark.asyncio +async def test_get_content_file_uri_explicit_octet_stream_raises(): + """Test that an explicit application/octet-stream MIME type raises ValueError. + + 'application/octet-stream' is semantically equivalent to an unknown type and + causes the same downstream ValueError from LiteLLM whether it arrives as a + default fallback or is set explicitly by the caller. ADK raises early with + an actionable message in both cases. + + See: https://github.com/google/adk-python/issues/5022 + """ + parts = [ + types.Part( + file_data=types.FileData( + file_uri="gs://bucket/artifact/0", + mime_type="application/octet-stream", + ) + ) + ] + with pytest.raises(ValueError, match="application/octet-stream"): + await _get_content(parts) @pytest.mark.asyncio diff --git a/tests/unittests/models/test_llm_response.py b/tests/unittests/models/test_llm_response.py index e2cbe4c286..02b7126ab5 100644 --- a/tests/unittests/models/test_llm_response.py +++ b/tests/unittests/models/test_llm_response.py @@ -107,6 +107,31 @@ def test_llm_response_create_no_candidates(): assert response.error_message == 'Prompt blocked for safety' +def test_llm_response_create_no_candidates_without_prompt_feedback(): + """Test LlmResponse.create() for empty successful model responses.""" + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=10, + candidates_token_count=0, + total_token_count=10, + ) + generate_content_response = types.GenerateContentResponse( + candidates=[], + usage_metadata=usage_metadata, + model_version='gemini-2.5-flash', + ) + + response = LlmResponse.create(generate_content_response) + + assert response.error_code is None + assert response.error_message is None + assert response.finish_reason is None + assert response.content is not None + assert response.content.role == 'model' + assert not response.content.parts + assert response.usage_metadata == usage_metadata + assert response.model_version == 'gemini-2.5-flash' + + def test_llm_response_create_with_concrete_logprobs_result(): """Test LlmResponse.create() with detailed logprobs_result containing actual token data.""" # Create realistic logprobs data diff --git a/tests/unittests/skills/test__utils.py b/tests/unittests/skills/test__utils.py index ae5df00c1c..1975c0a494 100644 --- a/tests/unittests/skills/test__utils.py +++ b/tests/unittests/skills/test__utils.py @@ -14,12 +14,15 @@ """Unit tests for skill utilities.""" +import io from unittest import mock +import zipfile from google.adk.skills import list_skills_in_dir from google.adk.skills import list_skills_in_gcs_dir as _list_skills_in_gcs_dir from google.adk.skills import load_skill_from_dir as _load_skill_from_dir from google.adk.skills import load_skill_from_gcs_dir as _load_skill_from_gcs_dir +from google.adk.skills._utils import _load_skill_from_zip_bytes from google.adk.skills._utils import _read_skill_properties from google.adk.skills._utils import _validate_skill_dir import pytest @@ -340,3 +343,23 @@ def test_list_skills_in_dir_missing_base_path(tmp_path): skills = list_skills_in_dir(tmp_path / "nonexistent") assert skills == {} + + +def test__load_skill_from_zip_bytes(): + """Tests loading a skill directly from in-memory zip file bytes.""" + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as z: + z.writestr( + "SKILL.md", + "---\nname: my-skill\ndescription: A skill\n---\nBody instructions", + ) + z.writestr("references/ref1.md", "ref1 content") + z.writestr("scripts/script1.sh", "echo hello") + + skill = _load_skill_from_zip_bytes(zip_buffer.getvalue()) + assert skill.frontmatter.name == "my-skill" + assert skill.frontmatter.description == "A skill" + assert skill.instructions == "Body instructions" + assert skill.resources.get_reference("ref1.md") == "ref1 content" + assert skill.resources.get_script("script1.sh").src == "echo hello" diff --git a/tests/unittests/skills/test_models.py b/tests/unittests/skills/test_models.py index 73136be0b9..ffbbb2dd50 100644 --- a/tests/unittests/skills/test_models.py +++ b/tests/unittests/skills/test_models.py @@ -153,7 +153,10 @@ def test_description_empty(): def test_description_too_long(): - with pytest.raises(ValidationError, match="at most 1024 characters"): + with pytest.raises( + ValidationError, + match="at most 1024 characters. Description length: 1025", + ): models.Frontmatter(name="my-skill", description="x" * 1025) diff --git a/tests/unittests/telemetry/test_google_cloud.py b/tests/unittests/telemetry/test_google_cloud.py index 0199e7b4b6..284a51539e 100644 --- a/tests/unittests/telemetry/test_google_cloud.py +++ b/tests/unittests/telemetry/test_google_cloud.py @@ -16,8 +16,18 @@ from typing import Optional from unittest import mock +from google.adk.telemetry import google_cloud +from google.adk.telemetry.google_cloud import _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT +from google.adk.telemetry.google_cloud import _DEFAULT_TELEMETRY_TRACES_ENPOINT +from google.adk.telemetry.google_cloud import _get_api_endpoint +from google.adk.telemetry.google_cloud import _get_gcp_span_exporter +from google.adk.telemetry.google_cloud import _use_client_cert_effective from google.adk.telemetry.google_cloud import get_gcp_exporters from google.adk.telemetry.google_cloud import get_gcp_resource +import google.auth.credentials +from google.auth.transport import mtls +from google.auth.transport import requests +from opentelemetry.exporter.otlp.proto.http import trace_exporter import pytest @@ -89,3 +99,108 @@ def test_get_gcp_resource( otel_resource.attributes.get("gcp.project_id", None) == expected_project_id ) + + +@mock.patch.object(mtls, "should_use_client_cert", autospec=True) +def test_use_client_cert_effective_from_mtls(mock_should_use): + mock_should_use.return_value = True + assert _use_client_cert_effective() + + mock_should_use.return_value = False + assert not _use_client_cert_effective() + + +def test_use_client_cert_effective_from_env( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +): + with mock.patch.object( + mtls, + "should_use_client_cert", + autospec=True, + side_effect=AttributeError, + ): + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true") + assert _use_client_cert_effective() + + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + assert not _use_client_cert_effective() + + # Test invalid value defaults to False + monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "maybe") + assert not _use_client_cert_effective() + assert ( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + in caplog.text + ) + + +@pytest.mark.parametrize( + "env_val, cert_source, expected", + [ + ("auto", lambda: b"cert", _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT), + ("auto", None, _DEFAULT_TELEMETRY_TRACES_ENPOINT), + ("always", None, _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT), + ("never", lambda: b"cert", _DEFAULT_TELEMETRY_TRACES_ENPOINT), + ("invalid", None, _DEFAULT_TELEMETRY_TRACES_ENPOINT), + ], +) +def test_get_api_endpoint( + env_val, + cert_source, + expected, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +): + monkeypatch.setenv("GOOGLE_API_USE_MTLS_ENDPOINT", env_val) + if env_val == "invalid": + assert _get_api_endpoint(cert_source) == expected + assert ( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of" + in caplog.text + ) + else: + assert _get_api_endpoint(cert_source) == expected + + +@mock.patch.object(requests, "AuthorizedSession", autospec=True) +@mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter", + autospec=True, +) +@mock.patch( + "google.adk.telemetry.google_cloud.BatchSpanProcessor", autospec=True +) +@mock.patch( + "google.adk.telemetry.google_cloud._use_client_cert_effective", + autospec=True, +) +@mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", autospec=True +) +@mock.patch( + "google.auth.transport.mtls.default_client_cert_source", autospec=True +) +def test_get_gcp_span_exporter_mtls( + mock_default_cert: mock.MagicMock, + mock_has_cert: mock.MagicMock, + mock_use_cert: mock.MagicMock, + mock_batch: mock.MagicMock, + mock_exporter: mock.MagicMock, + mock_session: mock.MagicMock, +): + credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + mock_use_cert.return_value = True + mock_has_cert.return_value = True + mock_default_cert.return_value = b"cert" + + _get_gcp_span_exporter(credentials) + + mock_session.assert_called_once_with(credentials=credentials) + mock_session.return_value.configure_mtls_channel.assert_called_once() + mock_exporter.assert_called_once_with( + session=mock_session.return_value, + endpoint=_DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT, + ) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 27df785277..a94b2eb885 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -1043,3 +1043,35 @@ def test_env_var_disable_acts_as_kill_switch(self): os.environ[disable] = saved_disable if saved_enable is not None: os.environ[enable] = saved_enable + + @pytest.mark.asyncio + @patch("google.adk.tools.mcp_tool.mcp_session_manager.asyncio.wait_for") + async def test_create_session_does_not_use_wait_for_when_ge_is_enabled( + self, mock_wait_for + ): + """create_session must not wrap enter_async_context in asyncio.wait_for when GE is enabled.""" + from google.adk.features import FeatureName + from google.adk.features._feature_registry import temporary_feature_override + + manager = MCPSessionManager( + StdioConnectionParams( + server_params=StdioServerParameters(command="dummy", args=[]), + timeout=5.0, + ) + ) + with temporary_feature_override( + FeatureName._MCP_GRACEFUL_ERROR_HANDLING, True + ): + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack" + ) as mock_stack: + mock_stack.return_value.enter_async_context = AsyncMock() + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.SessionContext" + ): + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.stdio_client" + ): + await manager.create_session() + + mock_wait_for.assert_not_called() diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index cc6c0b2134..67a2674ff7 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -15,12 +15,14 @@ from typing import Any from typing import Optional +from google.adk.agents.base_agent import BaseAgent from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import Agent from google.adk.agents.run_config import RunConfig from google.adk.agents.sequential_agent import SequentialAgent from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.events.event import Event from google.adk.features import FeatureName from google.adk.features._feature_registry import temporary_feature_override from google.adk.memory.in_memory_memory_service import InMemoryMemoryService @@ -985,6 +987,109 @@ async def test_run_async_handles_none_parts_in_response(): assert tool_result == '' +async def _run_agent_tool_with_parts(parts: list[types.Part]) -> Any: + """Drives AgentTool with an inner agent whose final event content is `parts`.""" + + class _StaticAgent(BaseAgent): + + async def _run_async_impl(self, ctx): + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + content=types.Content(role='model', parts=parts), + ) + + inner = _StaticAgent(name='inner_agent', description='static') + agent_tool = AgentTool(agent=inner) + + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + invocation_context = InvocationContext( + invocation_id='invocation_id', + agent=inner, + session=session, + session_service=session_service, + ) + tool_context = ToolContext(invocation_context=invocation_context) + + return await agent_tool.run_async( + args={'request': 'test request'}, tool_context=tool_context + ) + + +@mark.asyncio +async def test_run_async_extracts_text_only(): + """Plain text parts pass through unchanged.""" + result = await _run_agent_tool_with_parts([types.Part(text='hello world')]) + assert result == 'hello world' + + +@mark.asyncio +async def test_run_async_extracts_code_execution_result_only(): + """code_execution_result.output and executable_code.code are returned.""" + result = await _run_agent_tool_with_parts([ + types.Part( + executable_code=types.ExecutableCode( + language=types.Language.PYTHON, code='print(2 ** 10)' + ) + ), + types.Part( + code_execution_result=types.CodeExecutionResult( + outcome=types.Outcome.OUTCOME_OK, output='1024\n' + ) + ), + ]) + assert result == 'print(2 ** 10)\n1024' + + +@mark.asyncio +async def test_run_async_extracts_text_and_code_execution_result(): + """Mixed text + code parts are concatenated in order.""" + result = await _run_agent_tool_with_parts([ + types.Part(text='Here is the answer:'), + types.Part( + executable_code=types.ExecutableCode( + language=types.Language.PYTHON, code='print(2 ** 10)' + ) + ), + types.Part( + code_execution_result=types.CodeExecutionResult( + outcome=types.Outcome.OUTCOME_OK, output='1024\n' + ) + ), + ]) + assert result == 'Here is the answer:\nprint(2 ** 10)\n1024' + + +@mark.asyncio +async def test_run_async_extracts_executable_code_only(): + """executable_code.code alone is returned when no result part follows.""" + result = await _run_agent_tool_with_parts([ + types.Part( + executable_code=types.ExecutableCode( + language=types.Language.PYTHON, code='print("hi")' + ) + ), + ]) + assert result == 'print("hi")' + + +@mark.asyncio +async def test_run_async_skips_thought_parts(): + """Parts marked thought=True are dropped regardless of kind.""" + result = await _run_agent_tool_with_parts([ + types.Part(text='thinking out loud', thought=True), + types.Part( + code_execution_result=types.CodeExecutionResult( + outcome=types.Outcome.OUTCOME_OK, output='42\n' + ) + ), + ]) + assert result == '42' + + class TestAgentToolWithCompositeAgents: """Tests for AgentTool wrapping composite agents (SequentialAgent, etc.).""" diff --git a/tests/unittests/tools/test_skill_toolset.py b/tests/unittests/tools/test_skill_toolset.py index b58c01b91b..f637377513 100644 --- a/tests/unittests/tools/test_skill_toolset.py +++ b/tests/unittests/tools/test_skill_toolset.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import collections import logging import sys from unittest import mock @@ -1859,14 +1860,14 @@ async def test_turn_scoped_skill_cache_eviction(mock_registry, mock_skill1): for i in range(16): await toolset._get_or_fetch_skill("skill1", f"turn-{i}") - assert len(toolset._invocation_cache) == 16 - assert "turn-0" in toolset._invocation_cache + assert len(toolset._fetched_skill_cache) == 16 + assert "turn-0" in toolset._fetched_skill_cache # Next turn should evict oldest (turn-0) await toolset._get_or_fetch_skill("skill1", "turn-16") - assert len(toolset._invocation_cache) == 16 - assert "turn-0" not in toolset._invocation_cache - assert "turn-1" in toolset._invocation_cache + assert len(toolset._fetched_skill_cache) == 16 + assert "turn-0" not in toolset._fetched_skill_cache + assert "turn-1" in toolset._fetched_skill_cache @pytest.mark.asyncio @@ -1891,3 +1892,36 @@ async def delayed_get_skill(name): # Registry should have been called exactly once mock_registry.get_skill.assert_called_once_with(name="skill1") + + +def test_skill_toolset_disables_invocation_cache(): + """Verify SkillToolset disables tool invocation caching to allow dynamic tools.""" + toolset = skill_toolset.SkillToolset() + assert toolset._use_invocation_cache is False + + +@pytest.mark.asyncio +async def test_close_cancels_futures_and_clears_cache(): + # pylint: disable=protected-access + toolset = skill_toolset.SkillToolset() + + # Create mock futures for testing close() behavior + loop = asyncio.get_running_loop() + fut1 = loop.create_future() + fut2 = loop.create_future() + fut2.set_result(None) # Already done future + + toolset._fetched_skill_cache = collections.OrderedDict( + { + "turn1": { + "skill1": fut1, + "skill2": fut2, + } + } + ) + + await toolset.close() + + assert fut1.cancelled() + assert not fut2.cancelled() # Done futures shouldn't/can't be cancelled + assert not toolset._fetched_skill_cache diff --git a/tests/unittests/utils/test_context_utils.py b/tests/unittests/utils/test_context_utils.py index 56583b9c6d..b8173be4b0 100644 --- a/tests/unittests/utils/test_context_utils.py +++ b/tests/unittests/utils/test_context_utils.py @@ -15,10 +15,12 @@ """Tests for context_utils module.""" from typing import Optional +from unittest import mock from google.adk.agents.callback_context import CallbackContext from google.adk.agents.context import Context from google.adk.tools.tool_context import ToolContext +from google.adk.utils import context_utils from google.adk.utils.context_utils import find_context_parameter @@ -129,3 +131,25 @@ def my_tool( return query assert find_context_parameter(my_tool) == 'ctx' + + +class TestFindContextParameterCaching: + """Tests for find_context_parameter caching behavior.""" + + def test_repeated_calls_inspect_signature_once(self): + """Repeated calls with the same function reuse the cached result.""" + + def my_tool(ctx: Context) -> str: + return 'ok' + + find_context_parameter.cache_clear() + + with mock.patch.object( + context_utils.inspect, + 'signature', + wraps=context_utils.inspect.signature, + ) as spy: + for _ in range(10): + assert find_context_parameter(my_tool) == 'ctx' + + assert spy.call_count == 1