From bfebfbc42779cf371e306b64af08950949584956 Mon Sep 17 00:00:00 2001 From: Salman Muin Kayser Chishti <13schishti@gmail.com> Date: Fri, 23 Jan 2026 08:49:52 +0000 Subject: [PATCH 01/40] Upgrade GitHub Actions for Node 24 compatibility Signed-off-by: Salman Muin Kayser Chishti <13schishti@gmail.com> --- .github/workflows/pr-commit-check.yml | 2 +- .github/workflows/validation.yml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pr-commit-check.yml b/.github/workflows/pr-commit-check.yml index ec6644311..1e31e42f3 100644 --- a/.github/workflows/pr-commit-check.yml +++ b/.github/workflows/pr-commit-check.yml @@ -21,7 +21,7 @@ jobs: # Step 1: Check out the code # This action checks out your repository under $GITHUB_WORKSPACE, so your workflow can access it. - name: Checkout Code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: # We need to fetch all commits to accurately count them. # '0' means fetch all history for all branches and tags. diff --git a/.github/workflows/validation.yml b/.github/workflows/validation.yml index eeb16e1ff..26a276f05 100644 --- a/.github/workflows/validation.yml +++ b/.github/workflows/validation.yml @@ -20,16 +20,16 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Java ${{ matrix.java-version }} - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: temurin java-version: ${{ matrix.java-version }} - name: Cache Maven packages - uses: actions/cache@v3 + uses: actions/cache@v5 with: path: ~/.m2/repository key: ${{ runner.os }}-maven-${{ matrix.java-version }}-${{ hashFiles('**/pom.xml') }} From 5ea48cd6f69cb5221317e6660f15c88053565de7 Mon Sep 17 00:00:00 2001 From: OwenDavisBC Date: Mon, 9 Feb 2026 13:43:42 -0700 Subject: [PATCH 02/40] ISSUE-777: Ensure token usage metadata included with streaming responses --- .../java/com/google/adk/models/Gemini.java | 25 ++- .../com/google/adk/models/GeminiTest.java | 191 ++++++++++++++++++ 2 files changed, 211 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/google/adk/models/Gemini.java b/core/src/main/java/com/google/adk/models/Gemini.java index 74cf78b98..6f145e1de 100644 --- a/core/src/main/java/com/google/adk/models/Gemini.java +++ b/core/src/main/java/com/google/adk/models/Gemini.java @@ -239,7 +239,7 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre p -> p.functionCall().isPresent() || p.functionResponse().isPresent() - || p.text().map(t -> !t.isBlank()).orElse(false))) + || p.text().isPresent())) .orElse(false)); } else { logger.debug("Sending generateContent request to model {}", effectiveModelName); @@ -272,11 +272,17 @@ static Flowable processRawResponses(Flowable 0 @@ -316,11 +322,20 @@ static Flowable processRawResponses(Flowable finalResponses = new ArrayList<>(); if (accumulatedThoughtText.length() > 0) { finalResponses.add( - thinkingResponseFromText(accumulatedThoughtText.toString())); + thinkingResponseFromText(accumulatedThoughtText.toString()).toBuilder() + .usageMetadata( + accumulatedText.length() > 0 + ? null + : finalRawResp.usageMetadata().orElse(null)) + .build()); } if (accumulatedText.length() > 0) { - finalResponses.add(responseFromText(accumulatedText.toString())); + finalResponses.add( + responseFromText(accumulatedText.toString()).toBuilder() + .usageMetadata(finalRawResp.usageMetadata().orElse(null)) + .build()); } + return Flowable.fromIterable(finalResponses); } return Flowable.empty(); diff --git a/core/src/test/java/com/google/adk/models/GeminiTest.java b/core/src/test/java/com/google/adk/models/GeminiTest.java index 07dd675e5..c230f5f68 100644 --- a/core/src/test/java/com/google/adk/models/GeminiTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiTest.java @@ -22,6 +22,7 @@ import com.google.genai.types.Content; import com.google.genai.types.FinishReason; import com.google.genai.types.GenerateContentResponse; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.functions.Predicate; @@ -123,6 +124,76 @@ public void processRawResponses_textThenEmpty_emitsPartialTextThenFullTextAndEmp isEmptyResponse()); } + @Test + public void processRawResponses_withTextChunks_partialResponsesIncludeUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + Flowable rawResponses = + Flowable.just( + toResponseWithText("Hello", metadata1), toResponseWithText(" world", metadata2)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponseWithUsageMetadata("Hello", metadata1), + isPartialTextResponseWithUsageMetadata(" world", metadata2)); + } + + @Test + public void processRawResponses_textAndStopReason_finalResponseIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(10, 20, 30); + Flowable rawResponses = + Flowable.just( + toResponseWithText("Hello"), + toResponseWithText(" world", FinishReason.Known.STOP, metadata)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponse("Hello"), + isPartialTextResponseWithUsageMetadata(" world", metadata), + isFinalTextResponseWithUsageMetadata("Hello world", metadata)); + } + + @Test + public void processRawResponses_thoughtChunksAndStop_includeUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + Flowable rawResponses = + Flowable.just( + toResponseWithThoughtText("Thinking", metadata1), + toResponseWithThoughtText(" deeply", FinishReason.Known.STOP, metadata2)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isPartialThoughtResponseWithUsageMetadata(" deeply", metadata2), + isFinalThoughtResponseWithUsageMetadata("Thinking deeply", metadata2)); + } + + @Test + public void processRawResponses_thoughtAndTextWithStop_onlyFinalTextIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 5, 10); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(10, 20, 30); + Flowable rawResponses = + Flowable.just( + toResponseWithThoughtText("Thinking", metadata1), + toResponseWithText("Answer", FinishReason.Known.STOP, metadata2)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isPartialTextResponseWithUsageMetadata("Answer", metadata2), + isFinalThoughtResponseWithNoUsageMetadata("Thinking"), + isFinalTextResponseWithUsageMetadata("Answer", metadata2)); + } + // Helper methods for assertions private void assertLlmResponses( @@ -170,6 +241,67 @@ private static Predicate isEmptyResponse() { }; } + private static Predicate isPartialTextResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isPartialThoughtResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) + .isTrue(); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalTextResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalThoughtResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) + .isTrue(); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalThoughtResponseWithNoUsageMetadata( + String expectedText) { + return response -> { + assertThat(response.partial()).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) + .isTrue(); + assertThat(response.usageMetadata()).isEmpty(); + return true; + }; + } + // Helper methods to create responses for testing private GenerateContentResponse toResponseWithText(String text) { @@ -191,4 +323,63 @@ private GenerateContentResponse toResponse(Part part) { private GenerateContentResponse toResponse(Candidate candidate) { return GenerateContentResponse.builder().candidates(candidate).build(); } + + private GenerateContentResponse toResponseWithText( + String text, GenerateContentResponseUsageMetadata usageMetadata) { + return GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(Part.fromText(text)).build()) + .build()) + .usageMetadata(usageMetadata) + .build(); + } + + private GenerateContentResponse toResponseWithText( + String text, + FinishReason.Known finishReason, + GenerateContentResponseUsageMetadata usageMetadata) { + return GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(Part.fromText(text)).build()) + .finishReason(new FinishReason(finishReason)) + .build()) + .usageMetadata(usageMetadata) + .build(); + } + + private GenerateContentResponse toResponseWithThoughtText( + String text, GenerateContentResponseUsageMetadata usageMetadata) { + Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build(); + return GenerateContentResponse.builder() + .candidates( + Candidate.builder().content(Content.builder().parts(thoughtPart).build()).build()) + .usageMetadata(usageMetadata) + .build(); + } + + private GenerateContentResponse toResponseWithThoughtText( + String text, + FinishReason.Known finishReason, + GenerateContentResponseUsageMetadata usageMetadata) { + Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build(); + return GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(thoughtPart).build()) + .finishReason(new FinishReason(finishReason)) + .build()) + .usageMetadata(usageMetadata) + .build(); + } + + private static GenerateContentResponseUsageMetadata createUsageMetadata( + int promptTokens, int candidateTokens, int totalTokens) { + return GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(promptTokens) + .candidatesTokenCount(candidateTokens) + .totalTokenCount(totalTokens) + .build(); + } } From 28a8cd04ca9348dbe51a15d2be3a2b5307394174 Mon Sep 17 00:00:00 2001 From: Mateusz Krawiec Date: Tue, 17 Mar 2026 01:48:10 -0700 Subject: [PATCH 03/40] chore!: remove deprecated Example processor PiperOrigin-RevId: 884881559 --- .../java/com/google/adk/agents/LlmAgent.java | 47 +++------ .../com/google/adk/examples/ExampleUtils.java | 3 + .../google/adk/flows/llmflows/Examples.java | 57 ----------- .../google/adk/flows/llmflows/SingleFlow.java | 1 - .../com/google/adk/tools/ExampleTool.java | 4 +- .../com/google/adk/agents/LlmAgentTest.java | 28 ++++++ .../google/adk/examples/ExampleUtilsTest.java | 12 +-- .../adk/flows/llmflows/ExamplesTest.java | 99 ------------------- .../com/google/adk/tools/ExampleToolTest.java | 26 +++++ 9 files changed, 73 insertions(+), 204 deletions(-) delete mode 100644 core/src/main/java/com/google/adk/flows/llmflows/Examples.java delete mode 100644 core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index d326d8154..89024a59b 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -45,8 +45,6 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.codeexecutors.BaseCodeExecutor; import com.google.adk.events.Event; -import com.google.adk.examples.BaseExampleProvider; -import com.google.adk.examples.Example; import com.google.adk.flows.llmflows.AutoFlow; import com.google.adk.flows.llmflows.BaseLlmFlow; import com.google.adk.flows.llmflows.SingleFlow; @@ -97,8 +95,6 @@ public enum IncludeContents { private final List toolsUnion; private final ImmutableList toolsets; private final Optional generateContentConfig; - // TODO: Remove exampleProvider field - examples should only be provided via ExampleTool - private final Optional exampleProvider; private final IncludeContents includeContents; private final boolean planning; @@ -132,7 +128,6 @@ protected LlmAgent(Builder builder) { this.globalInstruction = requireNonNullElse(builder.globalInstruction, new Instruction.Static("")); this.generateContentConfig = Optional.ofNullable(builder.generateContentConfig); - this.exampleProvider = Optional.ofNullable(builder.exampleProvider); this.includeContents = requireNonNullElse(builder.includeContents, IncludeContents.DEFAULT); this.planning = builder.planning != null && builder.planning; this.maxSteps = Optional.ofNullable(builder.maxSteps); @@ -180,7 +175,6 @@ public static class Builder extends BaseAgent.Builder { private Instruction globalInstruction; private ImmutableList toolsUnion; private GenerateContentConfig generateContentConfig; - private BaseExampleProvider exampleProvider; private IncludeContents includeContents; private Boolean planning; private Integer maxSteps; @@ -253,26 +247,6 @@ public Builder generateContentConfig(GenerateContentConfig generateContentConfig return this; } - // TODO: Remove these example provider methods and only use ExampleTool for providing examples. - // Direct example methods should be deprecated in favor of using ExampleTool consistently. - @CanIgnoreReturnValue - public Builder exampleProvider(BaseExampleProvider exampleProvider) { - this.exampleProvider = exampleProvider; - return this; - } - - @CanIgnoreReturnValue - public Builder exampleProvider(List examples) { - this.exampleProvider = (unused) -> examples; - return this; - } - - @CanIgnoreReturnValue - public Builder exampleProvider(Example... examples) { - this.exampleProvider = (unused) -> ImmutableList.copyOf(examples); - return this; - } - @CanIgnoreReturnValue public Builder includeContents(IncludeContents includeContents) { this.includeContents = includeContents; @@ -640,10 +614,18 @@ protected void validate() { + " transfer."); } if (this.toolsUnion != null && !this.toolsUnion.isEmpty()) { - throw new IllegalArgumentException( - "Invalid config for agent " - + this.name - + ": if outputSchema is set, tools must be empty."); + boolean hasOtherTools = + this.toolsUnion.stream() + .anyMatch( + tool -> + !(tool instanceof BaseTool baseTool) + || !baseTool.name().equals("example_tool")); + if (hasOtherTools) { + throw new IllegalArgumentException( + "Invalid config for agent " + + this.name + + ": if outputSchema is set, tools must be empty."); + } } } } @@ -812,11 +794,6 @@ public Optional generateContentConfig() { return generateContentConfig; } - // TODO: Remove this getter - examples should only be provided via ExampleTool - public Optional exampleProvider() { - return exampleProvider; - } - public IncludeContents includeContents() { return includeContents; } diff --git a/core/src/main/java/com/google/adk/examples/ExampleUtils.java b/core/src/main/java/com/google/adk/examples/ExampleUtils.java index 9cce535dc..2f3927ece 100644 --- a/core/src/main/java/com/google/adk/examples/ExampleUtils.java +++ b/core/src/main/java/com/google/adk/examples/ExampleUtils.java @@ -64,6 +64,9 @@ public final class ExampleUtils { * @return string representation of the examples block. */ private static String convertExamplesToText(List examples) { + if (examples.isEmpty()) { + return ""; + } StringBuilder examplesStr = new StringBuilder(); // super header diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Examples.java b/core/src/main/java/com/google/adk/flows/llmflows/Examples.java deleted file mode 100644 index d9cee5fa0..000000000 --- a/core/src/main/java/com/google/adk/flows/llmflows/Examples.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.flows.llmflows; - -import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.LlmAgent; -import com.google.adk.examples.ExampleUtils; -import com.google.adk.models.LlmRequest; -import com.google.common.collect.ImmutableList; -import com.google.genai.types.Content; -import io.reactivex.rxjava3.core.Single; - -/** {@link RequestProcessor} that populates examples in LLM request. */ -public final class Examples implements RequestProcessor { - - public Examples() {} - - @Override - public Single processRequest( - InvocationContext context, LlmRequest request) { - if (!(context.agent() instanceof LlmAgent)) { - throw new IllegalArgumentException("Agent in InvocationContext is not an instance of Agent."); - } - LlmAgent agent = (LlmAgent) context.agent(); - LlmRequest.Builder builder = request.toBuilder(); - - String query = - context - .userContent() - .flatMap(Content::parts) - .filter(parts -> !parts.isEmpty()) - .map(parts -> parts.get(0).text().orElse("")) - .orElse(""); - agent - .exampleProvider() - .ifPresent( - exampleProvider -> - builder.appendInstructions( - ImmutableList.of(ExampleUtils.buildExampleSi(exampleProvider, query)))); - return Single.just( - RequestProcessor.RequestProcessingResult.create(builder.build(), ImmutableList.of())); - } -} diff --git a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java index de45ba702..f56cc61c3 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java @@ -32,7 +32,6 @@ public class SingleFlow extends BaseLlmFlow { new Identity(), new Compaction(), new Contents(), - new Examples(), CodeExecution.requestProcessor); protected static final ImmutableList RESPONSE_PROCESSORS = diff --git a/core/src/main/java/com/google/adk/tools/ExampleTool.java b/core/src/main/java/com/google/adk/tools/ExampleTool.java index d08481532..d03c2e4f1 100644 --- a/core/src/main/java/com/google/adk/tools/ExampleTool.java +++ b/core/src/main/java/com/google/adk/tools/ExampleTool.java @@ -85,7 +85,9 @@ public Completable processLlmRequest( return Completable.complete(); } - llmRequestBuilder.appendInstructions(ImmutableList.of(examplesBlock)); + if (!examplesBlock.isEmpty()) { + llmRequestBuilder.appendInstructions(ImmutableList.of(examplesBlock)); + } // Delegate to BaseTool to keep any declaration bookkeeping (none for this tool) return super.processLlmRequest(llmRequestBuilder, toolContext); } diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 594e47fd8..3524c7755 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -35,6 +35,7 @@ import com.google.adk.agents.Callbacks.OnModelErrorCallback; import com.google.adk.agents.Callbacks.OnToolErrorCallback; import com.google.adk.events.Event; +import com.google.adk.examples.Example; import com.google.adk.models.LlmRegistry; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; @@ -46,6 +47,7 @@ import com.google.adk.testing.TestUtils.EchoTool; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ExampleTool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -649,4 +651,30 @@ private SpanData findSpanByName(List spans, String name) { .findFirst() .orElseThrow(() -> new AssertionError("Span not found: " + name)); } + + @Test + public void run_withExampleTool_doesNotAddFunctionDeclarations() { + ExampleTool tool = + ExampleTool.builder() + .addExample( + Example.builder() + .input(Content.fromParts(Part.fromText("qin"))) + .output(ImmutableList.of(Content.fromParts(Part.fromText("qout")))) + .build()) + .build(); + + Content modelContent = Content.fromParts(Part.fromText("Real LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(modelContent)); + LlmAgent agent = createTestAgentBuilder(testLlm).tools(tool).build(); + InvocationContext invocationContext = createInvocationContext(agent); + + var unused = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(testLlm.getRequests()).hasSize(1); + LlmRequest request = testLlm.getRequests().get(0); + + assertThat(request.config().isPresent()).isTrue(); + var config = request.config().get(); + assertThat(config.tools().isPresent()).isFalse(); + } } diff --git a/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java b/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java index 4a1dcf8e3..2d22ed3f1 100644 --- a/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java +++ b/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java @@ -49,17 +49,7 @@ public List getExamples(String query) { @Test public void buildFewShotFewShot_noExamples() { TestExampleProvider exampleProvider = new TestExampleProvider(ImmutableList.of()); - String expected = - """ - - Begin few-shot - The following are examples of user queries and model responses using the available tools. - - End few-shot - Now, try to follow these examples and complete the following conversation - \ - """; - assertThat(ExampleUtils.buildExampleSi(exampleProvider, "test query")).isEqualTo(expected); + assertThat(ExampleUtils.buildExampleSi(exampleProvider, "test query")).isEmpty(); } @Test diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java deleted file mode 100644 index 7d1615dc2..000000000 --- a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.flows.llmflows; - -import static com.google.common.truth.Truth.assertThat; - -import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.RunConfig; -import com.google.adk.examples.BaseExampleProvider; -import com.google.adk.examples.Example; -import com.google.adk.models.LlmRequest; -import com.google.adk.sessions.InMemorySessionService; -import com.google.adk.sessions.Session; -import com.google.common.collect.ImmutableList; -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import java.util.List; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public final class ExamplesTest { - - private static final InMemorySessionService sessionService = new InMemorySessionService(); - - private static class TestExampleProvider implements BaseExampleProvider { - @Override - public List getExamples(String query) { - return ImmutableList.of( - Example.builder() - .input(Content.fromParts(Part.fromText("input1"))) - .output( - ImmutableList.of( - Content.builder().parts(Part.fromText("output1")).role("model").build())) - .build()); - } - } - - @Test - public void processRequest_withExampleProvider_addsExamplesToInstructions() { - LlmAgent agent = - LlmAgent.builder().name("test-agent").exampleProvider(new TestExampleProvider()).build(); - InvocationContext context = - InvocationContext.builder() - .invocationId("invocation1") - .session(Session.builder("session1").build()) - .sessionService(sessionService) - .agent(agent) - .userContent(Content.fromParts(Part.fromText("what is up?"))) - .runConfig(RunConfig.builder().build()) - .build(); - LlmRequest request = LlmRequest.builder().build(); - Examples examplesProcessor = new Examples(); - - RequestProcessor.RequestProcessingResult result = - examplesProcessor.processRequest(context, request).blockingGet(); - - assertThat(result.updatedRequest().getSystemInstructions()).isNotEmpty(); - assertThat(result.updatedRequest().getSystemInstructions().get(0)) - .contains("[user]\ninput1\n\n[model]\noutput1\n"); - } - - @Test - public void processRequest_withoutExampleProvider_doesNotAddExamplesToInstructions() { - LlmAgent agent = LlmAgent.builder().name("test-agent").build(); - InvocationContext context = - InvocationContext.builder() - .invocationId("invocation1") - .session(Session.builder("session1").build()) - .sessionService(sessionService) - .agent(agent) - .userContent(Content.fromParts(Part.fromText("what is up?"))) - .runConfig(RunConfig.builder().build()) - .build(); - LlmRequest request = LlmRequest.builder().build(); - Examples examplesProcessor = new Examples(); - - RequestProcessor.RequestProcessingResult result = - examplesProcessor.processRequest(context, request).blockingGet(); - - assertThat(result.updatedRequest().getSystemInstructions()).isEmpty(); - } -} diff --git a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java index 4e80ed0ff..55e5d8f93 100644 --- a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java +++ b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java @@ -305,4 +305,30 @@ static final class WrongTypeProviderHolder { private WrongTypeProviderHolder() {} } + + @Test + public void declaration_isEmpty() { + ExampleTool tool = ExampleTool.builder().build(); + assertThat(tool.declaration().isPresent()).isFalse(); + } + + @Test + public void processLlmRequest_doesNotAddFunctionDeclarations() { + ExampleTool tool = ExampleTool.builder().addExample(makeExample("qin", "qout")).build(); + InvocationContext ctx = buildInvocationContext(); + LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); + + tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); + LlmRequest updated = builder.build(); + + if (updated.config().isPresent()) { + var config = updated.config().get(); + if (config.tools().isPresent()) { + var tools = config.tools().get(); + boolean hasFunctionDeclarations = + tools.stream().anyMatch(t -> t.functionDeclarations().isPresent()); + assertThat(hasFunctionDeclarations).isFalse(); + } + } + } } From 8556d4af16ff04c6e3b678dcfc3d4bb232abc550 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Mar 2026 06:59:16 -0700 Subject: [PATCH 04/40] feat: Propagating the otel context This change ensures that the OpenTelemetry context is correctly propagated across asynchronous boundaries throughout the ADK, primarily within RxJava streams. ### Key Changes * **Context Propagation:** Replaces manual `Scope` management (which often fails in reactive code) with `.compose(Tracing.withContext(context))`. This ensures the OTel context is preserved when work moves between different threads or schedulers. * **`Runner` Refactoring:** * Adds a top-level `"invocation"` span to `runAsync` and `runLive` calls. * Captures the context at entry points and propagates it through the internal execution flow (`runAsyncImpl`, `runLiveImpl`, `runAgentWithFreshSession`). * **`BaseLlmFlow` & `Functions`:** Updates preprocessing, postprocessing, and tool execution logic to maintain context. This ensures that spans created within tools or processors are correctly parented. * **`PluginManager`:** Ensures that plugin callbacks (like `afterRunCallback` and `onEventCallback`) execute within the captured context. * **Testing:** Adds several unit tests across `BaseLlmFlowTest`, `FunctionsTest`, `PluginManagerTest`, and `RunnerTest` that specifically verify context propagation using `ContextKey` and `Schedulers.computation()`. ### Files Modified * **`BaseLlmFlow.java`**, **`Functions.java`**, **`PluginManager.java`**, **`Runner.java`**: Core logic updates for context propagation. * **`LlmAgentTest.java`**, **`BaseLlmFlowTest.java`**, **`FunctionsTest.java`**, **`PluginManagerTest.java`**, **`RunnerTest.java`**: New tests for OTel integration. * **`BUILD` files**: Updated dependencies for OpenTelemetry APIs and SDK testing. PiperOrigin-RevId: 884998997 --- .../adk/flows/llmflows/BaseLlmFlow.java | 208 +++++++++++------- .../google/adk/flows/llmflows/Functions.java | 138 ++++++------ ...equestConfirmationLlmRequestProcessor.java | 13 +- .../com/google/adk/plugins/PluginManager.java | 15 +- .../java/com/google/adk/runner/Runner.java | 163 ++++++++------ .../com/google/adk/agents/LlmAgentTest.java | 9 +- .../adk/flows/llmflows/BaseLlmFlowTest.java | 76 ++++++- .../google/adk/plugins/PluginManagerTest.java | 85 +++++++ .../com/google/adk/runner/RunnerTest.java | 128 +++++++++++ 9 files changed, 601 insertions(+), 234 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index ab5f6567a..e00cf0cbf 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -91,8 +91,9 @@ public BaseLlmFlow( * RequestProcessor} transforming the provided {@code llmRequestRef} in-place, and emits the * events generated by them. */ - protected Flowable preprocess( + private Flowable preprocess( InvocationContext context, AtomicReference llmRequestRef) { + Context currentContext = Context.current(); LlmAgent agent = (LlmAgent) context.agent(); RequestProcessor toolsProcessor = @@ -114,6 +115,7 @@ protected Flowable preprocess( .concatMap( processor -> Single.defer(() -> processor.processRequest(context, llmRequestRef.get())) + .compose(Tracing.withContext(currentContext)) .doOnSuccess(result -> llmRequestRef.set(result.updatedRequest())) .flattenAsFlowable( result -> result.events() != null ? result.events() : ImmutableList.of())); @@ -128,7 +130,8 @@ protected Flowable postprocess( InvocationContext context, Event baseEventForLlmResponse, LlmRequest llmRequest, - LlmResponse llmResponse) { + LlmResponse llmResponse, + Context parentContext) { List> eventIterables = new ArrayList<>(); Single currentLlmResponse = Single.just(llmResponse); @@ -144,15 +147,16 @@ protected Flowable postprocess( }) .map(ResponseProcessingResult::updatedResponse); } - Context parentContext = Context.current(); - return currentLlmResponse.flatMapPublisher( - updatedResponse -> { - try (Scope scope = parentContext.makeCurrent()) { - return buildPostprocessingEvents( - updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); - } - }); + updatedResponse -> + buildPostprocessingEvents( + updatedResponse, + eventIterables, + context, + baseEventForLlmResponse, + llmRequest, + parentContext) + .compose(Tracing.withContext(parentContext))); } /** @@ -163,54 +167,80 @@ protected Flowable postprocess( * @param eventForCallbackUsage An Event object primarily for providing context (like actions) to * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ - private Flowable callLlm( + private Flowable callLlm( Context spanContext, InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { - LlmAgent agent = (LlmAgent) context.agent(); - LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) .toFlowable() + .concatMap( + llmResp -> + postprocess( + context, + eventForCallbackUsage, + llmRequestBuilder.build(), + llmResp, + spanContext)) .switchIfEmpty( Flowable.defer( () -> { + LlmAgent agent = (LlmAgent) context.agent(); BaseLlm llm = agent.resolvedModel().model().isPresent() ? agent.resolvedModel().model().get() : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, llmRequestBuilder, eventForCallbackUsage, exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .compose( - Tracing.trace("call_llm") - .setParent(spanContext) - .onSuccess( - (span, llmResp) -> - Tracing.traceCallLlm( - span, + LlmRequest finalLlmRequest = llmRequestBuilder.build(); + + Span span = + Tracing.getTracer() + .spanBuilder("call_llm") + .setParent(spanContext) + .startSpan(); + Context callLlmContext = spanContext.with(span); + + Flowable flowable = + llm.generateContent( + finalLlmRequest, + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, + llmRequestBuilder, + eventForCallbackUsage, + exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnError( + error -> { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .concatMap( + llmResp -> + handleAfterModelCallback(context, llmResp, eventForCallbackUsage) + .toFlowable()) + .flatMap( + llmResp -> + postprocess( context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp))) - .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()); + eventForCallbackUsage, + finalLlmRequest, + llmResp, + callLlmContext) + .doOnSubscribe( + s -> + Tracing.traceCallLlm( + span, + context, + eventForCallbackUsage.id(), + finalLlmRequest, + llmResp))); + + return Tracing.traceFlowable(callLlmContext, span, () -> flowable); })); } @@ -222,6 +252,7 @@ private Flowable callLlm( */ private Maybe handleBeforeModelCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -240,7 +271,11 @@ private Maybe handleBeforeModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequestBuilder) + .compose(Tracing.withContext(currentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult); @@ -257,6 +292,7 @@ private Maybe handleOnModelErrorCallback( LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent, Throwable throwable) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -277,7 +313,11 @@ private Maybe handleOnModelErrorCallback( () -> { LlmRequest llmRequest = llmRequestBuilder.build(); return Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequest, ex)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequest, ex) + .compose(Tracing.withContext(currentContext))) .firstElement(); }); @@ -292,6 +332,7 @@ private Maybe handleOnModelErrorCallback( */ private Single handleAfterModelCallback( InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -310,7 +351,11 @@ private Single handleAfterModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmResponse) + .compose(Tracing.withContext(currentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); @@ -330,7 +375,6 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex return Flowable.defer( () -> { - Context currentContext = Context.current(); return preprocess(context, llmRequestRef) .concatWith( Flowable.defer( @@ -362,23 +406,12 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex context, llmRequestAfterPreprocess, mutableEventTemplate) - .concatMap( - llmResponse -> { - try (Scope postScope = currentContext.makeCurrent()) { - return postprocess( - context, - mutableEventTemplate, - llmRequestAfterPreprocess, - llmResponse) - .doFinally( - () -> { - String oldId = mutableEventTemplate.id(); - String newId = Event.generateEventId(); - logger.debug( - "Resetting event ID from {} to {}", oldId, newId); - mutableEventTemplate.setId(newId); - }); - } + .doFinally( + () -> { + String oldId = mutableEventTemplate.id(); + String newId = Event.generateEventId(); + logger.debug("Resetting event ID from {} to {}", oldId, newId); + mutableEventTemplate.setId(newId); }) .concatMap( event -> { @@ -545,6 +578,10 @@ public void onError(Throwable e) { .author(invocationContext.agent().name()) .branch(invocationContext.branch().orElse(null)); + Span span = + Tracing.getTracer().spanBuilder("call_llm").setParent(spanContext).startSpan(); + Context callLlmContext = spanContext.with(span); + Flowable receiveFlow = connection .receive() @@ -556,7 +593,8 @@ public void onError(Throwable e) { invocationContext, baseEventForThisLlmResponse, llmRequestAfterPreprocess, - llmResponse); + llmResponse, + callLlmContext); }) .flatMap( event -> { @@ -592,7 +630,12 @@ public void onError(Throwable e) { } }); - return receiveFlow.takeWhile(event -> !event.actions().endInvocation().orElse(false)); + return Tracing.traceFlowable( + callLlmContext, + span, + () -> + receiveFlow.takeWhile( + event -> !event.actions().endInvocation().orElse(false))); })); } @@ -608,7 +651,8 @@ private Flowable buildPostprocessingEvents( List> eventIterables, InvocationContext context, Event baseEventForLlmResponse, - LlmRequest llmRequest) { + LlmRequest llmRequest, + Context parentContext) { Flowable processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables)); if (updatedResponse.content().isEmpty() && updatedResponse.errorCode().isEmpty() @@ -624,21 +668,23 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)); } - Maybe maybeFunctionResponseEvent = - context.runConfig().streamingMode() == StreamingMode.BIDI - ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) - : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); - - Flowable functionEvents = - maybeFunctionResponseEvent.flatMapPublisher( - functionResponseEvent -> { - Optional toolConfirmationEvent = - Functions.generateRequestConfirmationEvent( - context, modelResponseEvent, functionResponseEvent); - return toolConfirmationEvent.isPresent() - ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) - : Flowable.just(functionResponseEvent); - }); + Flowable functionEvents; + try (Scope scope = parentContext.makeCurrent()) { + Maybe maybeFunctionResponseEvent = + context.runConfig().streamingMode() == StreamingMode.BIDI + ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) + : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); + functionEvents = + maybeFunctionResponseEvent.flatMapPublisher( + functionResponseEvent -> { + Optional toolConfirmationEvent = + Functions.generateRequestConfirmationEvent( + context, modelResponseEvent, functionResponseEvent); + return toolConfirmationEvent.isPresent() + ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) + : Flowable.just(functionResponseEvent); + }); + } return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index c1a996064..84a8141ea 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -42,7 +42,6 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.context.Context; -import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Observable; @@ -163,7 +162,9 @@ public static Maybe handleFunctionCalls( } return functionResponseEventsObservable .toList() - .flatMapMaybe( + .toMaybe() + .compose(Tracing.withContext(parentContext)) + .flatMap( events -> { if (events.isEmpty()) { return Maybe.empty(); @@ -226,7 +227,9 @@ public static Maybe handleFunctionCallsLive( return responseEventsObservable .toList() - .flatMapMaybe( + .toMaybe() + .compose(Tracing.withContext(parentContext)) + .flatMap( events -> { if (events.isEmpty()) { return Maybe.empty(); @@ -243,47 +246,45 @@ private static Function> getFunctionCallMapper( Context parentContext) { return functionCall -> Maybe.defer( - () -> { - try (Scope scope = parentContext.makeCurrent()) { - BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = - ToolContext.builder(invocationContext) - .functionCallId(functionCall.id().orElse("")) - .toolConfirmation( - functionCall.id().map(toolConfirmations::get).orElse(null)) - .build(); - - Map functionArgs = - functionCall.args().map(HashMap::new).orElse(new HashMap<>()); - - Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) - .switchIfEmpty( - Maybe.defer( - () -> { - try (Scope innerScope = parentContext.makeCurrent()) { - return isLive - ? processFunctionLive( - invocationContext, - tool, - toolContext, - functionCall, - functionArgs, - parentContext) - : callTool(tool, functionArgs, toolContext, parentContext); - } - })); - - return postProcessFunctionResult( - maybeFunctionResult, - invocationContext, - tool, - functionArgs, - toolContext, - isLive, - parentContext); - } - }); + () -> { + BaseTool tool = tools.get(functionCall.name().get()); + ToolContext toolContext = + ToolContext.builder(invocationContext) + .functionCallId(functionCall.id().orElse("")) + .toolConfirmation( + functionCall.id().map(toolConfirmations::get).orElse(null)) + .build(); + + Map functionArgs = + functionCall.args().map(HashMap::new).orElse(new HashMap<>()); + + Maybe> maybeFunctionResult = + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + .switchIfEmpty( + Maybe.defer( + () -> + isLive + ? processFunctionLive( + invocationContext, + tool, + toolContext, + functionCall, + functionArgs, + parentContext) + : callTool( + tool, functionArgs, toolContext, parentContext)) + .compose(Tracing.withContext(parentContext))); + + return postProcessFunctionResult( + maybeFunctionResult, + invocationContext, + tool, + functionArgs, + toolContext, + isLive, + parentContext); + }) + .compose(Tracing.withContext(parentContext)); } /** @@ -410,34 +411,27 @@ private static Maybe postProcessFunctionResult( }) .flatMapMaybe( optionalInitialResult -> { - try (Scope scope = parentContext.makeCurrent()) { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - return maybeInvokeAfterToolCall( - invocationContext, tool, functionArgs, toolContext, initialFunctionResult) - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - return Maybe.fromCallable( - () -> - buildResponseEvent( - tool, - finalFunctionResult, - toolContext, - invocationContext)) - .compose( - Tracing.trace("tool_response [" + tool.name() + "]") - .setParent(parentContext)) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); - }); - } - }); + Map initialFunctionResult = optionalInitialResult.orElse(null); + + return maybeInvokeAfterToolCall( + invocationContext, tool, functionArgs, toolContext, initialFunctionResult) + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + Event event = + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext); + Tracing.traceToolResponse(event.id(), event); + return Maybe.just(event); + }); + }) + .compose( + Tracing.trace("tool_response [" + tool.name() + "]").setParent(parentContext)); } private static Optional mergeParallelFunctionResponseEvents( diff --git a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java index e00c0093d..a93eb3cb4 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java @@ -29,6 +29,7 @@ import com.google.adk.events.Event; import com.google.adk.events.ToolConfirmation; import com.google.adk.models.LlmRequest; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -37,6 +38,7 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.Collection; @@ -216,10 +218,13 @@ private Maybe assembleEvent( .build()) .build(); - return toolsMapSingle.flatMapMaybe( - toolsMap -> - Functions.handleFunctionCalls( - invocationContext, functionCallEvent, toolsMap, toolConfirmations)); + Context parentContext = Context.current(); + return toolsMapSingle + .flatMapMaybe( + toolsMap -> + Functions.handleFunctionCalls( + invocationContext, functionCallEvent, toolsMap, toolConfirmations)) + .compose(Tracing.withContext(parentContext)); } private static Optional> maybeCreateToolConfirmationEntry( diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index e534da787..8d0366e9a 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -21,11 +21,13 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -126,6 +128,7 @@ public Maybe beforeRunCallback(InvocationContext invocationContext) { @Override public Completable afterRunCallback(InvocationContext invocationContext) { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletable( plugin -> @@ -136,11 +139,13 @@ public Completable afterRunCallback(InvocationContext invocationContext) { logger.error( "[{}] Error during callback 'afterRunCallback'", plugin.getName(), - e))); + e))) + .compose(Tracing.withContext(capturedContext)); } @Override public Completable close() { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletableDelayError( plugin -> @@ -149,7 +154,8 @@ public Completable close() { .doOnError( e -> logger.error( - "[{}] Error during callback 'close'", plugin.getName(), e))); + "[{}] Error during callback 'close'", plugin.getName(), e))) + .compose(Tracing.withContext(capturedContext)); } @Override @@ -227,7 +233,7 @@ public Maybe> onToolErrorCallback( */ private Maybe runMaybeCallbacks( Function> callbackExecutor, String callbackName) { - + Context capturedContext = Context.current(); return Flowable.fromIterable(this.plugins) .concatMapMaybe( plugin -> @@ -247,6 +253,7 @@ private Maybe runMaybeCallbacks( plugin.getName(), callbackName, e))) - .firstElement(); + .firstElement() + .compose(Tracing.withContext(capturedContext)); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 5859c4786..51e1b8f25 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -52,6 +52,7 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -375,20 +376,25 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - Maybe maybeSession = - this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); - return maybeSession - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) + .compose(Tracing.trace("invocation")); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -441,7 +447,8 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta); + return runAsyncImpl(session, newMessage, runConfig, stateDelta) + .compose(Tracing.trace("invocation")); } /** @@ -461,6 +468,7 @@ protected Flowable runAsyncImpl( Preconditions.checkNotNull(session, "session cannot be null"); Preconditions.checkNotNull(newMessage, "newMessage cannot be null"); Preconditions.checkNotNull(runConfig, "runConfig cannot be null"); + Context capturedContext = Context.current(); return Flowable.defer( () -> { BaseAgent rootAgent = this.agent; @@ -476,6 +484,7 @@ protected Flowable runAsyncImpl( return this.pluginManager .onUserMessageCallback(initialContext, newMessage) + .compose(Tracing.withContext(capturedContext)) .defaultIfEmpty(newMessage) .flatMap( content -> @@ -500,7 +509,8 @@ protected Flowable runAsyncImpl( event, invocationId, runConfig, - rootAgent)); + rootAgent)) + .compose(Tracing.withContext(capturedContext)); }); }) .doOnError( @@ -508,8 +518,7 @@ protected Flowable runAsyncImpl( Span span = Span.current(); span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); span.recordException(throwable); - }) - .compose(Tracing.trace("invocation")); + }); } private Flowable runAgentWithFreshSession( @@ -562,12 +571,14 @@ private Flowable runAgentWithFreshSession( .toFlowable()); // If beforeRunCallback returns content, emit it and skip agent + Context capturedContext = Context.current(); return beforeRunEvent .toFlowable() .switchIfEmpty(agentEvents) .concatWith( Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) - .concatWith(Completable.defer(() -> compactEvents(updatedSession))); + .concatWith(Completable.defer(() -> compactEvents(updatedSession))) + .compose(Tracing.withContext(capturedContext)); } private Completable compactEvents(Session session) { @@ -632,46 +643,9 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { .agent(this.findAgentToRun(session, rootAgent)); } - /** - * Runs the agent in live mode, appending generated events to the session. - * - * @return stream of events from the agent. - */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return Flowable.defer( - () -> { - InvocationContext invocationContext = - newInvocationContextForLive(session, liveRequestQueue, runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - return invocationContextSingle - .flatMapPublisher( - updatedInvocationContext -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .doOnNext(event -> this.sessionService.appendEvent(session, event))) - .doOnError( - throwable -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - }); - }) - .compose(Tracing.trace("invocation")); + return runLiveImpl(session, liveRequestQueue, runConfig).compose(Tracing.trace("invocation")); } /** @@ -682,19 +656,25 @@ public Flowable runLive( */ public Flowable runLive( String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runLiveImpl(session, liveRequestQueue, runConfig))) + .compose(Tracing.trace("invocation")); } /** @@ -708,6 +688,49 @@ public Flowable runLive( return runLive(sessionKey.userId(), sessionKey.id(), liveRequestQueue, runConfig); } + /** + * Runs the agent in live mode, appending generated events to the session. + * + * @return stream of events from the agent. + */ + protected Flowable runLiveImpl( + Session session, @Nullable LiveRequestQueue liveRequestQueue, RunConfig runConfig) { + return Flowable.defer( + () -> { + Context capturedContext = Context.current(); + InvocationContext invocationContext = + newInvocationContextForLive(session, liveRequestQueue, runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }) + .compose(Tracing.withContext(capturedContext)); + }); + } + /** * Runs the agent asynchronously with a default user ID. * diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 3524c7755..c193e4a65 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -574,8 +574,13 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { String agentSpanId = agentSpan.getSpanContext().getSpanId(); llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); - toolCallSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); - toolResponseSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); + + // The tool calls and responses are children of the first LLM call that produced the function + // call. + String firstLlmSpanId = llmSpans.get(0).getSpanContext().getSpanId(); + toolCallSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); + toolResponseSpans.forEach( + s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); } @Test diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 4a0b345c6..6cae6c88a 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -43,9 +43,13 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.List; import java.util.Map; import java.util.Optional; @@ -572,6 +576,71 @@ public Single> runAsync(Map args, ToolContex } } + @Test + public void run_contextPropagation() { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + RequestProcessor requestProcessor = + (ctx, request) -> { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())) + .subscribeOn(Schedulers.computation()); + }; + + ResponseProcessor responseProcessor = + (ctx, response) -> { + return Single.just(ResponseProcessingResult.create(response, ImmutableList.of())) + .subscribeOn(Schedulers.computation()); + }; + + Callbacks.BeforeModelCallback beforeCallback = + (ctx, req) -> { + return Maybe.empty().subscribeOn(Schedulers.computation()); + }; + + Callbacks.AfterModelCallback afterCallback = + (ctx, resp) -> { + return Maybe.just(resp).subscribeOn(Schedulers.computation()); + }; + + Callbacks.OnModelErrorCallback onErrorCallback = + (ctx, req, err) -> { + return Maybe.just( + LlmResponse.builder().content(Content.fromParts(Part.fromText("error"))).build()) + .subscribeOn(Schedulers.computation()); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .beforeModelCallback(beforeCallback) + .afterModelCallback(afterCallback) + .onModelErrorCallback(onErrorCallback) + .build()); + + BaseLlmFlow baseLlmFlow = + createBaseLlmFlow(ImmutableList.of(requestProcessor), ImmutableList.of(responseProcessor)); + + List events; + try (Scope scope = testContext.makeCurrent()) { + events = + baseLlmFlow + .run(invocationContext) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .toList() + .blockingGet(); + } + + assertThat(events).hasSize(1); + assertThat(events.get(0).content()).hasValue(content); + } + @Test public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { GenerateContentResponseUsageMetadata usageMetadata = @@ -588,7 +657,12 @@ public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { List events = baseLlmFlow - .postprocess(invocationContext, baseEvent, LlmRequest.builder().build(), llmResponse) + .postprocess( + invocationContext, + baseEvent, + LlmRequest.builder().build(), + llmResponse, + Context.current()) .toList() .blockingGet(); diff --git a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java index 4ae856fc7..3771143cf 100644 --- a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java +++ b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java @@ -37,8 +37,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.schedulers.Schedulers; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -144,6 +148,87 @@ public void onUserMessageCallback_pluginOrderRespected() { inOrder.verify(plugin2).onUserMessageCallback(mockInvocationContext, content); } + @Test + public void contextPropagation_runMaybeCallbacks() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + Content expectedContent = Content.builder().build(); + when(plugin1.onUserMessageCallback(any(), any())) + .thenReturn(Maybe.just(expectedContent).subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Maybe resultMaybe; + try (Scope scope = testContext.makeCurrent()) { + resultMaybe = pluginManager.onUserMessageCallback(mockInvocationContext, content); + } + + // Assert downstream operators have the propagated context + resultMaybe + .doOnSuccess( + result -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(expectedContent); + + verify(plugin1).onUserMessageCallback(mockInvocationContext, content); + } + + @Test + public void contextPropagation_afterRunCallback() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + when(plugin1.afterRunCallback(any())) + .thenReturn(Completable.complete().subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Completable resultCompletable; + try (Scope scope = testContext.makeCurrent()) { + resultCompletable = pluginManager.afterRunCallback(mockInvocationContext); + } + + // Assert downstream operators have the propagated context + resultCompletable + .doOnComplete( + () -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(); + + verify(plugin1).afterRunCallback(mockInvocationContext); + } + + @Test + public void contextPropagation_close() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + when(plugin1.close()).thenReturn(Completable.complete().subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Completable resultCompletable; + try (Scope scope = testContext.makeCurrent()) { + resultCompletable = pluginManager.close(); + } + + // Assert downstream operators have the propagated context + resultCompletable + .doOnComplete( + () -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(); + + verify(plugin1).close(); + } + @Test public void afterRunCallback_allComplete() { when(plugin1.afterRunCallback(any())).thenReturn(Completable.complete()); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 8a0a84b08..2eb515fa2 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -57,6 +57,9 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; import io.reactivex.rxjava3.core.Completable; @@ -977,6 +980,84 @@ public void runLive_createsInvocationSpan() { assertThat(invocationSpan.get().hasEnded()).isTrue(); } + @Test + public void runAsync_createsToolSpansWithCorrectParent() { + LlmAgent agentWithTool = + createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); + Runner runnerWithTool = + Runner.builder().app(App.builder().name("test").rootAgent(agentWithTool).build()).build(); + Session sessionWithTool = + runnerWithTool.sessionService().createSession("test", "user").blockingGet(); + + var unused = + runnerWithTool + .runAsync( + sessionWithTool.sessionKey(), + createContent("from user"), + RunConfig.builder().build()) + .toList() + .blockingGet(); + + List spans = openTelemetryRule.getSpans(); + List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); + List toolCallSpans = + spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); + List toolResponseSpans = + spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + + assertThat(llmSpans).hasSize(2); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); + String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); + String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + + assertThat(toolCallParentId).isEqualTo(toolResponseParentId); + assertThat(llmSpanIds).contains(toolCallParentId); + } + + @Test + public void runLive_createsToolSpansWithCorrectParent() throws Exception { + LlmAgent agentWithTool = + createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); + Runner runnerWithTool = + Runner.builder().app(App.builder().name("test").rootAgent(agentWithTool).build()).build(); + Session sessionWithTool = + runnerWithTool.sessionService().createSession("test", "user").blockingGet(); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber = + runnerWithTool + .runLive(sessionWithTool.sessionKey(), liveRequestQueue, RunConfig.builder().build()) + .test(); + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + + List spans = openTelemetryRule.getSpans(); + List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); + List toolCallSpans = + spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); + List toolResponseSpans = + spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + + // In runLive, there is one call_llm span for the execution + assertThat(llmSpans).hasSize(1); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); + String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); + String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + + assertThat(toolCallParentId).isEqualTo(toolResponseParentId); + assertThat(llmSpanIds).contains(toolCallParentId); + } + @Test public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() { RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); @@ -1188,6 +1269,53 @@ public void close_closesPluginsAndCodeExecutors() { verify(plugin).close(); } + @Test + public void runAsync_contextPropagation() { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + List events; + try (Scope scope = testContext.makeCurrent()) { + events = + runner + .runAsync("user", session.id(), createContent("test message")) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .toList() + .blockingGet(); + } + + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + } + + @Test + public void runLive_contextPropagation() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber; + try (Scope scope = testContext.makeCurrent()) { + testSubscriber = + runner + .runLive(session, liveRequestQueue, RunConfig.builder().build()) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test(); + } + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + assertThat(simplifyEvents(testSubscriber.values())).containsExactly("test agent: from llm"); + } + @Test public void buildRunnerWithPlugins_success() { BasePlugin plugin1 = mockPlugin("test1"); From 2fcff3c30f5d0af4b4007e821af1f204e555fef9 Mon Sep 17 00:00:00 2001 From: Taylor Lanclos Date: Wed, 4 Mar 2026 19:12:16 +0000 Subject: [PATCH 05/40] Extract timestamp as double for InMemorySessionService events InMemorySessionService sets a Session's last modified time based on when the last appended event's timestamp. The timestamp in an event is recorded in millis while the Session's timestamp is an Instant. During the transformation, Events perform this converstion using division. Before this change, the timestamp was truncated to the second, yet the code was trying to extract nanos which were always 0. This fixes that bug with a simple type change. I've also added a test to prevent regressions. --- .../adk/sessions/InMemorySessionService.java | 13 ++----------- .../sessions/InMemorySessionServiceTest.java | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index b2a584b11..d9bb047a3 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -154,19 +154,13 @@ public Maybe getSession( if (config.numRecentEvents().isEmpty() && config.afterTimestamp().isPresent()) { Instant threshold = config.afterTimestamp().get(); - eventsInCopy.removeIf( - event -> getEventTimestampEpochSeconds(event) < threshold.getEpochSecond()); + eventsInCopy.removeIf(event -> getInstantFromEvent(event).isBefore(threshold)); } // Merge state into the potentially filtered copy and return return Maybe.just(mergeWithGlobalState(appName, userId, sessionCopy)); } - // Helper to get event timestamp as epoch seconds - private long getEventTimestampEpochSeconds(Event event) { - return event.timestamp() / 1000L; - } - @Override public Single listSessions(String appName, String userId) { Objects.requireNonNull(appName, "appName cannot be null"); @@ -294,10 +288,7 @@ public Single appendEvent(Session session, Event event) { /** Converts an event's timestamp to an Instant. Adapt based on actual Event structure. */ // TODO: have Event.timestamp() return Instant directly private Instant getInstantFromEvent(Event event) { - double epochSeconds = getEventTimestampEpochSeconds(event); - long seconds = (long) epochSeconds; - long nanos = (long) ((epochSeconds % 1.0) * 1_000_000_000L); - return Instant.ofEpochSecond(seconds, nanos); + return Instant.ofEpochMilli(event.timestamp()); } /** diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 41e156ffd..0d9235b1b 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -20,6 +20,7 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import io.reactivex.rxjava3.core.Single; +import java.time.Instant; import java.util.HashMap; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; @@ -214,6 +215,24 @@ public void appendEvent_removesState() { assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey"); } + @Test + public void appendEvent_updatesSessionTimestampWithFractionalSeconds() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + + // Add an event with a timestamp that contains a fractional second + Event eventAdd = Event.builder().timestamp(5500).build(); + var unused = sessionService.appendEvent(session, eventAdd).blockingGet(); + + // Verify the last modified timestamp contains a fractional second + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSession.lastUpdateTime()).isEqualTo(Instant.ofEpochSecond(5, 500000000L)); + } + @Test public void sequentialAgents_shareTempState() { InMemorySessionService sessionService = new InMemorySessionService(); From 4eb3613b65cb1334e9432960d0f864ef09829c23 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Mar 2026 10:27:33 -0700 Subject: [PATCH 06/40] fix: improve processRequest_concurrentReadAndWrite_noException test case PiperOrigin-RevId: 885091550 --- .../adk/flows/llmflows/ContentsTest.java | 89 ++++++++++--------- 1 file changed, 48 insertions(+), 41 deletions(-) diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 7164991f3..1e6267dde 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -36,13 +36,15 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; -import java.util.HashMap; +import java.util.ArrayList; +import java.util.ConcurrentModificationException; +import java.util.Iterator; import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -787,15 +789,49 @@ public void processRequest_notEmptyContent() { public void processRequest_concurrentReadAndWrite_noException() throws Exception { LlmAgent agent = LlmAgent.builder().name(AGENT).includeContents(LlmAgent.IncludeContents.DEFAULT).build(); + List customEvents = + new ArrayList() { + private void checkLock() { + if (!Thread.holdsLock(this)) { + throw new ConcurrentModificationException("Unsynchronized iteration detected!"); + } + } + + @Override + public Iterator iterator() { + checkLock(); + return super.iterator(); + } + + @Override + public ListIterator listIterator() { + checkLock(); + return super.listIterator(); + } + + @Override + public ListIterator listIterator(int index) { + checkLock(); + return super.listIterator(index); + } + + @Override + public Stream stream() { + checkLock(); + return super.stream(); + } + }; + Session session = - sessionService - .createSession("test-app", "test-user", new HashMap<>(), "test-session") - .blockingGet(); + Session.builder("test-session") + .appName("test-app") + .userId("test-user") + .events(customEvents) + .build(); - // Seed with dummy events to widen the race capability - for (int i = 0; i < 5000; i++) { - session.events().add(createUserEvent("dummy" + i, "dummy")); - } + // The list must have at least one element so that operations interacting with events trigger + // iteration. + customEvents.add(createUserEvent("dummy", "dummy")); InvocationContext context = InvocationContext.builder() @@ -807,37 +843,8 @@ public void processRequest_concurrentReadAndWrite_noException() throws Exception LlmRequest initialRequest = LlmRequest.builder().build(); - AtomicReference writerError = new AtomicReference<>(); - CountDownLatch startLatch = new CountDownLatch(1); - - Thread writerThread = - new Thread( - () -> { - startLatch.countDown(); - try { - for (int i = 0; i < 2000; i++) { - session.events().add(createUserEvent("writer" + i, "new data")); - } - } catch (Throwable t) { - writerError.set(t); - } - }); - - writerThread.start(); - startLatch.await(); // wait for writer to be ready - - // Process (read) requests concurrently to trigger race conditions - for (int i = 0; i < 200; i++) { - var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); - if (writerError.get() != null) { - throw new RuntimeException("Writer failed", writerError.get()); - } - } - - writerThread.join(); - if (writerError.get() != null) { - throw new RuntimeException("Writer failed", writerError.get()); - } + // This single call will throw the exception if the list is accessed insecurely. + var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); } private static Event createUserEvent(String id, String text) { From c8ab0f96b09a6c9636728d634c62695fcd622246 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Mar 2026 11:50:04 -0700 Subject: [PATCH 07/40] feat: Implement basic version of BigQuery Agent Analytics Plugin This change introduces a new plugin for the Agent Development Kit (ADK) that logs agent execution events to BigQuery. It includes: - `BigQueryAgentAnalyticsPlugin`: A plugin that captures various agent lifecycle events (user messages, tool calls, model invocations) and sends them to BigQuery. - `BigQueryLoggerConfig`: Configuration options for the plugin, including project/dataset/table IDs, batching, and retry settings. - `BigQuerySchema`: Defines the BigQuery and Arrow schemas used for the event table. - `BatchProcessor`: Handles batching of events and writing them to BigQuery using the Storage Write API with Arrow format. - `JsonFormatter`: Utility for safely formatting JSON content for BigQuery. PiperOrigin-RevId: 885133967 --- core/pom.xml | 20 + .../agentanalytics/BatchProcessor.java | 270 +++++++++++ .../BigQueryAgentAnalyticsPlugin.java | 436 +++++++++++++++++ .../agentanalytics/BigQueryLoggerConfig.java | 204 ++++++++ .../agentanalytics/BigQuerySchema.java | 304 ++++++++++++ .../plugins/agentanalytics/JsonFormatter.java | 111 +++++ .../agentanalytics/BatchProcessorTest.java | 367 ++++++++++++++ .../BigQueryAgentAnalyticsPluginTest.java | 457 ++++++++++++++++++ pom.xml | 11 + 9 files changed, 2180 insertions(+) create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java create mode 100644 core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java create mode 100644 core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java diff --git a/core/pom.xml b/core/pom.xml index 8c3c2069c..b3f2f5fd8 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -197,6 +197,26 @@ opentelemetry-sdk-testing test + + com.google.cloud + google-cloud-bigquery + 2.40.0 + + + org.apache.arrow + arrow-vector + 17.0.0 + + + org.apache.arrow + arrow-memory-core + 17.0.0 + + + org.apache.arrow + arrow-memory-netty + 17.0.0 + diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java new file mode 100644 index 000000000..ef826fb56 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java @@ -0,0 +1,270 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.Exceptions.AppendSerializationError; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +/** Handles asynchronous batching and writing of events to BigQuery. */ +class BatchProcessor implements AutoCloseable { + private static final Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + + private final StreamWriter writer; + private final int batchSize; + private final Duration flushInterval; + @VisibleForTesting final BlockingQueue> queue; + private final ScheduledExecutorService executor; + @VisibleForTesting final BufferAllocator allocator; + final AtomicBoolean flushLock = new AtomicBoolean(false); + private final Schema arrowSchema; + private final VectorSchemaRoot root; + + public BatchProcessor( + StreamWriter writer, + int batchSize, + Duration flushInterval, + int queueMaxSize, + ScheduledExecutorService executor) { + this.writer = writer; + this.batchSize = batchSize; + this.flushInterval = flushInterval; + this.queue = new LinkedBlockingQueue<>(queueMaxSize); + this.executor = executor; + // It's safe to use Long.MAX_VALUE here as this is a top-level RootAllocator, + // and memory is properly managed via try-with-resources in the flush() method. + // The actual memory usage is bounded by the batchSize and individual row sizes. + this.allocator = new RootAllocator(Long.MAX_VALUE); + this.arrowSchema = BigQuerySchema.getArrowSchema(); + this.root = VectorSchemaRoot.create(arrowSchema, allocator); + } + + public void start() { + @SuppressWarnings("unused") + var unused = + executor.scheduleWithFixedDelay( + () -> { + try { + flush(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Error in background flush", e); + } + }, + flushInterval.toMillis(), + flushInterval.toMillis(), + MILLISECONDS); + } + + public void append(Map row) { + if (!queue.offer(row)) { + logger.warning("BigQuery event queue is full, dropping event."); + return; + } + if (queue.size() >= batchSize && !flushLock.get()) { + executor.execute(this::flush); + } + } + + public void flush() { + // Acquire the flushLock. If another flush is already in progress, return immediately. + if (!flushLock.compareAndSet(false, true)) { + return; + } + try { + if (queue.isEmpty()) { + return; + } + List> batch = new ArrayList<>(); + queue.drainTo(batch, batchSize); + if (batch.isEmpty()) { + return; + } + try { + root.allocateNew(); + for (int i = 0; i < batch.size(); i++) { + Map row = batch.get(i); + for (Field field : arrowSchema.getFields()) { + populateVector(root.getVector(field.getName()), i, row.get(field.getName())); + } + } + root.setRowCount(batch.size()); + try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch()) { + AppendRowsResponse result = writer.append(recordBatch).get(); + if (result.hasError()) { + logger.severe("BigQuery append error: " + result.getError().getMessage()); + for (var error : result.getRowErrorsList()) { + logger.severe( + String.format("Row error at index %d: %s", error.getIndex(), error.getMessage())); + } + } else { + logger.fine("Successfully wrote " + batch.size() + " rows to BigQuery."); + } + } catch (AppendSerializationError ase) { + logger.log( + Level.SEVERE, "Failed to write batch to BigQuery due to serialization error", ase); + Map rowIndexToErrorMessage = ase.getRowIndexToErrorMessage(); + if (rowIndexToErrorMessage != null && !rowIndexToErrorMessage.isEmpty()) { + logger.severe("Row-level errors found:"); + for (Map.Entry entry : rowIndexToErrorMessage.entrySet()) { + logger.severe( + String.format("Row error at index %d: %s", entry.getKey(), entry.getValue())); + } + } else { + logger.severe( + "AppendSerializationError occurred, but no row-specific errors were provided."); + } + } + } catch (Exception e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + logger.log(Level.SEVERE, "Failed to write batch to BigQuery", e); + } finally { + // Clear the vectors to release the memory. + root.clear(); + } + } finally { + flushLock.set(false); + if (queue.size() >= batchSize && !flushLock.get()) { + executor.execute(this::flush); + } + } + } + + private void populateVector(FieldVector vector, int index, Object value) { + if (value == null || (value instanceof JsonNode jsonNode && jsonNode.isNull())) { + vector.setNull(index); + return; + } + if (vector instanceof VarCharVector varCharVector) { + String strValue = (value instanceof JsonNode jsonNode) ? jsonNode.asText() : value.toString(); + varCharVector.setSafe(index, strValue.getBytes(UTF_8)); + } else if (vector instanceof BigIntVector bigIntVector) { + long longValue; + if (value instanceof JsonNode jsonNode) { + longValue = jsonNode.asLong(); + } else if (value instanceof Number number) { + longValue = number.longValue(); + } else { + longValue = Long.parseLong(value.toString()); + } + bigIntVector.setSafe(index, longValue); + } else if (vector instanceof BitVector bitVector) { + boolean boolValue = + (value instanceof JsonNode jsonNode) ? jsonNode.asBoolean() : (Boolean) value; + bitVector.setSafe(index, boolValue ? 1 : 0); + } else if (vector instanceof TimeStampVector timeStampVector) { + if (value instanceof Instant instant) { + long micros = + SECONDS.toMicros(instant.getEpochSecond()) + NANOSECONDS.toMicros(instant.getNano()); + timeStampVector.setSafe(index, micros); + } else if (value instanceof JsonNode jsonNode) { + timeStampVector.setSafe(index, jsonNode.asLong()); + } else if (value instanceof Long longValue) { + timeStampVector.setSafe(index, longValue); + } + } else if (vector instanceof ListVector listVector) { + int start = listVector.startNewValue(index); + if (value instanceof ArrayNode arrayNode) { + for (int i = 0; i < arrayNode.size(); i++) { + populateVector(listVector.getDataVector(), start + i, arrayNode.get(i)); + } + listVector.endValue(index, arrayNode.size()); + } else if (value instanceof List) { + List list = (List) value; + for (int i = 0; i < list.size(); i++) { + populateVector(listVector.getDataVector(), start + i, list.get(i)); + } + listVector.endValue(index, list.size()); + } + } else if (vector instanceof StructVector structVector) { + structVector.setIndexDefined(index); + if (value instanceof ObjectNode objectNode) { + for (FieldVector child : structVector.getChildrenFromFields()) { + populateVector(child, index, objectNode.get(child.getName())); + } + } else if (value instanceof Map) { + Map map = (Map) value; + for (FieldVector child : structVector.getChildrenFromFields()) { + populateVector(child, index, map.get(child.getName())); + } + } + } + } + + @Override + public void close() { + if (this.queue != null && !this.queue.isEmpty()) { + this.flush(); + } + if (this.allocator != null) { + try { + this.allocator.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close Buffer allocator", e); + } + } + if (this.root != null) { + try { + this.root.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close VectorSchemaRoot", e); + } + } + if (this.writer != null) { + try { + this.writer.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close BigQuery writer", e); + } + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java new file mode 100644 index 000000000..68b5fb5a1 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -0,0 +1,436 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.plugins.BasePlugin; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.api.gax.core.FixedCredentialsProvider; +import com.google.api.gax.retrying.RetrySettings; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryException; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Clustering; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardTableDefinition; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.TableInfo; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import java.io.IOException; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.threeten.bp.Duration; + +/** + * BigQuery Agent Analytics Plugin for Java. + * + *

Logs agent execution events directly to a BigQuery table using the Storage Write API. + */ +public class BigQueryAgentAnalyticsPlugin extends BasePlugin { + private static final Logger logger = + Logger.getLogger(BigQueryAgentAnalyticsPlugin.class.getName()); + private static final ImmutableList DEFAULT_AUTH_SCOPES = + ImmutableList.of("https://www.googleapis.com/auth/cloud-platform"); + private static final AtomicLong threadCounter = new AtomicLong(0); + + private final BigQueryLoggerConfig config; + private final BigQuery bigQuery; + private final BigQueryWriteClient writeClient; + private final ScheduledExecutorService executor; + private final Object tableEnsuredLock = new Object(); + @VisibleForTesting final BatchProcessor batchProcessor; + private volatile boolean tableEnsured = false; + + public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException { + this(config, createBigQuery(config)); + } + + public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery) + throws IOException { + super("bigquery_agent_analytics"); + this.config = config; + this.bigQuery = bigQuery; + ThreadFactory threadFactory = + r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); + this.executor = Executors.newScheduledThreadPool(1, threadFactory); + this.writeClient = createWriteClient(config); + + if (config.enabled()) { + StreamWriter writer = createWriter(config); + this.batchProcessor = + new BatchProcessor( + writer, + config.batchSize(), + config.batchFlushInterval(), + config.queueMaxSize(), + executor); + this.batchProcessor.start(); + } else { + this.batchProcessor = null; + } + } + + private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException { + BigQueryOptions.Builder builder = BigQueryOptions.newBuilder(); + if (config.credentials() != null) { + builder.setCredentials(config.credentials()); + } else { + builder.setCredentials( + GoogleCredentials.getApplicationDefault().createScoped(DEFAULT_AUTH_SCOPES)); + } + return builder.build().getService(); + } + + private void ensureTableExistsOnce() { + if (!tableEnsured) { + synchronized (tableEnsuredLock) { + if (!tableEnsured) { + // Table creation is expensive, so we only do it once per plugin instance. + tableEnsured = true; + ensureTableExists(bigQuery, config); + } + } + } + } + + private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) { + TableId tableId = TableId.of(config.projectId(), config.datasetId(), config.tableName()); + Schema schema = BigQuerySchema.getEventsSchema(); + try { + Table table = bigQuery.getTable(tableId); + logger.info("BigQuery table: " + tableId); + if (table == null) { + logger.info("Creating BigQuery table: " + tableId); + StandardTableDefinition.Builder tableDefinitionBuilder = + StandardTableDefinition.newBuilder().setSchema(schema); + if (!config.clusteringFields().isEmpty()) { + tableDefinitionBuilder.setClustering( + Clustering.newBuilder().setFields(config.clusteringFields()).build()); + } + TableInfo tableInfo = TableInfo.newBuilder(tableId, tableDefinitionBuilder.build()).build(); + bigQuery.create(tableInfo); + } else if (config.autoSchemaUpgrade()) { + // TODO(b/491851868): Implement auto-schema upgrade. + logger.info("BigQuery table already exists and auto-schema upgrade is enabled: " + tableId); + logger.info("Auto-schema upgrade is not implemented yet."); + } + } catch (BigQueryException e) { + if (e.getMessage().contains("invalid_grant")) { + logger.log( + Level.SEVERE, + "Failed to authenticate with BigQuery. Please run 'gcloud auth application-default" + + " login' to refresh your credentials or provide valid credentials in" + + " BigQueryLoggerConfig.", + e); + } else { + logger.log( + Level.WARNING, "Failed to check or create/upgrade BigQuery table: " + tableId, e); + } + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Failed to check or create/upgrade BigQuery table: " + tableId, e); + } + } + + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException { + if (config.credentials() != null) { + return BigQueryWriteClient.create( + BigQueryWriteSettings.newBuilder() + .setCredentialsProvider(FixedCredentialsProvider.create(config.credentials())) + .build()); + } + return BigQueryWriteClient.create(); + } + + protected String getStreamName(BigQueryLoggerConfig config) { + return String.format( + "projects/%s/datasets/%s/tables/%s/streams/_default", + config.projectId(), config.datasetId(), config.tableName()); + } + + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig(); + RetrySettings retrySettings = + RetrySettings.newBuilder() + .setMaxAttempts(retryConfig.maxRetries()) + .setInitialRetryDelay(Duration.ofMillis(retryConfig.initialDelay().toMillis())) + .setRetryDelayMultiplier(retryConfig.multiplier()) + .setMaxRetryDelay(Duration.ofMillis(retryConfig.maxDelay().toMillis())) + .build(); + + String streamName = getStreamName(config); + try { + return StreamWriter.newBuilder(streamName, writeClient) + .setRetrySettings(retrySettings) + .setWriterSchema(BigQuerySchema.getArrowSchema()) + .build(); + } catch (Exception e) { + throw new VerifyException("Failed to create StreamWriter for " + streamName, e); + } + } + + private void logEvent( + String eventType, + InvocationContext invocationContext, + Optional callbackContext, + Object content, + Map extraAttributes) { + if (batchProcessor == null) { + return; + } + + ensureTableExistsOnce(); + + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", eventType); + row.put( + "agent", + callbackContext.map(CallbackContext::agentName).orElse(invocationContext.agent().name())); + row.put("session_id", invocationContext.session().id()); + row.put("invocation_id", invocationContext.invocationId()); + row.put("user_id", invocationContext.userId()); + + if (content instanceof Content contentParts) { + row.put( + "content_parts", + JsonFormatter.formatContentParts(Optional.of(contentParts), config.maxContentLength())); + row.put( + "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + } else if (content != null) { + row.put( + "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + } + + Map attributes = new HashMap<>(config.customTags()); + if (extraAttributes != null) { + attributes.putAll(extraAttributes); + } + row.put( + "attributes", + JsonFormatter.smartTruncate(attributes, config.maxContentLength()).toString()); + + addTraceDetails(row); + batchProcessor.append(row); + } + + // TODO(b/491849911): Implement own trace management functionality. + private void addTraceDetails(Map row) { + SpanContext spanContext = Span.current().getSpanContext(); + if (spanContext.isValid()) { + row.put("trace_id", spanContext.getTraceId()); + row.put("span_id", spanContext.getSpanId()); + } + } + + @Override + public Completable close() { + if (batchProcessor != null) { + batchProcessor.close(); + } + if (writeClient != null) { + writeClient.close(); + } + try { + executor.shutdown(); + if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + return Completable.complete(); + } + + @Override + public Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { + return Maybe.fromAction( + () -> logEvent("USER_MESSAGE", invocationContext, Optional.empty(), userMessage, null)); + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + return Maybe.fromAction( + () -> logEvent("INVOCATION_START", invocationContext, Optional.empty(), null, null)); + } + + @Override + public Maybe onEventCallback(InvocationContext invocationContext, Event event) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("event_author", event.author()); + logEvent( + "EVENT", invocationContext, Optional.empty(), event.content().orElse(null), attrs); + }); + } + + @Override + public Completable afterRunCallback(InvocationContext invocationContext) { + return Completable.fromAction( + () -> { + logEvent("INVOCATION_END", invocationContext, Optional.empty(), null, null); + batchProcessor.flush(); + }); + } + + @Override + public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.fromAction( + () -> + logEvent( + "AGENT_START", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + null)); + } + + @Override + public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.fromAction( + () -> + logEvent( + "AGENT_END", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + null)); + } + + @Override + public Maybe beforeModelCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + LlmRequest req = llmRequest.build(); + attrs.put("model", req.model().orElse("unknown")); + logEvent( + "MODEL_REQUEST", + callbackContext.invocationContext(), + Optional.of(callbackContext), + req, + attrs); + }); + } + + @Override + public Maybe afterModelCallback( + CallbackContext callbackContext, LlmResponse llmResponse) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + llmResponse.usageMetadata().ifPresent(u -> attrs.put("usage_metadata", u)); + logEvent( + "MODEL_RESPONSE", + callbackContext.invocationContext(), + Optional.of(callbackContext), + llmResponse, + attrs); + }); + } + + @Override + public Maybe onModelErrorCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("error_message", error.getMessage()); + logEvent( + "MODEL_ERROR", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + attrs); + }); + } + + @Override + public Maybe> beforeToolCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + logEvent( + "TOOL_START", + toolContext.invocationContext(), + Optional.of(toolContext), + toolArgs, + attrs); + }); + } + + @Override + public Maybe> afterToolCallback( + BaseTool tool, + Map toolArgs, + ToolContext toolContext, + Map result) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + logEvent( + "TOOL_END", toolContext.invocationContext(), Optional.of(toolContext), result, attrs); + }); + } + + @Override + public Maybe> onToolErrorCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + attrs.put("error_message", error.getMessage()); + logEvent( + "TOOL_ERROR", toolContext.invocationContext(), Optional.of(toolContext), null, attrs); + }); + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java new file mode 100644 index 000000000..aa5bf37de --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java @@ -0,0 +1,204 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.google.auth.Credentials; +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; +import javax.annotation.Nullable; + +/** Configuration for the BigQueryAgentAnalyticsPlugin. */ +@AutoValue +public abstract class BigQueryLoggerConfig { + // Whether the plugin is enabled. + public abstract boolean enabled(); + + // List of event types to log. If None, all are allowed + // TODO(b/491852782): Implement allowlist/denylist for event types. + @Nullable + public abstract ImmutableList eventAllowlist(); + + // List of event types to ignore. + // TODO(b/491852782): Implement allowlist/denylist for event types. + @Nullable + public abstract ImmutableList eventDenylist(); + + // Max length for text content before truncation. + public abstract int maxContentLength(); + + // Project ID for the BigQuery table. + public abstract String projectId(); + + // Dataset ID for the BigQuery table. + public abstract String datasetId(); + + // Table name for the BigQuery table. + public abstract String tableName(); + + // Fields to cluster the table by. + public abstract ImmutableList clusteringFields(); + + // Whether to log multi-modal content. + // TODO(b/491852782): Implement logging of multi-modal content. + public abstract boolean logMultiModalContent(); + + // Retry configuration for BigQuery writes. + public abstract RetryConfig retryConfig(); + + // Number of rows to batch before flushing. + public abstract int batchSize(); + + // Duration to wait before flushing the queue. + public abstract Duration batchFlushInterval(); + + // Max time to wait for shutdown. + public abstract Duration shutdownTimeout(); + + // Max size of the batch processor queue. + public abstract int queueMaxSize(); + + // Optional custom formatter for content. + // TODO(b/491852782): Implement content formatter. + @Nullable + public abstract BiFunction contentFormatter(); + + // TODO(b/491852782): Implement connection id. + public abstract Optional connectionId(); + + // Toggle for session metadata (e.g. gchat thread-id). + // TODO(b/491852782): Implement logging of session metadata. + public abstract boolean logSessionMetadata(); + + // Static custom tags (e.g. {"agent_role": "sales"}). + // TODO(b/491852782): Implement custom tags. + public abstract ImmutableMap customTags(); + + // Automatically add new columns to existing tables when the plugin + // schema evolves. Only additive changes are made (columns are never + // dropped or altered). + // TODO(b/491852782): Implement auto-schema upgrade. + public abstract boolean autoSchemaUpgrade(); + + @Nullable + public abstract Credentials credentials(); + + public static Builder builder() { + return new AutoValue_BigQueryLoggerConfig.Builder() + .setEnabled(true) + .setMaxContentLength(500 * 1024) + .setDatasetId("agent_analytics") + .setTableName("events") + .setClusteringFields(ImmutableList.of("event_type", "agent", "user_id")) + .setLogMultiModalContent(true) + .setRetryConfig(RetryConfig.builder().build()) + .setBatchSize(1) + .setBatchFlushInterval(Duration.ofSeconds(1)) + .setShutdownTimeout(Duration.ofSeconds(10)) + .setQueueMaxSize(10000) + .setLogSessionMetadata(true) + .setCustomTags(ImmutableMap.of()) + // TODO(b/491851868): Enable auto-schema upgrade once implemented. + .setAutoSchemaUpgrade(false); + } + + /** Builder for {@link BigQueryLoggerConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setEnabled(boolean enabled); + + public abstract Builder setEventAllowlist(@Nullable List eventAllowlist); + + public abstract Builder setEventDenylist(@Nullable List eventDenylist); + + public abstract Builder setMaxContentLength(int maxContentLength); + + public abstract Builder setProjectId(String projectId); + + public abstract Builder setDatasetId(String datasetId); + + public abstract Builder setTableName(String tableName); + + public abstract Builder setClusteringFields(List clusteringFields); + + public abstract Builder setLogMultiModalContent(boolean logMultiModalContent); + + public abstract Builder setRetryConfig(RetryConfig retryConfig); + + public abstract Builder setBatchSize(int batchSize); + + public abstract Builder setBatchFlushInterval(Duration batchFlushInterval); + + public abstract Builder setShutdownTimeout(Duration shutdownTimeout); + + public abstract Builder setQueueMaxSize(int queueMaxSize); + + public abstract Builder setContentFormatter( + @Nullable BiFunction contentFormatter); + + public abstract Builder setConnectionId(String connectionId); + + public abstract Builder setLogSessionMetadata(boolean logSessionMetadata); + + public abstract Builder setCustomTags(Map customTags); + + public abstract Builder setAutoSchemaUpgrade(boolean autoSchemaUpgrade); + + public abstract Builder setCredentials(Credentials credentials); + + public abstract BigQueryLoggerConfig build(); + } + + /** Retry configuration for BigQuery writes. */ + @AutoValue + public abstract static class RetryConfig { + public abstract int maxRetries(); + + public abstract Duration initialDelay(); + + public abstract double multiplier(); + + public abstract Duration maxDelay(); + + public static Builder builder() { + return new AutoValue_BigQueryLoggerConfig_RetryConfig.Builder() + .setMaxRetries(3) + .setInitialDelay(Duration.ofSeconds(1)) + .setMultiplier(2.0) + .setMaxDelay(Duration.ofSeconds(10)); + } + + /** Builder for {@link RetryConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setMaxRetries(int maxRetries); + + public abstract Builder setInitialDelay(Duration initialDelay); + + public abstract Builder setMultiplier(double multiplier); + + public abstract Builder setMaxDelay(Duration maxDelay); + + public abstract RetryConfig build(); + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java new file mode 100644 index 000000000..81181a1e0 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java @@ -0,0 +1,304 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.FieldList; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardSQLTypeName; +import com.google.cloud.bigquery.storage.v1.TableFieldSchema; +import com.google.cloud.bigquery.storage.v1.TableFieldSchema.Mode; +import com.google.cloud.bigquery.storage.v1.TableSchema; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; + +/** Utility for defining the BigQuery events table schema. */ +public final class BigQuerySchema { + + private BigQuerySchema() {} + + private static final ImmutableMap> + FIELD_TYPE_TO_ARROW_FIELD_METADATA = + ImmutableMap.of( + StandardSQLTypeName.JSON, + ImmutableMap.of("ARROW:extension:name", "google:sqlType:json"), + StandardSQLTypeName.DATETIME, + ImmutableMap.of("ARROW:extension:name", "google:sqlType:datetime"), + StandardSQLTypeName.GEOGRAPHY, + ImmutableMap.of( + "ARROW:extension:name", + "google:sqlType:geography", + "ARROW:extension:metadata", + "{\"encoding\": \"WKT\"}")); + + /** Returns the BigQuery schema for the events table. */ + // TODO(b/491848381): Rely on the same schema defined for python plugin. + public static Schema getEventsSchema() { + return Schema.of( + Field.newBuilder("timestamp", StandardSQLTypeName.TIMESTAMP) + .setMode(Field.Mode.REQUIRED) + .setDescription("The UTC timestamp when the event occurred.") + .build(), + Field.newBuilder("event_type", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The category of the event.") + .build(), + Field.newBuilder("agent", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The name of the agent that generated this event.") + .build(), + Field.newBuilder("session_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("A unique identifier for the entire conversation session.") + .build(), + Field.newBuilder("invocation_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("A unique identifier for a single turn or execution.") + .build(), + Field.newBuilder("user_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The identifier of the end-user.") + .build(), + Field.newBuilder("trace_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry trace ID.") + .build(), + Field.newBuilder("span_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry span ID.") + .build(), + Field.newBuilder("parent_span_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry parent span ID.") + .build(), + Field.newBuilder("content", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("The primary payload of the event.") + .build(), + Field.newBuilder( + "content_parts", + StandardSQLTypeName.STRUCT, + FieldList.of( + Field.newBuilder("mime_type", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The MIME type of the content part.") + .build(), + Field.newBuilder("uri", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The URI of the content part if stored externally.") + .build(), + Field.newBuilder( + "object_ref", + StandardSQLTypeName.STRUCT, + FieldList.of( + Field.newBuilder("uri", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("version", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("authorizer", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("details", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .build())) + .setMode(Field.Mode.NULLABLE) + .setDescription("The ObjectRef of the content part if stored externally.") + .build(), + Field.newBuilder("text", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The raw text content.") + .build(), + Field.newBuilder("part_index", StandardSQLTypeName.INT64) + .setMode(Field.Mode.NULLABLE) + .setDescription("The zero-based index of this part.") + .build(), + Field.newBuilder("part_attributes", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Additional metadata as a JSON object string.") + .build(), + Field.newBuilder("storage_mode", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Indicates how the content part is stored.") + .build())) + .setMode(Field.Mode.REPEATED) + .setDescription("Multi-modal events content parts.") + .build(), + Field.newBuilder("attributes", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("A JSON object containing arbitrary key-value pairs.") + .build(), + Field.newBuilder("latency_ms", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("A JSON object containing latency measurements.") + .build(), + Field.newBuilder("status", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The outcome of the event.") + .build(), + Field.newBuilder("error_message", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Detailed error message if the status is 'ERROR'.") + .build(), + Field.newBuilder("is_truncated", StandardSQLTypeName.BOOL) + .setMode(Field.Mode.NULLABLE) + .setDescription("Indicates if the 'content' field was truncated.") + .build()); + } + + /** Returns the Arrow schema for the events table. */ + public static org.apache.arrow.vector.types.pojo.Schema getArrowSchema() { + return new org.apache.arrow.vector.types.pojo.Schema( + getEventsSchema().getFields().stream() + .map(BigQuerySchema::convertToArrowField) + .collect(toImmutableList())); + } + + /** Returns the serialized Arrow schema for the events table. */ + public static ByteString getSerializedArrowSchema() { + try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), getArrowSchema()); + return ByteString.copyFrom(out.toByteArray()); + } catch (IOException e) { + throw new VerifyException("Failed to serialize arrow schema", e); + } + } + + private static org.apache.arrow.vector.types.pojo.Field convertToArrowField(Field field) { + ArrowType arrowType = convertTypeToArrow(field.getType().getStandardType()); + ImmutableList children = null; + if (field.getSubFields() != null) { + children = + field.getSubFields().stream() + .map(BigQuerySchema::convertToArrowField) + .collect(toImmutableList()); + } + + ImmutableMap metadata = + FIELD_TYPE_TO_ARROW_FIELD_METADATA.get(field.getType().getStandardType()); + + FieldType fieldType = + new FieldType(field.getMode() != Field.Mode.REQUIRED, arrowType, null, metadata); + org.apache.arrow.vector.types.pojo.Field arrowField = + new org.apache.arrow.vector.types.pojo.Field(field.getName(), fieldType, children); + + if (field.getMode() == Field.Mode.REPEATED) { + return new org.apache.arrow.vector.types.pojo.Field( + field.getName(), + new FieldType(false, new ArrowType.List(), null), + ImmutableList.of( + new org.apache.arrow.vector.types.pojo.Field( + "element", arrowField.getFieldType(), arrowField.getChildren()))); + } + return arrowField; + } + + private static ArrowType convertTypeToArrow(StandardSQLTypeName type) { + return switch (type) { + case BOOL -> new ArrowType.Bool(); + case BYTES -> new ArrowType.Binary(); + case DATE -> new ArrowType.Date(DateUnit.DAY); + case DATETIME -> + // Arrow doesn't have a direct DATETIME, often mapped to Timestamp or Utf8 + new ArrowType.Utf8(); + case FLOAT64 -> new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case INT64 -> new ArrowType.Int(64, true); + case NUMERIC, BIGNUMERIC -> new ArrowType.Decimal(38, 9, 128); + case GEOGRAPHY, STRING, JSON -> new ArrowType.Utf8(); + case STRUCT -> new ArrowType.Struct(); + case TIME -> new ArrowType.Time(TimeUnit.MICROSECOND, 64); + case TIMESTAMP -> new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC"); + default -> new ArrowType.Null(); + }; + } + + /** Returns names of fields to cluster by default. */ + public static ImmutableList getDefaultClusteringFields() { + return ImmutableList.of("event_type", "agent", "user_id"); + } + + /** Returns the BigQuery TableSchema for the events table (Storage Write API). */ + public static TableSchema getEventsTableSchema() { + return convertTableSchema(getEventsSchema()); + } + + private static TableSchema convertTableSchema(Schema schema) { + TableSchema.Builder result = TableSchema.newBuilder(); + for (int i = 0; i < schema.getFields().size(); i++) { + result.addFields(i, convertFieldSchema(schema.getFields().get(i))); + } + return result.build(); + } + + private static TableFieldSchema convertFieldSchema(Field field) { + TableFieldSchema.Builder result = TableFieldSchema.newBuilder(); + Field.Mode mode = field.getMode() != null ? field.getMode() : Field.Mode.NULLABLE; + + Mode resultMode = Mode.valueOf(mode.name()); + result.setMode(resultMode).setName(field.getName()); + + StandardSQLTypeName standardType = field.getType().getStandardType(); + TableFieldSchema.Type resultType = convertType(standardType); + result.setType(resultType); + + if (field.getDescription() != null) { + result.setDescription(field.getDescription()); + } + if (field.getSubFields() != null) { + for (int i = 0; i < field.getSubFields().size(); i++) { + result.addFields(i, convertFieldSchema(field.getSubFields().get(i))); + } + } + return result.build(); + } + + private static TableFieldSchema.Type convertType(StandardSQLTypeName type) { + return switch (type) { + case BOOL -> TableFieldSchema.Type.BOOL; + case BYTES -> TableFieldSchema.Type.BYTES; + case DATE -> TableFieldSchema.Type.DATE; + case DATETIME -> TableFieldSchema.Type.DATETIME; + case FLOAT64 -> TableFieldSchema.Type.DOUBLE; + case GEOGRAPHY -> TableFieldSchema.Type.GEOGRAPHY; + case INT64 -> TableFieldSchema.Type.INT64; + case NUMERIC -> TableFieldSchema.Type.NUMERIC; + case STRING -> TableFieldSchema.Type.STRING; + case STRUCT -> TableFieldSchema.Type.STRUCT; + case TIME -> TableFieldSchema.Type.TIME; + case TIMESTAMP -> TableFieldSchema.Type.TIMESTAMP; + case BIGNUMERIC -> TableFieldSchema.Type.BIGNUMERIC; + case JSON -> TableFieldSchema.Type.JSON; + case INTERVAL -> TableFieldSchema.Type.INTERVAL; + default -> TableFieldSchema.Type.TYPE_UNSPECIFIED; + }; + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java new file mode 100644 index 000000000..b4b4a1049 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java @@ -0,0 +1,111 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.Part; +import java.util.List; +import java.util.Optional; + +/** Utility for formatting and truncating content for BigQuery logging. */ +final class JsonFormatter { + private static final ObjectMapper mapper = new ObjectMapper().findAndRegisterModules(); + + private JsonFormatter() {} + + /** Formats Content parts into an ArrayNode for BigQuery logging. */ + public static ArrayNode formatContentParts(Optional content, int maxLength) { + ArrayNode partsArray = mapper.createArrayNode(); + if (content.isEmpty() || content.get().parts() == null) { + return partsArray; + } + + List parts = content.get().parts().orElse(ImmutableList.of()); + + for (int i = 0; i < parts.size(); i++) { + Part part = parts.get(i); + ObjectNode partObj = mapper.createObjectNode(); + partObj.put("part_index", i); + partObj.put("storage_mode", "INLINE"); + + if (part.text().isPresent()) { + partObj.put("mime_type", "text/plain"); + partObj.put("text", truncateString(part.text().get(), maxLength)); + } else if (part.inlineData().isPresent()) { + Blob blob = part.inlineData().get(); + partObj.put("mime_type", blob.mimeType().orElse("")); + partObj.put("text", "[BINARY DATA]"); + } else if (part.fileData().isPresent()) { + FileData fileData = part.fileData().get(); + partObj.put("mime_type", fileData.mimeType().orElse("")); + partObj.put("uri", fileData.fileUri().orElse("")); + partObj.put("storage_mode", "EXTERNAL_URI"); + } + partsArray.add(partObj); + } + return partsArray; + } + + /** Recursively truncates long strings inside an object and returns a Jackson JsonNode. */ + public static JsonNode smartTruncate(Object obj, int maxLength) { + if (obj == null) { + return mapper.nullNode(); + } + try { + return recursiveSmartTruncate(mapper.valueToTree(obj), maxLength); + } catch (IllegalArgumentException e) { + // Fallback for types that mapper can't handle directly as a tree + return mapper.valueToTree(String.valueOf(obj)); + } + } + + private static JsonNode recursiveSmartTruncate(JsonNode node, int maxLength) { + if (node.isTextual()) { + return mapper.valueToTree(truncateString(node.asText(), maxLength)); + } else if (node.isObject()) { + ObjectNode newNode = mapper.createObjectNode(); + node.properties() + .iterator() + .forEachRemaining( + entry -> { + newNode.set(entry.getKey(), recursiveSmartTruncate(entry.getValue(), maxLength)); + }); + return newNode; + } else if (node.isArray()) { + ArrayNode newNode = mapper.createArrayNode(); + for (JsonNode element : node) { + newNode.add(recursiveSmartTruncate(element, maxLength)); + } + return newNode; + } + return node; + } + + private static String truncateString(String s, int maxLength) { + if (s == null || s.length() <= maxLength) { + return s; + } + return s.substring(0, maxLength) + "...[truncated]"; + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java new file mode 100644 index 000000000..4f4350d1a --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java @@ -0,0 +1,367 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.core.ApiFutures; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.RowError; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.rpc.Status; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class BatchProcessorTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock private StreamWriter mockWriter; + private ScheduledExecutorService executor; + private BatchProcessor batchProcessor; + private Schema schema; + private Handler mockHandler; + + @Before + public void setUp() { + executor = Executors.newScheduledThreadPool(1); + batchProcessor = new BatchProcessor(mockWriter, 10, Duration.ofMinutes(1), 100, executor); + schema = BigQuerySchema.getArrowSchema(); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + + Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + } + + @After + public void tearDown() { + batchProcessor.close(); + executor.shutdown(); + } + + @Test + public void flush_populatesTimestampFieldCorrectly() throws Exception { + Instant now = Instant.parse("2026-03-02T19:11:49.631Z"); + Map row = new HashMap<>(); + row.put("timestamp", now); + row.put("event_type", "TEST_EVENT"); + + final boolean[] checksPassed = {false}; + final String[] failureMessage = {null}; + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + if (root.getRowCount() != 1) { + failureMessage[0] = "Expected 1 row, got " + root.getRowCount(); + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + } + + var timestampVector = root.getVector("timestamp"); + if (!(timestampVector instanceof TimeStampMicroTZVector tzVector)) { + failureMessage[0] = "Vector should be an instance of TimeStampMicroTZVector"; + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + } + if (tzVector.isNull(0)) { + failureMessage[0] = "Timestamp should NOT be null"; + } else if (tzVector.get(0) != now.toEpochMilli() * 1000) { + failureMessage[0] = + "Expected " + (now.toEpochMilli() * 1000) + ", got " + tzVector.get(0); + } else { + checksPassed[0] = true; + } + } catch (RuntimeException e) { + failureMessage[0] = "Exception during check: " + e.getMessage(); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + assertTrue(failureMessage[0], checksPassed[0]); + } + + @Test + public void flush_populatesAllBasicFields() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", "BASIC_EVENT"); + row.put("is_truncated", true); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertEquals("BASIC_EVENT", root.getVector("event_type").getObject(0).toString()); + assertEquals(1, ((BitVector) root.getVector("is_truncated")).get(0)); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_populatesJsonFields() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("content", "{\"key\": \"value\"}"); + row.put("attributes", "{\"attr\": 123}"); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertEquals( + "{\"key\": \"value\"}", root.getVector("content").getObject(0).toString()); + assertEquals( + "{\"attr\": 123}", root.getVector("attributes").getObject(0).toString()); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_populatesNestedStructs() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + + List> contentParts = new ArrayList<>(); + Map part = new HashMap<>(); + part.put("mime_type", "text/plain"); + part.put("text", "hello world"); + part.put("part_index", 0L); + contentParts.add(part); + row.put("content_parts", contentParts); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + ListVector contentPartsVector = (ListVector) root.getVector("content_parts"); + StructVector structVector = (StructVector) contentPartsVector.getDataVector(); + + assertEquals(1, ((List) contentPartsVector.getObject(0)).size()); + VarCharVector mimeTypeVector = (VarCharVector) structVector.getChild("mime_type"); + assertEquals("text/plain", mimeTypeVector.getObject(0).toString()); + + VarCharVector textVector = (VarCharVector) structVector.getChild("text"); + assertEquals("hello world", textVector.getObject(0).toString()); + + BigIntVector partIndexVector = (BigIntVector) structVector.getChild("part_index"); + assertEquals(0L, partIndexVector.get(0)); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesBigQueryErrorResponse() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "ERROR_EVENT"); + + AppendRowsResponse responseWithError = + AppendRowsResponse.newBuilder() + .setError(Status.newBuilder().setMessage("Global error").build()) + .addRowErrors(RowError.newBuilder().setIndex(0).setMessage("Row error").build()) + .build(); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(responseWithError)); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesGenericExceptionDuringAppend() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "EXCEPTION_EVENT"); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenThrow(new RuntimeException("Simulated failure")); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void append_triggersFlushWhenBatchSizeReached() { + ScheduledExecutorService mockExecutor = mock(ScheduledExecutorService.class); + BatchProcessor bp = new BatchProcessor(mockWriter, 2, Duration.ofMinutes(1), 10, mockExecutor); + + Map row = new HashMap<>(); + bp.append(row); + verify(mockExecutor, never()).execute(any(Runnable.class)); + + bp.append(row); + verify(mockExecutor).execute(any(Runnable.class)); + } + + @Test + public void flush_doesNothingWhenQueueIsEmpty() throws Exception { + batchProcessor.flush(); + verify(mockWriter, never()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesNullValues() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", null); + row.put("is_truncated", null); + + final boolean[] checksPassed = {false}; + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertTrue(root.getVector("event_type").isNull(0)); + assertTrue(root.getVector("is_truncated").isNull(0)); + checksPassed[0] = true; + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + assertTrue("Null checks failed", checksPassed[0]); + } + + @Test + public void flush_handlesAllocationFailure() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "ALLOC_FAIL_EVENT"); + batchProcessor.append(row); + batchProcessor.allocator.setLimit(1); + + batchProcessor.flush(); + + verify(mockWriter, never()).append(any(ArrowRecordBatch.class)); + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + boolean foundError = false; + for (LogRecord record : captor.getAllValues()) { + if (record.getLevel().equals(Level.SEVERE) + && record.getMessage().contains("Failed to write batch to BigQuery")) { + foundError = true; + break; + } + } + assertTrue("Expected SEVERE error log not found", foundError); + } + + @Test + public void close_flushesAndClosesResources() throws Exception { + try (BatchProcessor bp = + new BatchProcessor(mockWriter, 10, Duration.ofMinutes(1), 100, executor)) { + Map row = new HashMap<>(); + row.put("event_type", "CLOSE_EVENT"); + bp.append(row); + } + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + verify(mockWriter).close(); + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java new file mode 100644 index 000000000..8147c5cc6 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -0,0 +1,457 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.Session; +import com.google.api.core.ApiFutures; +import com.google.auth.Credentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.context.Scope; +import io.reactivex.rxjava3.core.Flowable; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class BigQueryAgentAnalyticsPluginTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock private BigQuery mockBigQuery; + @Mock private StreamWriter mockWriter; + @Mock private BigQueryWriteClient mockWriteClient; + @Mock private InvocationContext mockInvocationContext; + private BaseAgent fakeAgent; + + private BigQueryLoggerConfig config; + private BigQueryAgentAnalyticsPlugin plugin; + private Handler mockHandler; + + @Before + public void setUp() throws Exception { + fakeAgent = new FakeAgent("agent_name"); + config = + BigQueryLoggerConfig.builder() + .setEnabled(true) + .setProjectId("project") + .setDatasetId("dataset") + .setTableName("table") + .setBatchSize(10) + .setBatchFlushInterval(Duration.ofSeconds(10)) + .setAutoSchemaUpgrade(false) + .setCredentials(mock(Credentials.class)) + .setCustomTags(ImmutableMap.of("global_tag", "global_value")) + .build(); + + when(mockBigQuery.getOptions()) + .thenReturn(BigQueryOptions.newBuilder().setProjectId("test-project").build()); + when(mockBigQuery.getTable(any(TableId.class))).thenReturn(mock(Table.class)); + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + + plugin = + new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + return mockWriter; + } + }; + + Session session = Session.builder("session_id").build(); + when(mockInvocationContext.session()).thenReturn(session); + when(mockInvocationContext.invocationId()).thenReturn("invocation_id"); + when(mockInvocationContext.agent()).thenReturn(fakeAgent); + when(mockInvocationContext.userId()).thenReturn("user_id"); + + Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + } + + @Test + public void onUserMessageCallback_appendsToWriter() throws Exception { + Content content = Content.builder().build(); + + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void beforeRunCallback_appendsToWriter() throws Exception { + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void afterRunCallback_flushesAndAppends() throws Exception { + plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void getStreamName_returnsCorrectFormat() { + BigQueryLoggerConfig config = + BigQueryLoggerConfig.builder() + .setProjectId("test-project") + .setDatasetId("test-dataset") + .setTableName("test-table") + .build(); + + String streamName = plugin.getStreamName(config); + + assertEquals( + "projects/test-project/datasets/test-dataset/tables/test-table/streams/_default", + streamName); + } + + @Test + public void formatContentParts_populatesCorrectFields() { + Content content = Content.fromParts(Part.fromText("hello")); + ArrayNode nodes = JsonFormatter.formatContentParts(Optional.of(content), 100); + assertEquals(1, nodes.size()); + ObjectNode node = (ObjectNode) nodes.get(0); + assertEquals(0, node.get("part_index").asInt()); + assertEquals("INLINE", node.get("storage_mode").asText()); + assertEquals("hello", node.get("text").asText()); + assertEquals("text/plain", node.get("mime_type").asText()); + } + + @Test + public void arrowSchema_hasJsonMetadata() { + Schema schema = BigQuerySchema.getArrowSchema(); + Field contentField = schema.findField("content"); + assertNotNull(contentField); + assertEquals("google:sqlType:json", contentField.getMetadata().get("ARROW:extension:name")); + } + + @Test + public void onUserMessageCallback_handlesTableCreationFailure() throws Exception { + Logger logger = Logger.getLogger(BigQueryAgentAnalyticsPlugin.class.getName()); + Handler mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + try { + when(mockBigQuery.getTable(any(TableId.class))) + .thenThrow(new RuntimeException("Table check failed")); + Content content = Content.builder().build(); + + // Should not throw exception + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + assertTrue( + captor + .getValue() + .getMessage() + .contains("Failed to check or create/upgrade BigQuery table")); + assertEquals(Level.WARNING, captor.getValue().getLevel()); + } finally { + logger.removeHandler(mockHandler); + } + } + + @Test + public void onUserMessageCallback_handlesAppendFailure() throws Exception { + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFailedFuture(new RuntimeException("Append failed"))); + Content content = Content.builder().build(); + + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + // Flush should handle the failed future from writer.append() + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + assertTrue(captor.getValue().getMessage().contains("Failed to write batch to BigQuery")); + assertEquals(Level.SEVERE, captor.getValue().getLevel()); + } + + @Test + public void ensureTableExists_calledOnlyOnce() throws Exception { + Content content = Content.builder().build(); + + // Multiple calls to logEvent via different callbacks + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); + + // Verify getting table was only done once. Using fully qualified name to avoid ambiguity. + verify(mockBigQuery).getTable(any(TableId.class)); + } + + @Test + public void arrowSchema_handlesNestedFields() { + Schema schema = BigQuerySchema.getArrowSchema(); + Field contentPartsField = schema.findField("content_parts"); + assertNotNull(contentPartsField); + // Repeated struct becomes a List of Structs + assertTrue(contentPartsField.getType() instanceof ArrowType.List); + + Field element = contentPartsField.getChildren().get(0); + assertEquals("element", element.getName()); + + // Check object_ref which is a nested STRUCT + Field objectRef = + element.getChildren().stream() + .filter(f -> f.getName().equals("object_ref")) + .findFirst() + .orElse(null); + assertNotNull(objectRef); + assertTrue(objectRef.getType() instanceof ArrowType.Struct); + assertFalse(objectRef.getChildren().isEmpty()); + } + + @Test + public void arrowSchema_handlesFieldNullability() { + Schema schema = BigQuerySchema.getArrowSchema(); + + // timestamp is REQUIRED in BigQuerySchema.getEventsSchema() + Field timestampField = schema.findField("timestamp"); + assertNotNull(timestampField); + assertFalse(timestampField.isNullable()); + + // event_type is NULLABLE in BigQuerySchema.getEventsSchema() + Field eventTypeField = schema.findField("event_type"); + assertNotNull(eventTypeField); + assertTrue(eventTypeField.isNullable()); + } + + @Test + public void logEvent_populatesCommonFields() throws Exception { + final boolean[] checksPassed = {false}; + final String[] failureMessage = {null}; + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + Schema schema = BigQuerySchema.getArrowSchema(); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, plugin.batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + if (root.getRowCount() != 1) { + failureMessage[0] = "Expected 1 row, got " + root.getRowCount(); + } else if (!Objects.equals( + root.getVector("event_type").getObject(0).toString(), "USER_MESSAGE")) { + failureMessage[0] = + "Wrong event_type: " + root.getVector("event_type").getObject(0); + } else if (!root.getVector("agent").getObject(0).toString().equals("agent_name")) { + failureMessage[0] = "Wrong agent: " + root.getVector("agent").getObject(0); + } else if (!root.getVector("session_id") + .getObject(0) + .toString() + .equals("session_id")) { + failureMessage[0] = + "Wrong session_id: " + root.getVector("session_id").getObject(0); + } else if (!root.getVector("invocation_id") + .getObject(0) + .toString() + .equals("invocation_id")) { + failureMessage[0] = + "Wrong invocation_id: " + root.getVector("invocation_id").getObject(0); + } else if (!root.getVector("user_id").getObject(0).toString().equals("user_id")) { + failureMessage[0] = "Wrong user_id: " + root.getVector("user_id").getObject(0); + } else if (((TimeStampMicroTZVector) root.getVector("timestamp")).get(0) <= 0) { + failureMessage[0] = "Timestamp not populated"; + } else { + // Check content and content_parts + String contentJson = root.getVector("content").getObject(0).toString(); + if (!contentJson.contains("test message")) { + failureMessage[0] = "Wrong content: " + contentJson; + } else { + ListVector contentPartsVector = (ListVector) root.getVector("content_parts"); + if (((List) contentPartsVector.getObject(0)).isEmpty()) { + failureMessage[0] = "content_parts is empty"; + } else { + // Check attributes + String attributesJson = root.getVector("attributes").getObject(0).toString(); + if (!attributesJson.contains("global_tag") + || !attributesJson.contains("global_value")) { + failureMessage[0] = "Wrong attributes: " + attributesJson; + } else { + checksPassed[0] = true; + } + } + } + } + } catch (RuntimeException e) { + failureMessage[0] = "Exception during inspection: " + e.getMessage(); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + Content content = Content.fromParts(Part.fromText("test message")); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.batchProcessor.flush(); + + assertTrue(failureMessage[0], checksPassed[0]); + } + + @Test + public void logEvent_populatesTraceDetails() throws Exception { + String traceId = "4bf92f3577b34da6a3ce929d0e0e4736"; + String spanId = "00f067aa0ba902b7"; + + SpanContext mockSpanContext = mock(SpanContext.class); + when(mockSpanContext.isValid()).thenReturn(true); + when(mockSpanContext.getTraceId()).thenReturn(traceId); + when(mockSpanContext.getSpanId()).thenReturn(spanId); + + Span mockSpan = Span.wrap(mockSpanContext); + + try (Scope scope = mockSpan.makeCurrent()) { + Content content = Content.builder().build(); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals(traceId, row.get("trace_id")); + assertEquals(spanId, row.get("span_id")); + } + } + + @Test + public void complexType_appendsToWriter() throws Exception { + Part part = Part.fromText("test text"); + Content content = Content.fromParts(part); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void onEventCallback_populatesCorrectFields() throws Exception { + Event event = + Event.builder() + .author("agent_author") + .content(Content.fromParts(Part.fromText("event content"))) + .build(); + + plugin.onEventCallback(mockInvocationContext, event).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("EVENT", row.get("event_type")); + assertEquals("agent_name", row.get("agent")); + assertTrue(row.get("attributes").toString().contains("agent_author")); + assertTrue(row.get("content").toString().contains("event content")); + } + + @Test + public void onModelErrorCallback_populatesCorrectFields() throws Exception { + CallbackContext mockCallbackContext = mock(CallbackContext.class); + when(mockCallbackContext.invocationContext()).thenReturn(mockInvocationContext); + when(mockCallbackContext.agentName()).thenReturn("agent_in_context"); + LlmRequest.Builder mockLlmRequestBuilder = mock(LlmRequest.Builder.class); + Throwable error = new RuntimeException("model error message"); + + plugin + .onModelErrorCallback(mockCallbackContext, mockLlmRequestBuilder, error) + .blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("MODEL_ERROR", row.get("event_type")); + assertEquals("agent_in_context", row.get("agent")); + assertTrue(row.get("attributes").toString().contains("model error message")); + } + + private static class FakeAgent extends BaseAgent { + FakeAgent(String name) { + super(name, "description", null, null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + } +} diff --git a/pom.xml b/pom.xml index bd0caca0d..62082cfc9 100644 --- a/pom.xml +++ b/pom.xml @@ -73,6 +73,8 @@ 2.15.0 3.9.0 5.6 + 4.1.118.Final + @{jacoco.agent.argLine} --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.text=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/jdk.internal.misc=ALL-UNNAMED -Dio.netty.tryReflectionSetAccessible=true @@ -85,6 +87,13 @@ pom import + + io.netty + netty-bom + ${netty.version} + pom + import + com.google.cloud libraries-bom @@ -338,6 +347,8 @@ + + ${surefire.argLine} plain From 551c31f495aafde8568461cc0aa0973d7df7e5ac Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Wed, 18 Mar 2026 03:39:02 -0700 Subject: [PATCH 08/40] fix: include saveArtifact invocations in event chain PiperOrigin-RevId: 885495376 --- .../java/com/google/adk/runner/Runner.java | 12 ++-- .../com/google/adk/runner/RunnerTest.java | 55 +++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 51e1b8f25..1f7d924ab 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -313,6 +313,7 @@ private Single appendNewMessageToSession( throw new IllegalArgumentException("No parts in the new_message."); } + Completable saveArtifactsFlow = Completable.complete(); if (this.artifactService != null && saveInputBlobsAsArtifacts) { // The runner directly saves the artifacts (if applicable) in the user message and replaces // the artifact data with a file name placeholder. @@ -322,9 +323,11 @@ private Single appendNewMessageToSession( continue; } String fileName = "artifact_" + invocationContext.invocationId() + "_" + i; - var unused = - this.artifactService.saveArtifact( - this.appName, session.userId(), session.id(), fileName, part); + saveArtifactsFlow = + saveArtifactsFlow.andThen( + this.artifactService + .saveArtifact(this.appName, session.userId(), session.id(), fileName, part) + .ignoreElement()); newMessage .parts() @@ -349,7 +352,8 @@ private Single appendNewMessageToSession( EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build()); } - return this.sessionService.appendEvent(session, eventBuilder.build()); + return saveArtifactsFlow.andThen( + this.sessionService.appendEvent(session, eventBuilder.build())); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 2eb515fa2..a3e21cb73 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -24,6 +24,8 @@ import static com.google.adk.testing.TestUtils.createTextLlmResponse; import static com.google.adk.testing.TestUtils.simplifyEvents; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Arrays.stream; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.mock; @@ -36,6 +38,7 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; +import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; @@ -65,12 +68,14 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -78,6 +83,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; @RunWith(JUnit4.class) public final class RunnerTest { @@ -849,6 +855,19 @@ private Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); } + private static Content createInlineDataContent(byte[]... data) { + return Content.builder() + .parts( + stream(data) + .map(dataBytes -> Part.fromBytes(dataBytes, "example/octet-stream")) + .toArray(Part[]::new)) + .build(); + } + + private static Content createInlineDataContent(String... data) { + return createInlineDataContent(stream(data).map(d -> d.getBytes(UTF_8)).toArray(byte[][]::new)); + } + @Test public void runAsync_createsInvocationSpan() { var unused = @@ -1331,4 +1350,40 @@ public static ImmutableMap echoTool(String message) { return ImmutableMap.of("message", message); } } + + @Test + public void runner_executesSaveArtifactFlow() { + // arrange + final AtomicInteger artifactsSavedCounter = new AtomicInteger(); + BaseArtifactService mockArtifactService = Mockito.mock(BaseArtifactService.class); + when(mockArtifactService.saveArtifact(any(), any(), any(), any(), any())) + .thenReturn( + Single.defer( + () -> { + // we want to assert not only that the saveArtifact method was + // called, but also that the flow that it returned was run, so + // we need to record the call in a counter + artifactsSavedCounter.incrementAndGet(); + return Single.just(42); + })); + Runner runner = + Runner.builder() + .app(App.builder().name("test").rootAgent(agent).build()) + .artifactService(mockArtifactService) + .build(); + session = runner.sessionService().createSession("test", "user").blockingGet(); + // each inline data will be saved using our mock artifact service + Content content = createInlineDataContent("test data", "test data 2"); + RunConfig runConfig = RunConfig.builder().setSaveInputBlobsAsArtifacts(true).build(); + + // act + var events = runner.runAsync("user", session.id(), content, runConfig).test(); + + // assert + events.assertComplete(); + // artifacts were saved + assertThat(artifactsSavedCounter.get()).isEqualTo(2); + // agent was run + assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); + } } From e51f9112050955657da0dfc3aedc00f90ad739ec Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 18 Mar 2026 05:37:53 -0700 Subject: [PATCH 09/40] feat: add handling the a2a metadata in the RemoteA2AAgent; Add the enum type for the metadata keys PiperOrigin-RevId: 885539894 --- .../adk/a2a/converters/A2AMetadataKey.java | 40 ++++ .../adk/a2a/converters/AdkMetadataKey.java | 35 ++++ .../adk/a2a/converters/PartConverter.java | 19 +- .../adk/a2a/converters/ResponseConverter.java | 131 +++++++++++-- .../adk/a2a/converters/PartConverterTest.java | 50 ++++- .../a2a/converters/ResponseConverterTest.java | 175 +++++++++++++++++- 6 files changed, 408 insertions(+), 42 deletions(-) create mode 100644 a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java create mode 100644 a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java b/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java new file mode 100644 index 000000000..d4f1fef58 --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java @@ -0,0 +1,40 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.a2a.converters; + +/** + * Enum for the type of A2A metadata. Adds a prefix used to differentiage ADK-related values stored + * in Metadata an A2A event. + */ +public enum A2AMetadataKey { + TYPE("type"), + IS_LONG_RUNNING("is_long_running"), + PARTIAL("partial"), + GROUNDING_METADATA("grounding_metadata"), + USAGE_METADATA("usage_metadata"), + CUSTOM_METADATA("custom_metadata"), + ERROR_CODE("error_code"); + + private final String type; + + private A2AMetadataKey(String type) { + this.type = "adk_" + type; + } + + public String getType() { + return type; + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java b/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java new file mode 100644 index 000000000..e38f28828 --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.a2a.converters; + +/** + * Enum for the type of ADK metadata. Adds a prefix used to differentiate A2A-related values stored + * in custom metadata of an ADK session event. + */ +public enum AdkMetadataKey { + TASK_ID("task_id"), + CONTEXT_ID("context_id"); + + private final String type; + + private AdkMetadataKey(String type) { + this.type = "a2a:" + type; + } + + public String getType() { + return type; + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 61f24fa21..714a79736 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -52,11 +52,7 @@ public final class PartConverter { private static final Logger logger = LoggerFactory.getLogger(PartConverter.class); private static final ObjectMapper objectMapper = new ObjectMapper(); - // Constants for metadata types. By convention metadata keys are prefixed with "adk_" to align - // with the Python and Golang libraries. - public static final String A2A_DATA_PART_METADATA_TYPE_KEY = "adk_type"; - public static final String A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY = "adk_is_long_running"; - public static final String A2A_DATA_PART_METADATA_IS_PARTIAL_KEY = "adk_partial"; + // Constants for metadata types. public static final String LANGUAGE_KEY = "language"; public static final String OUTCOME_KEY = "outcome"; public static final String CODE_KEY = "code"; @@ -135,7 +131,7 @@ private static com.google.genai.types.Part convertDataPartToGenAiPart(DataPart d Map metadata = Optional.ofNullable(dataPart.getMetadata()).map(HashMap::new).orElseGet(HashMap::new); - String metadataType = metadata.getOrDefault(A2A_DATA_PART_METADATA_TYPE_KEY, "").toString(); + String metadataType = metadata.getOrDefault(A2AMetadataKey.TYPE.getType(), "").toString(); if ((data.containsKey(NAME_KEY) && data.containsKey(ARGS_KEY)) || metadataType.equals(A2ADataPartMetadataType.FUNCTION_CALL.getType())) { @@ -218,7 +214,7 @@ private static DataPart createDataPartFromFunctionCall( addValueIfPresent(data, WILL_CONTINUE_KEY, functionCall.willContinue()); addValueIfPresent(data, PARTIAL_ARGS_KEY, functionCall.partialArgs()); - metadata.put(A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.FUNCTION_CALL.getType()); + metadata.put(A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -245,7 +241,7 @@ private static DataPart createDataPartFromFunctionResponse( addValueIfPresent(data, PARTS_KEY, functionResponse.parts()); metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -268,7 +264,7 @@ private static DataPart createDataPartFromCodeExecutionResult( addValueIfPresent(data, OUTPUT_KEY, codeExecutionResult.output()); metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.CODE_EXECUTION_RESULT.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.CODE_EXECUTION_RESULT.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -290,8 +286,7 @@ private static DataPart createDataPartFromExecutableCode( .orElse(Language.Known.LANGUAGE_UNSPECIFIED.toString())); addValueIfPresent(data, CODE_KEY, executableCode.code()); - metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.EXECUTABLE_CODE.getType()); + metadata.put(A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.EXECUTABLE_CODE.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -305,7 +300,7 @@ public static io.a2a.spec.Part fromGenaiPart(Part part, boolean isPartial) { } ImmutableMap.Builder metadata = ImmutableMap.builder(); if (isPartial) { - metadata.put(A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, true); + metadata.put(A2AMetadataKey.PARTIAL.getType(), true); } if (part.text().isPresent()) { diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java index 503432a30..cffd76983 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java @@ -19,12 +19,20 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Streams.zip; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.GroundingMetadata; import com.google.genai.types.Part; import io.a2a.client.ClientEvent; import io.a2a.client.MessageEvent; @@ -43,11 +51,13 @@ import java.util.Objects; import java.util.Optional; import java.util.UUID; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** Utility for converting ADK events to A2A spec messages (and back). */ public final class ResponseConverter { + private static final ObjectMapper objectMapper = new ObjectMapper(); private static final Logger logger = LoggerFactory.getLogger(ResponseConverter.class); private static final ImmutableSet PENDING_STATES = ImmutableSet.of(TaskState.WORKING, TaskState.SUBMITTED); @@ -74,12 +84,11 @@ public static Optional clientEventToEvent( throw new IllegalArgumentException("Unsupported ClientEvent type: " + event.getClass()); } - private static boolean isPartial(Map metadata) { + private static boolean isPartial(@Nullable Map metadata) { if (metadata == null) { return false; } - return Objects.equals( - metadata.getOrDefault(PartConverter.A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, false), true); + return Objects.equals(metadata.getOrDefault(A2AMetadataKey.PARTIAL.getType(), false), true); } /** @@ -110,7 +119,12 @@ private static Optional handleTaskUpdate( // append=false, lastChunk=false: emit as partial, reset aggregation // append=true, lastChunk=true: emit as partial, update aggregation and emit as non-partial // append=false, lastChunk=true: emit as non-partial, drop aggregation - return Optional.of(eventPart); + return Optional.of( + updateEventMetadata( + eventPart, + artifactEvent.getMetadata(), + artifactEvent.getTaskId(), + artifactEvent.getContextId())); } if (updateEvent instanceof TaskStatusUpdateEvent statusEvent) { @@ -128,14 +142,21 @@ private static Optional handleTaskUpdate( }); if (statusEvent.isFinal()) { - return messageEvent - .map(Event::toBuilder) - .or(() -> Optional.of(remoteAgentEventBuilder(context))) - .map(builder -> builder.turnComplete(true)) - .map(builder -> builder.partial(false)) - .map(Event.Builder::build); + messageEvent = + messageEvent + .map(Event::toBuilder) + .or(() -> Optional.of(remoteAgentEventBuilder(context))) + .map(builder -> builder.turnComplete(true)) + .map(builder -> builder.partial(false)) + .map(Event.Builder::build); } - return messageEvent; + return messageEvent.map( + finalMessageEvent -> + updateEventMetadata( + finalMessageEvent, + statusEvent.getMetadata(), + statusEvent.getTaskId(), + statusEvent.getContextId())); } throw new IllegalArgumentException( "Unsupported TaskUpdateEvent type: " + updateEvent.getClass()); @@ -163,9 +184,13 @@ public static Event messageToFailedEvent(Message message, InvocationContext invo /** Converts an A2A message back to ADK events. */ public static Event messageToEvent(Message message, InvocationContext invocationContext) { - return remoteAgentEventBuilder(invocationContext) - .content(fromModelParts(PartConverter.toGenaiParts(message.getParts()))) - .build(); + return updateEventMetadata( + remoteAgentEventBuilder(invocationContext) + .content(fromModelParts(PartConverter.toGenaiParts(message.getParts()))) + .build(), + message.getMetadata(), + message.getTaskId(), + message.getContextId()); } /** @@ -228,7 +253,8 @@ public static Event taskToEvent(Task task, InvocationContext invocationContext) eventBuilder.longRunningToolIds(longRunningToolIds.build()); } eventBuilder.turnComplete(isFinal); - return eventBuilder.build(); + return updateEventMetadata( + eventBuilder.build(), task.getMetadata(), task.getId(), task.getContextId()); } private static ImmutableSet getLongRunningToolIds( @@ -241,9 +267,7 @@ private static ImmutableSet getLongRunningToolIds( return Optional.empty(); } Object isLongRunning = - dataPart - .getMetadata() - .get(PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY); + dataPart.getMetadata().get(A2AMetadataKey.IS_LONG_RUNNING.getType()); if (!Objects.equals(isLongRunning, true)) { return Optional.empty(); } @@ -256,6 +280,77 @@ private static ImmutableSet getLongRunningToolIds( .collect(toImmutableSet()); } + private static Event updateEventMetadata( + Event event, + @Nullable Map clientMetadata, + @Nullable String taskId, + @Nullable String contextId) { + if (taskId == null || contextId == null) { + logger.warn("Task ID or context ID is null, skipping metadata update."); + return event; + } + + if (clientMetadata == null) { + clientMetadata = ImmutableMap.of(); + } + Event.Builder eventBuilder = event.toBuilder(); + Object groundingMetadata = clientMetadata.get(A2AMetadataKey.GROUNDING_METADATA.getType()); + // if groundingMetadata is null, parseMetadata will return null as well. + eventBuilder.groundingMetadata(parseMetadata(groundingMetadata, GroundingMetadata.class)); + Object usageMetadata = clientMetadata.get(A2AMetadataKey.USAGE_METADATA.getType()); + // if usageMetadata is null, parseMetadata will return null as well. + eventBuilder.usageMetadata( + parseMetadata(usageMetadata, GenerateContentResponseUsageMetadata.class)); + + ImmutableList.Builder customMetadataList = ImmutableList.builder(); + customMetadataList + .add( + CustomMetadata.builder() + .key(AdkMetadataKey.TASK_ID.getType()) + .stringValue(taskId) + .build()) + .add( + CustomMetadata.builder() + .key(AdkMetadataKey.CONTEXT_ID.getType()) + .stringValue(contextId) + .build()); + Object customMetadata = clientMetadata.get(A2AMetadataKey.CUSTOM_METADATA.getType()); + if (customMetadata != null) { + customMetadataList.addAll( + parseMetadata(customMetadata, new TypeReference>() {})); + } + eventBuilder.customMetadata(customMetadataList.build()); + + Object errorCode = clientMetadata.get(A2AMetadataKey.ERROR_CODE.getType()); + eventBuilder.errorCode(parseMetadata(errorCode, FinishReason.class)); + + return eventBuilder.build(); + } + + private static @Nullable T parseMetadata(@Nullable Object metadata, Class type) { + try { + if (metadata instanceof String jsonString) { + return objectMapper.readValue(jsonString, type); + } else { + return objectMapper.convertValue(metadata, type); + } + } catch (IllegalArgumentException | JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse metadata of type " + type, e); + } + } + + private static @Nullable T parseMetadata(@Nullable Object metadata, TypeReference type) { + try { + if (metadata instanceof String jsonString) { + return objectMapper.readValue(jsonString, type); + } else { + return objectMapper.convertValue(metadata, type); + } + } catch (IllegalArgumentException | JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse metadata of type " + type.getType(), e); + } + } + private static Event emptyEvent(InvocationContext invocationContext) { Event.Builder builder = Event.builder() diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java index d93466dd2..4a0828c43 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java @@ -8,9 +8,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Blob; +import com.google.genai.types.CodeExecutionResult; +import com.google.genai.types.ExecutableCode; import com.google.genai.types.FileData; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Language; +import com.google.genai.types.Outcome; import com.google.genai.types.Part; import io.a2a.spec.DataPart; import io.a2a.spec.FilePart; @@ -86,8 +90,7 @@ public void toGenaiPart_withDataPartFunctionCall_returnsGenaiFunctionCallPart() new DataPart( data, ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_CALL.getType())); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType())); Part result = PartConverter.toGenaiPart(dataPart); @@ -121,7 +124,7 @@ public void toGenaiPart_withDataPartFunctionResponse_returnsGenaiFunctionRespons new DataPart( data, ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType())); Part result = PartConverter.toGenaiPart(dataPart); @@ -188,7 +191,7 @@ public void fromGenaiPart_withTextPart_returnsTextPart() { assertThat(((TextPart) result).getText()).isEqualTo("text"); assertThat(((TextPart) result).getMetadata()).containsEntry("thought", true); assertThat(((TextPart) result).getMetadata()) - .containsEntry(PartConverter.A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, true); + .containsEntry(A2AMetadataKey.PARTIAL.getType(), true); } @Test @@ -226,6 +229,39 @@ public void fromGenaiPart_withInlineDataPart_returnsFilePartWithBytes() { assertThat(Base64.getDecoder().decode(fileWithBytes.bytes())).isEqualTo(bytes); } + @Test + public void fromGenaiPart_dataPart_executableCode_returnsDataPart() { + ExecutableCode executableCode = + ExecutableCode.builder().code("print('hello')").language(new Language("python")).build(); + Part part = Part.builder().executableCode(executableCode).build(); + io.a2a.spec.Part result = PartConverter.fromGenaiPart(part, false); + + assertThat(result).isInstanceOf(DataPart.class); + DataPart dataPart = (DataPart) result; + assertThat(dataPart.getData().get("code")).isEqualTo("print('hello')"); + assertThat(dataPart.getData().get("language")).isEqualTo("python"); + assertThat(dataPart.getMetadata().get(A2AMetadataKey.TYPE.getType())) + .isEqualTo("executable_code"); + } + + @Test + public void fromGenaiPart_dataPart_codeExecutionResult_returnsDataPart() { + CodeExecutionResult codeExecutionResult = + CodeExecutionResult.builder() + .outcome(new Outcome("OUTCOME_OK")) + .output("print('hello')") + .build(); + Part part = Part.builder().codeExecutionResult(codeExecutionResult).build(); + io.a2a.spec.Part result = PartConverter.fromGenaiPart(part, false); + + assertThat(result).isInstanceOf(DataPart.class); + DataPart dataPart = (DataPart) result; + assertThat(dataPart.getData().get("outcome")).isEqualTo("OUTCOME_OK"); + assertThat(dataPart.getData().get("output")).isEqualTo("print('hello')"); + assertThat(dataPart.getMetadata().get(A2AMetadataKey.TYPE.getType())) + .isEqualTo("code_execution_result"); + } + @Test public void fromGenaiPart_withFunctionCallPart_returnsDataPart() { Part part = @@ -255,8 +291,7 @@ public void fromGenaiPart_withFunctionCallPart_returnsDataPart() { true); assertThat(dataPart.getMetadata()) .containsEntry( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_CALL.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType()); } @Test @@ -275,8 +310,7 @@ public void fromGenaiPart_withFunctionResponsePart_returnsDataPart() { .containsExactly("name", "func", "id", "1", "response", ImmutableMap.of()); assertThat(dataPart.getMetadata()) .containsEntry( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); } @Test diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java index d84dc42cd..b61b00e1a 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java @@ -1,6 +1,8 @@ package com.google.adk.a2a.converters; import static com.google.common.truth.Truth.assertThat; +import static java.util.stream.Collectors.joining; +import static org.junit.Assert.assertThrows; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; @@ -13,6 +15,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.GroundingMetadata; import io.a2a.client.MessageEvent; import io.a2a.client.TaskUpdateEvent; import io.a2a.spec.Artifact; @@ -136,6 +142,74 @@ public void taskToEvent_withStatusMessage_returnsEvent() { assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); } + @Test + public void taskToEvent_withGroundingMetadata_returnsEvent() { + GroundingMetadata groundingMetadata = + GroundingMetadata.builder().webSearchQueries("test-query").build(); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of( + A2AMetadataKey.GROUNDING_METADATA.getType(), groundingMetadata.toJson())) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); + assertThat(event.groundingMetadata()).hasValue(groundingMetadata); + } + + @Test + public void taskToEvent_withCustomMetadata_returnsEvent() { + ImmutableList customMetadataList = + ImmutableList.of( + CustomMetadata.builder().key("test-key").stringValue("test-value").build()); + String customMetadataJson = + customMetadataList.stream().map(CustomMetadata::toJson).collect(joining(",", "[", "]")); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata(ImmutableMap.of(A2AMetadataKey.CUSTOM_METADATA.getType(), customMetadataJson)) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); + assertThat(event.customMetadata().get()) + .containsExactly( + CustomMetadata.builder().key("a2a:task_id").stringValue("task-1").build(), + CustomMetadata.builder().key("a2a:context_id").stringValue("context-1").build(), + CustomMetadata.builder().key("test-key").stringValue("test-value").build()) + .inOrder(); + } + + @Test + public void messageToEvent_withMissingTaskId_returnsEvent() { + Message a2aMessage = + new Message.Builder() + .messageId("msg-1") + .role(Message.Role.USER) + .taskId("task-1") + .parts(ImmutableList.of(new TextPart("test-message"))) + .build(); + Event event = ResponseConverter.messageToEvent(a2aMessage, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.customMetadata()).isEmpty(); + } + @Test public void taskToEvent_withNoMessage_returnsEmptyEvent() { TaskStatus status = new TaskStatus(TaskState.WORKING, null, null); @@ -152,18 +226,18 @@ public void taskToEvent_withInputRequired_parsesLongRunningToolIds() { ImmutableMap.of("name", "myTool", "id", "call_123", "args", ImmutableMap.of()); ImmutableMap metadata = ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), "function_call", - PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, + A2AMetadataKey.IS_LONG_RUNNING.getType(), true); DataPart dataPart = new DataPart(data, metadata); ImmutableMap statusData = ImmutableMap.of("name", "messageTools", "id", "msg_123", "args", ImmutableMap.of()); ImmutableMap statusMetadata = ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), "function_call", - PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, + A2AMetadataKey.IS_LONG_RUNNING.getType(), true); DataPart statusDataPart = new DataPart(statusData, statusMetadata); Message statusMessage = @@ -361,6 +435,99 @@ public void clientEventToEvent_withFailedTaskStatusUpdateEvent_returnsErrorEvent assertThat(resultEvent.turnComplete()).hasValue(true); } + @Test + public void taskToEvent_withInvalidMetadata_throwsException() { + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of(A2AMetadataKey.GROUNDING_METADATA.getType(), "{ invalid json ]")) + .build(); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> ResponseConverter.taskToEvent(task, invocationContext)); + assertThat(exception).hasMessageThat().contains("Failed to parse metadata"); + assertThat(exception).hasMessageThat().contains("GroundingMetadata"); + } + + @Test + public void taskToEvent_withErrorCode_returnsEvent() { + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata(ImmutableMap.of(A2AMetadataKey.ERROR_CODE.getType(), "\"STOP\"")) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.errorCode()).hasValue(new FinishReason(FinishReason.Known.STOP)); + } + + @Test + public void taskToEvent_withUsageMetadata_returnsEvent() { + GenerateContentResponseUsageMetadata usageMetadata = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of(A2AMetadataKey.USAGE_METADATA.getType(), usageMetadata.toJson())) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.usageMetadata()).hasValue(usageMetadata); + } + + @Test + public void clientEventToEvent_withTaskArtifactUpdateEventAndPartialTrue_returnsEmpty() { + io.a2a.spec.Part a2aPart = new TextPart("Artifact content"); + Artifact artifact = + new Artifact.Builder().artifactId("artifact-1").parts(ImmutableList.of(a2aPart)).build(); + Task task = + testTask() + .status(new TaskStatus(TaskState.COMPLETED)) + .artifacts(ImmutableList.of(artifact)) + .build(); + TaskArtifactUpdateEvent updateEvent = + new TaskArtifactUpdateEvent.Builder() + .lastChunk(true) + .metadata(ImmutableMap.of(A2AMetadataKey.PARTIAL.getType(), true)) + .contextId("context-1") + .artifact(artifact) + .taskId("task-id-1") + .build(); + TaskUpdateEvent event = new TaskUpdateEvent(task, updateEvent); + + Optional optionalEvent = ResponseConverter.clientEventToEvent(event, invocationContext); + assertThat(optionalEvent).isEmpty(); + } + private static final class TestAgent extends BaseAgent { TestAgent() { super("test_agent", "test", ImmutableList.of(), null, null); From 0d1e5c7b0c42cea66b178cf8fedf08a8c20f7fd0 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 18 Mar 2026 07:02:14 -0700 Subject: [PATCH 10/40] feat: update stateDelta builder input to Map from ConcurrentMap PiperOrigin-RevId: 885570460 --- .../main/java/com/google/adk/events/EventActions.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 1ca856b45..3565c3e99 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -287,8 +287,14 @@ public Builder skipSummarization(boolean skipSummarization) { @CanIgnoreReturnValue @JsonProperty("stateDelta") - public Builder stateDelta(ConcurrentMap value) { - this.stateDelta = value; + public Builder stateDelta(@Nullable Map value) { + if (value == null) { + this.stateDelta = new ConcurrentHashMap<>(); + } else if (value instanceof ConcurrentMap) { + this.stateDelta = (ConcurrentMap) value; + } else { + this.stateDelta = new ConcurrentHashMap<>(value); + } return this; } From de3b2767748436b07f55e7d00034d77d7d940579 Mon Sep 17 00:00:00 2001 From: Greg Brail Date: Mon, 12 Jan 2026 15:26:20 -0800 Subject: [PATCH 11/40] Remove ADK dependency for langchain4j module --- contrib/langchain4j/pom.xml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index c2326fa0a..3dd2d1132 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -58,11 +58,6 @@ google-adk ${project.version} - - com.google.adk - google-adk-dev - ${project.version} - com.google.genai google-genai From 3ba04d33dc8f2ef8b151abe1be4d1c8b7afcc25a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 18 Mar 2026 07:59:31 -0700 Subject: [PATCH 12/40] fix: workaround for the client config streaming settings are not respected (#983) PiperOrigin-RevId: 885595843 --- .../com/google/adk/a2a/agent/RemoteA2AAgent.java | 13 ++++++++++++- .../google/adk/a2a/agent/RemoteA2AAgentTest.java | 16 +++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java index ccb662b7c..4a375980c 100644 --- a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java @@ -117,7 +117,7 @@ private RemoteA2AAgent(Builder builder) { if (this.description.isEmpty() && this.agentCard.description() != null) { this.description = this.agentCard.description(); } - this.streaming = this.agentCard.capabilities().streaming(); + this.streaming = builder.streaming && this.agentCard.capabilities().streaming(); } public static Builder builder() { @@ -133,6 +133,13 @@ public static class Builder { private List subAgents; private List beforeAgentCallback; private List afterAgentCallback; + private boolean streaming; + + @CanIgnoreReturnValue + public Builder streaming(boolean streaming) { + this.streaming = streaming; + return this; + } @CanIgnoreReturnValue public Builder name(String name) { @@ -181,6 +188,10 @@ public RemoteA2AAgent build() { } } + public boolean isStreaming() { + return streaming; + } + private Message.Builder newA2AMessage(Message.Role role, List> parts) { return new Message.Builder().messageId(UUID.randomUUID().toString()).role(role).parts(parts); } diff --git a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java index b1ffa248a..0609c3b04 100644 --- a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java @@ -113,6 +113,20 @@ public void setUp() { .build(); } + @Test + public void createAgent_streaming_false_returnsNonStreamingAgent() { + // With streaming false, the agent should not stream even if the AgentCard supports streaming. + RemoteA2AAgent agent = getAgentBuilder().streaming(false).build(); + assertThat(agent.isStreaming()).isFalse(); + } + + @Test + public void createAgent_streaming_true_returnsStreamingAgent() { + // With streaming true, the agent should support streaming if the AgentCard supports streaming. + RemoteA2AAgent agent = getAgentBuilder().streaming(true).build(); + assertThat(agent.isStreaming()).isTrue(); + } + @Test public void runAsync_aggregatesPartialEvents() { RemoteA2AAgent agent = createAgent(); @@ -763,7 +777,7 @@ private RemoteA2AAgent.Builder getAgentBuilder() { } private RemoteA2AAgent createAgent() { - return getAgentBuilder().build(); + return getAgentBuilder().streaming(true).build(); } @SuppressWarnings("unchecked") // cast for Mockito From 94de7f199f86b39bdb7cce6e9800eb05008a8953 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Wed, 18 Mar 2026 09:37:12 -0700 Subject: [PATCH 13/40] fix: Use ConcurrentHashMap in InvocationReplayState fixes #1009 PiperOrigin-RevId: 885641755 --- .../java/com/google/adk/plugins/InvocationReplayState.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java b/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java index 7d70a0efb..eab293de0 100644 --- a/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java +++ b/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java @@ -16,8 +16,8 @@ package com.google.adk.plugins; import com.google.adk.plugins.recordings.Recordings; -import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** Per-invocation replay state to isolate concurrent runs. */ class InvocationReplayState { @@ -33,7 +33,7 @@ public InvocationReplayState(String testCasePath, int userMessageIndex, Recordin this.testCasePath = testCasePath; this.userMessageIndex = userMessageIndex; this.recordings = recordings; - this.agentReplayIndices = new HashMap<>(); + this.agentReplayIndices = new ConcurrentHashMap<>(); } public String getTestCasePath() { @@ -57,7 +57,6 @@ public void setAgentReplayIndex(String agentName, int index) { } public void incrementAgentReplayIndex(String agentName) { - int currentIndex = getAgentReplayIndex(agentName); - setAgentReplayIndex(agentName, currentIndex + 1); + agentReplayIndices.merge(agentName, 1, Integer::sum); } } From 2c71ba1332e052189115cd4644b7a473c31ed414 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fatih=20C=C3=BCre?= Date: Wed, 18 Mar 2026 10:06:35 +0300 Subject: [PATCH 14/40] feat: Enhance LangChain4j to support MCP tools with parametersJsonSchema --- .../adk/models/langchain4j/LangChain4j.java | 23 +++- .../models/langchain4j/LangChain4jTest.java | 124 ++++++++++++++++++ .../com/google/adk/models/LlmRequest.java | 2 +- 3 files changed, 144 insertions(+), 5 deletions(-) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 80c25610d..3ccb1e029 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; import com.google.adk.models.BaseLlm; import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; @@ -428,8 +429,24 @@ private List toToolSpecifications(LlmRequest llmRequest) { baseTool -> { if (baseTool.declaration().isPresent()) { FunctionDeclaration functionDeclaration = baseTool.declaration().get(); - if (functionDeclaration.parameters().isPresent()) { - Schema schema = functionDeclaration.parameters().get(); + Schema schema = null; + if (functionDeclaration.parametersJsonSchema().isPresent()) { + Object jsonSchemaObj = functionDeclaration.parametersJsonSchema().get(); + try { + if (jsonSchemaObj instanceof Schema) { + schema = (Schema) jsonSchemaObj; + } else { + schema = JsonBaseModel.getMapper().convertValue(jsonSchemaObj, Schema.class); + } + } catch (Exception e) { + throw new IllegalStateException( + "Failed to convert parametersJsonSchema to Schema: " + e.getMessage(), e); + } + } else if (functionDeclaration.parameters().isPresent()) { + schema = functionDeclaration.parameters().get(); + } + + if (schema != null) { ToolSpecification toolSpecification = ToolSpecification.builder() .name(baseTool.name()) @@ -438,11 +455,9 @@ private List toToolSpecifications(LlmRequest llmRequest) { .build(); toolSpecifications.add(toolSpecification); } else { - // TODO exception or something else? throw new IllegalStateException("Tool lacking parameters: " + baseTool); } } else { - // TODO exception or something else? throw new IllegalStateException("Tool lacking declaration: " + baseTool); } }); diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 428a5660c..076bb79a3 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -688,4 +688,128 @@ void testGenerateContentWithStructuredResponseJsonSchema() { final UserMessage userMessage = (UserMessage) capturedRequest.messages().get(0); assertThat(userMessage.singleText()).isEqualTo("Give me information about John Doe"); } + + @Test + @DisplayName("Should handle MCP tools with parametersJsonSchema") + void testGenerateContentWithMcpToolParametersJsonSchema() { + // Given + // Create a mock BaseTool for MCP tool + final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class); + when(mcpTool.name()).thenReturn("mcpTool"); + when(mcpTool.description()).thenReturn("An MCP tool"); + + // Create a mock FunctionDeclaration + final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class); + when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration)); + + // MCP tools use parametersJsonSchema() instead of parameters() + // Create a JSON schema object (Map representation) + final Map jsonSchemaMap = + Map.of( + "type", + "object", + "properties", + Map.of("city", Map.of("type", "string", "description", "City name")), + "required", + List.of("city")); + + // Mock parametersJsonSchema() to return the JSON schema object + when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(jsonSchemaMap)); + when(functionDeclaration.parameters()).thenReturn(Optional.empty()); + + // Create a LlmRequest with the MCP tool + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool")))) + .tools(Map.of("mcpTool", mcpTool)) + .build(); + + // Mock the AI response + final AiMessage aiMessage = AiMessage.from("Tool executed successfully"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Tool executed successfully"); + + // Verify the request was built correctly with the tool specification + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool specifications were created from parametersJsonSchema + assertThat(capturedRequest.toolSpecifications()).isNotEmpty(); + assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); + assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); + } + + @Test + @DisplayName("Should handle MCP tools with parametersJsonSchema when it's already a Schema") + void testGenerateContentWithMcpToolParametersJsonSchemaAsSchema() { + // Given + // Create a mock BaseTool for MCP tool + final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class); + when(mcpTool.name()).thenReturn("mcpTool"); + when(mcpTool.description()).thenReturn("An MCP tool"); + + // Create a mock FunctionDeclaration + final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class); + when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration)); + + // Create a Schema object directly (when parametersJsonSchema returns Schema) + final Schema cityPropertySchema = + Schema.builder().type("STRING").description("City name").build(); + + final Schema objectSchema = + Schema.builder() + .type("OBJECT") + .properties(Map.of("city", cityPropertySchema)) + .required(List.of("city")) + .build(); + + // Mock parametersJsonSchema() to return Schema directly + when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(objectSchema)); + when(functionDeclaration.parameters()).thenReturn(Optional.empty()); + + // Create a LlmRequest with the MCP tool + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool")))) + .tools(Map.of("mcpTool", mcpTool)) + .build(); + + // Mock the AI response + final AiMessage aiMessage = AiMessage.from("Tool executed successfully"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Tool executed successfully"); + + // Verify the request was built correctly with the tool specification + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool specifications were created from parametersJsonSchema + assertThat(capturedRequest.toolSpecifications()).isNotEmpty(); + assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); + assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); + } } diff --git a/core/src/main/java/com/google/adk/models/LlmRequest.java b/core/src/main/java/com/google/adk/models/LlmRequest.java index 1a45c3a95..760a7c1c6 100644 --- a/core/src/main/java/com/google/adk/models/LlmRequest.java +++ b/core/src/main/java/com/google/adk/models/LlmRequest.java @@ -150,7 +150,7 @@ private static Builder create() { abstract LiveConnectConfig liveConnectConfig(); @CanIgnoreReturnValue - abstract Builder tools(Map tools); + public abstract Builder tools(Map tools); abstract Map tools(); From fa67101fe0555e8bbed5cf304d00550a56308222 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 18 Mar 2026 14:05:17 -0700 Subject: [PATCH 15/40] ADK changes PiperOrigin-RevId: 885777704 --- .../google/adk/agents/InvocationContext.java | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 91ce13a87..365f4f8c1 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -90,16 +90,6 @@ public Builder toBuilder() { return new Builder(this); } - /** - * Creates a shallow copy of the given {@link InvocationContext}. - * - * @deprecated Use {@code other.toBuilder().build()} instead. - */ - @Deprecated(forRemoval = true) - public static InvocationContext copyOf(InvocationContext other) { - return other.toBuilder().build(); - } - /** Returns the session service for managing session state. */ public BaseSessionService sessionService() { return sessionService; @@ -156,16 +146,6 @@ public BaseAgent agent() { return agent; } - /** - * Sets the [agent] being invoked. This is useful when delegating to a sub-agent. - * - * @deprecated Use {@link #toBuilder()} and {@link Builder#agent(BaseAgent)} instead. - */ - @Deprecated(forRemoval = true) - public void agent(BaseAgent agent) { - this.agent = agent; - } - /** Returns the session associated with this invocation. */ public Session session() { return session; From d7e03eeb067b83abd2afa3ea9bb5fc1c16143245 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Mar 2026 01:55:07 -0700 Subject: [PATCH 16/40] fix: Relaxing constraints for output schema These changes are now in sync with Python ADK PiperOrigin-RevId: 886040294 --- .../java/com/google/adk/agents/LlmAgent.java | 34 ------- .../com/google/adk/agents/LlmAgentTest.java | 98 +++++-------------- 2 files changed, 26 insertions(+), 106 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 89024a59b..b387aee34 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -594,40 +594,6 @@ protected void validate() { this.disallowTransferToParent != null && this.disallowTransferToParent; this.disallowTransferToPeers = this.disallowTransferToPeers != null && this.disallowTransferToPeers; - - if (this.outputSchema != null) { - if (!this.disallowTransferToParent || !this.disallowTransferToPeers) { - logger.warn( - "Invalid config for agent {}: outputSchema cannot co-exist with agent transfer" - + " configurations. Setting disallowTransferToParent=true and" - + " disallowTransferToPeers=true.", - this.name); - this.disallowTransferToParent = true; - this.disallowTransferToPeers = true; - } - - if (this.subAgents != null && !this.subAgents.isEmpty()) { - throw new IllegalArgumentException( - "Invalid config for agent " - + this.name - + ": if outputSchema is set, subAgents must be empty to disable agent" - + " transfer."); - } - if (this.toolsUnion != null && !this.toolsUnion.isEmpty()) { - boolean hasOtherTools = - this.toolsUnion.stream() - .anyMatch( - tool -> - !(tool instanceof BaseTool baseTool) - || !baseTool.name().equals("example_tool")); - if (hasOtherTools) { - throw new IllegalArgumentException( - "Invalid config for agent " - + this.name - + ": if outputSchema is set, tools must be empty."); - } - } - } } @Override diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index c193e4a65..a9e7a6f8d 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -26,7 +26,6 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThrows; import com.google.adk.agents.Callbacks.AfterModelCallback; import com.google.adk.agents.Callbacks.AfterToolCallback; @@ -52,9 +51,9 @@ import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; -import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Part; import com.google.genai.types.Schema; +import com.google.genai.types.Type; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; @@ -63,7 +62,6 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.After; @@ -213,75 +211,6 @@ public void run_withToolsAndMaxSteps_stopsAfterMaxSteps() { assertEqualIgnoringFunctionIds(events.get(3).content().get(), expectedFunctionResponseContent); } - @Test - public void build_withOutputSchemaAndTools_throwsIllegalArgumentException() { - BaseTool tool = - new BaseTool("test_tool", "test_description") { - @Override - public Optional declaration() { - return Optional.empty(); - } - }; - - Schema outputSchema = - Schema.builder() - .type("OBJECT") - .properties(ImmutableMap.of("status", Schema.builder().type("STRING").build())) - .required(ImmutableList.of("status")) - .build(); - - // Expecting an IllegalArgumentException when building the agent - IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> - LlmAgent.builder() // Use the agent builder directly - .name("agent with invalid tool config") - .outputSchema(outputSchema) // Set the output schema - .tools(ImmutableList.of(tool)) // Set tools (this should cause the error) - .build()); // Attempt to build the agent - - assertThat(exception) - .hasMessageThat() - .contains( - "Invalid config for agent agent with invalid tool config: if outputSchema is set, tools" - + " must be empty"); - } - - @Test - public void build_withOutputSchemaAndSubAgents_throwsIllegalArgumentException() { - ImmutableList subAgents = - ImmutableList.of( - createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .name("test_sub_agent") - .description("test_sub_agent_description") - .build()); - - Schema outputSchema = - Schema.builder() - .type("OBJECT") - .properties(ImmutableMap.of("status", Schema.builder().type("STRING").build())) - .required(ImmutableList.of("status")) - .build(); - - // Expecting an IllegalArgumentException when building the agent - IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> - LlmAgent.builder() // Use the agent builder directly - .name("agent with invalid tool config") - .outputSchema(outputSchema) // Set the output schema - .subAgents(subAgents) // Set subAgents (this should cause the error) - .build()); // Attempt to build the agent - - assertThat(exception) - .hasMessageThat() - .contains( - "Invalid config for agent agent with invalid tool config: if outputSchema is set," - + " subAgents must be empty to disable agent transfer."); - } - @Test public void testBuild_withNullInstruction_setsInstructionToEmptyString() { LlmAgent agent = @@ -645,6 +574,31 @@ public void runAsync_withSubAgents_createsSpans() throws InterruptedException { assertThat(llmSpans).hasSize(2); // One for main agent, one for sub agent } + @Test + public void run_outputSchemaWithTools_allowed() { + Schema personShema = + Schema.builder() + .type(Type.Known.OBJECT) + .properties( + ImmutableMap.of( + "name", Schema.builder().type(Type.Known.STRING).build(), + "age", Schema.builder().type(Type.Known.INTEGER).build(), + "city", Schema.builder().type(Type.Known.STRING).build())) + .build(); + LlmAgent agent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .outputSchema(personShema) + .tools(new EchoTool()) + .build(); + assertThat(agent.outputSchema()).hasValue(personShema); + assertThat( + agent + .canonicalTools(new ReadonlyContext(createInvocationContext(agent))) + .count() + .blockingGet()) + .isEqualTo(1); + } + private List findSpansByName(List spans, String name) { return spans.stream().filter(s -> s.getName().equals(name)).toList(); } From e534f12bd5c7cadb8a6100b00ac2ae771a868ab0 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Mar 2026 05:59:19 -0700 Subject: [PATCH 17/40] refactor: Update map handling in EventActions to always use defensive copy and add null handling for `artifactDelta` in the Builder PiperOrigin-RevId: 886130618 --- .../sessions/FirestoreSessionServiceTest.java | 99 ------------------- .../com/google/adk/events/EventActions.java | 17 ++-- .../google/adk/events/EventActionsTest.java | 11 --- 3 files changed, 6 insertions(+), 121 deletions(-) diff --git a/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java b/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java index ffcbcf8f5..43ca6889f 100644 --- a/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java +++ b/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java @@ -47,15 +47,11 @@ import com.google.genai.types.Part; import io.reactivex.rxjava3.observers.TestObserver; import java.time.Instant; -import java.util.AbstractMap; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -530,44 +526,6 @@ void appendAndGet_withAllPartTypes_serializesAndDeserializesCorrectly() { }); } - /** - * A wrapper class that implements ConcurrentMap but delegates to a HashMap. This is a workaround - * to allow putting null values, which ConcurrentHashMap forbids, for testing state removal logic. - */ - private static class HashMapAsConcurrentMap extends AbstractMap - implements ConcurrentMap { - private final HashMap map; - - public HashMapAsConcurrentMap(Map map) { - this.map = new HashMap<>(map); - } - - @Override - public Set> entrySet() { - return map.entrySet(); - } - - @Override - public V putIfAbsent(K key, V value) { - return map.putIfAbsent(key, value); - } - - @Override - public boolean remove(Object key, Object value) { - return map.remove(key, value); - } - - @Override - public boolean replace(K key, V oldValue, V newValue) { - return map.replace(key, oldValue, newValue); - } - - @Override - public V replace(K key, V value) { - return map.replace(key, value); - } - } - /** Tests that appendEvent with only app state deltas updates the correct stores. */ @Test void appendEvent_withAppOnlyStateDeltas_updatesCorrectStores() { @@ -662,63 +620,6 @@ void appendEvent_withUserOnlyStateDeltas_updatesCorrectStores() { verify(mockSessionDocRef, never()).update(eq(Constants.KEY_STATE), any()); } - /** - * Tests that appendEvent with all types of state deltas updates the correct stores and session - * state. - */ - @Test - void appendEvent_withAllStateDeltas_updatesCorrectStores() { - // Arrange - Session session = - Session.builder(SESSION_ID) - .appName(APP_NAME) - .userId(USER_ID) - .state(new ConcurrentHashMap<>()) // The session state itself must be concurrent - .build(); - session.state().put("keyToRemove", "someValue"); - - Map stateDeltaMap = new HashMap<>(); - stateDeltaMap.put("sessionKey", "sessionValue"); - stateDeltaMap.put("_app_appKey", "appValue"); - stateDeltaMap.put("_user_userKey", "userValue"); - stateDeltaMap.put("keyToRemove", null); - - // Use the wrapper to satisfy the ConcurrentMap interface for the builder - EventActions actions = - EventActions.builder().stateDelta(new HashMapAsConcurrentMap<>(stateDeltaMap)).build(); - - Event event = - Event.builder() - .author("model") - .content(Content.builder().parts(List.of(Part.fromText("..."))).build()) - .actions(actions) - .build(); - - when(mockSessionsCollection.document(SESSION_ID)).thenReturn(mockSessionDocRef); - when(mockEventsCollection.document()).thenReturn(mockEventDocRef); - when(mockEventDocRef.getId()).thenReturn(EVENT_ID); - // THIS IS THE MISSING MOCK: Stub the call to get the document by its specific ID. - when(mockEventsCollection.document(EVENT_ID)).thenReturn(mockEventDocRef); - // Add the missing mock for the final session update call - when(mockSessionDocRef.update(anyMap())) - .thenReturn(ApiFutures.immediateFuture(mockWriteResult)); - - // Act - sessionService.appendEvent(session, event).test().assertComplete(); - - // Assert - assertThat(session.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(session.state()).doesNotContainKey("keyToRemove"); - - ArgumentCaptor> appStateCaptor = ArgumentCaptor.forClass(Map.class); - verify(mockAppStateDocRef).set(appStateCaptor.capture(), any(SetOptions.class)); - assertThat(appStateCaptor.getValue()).containsEntry("appKey", "appValue"); - - ArgumentCaptor> userStateCaptor = ArgumentCaptor.forClass(Map.class); - verify(mockUserStateUserDocRef).set(userStateCaptor.capture(), any(SetOptions.class)); - assertThat(userStateCaptor.getValue()).containsEntry("userKey", "userValue"); - } - /** Tests that getSession skips malformed events and returns only the well-formed ones. */ @Test @SuppressWarnings("unchecked") diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 3565c3e99..83fd60e54 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -157,9 +157,6 @@ public void setRequestedToolConfirmations( Map requestedToolConfirmations) { if (requestedToolConfirmations == null) { this.requestedToolConfirmations = new ConcurrentHashMap<>(); - } else if (requestedToolConfirmations instanceof ConcurrentMap) { - this.requestedToolConfirmations = - (ConcurrentMap) requestedToolConfirmations; } else { this.requestedToolConfirmations = new ConcurrentHashMap<>(requestedToolConfirmations); } @@ -290,8 +287,6 @@ public Builder skipSummarization(boolean skipSummarization) { public Builder stateDelta(@Nullable Map value) { if (value == null) { this.stateDelta = new ConcurrentHashMap<>(); - } else if (value instanceof ConcurrentMap) { - this.stateDelta = (ConcurrentMap) value; } else { this.stateDelta = new ConcurrentHashMap<>(value); } @@ -300,8 +295,12 @@ public Builder stateDelta(@Nullable Map value) { @CanIgnoreReturnValue @JsonProperty("artifactDelta") - public Builder artifactDelta(Map value) { - this.artifactDelta = new ConcurrentHashMap<>(value); + public Builder artifactDelta(@Nullable Map value) { + if (value == null) { + this.artifactDelta = new ConcurrentHashMap<>(); + } else { + this.artifactDelta = new ConcurrentHashMap<>(value); + } return this; } @@ -339,10 +338,6 @@ public Builder requestedAuthConfigs( public Builder requestedToolConfirmations(@Nullable Map value) { if (value == null) { this.requestedToolConfirmations = new ConcurrentHashMap<>(); - return this; - } - if (value instanceof ConcurrentMap) { - this.requestedToolConfirmations = (ConcurrentMap) value; } else { this.requestedToolConfirmations = new ConcurrentHashMap<>(value); } diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 22bb94e64..b1e645e1a 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -177,17 +177,6 @@ public void merge_failsOnMismatchedKeyTypesNestedInStateDelta() { IllegalArgumentException.class, () -> eventActions1.toBuilder().merge(eventActions2)); } - @Test - public void setRequestedToolConfirmations_withConcurrentMap_usesSameInstance() { - ConcurrentHashMap map = new ConcurrentHashMap<>(); - map.put("tool", TOOL_CONFIRMATION); - - EventActions actions = new EventActions(); - actions.setRequestedToolConfirmations(map); - - assertThat(actions.requestedToolConfirmations()).isSameInstanceAs(map); - } - @Test public void setRequestedToolConfirmations_withRegularMap_createsConcurrentMap() { ImmutableMap map = ImmutableMap.of("tool", TOOL_CONFIRMATION); From cd56902b803d4f7a1f3c718529842823d9e4370a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Mar 2026 06:35:06 -0700 Subject: [PATCH 18/40] feat: Update return type of toolsets() from ImmutableList to List PiperOrigin-RevId: 886145022 --- core/src/main/java/com/google/adk/agents/LlmAgent.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index b387aee34..077068283 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -772,7 +772,7 @@ public List toolsUnion() { return toolsUnion; } - public ImmutableList toolsets() { + public List toolsets() { return toolsets; } From 9a080763d83c319f539d1bacac4595d13b299e7e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Mar 2026 07:11:45 -0700 Subject: [PATCH 19/40] feat: fixing context propagation for agent transfers PiperOrigin-RevId: 886159283 --- .../adk/flows/llmflows/BaseLlmFlow.java | 16 +- .../java/com/google/adk/telemetry/README.md | 156 ++++++++++ .../adk/telemetry/ContextPropagationTest.java | 269 +++++++++++++----- 3 files changed, 368 insertions(+), 73 deletions(-) create mode 100644 core/src/main/java/com/google/adk/telemetry/README.md diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index e00cf0cbf..8fabc978d 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -430,7 +430,12 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex "Agent not found: " + agentToTransfer))); } return postProcessedEvents.concatWith( - Flowable.defer(() -> nextAgent.get().runAsync(context))); + Flowable.defer( + () -> { + try (Scope s = spanContext.makeCurrent()) { + return nextAgent.get().runAsync(context); + } + })); } return postProcessedEvents; }); @@ -488,6 +493,8 @@ private Flowable run( public Flowable runLive(InvocationContext invocationContext) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); + // Capture agent context at assembly time to use as parent for agent transfer at subscription + // time. See Flowable.defer() usages below. Context spanContext = Context.current(); return preprocessEvents.concatWith( @@ -608,7 +615,12 @@ public void onError(Throwable e) { "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = - nextAgent.get().runLive(invocationContext); + Flowable.defer( + () -> { + try (Scope s = spanContext.makeCurrent()) { + return nextAgent.get().runLive(invocationContext); + } + }); events = Flowable.concat(events, nextAgentEvents); } return events; diff --git a/core/src/main/java/com/google/adk/telemetry/README.md b/core/src/main/java/com/google/adk/telemetry/README.md new file mode 100644 index 000000000..8665b3352 --- /dev/null +++ b/core/src/main/java/com/google/adk/telemetry/README.md @@ -0,0 +1,156 @@ +# ADK Telemetry and Tracing + +This package contains classes for capturing and reporting telemetry data within +the ADK, primarily for tracing agent execution leveraging OpenTelemetry. + +## Overview + +The `Tracing` utility class provides methods to trace various aspects of an +agent's execution, including: + +* Agent invocations +* LLM requests and responses +* Tool calls and responses + +These traces can be exported and visualized in telemetry backends like Google +Cloud Trace or Zipkin, or viewed through the ADK Dev Server UI, providing +observability into agent behavior. + +## How Tracing is Used + +Tracing is deeply integrated into the ADK's RxJava-based asynchronous workflows. + +### Agent Invocations + +Every agent's `runAsync` or `runLive` execution is wrapped in a span named +`invoke_agent `. The top-level agent invocation initiated by +`Runner.runAsync` or `Runner.runLive` is captured in a span named `invocation`. +Agent-specific metadata like name and description are added as span attributes, +following OpenTelemetry semantic conventions (e.g., `gen_ai.agent.name`). + +### LLM Calls + +Calls to Large Language Models (LLMs) are traced within a `call_llm` span. The +`traceCallLlm` method attaches detailed attributes to this span, including: + +* The LLM request (excluding large data like images) and response. +* Model name (`gen_ai.request.model`). +* Token usage (`gen_ai.usage.input_tokens`, `gen_ai.usage.output_tokens`). +* Configuration parameters (`gen_ai.request.top_p`, + `gen_ai.request.max_tokens`). +* Response finish reason (`gen_ai.response.finish_reasons`). + +### Tool Calls and Responses + +Tool executions triggered by the LLM are traced using `tool_call []` +and `tool_response []` spans. + +* `traceToolCall` records tool arguments in the + `gcp.vertex.agent.tool_call_args` attribute. +* `traceToolResponse` records tool output in the + `gcp.vertex.agent.tool_response` attribute. +* If multiple tools are called in parallel, a single `tool_response` span may + be created for the merged result. + +### Context Propagation + +ADK is built on RxJava and heavily uses asynchronous processing, which means +that work is often handed off between different threads. For tracing to work +correctly in such an environment, it's crucial that the active span's context +is propagated across these thread boundaries. If context is not propagated, +new spans may be orphaned or attached to the wrong parent, making traces +difficult to interpret. + +OpenTelemetry stores the currently active span in a thread-local variable. +When an asynchronous operation switches threads, this thread-local context is +lost. To solve this, ADK's `Tracing` class provides functionality to capture +the context on one thread and restore it on another when an asynchronous +operation resumes. This ensures that spans created on different threads are +correctly parented under the same trace. + +The primary mechanism for this is the `Tracing.withContext(context)` method, +which returns an RxJava transformer. When applied to an RxJava stream via +`.compose()`, this transformer ensures that the provided `Context` (containing +the parent span) is re-activated before any `onNext`, `onError`, `onComplete`, +or `onSuccess` signals are propagated downstream. It achieves this by wrapping +the downstream observer with a `TracingObserver`, which uses +`context.makeCurrent()` in a try-with-resources block around each callback, +guaranteeing that the correct span is active when downstream operators execute, +regardless of the thread. + +### RxJava Integration + +ADK integrates OpenTelemetry with RxJava streams to simplify span creation and +ensure context propagation: + +* **Span Creation**: The `Tracing.trace(spanName)` method returns an RxJava + transformer that can be applied to a `Flowable`, `Single`, `Maybe`, or + `Completable` using `.compose()`. This transformer wraps the stream's + execution in a new OpenTelemetry span. +* **Context Propagation**: The `Tracing.withContext(context)` transformer is + used with `.compose()` to ensure that the correct OpenTelemetry `Context` + (and thus the correct parent span) is active when stream operators or + subscriptions are executed, even across thread boundaries. + +## Trace Hierarchy Example + +A typical agent interaction might produce a trace hierarchy like the following: + +``` +invocation +└── invoke_agent my_agent + ├── call_llm + │ ├── tool_call [search_flights] + │ └── tool_response [search_flights] + └── call_llm +``` + +This shows: + +1. The overall `invocation` started by the `Runner`. +2. The invocation of `my_agent`. +3. The first `call_llm` made by `my_agent`. +4. A `tool_call` to `search_flights` and its corresponding `tool_response`. +5. A second `call_llm` made by `my_agent` to generate the final user response. + +### Nested Agents + +ADK supports nested agents, where one agent invokes another. If an agent has +sub-agents, it can transfer control to one of them using the built-in +`transfer_to_agent` tool. When `AgentA` calls `transfer_to_agent` to transfer +control to `AgentB`, the `invoke_agent AgentB` span will appear as a child of +the `invoke_agent AgentA` span, like so: + +``` +invocation +└── invoke_agent AgentA + ├── call_llm + │ ├── tool_call [transfer_to_agent] + │ └── tool_response [transfer_to_agent] + └── invoke_agent AgentB + ├── call_llm + └── ... +``` + +This structure allows you to see how `AgentA` delegated work to `AgentB`. + +## Span Creation References + +The following classes are the primary places where spans are created: + +* **`com.google.adk.runner.Runner`**: Initiates the top-level `invocation` + span for `runAsync` and `runLive`. +* **`com.google.adk.agents.BaseAgent`**: Creates the `invoke_agent + ` span for each agent execution. +* **`com.google.adk.flows.llmflows.BaseLlmFlow`**: Creates the `call_llm` span + when the LLM is invoked. +* **`com.google.adk.flows.llmflows.Functions`**: Creates `tool_call [...]` and + `tool_response [...]` spans when handling tool calls and responses. + +## Configuration + +**ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS**: This environment variable controls +whether LLM request/response content and tool arguments/responses are captured +in span attributes. It defaults to `true`. Set to `false` to exclude potentially +large or sensitive data from traces, in which case a `{}` JSON object will be +recorded instead. diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index e5795d61f..1ee018848 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -16,6 +16,7 @@ package com.google.adk.telemetry; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -32,10 +33,15 @@ import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; +import com.google.adk.testing.TestLlm; +import com.google.adk.testing.TestUtils; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.GenerateContentResponseUsageMetadata; @@ -54,6 +60,7 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.schedulers.Schedulers; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -123,10 +130,7 @@ public void testToolCallSpanLinksToParent() { parentSpanData.getSpanContext().getTraceId(), toolCallSpanData.getSpanContext().getTraceId()); - assertEquals( - "Tool call's parent should be the parent span", - parentSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, toolCallSpanData); } @Test @@ -146,7 +150,7 @@ public void testToolCallWithoutParentCreatesRootSpan() { // Then: Should create root span (backward compatible) List spans = openTelemetryRule.getSpans(); - assertEquals("Should have exactly 1 span", 1, spans.size()); + assertThat(spans).hasSize(1); SpanData toolCallSpanData = spans.get(0); assertFalse( @@ -193,7 +197,7 @@ public void testNestedSpanHierarchy() { List spans = openTelemetryRule.getSpans(); // The 4 spans are: "parent", "invocation", "tool_call [testTool]", and "tool_response // [testTool]". - assertEquals("Should have 4 spans in the hierarchy", 4, spans.size()); + assertThat(spans).hasSize(4); SpanData parentSpanData = findSpanByName("parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); @@ -210,22 +214,13 @@ public void testNestedSpanHierarchy() { SpanData toolResponseSpanData = findSpanByName("tool_response [testTool]"); // invocation should be child of parent - assertEquals( - "Invocation should be child of parent", - parentSpanData.getSpanContext().getSpanId(), - invocationSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, invocationSpanData); // tool_call should be child of invocation - assertEquals( - "Tool call should be child of invocation", - invocationSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId()); + assertParent(invocationSpanData, toolCallSpanData); // tool_response should be child of tool_call - assertEquals( - "Tool response should be child of tool call", - toolCallSpanData.getSpanContext().getSpanId(), - toolResponseSpanData.getParentSpanContext().getSpanId()); + assertParent(toolCallSpanData, toolResponseSpanData); } @Test @@ -253,7 +248,6 @@ public void testMultipleSpansInParallel() { // Verify all tool calls link to same parent SpanData parentSpanData = findSpanByName("parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - String parentSpanId = parentSpanData.getSpanContext().getSpanId(); // All tool calls should have same trace ID and parent span ID List toolCallSpans = @@ -261,7 +255,7 @@ public void testMultipleSpansInParallel() { .filter(s -> s.getName().startsWith("tool_call")) .toList(); - assertEquals("Should have 3 tool call spans", 3, toolCallSpans.size()); + assertThat(toolCallSpans).hasSize(3); toolCallSpans.forEach( span -> { @@ -269,10 +263,7 @@ public void testMultipleSpansInParallel() { "Tool call should have same trace ID as parent", parentTraceId, span.getSpanContext().getTraceId()); - assertEquals( - "Tool call should have parent as parent span", - parentSpanId, - span.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, span); }); } @@ -298,10 +289,7 @@ public void testInvokeAgentSpanLinksToInvocation() { SpanData invocationSpanData = findSpanByName("invocation"); SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - assertEquals( - "Agent run should be child of invocation", - invocationSpanData.getSpanContext().getSpanId(), - invokeAgentSpanData.getParentSpanContext().getSpanId()); + assertParent(invocationSpanData, invokeAgentSpanData); } @Test @@ -323,15 +311,12 @@ public void testCallLlmSpanLinksToAgentRun() { } List spans = openTelemetryRule.getSpans(); - assertEquals("Should have 2 spans", 2, spans.size()); + assertThat(spans).hasSize(2); SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); SpanData callLlmSpanData = findSpanByName("call_llm"); - assertEquals( - "Call LLM should be child of agent run", - invokeAgentSpanData.getSpanContext().getSpanId(), - callLlmSpanData.getParentSpanContext().getSpanId()); + assertParent(invokeAgentSpanData, callLlmSpanData); } @Test @@ -349,10 +334,7 @@ public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { SpanData parentSpanData = findSpanByName("invocation"); SpanData agentSpanData = findSpanByName("invoke_agent"); - assertEquals( - "Agent span should be a child of the invocation span", - parentSpanData.getSpanContext().getSpanId(), - agentSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, agentSpanData); } @Test @@ -380,9 +362,7 @@ public void testTraceFlowable() throws InterruptedException { SpanData parentSpanData = findSpanByName("parent"); SpanData flowableSpanData = findSpanByName("flowable"); - assertEquals( - parentSpanData.getSpanContext().getSpanId(), - flowableSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, flowableSpanData); assertTrue(flowableSpanData.hasEnded()); } @@ -469,9 +449,7 @@ public void testTraceTransformer() throws InterruptedException { SpanData parentSpanData = findSpanByName("parent"); SpanData transformerSpanData = findSpanByName("transformer"); - assertEquals( - parentSpanData.getSpanContext().getSpanId(), - transformerSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, transformerSpanData); assertTrue(transformerSpanData.hasEnded()); } @@ -485,7 +463,7 @@ public void testTraceAgentInvocation() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("invoke_agent", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -504,7 +482,7 @@ public void testTraceToolCall() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -541,7 +519,7 @@ public void testTraceToolResponse() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -578,7 +556,7 @@ public void testTraceCallLlm() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); @@ -606,12 +584,12 @@ public void testTraceSendData() { Tracing.traceSendData( buildInvocationContext(), "event-1", - ImmutableList.of(Content.builder().role("user").parts(Part.fromText("hello")).build())); + ImmutableList.of(Content.fromParts(Part.fromText("hello")))); } finally { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals( @@ -653,37 +631,23 @@ public void baseAgentRunAsync_propagatesContext() throws InterruptedException { } SpanData parent = findSpanByName("parent"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals(parent.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, agentSpan); } @Test public void runnerRunAsync_propagatesContext() throws InterruptedException { BaseAgent agent = new TestAgent(); - Runner runner = Runner.builder().agent(agent).appName("test-app").build(); Span parentSpan = tracer.spanBuilder("parent").startSpan(); try (Scope s = parentSpan.makeCurrent()) { - Session session = - runner - .sessionService() - .createSession(new SessionKey("test-app", "test-user", "test-session")) - .blockingGet(); - Content newMessage = Content.fromParts(Part.fromText("hi")); - RunConfig runConfig = RunConfig.builder().build(); - runner - .runAsync(session.userId(), session.id(), newMessage, runConfig, null) - .test() - .await() - .assertComplete(); + runAgent(agent); } finally { parentSpan.end(); } SpanData parent = findSpanByName("parent"); SpanData invocation = findSpanByName("invocation"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals( - parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); - assertEquals( - invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, invocation); + assertParent(invocation, agentSpan); } @Test @@ -713,10 +677,173 @@ public void runnerRunLive_propagatesContext() throws InterruptedException { SpanData parent = findSpanByName("parent"); SpanData invocation = findSpanByName("invocation"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals( - parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); - assertEquals( - invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, invocation); + assertParent(invocation, agentSpan); + } + + @Test + public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { + // This test verifies the trace hierarchy created when an agent calls an LLM, + // which then invokes a tool. The expected hierarchy is: + // invocation + // └── invoke_agent test_agent + // ├── call_llm + // │ ├── tool_call [search_flights] + // │ └── tool_response [search_flights] + // └── call_llm + + SearchFlightsTool searchFlightsTool = new SearchFlightsTool(); + + TestLlm testLlm = + TestUtils.createTestLlm( + TestUtils.createLlmResponse( + Content.builder() + .role("model") + .parts( + Part.fromFunctionCall( + searchFlightsTool.name(), ImmutableMap.of("destination", "SFO"))) + .build()), + TestUtils.createLlmResponse(Content.fromParts(Part.fromText("done")))); + + LlmAgent agentWithTool = + LlmAgent.builder() + .name("test_agent") + .description("description") + .model(testLlm) + .tools(ImmutableList.of(searchFlightsTool)) + .build(); + + runAgent(agentWithTool); + + SpanData invocation = findSpanByName("invocation"); + SpanData invokeAgent = findSpanByName("invoke_agent test_agent"); + SpanData toolCall = findSpanByName("tool_call [search_flights]"); + SpanData toolResponse = findSpanByName("tool_response [search_flights]"); + List callLlmSpans = + openTelemetryRule.getSpans().stream() + .filter(s -> s.getName().equals("call_llm")) + .sorted(Comparator.comparing(SpanData::getStartEpochNanos)) + .toList(); + assertThat(callLlmSpans).hasSize(2); + SpanData callLlm1 = callLlmSpans.get(0); + SpanData callLlm2 = callLlmSpans.get(1); + + // Assert hierarchy: + // invocation + // └── invoke_agent test_agent + assertParent(invocation, invokeAgent); + // ├── call_llm 1 + assertParent(invokeAgent, callLlm1); + // │ ├── tool_call [search_flights] + assertParent(callLlm1, toolCall); + // │ └── tool_response [search_flights] + assertParent(callLlm1, toolResponse); + // └── call_llm 2 + assertParent(invokeAgent, callLlm2); + } + + @Test + public void testNestedAgentTraceHierarchy() throws InterruptedException { + // This test verifies the trace hierarchy created when AgentA transfers to AgentB. + // The expected hierarchy is: + // invocation + // └── invoke_agent AgentA + // ├── call_llm + // │ ├── tool_call [transfer_to_agent] + // │ └── tool_response [transfer_to_agent] + // └── invoke_agent AgentB + // └── call_llm + TestLlm llm = + TestUtils.createTestLlm( + TestUtils.createLlmResponse( + Content.builder() + .role("model") + .parts( + Part.fromFunctionCall( + "transfer_to_agent", ImmutableMap.of("agent_name", "AgentB"))) + .build()), + TestUtils.createLlmResponse(Content.fromParts(Part.fromText("agent b response")))); + LlmAgent agentB = LlmAgent.builder().name("AgentB").description("Agent B").model(llm).build(); + + LlmAgent agentA = + LlmAgent.builder() + .name("AgentA") + .description("Agent A") + .model(llm) + .subAgents(ImmutableList.of(agentB)) + .build(); + + runAgent(agentA); + + SpanData invocation = findSpanByName("invocation"); + SpanData agentASpan = findSpanByName("invoke_agent AgentA"); + SpanData toolCall = findSpanByName("tool_call [transfer_to_agent]"); + SpanData agentBSpan = findSpanByName("invoke_agent AgentB"); + SpanData toolResponse = findSpanByName("tool_response [transfer_to_agent]"); + + List callLlmSpans = + openTelemetryRule.getSpans().stream() + .filter(s -> s.getName().equals("call_llm")) + .sorted(Comparator.comparing(SpanData::getStartEpochNanos)) + .toList(); + assertThat(callLlmSpans).hasSize(2); + + SpanData agentACallLlm1 = callLlmSpans.get(0); + SpanData agentBCallLlm = callLlmSpans.get(1); + + assertParent(invocation, agentASpan); + assertParent(agentASpan, agentACallLlm1); + assertParent(agentACallLlm1, toolCall); + assertParent(agentACallLlm1, toolResponse); + assertParent(agentASpan, agentBSpan); + assertParent(agentBSpan, agentBCallLlm); + } + + private void runAgent(BaseAgent agent) throws InterruptedException { + Runner runner = Runner.builder().agent(agent).appName("test-app").build(); + Session session = + runner + .sessionService() + .createSession(new SessionKey("test-app", "test-user", "test-session")) + .blockingGet(); + Content newMessage = Content.fromParts(Part.fromText("hi")); + RunConfig runConfig = RunConfig.builder().build(); + runner + .runAsync(session.sessionKey(), newMessage, runConfig, null) + .test() + .await() + .assertComplete(); + } + + /** Tool for testing. */ + public static class SearchFlightsTool extends BaseTool { + public SearchFlightsTool() { + super("search_flights", "Search for flights tool"); + } + + @Override + public Single> runAsync(Map args, ToolContext context) { + return Single.just(ImmutableMap.of("result", args)); + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name("search_flights") + .description("Search for flights tool") + .build()); + } + } + + /** + * Asserts that the parent span is the parent of the child span. + * + * @param parent The parent span. + * @param child The child span. + */ + private void assertParent(SpanData parent, SpanData child) { + assertEquals(parent.getSpanContext().getSpanId(), child.getParentSpanContext().getSpanId()); } /** From 0af82e61a3c0dbbd95166a10b450cb507115ab60 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Mar 2026 07:29:15 -0700 Subject: [PATCH 20/40] fix: Removing deprecated methods in Runner PiperOrigin-RevId: 886166671 --- .../java/com/google/adk/runner/Runner.java | 42 ------------------- 1 file changed, 42 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 1f7d924ab..849a3cd04 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -425,36 +425,6 @@ public Flowable runAsync(String userId, String sessionId, Content newMess return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); } - /** - * See {@link #runAsync(Session, Content, RunConfig, Map)}. - * - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { - return runAsync(session, newMessage, runConfig, /* stateDelta= */ null); - } - - /** - * Runs the agent asynchronously using a provided Session object. - * - * @param session The session to run the agent in. - * @param newMessage The new message from the user to process. - * @param runConfig Configuration for the agent run. - * @param stateDelta Optional map of state updates to merge into the session for this run. - * @return A Flowable stream of {@link Event} objects generated by the agent during execution. - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync( - Session session, - Content newMessage, - RunConfig runConfig, - @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta) - .compose(Tracing.trace("invocation")); - } - /** * Runs the agent asynchronously using a provided Session object. * @@ -735,18 +705,6 @@ protected Flowable runLiveImpl( }); } - /** - * Runs the agent asynchronously with a default user ID. - * - * @return stream of generated events. - */ - @Deprecated(since = "0.5.0", forRemoval = true) - public Flowable runWithSessionId( - String sessionId, Content newMessage, RunConfig runConfig) { - // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. - return this.runAsync("tmp-user", sessionId, newMessage, runConfig); - } - /** * Checks if the agent and its parent chain allow transfer up the tree. * From dc5d794c066571c7d87f006767bd32298e2a3ba8 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Thu, 19 Mar 2026 08:39:07 -0700 Subject: [PATCH 21/40] chore: set version to 1.0.0-rc.1 Release-As: 1.0.0-rc.1 PiperOrigin-RevId: 886198912 --- .release-please-manifest.json | 1 - 1 file changed, 1 deletion(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index b0f3ba770..6db3039d0 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,4 +1,3 @@ { ".": "0.9.0" } - From dfbab955314a428ccb17855a69a77386e924c92b Mon Sep 17 00:00:00 2001 From: ddobrin Date: Tue, 17 Mar 2026 09:30:34 -0400 Subject: [PATCH 22/40] Updated tests in Spring AI, Langchain4j, dependency for Spering AI and GenAI SDK --- .../LangChain4jIntegrationTest.java | 16 ++--- contrib/spring-ai/pom.xml | 2 +- .../AnthropicApiIntegrationTest.java | 62 ++++++++++++------- pom.xml | 2 +- 4 files changed, 48 insertions(+), 34 deletions(-) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index 3fafb046d..191e48017 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -44,7 +44,7 @@ class LangChain4jIntegrationTest { - public static final String CLAUDE_3_7_SONNET_20250219 = "claude-3-7-sonnet-20250219"; + public static final String CLAUDE_4_6_SONNET = "claude-sonnet-4-6"; public static final String GEMINI_2_0_FLASH = "gemini-2.0-flash"; public static final String GPT_4_O_MINI = "gpt-4o-mini"; @@ -55,14 +55,14 @@ void testSimpleAgent() { AnthropicChatModel claudeModel = AnthropicChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); LlmAgent agent = LlmAgent.builder() .name("science-app") .description("Science teacher agent") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) .instruction( """ You are a helpful science teacher that explains science concepts @@ -91,14 +91,14 @@ void testSingleAgentWithTools() { AnthropicChatModel claudeModel = AnthropicChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); BaseAgent agent = LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) .instruction( """ You are a friendly assistant. @@ -155,7 +155,7 @@ void testSingleAgentWithTools() { List partsThree = contentThree.parts().get(); assertEquals(1, partsThree.size()); assertTrue(partsThree.get(0).text().isPresent()); - assertTrue(partsThree.get(0).text().get().contains("beautiful")); + assertTrue(partsThree.get(0).text().get().contains("sunny")); } @Test @@ -352,10 +352,10 @@ void testSimpleStreamingResponse() { AnthropicStreamingChatModel claudeStreamingModel = AnthropicStreamingChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); - LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_3_7_SONNET_20250219); + LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_4_6_SONNET); // when Flowable responses = diff --git a/contrib/spring-ai/pom.xml b/contrib/spring-ai/pom.xml index b24fa4b63..08d237ab5 100644 --- a/contrib/spring-ai/pom.xml +++ b/contrib/spring-ai/pom.xml @@ -29,7 +29,7 @@ Spring AI integration for the Agent Development Kit. - 2.0.0-M2 + 2.0.0-M3 1.21.3 diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java index f21b07ae9..c59a94f82 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java @@ -34,7 +34,6 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; -import org.springframework.ai.anthropic.api.AnthropicApi; /** * Integration tests with real Anthropic API. @@ -53,10 +52,14 @@ void testSimpleAgentWithRealAnthropicApi() throws InterruptedException { Thread.sleep(2000); // Create Anthropic model using Spring AI's builder pattern - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); // Wrap with SpringAI SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -92,10 +95,14 @@ void testStreamingWithRealAnthropicApi() throws InterruptedException { // Add delay to avoid rapid requests Thread.sleep(2000); - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -134,10 +141,14 @@ void testStreamingWithRealAnthropicApi() throws InterruptedException { @Test void testAgentWithToolsAndRealApi() { - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); LlmAgent agent = LlmAgent.builder() @@ -175,10 +186,13 @@ void testAgentWithToolsAndRealApi() { @Test void testDirectComparisonNonStreamingVsStreaming() throws InterruptedException { // Test both non-streaming and streaming with the same model to compare behavior - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -271,13 +285,13 @@ void testDirectComparisonNonStreamingVsStreaming() throws InterruptedException { @Test void testConfigurationOptions() { // Test with custom configuration - AnthropicChatOptions options = - AnthropicChatOptions.builder().model(CLAUDE_MODEL).temperature(0.7).maxTokens(100).build(); - - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).defaultOptions(options).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); diff --git a/pom.xml b/pom.xml index 62082cfc9..0be05a629 100644 --- a/pom.xml +++ b/pom.xml @@ -51,7 +51,7 @@ 1.51.0 0.17.2 2.47.0 - 1.41.0 + 1.43.0 4.33.5 5.11.4 5.20.0 From dbb139439d38157b4b9af38c52824b1e8405a495 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Thu, 19 Mar 2026 10:14:14 -0700 Subject: [PATCH 23/40] feat!: remove McpToolset constructors taking Optional parameters PiperOrigin-RevId: 886244600 --- .../com/google/adk/tools/mcp/McpToolset.java | 174 ++++++++++++------ .../google/adk/tools/mcp/McpToolsetTest.java | 9 +- 2 files changed, 120 insertions(+), 63 deletions(-) diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java index 207243ceb..4cafb9681 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java @@ -24,6 +24,8 @@ import com.google.adk.agents.ReadonlyContext; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ToolPredicate; +import com.google.common.collect.ImmutableList; import com.google.common.primitives.Booleans; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.ServerParameters; @@ -32,6 +34,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,7 +54,7 @@ public class McpToolset implements BaseToolset { private final McpSessionManager mcpSessionManager; private McpSyncClient mcpSession; private final ObjectMapper objectMapper; - private final Optional toolFilter; + private final @Nullable Object toolFilter; private static final int MAX_RETRIES = 3; private static final long RETRY_DELAY_MILLIS = 100; @@ -62,17 +65,29 @@ public class McpToolset implements BaseToolset { * * @param connectionParams The SSE connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( SseServerParameters connectionParams, ObjectMapper objectMapper, - Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); + } + + /** + * Initializes the McpToolset with SSE server parameters. + * + * @param connectionParams The SSE connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names + */ + public McpToolset( + SseServerParameters connectionParams, ObjectMapper objectMapper, List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); } /** @@ -82,7 +97,9 @@ public McpToolset( * @param objectMapper An ObjectMapper instance for parsing schemas. */ public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMapper) { - this(connectionParams, objectMapper, Optional.empty()); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -90,36 +107,39 @@ public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMappe * * @param connectionParams The local server connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( - ServerParameters connectionParams, ObjectMapper objectMapper, Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); } /** - * Initializes the McpToolset with local server parameters and no tool filter. + * Initializes the McpToolset with local server parameters. * * @param connectionParams The local server connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names */ - public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) { - this(connectionParams, objectMapper, Optional.empty()); + public McpToolset( + ServerParameters connectionParams, ObjectMapper objectMapper, List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); } /** - * Initializes the McpToolset with SSE server parameters, using the ObjectMapper used across the - * ADK. + * Initializes the McpToolset with local server parameters and no tool filter. * - * @param connectionParams The SSE connection parameters to the MCP server. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param connectionParams The local server connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. */ - public McpToolset(SseServerParameters connectionParams, Optional toolFilter) { - this(connectionParams, JsonBaseModel.getMapper(), toolFilter); + public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -129,28 +149,31 @@ public McpToolset(SseServerParameters connectionParams, Optional toolFil * @param connectionParams The SSE connection parameters to the MCP server. */ public McpToolset(SseServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + this(connectionParams, JsonBaseModel.getMapper()); } /** * Initializes the McpToolset with local server parameters, using the ObjectMapper used across the - * ADK. + * ADK and no tool filter. * * @param connectionParams The local server connection parameters to the MCP server. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. */ - public McpToolset(ServerParameters connectionParams, Optional toolFilter) { - this(connectionParams, JsonBaseModel.getMapper(), toolFilter); + public McpToolset(ServerParameters connectionParams) { + this(connectionParams, JsonBaseModel.getMapper()); } /** - * Initializes the McpToolset with local server parameters, using the ObjectMapper used across the - * ADK and no tool filter. + * Initializes the McpToolset with an McpSessionManager. * - * @param connectionParams The local server connection parameters to the MCP server. + * @param mcpSessionManager A McpSessionManager instance for testing. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolPredicate A {@link ToolPredicate} */ - public McpToolset(ServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + public McpToolset( + McpSessionManager mcpSessionManager, ObjectMapper objectMapper, ToolPredicate toolPredicate) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = Objects.requireNonNull(toolPredicate); } /** @@ -158,33 +181,69 @@ public McpToolset(ServerParameters connectionParams) { * * @param mcpSessionManager A McpSessionManager instance for testing. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolNames A list of tool names */ public McpToolset( - McpSessionManager mcpSessionManager, ObjectMapper objectMapper, Optional toolFilter) { - Objects.requireNonNull(mcpSessionManager); - Objects.requireNonNull(objectMapper); - this.mcpSessionManager = mcpSessionManager; - this.objectMapper = objectMapper; - this.toolFilter = toolFilter; + McpSessionManager mcpSessionManager, ObjectMapper objectMapper, List toolNames) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = ImmutableList.copyOf(toolNames); + } + + /** + * Initializes the McpToolset with an McpSessionManager and no tool filter. + * + * @param mcpSessionManager A McpSessionManager instance for testing. + * @param objectMapper An ObjectMapper instance for parsing schemas. + */ + public McpToolset(McpSessionManager mcpSessionManager, ObjectMapper objectMapper) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = null; } /** - * Initializes the McpToolset with Steamable HTTP server parameters. + * Initializes the McpToolset with Streamable HTTP server parameters. * * @param connectionParams The Streamable HTTP connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( StreamableHttpServerParameters connectionParams, ObjectMapper objectMapper, - Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); + } + + /** + * Initializes the McpToolset with Streamable HTTP server parameters. + * + * @param connectionParams The Streamable HTTP connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names + */ + public McpToolset( + StreamableHttpServerParameters connectionParams, + ObjectMapper objectMapper, + List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); + } + + /** + * Initializes the McpToolset with Streamable HTTP server parameters and no tool filter. + * + * @param connectionParams The Streamable HTTP connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + */ + public McpToolset(StreamableHttpServerParameters connectionParams, ObjectMapper objectMapper) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -194,7 +253,7 @@ public McpToolset( * @param connectionParams The Streamable HTTP connection parameters to the MCP server. */ public McpToolset(StreamableHttpServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + this(connectionParams, JsonBaseModel.getMapper()); } @Override @@ -215,8 +274,7 @@ public Flowable getTools(ReadonlyContext readonlyContext) { tool -> new McpTool( tool, this.mcpSession, this.mcpSessionManager, this.objectMapper)) - .filter( - tool -> isToolSelected(tool, toolFilter.orElse(null), readonlyContext))); + .filter(tool -> isToolSelected(tool, toolFilter, readonlyContext))); }) .retryWhen( errorObservable -> @@ -357,16 +415,18 @@ public static McpToolset fromConfig(BaseTool.ToolConfig config, String configAbs + " for McpToolset"); } - // Convert tool filter to Optional - Optional toolFilter = Optional.ofNullable(mcpToolsetConfig.toolFilter()); - + List toolNames = mcpToolsetConfig.toolFilter(); Object connectionParameters = Optional.ofNullable(mcpToolsetConfig.stdioConnectionParams()) .or(() -> Optional.ofNullable(mcpToolsetConfig.sseServerParams())) .orElse(mcpToolsetConfig.stdioConnectionParams()); // Create McpToolset with McpSessionManager having appropriate connection parameters - return new McpToolset(new McpSessionManager(connectionParameters), mapper, toolFilter); + if (toolNames != null) { + return new McpToolset(new McpSessionManager(connectionParameters), mapper, toolNames); + } else { + return new McpToolset(new McpSessionManager(connectionParameters), mapper); + } } catch (IllegalArgumentException e) { throw new ConfigurationException("Failed to parse McpToolsetConfig from ToolArgsConfig", e); } diff --git a/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java b/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java index 3d322b73f..0db218347 100644 --- a/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java +++ b/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java @@ -34,7 +34,6 @@ import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.spec.McpSchema; import java.util.List; -import java.util.Optional; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -324,7 +323,7 @@ public void getTools_withToolFilter_returnsFilteredTools() { when(mockMcpSyncClient.listTools()).thenReturn(mockResult); McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.of(toolFilter)); + new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), toolFilter); List tools = toolset.getTools(mockReadonlyContext).toList().blockingGet(); @@ -340,8 +339,7 @@ public void getTools_retriesAndFailsAfterMaxRetries() { when(mockMcpSessionManager.createSession()).thenReturn(mockMcpSyncClient); when(mockMcpSyncClient.listTools()).thenThrow(new RuntimeException("Test Exception")); - McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty()); + McpToolset toolset = new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper()); toolset .getTools(mockReadonlyContext) @@ -362,8 +360,7 @@ public void getTools_succeedsOnLastRetryAttempt() { .thenThrow(new RuntimeException("Attempt 2 failed")) .thenReturn(mockResult); - McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty()); + McpToolset toolset = new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper()); List tools = toolset.getTools(mockReadonlyContext).toList().blockingGet(); From 897f9d9776b75d66b6e7e01c98427d4d36d4dd5a Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Thu, 19 Mar 2026 10:16:25 -0700 Subject: [PATCH 24/40] chore: add test-jar goal in core sub-project PiperOrigin-RevId: 886245655 --- core/pom.xml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/core/pom.xml b/core/pom.xml index b3f2f5fd8..02c75f88b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -229,6 +229,16 @@ maven-compiler-plugin + + maven-jar-plugin + + + + test-jar + + + + maven-surefire-plugin From 4b9b99ae7149a465ba2ae9b7496e01f669786553 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 02:28:15 -0700 Subject: [PATCH 25/40] feat: update Session.state() and its builder to be of general Map types PiperOrigin-RevId: 886659065 --- core/src/main/java/com/google/adk/sessions/Session.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/Session.java b/core/src/main/java/com/google/adk/sessions/Session.java index f8376589a..94504fd96 100644 --- a/core/src/main/java/com/google/adk/sessions/Session.java +++ b/core/src/main/java/com/google/adk/sessions/Session.java @@ -27,8 +27,8 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; /** A {@link Session} object that encapsulates the {@link State} and {@link Event}s of a session. */ @JsonDeserialize(builder = Session.Builder.class) @@ -101,7 +101,7 @@ public Builder state(State state) { @CanIgnoreReturnValue @JsonProperty("state") - public Builder state(ConcurrentMap state) { + public Builder state(Map state) { this.state = new State(state); return this; } @@ -162,7 +162,7 @@ public String id() { } @JsonProperty("state") - public ConcurrentMap state() { + public Map state() { return state; } From 8ba4bfed3fa7045f3344329de7a39acddc64ee30 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 05:09:58 -0700 Subject: [PATCH 26/40] feat: Update return type of App.plugins() from ImmutableList to List PiperOrigin-RevId: 886722180 --- core/src/main/java/com/google/adk/apps/App.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/apps/App.java b/core/src/main/java/com/google/adk/apps/App.java index 897e24490..a087b738c 100644 --- a/core/src/main/java/com/google/adk/apps/App.java +++ b/core/src/main/java/com/google/adk/apps/App.java @@ -64,7 +64,7 @@ public BaseAgent rootAgent() { return rootAgent; } - public ImmutableList plugins() { + public List plugins() { return plugins; } From bf5ca82d5ad8adb3bdeb576d79f77eff6f9111d9 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Fri, 20 Mar 2026 05:14:14 -0700 Subject: [PATCH 27/40] chore: update VersionTest to allow rc versions PiperOrigin-RevId: 886723823 --- core/src/test/java/com/google/adk/VersionTest.java | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/test/java/com/google/adk/VersionTest.java b/core/src/test/java/com/google/adk/VersionTest.java index ff7939165..4b6f55c9b 100644 --- a/core/src/test/java/com/google/adk/VersionTest.java +++ b/core/src/test/java/com/google/adk/VersionTest.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; +import java.util.regex.Pattern; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -25,6 +26,11 @@ @RunWith(JUnit4.class) public class VersionTest { + // from semver.org + private static final Pattern SEM_VER = + Pattern.compile( + "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$"); + @Test public void versionShouldMatchProjectVersion() { assertThat(Version.JAVA_ADK_VERSION).isNotNull(); @@ -32,6 +38,11 @@ public void versionShouldMatchProjectVersion() { assertThat(Version.JAVA_ADK_VERSION).isNotEqualTo("unknown"); assertThat(Version.JAVA_ADK_VERSION).isNotEqualTo("${project.version}"); - assertThat(Version.JAVA_ADK_VERSION).matches("\\d+\\.\\d+\\.\\d+(-SNAPSHOT)?"); + assertThat(Version.JAVA_ADK_VERSION).matches("\\d+\\.\\d+\\.\\d+(-SNAPSHOT|-rc\\.\\d+)?"); + } + + @Test + public void versionShouldFollowSemanticVersioning() { + assertThat(Version.JAVA_ADK_VERSION).matches(SEM_VER); } } From 8af5e03811dfd548830df43103c81a592c8bf361 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 05:17:08 -0700 Subject: [PATCH 28/40] feat: Return List instead of ImmutableList in CallbackUtil methods PiperOrigin-RevId: 886725038 --- core/src/main/java/com/google/adk/agents/BaseAgent.java | 6 ++++-- core/src/main/java/com/google/adk/agents/CallbackUtil.java | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index ed6631c50..95fe838cc 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -529,7 +529,8 @@ public B beforeAgentCallback(BeforeAgentCallback beforeAgentCallback) { @CanIgnoreReturnValue public B beforeAgentCallback(List beforeAgentCallback) { - this.beforeAgentCallback = CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback); + this.beforeAgentCallback = + ImmutableList.copyOf(CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback)); return self(); } @@ -541,7 +542,8 @@ public B afterAgentCallback(AfterAgentCallback afterAgentCallback) { @CanIgnoreReturnValue public B afterAgentCallback(List afterAgentCallback) { - this.afterAgentCallback = CallbackUtil.getAfterAgentCallbacks(afterAgentCallback); + this.afterAgentCallback = + ImmutableList.copyOf(CallbackUtil.getAfterAgentCallbacks(afterAgentCallback)); return self(); } diff --git a/core/src/main/java/com/google/adk/agents/CallbackUtil.java b/core/src/main/java/com/google/adk/agents/CallbackUtil.java index 11740ae9c..4eb8704b6 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackUtil.java +++ b/core/src/main/java/com/google/adk/agents/CallbackUtil.java @@ -42,7 +42,7 @@ public final class CallbackUtil { * @return normalized async callbacks, or empty list if input is null. */ @CanIgnoreReturnValue - public static ImmutableList getBeforeAgentCallbacks( + public static List getBeforeAgentCallbacks( List beforeAgentCallbacks) { return getCallbacks( beforeAgentCallbacks, @@ -59,7 +59,7 @@ public static ImmutableList getBeforeAgentCallbacks( * @return normalized async callbacks, or empty list if input is null. */ @CanIgnoreReturnValue - public static ImmutableList getAfterAgentCallbacks( + public static List getAfterAgentCallbacks( List afterAgentCallback) { return getCallbacks( afterAgentCallback, From f145c744482b6b25f29a0b718bd452065e39d930 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 05:23:16 -0700 Subject: [PATCH 29/40] feat: update requestedAuthConfigs and its builder to be of general Map types PiperOrigin-RevId: 886727168 --- .../com/google/adk/events/EventActions.java | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 83fd60e54..105300aa0 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -41,7 +41,7 @@ public class EventActions extends JsonBaseModel { private Set deletedArtifactIds; private @Nullable String transferToAgent; private @Nullable Boolean escalate; - private ConcurrentMap> requestedAuthConfigs; + private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent; private @Nullable EventCompaction compaction; @@ -139,13 +139,17 @@ public void setEscalate(@Nullable Boolean escalate) { } @JsonProperty("requestedAuthConfigs") - public ConcurrentMap> requestedAuthConfigs() { + public Map> requestedAuthConfigs() { return requestedAuthConfigs; } public void setRequestedAuthConfigs( - ConcurrentMap> requestedAuthConfigs) { - this.requestedAuthConfigs = requestedAuthConfigs; + Map> requestedAuthConfigs) { + if (requestedAuthConfigs == null) { + this.requestedAuthConfigs = new ConcurrentHashMap<>(); + } else { + this.requestedAuthConfigs = new ConcurrentHashMap<>(requestedAuthConfigs); + } } @JsonProperty("requestedToolConfirmations") @@ -248,7 +252,7 @@ public static class Builder { private Set deletedArtifactIds; private @Nullable String transferToAgent; private @Nullable Boolean escalate; - private ConcurrentMap> requestedAuthConfigs; + private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent = false; private @Nullable EventCompaction compaction; @@ -328,8 +332,12 @@ public Builder escalate(boolean escalate) { @CanIgnoreReturnValue @JsonProperty("requestedAuthConfigs") public Builder requestedAuthConfigs( - ConcurrentMap> value) { - this.requestedAuthConfigs = value; + @Nullable Map> value) { + if (value == null) { + this.requestedAuthConfigs = new ConcurrentHashMap<>(); + } else { + this.requestedAuthConfigs = new ConcurrentHashMap<>(value); + } return this; } From f59215d94fa6732e275def543c68c23247b4b718 Mon Sep 17 00:00:00 2001 From: adk-java-releases-bot Date: Fri, 20 Mar 2026 13:24:24 +0100 Subject: [PATCH 30/40] chore(main): release 1.0.0-rc.1 --- .release-please-manifest.json | 2 +- CHANGELOG.md | 43 +++++++++++++++++++ README.md | 4 +- a2a/pom.xml | 2 +- contrib/firestore-session-service/pom.xml | 2 +- contrib/langchain4j/pom.xml | 2 +- contrib/samples/a2a_basic/pom.xml | 2 +- contrib/samples/a2a_server/pom.xml | 2 +- contrib/samples/configagent/pom.xml | 2 +- contrib/samples/helloworld/pom.xml | 2 +- contrib/samples/mcpfilesystem/pom.xml | 2 +- contrib/samples/pom.xml | 2 +- contrib/spring-ai/pom.xml | 2 +- core/pom.xml | 2 +- .../src/main/java/com/google/adk/Version.java | 2 +- dev/pom.xml | 2 +- maven_plugin/examples/custom_tools/pom.xml | 2 +- maven_plugin/examples/simple-agent/pom.xml | 2 +- maven_plugin/pom.xml | 2 +- pom.xml | 2 +- tutorials/city-time-weather/pom.xml | 2 +- tutorials/live-audio-single-agent/pom.xml | 2 +- 22 files changed, 65 insertions(+), 22 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 6db3039d0..802e9d13f 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.9.0" + ".": "1.0.0-rc.1" } diff --git a/CHANGELOG.md b/CHANGELOG.md index ab111e90c..4ef000794 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,48 @@ # Changelog +## [1.0.0-rc.1](https://github.com/google/adk-java/compare/v0.9.0...v1.0.0-rc.1) (2026-03-20) + + +### ⚠ BREAKING CHANGES + +* remove McpToolset constructors taking Optional parameters +* remove deprecated Example processor + +### Features + +* add handling the a2a metadata in the RemoteA2AAgent; Add the enum type for the metadata keys ([e51f911](https://github.com/google/adk-java/commit/e51f9112050955657da0dfc3aedc00f90ad739ec)) +* add type-safe runAsync methods to BaseTool ([b8cb7e2](https://github.com/google/adk-java/commit/b8cb7e2db6d5ce20f4d7a1b237bdc155563cf4bd)) +* Enhance LangChain4j to support MCP tools with parametersJsonSchema ([2c71ba1](https://github.com/google/adk-java/commit/2c71ba1332e052189115cd4644b7a473c31ed414)) +* fixing context propagation for agent transfers ([9a08076](https://github.com/google/adk-java/commit/9a080763d83c319f539d1bacac4595d13b299e7e)) +* Implement basic version of BigQuery Agent Analytics Plugin ([c8ab0f9](https://github.com/google/adk-java/commit/c8ab0f96b09a6c9636728d634c62695fcd622246)) +* init AGENTS.md file ([7ebeb07](https://github.com/google/adk-java/commit/7ebeb07bf2ee72475484d8a31ccf7b4c601dda96)) +* Propagating the otel context ([8556d4a](https://github.com/google/adk-java/commit/8556d4af16ff04c6e3b678dcfc3d4bb232abc550)) +* remove McpToolset constructors taking Optional parameters ([dbb1394](https://github.com/google/adk-java/commit/dbb139439d38157b4b9af38c52824b1e8405a495)) +* Return List instead of ImmutableList in CallbackUtil methods ([8af5e03](https://github.com/google/adk-java/commit/8af5e03811dfd548830df43103c81a592c8bf361)) +* update requestedAuthConfigs and its builder to be of general Map types ([f145c74](https://github.com/google/adk-java/commit/f145c744482b6b25f29a0b718bd452065e39d930)) +* Update return type of App.plugins() from ImmutableList to List ([8ba4bfe](https://github.com/google/adk-java/commit/8ba4bfed3fa7045f3344329de7a39acddc64ee30)) +* Update return type of toolsets() from ImmutableList to List ([cd56902](https://github.com/google/adk-java/commit/cd56902b803d4f7a1f3c718529842823d9e4370a)) +* update Session.state() and its builder to be of general Map types ([4b9b99a](https://github.com/google/adk-java/commit/4b9b99ae7149a465ba2ae9b7496e01f669786553)) +* update stateDelta builder input to Map from ConcurrentMap ([0d1e5c7](https://github.com/google/adk-java/commit/0d1e5c7b0c42cea66b178cf8fedf08a8c20f7fd0)) + + +### Bug Fixes + +* fix null handling in runAsyncImpl ([567fdf0](https://github.com/google/adk-java/commit/567fdf048fee49afc86ca5d7d35f55424a6016ba)) +* improve processRequest_concurrentReadAndWrite_noException test case ([4eb3613](https://github.com/google/adk-java/commit/4eb3613b65cb1334e9432960d0f864ef09829c23)) +* include saveArtifact invocations in event chain ([551c31f](https://github.com/google/adk-java/commit/551c31f495aafde8568461cc0aa0973d7df7e5ac)) +* prevent ConcurrentModificationException when session events are modified by another thread during iteration ([fca43fb](https://github.com/google/adk-java/commit/fca43fbb9684ec8d080e437761f6bb4e38adf255)) +* Relaxing constraints for output schema ([d7e03ee](https://github.com/google/adk-java/commit/d7e03eeb067b83abd2afa3ea9bb5fc1c16143245)) +* Removing deprecated methods in Runner ([0af82e6](https://github.com/google/adk-java/commit/0af82e61a3c0dbbd95166a10b450cb507115ab60)) +* Use ConcurrentHashMap in InvocationReplayState ([94de7f1](https://github.com/google/adk-java/commit/94de7f199f86b39bdb7cce6e9800eb05008a8953)), closes [#1009](https://github.com/google/adk-java/issues/1009) +* workaround for the client config streaming settings are not respected ([#983](https://github.com/google/adk-java/issues/983)) ([3ba04d3](https://github.com/google/adk-java/commit/3ba04d33dc8f2ef8b151abe1be4d1c8b7afcc25a)) + + +### Miscellaneous Chores + +* remove deprecated Example processor ([28a8cd0](https://github.com/google/adk-java/commit/28a8cd04ca9348dbe51a15d2be3a2b5307394174)) +* set version to 1.0.0-rc.1 ([dc5d794](https://github.com/google/adk-java/commit/dc5d794c066571c7d87f006767bd32298e2a3ba8)) + ## [0.9.0](https://github.com/google/adk-java/compare/v0.8.0...v0.9.0) (2026-03-13) diff --git a/README.md b/README.md index de1cfbef7..169c78564 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,13 @@ If you're using Maven, add the following to your dependencies: com.google.adk google-adk - 0.9.0 + 1.0.0-rc.1 com.google.adk google-adk-dev - 0.9.0 + 1.0.0-rc.1 ``` diff --git a/a2a/pom.xml b/a2a/pom.xml index a2f9d9456..bb9d19328 100644 --- a/a2a/pom.xml +++ b/a2a/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 google-adk-a2a diff --git a/contrib/firestore-session-service/pom.xml b/contrib/firestore-session-service/pom.xml index 0079dce24..e01ebfeae 100644 --- a/contrib/firestore-session-service/pom.xml +++ b/contrib/firestore-session-service/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 ../../pom.xml diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index 3dd2d1132..10d5d9eb9 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 ../../pom.xml diff --git a/contrib/samples/a2a_basic/pom.xml b/contrib/samples/a2a_basic/pom.xml index 0eccb733b..22d146c03 100644 --- a/contrib/samples/a2a_basic/pom.xml +++ b/contrib/samples/a2a_basic/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.9.1-SNAPSHOT + 1.0.0-rc.1 .. diff --git a/contrib/samples/a2a_server/pom.xml b/contrib/samples/a2a_server/pom.xml index 0677ad718..8cebc7ef0 100644 --- a/contrib/samples/a2a_server/pom.xml +++ b/contrib/samples/a2a_server/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.9.1-SNAPSHOT + 1.0.0-rc.1 .. diff --git a/contrib/samples/configagent/pom.xml b/contrib/samples/configagent/pom.xml index 059bd8a38..0486fc639 100644 --- a/contrib/samples/configagent/pom.xml +++ b/contrib/samples/configagent/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.9.1-SNAPSHOT + 1.0.0-rc.1 .. diff --git a/contrib/samples/helloworld/pom.xml b/contrib/samples/helloworld/pom.xml index df5d5e709..39191782b 100644 --- a/contrib/samples/helloworld/pom.xml +++ b/contrib/samples/helloworld/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-samples - 0.9.1-SNAPSHOT + 1.0.0-rc.1 .. diff --git a/contrib/samples/mcpfilesystem/pom.xml b/contrib/samples/mcpfilesystem/pom.xml index 16b139d35..aa1a76333 100644 --- a/contrib/samples/mcpfilesystem/pom.xml +++ b/contrib/samples/mcpfilesystem/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 ../../.. diff --git a/contrib/samples/pom.xml b/contrib/samples/pom.xml index 4a415113f..cd96853cb 100644 --- a/contrib/samples/pom.xml +++ b/contrib/samples/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 ../.. diff --git a/contrib/spring-ai/pom.xml b/contrib/spring-ai/pom.xml index 08d237ab5..7e99da1c2 100644 --- a/contrib/spring-ai/pom.xml +++ b/contrib/spring-ai/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 02c75f88b..b2776f2be 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 google-adk diff --git a/core/src/main/java/com/google/adk/Version.java b/core/src/main/java/com/google/adk/Version.java index a7aeb8b1f..d6e18c5ad 100644 --- a/core/src/main/java/com/google/adk/Version.java +++ b/core/src/main/java/com/google/adk/Version.java @@ -22,7 +22,7 @@ */ public final class Version { // Don't touch this, release-please should keep it up to date. - public static final String JAVA_ADK_VERSION = "0.9.0"; // x-release-please-released-version + public static final String JAVA_ADK_VERSION = "1.0.0-rc.1"; // x-release-please-released-version private Version() {} } diff --git a/dev/pom.xml b/dev/pom.xml index 6cabcba7c..7f9408ae7 100644 --- a/dev/pom.xml +++ b/dev/pom.xml @@ -18,7 +18,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 google-adk-dev diff --git a/maven_plugin/examples/custom_tools/pom.xml b/maven_plugin/examples/custom_tools/pom.xml index f2118f9cc..b76d04309 100644 --- a/maven_plugin/examples/custom_tools/pom.xml +++ b/maven_plugin/examples/custom_tools/pom.xml @@ -4,7 +4,7 @@ com.example custom-tools-example - 0.9.1-SNAPSHOT + 1.0.0-rc.1 jar ADK Custom Tools Example diff --git a/maven_plugin/examples/simple-agent/pom.xml b/maven_plugin/examples/simple-agent/pom.xml index 5c0f4462d..e65412ea8 100644 --- a/maven_plugin/examples/simple-agent/pom.xml +++ b/maven_plugin/examples/simple-agent/pom.xml @@ -4,7 +4,7 @@ com.example simple-adk-agent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 jar Simple ADK Agent Example diff --git a/maven_plugin/pom.xml b/maven_plugin/pom.xml index c48331f72..381e03a9e 100644 --- a/maven_plugin/pom.xml +++ b/maven_plugin/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 ../pom.xml diff --git a/pom.xml b/pom.xml index 0be05a629..368bf67c8 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 pom Google Agent Development Kit Maven Parent POM diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index 76b7331f3..e2e179e48 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 ../../pom.xml diff --git a/tutorials/live-audio-single-agent/pom.xml b/tutorials/live-audio-single-agent/pom.xml index a330cf4bd..8b5e97fb5 100644 --- a/tutorials/live-audio-single-agent/pom.xml +++ b/tutorials/live-audio-single-agent/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.0-rc.1 ../../pom.xml From f3eb936772740b7dc7a803a40d0d39fdbccc4af4 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 06:59:40 -0700 Subject: [PATCH 31/40] fix: Using App conformant agent names PiperOrigin-RevId: 886764077 --- .../adk/agents/AgentWithMemoryTest.java | 4 +-- .../adk/telemetry/ContextPropagationTest.java | 21 +++++++------- .../com/google/adk/testing/TestUtils.java | 2 +- .../com/google/adk/tools/AgentToolTest.java | 28 +++++++++---------- 4 files changed, 27 insertions(+), 28 deletions(-) diff --git a/core/src/test/java/com/google/adk/agents/AgentWithMemoryTest.java b/core/src/test/java/com/google/adk/agents/AgentWithMemoryTest.java index d5edfc876..361c5eb6b 100644 --- a/core/src/test/java/com/google/adk/agents/AgentWithMemoryTest.java +++ b/core/src/test/java/com/google/adk/agents/AgentWithMemoryTest.java @@ -41,7 +41,7 @@ public final class AgentWithMemoryTest { @Test public void agentRemembersUserNameWithMemoryTool() throws Exception { String userId = "test-user"; - String agentName = "test-agent"; + String agentName = "test_agent"; Part functionCall = Part.builder() @@ -101,7 +101,7 @@ public void agentRemembersUserNameWithMemoryTool() throws Exception { Session updatedSession = runner .sessionService() - .getSession("test-agent", userId, sessionId, Optional.empty()) + .getSession("test_agent", userId, sessionId, Optional.empty()) .blockingGet(); // Save the updated session to memory so we can bring it up on the next request. diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 1ee018848..b13904934 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -32,7 +32,6 @@ import com.google.adk.runner.Runner; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; -import com.google.adk.sessions.SessionKey; import com.google.adk.testing.TestLlm; import com.google.adk.testing.TestUtils; import com.google.adk.tools.BaseTool; @@ -653,13 +652,13 @@ public void runnerRunAsync_propagatesContext() throws InterruptedException { @Test public void runnerRunLive_propagatesContext() throws InterruptedException { BaseAgent agent = new TestAgent(); - Runner runner = Runner.builder().agent(agent).appName("test-app").build(); + Runner runner = + Runner.builder().agent(agent).appName("test_app").sessionService(sessionService).build(); Span parentSpan = tracer.spanBuilder("parent").startSpan(); try (Scope s = parentSpan.makeCurrent()) { Session session = - runner - .sessionService() - .createSession("test-app", "test-user", (Map) null, "test-session") + sessionService + .createSession("test_app", "test-user", (Map) null, "test-session") .blockingGet(); Content newMessage = Content.fromParts(Part.fromText("hi")); RunConfig runConfig = RunConfig.builder().build(); @@ -800,12 +799,10 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { } private void runAgent(BaseAgent agent) throws InterruptedException { - Runner runner = Runner.builder().agent(agent).appName("test-app").build(); + Runner runner = + Runner.builder().agent(agent).appName("test_app").sessionService(sessionService).build(); Session session = - runner - .sessionService() - .createSession(new SessionKey("test-app", "test-user", "test-session")) - .blockingGet(); + sessionService.createSession("test_app", "test-user", null, "test-session").blockingGet(); Content newMessage = Content.fromParts(Part.fromText("hi")); RunConfig runConfig = RunConfig.builder().build(); runner @@ -871,7 +868,9 @@ private SpanData findSpanByName(String name) { private InvocationContext buildInvocationContext() { Session session = - sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + sessionService + .createSession("test_app", "test-user", (Map) null, "test-session") + .blockingGet(); return InvocationContext.builder() .sessionService(sessionService) .session(session) diff --git a/core/src/test/java/com/google/adk/testing/TestUtils.java b/core/src/test/java/com/google/adk/testing/TestUtils.java index 70ae14bf1..daed8d2e4 100644 --- a/core/src/test/java/com/google/adk/testing/TestUtils.java +++ b/core/src/test/java/com/google/adk/testing/TestUtils.java @@ -61,7 +61,7 @@ public static InvocationContext createInvocationContext(BaseAgent agent, RunConf .artifactService(new InMemoryArtifactService()) .invocationId("invocationId") .agent(agent) - .session(sessionService.createSession("test-app", "test-user").blockingGet()) + .session(sessionService.createSession("test_app", "test-user").blockingGet()) .userContent(Content.fromParts(Part.fromText("user content"))) .runConfig(runConfig) .build(); diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index 3a5390027..0f168c5df 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -143,7 +143,7 @@ public void declaration_withInputSchema_returnsDeclarationWithSchema() { AgentTool agentTool = AgentTool.create( createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .name("agent name") + .name("agent_name") .description("agent description") .inputSchema(inputSchema) .build()); @@ -153,7 +153,7 @@ public void declaration_withInputSchema_returnsDeclarationWithSchema() { assertThat(declaration) .isEqualTo( FunctionDeclaration.builder() - .name("agent name") + .name("agent_name") .description("agent description") .parameters(inputSchema) .build()); @@ -164,7 +164,7 @@ public void declaration_withoutInputSchema_returnsDeclarationWithRequestParamete AgentTool agentTool = AgentTool.create( createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .name("agent name") + .name("agent_name") .description("agent description") .build()); @@ -173,7 +173,7 @@ public void declaration_withoutInputSchema_returnsDeclarationWithRequestParamete assertThat(declaration) .isEqualTo( FunctionDeclaration.builder() - .name("agent name") + .name("agent_name") .description("agent description") .parameters( Schema.builder() @@ -200,7 +200,7 @@ public void call_withInputSchema_invalidInput_throwsException() throws Exception .build(); LlmAgent testAgent = createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .name("agent name") + .name("agent_name") .description("agent description") .inputSchema(inputSchema) .build(); @@ -256,7 +256,7 @@ public void call_withOutputSchema_invalidOutput_throwsException() throws Excepti "{\"is_valid\": \"invalid type\", " + "\"message\": \"success\"}"))) .build())) - .name("agent name") + .name("agent_name") .description("agent description") .outputSchema(outputSchema) .build(); @@ -301,7 +301,7 @@ public void call_withInputAndOutputSchema_successful() throws Exception { Part.fromText( "{\"is_valid\": true, " + "\"message\": \"success\"}"))) .build())) - .name("agent name") + .name("agent_name") .description("agent description") .inputSchema(inputSchema) .outputSchema(outputSchema) @@ -332,7 +332,7 @@ public void call_withoutSchema_returnsConcatenatedTextFromLastEvent() throws Exc Part.fromText("First text part. "), Part.fromText("Second text part."))) .build()))) - .name("agent name") + .name("agent_name") .description("agent description") .build(); AgentTool agentTool = AgentTool.create(testAgent); @@ -358,7 +358,7 @@ public void call_withThoughts_returnsOnlyNonThoughtText() throws Exception { .build()) .build()); LlmAgent testAgent = - createTestAgentBuilder(testLlm).name("agent name").description("agent description").build(); + createTestAgentBuilder(testLlm).name("agent_name").description("agent description").build(); AgentTool agentTool = AgentTool.create(testAgent); ToolContext toolContext = createToolContext(testAgent); @@ -373,7 +373,7 @@ public void call_emptyModelResponse_returnsEmptyMap() throws Exception { LlmAgent testAgent = createTestAgentBuilder( createTestLlm(LlmResponse.builder().content(Content.builder().build()).build())) - .name("agent name") + .name("agent_name") .description("agent description") .build(); AgentTool agentTool = AgentTool.create(testAgent); @@ -394,7 +394,7 @@ public void call_withInputSchema_argsAreSentToAgent() throws Exception { .build()); LlmAgent testAgent = createTestAgentBuilder(testLlm) - .name("agent name") + .name("agent_name") .description("agent description") .inputSchema( Schema.builder() @@ -422,7 +422,7 @@ public void call_withoutInputSchema_requestIsSentToAgent() throws Exception { .content(Content.fromParts(Part.fromText("test response"))) .build()); LlmAgent testAgent = - createTestAgentBuilder(testLlm).name("agent name").description("agent description").build(); + createTestAgentBuilder(testLlm).name("agent_name").description("agent description").build(); AgentTool agentTool = AgentTool.create(testAgent); ToolContext toolContext = createToolContext(testAgent); @@ -447,7 +447,7 @@ public void call_withStateDeltaInResponse_propagatesStateDelta() throws Exceptio .build()); LlmAgent testAgent = createTestAgentBuilder(testLlm) - .name("agent name") + .name("agent_name") .description("agent description") .afterAgentCallback(afterAgentCallback) .build(); @@ -477,7 +477,7 @@ public void call_withSkipSummarizationAndStateDelta_propagatesStateAndSetsSkipSu .build()); LlmAgent testAgent = createTestAgentBuilder(testLlm) - .name("agent name") + .name("agent_name") .description("agent description") .afterAgentCallback(afterAgentCallback) .build(); From 8ea81f7575580a7316752adc2908efe442511b16 Mon Sep 17 00:00:00 2001 From: adk-java-releases-bot Date: Fri, 20 Mar 2026 07:00:31 -0700 Subject: [PATCH 32/40] chore(main): release 1.0.1-rc.1-SNAPSHOT Merge https://github.com/google/adk-java/pull/1067 :robot: I have created a release *beep* *boop* --- ### Updating meta-information for bleeding-edge SNAPSHOT release. --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-java/pull/1067 from google:release-please--branches--main 0b85f5baba4423b6e73b6dfe2982d8fe22411a98 PiperOrigin-RevId: 886764460 --- a2a/pom.xml | 2 +- contrib/firestore-session-service/pom.xml | 2 +- contrib/langchain4j/pom.xml | 2 +- contrib/samples/a2a_basic/pom.xml | 2 +- contrib/samples/a2a_server/pom.xml | 2 +- contrib/samples/configagent/pom.xml | 2 +- contrib/samples/helloworld/pom.xml | 2 +- contrib/samples/mcpfilesystem/pom.xml | 2 +- contrib/samples/pom.xml | 2 +- contrib/spring-ai/pom.xml | 2 +- core/pom.xml | 2 +- dev/pom.xml | 2 +- maven_plugin/examples/custom_tools/pom.xml | 2 +- maven_plugin/examples/simple-agent/pom.xml | 2 +- maven_plugin/pom.xml | 2 +- pom.xml | 2 +- tutorials/city-time-weather/pom.xml | 2 +- tutorials/live-audio-single-agent/pom.xml | 2 +- 18 files changed, 18 insertions(+), 18 deletions(-) diff --git a/a2a/pom.xml b/a2a/pom.xml index bb9d19328..97c186606 100644 --- a/a2a/pom.xml +++ b/a2a/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT google-adk-a2a diff --git a/contrib/firestore-session-service/pom.xml b/contrib/firestore-session-service/pom.xml index e01ebfeae..ed1ecd09b 100644 --- a/contrib/firestore-session-service/pom.xml +++ b/contrib/firestore-session-service/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT ../../pom.xml diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index 10d5d9eb9..e88174849 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT ../../pom.xml diff --git a/contrib/samples/a2a_basic/pom.xml b/contrib/samples/a2a_basic/pom.xml index 22d146c03..80881842b 100644 --- a/contrib/samples/a2a_basic/pom.xml +++ b/contrib/samples/a2a_basic/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT .. diff --git a/contrib/samples/a2a_server/pom.xml b/contrib/samples/a2a_server/pom.xml index 8cebc7ef0..61c44bc97 100644 --- a/contrib/samples/a2a_server/pom.xml +++ b/contrib/samples/a2a_server/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT .. diff --git a/contrib/samples/configagent/pom.xml b/contrib/samples/configagent/pom.xml index 0486fc639..8f57b7f9e 100644 --- a/contrib/samples/configagent/pom.xml +++ b/contrib/samples/configagent/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT .. diff --git a/contrib/samples/helloworld/pom.xml b/contrib/samples/helloworld/pom.xml index 39191782b..676a2bc96 100644 --- a/contrib/samples/helloworld/pom.xml +++ b/contrib/samples/helloworld/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-samples - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT .. diff --git a/contrib/samples/mcpfilesystem/pom.xml b/contrib/samples/mcpfilesystem/pom.xml index aa1a76333..7275313ab 100644 --- a/contrib/samples/mcpfilesystem/pom.xml +++ b/contrib/samples/mcpfilesystem/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT ../../.. diff --git a/contrib/samples/pom.xml b/contrib/samples/pom.xml index cd96853cb..ff48d6bd3 100644 --- a/contrib/samples/pom.xml +++ b/contrib/samples/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT ../.. diff --git a/contrib/spring-ai/pom.xml b/contrib/spring-ai/pom.xml index 7e99da1c2..5f7300896 100644 --- a/contrib/spring-ai/pom.xml +++ b/contrib/spring-ai/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index b2776f2be..8559f396d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT google-adk diff --git a/dev/pom.xml b/dev/pom.xml index 7f9408ae7..5468a1187 100644 --- a/dev/pom.xml +++ b/dev/pom.xml @@ -18,7 +18,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT google-adk-dev diff --git a/maven_plugin/examples/custom_tools/pom.xml b/maven_plugin/examples/custom_tools/pom.xml index b76d04309..aa273d732 100644 --- a/maven_plugin/examples/custom_tools/pom.xml +++ b/maven_plugin/examples/custom_tools/pom.xml @@ -4,7 +4,7 @@ com.example custom-tools-example - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT jar ADK Custom Tools Example diff --git a/maven_plugin/examples/simple-agent/pom.xml b/maven_plugin/examples/simple-agent/pom.xml index e65412ea8..34aeb8c1c 100644 --- a/maven_plugin/examples/simple-agent/pom.xml +++ b/maven_plugin/examples/simple-agent/pom.xml @@ -4,7 +4,7 @@ com.example simple-adk-agent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT jar Simple ADK Agent Example diff --git a/maven_plugin/pom.xml b/maven_plugin/pom.xml index 381e03a9e..d0feb41e3 100644 --- a/maven_plugin/pom.xml +++ b/maven_plugin/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index 368bf67c8..cbeca1b72 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT pom Google Agent Development Kit Maven Parent POM diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index e2e179e48..aeb110cf6 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT ../../pom.xml diff --git a/tutorials/live-audio-single-agent/pom.xml b/tutorials/live-audio-single-agent/pom.xml index 8b5e97fb5..99243893b 100644 --- a/tutorials/live-audio-single-agent/pom.xml +++ b/tutorials/live-audio-single-agent/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.0.0-rc.1 + 1.0.1-rc.1-SNAPSHOT ../../pom.xml From 40ca6a7c5163f711e02a54163d6066f7cd86e64d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 07:11:54 -0700 Subject: [PATCH 33/40] feat: enabling output_schema and tools to coexist This CL enables the simultaneous use of `output_schema` (structured output) and `tools` for models that do not natively support both features at once (specifically Gemini 1.x and 2.x on Vertex AI). ### Core Logic The CL implements a workaround for models with this limitation: 1. **Synthetic Tooling**: Instead of passing the `output_schema` directly to the model's configuration, it introduces a synthetic tool called `set_model_response`. 2. **Schema Injection**: The parameters of this tool are set to the requested `output_schema`. 3. **Instruction Prompting**: System instructions are appended, directing the model to provide its final response using this specific tool in the required format. 4. **Response Interception**: The `BaseLlmFlow` is updated to check if `set_model_response` was called. If so, it extracts the JSON arguments and converts them into a standard model response event. ### Key Changes * **`OutputSchema.java` (New)**: A new `RequestProcessor` that detects when the workaround is needed, adds the `SetModelResponseTool`, and provides utilities for extracting the structured response. * **`SetModelResponseTool.java` (New)**: A marker tool that simply returns its input arguments, used to "capture" the structured output from the model. * **`ModelNameUtils.java`**: Added logic to identify Gemini 1.x and 2.x models and determine if they can handle native `output_schema` alongside tools. * **`BaseLlmFlow.java`**: Updated the flow logic to detect the synthetic tool response and generate the final output event. * **`Basic.java`**: Updated to prevent native `outputSchema` configuration when the workaround is active. * **`SingleFlow.java`**: Registered the new `OutputSchema` processor. PiperOrigin-RevId: 886769688 --- .../adk/flows/llmflows/BaseLlmFlow.java | 11 +- .../com/google/adk/flows/llmflows/Basic.java | 11 +- .../adk/flows/llmflows/OutputSchema.java | 119 +++++++++++ .../google/adk/flows/llmflows/SingleFlow.java | 1 + .../adk/tools/SetModelResponseTool.java | 63 ++++++ .../com/google/adk/utils/ModelNameUtils.java | 18 +- .../adk/flows/llmflows/OutputSchemaTest.java | 192 ++++++++++++++++++ 7 files changed, 410 insertions(+), 5 deletions(-) create mode 100644 core/src/main/java/com/google/adk/flows/llmflows/OutputSchema.java create mode 100644 core/src/main/java/com/google/adk/tools/SetModelResponseTool.java create mode 100644 core/src/test/java/com/google/adk/flows/llmflows/OutputSchemaTest.java diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 8fabc978d..d4fe1b838 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -692,9 +692,14 @@ private Flowable buildPostprocessingEvents( Optional toolConfirmationEvent = Functions.generateRequestConfirmationEvent( context, modelResponseEvent, functionResponseEvent); - return toolConfirmationEvent.isPresent() - ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) - : Flowable.just(functionResponseEvent); + List events = new ArrayList<>(); + toolConfirmationEvent.ifPresent(events::add); + events.add(functionResponseEvent); + OutputSchema.getStructuredModelResponse(functionResponseEvent) + .ifPresent( + json -> + events.add(OutputSchema.createFinalModelResponseEvent(context, json))); + return Flowable.fromIterable(events); }); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Basic.java b/core/src/main/java/com/google/adk/flows/llmflows/Basic.java index 0876a26e8..5aa970be6 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Basic.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Basic.java @@ -19,6 +19,7 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.models.LlmRequest; +import com.google.adk.utils.ModelNameUtils; import com.google.common.collect.ImmutableList; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.LiveConnectConfig; @@ -60,7 +61,15 @@ public Single processRequest( .orElseGet(() -> GenerateContentConfig.builder().build())) .liveConnectConfig(liveConnectConfigBuilder.build()); - agent.outputSchema().ifPresent(builder::outputSchema); + agent + .outputSchema() + .ifPresent( + outputSchema -> { + if (agent.toolsUnion().isEmpty() + || ModelNameUtils.canUseOutputSchemaWithTools(modelName)) { + builder.outputSchema(outputSchema); + } + }); return Single.just( RequestProcessor.RequestProcessingResult.create(builder.build(), ImmutableList.of())); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/OutputSchema.java b/core/src/main/java/com/google/adk/flows/llmflows/OutputSchema.java new file mode 100644 index 000000000..d1f322f18 --- /dev/null +++ b/core/src/main/java/com/google/adk/flows/llmflows/OutputSchema.java @@ -0,0 +1,119 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.flows.llmflows; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.adk.JsonBaseModel; +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.tools.SetModelResponseTool; +import com.google.adk.tools.ToolContext; +import com.google.adk.utils.ModelNameUtils; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Single; +import java.util.Objects; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Processor that handles output schema for agents with tools. */ +public final class OutputSchema implements RequestProcessor { + + private static final Logger logger = LoggerFactory.getLogger(OutputSchema.class); + + public OutputSchema() {} + + @Override + public Single processRequest( + InvocationContext context, LlmRequest request) { + if (!(context.agent() instanceof LlmAgent)) { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())); + } + LlmAgent agent = (LlmAgent) context.agent(); + String modelName = request.model().orElse(""); + + if (agent.outputSchema().isEmpty() + || agent.toolsUnion().isEmpty() + || ModelNameUtils.canUseOutputSchemaWithTools(modelName)) { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())); + } + + // Add the set_model_response tool to handle structured output + SetModelResponseTool setResponseTool = new SetModelResponseTool(agent.outputSchema().get()); + LlmRequest.Builder builder = request.toBuilder(); + + return setResponseTool + .processLlmRequest(builder, ToolContext.builder(context).build()) + .andThen( + Single.fromCallable( + () -> { + builder.appendInstructions( + ImmutableList.of( + "IMPORTANT: You have access to other tools, but you must provide your" + + " final response using the set_model_response tool with the" + + " required structured format. After using any other tools needed" + + " to complete the task, always call set_model_response with your" + + " final answer in the specified schema format.")); + return RequestProcessingResult.create(builder.build(), ImmutableList.of()); + })); + } + + /** + * Check if function response contains set_model_response and extract JSON. + * + * @param functionResponseEvent The function response event to check. + * @return JSON response string if set_model_response was called, Optional.empty() otherwise. + */ + public static Optional getStructuredModelResponse(Event functionResponseEvent) { + for (FunctionResponse funcResponse : functionResponseEvent.functionResponses()) { + if (Objects.equals(funcResponse.name().orElse(""), SetModelResponseTool.NAME)) { + Object response = funcResponse.response(); + // The tool returns the args map directly. + try { + return Optional.of(JsonBaseModel.getMapper().writeValueAsString(response)); + } catch (JsonProcessingException e) { + logger.error("Failed to serialize set_model_response result", e); + return Optional.empty(); + } + } + } + return Optional.empty(); + } + + /** + * Create a final model response event from set_model_response JSON. + * + * @param context The invocation context. + * @param jsonResponse The JSON response from set_model_response tool. + * @return A new Event that looks like a normal model response. + */ + public static Event createFinalModelResponseEvent( + InvocationContext context, String jsonResponse) { + return Event.builder() + .id(Event.generateEventId()) + .invocationId(context.invocationId()) + .author(context.agent().name()) + .branch(context.branch().orElse(null)) + .content(Content.builder().role("model").parts(Part.fromText(jsonResponse)).build()) + .build(); + } +} diff --git a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java index f56cc61c3..41dff3b96 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java @@ -27,6 +27,7 @@ public class SingleFlow extends BaseLlmFlow { protected static final ImmutableList REQUEST_PROCESSORS = ImmutableList.of( new Basic(), + new OutputSchema(), new RequestConfirmationLlmRequestProcessor(), new Instructions(), new Identity(), diff --git a/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java new file mode 100644 index 000000000..e23d6414a --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java @@ -0,0 +1,63 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import java.util.Optional; +import javax.annotation.Nonnull; + +/** + * Internal tool used for output schema workaround. + * + *

This tool allows the model to set its final response when output_schema is configured + * alongside other tools. The model should use this tool to provide its final structured response + * instead of outputting text directly. + */ +public class SetModelResponseTool extends BaseTool { + public static final String NAME = "set_model_response"; + + private final Schema outputSchema; + + public SetModelResponseTool(@Nonnull Schema outputSchema) { + super( + NAME, + "Set your final response using the required output schema. " + + "After using any other tools needed to complete the task, always call" + + " set_model_response with your final answer in the specified schema format."); + this.outputSchema = outputSchema; + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name(name()) + .description(description()) + .parameters(outputSchema) + .build()); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + // This tool is a marker for the final response, it doesn't do anything but return its arguments + // which will be captured as the final result. + return Single.just(args); + } +} diff --git a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java index c46f6e3a8..cf0f2221e 100644 --- a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java +++ b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java @@ -21,6 +21,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +/** Utility class for model names. */ public final class ModelNameUtils { private static final String GEMINI_PREFIX = "gemini-"; private static final Pattern GEMINI_2_PATTERN = Pattern.compile("^gemini-2\\..*"); @@ -35,11 +36,15 @@ public static boolean isGeminiModel(String modelString) { } public static boolean isGemini2Model(String modelString) { + return matchesModelPattern(modelString, GEMINI_2_PATTERN); + } + + private static boolean matchesModelPattern(String modelString, Pattern pattern) { if (modelString == null) { return false; } String modelName = extractModelName(modelString); - return GEMINI_2_PATTERN.matcher(modelName).matches(); + return pattern.matcher(modelName).matches(); } /** @@ -65,6 +70,17 @@ public static boolean isInstanceOfGemini(Object o) { return false; } + /** + * Returns true if the model supports using output schema together with tools. + * + * @param modelString The model name or path. + * @return true if output schema with tools is supported, false otherwise. + */ + public static boolean canUseOutputSchemaWithTools(String modelString) { + // Current limitation for Vertex AI 2.x models. + return !isGemini2Model(modelString); + } + /** * Extract the actual model name from either simple or path-based format. * diff --git a/core/src/test/java/com/google/adk/flows/llmflows/OutputSchemaTest.java b/core/src/test/java/com/google/adk/flows/llmflows/OutputSchemaTest.java new file mode 100644 index 000000000..ffd56de6c --- /dev/null +++ b/core/src/test/java/com/google/adk/flows/llmflows/OutputSchemaTest.java @@ -0,0 +1,192 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.flows.llmflows; + +import static com.google.adk.testing.TestUtils.createInvocationContext; +import static com.google.adk.testing.TestUtils.createTestLlm; +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.events.Event; +import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.testing.TestLlm; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.SetModelResponseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import com.google.genai.types.Schema; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class OutputSchemaTest { + + private static final Schema TEST_OUTPUT_SCHEMA = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + private OutputSchema outputSchemaProcessor; + private TestLlm testLlm; + private LlmRequest initialRequest; + + @Before + public void setUp() { + outputSchemaProcessor = new OutputSchema(); + testLlm = createTestLlm(LlmResponse.builder().build()); + initialRequest = LlmRequest.builder().model("gemini-2.0-pro").build(); + } + + public static class TestTool extends BaseTool { + public TestTool() { + super("test_tool", "test description"); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + return Single.just(ImmutableMap.of()); + } + } + + @Test + public void processRequest_noOutputSchema_doesNothing() { + LlmAgent agent = + LlmAgent.builder() + .name("agent") + .model(testLlm) + .tools(ImmutableList.of(new TestTool())) + .build(); + InvocationContext context = createInvocationContext(agent); + + RequestProcessingResult result = + outputSchemaProcessor.processRequest(context, initialRequest).blockingGet(); + + assertThat(result.updatedRequest()).isEqualTo(initialRequest); + assertThat(result.events()).isEmpty(); + } + + @Test + public void processRequest_noTools_doesNothing() { + LlmAgent agent = + LlmAgent.builder().name("agent").model(testLlm).outputSchema(TEST_OUTPUT_SCHEMA).build(); + InvocationContext context = createInvocationContext(agent); + + RequestProcessingResult result = + outputSchemaProcessor.processRequest(context, initialRequest).blockingGet(); + + assertThat(result.updatedRequest()).isEqualTo(initialRequest); + assertThat(result.events()).isEmpty(); + } + + @Test + public void processRequest_withOutputSchemaAndTools_addsSetModelResponseTool() { + LlmAgent agent = + LlmAgent.builder() + .name("agent") + .model(testLlm) + .outputSchema(TEST_OUTPUT_SCHEMA) + .tools(ImmutableList.of(new TestTool())) + .build(); + InvocationContext context = createInvocationContext(agent); + LlmRequest requestWithTools = + LlmRequest.builder() + .model("gemini-2.5-pro") + .tools(ImmutableMap.of("test_tool", new TestTool())) + .build(); + + RequestProcessingResult result = + outputSchemaProcessor.processRequest(context, requestWithTools).blockingGet(); + + LlmRequest updatedRequest = result.updatedRequest(); + assertThat(updatedRequest.tools()).hasSize(2); + assertThat( + updatedRequest.tools().values().stream() + .anyMatch(t -> t instanceof SetModelResponseTool)) + .isTrue(); + assertThat(updatedRequest.tools().values().stream().anyMatch(t -> t.name().equals("test_tool"))) + .isTrue(); + assertThat(updatedRequest.getSystemInstructions()).isNotEmpty(); + assertThat(updatedRequest.getSystemInstructions().get(0)) + .contains("you must provide your final response using the set_model_response tool"); + assertThat(result.events()).isEmpty(); + } + + @Test + public void getStructuredModelResponse_withSetModelResponse_returnsJson() { + FunctionResponse fr = + FunctionResponse.builder() + .name(SetModelResponseTool.NAME) + .response(ImmutableMap.of("field1", "value1")) + .build(); + Event event = + Event.builder() + .content( + Content.builder() + .parts(Part.builder().functionResponse(fr).build()) + .role("model") + .build()) + .build(); + + assertThat(OutputSchema.getStructuredModelResponse(event)).hasValue("{\"field1\":\"value1\"}"); + } + + @Test + public void getStructuredModelResponse_withoutSetModelResponse_returnsEmpty() { + FunctionResponse fr = + FunctionResponse.builder() + .name("other_tool") + .response(ImmutableMap.of("field1", "value1")) + .build(); + Event event = + Event.builder() + .content( + Content.builder() + .parts(Part.builder().functionResponse(fr).build()) + .role("model") + .build()) + .build(); + + assertThat(OutputSchema.getStructuredModelResponse(event)).isEmpty(); + } + + @Test + public void createFinalModelResponseEvent_createsModelResponseEvent() { + LlmAgent agent = LlmAgent.builder().name("agent").model(testLlm).build(); + InvocationContext context = createInvocationContext(agent); + String jsonResponse = "{\"field1\":\"value1\"}"; + + Event event = OutputSchema.createFinalModelResponseEvent(context, jsonResponse); + + assertThat(event.invocationId()).isEqualTo(context.invocationId()); + assertThat(event.author()).isEqualTo("agent"); + assertThat(event.content().get().role()).hasValue("model"); + assertThat(event.content().get().parts().get()).containsExactly(Part.fromText(jsonResponse)); + } +} From 70056707f42281772bd737e2c7fd5878181c7c37 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Fri, 20 Mar 2026 16:47:40 +0100 Subject: [PATCH 34/40] refactor: migrate LangChain4j to builder pattern, enhance token usage, and use JSpecify Nullable - Migrate LangChain4j to a builder pattern - Enhance token usage handling with TokenCountEstimator (from PR #623) - Upgrade to latest version of LangChain4j - Replace javax.annotation.Nullable with org.jspecify.annotations.Nullable --- .../adk/models/langchain4j/LangChain4j.java | 230 ++++++++++++------ .../LangChain4jIntegrationTest.java | 24 +- .../models/langchain4j/LangChain4jTest.java | 162 +++++++++++- pom.xml | 2 +- 4 files changed, 327 insertions(+), 91 deletions(-) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 3ccb1e029..8279dc21a 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -23,6 +23,7 @@ import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.auto.value.AutoValue; import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; @@ -30,11 +31,11 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import com.google.genai.types.Schema; import com.google.genai.types.ToolConfig; import com.google.genai.types.Type; -import dev.langchain4j.Experimental; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.audio.Audio; @@ -52,6 +53,7 @@ import dev.langchain4j.data.pdf.PdfFile; import dev.langchain4j.data.video.Video; import dev.langchain4j.exception.UnsupportedFeatureException; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -65,6 +67,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; @@ -72,66 +75,101 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.UUID; +import org.jspecify.annotations.Nullable; -@Experimental -public class LangChain4j extends BaseLlm { +@AutoValue +public abstract class LangChain4j extends BaseLlm { private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference<>() {}; - private final ChatModel chatModel; - private final StreamingChatModel streamingChatModel; - private final ObjectMapper objectMapper; + LangChain4j() { + super(""); + } + + @Nullable + public abstract ChatModel chatModel(); + + @Nullable + public abstract StreamingChatModel streamingChatModel(); + + public abstract ObjectMapper objectMapper(); + + public abstract String modelName(); + + @Nullable + public abstract TokenCountEstimator tokenCountEstimator(); + + @Override + public String model() { + return modelName(); + } + + public static Builder builder() { + return new AutoValue_LangChain4j.Builder().objectMapper(new ObjectMapper()); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder chatModel(ChatModel chatModel); + + public abstract Builder streamingChatModel(StreamingChatModel streamingChatModel); + + public abstract Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator); + + public abstract Builder objectMapper(ObjectMapper objectMapper); + + public abstract Builder modelName(String modelName); + + public abstract LangChain4j build(); + } public LangChain4j(ChatModel chatModel) { - super( - Objects.requireNonNull( - chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = null; - this.objectMapper = new ObjectMapper(); + this(chatModel, null, null, chatModel.defaultRequestParameters().modelName(), null); } public LangChain4j(ChatModel chatModel, String modelName) { - super(Objects.requireNonNull(modelName, "chat model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = null; - this.objectMapper = new ObjectMapper(); + this(chatModel, null, null, modelName, null); } public LangChain4j(StreamingChatModel streamingChatModel) { - super( - Objects.requireNonNull( - streamingChatModel.defaultRequestParameters().modelName(), - "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this( + null, + streamingChatModel, + null, + streamingChatModel.defaultRequestParameters().modelName(), + null); } public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(null, streamingChatModel, null, modelName, null); } public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(chatModel, streamingChatModel, null, modelName, null); + } + + private LangChain4j( + ChatModel chatModel, + StreamingChatModel streamingChatModel, + ObjectMapper objectMapper, + String modelName, + TokenCountEstimator tokenCountEstimator) { + this(); + LangChain4j.builder() + .chatModel(chatModel) + .streamingChatModel(streamingChatModel) + .objectMapper(objectMapper) + .modelName(modelName) + .tokenCountEstimator(tokenCountEstimator) + .build(); } @Override public Flowable generateContent(LlmRequest llmRequest, boolean stream) { if (stream) { - if (this.streamingChatModel == null) { + if (this.streamingChatModel() == null) { return Flowable.error(new IllegalStateException("StreamingChatModel is not configured")); } @@ -139,54 +177,57 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre return Flowable.create( emitter -> { - streamingChatModel.chat( - chatRequest, - new StreamingChatResponseHandler() { - @Override - public void onPartialResponse(String s) { - emitter.onNext( - LlmResponse.builder().content(Content.fromParts(Part.fromText(s))).build()); - } - - @Override - public void onCompleteResponse(ChatResponse chatResponse) { - if (chatResponse.aiMessage().hasToolExecutionRequests()) { - AiMessage aiMessage = chatResponse.aiMessage(); - toParts(aiMessage).stream() - .map(Part::functionCall) - .forEach( - functionCall -> { - functionCall.ifPresent( - function -> { - emitter.onNext( - LlmResponse.builder() - .content( - Content.fromParts( - Part.fromFunctionCall( - function.name().orElse(""), - function.args().orElse(Map.of())))) - .build()); - }); - }); - } - emitter.onComplete(); - } - - @Override - public void onError(Throwable throwable) { - emitter.onError(throwable); - } - }); + streamingChatModel() + .chat( + chatRequest, + new StreamingChatResponseHandler() { + @Override + public void onPartialResponse(String s) { + emitter.onNext( + LlmResponse.builder() + .content(Content.fromParts(Part.fromText(s))) + .build()); + } + + @Override + public void onCompleteResponse(ChatResponse chatResponse) { + if (chatResponse.aiMessage().hasToolExecutionRequests()) { + AiMessage aiMessage = chatResponse.aiMessage(); + toParts(aiMessage).stream() + .map(Part::functionCall) + .forEach( + functionCall -> { + functionCall.ifPresent( + function -> { + emitter.onNext( + LlmResponse.builder() + .content( + Content.fromParts( + Part.fromFunctionCall( + function.name().orElse(""), + function.args().orElse(Map.of())))) + .build()); + }); + }); + } + emitter.onComplete(); + } + + @Override + public void onError(Throwable throwable) { + emitter.onError(throwable); + } + }); }, BackpressureStrategy.BUFFER); } else { - if (this.chatModel == null) { + if (this.chatModel() == null) { return Flowable.error(new IllegalStateException("ChatModel is not configured")); } ChatRequest chatRequest = toChatRequest(llmRequest); - ChatResponse chatResponse = chatModel.chat(chatRequest); - LlmResponse llmResponse = toLlmResponse(chatResponse); + ChatResponse chatResponse = chatModel().chat(chatRequest); + LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest); return Flowable.just(llmResponse); } @@ -413,7 +454,7 @@ private AiMessage toAiMessage(Content content) { private String toJson(Object object) { try { - return objectMapper.writeValueAsString(object); + return objectMapper().writeValueAsString(object); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -511,11 +552,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) { } } - private LlmResponse toLlmResponse(ChatResponse chatResponse) { + private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) { Content content = Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build(); - return LlmResponse.builder().content(content).build(); + LlmResponse.Builder builder = LlmResponse.builder().content(content); + TokenUsage tokenUsage = chatResponse.tokenUsage(); + if (tokenCountEstimator() != null) { + try { + int estimatedInput = + tokenCountEstimator().estimateTokenCountInMessages(chatRequest.messages()); + int estimatedOutput = + tokenCountEstimator().estimateTokenCountInText(chatResponse.aiMessage().text()); + int estimatedTotal = estimatedInput + estimatedOutput; + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(estimatedInput) + .candidatesTokenCount(estimatedOutput) + .totalTokenCount(estimatedTotal) + .build()); + } catch (Exception e) { + e.printStackTrace(); + } + } else if (tokenUsage != null) { + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(tokenUsage.inputTokenCount()) + .candidatesTokenCount(tokenUsage.outputTokenCount()) + .totalTokenCount(tokenUsage.totalTokenCount()) + .build()); + } + + return builder.build(); } private List toParts(AiMessage aiMessage) { @@ -546,7 +614,7 @@ private List toParts(AiMessage aiMessage) { private Map toArgs(ToolExecutionRequest toolExecutionRequest) { try { - return objectMapper.readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); + return objectMapper().readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); } catch (JsonProcessingException e) { throw new RuntimeException(e); } diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index 191e48017..5b6d3f3ad 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -62,7 +62,8 @@ void testSimpleAgent() { LlmAgent.builder() .name("science-app") .description("Science teacher agent") - .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) + .model( + LangChain4j.builder().chatModel(claudeModel).modelName(CLAUDE_4_6_SONNET).build()) .instruction( """ You are a helpful science teacher that explains science concepts @@ -98,7 +99,8 @@ void testSingleAgentWithTools() { LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) + .model( + LangChain4j.builder().chatModel(claudeModel).modelName(CLAUDE_4_6_SONNET).build()) .instruction( """ You are a friendly assistant. @@ -183,7 +185,7 @@ void testAgentTool() { LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly assistant. @@ -246,7 +248,7 @@ void testSubAgent() { LlmAgent.builder() .name("greeterAgent") .description("Friendly agent that greets users") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly that greets users. @@ -257,7 +259,7 @@ void testSubAgent() { LlmAgent.builder() .name("farewellAgent") .description("Friendly agent that says goodbye to users") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly that says goodbye to users. @@ -355,7 +357,11 @@ void testSimpleStreamingResponse() { .modelName(CLAUDE_4_6_SONNET) .build(); - LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_4_6_SONNET); + LangChain4j lc4jClaude = + LangChain4j.builder() + .streamingChatModel(claudeStreamingModel) + .modelName(CLAUDE_4_6_SONNET) + .build(); // when Flowable responses = @@ -413,7 +419,11 @@ void testStreamingRunConfig() { When someone greets you, respond with "Hello". If someone asks about the weather, call the `getWeather` function. """) - .model(new LangChain4j(streamingModel, "GPT_4_O_MINI")) + .model( + LangChain4j.builder() + .streamingChatModel(streamingModel) + .modelName("GPT_4_O_MINI") + .build()) // .model(new LangChain4j(streamingModel, // CLAUDE_3_7_SONNET_20250219)) .tools(FunctionTool.create(ToolExample.class, "getWeather")) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 076bb79a3..f88237ff1 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -19,6 +19,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.tools.FunctionTool; @@ -26,6 +27,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -33,6 +35,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; import java.util.List; @@ -57,8 +60,26 @@ void setUp() { chatModel = mock(ChatModel.class); streamingChatModel = mock(StreamingChatModel.class); - langChain4j = new LangChain4j(chatModel, MODEL_NAME); - streamingLangChain4j = new LangChain4j(streamingChatModel, MODEL_NAME); + langChain4j = LangChain4j.builder().chatModel(chatModel).modelName(MODEL_NAME).build(); + streamingLangChain4j = + LangChain4j.builder().streamingChatModel(streamingChatModel).modelName(MODEL_NAME).build(); + } + + @Test + void testBuilder() { + ObjectMapper customMapper = new ObjectMapper(); + LangChain4j customLc4j = + LangChain4j.builder() + .chatModel(chatModel) + .streamingChatModel(streamingChatModel) + .objectMapper(customMapper) + .modelName("custom-model") + .build(); + + assertThat(customLc4j.chatModel()).isEqualTo(chatModel); + assertThat(customLc4j.streamingChatModel()).isEqualTo(streamingChatModel); + assertThat(customLc4j.objectMapper()).isEqualTo(customMapper); + assertThat(customLc4j.modelName()).isEqualTo("custom-model"); } @Test @@ -812,4 +833,141 @@ void testGenerateContentWithMcpToolParametersJsonSchemaAsSchema() { assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); } + + @Test + @DisplayName( + "Should use TokenCountEstimator to estimate token usage when TokenUsage is not available") + void testTokenCountEstimatorFallback() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(50); // Input tokens + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(20); // Output tokens + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage (simulating when LLM doesn't provide token counts) + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response has usage metadata estimated by TokenCountEstimator + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("The weather is sunny today."); + + // IMPORTANT: Verify that token usage was estimated via the TokenCountEstimator + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(20)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(70)); // 50 + 20 + + // Verify the estimator was actually called + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should prioritize TokenCountEstimator over TokenUsage when estimator is provided") + void testTokenCountEstimatorPriority() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(100); // From estimator + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(50); // From estimator + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITH actual TokenUsage from the LLM + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + final TokenUsage actualTokenUsage = new TokenUsage(30, 15, 45); // Actual token counts from LLM + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(actualTokenUsage); // LLM provides token usage + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // IMPORTANT: When TokenCountEstimator is present, it takes priority over TokenUsage + assertThat(response).isNotNull(); + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(100)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(150)); // 100 + 50 + + // Verify the estimator was called (it takes priority) + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should not include usageMetadata when TokenUsage is null and no estimator provided") + void testNoUsageMetadataWithoutEstimator() { + // Given + // Create LangChain4j WITHOUT TokenCountEstimator (default behavior) + final LangChain4j langChain4jNoEstimator = + LangChain4j.builder().chatModel(chatModel).modelName(MODEL_NAME).build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Hello, world!")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("Hello! How can I help you?"); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jNoEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response does NOT have usage metadata + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Hello! How can I help you?"); + + // IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator + assertThat(response.usageMetadata()).isEmpty(); + } } diff --git a/pom.xml b/pom.xml index cbeca1b72..40332472f 100644 --- a/pom.xml +++ b/pom.xml @@ -62,7 +62,7 @@ 0.18.1 3.41.0 3.9.0 - 1.11.0 + 1.12.2 2.0.17 1.4.5 1.0.0 From 3633a7dd071265087ea2ff148d419969b0c888ef Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 13:15:38 -0700 Subject: [PATCH 35/40] fix: Removing deprecated methods from Runner PiperOrigin-RevId: 886942637 --- .../java/com/google/adk/runner/Runner.java | 42 ------------------- 1 file changed, 42 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 1f7d924ab..849a3cd04 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -425,36 +425,6 @@ public Flowable runAsync(String userId, String sessionId, Content newMess return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); } - /** - * See {@link #runAsync(Session, Content, RunConfig, Map)}. - * - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { - return runAsync(session, newMessage, runConfig, /* stateDelta= */ null); - } - - /** - * Runs the agent asynchronously using a provided Session object. - * - * @param session The session to run the agent in. - * @param newMessage The new message from the user to process. - * @param runConfig Configuration for the agent run. - * @param stateDelta Optional map of state updates to merge into the session for this run. - * @return A Flowable stream of {@link Event} objects generated by the agent during execution. - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync( - Session session, - Content newMessage, - RunConfig runConfig, - @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta) - .compose(Tracing.trace("invocation")); - } - /** * Runs the agent asynchronously using a provided Session object. * @@ -735,18 +705,6 @@ protected Flowable runLiveImpl( }); } - /** - * Runs the agent asynchronously with a default user ID. - * - * @return stream of generated events. - */ - @Deprecated(since = "0.5.0", forRemoval = true) - public Flowable runWithSessionId( - String sessionId, Content newMessage, RunConfig runConfig) { - // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. - return this.runAsync("tmp-user", sessionId, newMessage, runConfig); - } - /** * Checks if the agent and its parent chain allow transfer up the tree. * From 8e9fb085354f8148e00cbd236e8f29e82de56d6e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 13:56:07 -0700 Subject: [PATCH 36/40] refactor: Use concatMap for sequential event persistence in Runner Ensure sequential event processing and persistence in ADK Runner. This ensures that events are appended in order and returned from runAsync in order. This aligns better with the Python implementation. PiperOrigin-RevId: 886961696 --- .../java/com/google/adk/runner/Runner.java | 2 +- .../com/google/adk/runner/RunnerTest.java | 42 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 849a3cd04..2bfbca881 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -529,7 +529,7 @@ private Flowable runAgentWithFreshSession( contextWithUpdatedSession .agent() .runAsync(contextWithUpdatedSession) - .flatMap( + .concatMap( agentEvent -> this.sessionService .appendEvent(updatedSession, agentEvent) diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index a3e21cb73..efd565c16 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -26,6 +26,7 @@ import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Arrays.stream; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.mock; @@ -33,6 +34,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; @@ -43,6 +45,7 @@ import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; +import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; import com.google.adk.summarizer.EventsCompactionConfig; @@ -851,6 +854,45 @@ public void beforeRunCallback_withStateDelta_seesMergedState() { assertThat(sessionInCallback.state()).containsEntry("number", 123); } + @Test + public void runAsync_ensureEventsAreAppendedInOrder() throws Exception { + Event event1 = TestUtils.createEvent("1"); + Event event2 = TestUtils.createEvent("2"); + BaseAgent mockAgent = TestUtils.createSubAgent("test agent", event1, event2); + + BaseSessionService mockSessionService = mock(BaseSessionService.class); + + when(mockSessionService.getSession(any(), any(), any(), any())).thenReturn(Maybe.just(session)); + when(mockSessionService.appendEvent(any(), any())) + .thenAnswer( + invocation -> { + Event eventArg = invocation.getArgument(1); + Single result = Single.just(eventArg); + if (eventArg.id().equals("1")) { + // Artificially delay the first event to ensure it is appended first. + return result.delay(100, MILLISECONDS); + } + return result; + }); + + Runner mockRunner = + Runner.builder() + .agent(mockAgent) + .appName("test") + .sessionService(mockSessionService) + .build(); + + List results = + mockRunner + .runAsync("user", session.id(), createContent("user message")) + .toList() + .blockingGet(); + + assertThat(simplifyEvents(results)) + .containsExactly("author: content for event 1", "author: content for event 2") + .inOrder(); + } + private Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); } From 3e21e7ac46b634341819b3543388a38caef85516 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Sat, 21 Mar 2026 20:11:12 +0100 Subject: [PATCH 37/40] fix: handle null `AiMessage.text()` to prevent NPE and add unit test (PR #1035) --- .../adk/models/langchain4j/LangChain4j.java | 7 ++++-- .../models/langchain4j/LangChain4jTest.java | 23 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 8279dc21a..97331e7b4 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -607,8 +607,11 @@ private List toParts(AiMessage aiMessage) { }); return parts; } else { - Part part = Part.builder().text(aiMessage.text()).build(); - return List.of(part); + String text = aiMessage.text(); + if (text == null) { + return List.of(); + } + return List.of(Part.builder().text(text).build()); } } diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index f88237ff1..a1ec7a3c2 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -970,4 +970,27 @@ void testNoUsageMetadataWithoutEstimator() { // IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator assertThat(response.usageMetadata()).isEmpty(); } + + @Test + @DisplayName("Should handle null AiMessage text without throwing NPE") + void testGenerateContentWithNullAiMessageText() { + // Given + final LlmRequest llmRequest = + LlmRequest.builder().contents(List.of(Content.fromParts(Part.fromText("Hello")))).build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = mock(AiMessage.class); + when(aiMessage.text()).thenReturn(null); + when(aiMessage.hasToolExecutionRequests()).thenReturn(false); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final Flowable responseFlowable = langChain4j.generateContent(llmRequest, false); + final LlmResponse response = responseFlowable.blockingFirst(); + // Then - no NPE thrown, and content has no text parts + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts().orElse(List.of())).isEmpty(); + } } From cdc5199eb0f92cb95db2ee7ff139d67317968457 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Mon, 23 Mar 2026 13:43:34 +0100 Subject: [PATCH 38/40] fix: add schema validation to SetModelResponseTool (issue #587 already implemented, but adding tests from PR #603) --- .../adk/tools/SetModelResponseTool.java | 7 +- .../adk/tools/SetModelResponseToolTest.java | 123 ++++++++++++++++++ 2 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java diff --git a/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java index e23d6414a..3b0e411b4 100644 --- a/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java +++ b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java @@ -16,6 +16,7 @@ package com.google.adk.tools; +import com.google.adk.SchemaUtils; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Single; @@ -58,6 +59,10 @@ public Optional declaration() { public Single> runAsync(Map args, ToolContext toolContext) { // This tool is a marker for the final response, it doesn't do anything but return its arguments // which will be captured as the final result. - return Single.just(args); + return Single.fromCallable( + () -> { + SchemaUtils.validateMapOnSchema(args, outputSchema, /* isInput= */ false); + return args; + }); } } diff --git a/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java new file mode 100644 index 000000000..64b600af9 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SetModelResponseToolTest { + + @Test + public void declaration_returnsCorrectFunctionDeclaration() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + FunctionDeclaration declaration = tool.declaration().get(); + + assertThat(declaration.name()).hasValue("set_model_response"); + assertThat(declaration.description()).isPresent(); + assertThat(declaration.description().get()).contains("Set your final response"); + assertThat(declaration.parameters()).hasValue(outputSchema); + } + + @Test + public void runAsync_returnsArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map args = ImmutableMap.of("field1", "value1"); + + Map result = tool.runAsync(args, null).blockingGet(); + + assertThat(result).isEqualTo(args); + } + + @Test + public void runAsync_validatesArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map invalidArgs = ImmutableMap.of("field2", "value2"); + + // Should throw validation error + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> tool.runAsync(invalidArgs, null).blockingGet()); + + assertThat(exception).hasMessageThat().contains("does not match agent output schema"); + } + + @Test + public void runAsync_validatesComplexArgs() { + Schema complexSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "id", + Schema.builder().type("INTEGER").build(), + "tags", + Schema.builder() + .type("ARRAY") + .items(Schema.builder().type("STRING").build()) + .build(), + "metadata", + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("key", Schema.builder().type("STRING").build())) + .build())) + .required(ImmutableList.of("id", "tags", "metadata")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(complexSchema); + Map complexArgs = + ImmutableMap.of( + "id", 123, + "tags", ImmutableList.of("tag1", "tag2"), + "metadata", ImmutableMap.of("key", "value")); + + Map result = tool.runAsync(complexArgs, null).blockingGet(); + + assertThat(result).containsEntry("id", 123); + assertThat(result).containsEntry("tags", ImmutableList.of("tag1", "tag2")); + assertThat(result).containsEntry("metadata", ImmutableMap.of("key", "value")); + } +} From e9df447f1445044552e8710713ab5a76c2ae5093 Mon Sep 17 00:00:00 2001 From: "Michael Vorburger.ch" Date: Mon, 23 Mar 2026 08:42:56 -0700 Subject: [PATCH 39/40] Remove explicit SLF4J binding from city-time-weather ADK tutorial. The `slf4j-simple` dependency and the exclusion of `logback-classic` are removed, allowing the default logging implementation provided by `google-adk-dev` to be used. PiperOrigin-RevId: 888114465 --- tutorials/city-time-weather/pom.xml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index aeb110cf6..19ef08a2d 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -36,16 +36,6 @@ com.google.adk google-adk-dev ${project.version} - - - ch.qos.logback - logback-classic - - - - - org.slf4j - slf4j-simple From 5e9824fe1f74b8c6e9de5484bc92a95e7f832bbf Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Mon, 23 Mar 2026 23:38:50 +0530 Subject: [PATCH 40/40] Adapt fork code to upstream API after google/adk-java merge - BaseArtifactService.loadArtifact now uses @Nullable Integer version - EventActions.stateDelta() exposed as Map; fix PostgresSessionService - PartConverter.toGenaiPart returns Part; update RequestConverter and tests - A2AMetadataKey.TYPE replaces removed PartConverter constant - Runner.runAsync(Session,...) removed; A2aService uses SessionKey - MediaSupportTest and artifact IT/unit tests updated accordingly Made-with: Cursor --- .../adk/a2a/converters/RequestConverter.java | 19 ++++-- .../com/google/adk/a2a/grpc/A2aService.java | 5 +- .../google/adk/a2a/grpc/MediaSupportTest.java | 66 ++++++++----------- .../artifacts/CassandraArtifactService.java | 12 ++-- .../adk/artifacts/MapDbArtifactService.java | 8 +-- .../adk/artifacts/MongoDbArtifactService.java | 4 +- .../artifacts/PostgresArtifactService.java | 7 +- .../adk/artifacts/RedisArtifactService.java | 8 +-- .../adk/sessions/PostgresSessionService.java | 5 +- .../artifacts/CassandraArtifactServiceIT.java | 11 +--- .../CassandraArtifactServiceTest.java | 5 +- .../artifacts/PostgresArtifactServiceIT.java | 44 +++++-------- .../PostgresArtifactServiceTest.java | 13 +--- .../adk/artifacts/RedisArtifactServiceIT.java | 9 +-- 14 files changed, 87 insertions(+), 129 deletions(-) diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/RequestConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/RequestConverter.java index b289ced6c..ae9faadfa 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/RequestConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/RequestConverter.java @@ -61,8 +61,11 @@ public static Optional convertA2aMessageToAdkEvent(Message message, Strin // Convert each A2A Part to GenAI Part if (message.getParts() != null) { for (Part a2aPart : message.getParts()) { - Optional genaiPart = PartConverter.toGenaiPart(a2aPart); - genaiPart.ifPresent(genaiParts::add); + try { + genaiParts.add(PartConverter.toGenaiPart(a2aPart)); + } catch (IllegalArgumentException e) { + logger.debug("Skipping unconvertible A2A part: {}", e.getMessage()); + } } } @@ -125,15 +128,18 @@ public static ImmutableList convertAggregatedA2aMessageToAdkEvents( // Emit exactly one ADK Event per A2A Part, preserving order. for (Part a2aPart : message.getParts()) { - Optional genaiPart = PartConverter.toGenaiPart(a2aPart); - if (genaiPart.isEmpty()) { + com.google.genai.types.Part genaiPart; + try { + genaiPart = PartConverter.toGenaiPart(a2aPart); + } catch (IllegalArgumentException e) { + logger.debug("Skipping unconvertible A2A part in aggregate: {}", e.getMessage()); continue; } String author = extractAuthorFromMetadata(a2aPart); String role = determineRoleFromAuthor(author); - events.add(createEvent(ImmutableList.of(genaiPart.get()), author, role, invocationId)); + events.add(createEvent(ImmutableList.of(genaiPart), author, role, invocationId)); } if (events.isEmpty()) { @@ -162,8 +168,7 @@ private static String extractAuthorFromMetadata(Part a2aPart) { if (a2aPart instanceof DataPart dataPart) { Map metadata = Optional.ofNullable(dataPart.getMetadata()).orElse(ImmutableMap.of()); - String type = - metadata.getOrDefault(PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, "").toString(); + String type = metadata.getOrDefault(A2AMetadataKey.TYPE.getType(), "").toString(); if (type.equals(A2ADataPartMetadataType.FUNCTION_CALL.getType())) { return "model"; } diff --git a/a2a/src/main/java/com/google/adk/a2a/grpc/A2aService.java b/a2a/src/main/java/com/google/adk/a2a/grpc/A2aService.java index e6658e6c0..0e4881ab1 100644 --- a/a2a/src/main/java/com/google/adk/a2a/grpc/A2aService.java +++ b/a2a/src/main/java/com/google/adk/a2a/grpc/A2aService.java @@ -11,6 +11,7 @@ import com.google.adk.runner.Runner; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; +import com.google.adk.sessions.SessionKey; import com.google.genai.types.Content; import com.google.genai.types.Part; import io.grpc.stub.StreamObserver; @@ -167,8 +168,8 @@ public void sendMessage( .setMaxLlmCalls(20) .build(); - // Execute the agent using Runner with Session object - Flowable eventStream = runner.runAsync(session, userContent, runConfig); + SessionKey sessionKey = new SessionKey(session.appName(), session.userId(), session.id()); + Flowable eventStream = runner.runAsync(sessionKey, userContent, runConfig); // Collect all events and aggregate into a single response // Since sendMessage is unary RPC, we need to send a single response diff --git a/a2a/src/test/java/com/google/adk/a2a/grpc/MediaSupportTest.java b/a2a/src/test/java/com/google/adk/a2a/grpc/MediaSupportTest.java index a8019be30..e7bc6eb8f 100644 --- a/a2a/src/test/java/com/google/adk/a2a/grpc/MediaSupportTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/grpc/MediaSupportTest.java @@ -14,7 +14,6 @@ import io.a2a.spec.FileWithUri; import io.a2a.spec.TextPart; import java.util.Base64; -import java.util.Optional; import org.junit.jupiter.api.Test; /** Tests for image, audio, and video support in A2A. */ @@ -36,11 +35,10 @@ class MediaSupportTest { void testTextPart_conversion() { // A2A TextPart to GenAI Part TextPart textPart = new TextPart("Hello, world!"); - Optional genaiPart = PartConverter.toGenaiPart(textPart); + Part genaiPart = PartConverter.toGenaiPart(textPart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().text()).isPresent(); - assertThat(genaiPart.get().text().get()).isEqualTo("Hello, world!"); + assertThat(genaiPart.text()).isPresent(); + assertThat(genaiPart.text().get()).isEqualTo("Hello, world!"); // GenAI Part to A2A TextPart Part genaiTextPart = Part.builder().text("Hello, world!").build(); @@ -57,11 +55,10 @@ void testImageFilePart_withUri() { FilePart imagePart = new FilePart(new FileWithUri("image/png", "test.png", "https://example.com/image.png")); - Optional genaiPart = PartConverter.toGenaiPart(imagePart); + Part genaiPart = PartConverter.toGenaiPart(imagePart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().fileData()).isPresent(); - FileData fileData = genaiPart.get().fileData().get(); + assertThat(genaiPart.fileData()).isPresent(); + FileData fileData = genaiPart.fileData().get(); assertThat(fileData.fileUri()).isPresent(); assertThat(fileData.fileUri().get()).isEqualTo("https://example.com/image.png"); assertThat(fileData.mimeType()).isPresent(); @@ -74,11 +71,10 @@ void testImageFilePart_withBytes() { FilePart imagePart = new FilePart(new FileWithBytes("image/png", "test.png", SAMPLE_IMAGE_BASE64)); - Optional genaiPart = PartConverter.toGenaiPart(imagePart); + Part genaiPart = PartConverter.toGenaiPart(imagePart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().inlineData()).isPresent(); - Blob blob = genaiPart.get().inlineData().get(); + assertThat(genaiPart.inlineData()).isPresent(); + Blob blob = genaiPart.inlineData().get(); assertThat(blob.mimeType()).isPresent(); assertThat(blob.mimeType().get()).isEqualTo("image/png"); assertThat(blob.data()).isPresent(); @@ -91,11 +87,10 @@ void testAudioFilePart_withUri() { FilePart audioPart = new FilePart(new FileWithUri("audio/mpeg", "test.mp3", "https://example.com/audio.mp3")); - Optional genaiPart = PartConverter.toGenaiPart(audioPart); + Part genaiPart = PartConverter.toGenaiPart(audioPart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().fileData()).isPresent(); - FileData fileData = genaiPart.get().fileData().get(); + assertThat(genaiPart.fileData()).isPresent(); + FileData fileData = genaiPart.fileData().get(); assertThat(fileData.mimeType()).isPresent(); assertThat(fileData.mimeType().get()).isEqualTo("audio/mpeg"); } @@ -106,11 +101,10 @@ void testAudioFilePart_withBytes() { FilePart audioPart = new FilePart(new FileWithBytes("audio/wav", "test.wav", SAMPLE_AUDIO_BASE64)); - Optional genaiPart = PartConverter.toGenaiPart(audioPart); + Part genaiPart = PartConverter.toGenaiPart(audioPart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().inlineData()).isPresent(); - Blob blob = genaiPart.get().inlineData().get(); + assertThat(genaiPart.inlineData()).isPresent(); + Blob blob = genaiPart.inlineData().get(); assertThat(blob.mimeType()).isPresent(); assertThat(blob.mimeType().get()).isEqualTo("audio/wav"); } @@ -121,11 +115,10 @@ void testVideoFilePart_withUri() { FilePart videoPart = new FilePart(new FileWithUri("video/mp4", "test.mp4", "https://example.com/video.mp4")); - Optional genaiPart = PartConverter.toGenaiPart(videoPart); + Part genaiPart = PartConverter.toGenaiPart(videoPart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().fileData()).isPresent(); - FileData fileData = genaiPart.get().fileData().get(); + assertThat(genaiPart.fileData()).isPresent(); + FileData fileData = genaiPart.fileData().get(); assertThat(fileData.mimeType()).isPresent(); assertThat(fileData.mimeType().get()).isEqualTo("video/mp4"); } @@ -136,11 +129,10 @@ void testVideoFilePart_withBytes() { FilePart videoPart = new FilePart(new FileWithBytes("video/mp4", "test.mp4", SAMPLE_VIDEO_BASE64)); - Optional genaiPart = PartConverter.toGenaiPart(videoPart); + Part genaiPart = PartConverter.toGenaiPart(videoPart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().inlineData()).isPresent(); - Blob blob = genaiPart.get().inlineData().get(); + assertThat(genaiPart.inlineData()).isPresent(); + Blob blob = genaiPart.inlineData().get(); assertThat(blob.mimeType()).isPresent(); assertThat(blob.mimeType().get()).isEqualTo("video/mp4"); } @@ -227,16 +219,12 @@ void testMultipleMediaTypes_inMessage() { FilePart videoPart = new FilePart(new FileWithUri("video/mp4", "movie.mp4", "https://example.com/movie.mp4")); - Optional imageGenai = PartConverter.toGenaiPart(imagePart); - Optional audioGenai = PartConverter.toGenaiPart(audioPart); - Optional videoGenai = PartConverter.toGenaiPart(videoPart); + Part imageGenai = PartConverter.toGenaiPart(imagePart); + Part audioGenai = PartConverter.toGenaiPart(audioPart); + Part videoGenai = PartConverter.toGenaiPart(videoPart); - assertThat(imageGenai).isPresent(); - assertThat(audioGenai).isPresent(); - assertThat(videoGenai).isPresent(); - - assertThat(imageGenai.get().fileData().get().mimeType().get()).isEqualTo("image/jpeg"); - assertThat(audioGenai.get().fileData().get().mimeType().get()).isEqualTo("audio/mpeg"); - assertThat(videoGenai.get().fileData().get().mimeType().get()).isEqualTo("video/mp4"); + assertThat(imageGenai.fileData().get().mimeType().get()).isEqualTo("image/jpeg"); + assertThat(audioGenai.fileData().get().mimeType().get()).isEqualTo("audio/mpeg"); + assertThat(videoGenai.fileData().get().mimeType().get()).isEqualTo("video/mp4"); } } diff --git a/core/src/main/java/com/google/adk/artifacts/CassandraArtifactService.java b/core/src/main/java/com/google/adk/artifacts/CassandraArtifactService.java index 9977381e8..e30656942 100644 --- a/core/src/main/java/com/google/adk/artifacts/CassandraArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/CassandraArtifactService.java @@ -29,7 +29,7 @@ import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.List; -import java.util.Optional; +import org.jspecify.annotations.Nullable; /** * A Cassandra-backed implementation of the {@link BaseArtifactService}. @@ -74,11 +74,11 @@ public Single saveArtifact( @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { return Maybe.fromCallable( () -> { Row row; - if (version.isPresent()) { + if (version != null) { row = session .execute( @@ -87,7 +87,7 @@ public Maybe loadArtifact( userId, sessionId, filename, - version.get()) + version) .one(); } else { row = @@ -197,9 +197,7 @@ public static void main(String[] args) { // Load the artifact Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); System.out.println("Loaded artifact content: " + loadedArtifact.text().get()); CassandraHelper.close(); diff --git a/core/src/main/java/com/google/adk/artifacts/MapDbArtifactService.java b/core/src/main/java/com/google/adk/artifacts/MapDbArtifactService.java index a2a087b9a..7d88d5b00 100644 --- a/core/src/main/java/com/google/adk/artifacts/MapDbArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/MapDbArtifactService.java @@ -19,10 +19,10 @@ import io.reactivex.rxjava3.core.Single; import java.io.File; import java.util.NavigableMap; -import java.util.Optional; import java.util.Set; import java.util.logging.Level; import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; import org.mapdb.BTreeMap; // BTreeMap is suitable for range queries import org.mapdb.DB; import org.mapdb.DBMaker; @@ -167,15 +167,15 @@ public Single saveArtifact( */ @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { // The Callable should return the item (Part) or null. // Maybe.fromCallable will wrap the non-null item in a Maybe or emit empty if null. return Maybe.fromCallable( () -> { String key; - if (version.isPresent()) { + if (version != null) { // Load specific version - int v = version.get(); + int v = version; if (v < 0) { // Version numbers must be non-negative return null; // Return null for empty Maybe } diff --git a/core/src/main/java/com/google/adk/artifacts/MongoDbArtifactService.java b/core/src/main/java/com/google/adk/artifacts/MongoDbArtifactService.java index 0259a3b93..8313b57be 100644 --- a/core/src/main/java/com/google/adk/artifacts/MongoDbArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/MongoDbArtifactService.java @@ -5,7 +5,7 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; -import java.util.Optional; +import org.jspecify.annotations.Nullable; /** * @author Harshavardhan A @@ -22,7 +22,7 @@ public Single saveArtifact( @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { return null; } diff --git a/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java b/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java index a47e13fb1..74c60bd8e 100644 --- a/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java @@ -26,7 +26,7 @@ import io.reactivex.rxjava3.schedulers.Schedulers; import java.sql.SQLException; import java.util.List; -import java.util.Optional; +import org.jspecify.annotations.Nullable; /** * A PostgreSQL-backed implementation of the {@link BaseArtifactService}. @@ -166,14 +166,13 @@ public Single saveArtifact( @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { return Maybe.fromCallable( () -> { try { // Load from database ArtifactData artifactData = - dbHelper.loadArtifact( - appName, userId, sessionId, filename, version.orElse(null)); + dbHelper.loadArtifact(appName, userId, sessionId, filename, version); if (artifactData == null) { return null; diff --git a/core/src/main/java/com/google/adk/artifacts/RedisArtifactService.java b/core/src/main/java/com/google/adk/artifacts/RedisArtifactService.java index 87c58cdf7..c8553715a 100644 --- a/core/src/main/java/com/google/adk/artifacts/RedisArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/RedisArtifactService.java @@ -25,7 +25,7 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; -import java.util.Optional; +import org.jspecify.annotations.Nullable; import reactor.adapter.rxjava.RxJava3Adapter; /** @@ -72,11 +72,11 @@ public Single saveArtifact( @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { String key = artifactKey(appName, userId, sessionId, filename); Single data; - if (version.isPresent()) { - data = RxJava3Adapter.monoToSingle(commands.lindex(key, version.get())); + if (version != null) { + data = RxJava3Adapter.monoToSingle(commands.lindex(key, version)); } else { data = RxJava3Adapter.monoToSingle(commands.lindex(key, -1)); } diff --git a/core/src/main/java/com/google/adk/sessions/PostgresSessionService.java b/core/src/main/java/com/google/adk/sessions/PostgresSessionService.java index c5d5c94c4..771b77d4d 100644 --- a/core/src/main/java/com/google/adk/sessions/PostgresSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/PostgresSessionService.java @@ -15,6 +15,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.UUID; @@ -255,7 +256,7 @@ public Single appendEvent(Session session, Event event) { // Apply state delta from event actions EventActions actions = event.actions(); if (actions != null) { - ConcurrentMap stateDelta = actions.stateDelta(); + Map stateDelta = actions.stateDelta(); if (stateDelta != null && !stateDelta.isEmpty()) { stateDelta.forEach( (key, value) -> { @@ -339,7 +340,7 @@ private void trimTempDeltaState(Event event) { if (event == null || event.actions() == null || event.actions().stateDelta() == null) { return; } - ConcurrentMap stateDelta = event.actions().stateDelta(); + Map stateDelta = event.actions().stateDelta(); stateDelta.entrySet().removeIf(entry -> entry.getKey().startsWith(State.TEMP_PREFIX)); } diff --git a/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceIT.java b/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceIT.java index 55080c2f0..54faffbc6 100644 --- a/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceIT.java +++ b/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceIT.java @@ -25,7 +25,6 @@ import io.reactivex.rxjava3.core.Maybe; import java.net.InetSocketAddress; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -79,9 +78,7 @@ public void testSaveAndLoadArtifact() { assertThat(version).isEqualTo(0); Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); assertThat(loadedArtifact.text().get()).isEqualTo("hello world"); } @@ -100,7 +97,7 @@ public void testDeleteArtifact() { artifactService.deleteArtifact(appName, userId, sessionId, filename).blockingAwait(); Maybe loadedArtifact = - artifactService.loadArtifact(appName, userId, sessionId, filename, Optional.of(version)); + artifactService.loadArtifact(appName, userId, sessionId, filename, version); assertThat(loadedArtifact.blockingGet()).isNull(); } @@ -135,9 +132,7 @@ public void testSaveAndLoadBinaryArtifact() { assertThat(version).isEqualTo(0); Part loadedBinaryArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); assertThat(loadedBinaryArtifact.inlineData().get().data().get()).isEqualTo(binaryData); assertThat(loadedBinaryArtifact.inlineData().get().mimeType().get()) .isEqualTo("application/octet-stream"); diff --git a/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceTest.java index 8c4f9ca1b..78b528fe1 100644 --- a/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceTest.java @@ -32,7 +32,6 @@ import java.nio.ByteBuffer; import java.util.Collections; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -108,9 +107,7 @@ public void testSaveAndLoadArtifact() throws Exception { when(mockObjectMapper.readValue(artifactData, Part.class)).thenReturn(artifact); Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); assertThat(loadedArtifact).isEqualTo(artifact); } diff --git a/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceIT.java b/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceIT.java index ef600fd9e..70ecc3d9f 100644 --- a/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceIT.java +++ b/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceIT.java @@ -21,7 +21,6 @@ import com.google.genai.types.Part; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -122,7 +121,7 @@ public void testSaveAndLoadArtifact_Success() { // Act - Load Part loadedArtifact = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.empty()) + .loadArtifact(testAppName, testUserId, testSessionId, filename, null) .blockingGet(); // Assert - Content matches @@ -161,15 +160,15 @@ public void testVersioning_MultipleVersions() { // Act - Load specific versions Part loaded0 = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.of(0)) + .loadArtifact(testAppName, testUserId, testSessionId, filename, 0) .blockingGet(); Part loaded1 = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.of(1)) + .loadArtifact(testAppName, testUserId, testSessionId, filename, 1) .blockingGet(); Part loaded2 = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.of(2)) + .loadArtifact(testAppName, testUserId, testSessionId, filename, 2) .blockingGet(); // Assert - Each version contains correct content @@ -180,7 +179,7 @@ public void testVersioning_MultipleVersions() { // Act - Load latest (should be version 2) Part loadedLatest = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.empty()) + .loadArtifact(testAppName, testUserId, testSessionId, filename, null) .blockingGet(); // Assert - Latest is version 2 @@ -224,7 +223,7 @@ public void testDeleteArtifact() { // Verify it exists Part beforeDelete = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.empty()) + .loadArtifact(testAppName, testUserId, testSessionId, filename, null) .blockingGet(); assertThat(beforeDelete).isNotNull(); @@ -236,7 +235,7 @@ public void testDeleteArtifact() { // Assert - No longer exists Part afterDelete = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.empty()) + .loadArtifact(testAppName, testUserId, testSessionId, filename, null) .blockingGet(); assertThat(afterDelete).isNull(); } @@ -284,13 +283,9 @@ public void testMultiTenancy_AppNameIsolation() { // Act - Load from each app Part fromApp1 = - artifactService - .loadArtifact(app1, userId, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(app1, userId, sessionId, filename, null).blockingGet(); Part fromApp2 = - artifactService - .loadArtifact(app2, userId, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(app2, userId, sessionId, filename, null).blockingGet(); // Assert - Content is isolated assertThat(fromApp1.text()).isEqualTo("App1 content"); @@ -320,13 +315,9 @@ public void testMultiTenancy_UserIdIsolation() { // Act - Load from each user Part fromUser1 = - artifactService - .loadArtifact(appName, user1, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, user1, sessionId, filename, null).blockingGet(); Part fromUser2 = - artifactService - .loadArtifact(appName, user2, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, user2, sessionId, filename, null).blockingGet(); // Assert - Content is isolated assertThat(fromUser1.text()).isEqualTo("User1 content"); @@ -356,13 +347,9 @@ public void testMultiTenancy_SessionIdIsolation() { // Act - Load from each session Part fromSession1 = - artifactService - .loadArtifact(appName, userId, session1, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, userId, session1, filename, null).blockingGet(); Part fromSession2 = - artifactService - .loadArtifact(appName, userId, session2, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, userId, session2, filename, null).blockingGet(); // Assert - Content is isolated assertThat(fromSession1.text()).isEqualTo("Session1 content"); @@ -398,8 +385,7 @@ public void testLoadArtifact_NonExistent() { // Act Part result = artifactService - .loadArtifact( - testAppName, testUserId, testSessionId, "nonexistent.txt", Optional.empty()) + .loadArtifact(testAppName, testUserId, testSessionId, "nonexistent.txt", null) .blockingGet(); // Assert @@ -417,7 +403,7 @@ public void testLoadArtifact_NonExistentVersion() { // Act - Try to load non-existent version 99 Part result = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.of(99)) + .loadArtifact(testAppName, testUserId, testSessionId, filename, 99) .blockingGet(); // Assert diff --git a/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java index 509f72d9f..53e6d1e20 100644 --- a/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java @@ -33,7 +33,6 @@ import java.sql.Timestamp; import java.util.Arrays; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -176,9 +175,7 @@ public void testLoadArtifact_LatestVersion() throws Exception { // Act Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, null).blockingGet(); // Assert assertThat(loadedArtifact).isNotNull(); @@ -207,9 +204,7 @@ public void testLoadArtifact_SpecificVersion() throws Exception { // Act Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); // Assert assertThat(loadedArtifact).isNotNull(); @@ -231,9 +226,7 @@ public void testLoadArtifact_NotFound() throws Exception { // Act Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, null).blockingGet(); // Assert assertThat(loadedArtifact).isNull(); diff --git a/core/src/test/java/com/google/adk/artifacts/RedisArtifactServiceIT.java b/core/src/test/java/com/google/adk/artifacts/RedisArtifactServiceIT.java index 92e1772bf..cdd3a0f3b 100644 --- a/core/src/test/java/com/google/adk/artifacts/RedisArtifactServiceIT.java +++ b/core/src/test/java/com/google/adk/artifacts/RedisArtifactServiceIT.java @@ -20,7 +20,6 @@ import com.google.adk.store.RedisHelper; import com.google.genai.types.Part; -import java.util.Optional; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -64,9 +63,7 @@ public void testSaveAndLoadArtifact() { assertThat(version).isEqualTo(0); Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); assertThat(loadedArtifact.text().get()).isEqualTo("hello world"); } @@ -84,9 +81,7 @@ public void testSaveAndLoadBinaryArtifact() { assertThat(version).isEqualTo(0); Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); assertThat(loadedArtifact.inlineData().get().data().get()).isEqualTo(binaryData); assertThat(loadedArtifact.inlineData().get().mimeType().get()) .isEqualTo("application/octet-stream");