From 8624d59de41205800e5538e51e2b1416b79bca48 Mon Sep 17 00:00:00 2001 From: Michael Vorburger Date: Tue, 3 Mar 2026 12:44:46 +0100 Subject: [PATCH 01/50] dev: Introduce initial AGENTS.md Intentionally named AGENTS.md instead of e.g. GEMINI.md to be fully model neutral; see https://agents.md for background. --- AGENTS.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..5d33d2172 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,3 @@ +# AGENTS.md + +Validate changes by running `./mvnw test`. From 82ef5ac2689e01676aa95d2616e3b4d8463e573e Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Mon, 9 Mar 2026 03:45:34 -0700 Subject: [PATCH 02/50] feat!: remove McpAsyncToolset constructors PiperOrigin-RevId: 880768893 --- .../google/adk/tools/mcp/McpAsyncToolset.java | 46 +++++++------------ 1 file changed, 16 insertions(+), 30 deletions(-) diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java index 73af9cc6a..bcc786d69 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java @@ -22,6 +22,8 @@ import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; import com.google.adk.tools.NamedToolPredicate; +import com.google.adk.tools.ToolPredicate; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.modelcontextprotocol.client.McpAsyncClient; @@ -32,8 +34,8 @@ import java.time.Duration; import java.util.List; import java.util.Objects; -import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -59,14 +61,14 @@ public class McpAsyncToolset implements BaseToolset { private final McpSessionManager mcpSessionManager; private final ObjectMapper objectMapper; - private final Optional toolFilter; + private final @Nullable Object toolFilter; private final AtomicReference>> mcpTools = new AtomicReference<>(); /** Builder for McpAsyncToolset */ public static class Builder { private Object connectionParams = null; private ObjectMapper objectMapper = null; - private Optional toolFilter = null; + private @Nullable Object toolFilter = null; @CanIgnoreReturnValue public Builder connectionParams(ServerParameters connectionParams) { @@ -87,14 +89,14 @@ public Builder objectMapper(ObjectMapper objectMapper) { } @CanIgnoreReturnValue - public Builder toolFilter(Optional toolFilter) { - this.toolFilter = toolFilter; + public Builder toolFilter(List toolNames) { + this.toolFilter = new NamedToolPredicate(Preconditions.checkNotNull(toolNames)); return this; } @CanIgnoreReturnValue - public Builder toolFilter(List toolNames) { - this.toolFilter = Optional.of(new NamedToolPredicate(toolNames)); + public Builder toolFilter(@Nullable ToolPredicate toolPredicate) { + this.toolFilter = toolPredicate; return this; } @@ -118,12 +120,12 @@ public McpAsyncToolset build() { * * @param connectionParams The SSE connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolFilter Either a ToolPredicate or a List of tool names. */ - public McpAsyncToolset( + McpAsyncToolset( SseServerParameters connectionParams, ObjectMapper objectMapper, - Optional toolFilter) { + @Nullable Object toolFilter) { Objects.requireNonNull(connectionParams); Objects.requireNonNull(objectMapper); this.objectMapper = objectMapper; @@ -136,10 +138,10 @@ public McpAsyncToolset( * * @param connectionParams The local server connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolFilter Either a ToolPredicate or a List of tool names or null. */ - public McpAsyncToolset( - ServerParameters connectionParams, ObjectMapper objectMapper, Optional toolFilter) { + McpAsyncToolset( + ServerParameters connectionParams, ObjectMapper objectMapper, @Nullable Object toolFilter) { Objects.requireNonNull(connectionParams); Objects.requireNonNull(objectMapper); this.objectMapper = objectMapper; @@ -147,22 +149,6 @@ public McpAsyncToolset( this.toolFilter = toolFilter; } - /** - * Initializes the McpAsyncToolset with a provided McpSessionManager. - * - * @param mcpSessionManager The session manager for MCP connections. - * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. - */ - public McpAsyncToolset( - McpSessionManager mcpSessionManager, ObjectMapper objectMapper, Optional toolFilter) { - Objects.requireNonNull(mcpSessionManager); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = mcpSessionManager; - this.toolFilter = toolFilter; - } - @Override public Flowable getTools(ReadonlyContext readonlyContext) { return Maybe.defer(() -> Maybe.fromCompletionStage(this.initAndGetTools().toFuture())) @@ -170,7 +156,7 @@ public Flowable getTools(ReadonlyContext readonlyContext) { .map( tools -> tools.stream() - .filter(tool -> isToolSelected(tool, toolFilter.orElse(null), readonlyContext)) + .filter(tool -> isToolSelected(tool, toolFilter, readonlyContext)) .toList()) .onErrorResumeNext( err -> { From 5e4eaa4805f10e03e88c9433c30dbe844c0161ec Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Mon, 9 Mar 2026 04:26:08 -0700 Subject: [PATCH 03/50] refactor: remove use of Optional params in Contents class PiperOrigin-RevId: 880784641 --- .../google/adk/flows/llmflows/Contents.java | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index ca8e0a051..6ebd39a9c 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -25,6 +25,7 @@ import com.google.adk.events.Event; import com.google.adk.events.EventCompaction; import com.google.adk.models.LlmRequest; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -41,6 +42,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import javax.annotation.Nullable; /** {@link RequestProcessor} that populates content in request for LLM flows. */ public final class Contents implements RequestProcessor { @@ -68,7 +70,7 @@ public Single processRequest( request.toBuilder() .contents( getCurrentTurnContents( - context.branch(), + context.branch().orElse(null), context.session().events(), context.agent().name(), modelName)) @@ -78,7 +80,10 @@ public Single processRequest( ImmutableList contents = getContents( - context.branch(), context.session().events(), context.agent().name(), modelName); + context.branch().orElse(null), + context.session().events(), + context.agent().name(), + modelName); return Single.just( RequestProcessor.RequestProcessingResult.create( @@ -87,7 +92,7 @@ public Single processRequest( /** Gets contents for the current turn only (no conversation history). */ private ImmutableList getCurrentTurnContents( - Optional currentBranch, List events, String agentName, String modelName) { + @Nullable String currentBranch, List events, String agentName, String modelName) { // Find the latest event that starts the current turn and process from there. for (int i = events.size() - 1; i >= 0; i--) { Event event = events.get(i); @@ -99,7 +104,7 @@ private ImmutableList getCurrentTurnContents( } private ImmutableList getContents( - Optional currentBranch, List events, String agentName, String modelName) { + @Nullable String currentBranch, List events, String agentName, String modelName) { List filteredEvents = new ArrayList<>(); boolean hasCompactEvent = false; @@ -414,16 +419,12 @@ private static String convertMapToJson(Map struct) { } } - private static boolean isEventBelongsToBranch(Optional invocationBranchOpt, Event event) { - Optional eventBranchOpt = event.branch(); + private static boolean isEventBelongsToBranch(@Nullable String invocationBranch, Event event) { + @Nullable String eventBranch = event.branch().orElse(null); - if (invocationBranchOpt.isEmpty() || invocationBranchOpt.get().isEmpty()) { - return true; - } - if (eventBranchOpt.isEmpty() || eventBranchOpt.get().isEmpty()) { - return true; - } - return invocationBranchOpt.get().startsWith(eventBranchOpt.get()); + return Strings.isNullOrEmpty(invocationBranch) + || Strings.isNullOrEmpty(eventBranch) + || invocationBranch.startsWith(eventBranch); } /** From 1cb4d431ca62573a8071b244e64971e6d60a42ad Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 9 Mar 2026 05:28:00 -0700 Subject: [PATCH 04/50] refactor: Move RemoteA2AAgent to agent package; remove EXPERIMENTAL annotation; add LICENSE header PiperOrigin-RevId: 880804505 --- .../adk/a2a/{ => agent}/RemoteA2AAgent.java | 40 ++++++++++++------- .../google/adk/a2a/common/A2AClientError.java | 15 +++++++ .../google/adk/a2a/common/A2AMetadata.java | 15 +++++++ .../common/GenAiFieldMissingException.java | 15 +++++++ .../converters/A2ADataPartMetadataType.java | 15 +++++++ .../adk/a2a/converters/EventConverter.java | 22 +++++++--- .../adk/a2a/converters/PartConverter.java | 22 +++++++--- .../adk/a2a/converters/ResponseConverter.java | 22 +++++++--- .../adk/a2a/executor/AgentExecutor.java | 22 +++++++--- .../adk/a2a/executor/AgentExecutorConfig.java | 15 +++++++ .../google/adk/a2a/executor/Callbacks.java | 15 +++++++ .../a2a/{ => agent}/RemoteA2AAgentTest.java | 2 +- contrib/samples/a2a_basic/A2AAgent.java | 2 +- 13 files changed, 181 insertions(+), 41 deletions(-) rename a2a/src/main/java/com/google/adk/a2a/{ => agent}/RemoteA2AAgent.java (93%) rename a2a/src/test/java/com/google/adk/a2a/{ => agent}/RemoteA2AAgentTest.java (99%) diff --git a/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java similarity index 93% rename from a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java rename to a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java index b391f2985..021786162 100644 --- a/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java @@ -1,4 +1,19 @@ -package com.google.adk.a2a; +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.a2a.agent; import static com.google.common.base.Strings.nullToEmpty; @@ -44,26 +59,21 @@ import org.slf4j.LoggerFactory; /** - * Agent that communicates with a remote A2A agent via A2A client. - * - *

This agent supports multiple ways to specify the remote agent: + * Agent that communicates with a remote A2A agent via an A2A client. * - *

    - *
  1. Direct AgentCard object - *
  2. URL to agent card JSON - *
  3. File path to agent card JSON - *
+ *

The remote agent can be specified directly by providing an {@link AgentCard} to the builder, + * or it can be resolved automatically using the provided A2A client. * - *

The agent handles: + *

Key responsibilities of this agent include: * *

    *
  • Agent card resolution and validation - *
  • A2A message conversion and error handling - *
  • Session state management across requests + *
  • Converting ADK session history events into A2A requests ({@link io.a2a.spec.Message}) + *
  • Handling streaming and non-streaming responses from the A2A client + *
  • Buffering and aggregating streamed response chunks into ADK {@link + * com.google.adk.events.Event}s + *
  • Converting A2A client responses back into ADK format *
- * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. */ public class RemoteA2AAgent extends BaseAgent { diff --git a/a2a/src/main/java/com/google/adk/a2a/common/A2AClientError.java b/a2a/src/main/java/com/google/adk/a2a/common/A2AClientError.java index 8e8282742..466c89223 100644 --- a/a2a/src/main/java/com/google/adk/a2a/common/A2AClientError.java +++ b/a2a/src/main/java/com/google/adk/a2a/common/A2AClientError.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.common; /** Exception thrown when the A2A client encounters an error. */ diff --git a/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java b/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java index 5c75faeac..a5faeff2a 100644 --- a/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java +++ b/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.common; /** Constants and utilities for A2A metadata keys. */ diff --git a/a2a/src/main/java/com/google/adk/a2a/common/GenAiFieldMissingException.java b/a2a/src/main/java/com/google/adk/a2a/common/GenAiFieldMissingException.java index a5947dcb8..0ac56fc01 100644 --- a/a2a/src/main/java/com/google/adk/a2a/common/GenAiFieldMissingException.java +++ b/a2a/src/main/java/com/google/adk/a2a/common/GenAiFieldMissingException.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.common; /** Exception thrown when the the genai class has an empty field. */ diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/A2ADataPartMetadataType.java b/a2a/src/main/java/com/google/adk/a2a/converters/A2ADataPartMetadataType.java index b5b53c49a..e0e97c8e9 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/A2ADataPartMetadataType.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/A2ADataPartMetadataType.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.converters; /** Enum for the type of A2A DataPart metadata. */ diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java index 1a49b0070..d823e3817 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.converters; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -13,12 +28,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Converter for ADK Events to A2A Messages. - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ +/** Converter for ADK Events to A2A Messages. */ public final class EventConverter { private static final Logger logger = LoggerFactory.getLogger(EventConverter.class); diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 05125d170..96ef66bc8 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.converters; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -32,12 +47,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Utility class for converting between Google GenAI Parts and A2A DataParts. - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ +/** Utility class for converting between Google GenAI Parts and A2A DataParts. */ public final class PartConverter { private static final Logger logger = LoggerFactory.getLogger(PartConverter.class); diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java index 57a84b58f..f3be48c1b 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.converters; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -27,12 +42,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Utility for converting ADK events to A2A spec messages (and back). - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ +/** Utility for converting ADK events to A2A spec messages (and back). */ public final class ResponseConverter { private static final Logger logger = LoggerFactory.getLogger(ResponseConverter.class); private static final ImmutableSet PENDING_STATES = diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java index b7b4e9953..7252cdec1 100644 --- a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.executor; import static java.util.Objects.requireNonNull; @@ -44,12 +59,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Implementation of the A2A AgentExecutor interface that uses ADK to execute agent tasks. - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ +/** Implementation of the A2A AgentExecutor interface that uses ADK to execute agent tasks. */ public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor { private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class); private static final String USER_ID_PREFIX = "A2A_USER_"; diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java index ba0177dc4..3ee8656d2 100644 --- a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.executor; import com.google.adk.a2a.executor.Callbacks.AfterEventCallback; diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java b/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java index 666f1d8a0..3483c527f 100644 --- a/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.executor; import com.google.adk.events.Event; diff --git a/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java similarity index 99% rename from a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java rename to a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java index 87eaa2321..e75da64ba 100644 --- a/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java @@ -1,4 +1,4 @@ -package com.google.adk.a2a; +package com.google.adk.a2a.agent; import static com.google.common.truth.Truth.assertThat; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/contrib/samples/a2a_basic/A2AAgent.java b/contrib/samples/a2a_basic/A2AAgent.java index e4e79a4eb..e08a87a67 100644 --- a/contrib/samples/a2a_basic/A2AAgent.java +++ b/contrib/samples/a2a_basic/A2AAgent.java @@ -1,6 +1,6 @@ package com.example.a2a_basic; -import com.google.adk.a2a.RemoteA2AAgent; +import com.google.adk.a2a.agent.RemoteA2AAgent; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.LlmAgent; import com.google.adk.tools.FunctionTool; From 5e1e1d434fa1f3931af30194422800757de96cb6 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Mon, 9 Mar 2026 06:57:20 -0700 Subject: [PATCH 05/50] feat!: Remove deprecated create method in ResponseProcessor PiperOrigin-RevId: 880834921 --- .../com/google/adk/flows/llmflows/CodeExecution.java | 5 ++--- .../google/adk/flows/llmflows/ResponseProcessor.java | 11 ++++++++--- .../google/adk/flows/llmflows/BaseLlmFlowTest.java | 9 ++------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java index f7c3c51ef..f2cbe967e 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java @@ -159,8 +159,7 @@ public Single processResponse( InvocationContext invocationContext, LlmResponse llmResponse) { if (llmResponse.partial().orElse(false)) { return Single.just( - ResponseProcessor.ResponseProcessingResult.create( - llmResponse, ImmutableList.of(), Optional.empty())); + ResponseProcessor.ResponseProcessingResult.create(llmResponse, ImmutableList.of())); } var llmResponseBuilder = llmResponse.toBuilder(); return runPostProcessor(invocationContext, llmResponseBuilder) @@ -168,7 +167,7 @@ public Single processResponse( .map( events -> ResponseProcessor.ResponseProcessingResult.create( - llmResponseBuilder.build(), events, Optional.empty())); + llmResponseBuilder.build(), events)); } } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/ResponseProcessor.java b/core/src/main/java/com/google/adk/flows/llmflows/ResponseProcessor.java index 4baa29523..d8e5ce3ab 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/ResponseProcessor.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/ResponseProcessor.java @@ -50,11 +50,16 @@ public abstract static class ResponseProcessingResult { */ public abstract Optional transferToAgent(); - /** Creates a new {@link ResponseProcessingResult}. */ public static ResponseProcessingResult create( - LlmResponse updatedResponse, Iterable events, Optional transferToAgent) { + LlmResponse updatedResponse, Iterable events, String transferToAgent) { return new AutoValue_ResponseProcessor_ResponseProcessingResult( - updatedResponse, events, transferToAgent); + updatedResponse, events, Optional.of(transferToAgent)); + } + + public static ResponseProcessingResult create( + LlmResponse updatedResponse, Iterable events) { + return new AutoValue_ResponseProcessor_ResponseProcessingResult( + updatedResponse, events, /* transferToAgent= */ Optional.empty()); } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index ff151a0b2..4a0b345c6 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -524,19 +524,14 @@ private static RequestProcessor createRequestProcessor( private static ResponseProcessor createResponseProcessor() { return (context, response) -> - Single.just( - ResponseProcessingResult.create( - response, ImmutableList.of(), /* transferToAgent= */ Optional.empty())); + Single.just(ResponseProcessingResult.create(response, ImmutableList.of())); } private static ResponseProcessor createResponseProcessor( Function responseUpdater) { return (context, response) -> Single.just( - ResponseProcessingResult.create( - responseUpdater.apply(response), - ImmutableList.of(), - /* transferToAgent= */ Optional.empty())); + ResponseProcessingResult.create(responseUpdater.apply(response), ImmutableList.of())); } private static class TestTool extends BaseTool { From a86ede007c3442ed73ee08a5c6ad0e2efa12998a Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Mon, 9 Mar 2026 07:08:54 -0700 Subject: [PATCH 06/50] feat!: remove deprecated url method in ComputerState.Builder PiperOrigin-RevId: 880839286 --- .../adk/tools/computeruse/ComputerState.java | 17 +++----- .../computeruse/ComputerUseToolTest.java | 9 +--- .../computeruse/ComputerUseToolsetTest.java | 42 +++++++------------ 3 files changed, 22 insertions(+), 46 deletions(-) diff --git a/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java b/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java index 4f3be46c2..b3d0f73bb 100644 --- a/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java +++ b/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Objects; import java.util.Optional; +import org.jspecify.annotations.Nullable; /** * Represents the current state of the computer environment. @@ -31,11 +32,11 @@ */ public final class ComputerState { private final byte[] screenshot; - private final Optional url; + private final @Nullable String url; @JsonCreator private ComputerState( - @JsonProperty("screenshot") byte[] screenshot, @JsonProperty("url") Optional url) { + @JsonProperty("screenshot") byte[] screenshot, @JsonProperty("url") @Nullable String url) { this.screenshot = screenshot.clone(); this.url = url; } @@ -47,7 +48,7 @@ public byte[] screenshot() { @JsonProperty("url") public Optional url() { - return url; + return Optional.ofNullable(url); } public static Builder builder() { @@ -57,7 +58,7 @@ public static Builder builder() { /** Builder for {@link ComputerState}. */ public static final class Builder { private byte[] screenshot; - private Optional url = Optional.empty(); + private @Nullable String url; @CanIgnoreReturnValue public Builder screenshot(byte[] screenshot) { @@ -66,17 +67,11 @@ public Builder screenshot(byte[] screenshot) { } @CanIgnoreReturnValue - public Builder url(Optional url) { + public Builder url(@Nullable String url) { this.url = url; return this; } - @CanIgnoreReturnValue - public Builder url(String url) { - this.url = Optional.ofNullable(url); - return this; - } - public ComputerState build() { return new ComputerState(screenshot, url); } diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java index 20fb146cf..236172b27 100644 --- a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java @@ -30,7 +30,6 @@ import java.lang.reflect.Method; import java.util.Base64; import java.util.Map; -import java.util.Optional; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -128,10 +127,7 @@ public void testNormalizeDragAndDrop() throws NoSuchMethodException { public void testResultFormatting() throws NoSuchMethodException { byte[] screenshot = new byte[] {1, 2, 3}; computerMock.nextState = - ComputerState.builder() - .screenshot(screenshot) - .url(Optional.of("https://example.com")) - .build(); + ComputerState.builder().screenshot(screenshot).url("https://example.com").build(); Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); ComputerUseTool tool = @@ -226,8 +222,7 @@ public static class ComputerMock { public int lastY; public int lastDestX; public int lastDestY; - public ComputerState nextState = - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build(); + public ComputerState nextState = ComputerState.builder().screenshot(new byte[0]).build(); public Single clickAt(@Schema(name = "x") int x, @Schema(name = "y") int y) { this.lastX = x; diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java index 1ed49419e..8051a018d 100644 --- a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java @@ -173,87 +173,73 @@ public Single environment() { @Override public Single openWebBrowser() { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single clickAt(int x, int y) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single hoverAt(int x, int y) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single typeTextAt( int x, int y, String text, Boolean pressEnter, Boolean clearBeforeTyping) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single scrollDocument(String direction) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single scrollAt(int x, int y, String direction, int magnitude) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single wait(Duration duration) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single goBack() { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single goForward() { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single search() { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single navigate(String url) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.of(url)).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).url(url).build()); } @Override public Single keyCombination(List keys) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single dragAndDrop(int x, int y, int destinationX, int destinationY) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single currentState() { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override From 143b656949d61363d135e0b74ef5696e78eb270a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 9 Mar 2026 09:47:49 -0700 Subject: [PATCH 07/50] feat: update return type for requestedToolConfirmations getter and setter to Map from ConcurrentMap PiperOrigin-RevId: 880903703 --- .../com/google/adk/events/EventActions.java | 25 +++++++++++++++---- .../google/adk/events/EventActionsTest.java | 24 ++++++++++++++++++ 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index bf25acfc7..31b096930 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -165,13 +165,20 @@ public void setRequestedAuthConfigs( } @JsonProperty("requestedToolConfirmations") - public ConcurrentMap requestedToolConfirmations() { + public Map requestedToolConfirmations() { return requestedToolConfirmations; } public void setRequestedToolConfirmations( - ConcurrentMap requestedToolConfirmations) { - this.requestedToolConfirmations = requestedToolConfirmations; + Map requestedToolConfirmations) { + if (requestedToolConfirmations == null) { + this.requestedToolConfirmations = new ConcurrentHashMap<>(); + } else if (requestedToolConfirmations instanceof ConcurrentMap) { + this.requestedToolConfirmations = + (ConcurrentMap) requestedToolConfirmations; + } else { + this.requestedToolConfirmations = new ConcurrentHashMap<>(requestedToolConfirmations); + } } @JsonProperty("endOfAgent") @@ -351,8 +358,16 @@ public Builder requestedAuthConfigs( @CanIgnoreReturnValue @JsonProperty("requestedToolConfirmations") - public Builder requestedToolConfirmations(ConcurrentMap value) { - this.requestedToolConfirmations = value; + public Builder requestedToolConfirmations(@Nullable Map value) { + if (value == null) { + this.requestedToolConfirmations = new ConcurrentHashMap<>(); + return this; + } + if (value instanceof ConcurrentMap) { + this.requestedToolConfirmations = (ConcurrentMap) value; + } else { + this.requestedToolConfirmations = new ConcurrentHashMap<>(value); + } return this; } diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 28123bab8..2975ca83f 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -26,6 +26,7 @@ import com.google.genai.types.Part; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -165,4 +166,27 @@ public void merge_failsOnMismatchedKeyTypesNestedInStateDelta() { assertThrows( IllegalArgumentException.class, () -> eventActions1.toBuilder().merge(eventActions2)); } + + @Test + public void setRequestedToolConfirmations_withConcurrentMap_usesSameInstance() { + ConcurrentHashMap map = new ConcurrentHashMap<>(); + map.put("tool", TOOL_CONFIRMATION); + + EventActions actions = new EventActions(); + actions.setRequestedToolConfirmations(map); + + assertThat(actions.requestedToolConfirmations()).isSameInstanceAs(map); + } + + @Test + public void setRequestedToolConfirmations_withRegularMap_createsConcurrentMap() { + ImmutableMap map = ImmutableMap.of("tool", TOOL_CONFIRMATION); + + EventActions actions = new EventActions(); + actions.setRequestedToolConfirmations(map); + + assertThat(actions.requestedToolConfirmations()).isNotSameInstanceAs(map); + assertThat(actions.requestedToolConfirmations()).isInstanceOf(ConcurrentMap.class); + assertThat(actions.requestedToolConfirmations()).containsExactly("tool", TOOL_CONFIRMATION); + } } From 973f88743cabebcd2e6e7a8d5f141142b596dbbb Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 9 Mar 2026 13:13:51 -0700 Subject: [PATCH 08/50] feat: Fixing the spans produced by agent calls to have the right parent spans PiperOrigin-RevId: 881003835 --- .../java/com/google/adk/agents/BaseAgent.java | 107 ++-- .../adk/flows/llmflows/BaseLlmFlow.java | 390 ++++++++------ .../google/adk/flows/llmflows/Functions.java | 227 ++++---- .../com/google/adk/plugins/PluginManager.java | 16 +- .../java/com/google/adk/runner/Runner.java | 167 +++--- .../com/google/adk/telemetry/Tracing.java | 488 +++++++++++++----- 6 files changed, 887 insertions(+), 508 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index d74ba9ca5..c527eeab3 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -29,10 +29,10 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; -import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -312,38 +312,47 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { + Context otelParentContext = Context.current(); + InvocationContext invocationContext = createInvocationContext(parentContext); + return Flowable.defer( - () -> { - InvocationContext invocationContext = createInvocationContext(parentContext); - - return callCallback( - beforeCallbacksToFunctions( - invocationContext.pluginManager(), beforeAgentCallback), - invocationContext) - .flatMapPublisher( - beforeEventOpt -> { - if (invocationContext.endInvocation()) { - return Flowable.fromOptional(beforeEventOpt); - } - - Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); - Flowable mainEvents = - Flowable.defer(() -> runImplementation.apply(invocationContext)); - Flowable afterEvents = - Flowable.defer( - () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), afterAgentCallback), - invocationContext) - .flatMapPublisher(Flowable::fromOptional)); - - return Flowable.concat(beforeEvents, mainEvents, afterEvents); - }) - .compose( - Tracing.traceAgent( - "invoke_agent " + name(), name(), description(), invocationContext)); - }); + () -> { + return callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), beforeAgentCallback), + invocationContext) + .flatMapPublisher( + beforeEvent -> { + if (invocationContext.endInvocation()) { + return Flowable.just(beforeEvent); + } + + return Flowable.just(beforeEvent) + .concatWith(runMainAndAfter(invocationContext, runImplementation)); + }) + .switchIfEmpty( + Flowable.defer(() -> runMainAndAfter(invocationContext, runImplementation))); + }) + .compose( + Tracing.traceAgent( + otelParentContext, + "invoke_agent " + name(), + name(), + description(), + invocationContext)); + } + + private Flowable runMainAndAfter( + InvocationContext invocationContext, + Function> runImplementation) { + Flowable mainEvents = runImplementation.apply(invocationContext); + Flowable afterEvents = + callCallback( + afterCallbacksToFunctions(invocationContext.pluginManager(), afterAgentCallback), + invocationContext) + .flatMapPublisher(Flowable::just); + + return Flowable.concat(mainEvents, afterEvents); } /** @@ -383,13 +392,13 @@ private ImmutableList>> callbacksTo * * @param agentCallbacks Callback functions. * @param invocationContext Current invocation context. - * @return single emitting first event, or empty if none. + * @return Maybe emitting first event, or empty if none. */ - private Single> callCallback( + private Maybe callCallback( List>> agentCallbacks, InvocationContext invocationContext) { if (agentCallbacks.isEmpty()) { - return Single.just(Optional.empty()); + return Maybe.empty(); } CallbackContext callbackContext = @@ -398,27 +407,25 @@ private Single> callCallback( return Flowable.fromIterable(agentCallbacks) .concatMap( callback -> { - Maybe maybeContent = callback.apply(callbackContext); - - return maybeContent + return callback + .apply(callbackContext) .map( content -> { invocationContext.setEndInvocation(true); - return Optional.of( - Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch().orElse(null)) - .actions(callbackContext.eventActions()) - .content(content) - .build()); + return Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch().orElse(null)) + .actions(callbackContext.eventActions()) + .content(content) + .build(); }) .toFlowable(); }) .firstElement() .switchIfEmpty( - Single.defer( + Maybe.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -429,9 +436,9 @@ private Single> callCallback( .branch(invocationContext.branch().orElse(null)) .actions(callbackContext.eventActions()); - return Single.just(Optional.of(eventBuilder.build())); + return Maybe.just(eventBuilder.build()); } else { - return Single.just(Optional.empty()); + return Maybe.empty(); } })); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 6ed9ccaa3..fba7f10e0 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -92,7 +92,9 @@ public BaseLlmFlow( * events generated by them. */ protected Flowable preprocess( - InvocationContext context, AtomicReference llmRequestRef) { + InvocationContext context, + AtomicReference llmRequestRef, + Context otelParentContext) { LlmAgent agent = (LlmAgent) context.agent(); RequestProcessor toolsProcessor = @@ -104,7 +106,8 @@ protected Flowable preprocess( tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build())) .andThen( Single.fromCallable( - () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); + () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))) + .compose(Tracing.withContext(otelParentContext)); }; Iterable allProcessors = @@ -113,7 +116,9 @@ protected Flowable preprocess( return Flowable.fromIterable(allProcessors) .concatMap( processor -> - Single.defer(() -> processor.processRequest(context, llmRequestRef.get())) + processor + .processRequest(context, llmRequestRef.get()) + .compose(Tracing.withContext(otelParentContext)) .doOnSuccess(result -> llmRequestRef.set(result.updatedRequest())) .flattenAsFlowable( result -> result.events() != null ? result.events() : ImmutableList.of())); @@ -129,13 +134,32 @@ protected Flowable postprocess( Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) { + return postprocess( + context, baseEventForLlmResponse, llmRequest, llmResponse, Context.current()); + } + + /** + * Post-processes the LLM response after receiving it from the LLM. Executes all registered {@link + * ResponseProcessor} instances. Emits events for the model response and any subsequent function + * calls. + */ + private Flowable postprocess( + InvocationContext context, + Event baseEventForLlmResponse, + LlmRequest llmRequest, + LlmResponse llmResponse, + Context otelParentContext) { List> eventIterables = new ArrayList<>(); Single currentLlmResponse = Single.just(llmResponse); for (ResponseProcessor processor : responseProcessors) { currentLlmResponse = currentLlmResponse - .flatMap(response -> processor.processResponse(context, response)) + .flatMap( + response -> + processor + .processResponse(context, response) + .compose(Tracing.withContext(otelParentContext))) .doOnSuccess( result -> { if (result.events() != null) { @@ -144,15 +168,16 @@ protected Flowable postprocess( }) .map(ResponseProcessingResult::updatedResponse); } - Context parentContext = Context.current(); return currentLlmResponse.flatMapPublisher( - updatedResponse -> { - try (Scope scope = parentContext.makeCurrent()) { - return buildPostprocessingEvents( - updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); - } - }); + updatedResponse -> + buildPostprocessingEvents( + updatedResponse, + eventIterables, + context, + baseEventForLlmResponse, + llmRequest, + otelParentContext)); } /** @@ -164,84 +189,100 @@ protected Flowable postprocess( * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ private Flowable callLlm( - InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { + InvocationContext context, + LlmRequest llmRequest, + Event eventForCallbackUsage, + Context otelParentContext) { LlmAgent agent = (LlmAgent) context.agent(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); - return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) - .flatMapPublisher( - beforeResponse -> { - if (beforeResponse.isPresent()) { - return Flowable.just(beforeResponse.get()); - } - BaseLlm llm = - agent.resolvedModel().model().isPresent() - ? agent.resolvedModel().model().get() - : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, llmRequestBuilder, eventForCallbackUsage, exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnNext( - llmResp -> - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp)) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .compose(Tracing.trace("call_llm")) - .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()); - }); + return handleBeforeModelCallback( + context, llmRequestBuilder, eventForCallbackUsage, otelParentContext) + .flatMapPublisher(Flowable::just) + .switchIfEmpty( + Flowable.defer( + () -> { + BaseLlm llm = + agent.resolvedModel().model().isPresent() + ? agent.resolvedModel().model().get() + : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); + return llm.generateContent( + llmRequestBuilder.build(), + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, + llmRequestBuilder, + eventForCallbackUsage, + exception, + otelParentContext) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .compose( + Tracing.trace("call_llm", otelParentContext) + .onSuccess( + (span, llmResp) -> + Tracing.traceCallLlm( + span, + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp))) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .concatMap( + llmResp -> + handleAfterModelCallback( + context, llmResp, eventForCallbackUsage, otelParentContext) + .toFlowable()); + })); } /** * Invokes {@link BeforeModelCallback}s. If any returns a response, it's used instead of calling * the LLM. * - * @return A {@link Single} with the callback result or {@link Optional#empty()}. + * @return A {@link Maybe} with the callback result. */ - private Single> handleBeforeModelCallback( - InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { - Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = - new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); + private Maybe handleBeforeModelCallback( + InvocationContext context, + LlmRequest.Builder llmRequestBuilder, + Event modelResponseEvent, + Context otelParentContext) { + try (Scope scope = otelParentContext.makeCurrent()) { + Event callbackEvent = modelResponseEvent.toBuilder().build(); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - Maybe pluginResult = - context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); + Maybe pluginResult = + context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); - LlmAgent agent = (LlmAgent) context.agent(); + LlmAgent agent = (LlmAgent) context.agent(); - List callbacks = agent.canonicalBeforeModelCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty()); - } + List callbacks = agent.canonicalBeforeModelCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; + } - Maybe callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) - .firstElement()); - - return pluginResult - .switchIfEmpty(callbackResult) - .map(Optional::of) - .defaultIfEmpty(Optional.empty()); + Maybe callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequestBuilder) + .compose(Tracing.withContext(otelParentContext))) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } } /** @@ -254,32 +295,41 @@ private Maybe handleOnModelErrorCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent, - Throwable throwable) { - Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = - new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - Exception ex = throwable instanceof Exception e ? e : new Exception(throwable); - - Maybe pluginResult = - context.pluginManager().onModelErrorCallback(callbackContext, llmRequestBuilder, throwable); - - LlmAgent agent = (LlmAgent) context.agent(); - List callbacks = agent.canonicalOnModelErrorCallbacks(); + Throwable throwable, + Context otelParentContext) { + + try (Scope scope = otelParentContext.makeCurrent()) { + Event callbackEvent = modelResponseEvent.toBuilder().build(); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); + Exception ex = throwable instanceof Exception e ? e : new Exception(throwable); + Maybe pluginResult = + context + .pluginManager() + .onModelErrorCallback(callbackContext, llmRequestBuilder, throwable); + + LlmAgent agent = (LlmAgent) context.agent(); + List callbacks = agent.canonicalOnModelErrorCallbacks(); + + if (callbacks.isEmpty()) { + return pluginResult; + } - if (callbacks.isEmpty()) { - return pluginResult; + Maybe callbackResult = + Maybe.defer( + () -> { + LlmRequest llmRequest = llmRequestBuilder.build(); + return Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequest, ex) + .compose(Tracing.withContext(otelParentContext))) + .firstElement(); + }); + + return pluginResult.switchIfEmpty(callbackResult); } - - Maybe callbackResult = - Maybe.defer( - () -> { - LlmRequest llmRequest = llmRequestBuilder.build(); - return Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequest, ex)) - .firstElement(); - }); - - return pluginResult.switchIfEmpty(callbackResult); } /** @@ -289,29 +339,39 @@ private Maybe handleOnModelErrorCallback( * @return A {@link Single} with the final {@link LlmResponse}. */ private Single handleAfterModelCallback( - InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { - Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = - new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); + InvocationContext context, + LlmResponse llmResponse, + Event modelResponseEvent, + Context otelParentContext) { - Maybe pluginResult = - context.pluginManager().afterModelCallback(callbackContext, llmResponse); + try (Scope scope = otelParentContext.makeCurrent()) { + Event callbackEvent = modelResponseEvent.toBuilder().build(); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - LlmAgent agent = (LlmAgent) context.agent(); - List callbacks = agent.canonicalAfterModelCallbacks(); + Maybe pluginResult = + context.pluginManager().afterModelCallback(callbackContext, llmResponse); - if (callbacks.isEmpty()) { - return pluginResult.defaultIfEmpty(llmResponse); - } + LlmAgent agent = (LlmAgent) context.agent(); + List callbacks = agent.canonicalAfterModelCallbacks(); - Maybe callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) - .firstElement()); + if (callbacks.isEmpty()) { + return pluginResult.defaultIfEmpty(llmResponse); + } - return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); + Maybe callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmResponse) + .compose(Tracing.withContext(otelParentContext))) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); + } } /** @@ -323,13 +383,12 @@ private Single handleAfterModelCallback( * @throws LlmCallsLimitExceededException if the agent exceeds allowed LLM invocations. * @throws IllegalStateException if a transfer agent is specified but not found. */ - private Flowable runOneStep(InvocationContext context) { + private Flowable runOneStep(InvocationContext context, Context otelParentContext) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); return Flowable.defer( () -> { - Context currentContext = Context.current(); - return preprocess(context, llmRequestRef) + return preprocess(context, llmRequestRef, otelParentContext) .concatWith( Flowable.defer( () -> { @@ -355,15 +414,19 @@ private Flowable runOneStep(InvocationContext context) { .build(); mutableEventTemplate.setTimestamp(0L); - return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) + return callLlm( + context, + llmRequestAfterPreprocess, + mutableEventTemplate, + otelParentContext) .concatMap( - llmResponse -> { - try (Scope postScope = currentContext.makeCurrent()) { - return postprocess( + llmResponse -> + postprocess( context, mutableEventTemplate, llmRequestAfterPreprocess, - llmResponse) + llmResponse, + otelParentContext) .doFinally( () -> { String oldId = mutableEventTemplate.id(); @@ -371,9 +434,7 @@ private Flowable runOneStep(InvocationContext context) { logger.debug( "Resetting event ID from {} to {}", oldId, newId); mutableEventTemplate.setId(newId); - }); - } - }) + })) .concatMap( event -> { Flowable postProcessedEvents = Flowable.just(event); @@ -407,11 +468,12 @@ private Flowable runOneStep(InvocationContext context) { */ @Override public Flowable run(InvocationContext invocationContext) { - return run(invocationContext, 0); + return run(invocationContext, Context.current(), 0); } - private Flowable run(InvocationContext invocationContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(invocationContext).cache(); + private Flowable run( + InvocationContext invocationContext, Context otelParentContext, int stepsCompleted) { + Flowable currentStepEvents = runOneStep(invocationContext, otelParentContext).cache(); if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); return currentStepEvents; @@ -431,7 +493,7 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return run(invocationContext, stepsCompleted + 1); + return run(invocationContext, otelParentContext, stepsCompleted + 1); } })); } @@ -446,8 +508,10 @@ private Flowable run(InvocationContext invocationContext, int stepsComple */ @Override public Flowable runLive(InvocationContext invocationContext) { + Context otelParentContext = Context.current(); AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); - Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); + Flowable preprocessEvents = + preprocess(invocationContext, llmRequestRef, otelParentContext); return preprocessEvents.concatWith( Flowable.defer( @@ -469,6 +533,7 @@ public Flowable runLive(InvocationContext invocationContext) { ? Completable.complete() : connection .sendHistory(llmRequestAfterPreprocess.contents()) + .compose(Tracing.trace("send_data", otelParentContext)) .doOnComplete( () -> Tracing.traceSendData( @@ -484,8 +549,7 @@ public Flowable runLive(InvocationContext invocationContext) { invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents()); - }) - .compose(Tracing.trace("send_data")); + }); Flowable liveRequests = invocationContext @@ -542,13 +606,16 @@ public void onError(Throwable e) { .receive() .flatMap( llmResponse -> { - Event baseEventForThisLlmResponse = - liveEventBuilderTemplate.id(Event.generateEventId()).build(); - return postprocess( - invocationContext, - baseEventForThisLlmResponse, - llmRequestAfterPreprocess, - llmResponse); + try (Scope scope = otelParentContext.makeCurrent()) { + Event baseEventForLlmResponse = + liveEventBuilderTemplate.id(Event.generateEventId()).build(); + return postprocess( + invocationContext, + baseEventForLlmResponse, + llmRequestAfterPreprocess, + llmResponse, + otelParentContext); + } }) .flatMap( event -> { @@ -600,7 +667,8 @@ private Flowable buildPostprocessingEvents( List> eventIterables, InvocationContext context, Event baseEventForLlmResponse, - LlmRequest llmRequest) { + LlmRequest llmRequest, + Context otelParentContext) { Flowable processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables)); if (updatedResponse.content().isEmpty() && updatedResponse.errorCode().isEmpty() @@ -616,23 +684,27 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)); } - Maybe maybeFunctionResponseEvent = - context.runConfig().streamingMode() == StreamingMode.BIDI - ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) - : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); - - Flowable functionEvents = - maybeFunctionResponseEvent.flatMapPublisher( - functionResponseEvent -> { - Optional toolConfirmationEvent = - Functions.generateRequestConfirmationEvent( - context, modelResponseEvent, functionResponseEvent); - return toolConfirmationEvent.isPresent() - ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) - : Flowable.just(functionResponseEvent); - }); - - return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); + try (Scope scope = otelParentContext.makeCurrent()) { + Maybe maybeFunctionResponseEvent = + context.runConfig().streamingMode() == StreamingMode.BIDI + ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) + : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); + + Flowable functionEvents = + maybeFunctionResponseEvent.flatMapPublisher( + functionResponseEvent -> { + Optional toolConfirmationEvent = + Functions.generateRequestConfirmationEvent( + context, modelResponseEvent, functionResponseEvent); + return toolConfirmationEvent.isPresent() + ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) + : Flowable.just(functionResponseEvent); + }); + + return processorEvents + .concatWith(Flowable.just(modelResponseEvent)) + .concatWith(functionEvents); + } } private Event buildModelResponseEvent( diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index ecc2bb412..f8b9e180d 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -257,7 +257,8 @@ private static Function> getFunctionCallMapper( functionCall.args().map(HashMap::new).orElse(new HashMap<>()); Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + maybeInvokeBeforeToolCall( + invocationContext, tool, functionArgs, toolContext, parentContext) .switchIfEmpty( Maybe.defer( () -> { @@ -395,48 +396,49 @@ private static Maybe postProcessFunctionResult( .defaultIfEmpty(Optional.empty()) .onErrorResumeNext( t -> { - Maybe> errorCallbackResult = - handleOnToolErrorCallback(invocationContext, tool, functionArgs, toolContext, t); - Maybe>> mappedResult; - if (isLive) { - // In live mode, handle null results from the error callback gracefully. - mappedResult = errorCallbackResult.map(Optional::ofNullable); - } else { - // In non-live mode, a null result from the error callback will cause an NPE - // when wrapped with Optional.of(), potentially matching prior behavior. - mappedResult = errorCallbackResult.map(Optional::of); + try (Scope scope = parentContext.makeCurrent()) { + Maybe> errorCallbackResult = + handleOnToolErrorCallback( + invocationContext, tool, functionArgs, toolContext, t, parentContext); + Maybe>> mappedResult; + if (isLive) { + // In live mode, handle null results from the error callback gracefully. + mappedResult = errorCallbackResult.map(Optional::ofNullable); + } else { + // In non-live mode, a null result from the error callback will cause an NPE + // when wrapped with Optional.of(), potentially matching prior behavior. + mappedResult = errorCallbackResult.map(Optional::of); + } + return mappedResult.switchIfEmpty(Single.error(t)); } - return mappedResult.switchIfEmpty(Single.error(t)); }) .flatMapMaybe( optionalInitialResult -> { - try (Scope scope = parentContext.makeCurrent()) { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - return maybeInvokeAfterToolCall( - invocationContext, tool, functionArgs, toolContext, initialFunctionResult) - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - return Maybe.fromCallable( - () -> - buildResponseEvent( - tool, - finalFunctionResult, - toolContext, - invocationContext)) - .compose( - Tracing.trace( - "tool_response [" + tool.name() + "]", parentContext)) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); - }); - } + Map initialFunctionResult = optionalInitialResult.orElse(null); + + return maybeInvokeAfterToolCall( + invocationContext, + tool, + functionArgs, + toolContext, + initialFunctionResult, + parentContext) + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + return Maybe.fromCallable( + () -> + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext)) + .compose( + Tracing.trace("tool_response [" + tool.name() + "]", parentContext)) + .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); + }); }); } @@ -479,28 +481,32 @@ private static Maybe> maybeInvokeBeforeToolCall( InvocationContext invocationContext, BaseTool tool, Map functionArgs, - ToolContext toolContext) { - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); + ToolContext toolContext, + Context parentContext) { + if (invocationContext.agent() instanceof LlmAgent agent) { + try (Scope scope = parentContext.makeCurrent()) { - Maybe> pluginResult = - invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); + Maybe> pluginResult = + invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); - List callbacks = agent.canonicalBeforeToolCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } - - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback.call(invocationContext, tool, functionArgs, toolContext)) - .firstElement()); + List callbacks = agent.canonicalBeforeToolCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; + } - return pluginResult.switchIfEmpty(callbackResult); + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call(invocationContext, tool, functionArgs, toolContext) + .compose(Tracing.withContext(parentContext))) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } } return Maybe.empty(); } @@ -516,34 +522,39 @@ private static Maybe> handleOnToolErrorCallback( BaseTool tool, Map functionArgs, ToolContext toolContext, - Throwable throwable) { + Throwable throwable, + Context parentContext) { Exception ex = throwable instanceof Exception exception ? exception : new Exception(throwable); - Maybe> pluginResult = - invocationContext - .pluginManager() - .onToolErrorCallback(tool, functionArgs, toolContext, throwable); - - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); + try (Scope scope = parentContext.makeCurrent()) { + Maybe> pluginResult = + invocationContext + .pluginManager() + .onToolErrorCallback(tool, functionArgs, toolContext, throwable); - List callbacks = agent.canonicalOnToolErrorCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + if (invocationContext.agent() instanceof LlmAgent) { + LlmAgent agent = (LlmAgent) invocationContext.agent(); - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback.call(invocationContext, tool, functionArgs, toolContext, ex)) - .firstElement()); + List callbacks = agent.canonicalOnToolErrorCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; + } - return pluginResult.switchIfEmpty(callbackResult); + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call(invocationContext, tool, functionArgs, toolContext, ex) + .compose(Tracing.withContext(parentContext))) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } + return pluginResult; } - return pluginResult; } private static Maybe> maybeInvokeAfterToolCall( @@ -551,35 +562,39 @@ private static Maybe> maybeInvokeAfterToolCall( BaseTool tool, Map functionArgs, ToolContext toolContext, - Map functionResult) { - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); + Map functionResult, + Context parentContext) { + if (invocationContext.agent() instanceof LlmAgent agent) { - Maybe> pluginResult = - invocationContext - .pluginManager() - .afterToolCallback(tool, functionArgs, toolContext, functionResult); + try (Scope scope = parentContext.makeCurrent()) { + Maybe> pluginResult = + invocationContext + .pluginManager() + .afterToolCallback(tool, functionArgs, toolContext, functionResult); - List callbacks = agent.canonicalAfterToolCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + List callbacks = agent.canonicalAfterToolCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; + } - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback.call( - invocationContext, - tool, - functionArgs, - toolContext, - functionResult)) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call( + invocationContext, + tool, + functionArgs, + toolContext, + functionResult) + .compose(Tracing.withContext(parentContext))) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } } return Maybe.empty(); } diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index 56dea936a..4d90ca7b5 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -21,11 +21,13 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -129,6 +131,7 @@ public Completable runAfterRunCallback(InvocationContext invocationContext) { @Override public Completable afterRunCallback(InvocationContext invocationContext) { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletable( plugin -> @@ -139,11 +142,13 @@ public Completable afterRunCallback(InvocationContext invocationContext) { logger.error( "[{}] Error during callback 'afterRunCallback'", plugin.getName(), - e))); + e)) + .compose(Tracing.withContext(capturedContext))); } @Override public Completable close() { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletableDelayError( plugin -> @@ -151,8 +156,8 @@ public Completable close() { .close() .doOnError( e -> - logger.error( - "[{}] Error during callback 'close'", plugin.getName(), e))); + logger.error("[{}] Error during callback 'close'", plugin.getName(), e)) + .compose(Tracing.withContext(capturedContext))); } public Maybe runOnEventCallback(InvocationContext invocationContext, Event event) { @@ -275,7 +280,7 @@ public Maybe> onToolErrorCallback( */ private Maybe runMaybeCallbacks( Function> callbackExecutor, String callbackName) { - + Context capturedContext = Context.current(); return Flowable.fromIterable(this.plugins) .concatMapMaybe( plugin -> @@ -294,7 +299,8 @@ private Maybe runMaybeCallbacks( "[{}] Error during callback '{}'", plugin.getName(), callbackName, - e))) + e)) + .compose(Tracing.withContext(capturedContext))) .firstElement(); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 4371300fb..e35f5c33d 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -52,6 +52,7 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -375,20 +376,25 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - Maybe maybeSession = - this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); - return maybeSession - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) + .compose(Tracing.trace("invocation")); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -441,7 +447,8 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta); + return runAsyncImpl(session, newMessage, runConfig, stateDelta) + .compose(Tracing.trace("invocation")); } /** @@ -460,6 +467,7 @@ protected Flowable runAsyncImpl( @Nullable Map stateDelta) { return Flowable.defer( () -> { + Context capturedContext = Context.current(); BaseAgent rootAgent = this.agent; String invocationId = InvocationContext.newInvocationContextId(); @@ -473,6 +481,7 @@ protected Flowable runAsyncImpl( return this.pluginManager .onUserMessageCallback(initialContext, newMessage) + .compose(Tracing.withContext(capturedContext)) .defaultIfEmpty(newMessage) .flatMap( content -> @@ -484,6 +493,7 @@ protected Flowable runAsyncImpl( runConfig.saveInputBlobsAsArtifacts(), stateDelta) : Single.just(null)) + .compose(Tracing.withContext(capturedContext)) .flatMapPublisher( event -> { if (event == null) { @@ -494,15 +504,17 @@ protected Flowable runAsyncImpl( return this.sessionService .getSession( session.appName(), session.userId(), session.id(), Optional.empty()) + .compose(Tracing.withContext(capturedContext)) .flatMapPublisher( updatedSession -> runAgentWithFreshSession( - session, - updatedSession, - event, - invocationId, - runConfig, - rootAgent)); + session, + updatedSession, + event, + invocationId, + runConfig, + rootAgent) + .compose(Tracing.withContext(capturedContext))); }); }) .doOnError( @@ -510,8 +522,7 @@ protected Flowable runAsyncImpl( Span span = Span.current(); span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); span.recordException(throwable); - }) - .compose(Tracing.trace("invocation")); + }); } private Flowable runAgentWithFreshSession( @@ -568,7 +579,7 @@ private Flowable runAgentWithFreshSession( .toFlowable() .switchIfEmpty(agentEvents) .concatWith( - Completable.defer(() -> pluginManager.runAfterRunCallback(contextWithUpdatedSession))) + Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) .concatWith(Completable.defer(() -> compactEvents(updatedSession))); } @@ -641,39 +652,51 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { + return runLiveImpl(session, liveRequestQueue, runConfig).compose(Tracing.trace("invocation")); + } + + /** + * Runs the agent in live mode, appending generated events to the session. + * + * @return stream of events from the agent. + */ + protected Flowable runLiveImpl( + Session session, @Nullable LiveRequestQueue liveRequestQueue, RunConfig runConfig) { return Flowable.defer( - () -> { - InvocationContext invocationContext = - newInvocationContextForLive(session, liveRequestQueue, runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - return invocationContextSingle - .flatMapPublisher( - updatedInvocationContext -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .doOnNext(event -> this.sessionService.appendEvent(session, event))) - .doOnError( - throwable -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - }); - }) - .compose(Tracing.trace("invocation")); + () -> { + Context capturedContext = Context.current(); + InvocationContext invocationContext = + newInvocationContextForLive(session, liveRequestQueue, runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .compose(Tracing.withContext(capturedContext)) + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .compose(Tracing.withContext(capturedContext)) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }); + }); } /** @@ -684,19 +707,25 @@ public Flowable runLive( */ public Flowable runLive( String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runLiveImpl(session, liveRequestQueue, runConfig))) + .compose(Tracing.trace("invocation")); } /** diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 07a640c37..9fa68ee00 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -37,16 +37,20 @@ import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.CompletableObserver; import io.reactivex.rxjava3.core.CompletableSource; import io.reactivex.rxjava3.core.CompletableTransformer; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.FlowableTransformer; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.MaybeObserver; import io.reactivex.rxjava3.core.MaybeSource; import io.reactivex.rxjava3.core.MaybeTransformer; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.core.SingleObserver; import io.reactivex.rxjava3.core.SingleSource; import io.reactivex.rxjava3.core.SingleTransformer; +import io.reactivex.rxjava3.disposables.Disposable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -54,9 +58,12 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -136,6 +143,10 @@ private static Optional getValidCurrentSpan(String methodName) { return Optional.of(span); } + private static void traceWithSpan(String methodName, Consumer action) { + getValidCurrentSpan(methodName).ifPresent(action); + } + private static void setInvocationAttributes( Span span, InvocationContext invocationContext, String eventId) { span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); @@ -206,16 +217,16 @@ public static void traceAgentInvocation( */ public static void traceToolCall( String toolName, String toolDescription, String toolType, Map args) { - getValidCurrentSpan("traceToolCall") - .ifPresent( - span -> { - setToolExecutionAttributes(span); - span.setAttribute(GEN_AI_TOOL_NAME, toolName); - span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); - span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - - setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); - }); + traceWithSpan( + "traceToolCall", + span -> { + setToolExecutionAttributes(span); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); + + setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); + }); } /** @@ -225,33 +236,32 @@ public static void traceToolCall( * @param functionResponseEvent The function response event. */ public static void traceToolResponse(String eventId, Event functionResponseEvent) { - getValidCurrentSpan("traceToolResponse") - .ifPresent( - span -> { - setToolExecutionAttributes(span); - span.setAttribute(ADK_EVENT_ID, eventId); - - FunctionResponse functionResponse = - functionResponseEvent.functionResponses().stream().findFirst().orElse(null); - - String toolCallId = ""; - Object toolResponse = ""; - if (functionResponse != null) { - toolCallId = functionResponse.id().orElse(toolCallId); - if (functionResponse.response().isPresent()) { - toolResponse = functionResponse.response().get(); - } - } - - span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); - - Object finalToolResponse = - (toolResponse instanceof Map) - ? toolResponse - : ImmutableMap.of("result", toolResponse); - - setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); - }); + traceWithSpan( + "traceToolResponse", + span -> { + setToolExecutionAttributes(span); + span.setAttribute(ADK_EVENT_ID, eventId); + + Optional functionResponse = + functionResponseEvent.functionResponses().stream().findFirst(); + + String toolCallId = + functionResponse.flatMap(FunctionResponse::id).orElse(""); + Object toolResponse = + functionResponse + .flatMap(FunctionResponse::response) + .map(Object.class::cast) + .orElse(""); + + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + + Object finalToolResponse = + (toolResponse instanceof Map) + ? toolResponse + : ImmutableMap.of("result", toolResponse); + + setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); + }); } /** @@ -296,58 +306,63 @@ public static void traceCallLlm( String eventId, LlmRequest llmRequest, LlmResponse llmResponse) { - getValidCurrentSpan("traceCallLlm") - .ifPresent( - span -> { - span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); - llmRequest - .model() - .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - - setInvocationAttributes(span, invocationContext, eventId); + traceWithSpan( + "traceCallLlm", + span -> traceCallLlm(span, invocationContext, eventId, llmRequest, llmResponse)); + } - setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); - setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + /** + * Traces a call to the LLM. + * + * @param span The span to end when the stream completes + * @param invocationContext The invocation context. + * @param eventId The ID of the event associated with this LLM call/response. + * @param llmRequest The LLM request object. + * @param llmResponse The LLM response object. + */ + public static void traceCallLlm( + Span span, + InvocationContext invocationContext, + String eventId, + LlmRequest llmRequest, + LlmResponse llmResponse) { + span.setAttribute(GEN_AI_OPERATION_NAME, "call_llm"); + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); + + setInvocationAttributes(span, invocationContext, eventId); + + setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); + setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + + llmRequest + .config() + .flatMap(config -> config.topP()) + .ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + llmRequest + .config() + .flatMap(config -> config.maxOutputTokens()) + .ifPresent( + maxTokens -> span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); - llmRequest - .config() - .ifPresent( - config -> { - config - .topP() - .ifPresent( - topP -> - span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); - config - .maxOutputTokens() - .ifPresent( - maxTokens -> - span.setAttribute( - GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); - }); - llmResponse - .usageMetadata() - .ifPresent( - usage -> { - usage - .promptTokenCount() - .ifPresent( - tokens -> - span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); - usage - .candidatesTokenCount() - .ifPresent( - tokens -> - span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); - }); - llmResponse - .finishReason() - .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() .ifPresent( - reason -> - span.setAttribute( - GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); + tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); }); + + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + .ifPresent( + reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); } /** @@ -359,17 +374,18 @@ public static void traceCallLlm( */ public static void traceSendData( InvocationContext invocationContext, String eventId, List data) { - getValidCurrentSpan("traceSendData") - .ifPresent( - span -> { - setInvocationAttributes(span, invocationContext, eventId); - - ImmutableList safeData = - Optional.ofNullable(data).orElse(ImmutableList.of()).stream() - .filter(Objects::nonNull) - .collect(toImmutableList()); - setJsonAttribute(span, ADK_DATA, safeData); - }); + traceWithSpan( + "traceSendData", + span -> { + span.setAttribute(GEN_AI_OPERATION_NAME, "send_data"); + setInvocationAttributes(span, invocationContext, eventId); + + ImmutableList safeData = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + setJsonAttribute(span, ADK_DATA, safeData); + }); } /** @@ -405,14 +421,17 @@ public static Tracer getTracer() { @SuppressWarnings("MustBeClosedChecker") // Scope lifecycle managed by RxJava doFinally public static Flowable traceFlowable( Context spanContext, Span span, Supplier> flowableSupplier) { - Scope scope = spanContext.makeCurrent(); - return flowableSupplier - .get() - .doFinally( - () -> { - scope.close(); - span.end(); - }); + return Flowable.defer( + () -> { + Scope scope = spanContext.makeCurrent(); + return flowableSupplier + .get() + .doFinally( + () -> { + scope.close(); + span.end(); + }); + }); } /** @@ -450,15 +469,66 @@ public static TracerProvider trace(String spanName, Context parentContext * @return A TracerProvider configured for agent invocation. */ public static TracerProvider traceAgent( + Context parent, String spanName, String agentName, String agentDescription, InvocationContext invocationContext) { return new TracerProvider(spanName) + .setParent(parent) .configure( span -> traceAgentInvocation(span, agentName, agentDescription, invocationContext)); } + /** + * Returns a transformer that re-activates a given context for the duration of the stream's + * subscription. + * + * @param context The context to re-activate. + * @param The type of the stream. + * @return A transformer that re-activates the context. + */ + public static ContextTransformer withContext(Context context) { + return new ContextTransformer<>(context); + } + + /** + * A transformer that re-activates a given context for the duration of the stream's subscription. + * + * @param The type of the stream. + */ + public static final class ContextTransformer + implements FlowableTransformer, + SingleTransformer, + MaybeTransformer, + CompletableTransformer { + private final Context context; + + private ContextTransformer(Context context) { + this.context = context; + } + + @Override + public Publisher apply(Flowable upstream) { + return upstream.lift(subscriber -> TracingObserver.wrap(context, subscriber)); + } + + @Override + public SingleSource apply(Single upstream) { + return upstream.lift(observer -> TracingObserver.wrap(context, observer)); + } + + @Override + public MaybeSource apply(Maybe upstream) { + return upstream.lift(observer -> TracingObserver.wrap(context, observer)); + } + + @Override + public CompletableSource apply(Completable upstream) { + return upstream.lift(observer -> TracingObserver.wrap(context, observer)); + } + } + /** * A transformer that manages an OpenTelemetry span and scope for RxJava streams. * @@ -472,6 +542,7 @@ public static final class TracerProvider private final String spanName; private Context explicitParentContext; private final List> spanConfigurers = new ArrayList<>(); + private BiConsumer onSuccessConsumer; private TracerProvider(String spanName) { this.spanName = spanName; @@ -491,27 +562,38 @@ public TracerProvider setParent(Context parentContext) { return this; } + /** + * Registers a callback to be executed with the span and the result item when the stream emits a + * success value. + */ + @CanIgnoreReturnValue + public TracerProvider onSuccess(BiConsumer consumer) { + this.onSuccessConsumer = consumer; + return this; + } + private Context getParentContext() { return explicitParentContext != null ? explicitParentContext : Context.current(); } private final class TracingLifecycle { - private Span span; - private Scope scope; + private final Span span; + private final Context context; - @SuppressWarnings("MustBeClosedChecker") - void start() { - span = tracer.spanBuilder(spanName).setParent(getParentContext()).startSpan(); + TracingLifecycle() { + Context parentContext = getParentContext(); + span = tracer.spanBuilder(spanName).setParent(parentContext).startSpan(); spanConfigurers.forEach(c -> c.accept(span)); - scope = span.makeCurrent(); + context = parentContext.with(span); } void end() { - if (scope != null) { - scope.close(); - } - if (span != null) { - span.end(); + span.end(); + } + + void run(O observer, Consumer subscribeAction) { + try (Scope scope = context.makeCurrent()) { + subscribeAction.accept(observer); } } } @@ -521,7 +603,18 @@ public Publisher apply(Flowable upstream) { return Flowable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + return Flowable.fromPublisher( + observer -> + lifecycle.run( + observer, + o -> { + Flowable chain = upstream.compose(withContext(lifecycle.context)); + if (onSuccessConsumer != null) { + chain = + chain.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + chain.doFinally(lifecycle::end).subscribe(o); + })); }); } @@ -530,7 +623,18 @@ public SingleSource apply(Single upstream) { return Single.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + return Single.wrap( + observer -> + lifecycle.run( + observer, + o -> { + Single chain = upstream.compose(withContext(lifecycle.context)); + if (onSuccessConsumer != null) { + chain = + chain.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + chain.doFinally(lifecycle::end).subscribe(o); + })); }); } @@ -539,7 +643,18 @@ public MaybeSource apply(Maybe upstream) { return Maybe.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + return Maybe.wrap( + observer -> + lifecycle.run( + observer, + o -> { + Maybe chain = upstream.compose(withContext(lifecycle.context)); + if (onSuccessConsumer != null) { + chain = + chain.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + chain.doFinally(lifecycle::end).subscribe(o); + })); }); } @@ -548,7 +663,142 @@ public CompletableSource apply(Completable upstream) { return Completable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + return Completable.wrap( + observer -> + lifecycle.run( + observer, + o -> { + Completable chain = upstream.compose(withContext(lifecycle.context)); + // Completable does not emit items, so onSuccessConsumer is not + // applicable. + chain.doFinally(lifecycle::end).subscribe(o); + })); + }); + } + } + + /** + * An observer that wraps another observer and ensures that the OpenTelemetry context is active + * during all callback methods. + * + * @param The type of the items emitted by the stream. + */ + private static final class TracingObserver + implements Subscriber, SingleObserver, MaybeObserver, CompletableObserver { + private final Context context; + private final Subscriber subscriber; + private final SingleObserver singleObserver; + private final MaybeObserver maybeObserver; + private final CompletableObserver completableObserver; + + private TracingObserver( + Context context, + Subscriber subscriber, + SingleObserver singleObserver, + MaybeObserver maybeObserver, + CompletableObserver completableObserver) { + this.context = context; + this.subscriber = subscriber; + this.singleObserver = singleObserver; + this.maybeObserver = maybeObserver; + this.completableObserver = completableObserver; + } + + static TracingObserver wrap(Context context, Subscriber subscriber) { + return new TracingObserver<>(context, subscriber, null, null, null); + } + + static TracingObserver wrap(Context context, SingleObserver observer) { + return new TracingObserver<>(context, null, observer, null, null); + } + + static TracingObserver wrap(Context context, MaybeObserver observer) { + return new TracingObserver<>(context, null, null, observer, null); + } + + static TracingObserver wrap(Context context, CompletableObserver observer) { + return new TracingObserver<>(context, null, null, null, observer); + } + + private void runInContext(Runnable action) { + try (Scope scope = context.makeCurrent()) { + action.run(); + } + } + + @Override + public void onSubscribe(Subscription s) { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onSubscribe(s); + } + }); + } + + @Override + public void onSubscribe(Disposable d) { + runInContext( + () -> { + if (singleObserver != null) { + singleObserver.onSubscribe(d); + } else if (maybeObserver != null) { + maybeObserver.onSubscribe(d); + } else if (completableObserver != null) { + completableObserver.onSubscribe(d); + } + }); + } + + @Override + public void onNext(T t) { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onNext(t); + } + }); + } + + @Override + public void onSuccess(T t) { + runInContext( + () -> { + if (singleObserver != null) { + singleObserver.onSuccess(t); + } else if (maybeObserver != null) { + maybeObserver.onSuccess(t); + } + }); + } + + @Override + public void onError(Throwable t) { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onError(t); + } else if (singleObserver != null) { + singleObserver.onError(t); + } else if (maybeObserver != null) { + maybeObserver.onError(t); + } else if (completableObserver != null) { + completableObserver.onError(t); + } + }); + } + + @Override + public void onComplete() { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onComplete(); + } else if (maybeObserver != null) { + maybeObserver.onComplete(); + } else if (completableObserver != null) { + completableObserver.onComplete(); + } }); } } From 3c8f4886f0e4c76abdbeb64a348bfccd5c16120e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 9 Mar 2026 18:45:28 -0700 Subject: [PATCH 09/50] feat: Fixing the spans produced by agent calls to have the right parent spans PiperOrigin-RevId: 881142814 --- .../java/com/google/adk/agents/BaseAgent.java | 107 ++-- .../adk/flows/llmflows/BaseLlmFlow.java | 390 ++++++-------- .../google/adk/flows/llmflows/Functions.java | 227 ++++---- .../com/google/adk/plugins/PluginManager.java | 16 +- .../java/com/google/adk/runner/Runner.java | 167 +++--- .../com/google/adk/telemetry/Tracing.java | 488 +++++------------- 6 files changed, 508 insertions(+), 887 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index c527eeab3..d74ba9ca5 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -29,10 +29,10 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import com.google.genai.types.Content; -import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -312,47 +312,38 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { - Context otelParentContext = Context.current(); - InvocationContext invocationContext = createInvocationContext(parentContext); - return Flowable.defer( - () -> { - return callCallback( - beforeCallbacksToFunctions( - invocationContext.pluginManager(), beforeAgentCallback), - invocationContext) - .flatMapPublisher( - beforeEvent -> { - if (invocationContext.endInvocation()) { - return Flowable.just(beforeEvent); - } - - return Flowable.just(beforeEvent) - .concatWith(runMainAndAfter(invocationContext, runImplementation)); - }) - .switchIfEmpty( - Flowable.defer(() -> runMainAndAfter(invocationContext, runImplementation))); - }) - .compose( - Tracing.traceAgent( - otelParentContext, - "invoke_agent " + name(), - name(), - description(), - invocationContext)); - } - - private Flowable runMainAndAfter( - InvocationContext invocationContext, - Function> runImplementation) { - Flowable mainEvents = runImplementation.apply(invocationContext); - Flowable afterEvents = - callCallback( - afterCallbacksToFunctions(invocationContext.pluginManager(), afterAgentCallback), - invocationContext) - .flatMapPublisher(Flowable::just); - - return Flowable.concat(mainEvents, afterEvents); + () -> { + InvocationContext invocationContext = createInvocationContext(parentContext); + + return callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), beforeAgentCallback), + invocationContext) + .flatMapPublisher( + beforeEventOpt -> { + if (invocationContext.endInvocation()) { + return Flowable.fromOptional(beforeEventOpt); + } + + Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); + Flowable mainEvents = + Flowable.defer(() -> runImplementation.apply(invocationContext)); + Flowable afterEvents = + Flowable.defer( + () -> + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), afterAgentCallback), + invocationContext) + .flatMapPublisher(Flowable::fromOptional)); + + return Flowable.concat(beforeEvents, mainEvents, afterEvents); + }) + .compose( + Tracing.traceAgent( + "invoke_agent " + name(), name(), description(), invocationContext)); + }); } /** @@ -392,13 +383,13 @@ private ImmutableList>> callbacksTo * * @param agentCallbacks Callback functions. * @param invocationContext Current invocation context. - * @return Maybe emitting first event, or empty if none. + * @return single emitting first event, or empty if none. */ - private Maybe callCallback( + private Single> callCallback( List>> agentCallbacks, InvocationContext invocationContext) { if (agentCallbacks.isEmpty()) { - return Maybe.empty(); + return Single.just(Optional.empty()); } CallbackContext callbackContext = @@ -407,25 +398,27 @@ private Maybe callCallback( return Flowable.fromIterable(agentCallbacks) .concatMap( callback -> { - return callback - .apply(callbackContext) + Maybe maybeContent = callback.apply(callbackContext); + + return maybeContent .map( content -> { invocationContext.setEndInvocation(true); - return Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch().orElse(null)) - .actions(callbackContext.eventActions()) - .content(content) - .build(); + return Optional.of( + Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch().orElse(null)) + .actions(callbackContext.eventActions()) + .content(content) + .build()); }) .toFlowable(); }) .firstElement() .switchIfEmpty( - Maybe.defer( + Single.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -436,9 +429,9 @@ private Maybe callCallback( .branch(invocationContext.branch().orElse(null)) .actions(callbackContext.eventActions()); - return Maybe.just(eventBuilder.build()); + return Single.just(Optional.of(eventBuilder.build())); } else { - return Maybe.empty(); + return Single.just(Optional.empty()); } })); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fba7f10e0..6ed9ccaa3 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -92,9 +92,7 @@ public BaseLlmFlow( * events generated by them. */ protected Flowable preprocess( - InvocationContext context, - AtomicReference llmRequestRef, - Context otelParentContext) { + InvocationContext context, AtomicReference llmRequestRef) { LlmAgent agent = (LlmAgent) context.agent(); RequestProcessor toolsProcessor = @@ -106,8 +104,7 @@ protected Flowable preprocess( tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build())) .andThen( Single.fromCallable( - () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))) - .compose(Tracing.withContext(otelParentContext)); + () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); }; Iterable allProcessors = @@ -116,9 +113,7 @@ protected Flowable preprocess( return Flowable.fromIterable(allProcessors) .concatMap( processor -> - processor - .processRequest(context, llmRequestRef.get()) - .compose(Tracing.withContext(otelParentContext)) + Single.defer(() -> processor.processRequest(context, llmRequestRef.get())) .doOnSuccess(result -> llmRequestRef.set(result.updatedRequest())) .flattenAsFlowable( result -> result.events() != null ? result.events() : ImmutableList.of())); @@ -134,32 +129,13 @@ protected Flowable postprocess( Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) { - return postprocess( - context, baseEventForLlmResponse, llmRequest, llmResponse, Context.current()); - } - - /** - * Post-processes the LLM response after receiving it from the LLM. Executes all registered {@link - * ResponseProcessor} instances. Emits events for the model response and any subsequent function - * calls. - */ - private Flowable postprocess( - InvocationContext context, - Event baseEventForLlmResponse, - LlmRequest llmRequest, - LlmResponse llmResponse, - Context otelParentContext) { List> eventIterables = new ArrayList<>(); Single currentLlmResponse = Single.just(llmResponse); for (ResponseProcessor processor : responseProcessors) { currentLlmResponse = currentLlmResponse - .flatMap( - response -> - processor - .processResponse(context, response) - .compose(Tracing.withContext(otelParentContext))) + .flatMap(response -> processor.processResponse(context, response)) .doOnSuccess( result -> { if (result.events() != null) { @@ -168,16 +144,15 @@ private Flowable postprocess( }) .map(ResponseProcessingResult::updatedResponse); } + Context parentContext = Context.current(); return currentLlmResponse.flatMapPublisher( - updatedResponse -> - buildPostprocessingEvents( - updatedResponse, - eventIterables, - context, - baseEventForLlmResponse, - llmRequest, - otelParentContext)); + updatedResponse -> { + try (Scope scope = parentContext.makeCurrent()) { + return buildPostprocessingEvents( + updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); + } + }); } /** @@ -189,100 +164,84 @@ private Flowable postprocess( * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ private Flowable callLlm( - InvocationContext context, - LlmRequest llmRequest, - Event eventForCallbackUsage, - Context otelParentContext) { + InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { LlmAgent agent = (LlmAgent) context.agent(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); - return handleBeforeModelCallback( - context, llmRequestBuilder, eventForCallbackUsage, otelParentContext) - .flatMapPublisher(Flowable::just) - .switchIfEmpty( - Flowable.defer( - () -> { - BaseLlm llm = - agent.resolvedModel().model().isPresent() - ? agent.resolvedModel().model().get() - : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, - llmRequestBuilder, - eventForCallbackUsage, - exception, - otelParentContext) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .compose( - Tracing.trace("call_llm", otelParentContext) - .onSuccess( - (span, llmResp) -> - Tracing.traceCallLlm( - span, - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp))) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .concatMap( - llmResp -> - handleAfterModelCallback( - context, llmResp, eventForCallbackUsage, otelParentContext) - .toFlowable()); - })); + return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) + .flatMapPublisher( + beforeResponse -> { + if (beforeResponse.isPresent()) { + return Flowable.just(beforeResponse.get()); + } + BaseLlm llm = + agent.resolvedModel().model().isPresent() + ? agent.resolvedModel().model().get() + : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); + return llm.generateContent( + llmRequestBuilder.build(), + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, llmRequestBuilder, eventForCallbackUsage, exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnNext( + llmResp -> + Tracing.traceCallLlm( + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp)) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .compose(Tracing.trace("call_llm")) + .concatMap( + llmResp -> + handleAfterModelCallback(context, llmResp, eventForCallbackUsage) + .toFlowable()); + }); } /** * Invokes {@link BeforeModelCallback}s. If any returns a response, it's used instead of calling * the LLM. * - * @return A {@link Maybe} with the callback result. + * @return A {@link Single} with the callback result or {@link Optional#empty()}. */ - private Maybe handleBeforeModelCallback( - InvocationContext context, - LlmRequest.Builder llmRequestBuilder, - Event modelResponseEvent, - Context otelParentContext) { - try (Scope scope = otelParentContext.makeCurrent()) { - Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = - new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); + private Single> handleBeforeModelCallback( + InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { + Event callbackEvent = modelResponseEvent.toBuilder().build(); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - Maybe pluginResult = - context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); + Maybe pluginResult = + context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); - LlmAgent agent = (LlmAgent) context.agent(); - - List callbacks = agent.canonicalBeforeModelCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + LlmAgent agent = (LlmAgent) context.agent(); - Maybe callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call(callbackContext, llmRequestBuilder) - .compose(Tracing.withContext(otelParentContext))) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); + List callbacks = agent.canonicalBeforeModelCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty()); } + + Maybe callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) + .firstElement()); + + return pluginResult + .switchIfEmpty(callbackResult) + .map(Optional::of) + .defaultIfEmpty(Optional.empty()); } /** @@ -295,41 +254,32 @@ private Maybe handleOnModelErrorCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent, - Throwable throwable, - Context otelParentContext) { - - try (Scope scope = otelParentContext.makeCurrent()) { - Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = - new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - Exception ex = throwable instanceof Exception e ? e : new Exception(throwable); - Maybe pluginResult = - context - .pluginManager() - .onModelErrorCallback(callbackContext, llmRequestBuilder, throwable); - - LlmAgent agent = (LlmAgent) context.agent(); - List callbacks = agent.canonicalOnModelErrorCallbacks(); - - if (callbacks.isEmpty()) { - return pluginResult; - } + Throwable throwable) { + Event callbackEvent = modelResponseEvent.toBuilder().build(); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); + Exception ex = throwable instanceof Exception e ? e : new Exception(throwable); + + Maybe pluginResult = + context.pluginManager().onModelErrorCallback(callbackContext, llmRequestBuilder, throwable); - Maybe callbackResult = - Maybe.defer( - () -> { - LlmRequest llmRequest = llmRequestBuilder.build(); - return Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call(callbackContext, llmRequest, ex) - .compose(Tracing.withContext(otelParentContext))) - .firstElement(); - }); - - return pluginResult.switchIfEmpty(callbackResult); + LlmAgent agent = (LlmAgent) context.agent(); + List callbacks = agent.canonicalOnModelErrorCallbacks(); + + if (callbacks.isEmpty()) { + return pluginResult; } + + Maybe callbackResult = + Maybe.defer( + () -> { + LlmRequest llmRequest = llmRequestBuilder.build(); + return Flowable.fromIterable(callbacks) + .concatMapMaybe(callback -> callback.call(callbackContext, llmRequest, ex)) + .firstElement(); + }); + + return pluginResult.switchIfEmpty(callbackResult); } /** @@ -339,39 +289,29 @@ private Maybe handleOnModelErrorCallback( * @return A {@link Single} with the final {@link LlmResponse}. */ private Single handleAfterModelCallback( - InvocationContext context, - LlmResponse llmResponse, - Event modelResponseEvent, - Context otelParentContext) { + InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { + Event callbackEvent = modelResponseEvent.toBuilder().build(); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - try (Scope scope = otelParentContext.makeCurrent()) { - Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = - new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); + Maybe pluginResult = + context.pluginManager().afterModelCallback(callbackContext, llmResponse); - Maybe pluginResult = - context.pluginManager().afterModelCallback(callbackContext, llmResponse); + LlmAgent agent = (LlmAgent) context.agent(); + List callbacks = agent.canonicalAfterModelCallbacks(); - LlmAgent agent = (LlmAgent) context.agent(); - List callbacks = agent.canonicalAfterModelCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult.defaultIfEmpty(llmResponse); + } - if (callbacks.isEmpty()) { - return pluginResult.defaultIfEmpty(llmResponse); - } + Maybe callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) + .firstElement()); - Maybe callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call(callbackContext, llmResponse) - .compose(Tracing.withContext(otelParentContext))) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); - } + return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); } /** @@ -383,12 +323,13 @@ private Single handleAfterModelCallback( * @throws LlmCallsLimitExceededException if the agent exceeds allowed LLM invocations. * @throws IllegalStateException if a transfer agent is specified but not found. */ - private Flowable runOneStep(InvocationContext context, Context otelParentContext) { + private Flowable runOneStep(InvocationContext context) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); return Flowable.defer( () -> { - return preprocess(context, llmRequestRef, otelParentContext) + Context currentContext = Context.current(); + return preprocess(context, llmRequestRef) .concatWith( Flowable.defer( () -> { @@ -414,19 +355,15 @@ private Flowable runOneStep(InvocationContext context, Context otelParent .build(); mutableEventTemplate.setTimestamp(0L); - return callLlm( - context, - llmRequestAfterPreprocess, - mutableEventTemplate, - otelParentContext) + return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) .concatMap( - llmResponse -> - postprocess( + llmResponse -> { + try (Scope postScope = currentContext.makeCurrent()) { + return postprocess( context, mutableEventTemplate, llmRequestAfterPreprocess, - llmResponse, - otelParentContext) + llmResponse) .doFinally( () -> { String oldId = mutableEventTemplate.id(); @@ -434,7 +371,9 @@ private Flowable runOneStep(InvocationContext context, Context otelParent logger.debug( "Resetting event ID from {} to {}", oldId, newId); mutableEventTemplate.setId(newId); - })) + }); + } + }) .concatMap( event -> { Flowable postProcessedEvents = Flowable.just(event); @@ -468,12 +407,11 @@ private Flowable runOneStep(InvocationContext context, Context otelParent */ @Override public Flowable run(InvocationContext invocationContext) { - return run(invocationContext, Context.current(), 0); + return run(invocationContext, 0); } - private Flowable run( - InvocationContext invocationContext, Context otelParentContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(invocationContext, otelParentContext).cache(); + private Flowable run(InvocationContext invocationContext, int stepsCompleted) { + Flowable currentStepEvents = runOneStep(invocationContext).cache(); if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); return currentStepEvents; @@ -493,7 +431,7 @@ private Flowable run( return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return run(invocationContext, otelParentContext, stepsCompleted + 1); + return run(invocationContext, stepsCompleted + 1); } })); } @@ -508,10 +446,8 @@ private Flowable run( */ @Override public Flowable runLive(InvocationContext invocationContext) { - Context otelParentContext = Context.current(); AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); - Flowable preprocessEvents = - preprocess(invocationContext, llmRequestRef, otelParentContext); + Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); return preprocessEvents.concatWith( Flowable.defer( @@ -533,7 +469,6 @@ public Flowable runLive(InvocationContext invocationContext) { ? Completable.complete() : connection .sendHistory(llmRequestAfterPreprocess.contents()) - .compose(Tracing.trace("send_data", otelParentContext)) .doOnComplete( () -> Tracing.traceSendData( @@ -549,7 +484,8 @@ public Flowable runLive(InvocationContext invocationContext) { invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents()); - }); + }) + .compose(Tracing.trace("send_data")); Flowable liveRequests = invocationContext @@ -606,16 +542,13 @@ public void onError(Throwable e) { .receive() .flatMap( llmResponse -> { - try (Scope scope = otelParentContext.makeCurrent()) { - Event baseEventForLlmResponse = - liveEventBuilderTemplate.id(Event.generateEventId()).build(); - return postprocess( - invocationContext, - baseEventForLlmResponse, - llmRequestAfterPreprocess, - llmResponse, - otelParentContext); - } + Event baseEventForThisLlmResponse = + liveEventBuilderTemplate.id(Event.generateEventId()).build(); + return postprocess( + invocationContext, + baseEventForThisLlmResponse, + llmRequestAfterPreprocess, + llmResponse); }) .flatMap( event -> { @@ -667,8 +600,7 @@ private Flowable buildPostprocessingEvents( List> eventIterables, InvocationContext context, Event baseEventForLlmResponse, - LlmRequest llmRequest, - Context otelParentContext) { + LlmRequest llmRequest) { Flowable processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables)); if (updatedResponse.content().isEmpty() && updatedResponse.errorCode().isEmpty() @@ -684,27 +616,23 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)); } - try (Scope scope = otelParentContext.makeCurrent()) { - Maybe maybeFunctionResponseEvent = - context.runConfig().streamingMode() == StreamingMode.BIDI - ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) - : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); - - Flowable functionEvents = - maybeFunctionResponseEvent.flatMapPublisher( - functionResponseEvent -> { - Optional toolConfirmationEvent = - Functions.generateRequestConfirmationEvent( - context, modelResponseEvent, functionResponseEvent); - return toolConfirmationEvent.isPresent() - ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) - : Flowable.just(functionResponseEvent); - }); - - return processorEvents - .concatWith(Flowable.just(modelResponseEvent)) - .concatWith(functionEvents); - } + Maybe maybeFunctionResponseEvent = + context.runConfig().streamingMode() == StreamingMode.BIDI + ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) + : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); + + Flowable functionEvents = + maybeFunctionResponseEvent.flatMapPublisher( + functionResponseEvent -> { + Optional toolConfirmationEvent = + Functions.generateRequestConfirmationEvent( + context, modelResponseEvent, functionResponseEvent); + return toolConfirmationEvent.isPresent() + ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) + : Flowable.just(functionResponseEvent); + }); + + return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); } private Event buildModelResponseEvent( diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index f8b9e180d..ecc2bb412 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -257,8 +257,7 @@ private static Function> getFunctionCallMapper( functionCall.args().map(HashMap::new).orElse(new HashMap<>()); Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall( - invocationContext, tool, functionArgs, toolContext, parentContext) + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) .switchIfEmpty( Maybe.defer( () -> { @@ -396,49 +395,48 @@ private static Maybe postProcessFunctionResult( .defaultIfEmpty(Optional.empty()) .onErrorResumeNext( t -> { - try (Scope scope = parentContext.makeCurrent()) { - Maybe> errorCallbackResult = - handleOnToolErrorCallback( - invocationContext, tool, functionArgs, toolContext, t, parentContext); - Maybe>> mappedResult; - if (isLive) { - // In live mode, handle null results from the error callback gracefully. - mappedResult = errorCallbackResult.map(Optional::ofNullable); - } else { - // In non-live mode, a null result from the error callback will cause an NPE - // when wrapped with Optional.of(), potentially matching prior behavior. - mappedResult = errorCallbackResult.map(Optional::of); - } - return mappedResult.switchIfEmpty(Single.error(t)); + Maybe> errorCallbackResult = + handleOnToolErrorCallback(invocationContext, tool, functionArgs, toolContext, t); + Maybe>> mappedResult; + if (isLive) { + // In live mode, handle null results from the error callback gracefully. + mappedResult = errorCallbackResult.map(Optional::ofNullable); + } else { + // In non-live mode, a null result from the error callback will cause an NPE + // when wrapped with Optional.of(), potentially matching prior behavior. + mappedResult = errorCallbackResult.map(Optional::of); } + return mappedResult.switchIfEmpty(Single.error(t)); }) .flatMapMaybe( optionalInitialResult -> { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - return maybeInvokeAfterToolCall( - invocationContext, - tool, - functionArgs, - toolContext, - initialFunctionResult, - parentContext) - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - return Maybe.fromCallable( - () -> - buildResponseEvent( - tool, finalFunctionResult, toolContext, invocationContext)) - .compose( - Tracing.trace("tool_response [" + tool.name() + "]", parentContext)) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); - }); + try (Scope scope = parentContext.makeCurrent()) { + Map initialFunctionResult = optionalInitialResult.orElse(null); + + return maybeInvokeAfterToolCall( + invocationContext, tool, functionArgs, toolContext, initialFunctionResult) + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = + finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + return Maybe.fromCallable( + () -> + buildResponseEvent( + tool, + finalFunctionResult, + toolContext, + invocationContext)) + .compose( + Tracing.trace( + "tool_response [" + tool.name() + "]", parentContext)) + .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); + }); + } }); } @@ -481,32 +479,28 @@ private static Maybe> maybeInvokeBeforeToolCall( InvocationContext invocationContext, BaseTool tool, Map functionArgs, - ToolContext toolContext, - Context parentContext) { - if (invocationContext.agent() instanceof LlmAgent agent) { - try (Scope scope = parentContext.makeCurrent()) { - - Maybe> pluginResult = - invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); + ToolContext toolContext) { + if (invocationContext.agent() instanceof LlmAgent) { + LlmAgent agent = (LlmAgent) invocationContext.agent(); - List callbacks = agent.canonicalBeforeToolCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + Maybe> pluginResult = + invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call(invocationContext, tool, functionArgs, toolContext) - .compose(Tracing.withContext(parentContext))) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); + List callbacks = agent.canonicalBeforeToolCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; } + + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback.call(invocationContext, tool, functionArgs, toolContext)) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); } return Maybe.empty(); } @@ -522,39 +516,34 @@ private static Maybe> handleOnToolErrorCallback( BaseTool tool, Map functionArgs, ToolContext toolContext, - Throwable throwable, - Context parentContext) { + Throwable throwable) { Exception ex = throwable instanceof Exception exception ? exception : new Exception(throwable); - try (Scope scope = parentContext.makeCurrent()) { - Maybe> pluginResult = - invocationContext - .pluginManager() - .onToolErrorCallback(tool, functionArgs, toolContext, throwable); - - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); + Maybe> pluginResult = + invocationContext + .pluginManager() + .onToolErrorCallback(tool, functionArgs, toolContext, throwable); - List callbacks = agent.canonicalOnToolErrorCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + if (invocationContext.agent() instanceof LlmAgent) { + LlmAgent agent = (LlmAgent) invocationContext.agent(); - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call(invocationContext, tool, functionArgs, toolContext, ex) - .compose(Tracing.withContext(parentContext))) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); + List callbacks = agent.canonicalOnToolErrorCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; } - return pluginResult; + + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback.call(invocationContext, tool, functionArgs, toolContext, ex)) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); } + return pluginResult; } private static Maybe> maybeInvokeAfterToolCall( @@ -562,39 +551,35 @@ private static Maybe> maybeInvokeAfterToolCall( BaseTool tool, Map functionArgs, ToolContext toolContext, - Map functionResult, - Context parentContext) { - if (invocationContext.agent() instanceof LlmAgent agent) { + Map functionResult) { + if (invocationContext.agent() instanceof LlmAgent) { + LlmAgent agent = (LlmAgent) invocationContext.agent(); - try (Scope scope = parentContext.makeCurrent()) { - Maybe> pluginResult = - invocationContext - .pluginManager() - .afterToolCallback(tool, functionArgs, toolContext, functionResult); - - List callbacks = agent.canonicalAfterToolCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + Maybe> pluginResult = + invocationContext + .pluginManager() + .afterToolCallback(tool, functionArgs, toolContext, functionResult); - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call( - invocationContext, - tool, - functionArgs, - toolContext, - functionResult) - .compose(Tracing.withContext(parentContext))) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); + List callbacks = agent.canonicalAfterToolCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; } + + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback.call( + invocationContext, + tool, + functionArgs, + toolContext, + functionResult)) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); } return Maybe.empty(); } diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index 4d90ca7b5..56dea936a 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -21,13 +21,11 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; -import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; -import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -131,7 +129,6 @@ public Completable runAfterRunCallback(InvocationContext invocationContext) { @Override public Completable afterRunCallback(InvocationContext invocationContext) { - Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletable( plugin -> @@ -142,13 +139,11 @@ public Completable afterRunCallback(InvocationContext invocationContext) { logger.error( "[{}] Error during callback 'afterRunCallback'", plugin.getName(), - e)) - .compose(Tracing.withContext(capturedContext))); + e))); } @Override public Completable close() { - Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletableDelayError( plugin -> @@ -156,8 +151,8 @@ public Completable close() { .close() .doOnError( e -> - logger.error("[{}] Error during callback 'close'", plugin.getName(), e)) - .compose(Tracing.withContext(capturedContext))); + logger.error( + "[{}] Error during callback 'close'", plugin.getName(), e))); } public Maybe runOnEventCallback(InvocationContext invocationContext, Event event) { @@ -280,7 +275,7 @@ public Maybe> onToolErrorCallback( */ private Maybe runMaybeCallbacks( Function> callbackExecutor, String callbackName) { - Context capturedContext = Context.current(); + return Flowable.fromIterable(this.plugins) .concatMapMaybe( plugin -> @@ -299,8 +294,7 @@ private Maybe runMaybeCallbacks( "[{}] Error during callback '{}'", plugin.getName(), callbackName, - e)) - .compose(Tracing.withContext(capturedContext))) + e))) .firstElement(); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index e35f5c33d..4371300fb 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -52,7 +52,6 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; -import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -376,25 +375,20 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - return Flowable.defer( - () -> - this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession( - appName, userId, (Map) null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format( - "Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher( - session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) - .compose(Tracing.trace("invocation")); + Maybe maybeSession = + this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); + return maybeSession + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession(appName, userId, null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format("Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -447,8 +441,7 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta) - .compose(Tracing.trace("invocation")); + return runAsyncImpl(session, newMessage, runConfig, stateDelta); } /** @@ -467,7 +460,6 @@ protected Flowable runAsyncImpl( @Nullable Map stateDelta) { return Flowable.defer( () -> { - Context capturedContext = Context.current(); BaseAgent rootAgent = this.agent; String invocationId = InvocationContext.newInvocationContextId(); @@ -481,7 +473,6 @@ protected Flowable runAsyncImpl( return this.pluginManager .onUserMessageCallback(initialContext, newMessage) - .compose(Tracing.withContext(capturedContext)) .defaultIfEmpty(newMessage) .flatMap( content -> @@ -493,7 +484,6 @@ protected Flowable runAsyncImpl( runConfig.saveInputBlobsAsArtifacts(), stateDelta) : Single.just(null)) - .compose(Tracing.withContext(capturedContext)) .flatMapPublisher( event -> { if (event == null) { @@ -504,17 +494,15 @@ protected Flowable runAsyncImpl( return this.sessionService .getSession( session.appName(), session.userId(), session.id(), Optional.empty()) - .compose(Tracing.withContext(capturedContext)) .flatMapPublisher( updatedSession -> runAgentWithFreshSession( - session, - updatedSession, - event, - invocationId, - runConfig, - rootAgent) - .compose(Tracing.withContext(capturedContext))); + session, + updatedSession, + event, + invocationId, + runConfig, + rootAgent)); }); }) .doOnError( @@ -522,7 +510,8 @@ protected Flowable runAsyncImpl( Span span = Span.current(); span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); span.recordException(throwable); - }); + }) + .compose(Tracing.trace("invocation")); } private Flowable runAgentWithFreshSession( @@ -579,7 +568,7 @@ private Flowable runAgentWithFreshSession( .toFlowable() .switchIfEmpty(agentEvents) .concatWith( - Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) + Completable.defer(() -> pluginManager.runAfterRunCallback(contextWithUpdatedSession))) .concatWith(Completable.defer(() -> compactEvents(updatedSession))); } @@ -652,51 +641,39 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return runLiveImpl(session, liveRequestQueue, runConfig).compose(Tracing.trace("invocation")); - } - - /** - * Runs the agent in live mode, appending generated events to the session. - * - * @return stream of events from the agent. - */ - protected Flowable runLiveImpl( - Session session, @Nullable LiveRequestQueue liveRequestQueue, RunConfig runConfig) { return Flowable.defer( - () -> { - Context capturedContext = Context.current(); - InvocationContext invocationContext = - newInvocationContextForLive(session, liveRequestQueue, runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - return invocationContextSingle - .compose(Tracing.withContext(capturedContext)) - .flatMapPublisher( - updatedInvocationContext -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .compose(Tracing.withContext(capturedContext)) - .doOnNext(event -> this.sessionService.appendEvent(session, event))) - .doOnError( - throwable -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - }); - }); + () -> { + InvocationContext invocationContext = + newInvocationContextForLive(session, liveRequestQueue, runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }); + }) + .compose(Tracing.trace("invocation")); } /** @@ -707,25 +684,19 @@ protected Flowable runLiveImpl( */ public Flowable runLive( String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return Flowable.defer( - () -> - this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession( - appName, userId, (Map) null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format( - "Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher( - session -> this.runLiveImpl(session, liveRequestQueue, runConfig))) - .compose(Tracing.trace("invocation")); + return this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession(appName, userId, null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format("Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); } /** diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 9fa68ee00..07a640c37 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -37,20 +37,16 @@ import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; -import io.reactivex.rxjava3.core.CompletableObserver; import io.reactivex.rxjava3.core.CompletableSource; import io.reactivex.rxjava3.core.CompletableTransformer; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.FlowableTransformer; import io.reactivex.rxjava3.core.Maybe; -import io.reactivex.rxjava3.core.MaybeObserver; import io.reactivex.rxjava3.core.MaybeSource; import io.reactivex.rxjava3.core.MaybeTransformer; import io.reactivex.rxjava3.core.Single; -import io.reactivex.rxjava3.core.SingleObserver; import io.reactivex.rxjava3.core.SingleSource; import io.reactivex.rxjava3.core.SingleTransformer; -import io.reactivex.rxjava3.disposables.Disposable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -58,12 +54,9 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -143,10 +136,6 @@ private static Optional getValidCurrentSpan(String methodName) { return Optional.of(span); } - private static void traceWithSpan(String methodName, Consumer action) { - getValidCurrentSpan(methodName).ifPresent(action); - } - private static void setInvocationAttributes( Span span, InvocationContext invocationContext, String eventId) { span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); @@ -217,16 +206,16 @@ public static void traceAgentInvocation( */ public static void traceToolCall( String toolName, String toolDescription, String toolType, Map args) { - traceWithSpan( - "traceToolCall", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(GEN_AI_TOOL_NAME, toolName); - span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); - span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - - setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); - }); + getValidCurrentSpan("traceToolCall") + .ifPresent( + span -> { + setToolExecutionAttributes(span); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); + + setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); + }); } /** @@ -236,32 +225,33 @@ public static void traceToolCall( * @param functionResponseEvent The function response event. */ public static void traceToolResponse(String eventId, Event functionResponseEvent) { - traceWithSpan( - "traceToolResponse", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(ADK_EVENT_ID, eventId); - - Optional functionResponse = - functionResponseEvent.functionResponses().stream().findFirst(); - - String toolCallId = - functionResponse.flatMap(FunctionResponse::id).orElse(""); - Object toolResponse = - functionResponse - .flatMap(FunctionResponse::response) - .map(Object.class::cast) - .orElse(""); - - span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); - - Object finalToolResponse = - (toolResponse instanceof Map) - ? toolResponse - : ImmutableMap.of("result", toolResponse); - - setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); - }); + getValidCurrentSpan("traceToolResponse") + .ifPresent( + span -> { + setToolExecutionAttributes(span); + span.setAttribute(ADK_EVENT_ID, eventId); + + FunctionResponse functionResponse = + functionResponseEvent.functionResponses().stream().findFirst().orElse(null); + + String toolCallId = ""; + Object toolResponse = ""; + if (functionResponse != null) { + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + + Object finalToolResponse = + (toolResponse instanceof Map) + ? toolResponse + : ImmutableMap.of("result", toolResponse); + + setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); + }); } /** @@ -306,63 +296,58 @@ public static void traceCallLlm( String eventId, LlmRequest llmRequest, LlmResponse llmResponse) { - traceWithSpan( - "traceCallLlm", - span -> traceCallLlm(span, invocationContext, eventId, llmRequest, llmResponse)); - } - - /** - * Traces a call to the LLM. - * - * @param span The span to end when the stream completes - * @param invocationContext The invocation context. - * @param eventId The ID of the event associated with this LLM call/response. - * @param llmRequest The LLM request object. - * @param llmResponse The LLM response object. - */ - public static void traceCallLlm( - Span span, - InvocationContext invocationContext, - String eventId, - LlmRequest llmRequest, - LlmResponse llmResponse) { - span.setAttribute(GEN_AI_OPERATION_NAME, "call_llm"); - span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); - llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - - setInvocationAttributes(span, invocationContext, eventId); - - setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); - setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); - - llmRequest - .config() - .flatMap(config -> config.topP()) - .ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); - llmRequest - .config() - .flatMap(config -> config.maxOutputTokens()) + getValidCurrentSpan("traceCallLlm") .ifPresent( - maxTokens -> span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + span -> { + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest + .model() + .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - llmResponse - .usageMetadata() - .ifPresent( - usage -> { - usage - .promptTokenCount() - .ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); - usage - .candidatesTokenCount() + setInvocationAttributes(span, invocationContext, eventId); + + setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); + setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + + llmRequest + .config() + .ifPresent( + config -> { + config + .topP() + .ifPresent( + topP -> + span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() + .ifPresent( + maxTokens -> + span.setAttribute( + GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent( + tokens -> + span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() + .ifPresent( + tokens -> + span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) .ifPresent( - tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + reason -> + span.setAttribute( + GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); }); - - llmResponse - .finishReason() - .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) - .ifPresent( - reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); } /** @@ -374,18 +359,17 @@ public static void traceCallLlm( */ public static void traceSendData( InvocationContext invocationContext, String eventId, List data) { - traceWithSpan( - "traceSendData", - span -> { - span.setAttribute(GEN_AI_OPERATION_NAME, "send_data"); - setInvocationAttributes(span, invocationContext, eventId); - - ImmutableList safeData = - Optional.ofNullable(data).orElse(ImmutableList.of()).stream() - .filter(Objects::nonNull) - .collect(toImmutableList()); - setJsonAttribute(span, ADK_DATA, safeData); - }); + getValidCurrentSpan("traceSendData") + .ifPresent( + span -> { + setInvocationAttributes(span, invocationContext, eventId); + + ImmutableList safeData = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + setJsonAttribute(span, ADK_DATA, safeData); + }); } /** @@ -421,17 +405,14 @@ public static Tracer getTracer() { @SuppressWarnings("MustBeClosedChecker") // Scope lifecycle managed by RxJava doFinally public static Flowable traceFlowable( Context spanContext, Span span, Supplier> flowableSupplier) { - return Flowable.defer( - () -> { - Scope scope = spanContext.makeCurrent(); - return flowableSupplier - .get() - .doFinally( - () -> { - scope.close(); - span.end(); - }); - }); + Scope scope = spanContext.makeCurrent(); + return flowableSupplier + .get() + .doFinally( + () -> { + scope.close(); + span.end(); + }); } /** @@ -469,66 +450,15 @@ public static TracerProvider trace(String spanName, Context parentContext * @return A TracerProvider configured for agent invocation. */ public static TracerProvider traceAgent( - Context parent, String spanName, String agentName, String agentDescription, InvocationContext invocationContext) { return new TracerProvider(spanName) - .setParent(parent) .configure( span -> traceAgentInvocation(span, agentName, agentDescription, invocationContext)); } - /** - * Returns a transformer that re-activates a given context for the duration of the stream's - * subscription. - * - * @param context The context to re-activate. - * @param The type of the stream. - * @return A transformer that re-activates the context. - */ - public static ContextTransformer withContext(Context context) { - return new ContextTransformer<>(context); - } - - /** - * A transformer that re-activates a given context for the duration of the stream's subscription. - * - * @param The type of the stream. - */ - public static final class ContextTransformer - implements FlowableTransformer, - SingleTransformer, - MaybeTransformer, - CompletableTransformer { - private final Context context; - - private ContextTransformer(Context context) { - this.context = context; - } - - @Override - public Publisher apply(Flowable upstream) { - return upstream.lift(subscriber -> TracingObserver.wrap(context, subscriber)); - } - - @Override - public SingleSource apply(Single upstream) { - return upstream.lift(observer -> TracingObserver.wrap(context, observer)); - } - - @Override - public MaybeSource apply(Maybe upstream) { - return upstream.lift(observer -> TracingObserver.wrap(context, observer)); - } - - @Override - public CompletableSource apply(Completable upstream) { - return upstream.lift(observer -> TracingObserver.wrap(context, observer)); - } - } - /** * A transformer that manages an OpenTelemetry span and scope for RxJava streams. * @@ -542,7 +472,6 @@ public static final class TracerProvider private final String spanName; private Context explicitParentContext; private final List> spanConfigurers = new ArrayList<>(); - private BiConsumer onSuccessConsumer; private TracerProvider(String spanName) { this.spanName = spanName; @@ -562,38 +491,27 @@ public TracerProvider setParent(Context parentContext) { return this; } - /** - * Registers a callback to be executed with the span and the result item when the stream emits a - * success value. - */ - @CanIgnoreReturnValue - public TracerProvider onSuccess(BiConsumer consumer) { - this.onSuccessConsumer = consumer; - return this; - } - private Context getParentContext() { return explicitParentContext != null ? explicitParentContext : Context.current(); } private final class TracingLifecycle { - private final Span span; - private final Context context; + private Span span; + private Scope scope; - TracingLifecycle() { - Context parentContext = getParentContext(); - span = tracer.spanBuilder(spanName).setParent(parentContext).startSpan(); + @SuppressWarnings("MustBeClosedChecker") + void start() { + span = tracer.spanBuilder(spanName).setParent(getParentContext()).startSpan(); spanConfigurers.forEach(c -> c.accept(span)); - context = parentContext.with(span); + scope = span.makeCurrent(); } void end() { - span.end(); - } - - void run(O observer, Consumer subscribeAction) { - try (Scope scope = context.makeCurrent()) { - subscribeAction.accept(observer); + if (scope != null) { + scope.close(); + } + if (span != null) { + span.end(); } } } @@ -603,18 +521,7 @@ public Publisher apply(Flowable upstream) { return Flowable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return Flowable.fromPublisher( - observer -> - lifecycle.run( - observer, - o -> { - Flowable chain = upstream.compose(withContext(lifecycle.context)); - if (onSuccessConsumer != null) { - chain = - chain.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t)); - } - chain.doFinally(lifecycle::end).subscribe(o); - })); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); }); } @@ -623,18 +530,7 @@ public SingleSource apply(Single upstream) { return Single.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return Single.wrap( - observer -> - lifecycle.run( - observer, - o -> { - Single chain = upstream.compose(withContext(lifecycle.context)); - if (onSuccessConsumer != null) { - chain = - chain.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); - } - chain.doFinally(lifecycle::end).subscribe(o); - })); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); }); } @@ -643,18 +539,7 @@ public MaybeSource apply(Maybe upstream) { return Maybe.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return Maybe.wrap( - observer -> - lifecycle.run( - observer, - o -> { - Maybe chain = upstream.compose(withContext(lifecycle.context)); - if (onSuccessConsumer != null) { - chain = - chain.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); - } - chain.doFinally(lifecycle::end).subscribe(o); - })); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); }); } @@ -663,142 +548,7 @@ public CompletableSource apply(Completable upstream) { return Completable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return Completable.wrap( - observer -> - lifecycle.run( - observer, - o -> { - Completable chain = upstream.compose(withContext(lifecycle.context)); - // Completable does not emit items, so onSuccessConsumer is not - // applicable. - chain.doFinally(lifecycle::end).subscribe(o); - })); - }); - } - } - - /** - * An observer that wraps another observer and ensures that the OpenTelemetry context is active - * during all callback methods. - * - * @param The type of the items emitted by the stream. - */ - private static final class TracingObserver - implements Subscriber, SingleObserver, MaybeObserver, CompletableObserver { - private final Context context; - private final Subscriber subscriber; - private final SingleObserver singleObserver; - private final MaybeObserver maybeObserver; - private final CompletableObserver completableObserver; - - private TracingObserver( - Context context, - Subscriber subscriber, - SingleObserver singleObserver, - MaybeObserver maybeObserver, - CompletableObserver completableObserver) { - this.context = context; - this.subscriber = subscriber; - this.singleObserver = singleObserver; - this.maybeObserver = maybeObserver; - this.completableObserver = completableObserver; - } - - static TracingObserver wrap(Context context, Subscriber subscriber) { - return new TracingObserver<>(context, subscriber, null, null, null); - } - - static TracingObserver wrap(Context context, SingleObserver observer) { - return new TracingObserver<>(context, null, observer, null, null); - } - - static TracingObserver wrap(Context context, MaybeObserver observer) { - return new TracingObserver<>(context, null, null, observer, null); - } - - static TracingObserver wrap(Context context, CompletableObserver observer) { - return new TracingObserver<>(context, null, null, null, observer); - } - - private void runInContext(Runnable action) { - try (Scope scope = context.makeCurrent()) { - action.run(); - } - } - - @Override - public void onSubscribe(Subscription s) { - runInContext( - () -> { - if (subscriber != null) { - subscriber.onSubscribe(s); - } - }); - } - - @Override - public void onSubscribe(Disposable d) { - runInContext( - () -> { - if (singleObserver != null) { - singleObserver.onSubscribe(d); - } else if (maybeObserver != null) { - maybeObserver.onSubscribe(d); - } else if (completableObserver != null) { - completableObserver.onSubscribe(d); - } - }); - } - - @Override - public void onNext(T t) { - runInContext( - () -> { - if (subscriber != null) { - subscriber.onNext(t); - } - }); - } - - @Override - public void onSuccess(T t) { - runInContext( - () -> { - if (singleObserver != null) { - singleObserver.onSuccess(t); - } else if (maybeObserver != null) { - maybeObserver.onSuccess(t); - } - }); - } - - @Override - public void onError(Throwable t) { - runInContext( - () -> { - if (subscriber != null) { - subscriber.onError(t); - } else if (singleObserver != null) { - singleObserver.onError(t); - } else if (maybeObserver != null) { - maybeObserver.onError(t); - } else if (completableObserver != null) { - completableObserver.onError(t); - } - }); - } - - @Override - public void onComplete() { - runInContext( - () -> { - if (subscriber != null) { - subscriber.onComplete(); - } else if (maybeObserver != null) { - maybeObserver.onComplete(); - } else if (completableObserver != null) { - completableObserver.onComplete(); - } + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); }); } } From 305299fd3a009b24d415ae8f3f052e1bf0d477e4 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Tue, 10 Mar 2026 02:18:44 -0700 Subject: [PATCH 10/50] refactor: remove the Optional param in VertexAiCodeExecutor method PiperOrigin-RevId: 881301261 --- .../adk/codeexecutors/VertexAiCodeExecutor.java | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java b/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java index 5268edf39..af2219d18 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java +++ b/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java @@ -36,7 +36,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Optional; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -140,7 +140,7 @@ public CodeExecutionResult executeCode( executeCodeInterpreter( getCodeWithImports(codeExecutionInput.code()), codeExecutionInput.inputFiles(), - codeExecutionInput.executionId()); + codeExecutionInput.executionId().orElse(null)); // Save output file as artifacts. List savedFiles = new ArrayList<>(); @@ -173,7 +173,7 @@ public CodeExecutionResult executeCode( } private Map executeCodeInterpreter( - String code, List inputFiles, Optional sessionId) { + String code, List inputFiles, @Nullable String sessionId) { ExtensionExecutionServiceClient codeInterpreterExtension = getCodeInterpreterExtension(); if (codeInterpreterExtension == null) { logger.warn("Vertex AI Code Interpreter execution is not available. Returning empty result."); @@ -196,8 +196,9 @@ private Map executeCodeInterpreter( paramsBuilder.putFields( "files", Value.newBuilder().setListValue(listBuilder.build()).build()); } - sessionId.ifPresent( - s -> paramsBuilder.putFields("session_id", Value.newBuilder().setStringValue(s).build())); + if (sessionId != null) { + paramsBuilder.putFields("session_id", Value.newBuilder().setStringValue(sessionId).build()); + } ExecuteExtensionRequest request = ExecuteExtensionRequest.newBuilder() From b71900f08cbab1a89b775356db2294440945a605 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Mar 2026 05:56:51 -0700 Subject: [PATCH 11/50] refactor: Use Maybe instead of Single PiperOrigin-RevId: 881382533 --- .../java/com/google/adk/agents/BaseAgent.java | 61 +++++++------- .../adk/flows/llmflows/BaseLlmFlow.java | 84 +++++++++---------- 2 files changed, 69 insertions(+), 76 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index d74ba9ca5..00676ec31 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -32,7 +32,6 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; -import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -316,30 +315,29 @@ private Flowable run( () -> { InvocationContext invocationContext = createInvocationContext(parentContext); + Flowable mainAndAfterEvents = + Flowable.defer(() -> runImplementation.apply(invocationContext)) + .concatWith( + Flowable.defer( + () -> + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), afterAgentCallback), + invocationContext) + .toFlowable())); + return callCallback( beforeCallbacksToFunctions( invocationContext.pluginManager(), beforeAgentCallback), invocationContext) .flatMapPublisher( - beforeEventOpt -> { + beforeEvent -> { if (invocationContext.endInvocation()) { - return Flowable.fromOptional(beforeEventOpt); + return Flowable.just(beforeEvent); } - - Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); - Flowable mainEvents = - Flowable.defer(() -> runImplementation.apply(invocationContext)); - Flowable afterEvents = - Flowable.defer( - () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), afterAgentCallback), - invocationContext) - .flatMapPublisher(Flowable::fromOptional)); - - return Flowable.concat(beforeEvents, mainEvents, afterEvents); + return Flowable.just(beforeEvent).concatWith(mainAndAfterEvents); }) + .switchIfEmpty(mainAndAfterEvents) .compose( Tracing.traceAgent( "invoke_agent " + name(), name(), description(), invocationContext)); @@ -383,13 +381,13 @@ private ImmutableList>> callbacksTo * * @param agentCallbacks Callback functions. * @param invocationContext Current invocation context. - * @return single emitting first event, or empty if none. + * @return maybe emitting first event, or empty if none. */ - private Single> callCallback( + private Maybe callCallback( List>> agentCallbacks, InvocationContext invocationContext) { if (agentCallbacks.isEmpty()) { - return Single.just(Optional.empty()); + return Maybe.empty(); } CallbackContext callbackContext = @@ -404,21 +402,20 @@ private Single> callCallback( .map( content -> { invocationContext.setEndInvocation(true); - return Optional.of( - Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch().orElse(null)) - .actions(callbackContext.eventActions()) - .content(content) - .build()); + return Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch().orElse(null)) + .actions(callbackContext.eventActions()) + .content(content) + .build(); }) .toFlowable(); }) .firstElement() .switchIfEmpty( - Single.defer( + Maybe.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -429,9 +426,9 @@ private Single> callCallback( .branch(invocationContext.branch().orElse(null)) .actions(callbackContext.eventActions()); - return Single.just(Optional.of(eventBuilder.build())); + return Maybe.just(eventBuilder.build()); } else { - return Single.just(Optional.empty()); + return Maybe.empty(); } })); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 6ed9ccaa3..e1afca2b1 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -170,52 +170,51 @@ private Flowable callLlm( LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) - .flatMapPublisher( - beforeResponse -> { - if (beforeResponse.isPresent()) { - return Flowable.just(beforeResponse.get()); - } - BaseLlm llm = - agent.resolvedModel().model().isPresent() - ? agent.resolvedModel().model().get() - : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, llmRequestBuilder, eventForCallbackUsage, exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnNext( - llmResp -> - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp)) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .compose(Tracing.trace("call_llm")) - .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()); - }); + .toFlowable() + .switchIfEmpty( + Flowable.defer( + () -> { + BaseLlm llm = + agent.resolvedModel().model().isPresent() + ? agent.resolvedModel().model().get() + : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); + return llm.generateContent( + llmRequestBuilder.build(), + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, llmRequestBuilder, eventForCallbackUsage, exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnNext( + llmResp -> + Tracing.traceCallLlm( + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp)) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .compose(Tracing.trace("call_llm")) + .concatMap( + llmResp -> + handleAfterModelCallback(context, llmResp, eventForCallbackUsage) + .toFlowable()); + })); } /** * Invokes {@link BeforeModelCallback}s. If any returns a response, it's used instead of calling * the LLM. * - * @return A {@link Single} with the callback result or {@link Optional#empty()}. + * @return A {@link Maybe} with the callback result. */ - private Single> handleBeforeModelCallback( + private Maybe handleBeforeModelCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = @@ -228,7 +227,7 @@ private Single> handleBeforeModelCallback( List callbacks = agent.canonicalBeforeModelCallbacks(); if (callbacks.isEmpty()) { - return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty()); + return pluginResult; } Maybe callbackResult = @@ -238,10 +237,7 @@ private Single> handleBeforeModelCallback( .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) .firstElement()); - return pluginResult - .switchIfEmpty(callbackResult) - .map(Optional::of) - .defaultIfEmpty(Optional.empty()); + return pluginResult.switchIfEmpty(callbackResult); } /** From d1d5539ef763b6bfd5057c6ea0f2591225a98535 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Mar 2026 06:15:15 -0700 Subject: [PATCH 12/50] feat: update return type for artifactDelta getter and setter to Map from ConcurrentMap PiperOrigin-RevId: 881389219 --- .../main/java/com/google/adk/events/EventActions.java | 10 +++++----- .../java/com/google/adk/events/EventActionsTest.java | 10 ++++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 31b096930..4873a5f49 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -110,12 +110,12 @@ public void removeStateByKey(String key) { } @JsonProperty("artifactDelta") - public ConcurrentMap artifactDelta() { + public Map artifactDelta() { return artifactDelta; } - public void setArtifactDelta(ConcurrentMap artifactDelta) { - this.artifactDelta = artifactDelta; + public void setArtifactDelta(Map artifactDelta) { + this.artifactDelta = new ConcurrentHashMap<>(artifactDelta); } @JsonProperty("deletedArtifactIds") @@ -322,8 +322,8 @@ public Builder stateDelta(ConcurrentMap value) { @CanIgnoreReturnValue @JsonProperty("artifactDelta") - public Builder artifactDelta(ConcurrentMap value) { - this.artifactDelta = value; + public Builder artifactDelta(Map value) { + this.artifactDelta = new ConcurrentHashMap<>(value); return this; } diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 2975ca83f..22bb94e64 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -110,6 +110,16 @@ public void merge_mergesAllFields() { assertThat(merged.compaction()).hasValue(COMPACTION); } + @Test + public void setArtifactDelta_copiesRegularMap() { + EventActions eventActions = new EventActions(); + ImmutableMap artifactDelta = ImmutableMap.of("artifact1", 1); + + eventActions.setArtifactDelta(artifactDelta); + + assertThat(eventActions.artifactDelta()).containsExactly("artifact1", 1); + } + @Test public void removeStateByKey_marksKeyAsRemoved() { EventActions eventActions = new EventActions(); From b8316b1944ce17cc9208963cc09d900c379444c6 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Tue, 10 Mar 2026 07:12:02 -0700 Subject: [PATCH 13/50] feat!: Remove Optional parameters in EventActions PiperOrigin-RevId: 881410767 --- .../com/google/adk/events/EventActions.java | 84 +++++++------------ 1 file changed, 28 insertions(+), 56 deletions(-) diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 4873a5f49..0b167de93 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -28,36 +28,32 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import javax.annotation.Nullable; +import org.jspecify.annotations.Nullable; /** Represents the actions attached to an event. */ // TODO - b/414081262 make json wire camelCase @JsonDeserialize(builder = EventActions.Builder.class) public class EventActions extends JsonBaseModel { - private Optional skipSummarization; + private @Nullable Boolean skipSummarization; private ConcurrentMap stateDelta; private ConcurrentMap artifactDelta; private Set deletedArtifactIds; - private Optional transferToAgent; - private Optional escalate; + private @Nullable String transferToAgent; + private @Nullable Boolean escalate; private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent; - private Optional compaction; + private @Nullable EventCompaction compaction; /** Default constructor for Jackson. */ public EventActions() { - this.skipSummarization = Optional.empty(); this.stateDelta = new ConcurrentHashMap<>(); this.artifactDelta = new ConcurrentHashMap<>(); this.deletedArtifactIds = new HashSet<>(); - this.transferToAgent = Optional.empty(); - this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); this.requestedToolConfirmations = new ConcurrentHashMap<>(); this.endOfAgent = false; - this.compaction = Optional.empty(); } private EventActions(Builder builder) { @@ -75,19 +71,15 @@ private EventActions(Builder builder) { @JsonProperty("skipSummarization") public Optional skipSummarization() { - return skipSummarization; + return Optional.ofNullable(skipSummarization); } public void setSkipSummarization(@Nullable Boolean skipSummarization) { - this.skipSummarization = Optional.ofNullable(skipSummarization); - } - - public void setSkipSummarization(Optional skipSummarization) { this.skipSummarization = skipSummarization; } public void setSkipSummarization(boolean skipSummarization) { - this.skipSummarization = Optional.of(skipSummarization); + this.skipSummarization = skipSummarization; } @JsonProperty("stateDelta") @@ -130,30 +122,22 @@ public void setDeletedArtifactIds(Set deletedArtifactIds) { @JsonProperty("transferToAgent") public Optional transferToAgent() { - return transferToAgent; + return Optional.ofNullable(transferToAgent); } - public void setTransferToAgent(Optional transferToAgent) { + public void setTransferToAgent(@Nullable String transferToAgent) { this.transferToAgent = transferToAgent; } - public void setTransferToAgent(String transferToAgent) { - this.transferToAgent = Optional.ofNullable(transferToAgent); - } - @JsonProperty("escalate") public Optional escalate() { - return escalate; + return Optional.ofNullable(escalate); } - public void setEscalate(Optional escalate) { + public void setEscalate(@Nullable Boolean escalate) { this.escalate = escalate; } - public void setEscalate(boolean escalate) { - this.escalate = Optional.of(escalate); - } - @JsonProperty("requestedAuthConfigs") public ConcurrentMap> requestedAuthConfigs() { return requestedAuthConfigs; @@ -199,14 +183,6 @@ public Optional endInvocation() { return endOfAgent ? Optional.of(true) : Optional.empty(); } - /** - * @deprecated Use {@link #setEndOfAgent(boolean)} instead. - */ - @Deprecated - public void setEndInvocation(Optional endInvocation) { - this.endOfAgent = endInvocation.orElse(false); - } - /** * @deprecated Use {@link #setEndOfAgent(boolean)} instead. */ @@ -217,10 +193,10 @@ public void setEndInvocation(boolean endInvocation) { @JsonProperty("compaction") public Optional compaction() { - return compaction; + return Optional.ofNullable(compaction); } - public void setCompaction(Optional compaction) { + public void setCompaction(@Nullable EventCompaction compaction) { this.compaction = compaction; } @@ -269,47 +245,43 @@ public int hashCode() { /** Builder for {@link EventActions}. */ public static class Builder { - private Optional skipSummarization; + private @Nullable Boolean skipSummarization; private ConcurrentMap stateDelta; private ConcurrentMap artifactDelta; private Set deletedArtifactIds; - private Optional transferToAgent; - private Optional escalate; + private @Nullable String transferToAgent; + private @Nullable Boolean escalate; private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent = false; - private Optional compaction; + private @Nullable EventCompaction compaction; public Builder() { - this.skipSummarization = Optional.empty(); this.stateDelta = new ConcurrentHashMap<>(); this.artifactDelta = new ConcurrentHashMap<>(); this.deletedArtifactIds = new HashSet<>(); - this.transferToAgent = Optional.empty(); - this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); this.requestedToolConfirmations = new ConcurrentHashMap<>(); - this.compaction = Optional.empty(); } private Builder(EventActions eventActions) { - this.skipSummarization = eventActions.skipSummarization(); + this.skipSummarization = eventActions.skipSummarization; this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta()); this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta()); this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds()); - this.transferToAgent = eventActions.transferToAgent(); - this.escalate = eventActions.escalate(); + this.transferToAgent = eventActions.transferToAgent; + this.escalate = eventActions.escalate; this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs()); this.requestedToolConfirmations = new ConcurrentHashMap<>(eventActions.requestedToolConfirmations()); - this.endOfAgent = eventActions.endOfAgent(); - this.compaction = eventActions.compaction(); + this.endOfAgent = eventActions.endOfAgent; + this.compaction = eventActions.compaction; } @CanIgnoreReturnValue @JsonProperty("skipSummarization") public Builder skipSummarization(boolean skipSummarization) { - this.skipSummarization = Optional.of(skipSummarization); + this.skipSummarization = skipSummarization; return this; } @@ -336,15 +308,15 @@ public Builder deletedArtifactIds(Set value) { @CanIgnoreReturnValue @JsonProperty("transferToAgent") - public Builder transferToAgent(String agentId) { - this.transferToAgent = Optional.ofNullable(agentId); + public Builder transferToAgent(@Nullable String agentId) { + this.transferToAgent = agentId; return this; } @CanIgnoreReturnValue @JsonProperty("escalate") public Builder escalate(boolean escalate) { - this.escalate = Optional.of(escalate); + this.escalate = escalate; return this; } @@ -391,8 +363,8 @@ public Builder endInvocation(boolean endInvocation) { @CanIgnoreReturnValue @JsonProperty("compaction") - public Builder compaction(EventCompaction value) { - this.compaction = Optional.ofNullable(value); + public Builder compaction(@Nullable EventCompaction value) { + this.compaction = value; return this; } From 14ee28ba593a9f6f5f7b9bb6003441539fe33a18 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Mar 2026 08:01:55 -0700 Subject: [PATCH 14/50] fix: Make sure that `InvocationContext.callbackContextData` remains the same instance `InvocationContext.callbackContextData` is used by plugins to keep track of things like invocation start times in `before` and then read in `after` callbacks PiperOrigin-RevId: 881432816 --- .../java/com/google/adk/agents/InvocationContext.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 7602ca9f2..7f0e49d0c 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -75,7 +75,10 @@ protected InvocationContext(Builder builder) { this.eventsCompactionConfig = builder.eventsCompactionConfig; this.contextCacheConfig = builder.contextCacheConfig; this.invocationCostManager = builder.invocationCostManager; - this.callbackContextData = new ConcurrentHashMap<>(builder.callbackContextData); + // Don't copy the callback context data. This should be the same instance for the full + // invocation invocation so that Plugins can access the same data it during the invocation + // across all types of callbacks. + this.callbackContextData = builder.callbackContextData; } /** @@ -345,7 +348,10 @@ private Builder(InvocationContext context) { this.eventsCompactionConfig = context.eventsCompactionConfig; this.contextCacheConfig = context.contextCacheConfig; this.invocationCostManager = context.invocationCostManager; - this.callbackContextData = new ConcurrentHashMap<>(context.callbackContextData); + // Don't copy the callback context data. This should be the same instance for the full + // invocation invocation so that Plugins can access the same data it during the invocation + // across all types of callbacks. + this.callbackContextData = context.callbackContextData; } private BaseSessionService sessionService; From d66c31d5e7ef75b994c5ee3efbc3cf89f393776c Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Mar 2026 08:51:34 -0700 Subject: [PATCH 15/50] refactor: Removing unnecessary PluginManager.runX() methods PiperOrigin-RevId: 881456806 --- .../com/google/adk/plugins/PluginManager.java | 52 +------------------ .../java/com/google/adk/runner/Runner.java | 2 +- 2 files changed, 3 insertions(+), 51 deletions(-) diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index 56dea936a..e534da787 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -34,6 +34,7 @@ import java.util.Map; import java.util.Optional; import java.util.function.Function; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,7 +48,7 @@ public class PluginManager extends BasePlugin { private static final Logger logger = LoggerFactory.getLogger(PluginManager.class); private final List plugins = new ArrayList<>(); - public PluginManager(List plugins) { + public PluginManager(@Nullable List plugins) { super("PluginManager"); if (plugins != null) { plugins.forEach(this::registerPlugin); @@ -123,10 +124,6 @@ public Maybe beforeRunCallback(InvocationContext invocationContext) { plugin -> plugin.beforeRunCallback(invocationContext), "beforeRunCallback"); } - public Completable runAfterRunCallback(InvocationContext invocationContext) { - return afterRunCallback(invocationContext); - } - @Override public Completable afterRunCallback(InvocationContext invocationContext) { return Flowable.fromIterable(plugins) @@ -155,41 +152,24 @@ public Completable close() { "[{}] Error during callback 'close'", plugin.getName(), e))); } - public Maybe runOnEventCallback(InvocationContext invocationContext, Event event) { - return onEventCallback(invocationContext, event); - } - @Override public Maybe onEventCallback(InvocationContext invocationContext, Event event) { return runMaybeCallbacks( plugin -> plugin.onEventCallback(invocationContext, event), "onEventCallback"); } - public Maybe runBeforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { - return beforeAgentCallback(agent, callbackContext); - } - @Override public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return runMaybeCallbacks( plugin -> plugin.beforeAgentCallback(agent, callbackContext), "beforeAgentCallback"); } - public Maybe runAfterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { - return afterAgentCallback(agent, callbackContext); - } - @Override public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return runMaybeCallbacks( plugin -> plugin.afterAgentCallback(agent, callbackContext), "afterAgentCallback"); } - public Maybe runBeforeModelCallback( - CallbackContext callbackContext, LlmRequest.Builder llmRequest) { - return beforeModelCallback(callbackContext, llmRequest); - } - @Override public Maybe beforeModelCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest) { @@ -197,11 +177,6 @@ public Maybe beforeModelCallback( plugin -> plugin.beforeModelCallback(callbackContext, llmRequest), "beforeModelCallback"); } - public Maybe runAfterModelCallback( - CallbackContext callbackContext, LlmResponse llmResponse) { - return afterModelCallback(callbackContext, llmResponse); - } - @Override public Maybe afterModelCallback( CallbackContext callbackContext, LlmResponse llmResponse) { @@ -209,11 +184,6 @@ public Maybe afterModelCallback( plugin -> plugin.afterModelCallback(callbackContext, llmResponse), "afterModelCallback"); } - public Maybe runOnModelErrorCallback( - CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { - return onModelErrorCallback(callbackContext, llmRequest, error); - } - @Override public Maybe onModelErrorCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { @@ -222,11 +192,6 @@ public Maybe onModelErrorCallback( "onModelErrorCallback"); } - public Maybe> runBeforeToolCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext) { - return beforeToolCallback(tool, toolArgs, toolContext); - } - @Override public Maybe> beforeToolCallback( BaseTool tool, Map toolArgs, ToolContext toolContext) { @@ -234,14 +199,6 @@ public Maybe> beforeToolCallback( plugin -> plugin.beforeToolCallback(tool, toolArgs, toolContext), "beforeToolCallback"); } - public Maybe> runAfterToolCallback( - BaseTool tool, - Map toolArgs, - ToolContext toolContext, - Map result) { - return afterToolCallback(tool, toolArgs, toolContext, result); - } - @Override public Maybe> afterToolCallback( BaseTool tool, @@ -253,11 +210,6 @@ public Maybe> afterToolCallback( "afterToolCallback"); } - public Maybe> runOnToolErrorCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { - return onToolErrorCallback(tool, toolArgs, toolContext, error); - } - @Override public Maybe> onToolErrorCallback( BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 4371300fb..29b2b76d3 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -568,7 +568,7 @@ private Flowable runAgentWithFreshSession( .toFlowable() .switchIfEmpty(agentEvents) .concatWith( - Completable.defer(() -> pluginManager.runAfterRunCallback(contextWithUpdatedSession))) + Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) .concatWith(Completable.defer(() -> compactEvents(updatedSession))); } From 72c98045088ff49d4b45b645b3ad31100e6d1fa8 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Tue, 10 Mar 2026 08:55:40 -0700 Subject: [PATCH 16/50] refactor: suppress warnings for Optional param in LlmRequest.addInstructions PiperOrigin-RevId: 881458806 --- .../com/google/adk/models/LlmRequest.java | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/core/src/main/java/com/google/adk/models/LlmRequest.java b/core/src/main/java/com/google/adk/models/LlmRequest.java index e35969147..1a45c3a95 100644 --- a/core/src/main/java/com/google/adk/models/LlmRequest.java +++ b/core/src/main/java/com/google/adk/models/LlmRequest.java @@ -17,7 +17,6 @@ package com.google.adk.models; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -172,29 +171,33 @@ public final Builder appendInstructions(List instructions) { return liveConnectConfig(liveCfg.toBuilder().systemInstruction(newLiveSi).build()); } + // In this particular case we can keep the Optional as a type of a + // parameter, since the function is private and used in only one place while + // the Optional type plays nicely with flatMaps in the code (if we had a + // nullable here, we'd wrap it in the Optional anyway) private Content addInstructions( - Optional currentSystemInstruction, List additionalInstructions) { + @SuppressWarnings("checkstyle:IllegalType") Optional currentSystemInstruction, + List additionalInstructions) { checkArgument( - currentSystemInstruction.isEmpty() - || currentSystemInstruction.get().parts().map(parts -> parts.size()).orElse(0) <= 1, + currentSystemInstruction.flatMap(Content::parts).map(parts -> parts.size()).orElse(0) + <= 1, "At most one instruction is supported."); // Either append to the existing instruction, or create a new one. String instructions = String.join("\n\n", additionalInstructions); - Optional part = - currentSystemInstruction - .flatMap(Content::parts) - .flatMap(parts -> parts.stream().findFirst()); - if (part.isEmpty() || part.get().text().isEmpty()) { - part = Optional.of(Part.fromText(instructions)); - } else { - part = Optional.of(Part.fromText(part.get().text().get() + "\n\n" + instructions)); - } - checkState(part.isPresent(), "Failed to create instruction."); + Part part = + Part.fromText( + currentSystemInstruction + .flatMap(Content::parts) + .flatMap(parts -> parts.stream().findFirst()) + .flatMap(Part::text) + .map(text -> text + "\n\n" + instructions) + .orElse(instructions)); String role = currentSystemInstruction.flatMap(Content::role).orElse("user"); - return Content.builder().parts(part.get()).role(role).build(); + + return Content.builder().parts(part).role(role).build(); } @CanIgnoreReturnValue From aa0e06c535eb65c9008564f177e2590b6d29d30e Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Tue, 10 Mar 2026 11:25:07 -0700 Subject: [PATCH 17/50] refactor: update ApiClient.createHttpClient Optional timeout param to @Nullable PiperOrigin-RevId: 881536165 --- .../main/java/com/google/adk/sessions/ApiClient.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/ApiClient.java b/core/src/main/java/com/google/adk/sessions/ApiClient.java index 6bf69ee47..e850199e9 100644 --- a/core/src/main/java/com/google/adk/sessions/ApiClient.java +++ b/core/src/main/java/com/google/adk/sessions/ApiClient.java @@ -67,7 +67,7 @@ abstract class ApiClient { applyHttpOptions(customHttpOptions.get()); } - this.httpClient = createHttpClient(httpOptions.timeout()); + this.httpClient = createHttpClient(httpOptions.timeout().orElse(null)); } ApiClient( @@ -113,13 +113,13 @@ abstract class ApiClient { } this.apiKey = Optional.empty(); this.vertexAI = true; - this.httpClient = createHttpClient(httpOptions.timeout()); + this.httpClient = createHttpClient(httpOptions.timeout().orElse(null)); } - private OkHttpClient createHttpClient(Optional timeout) { + private OkHttpClient createHttpClient(@Nullable Integer timeout) { OkHttpClient.Builder builder = new OkHttpClient().newBuilder(); - if (timeout.isPresent()) { - builder.connectTimeout(Duration.ofMillis(timeout.get())); + if (timeout != null) { + builder.connectTimeout(Duration.ofMillis(timeout)); } return builder.build(); } From 444e0f0b4d02481bdb82213f8a87828188fbb89c Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Tue, 10 Mar 2026 11:26:24 -0700 Subject: [PATCH 18/50] refactor: delete SessionJsonConverter.putIfEmpty method with Optional collection param PiperOrigin-RevId: 881536906 --- .../java/com/google/adk/sessions/SessionJsonConverter.java | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index 97cc0f56d..0c2b33704 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -73,7 +73,7 @@ static String convertEventToJson(Event event, boolean useIsoString) { event.turnComplete().ifPresent(v -> metadataJson.put("turnComplete", v)); event.interrupted().ifPresent(v -> metadataJson.put("interrupted", v)); event.branch().ifPresent(v -> metadataJson.put("branch", v)); - putIfNotEmpty(metadataJson, "longRunningToolIds", event.longRunningToolIds()); + event.longRunningToolIds().ifPresent(v -> putIfNotEmpty(metadataJson, "longRunningToolIds", v)); event.groundingMetadata().ifPresent(v -> metadataJson.put("groundingMetadata", v)); event.usageMetadata().ifPresent(v -> metadataJson.put("usageMetadata", v)); Map eventJson = new HashMap<>(); @@ -355,11 +355,6 @@ private static void putIfNotEmpty(Map map, String key, Map } } - private static void putIfNotEmpty( - Map map, String key, Optional> values) { - values.ifPresent(v -> putIfNotEmpty(map, key, v)); - } - private static void putIfNotEmpty( Map map, String key, @Nullable Collection values) { if (values != null && !values.isEmpty()) { From b857f010a0f51df0eb25ecdc364465ffdd9fef65 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Mar 2026 13:55:10 -0700 Subject: [PATCH 19/50] fix: Removing deprecated methods in Runner PiperOrigin-RevId: 881608694 --- .../adk/models/langchain4j/RunLoop.java | 3 +- .../springai/SpringAIIntegrationTest.java | 21 ++++++---- .../google/adk/models/springai/TestUtils.java | 17 +++++--- .../java/com/google/adk/runner/Runner.java | 41 ------------------- 4 files changed, 26 insertions(+), 56 deletions(-) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java index 04a2aa585..ede7300fe 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java @@ -53,7 +53,8 @@ public static List runLoop(BaseAgent agent, boolean streaming, Object... allEvents.addAll( runner .runAsync( - session, + session.userId(), + session.id(), messageContent, RunConfig.builder() .setStreamingMode( diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java index 6843c8eaa..11b17ebf1 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java @@ -18,6 +18,7 @@ import static org.junit.jupiter.api.Assertions.*; import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; import com.google.adk.models.springai.integrations.tools.WeatherTool; import com.google.adk.runner.InMemoryRunner; @@ -73,14 +74,15 @@ public ChatResponse call(Prompt prompt) { // when Runner runner = new InMemoryRunner(agent); - Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); + Session session = + runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); Content userMessage = Content.builder().role("user").parts(List.of(Part.fromText("What is a qubit?"))).build(); List events = runner - .runAsync(session, userMessage, com.google.adk.agents.RunConfig.builder().build()) + .runAsync(session.userId(), session.id(), userMessage, RunConfig.builder().build()) .toList() .blockingGet(); @@ -149,7 +151,8 @@ public ChatResponse call(Prompt prompt) { // when Runner runner = new InMemoryRunner(agent); - Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); + Session session = + runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); Content userMessage = Content.builder() @@ -159,7 +162,7 @@ public ChatResponse call(Prompt prompt) { List events = runner - .runAsync(session, userMessage, com.google.adk.agents.RunConfig.builder().build()) + .runAsync(session.userId(), session.id(), userMessage, RunConfig.builder().build()) .toList() .blockingGet(); @@ -217,7 +220,8 @@ public Flux stream(Prompt prompt) { // when Runner runner = new InMemoryRunner(agent); - Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); + Session session = + runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); Content userMessage = Content.builder() @@ -228,11 +232,10 @@ public Flux stream(Prompt prompt) { List events = runner .runAsync( - session, + session.userId(), + session.id(), userMessage, - com.google.adk.agents.RunConfig.builder() - .setStreamingMode(com.google.adk.agents.RunConfig.StreamingMode.SSE) - .build()) + RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.SSE).build()) .toList() .blockingGet(); diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java index f18ded055..891dcd62d 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java @@ -46,7 +46,8 @@ public static List askAgent(BaseAgent agent, boolean streaming, Object... allEvents.addAll( runner .runAsync( - session, + session.userId(), + session.id(), messageContent, RunConfig.builder() .setStreamingMode( @@ -67,13 +68,17 @@ public static List askBlockingAgent(BaseAgent agent, Object... messages) } Runner runner = new InMemoryRunner(agent); - Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); + Session session = + runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); List events = new ArrayList<>(); for (Content content : contents) { List batchEvents = - runner.runAsync(session, content, RunConfig.builder().build()).toList().blockingGet(); + runner + .runAsync(session.userId(), session.id(), content, RunConfig.builder().build()) + .toList() + .blockingGet(); events.addAll(batchEvents); } @@ -88,7 +93,8 @@ public static List askAgentStreaming(BaseAgent agent, Object... messages) } Runner runner = new InMemoryRunner(agent); - Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); + Session session = + runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); List events = new ArrayList<>(); @@ -96,7 +102,8 @@ public static List askAgentStreaming(BaseAgent agent, Object... messages) List batchEvents = runner .runAsync( - session, + session.userId(), + session.id(), content, RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.SSE).build()) .toList() diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 29b2b76d3..0fd8bb92e 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -415,35 +415,6 @@ public Flowable runAsync(String userId, String sessionId, Content newMess return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); } - /** - * See {@link #runAsync(Session, Content, RunConfig, Map)}. - * - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { - return runAsync(session, newMessage, runConfig, /* stateDelta= */ null); - } - - /** - * Runs the agent asynchronously using a provided Session object. - * - * @param session The session to run the agent in. - * @param newMessage The new message from the user to process. - * @param runConfig Configuration for the agent run. - * @param stateDelta Optional map of state updates to merge into the session for this run. - * @return A Flowable stream of {@link Event} objects generated by the agent during execution. - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync( - Session session, - Content newMessage, - RunConfig runConfig, - @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta); - } - /** * Runs the agent asynchronously using a provided Session object. * @@ -710,18 +681,6 @@ public Flowable runLive( return runLive(sessionKey.userId(), sessionKey.id(), liveRequestQueue, runConfig); } - /** - * Runs the agent asynchronously with a default user ID. - * - * @return stream of generated events. - */ - @Deprecated(since = "0.5.0", forRemoval = true) - public Flowable runWithSessionId( - String sessionId, Content newMessage, RunConfig runConfig) { - // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. - return this.runAsync("tmp-user", sessionId, newMessage, runConfig); - } - /** * Checks if the agent and its parent chain allow transfer up the tree. * From 0d8e22d6e9fe4e8d29c87d485915ba51a22eb350 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Mar 2026 14:58:35 -0700 Subject: [PATCH 20/50] fix: Removing deprecated methods in Runner PiperOrigin-RevId: 881637295 --- .../adk/models/langchain4j/RunLoop.java | 3 +- .../springai/SpringAIIntegrationTest.java | 21 ++++------ .../google/adk/models/springai/TestUtils.java | 17 +++----- .../java/com/google/adk/runner/Runner.java | 41 +++++++++++++++++++ 4 files changed, 56 insertions(+), 26 deletions(-) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java index ede7300fe..04a2aa585 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java @@ -53,8 +53,7 @@ public static List runLoop(BaseAgent agent, boolean streaming, Object... allEvents.addAll( runner .runAsync( - session.userId(), - session.id(), + session, messageContent, RunConfig.builder() .setStreamingMode( diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java index 11b17ebf1..6843c8eaa 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java @@ -18,7 +18,6 @@ import static org.junit.jupiter.api.Assertions.*; import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; import com.google.adk.models.springai.integrations.tools.WeatherTool; import com.google.adk.runner.InMemoryRunner; @@ -74,15 +73,14 @@ public ChatResponse call(Prompt prompt) { // when Runner runner = new InMemoryRunner(agent); - Session session = - runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); + Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); Content userMessage = Content.builder().role("user").parts(List.of(Part.fromText("What is a qubit?"))).build(); List events = runner - .runAsync(session.userId(), session.id(), userMessage, RunConfig.builder().build()) + .runAsync(session, userMessage, com.google.adk.agents.RunConfig.builder().build()) .toList() .blockingGet(); @@ -151,8 +149,7 @@ public ChatResponse call(Prompt prompt) { // when Runner runner = new InMemoryRunner(agent); - Session session = - runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); + Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); Content userMessage = Content.builder() @@ -162,7 +159,7 @@ public ChatResponse call(Prompt prompt) { List events = runner - .runAsync(session.userId(), session.id(), userMessage, RunConfig.builder().build()) + .runAsync(session, userMessage, com.google.adk.agents.RunConfig.builder().build()) .toList() .blockingGet(); @@ -220,8 +217,7 @@ public Flux stream(Prompt prompt) { // when Runner runner = new InMemoryRunner(agent); - Session session = - runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); + Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); Content userMessage = Content.builder() @@ -232,10 +228,11 @@ public Flux stream(Prompt prompt) { List events = runner .runAsync( - session.userId(), - session.id(), + session, userMessage, - RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.SSE).build()) + com.google.adk.agents.RunConfig.builder() + .setStreamingMode(com.google.adk.agents.RunConfig.StreamingMode.SSE) + .build()) .toList() .blockingGet(); diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java index 891dcd62d..f18ded055 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java @@ -46,8 +46,7 @@ public static List askAgent(BaseAgent agent, boolean streaming, Object... allEvents.addAll( runner .runAsync( - session.userId(), - session.id(), + session, messageContent, RunConfig.builder() .setStreamingMode( @@ -68,17 +67,13 @@ public static List askBlockingAgent(BaseAgent agent, Object... messages) } Runner runner = new InMemoryRunner(agent); - Session session = - runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); + Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); List events = new ArrayList<>(); for (Content content : contents) { List batchEvents = - runner - .runAsync(session.userId(), session.id(), content, RunConfig.builder().build()) - .toList() - .blockingGet(); + runner.runAsync(session, content, RunConfig.builder().build()).toList().blockingGet(); events.addAll(batchEvents); } @@ -93,8 +88,7 @@ public static List askAgentStreaming(BaseAgent agent, Object... messages) } Runner runner = new InMemoryRunner(agent); - Session session = - runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); + Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); List events = new ArrayList<>(); @@ -102,8 +96,7 @@ public static List askAgentStreaming(BaseAgent agent, Object... messages) List batchEvents = runner .runAsync( - session.userId(), - session.id(), + session, content, RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.SSE).build()) .toList() diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 0fd8bb92e..29b2b76d3 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -415,6 +415,35 @@ public Flowable runAsync(String userId, String sessionId, Content newMess return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); } + /** + * See {@link #runAsync(Session, Content, RunConfig, Map)}. + * + * @deprecated Use runAsync with sessionId. + */ + @Deprecated(since = "0.4.0", forRemoval = true) + public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { + return runAsync(session, newMessage, runConfig, /* stateDelta= */ null); + } + + /** + * Runs the agent asynchronously using a provided Session object. + * + * @param session The session to run the agent in. + * @param newMessage The new message from the user to process. + * @param runConfig Configuration for the agent run. + * @param stateDelta Optional map of state updates to merge into the session for this run. + * @return A Flowable stream of {@link Event} objects generated by the agent during execution. + * @deprecated Use runAsync with sessionId. + */ + @Deprecated(since = "0.4.0", forRemoval = true) + public Flowable runAsync( + Session session, + Content newMessage, + RunConfig runConfig, + @Nullable Map stateDelta) { + return runAsyncImpl(session, newMessage, runConfig, stateDelta); + } + /** * Runs the agent asynchronously using a provided Session object. * @@ -681,6 +710,18 @@ public Flowable runLive( return runLive(sessionKey.userId(), sessionKey.id(), liveRequestQueue, runConfig); } + /** + * Runs the agent asynchronously with a default user ID. + * + * @return stream of generated events. + */ + @Deprecated(since = "0.5.0", forRemoval = true) + public Flowable runWithSessionId( + String sessionId, Content newMessage, RunConfig runConfig) { + // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. + return this.runAsync("tmp-user", sessionId, newMessage, runConfig); + } + /** * Checks if the agent and its parent chain allow transfer up the tree. * From 20f863f716f653979551c481d85d4e7fa56a35da Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Mar 2026 17:01:00 -0700 Subject: [PATCH 21/50] fix: Explicitly setting the otel parent spans in agents, llm flow and function calls PiperOrigin-RevId: 881688036 --- .../java/com/google/adk/agents/BaseAgent.java | 10 +++++-- .../adk/flows/llmflows/BaseLlmFlow.java | 27 ++++++++++++------- .../google/adk/flows/llmflows/Functions.java | 10 ++++--- .../com/google/adk/telemetry/Tracing.java | 13 --------- 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 00676ec31..ed6631c50 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -29,6 +29,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -311,6 +312,7 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { + Context parentSpanContext = Context.current(); return Flowable.defer( () -> { InvocationContext invocationContext = createInvocationContext(parentContext); @@ -339,8 +341,12 @@ private Flowable run( }) .switchIfEmpty(mainAndAfterEvents) .compose( - Tracing.traceAgent( - "invoke_agent " + name(), name(), description(), invocationContext)); + Tracing.trace("invoke_agent " + name()) + .setParent(parentSpanContext) + .configure( + span -> + Tracing.traceAgentInvocation( + span, name(), description(), invocationContext))); }); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index e1afca2b1..79066b213 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -164,7 +164,10 @@ protected Flowable postprocess( * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ private Flowable callLlm( - InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { + Context spanContext, + InvocationContext context, + LlmRequest llmRequest, + Event eventForCallbackUsage) { LlmAgent agent = (LlmAgent) context.agent(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); @@ -200,7 +203,7 @@ private Flowable callLlm( span.setStatus(StatusCode.ERROR, error.getMessage()); span.recordException(error); }) - .compose(Tracing.trace("call_llm")) + .compose(Tracing.trace("call_llm").setParent(spanContext)) .concatMap( llmResp -> handleAfterModelCallback(context, llmResp, eventForCallbackUsage) @@ -319,7 +322,7 @@ private Single handleAfterModelCallback( * @throws LlmCallsLimitExceededException if the agent exceeds allowed LLM invocations. * @throws IllegalStateException if a transfer agent is specified but not found. */ - private Flowable runOneStep(InvocationContext context) { + private Flowable runOneStep(Context spanContext, InvocationContext context) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); return Flowable.defer( @@ -351,7 +354,11 @@ private Flowable runOneStep(InvocationContext context) { .build(); mutableEventTemplate.setTimestamp(0L); - return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) + return callLlm( + spanContext, + context, + llmRequestAfterPreprocess, + mutableEventTemplate) .concatMap( llmResponse -> { try (Scope postScope = currentContext.makeCurrent()) { @@ -403,11 +410,12 @@ private Flowable runOneStep(InvocationContext context) { */ @Override public Flowable run(InvocationContext invocationContext) { - return run(invocationContext, 0); + return run(Context.current(), invocationContext, 0); } - private Flowable run(InvocationContext invocationContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(invocationContext).cache(); + private Flowable run( + Context spanContext, InvocationContext invocationContext, int stepsCompleted) { + Flowable currentStepEvents = runOneStep(spanContext, invocationContext).cache(); if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); return currentStepEvents; @@ -427,7 +435,7 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return run(invocationContext, stepsCompleted + 1); + return run(spanContext, invocationContext, stepsCompleted + 1); } })); } @@ -444,6 +452,7 @@ private Flowable run(InvocationContext invocationContext, int stepsComple public Flowable runLive(InvocationContext invocationContext) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); + Context spanContext = Context.current(); return preprocessEvents.concatWith( Flowable.defer( @@ -481,7 +490,7 @@ public Flowable runLive(InvocationContext invocationContext) { eventIdForSendData, llmRequestAfterPreprocess.contents()); }) - .compose(Tracing.trace("send_data")); + .compose(Tracing.trace("send_data").setParent(spanContext)); Flowable liveRequests = invocationContext diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index ecc2bb412..c1a996064 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -178,7 +178,7 @@ public static Maybe handleFunctionCalls( if (events.size() > 1) { return Maybe.just(mergedEvent) .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)) - .compose(Tracing.trace("tool_response", parentContext)); + .compose(Tracing.trace("tool_response").setParent(parentContext)); } return Maybe.just(mergedEvent); }); @@ -432,8 +432,8 @@ private static Maybe postProcessFunctionResult( toolContext, invocationContext)) .compose( - Tracing.trace( - "tool_response [" + tool.name() + "]", parentContext)) + Tracing.trace("tool_response [" + tool.name() + "]") + .setParent(parentContext)) .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); }); } @@ -593,7 +593,9 @@ private static Maybe> callTool( Tracing.traceToolCall( tool.name(), tool.description(), tool.getClass().getSimpleName(), args)) .doOnError(t -> Span.current().recordException(t)) - .compose(Tracing.trace("tool_call [" + tool.name() + "]", parentContext)) + .compose( + Tracing.>trace("tool_call [" + tool.name() + "]") + .setParent(parentContext)) .onErrorResumeNext( e -> Maybe.error( diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 07a640c37..fc2ca3abf 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -426,19 +426,6 @@ public static TracerProvider trace(String spanName) { return new TracerProvider<>(spanName); } - /** - * Returns a transformer that traces the execution of an RxJava stream with an explicit parent - * context. - * - * @param spanName The name of the span to create. - * @param parentContext The explicit parent context for the span. - * @param The type of the stream. - * @return A TracerProvider that can be used with .compose(). - */ - public static TracerProvider trace(String spanName, Context parentContext) { - return new TracerProvider(spanName).setParent(parentContext); - } - /** * Returns a transformer that traces an agent invocation. * From c6fdb63c92e2f3481a01cfeafa946b6dce728c51 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Mar 2026 02:33:45 -0700 Subject: [PATCH 22/50] feat: update State constructors to accept general Map types PiperOrigin-RevId: 881890323 --- .../java/com/google/adk/sessions/State.java | 24 ++++++--- .../com/google/adk/sessions/StateTest.java | 50 +++++++++++++++++++ 2 files changed, 67 insertions(+), 7 deletions(-) create mode 100644 core/src/test/java/com/google/adk/sessions/StateTest.java diff --git a/core/src/main/java/com/google/adk/sessions/State.java b/core/src/main/java/com/google/adk/sessions/State.java index ec23857d9..70d2dfbf2 100644 --- a/core/src/main/java/com/google/adk/sessions/State.java +++ b/core/src/main/java/com/google/adk/sessions/State.java @@ -24,6 +24,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import javax.annotation.Nullable; /** A {@link State} object that also keeps track of the changes to the state. */ @SuppressWarnings("ShouldNotSubclass") @@ -39,13 +40,22 @@ public final class State implements ConcurrentMap { private final ConcurrentMap state; private final ConcurrentMap delta; - public State(ConcurrentMap state) { - this(state, new ConcurrentHashMap<>()); - } - - public State(ConcurrentMap state, ConcurrentMap delta) { - this.state = Objects.requireNonNull(state); - this.delta = delta; + public State(Map state) { + this(state, null); + } + + public State(Map state, @Nullable Map delta) { + Objects.requireNonNull(state, "state is null"); + this.state = + state instanceof ConcurrentMap + ? (ConcurrentMap) state + : new ConcurrentHashMap<>(state); + this.delta = + delta == null + ? new ConcurrentHashMap<>() + : delta instanceof ConcurrentMap + ? (ConcurrentMap) delta + : new ConcurrentHashMap<>(delta); } @Override diff --git a/core/src/test/java/com/google/adk/sessions/StateTest.java b/core/src/test/java/com/google/adk/sessions/StateTest.java new file mode 100644 index 000000000..e1fcaeadc --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/StateTest.java @@ -0,0 +1,50 @@ +package com.google.adk.sessions; + +import static com.google.common.truth.Truth.assertThat; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class StateTest { + @Test + public void constructor_nullDelta_createsEmptyConcurrentHashMap() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + State state = new State(stateMap, null); + assertThat(state.hasDelta()).isFalse(); + state.put("key", "value"); + assertThat(state.hasDelta()).isTrue(); + } + + @Test + public void constructor_nullState_throwsException() { + Assert.assertThrows(NullPointerException.class, () -> new State(null, new HashMap<>())); + } + + @Test + public void constructor_regularMapState() { + Map stateMap = new HashMap<>(); + stateMap.put("initial", "val"); + State state = new State(stateMap, null); + // It should have copied the contents + assertThat(state).containsEntry("initial", "val"); + state.put("key", "value"); + // The original map should NOT be updated because a copy was created + assertThat(stateMap).doesNotContainKey("key"); + } + + @Test + public void constructor_singleArgument() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + State state = new State(stateMap); + assertThat(state.hasDelta()).isFalse(); + state.put("key", "value"); + assertThat(state.hasDelta()).isTrue(); + } +} From e0d833b337e958e299d0d11a03f6bfa1468731bc Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Wed, 11 Mar 2026 04:25:00 -0700 Subject: [PATCH 23/50] feat!: update LoopAgent's maxIteration field and methods to be @Nullable instead of Optional PiperOrigin-RevId: 881935126 --- .../java/com/google/adk/agents/LoopAgent.java | 22 +++++++++---------- .../com/google/adk/agents/LoopAgentTest.java | 7 +----- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/LoopAgent.java b/core/src/main/java/com/google/adk/agents/LoopAgent.java index d9d049f80..743d569b9 100644 --- a/core/src/main/java/com/google/adk/agents/LoopAgent.java +++ b/core/src/main/java/com/google/adk/agents/LoopAgent.java @@ -21,7 +21,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Flowable; import java.util.List; -import java.util.Optional; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,7 +34,7 @@ public class LoopAgent extends BaseAgent { private static final Logger logger = LoggerFactory.getLogger(LoopAgent.class); - private final Optional maxIterations; + private final @Nullable Integer maxIterations; /** * Constructor for LoopAgent. @@ -50,7 +50,7 @@ private LoopAgent( String name, String description, List subAgents, - Optional maxIterations, + @Nullable Integer maxIterations, List beforeAgentCallback, List afterAgentCallback) { @@ -60,16 +60,10 @@ private LoopAgent( /** Builder for {@link LoopAgent}. */ public static class Builder extends BaseAgent.Builder { - private Optional maxIterations = Optional.empty(); + private @Nullable Integer maxIterations; @CanIgnoreReturnValue - public Builder maxIterations(int maxIterations) { - this.maxIterations = Optional.of(maxIterations); - return this; - } - - @CanIgnoreReturnValue - public Builder maxIterations(Optional maxIterations) { + public Builder maxIterations(@Nullable Integer maxIterations) { this.maxIterations = maxIterations; return this; } @@ -124,7 +118,7 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { return Flowable.fromIterable(subAgents) .concatMap(subAgent -> subAgent.runAsync(invocationContext)) - .repeat(maxIterations.orElse(Integer.MAX_VALUE)) + .repeat(maxIterations != null ? maxIterations : Integer.MAX_VALUE) .takeUntil(LoopAgent::hasEscalateAction); } @@ -137,4 +131,8 @@ protected Flowable runLiveImpl(InvocationContext invocationContext) { private static boolean hasEscalateAction(Event event) { return event.actions().escalate().orElse(false); } + + public @Nullable Integer maxIterations() { + return maxIterations; + } } diff --git a/core/src/test/java/com/google/adk/agents/LoopAgentTest.java b/core/src/test/java/com/google/adk/agents/LoopAgentTest.java index 5c04ac74b..b2d0778c6 100644 --- a/core/src/test/java/com/google/adk/agents/LoopAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LoopAgentTest.java @@ -33,7 +33,6 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import java.util.List; -import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; import org.junit.runner.RunWith; @@ -165,11 +164,7 @@ public void runAsync_withNoMaxIterations_keepsLooping() { Event event2 = createEvent("event2"); TestBaseAgent subAgent = createSubAgent("subAgent", () -> Flowable.just(event1, event2)); LoopAgent loopAgent = - LoopAgent.builder() - .name("loopAgent") - .subAgents(ImmutableList.of(subAgent)) - .maxIterations(Optional.empty()) - .build(); + LoopAgent.builder().name("loopAgent").subAgents(ImmutableList.of(subAgent)).build(); InvocationContext invocationContext = createInvocationContext(loopAgent); Iterable result = loopAgent.runAsync(invocationContext).blockingIterable(); From 1a871141d1c4e659ec90fe4e9f29342bea305255 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Mar 2026 05:34:14 -0700 Subject: [PATCH 24/50] refactor: Simplifying Tracing code Replacing `getValidCurrentSpan.ifPresent` with a simpler `traceWithSpan` method. PiperOrigin-RevId: 881960190 --- .../com/google/adk/telemetry/Tracing.java | 202 +++++++++--------- 1 file changed, 99 insertions(+), 103 deletions(-) diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index fc2ca3abf..7f338fdcf 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -127,13 +127,13 @@ public class Tracing { private Tracing() {} - private static Optional getValidCurrentSpan(String methodName) { + private static void traceWithSpan(String methodName, Consumer traceAction) { Span span = Span.current(); if (!span.getSpanContext().isValid()) { log.trace("{}: No valid span in current context.", methodName); - return Optional.empty(); + return; } - return Optional.of(span); + traceAction.accept(span); } private static void setInvocationAttributes( @@ -206,16 +206,16 @@ public static void traceAgentInvocation( */ public static void traceToolCall( String toolName, String toolDescription, String toolType, Map args) { - getValidCurrentSpan("traceToolCall") - .ifPresent( - span -> { - setToolExecutionAttributes(span); - span.setAttribute(GEN_AI_TOOL_NAME, toolName); - span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); - span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - - setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); - }); + traceWithSpan( + "traceToolCall", + span -> { + setToolExecutionAttributes(span); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); + + setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); + }); } /** @@ -225,33 +225,33 @@ public static void traceToolCall( * @param functionResponseEvent The function response event. */ public static void traceToolResponse(String eventId, Event functionResponseEvent) { - getValidCurrentSpan("traceToolResponse") - .ifPresent( - span -> { - setToolExecutionAttributes(span); - span.setAttribute(ADK_EVENT_ID, eventId); - - FunctionResponse functionResponse = - functionResponseEvent.functionResponses().stream().findFirst().orElse(null); - - String toolCallId = ""; - Object toolResponse = ""; - if (functionResponse != null) { - toolCallId = functionResponse.id().orElse(toolCallId); - if (functionResponse.response().isPresent()) { - toolResponse = functionResponse.response().get(); - } - } - - span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); - - Object finalToolResponse = - (toolResponse instanceof Map) - ? toolResponse - : ImmutableMap.of("result", toolResponse); - - setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); - }); + traceWithSpan( + "traceToolResponse", + span -> { + setToolExecutionAttributes(span); + span.setAttribute(ADK_EVENT_ID, eventId); + + FunctionResponse functionResponse = + functionResponseEvent.functionResponses().stream().findFirst().orElse(null); + + String toolCallId = ""; + Object toolResponse = ""; + if (functionResponse != null) { + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + + Object finalToolResponse = + (toolResponse instanceof Map) + ? toolResponse + : ImmutableMap.of("result", toolResponse); + + setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); + }); } /** @@ -296,58 +296,54 @@ public static void traceCallLlm( String eventId, LlmRequest llmRequest, LlmResponse llmResponse) { - getValidCurrentSpan("traceCallLlm") - .ifPresent( - span -> { - span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); - llmRequest - .model() - .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - - setInvocationAttributes(span, invocationContext, eventId); - - setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); - setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); - - llmRequest - .config() - .ifPresent( - config -> { - config - .topP() - .ifPresent( - topP -> - span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); - config - .maxOutputTokens() - .ifPresent( - maxTokens -> - span.setAttribute( - GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); - }); - llmResponse - .usageMetadata() - .ifPresent( - usage -> { - usage - .promptTokenCount() - .ifPresent( - tokens -> - span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); - usage - .candidatesTokenCount() - .ifPresent( - tokens -> - span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); - }); - llmResponse - .finishReason() - .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) - .ifPresent( - reason -> - span.setAttribute( - GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); - }); + traceWithSpan( + "traceCallLlm", + span -> { + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest + .model() + .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); + + setInvocationAttributes(span, invocationContext, eventId); + + setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); + setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + + llmRequest + .config() + .ifPresent( + config -> { + config + .topP() + .ifPresent( + topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() + .ifPresent( + maxTokens -> + span.setAttribute( + GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent( + tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() + .ifPresent( + tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + .ifPresent( + reason -> + span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); + }); } /** @@ -359,17 +355,17 @@ public static void traceCallLlm( */ public static void traceSendData( InvocationContext invocationContext, String eventId, List data) { - getValidCurrentSpan("traceSendData") - .ifPresent( - span -> { - setInvocationAttributes(span, invocationContext, eventId); - - ImmutableList safeData = - Optional.ofNullable(data).orElse(ImmutableList.of()).stream() - .filter(Objects::nonNull) - .collect(toImmutableList()); - setJsonAttribute(span, ADK_DATA, safeData); - }); + traceWithSpan( + "traceSendData", + span -> { + setInvocationAttributes(span, invocationContext, eventId); + + ImmutableList safeData = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + setJsonAttribute(span, ADK_DATA, safeData); + }); } /** From bc385589057a6daf0209a335280bf19d20b2126b Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Wed, 11 Mar 2026 07:11:59 -0700 Subject: [PATCH 25/50] feat!: remove deprecated LoadArtifactsTool.loadArtifacts method PiperOrigin-RevId: 881996833 --- .../google/adk/agents/CallbackContext.java | 28 ++++++++++--------- .../google/adk/tools/LoadArtifactsTool.java | 2 +- .../adk/tools/LoadArtifactsToolTest.java | 6 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/CallbackContext.java b/core/src/main/java/com/google/adk/agents/CallbackContext.java index a29783769..da5b0d794 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackContext.java +++ b/core/src/main/java/com/google/adk/agents/CallbackContext.java @@ -19,12 +19,12 @@ import com.google.adk.artifacts.ListArtifactsResponse; import com.google.adk.events.EventActions; import com.google.adk.sessions.State; +import com.google.common.base.Preconditions; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; -import java.util.Optional; /** The context of various callbacks for an agent invocation. */ public class CallbackContext extends ReadonlyContext { @@ -94,22 +94,19 @@ public Single> listArtifacts() { /** Loads the latest version of an artifact from the service. */ public Maybe loadArtifact(String filename) { - return loadArtifact(filename, Optional.empty()); + checkArtifactServiceInitialized(); + return invocationContext + .artifactService() + .loadArtifact( + invocationContext.appName(), + invocationContext.userId(), + invocationContext.session().id(), + filename); } /** Loads a specific version of an artifact from the service. */ public Maybe loadArtifact(String filename, int version) { - return loadArtifact(filename, Optional.of(version)); - } - - /** - * @deprecated Use {@link #loadArtifact(String)} or {@link #loadArtifact(String, int)} instead. - */ - @Deprecated - public Maybe loadArtifact(String filename, Optional version) { - if (invocationContext.artifactService() == null) { - throw new IllegalStateException("Artifact service is not initialized."); - } + checkArtifactServiceInitialized(); return invocationContext .artifactService() .loadArtifact( @@ -120,6 +117,11 @@ public Maybe loadArtifact(String filename, Optional version) { version); } + private void checkArtifactServiceInitialized() { + Preconditions.checkState( + invocationContext.artifactService() != null, "Artifact service is not initialized."); + } + /** * Saves an artifact and records it as a delta for the current session. * diff --git a/core/src/main/java/com/google/adk/tools/LoadArtifactsTool.java b/core/src/main/java/com/google/adk/tools/LoadArtifactsTool.java index c5ae8af37..399079af5 100644 --- a/core/src/main/java/com/google/adk/tools/LoadArtifactsTool.java +++ b/core/src/main/java/com/google/adk/tools/LoadArtifactsTool.java @@ -169,7 +169,7 @@ private Completable loadAndAppendIndividualArtifact( LlmRequest.Builder llmRequestBuilder, ToolContext toolContext, String artifactName) { return toolContext - .loadArtifact(artifactName, Optional.empty()) + .loadArtifact(artifactName) .flatMapCompletable( actualArtifact -> Completable.fromAction( diff --git a/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java b/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java index 89014175d..5ed7a1f40 100644 --- a/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java +++ b/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java @@ -3,7 +3,6 @@ import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -163,12 +162,11 @@ public void processLlmRequest_artifactsInContext_withLoadArtifactsFunctionCall_l Part loadedArtifactPart = Part.fromText("This is the content of doc1.txt"); ToolContext spiedToolContext = spy(ToolContext.builder(mockInvocationContext).build()); when(spiedToolContext.listArtifacts()).thenReturn(Single.just(availableArtifacts)); - when(spiedToolContext.loadArtifact(eq("doc1.txt"), eq(Optional.empty()))) - .thenReturn(Maybe.just(loadedArtifactPart)); + when(spiedToolContext.loadArtifact("doc1.txt")).thenReturn(Maybe.just(loadedArtifactPart)); loadArtifactsTool.processLlmRequest(llmRequestBuilder, spiedToolContext).blockingAwait(); - verify(spiedToolContext).loadArtifact(eq("doc1.txt"), eq(Optional.empty())); + verify(spiedToolContext).loadArtifact("doc1.txt"); LlmRequest finalRequest = llmRequestBuilder.build(); List finalContents = finalRequest.contents(); From 0d6dd55f4870007e79db23e21bd261879dbfba79 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Mar 2026 07:52:37 -0700 Subject: [PATCH 26/50] feat: add formatting to the RemoteA2A agent so it filters out the previous agent responses and updates the context of the function calls and responses PiperOrigin-RevId: 882012279 --- .../google/adk/a2a/agent/RemoteA2AAgent.java | 28 ++- .../adk/a2a/converters/EventConverter.java | 191 ++++++++++++--- .../adk/a2a/converters/PartConverter.java | 44 ++++ .../adk/a2a/agent/RemoteA2AAgentTest.java | 7 +- .../a2a/converters/EventConverterTest.java | 227 ++++++++++++------ 5 files changed, 372 insertions(+), 125 deletions(-) diff --git a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java index 021786162..ccb662b7c 100644 --- a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java @@ -181,6 +181,25 @@ public RemoteA2AAgent build() { } } + private Message.Builder newA2AMessage(Message.Role role, List> parts) { + return new Message.Builder().messageId(UUID.randomUUID().toString()).role(role).parts(parts); + } + + private Message prepareMessage(InvocationContext invocationContext) { + Event userCall = EventConverter.findUserFunctionCall(invocationContext.session().events()); + if (userCall != null) { + ImmutableList> parts = + EventConverter.contentToParts(userCall.content(), userCall.partial().orElse(false)); + return newA2AMessage(Message.Role.USER, parts) + .taskId(EventConverter.taskId(userCall)) + .contextId(EventConverter.contextId(userCall)) + .build(); + } + return newA2AMessage( + Message.Role.USER, EventConverter.messagePartsFromContext(invocationContext)) + .build(); + } + @Override protected Flowable runAsyncImpl(InvocationContext invocationContext) { // Construct A2A Message from the last ADK event @@ -191,14 +210,7 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { return Flowable.empty(); } - Optional a2aMessageOpt = EventConverter.convertEventsToA2AMessage(invocationContext); - - if (a2aMessageOpt.isEmpty()) { - logger.warn("Failed to convert event to A2A message."); - return Flowable.empty(); - } - - Message originalMessage = a2aMessageOpt.get(); + Message originalMessage = prepareMessage(invocationContext); String requestJson = serializeMessageToJson(originalMessage); return Flowable.create( diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java index d823e3817..71573070e 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java @@ -18,66 +18,107 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; import com.google.genai.types.Content; -import io.a2a.spec.Message; +import com.google.genai.types.FunctionResponse; import io.a2a.spec.Part; import java.util.Collection; +import java.util.List; import java.util.Optional; import java.util.UUID; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.jspecify.annotations.Nullable; /** Converter for ADK Events to A2A Messages. */ public final class EventConverter { - private static final Logger logger = LoggerFactory.getLogger(EventConverter.class); + public static final String ADK_TASK_ID_KEY = "adk_task_id"; + public static final String ADK_CONTEXT_ID_KEY = "adk_context_id"; private EventConverter() {} /** - * Converts an ADK InvocationContext to an A2A Message. + * Returns the task ID from the event. * - *

It combines all the events in the session, plus the user content, converted into A2A Parts, - * into a single A2A Message. + *

Task ID is stored in the event's custom metadata with the key {@link #ADK_TASK_ID_KEY}. * - *

If the context has no events, or no suitable content to build the message, an empty optional - * is returned. - * - * @param context The ADK InvocationContext to convert. - * @return The converted A2A Message. + * @param event The event to get the task ID from. + * @return The task ID, or an empty string if not found. */ - public static Optional convertEventsToA2AMessage(InvocationContext context) { - if (context.session().events().isEmpty()) { - logger.warn("No events in session, cannot convert to A2A message."); - return Optional.empty(); - } - - ImmutableList.Builder> partsBuilder = ImmutableList.builder(); + public static String taskId(Event event) { + return metadataValue(event, ADK_TASK_ID_KEY); + } - context - .session() - .events() - .forEach( - event -> - partsBuilder.addAll( - contentToParts(event.content(), event.partial().orElse(false)))); - partsBuilder.addAll(contentToParts(context.userContent(), false)); + /** + * Returns the context ID from the event. + * + *

Context ID is stored in the event's custom metadata with the key {@link + * #ADK_CONTEXT_ID_KEY}. + * + * @param event The event to get the context ID from. + * @return The context ID, or an empty string if not found. + */ + public static String contextId(Event event) { + return metadataValue(event, ADK_CONTEXT_ID_KEY); + } - ImmutableList> parts = partsBuilder.build(); + /** + * Returns the last user function call event from the list of events. + * + * @param events The list of events to find the user function call event from. + * @return The user function call event, or null if not found. + */ + public static @Nullable Event findUserFunctionCall(List events) { + Event candidate = Iterables.getLast(events); + if (!candidate.author().equals("user")) { + return null; + } + FunctionResponse functionResponse = findUserFunctionResponse(candidate); + if (functionResponse == null || functionResponse.id().isEmpty()) { + return null; + } + for (int i = events.size() - 2; i >= 0; i--) { + Event event = events.get(i); + if (isUserFunctionCall(event, functionResponse.id().get())) { + return event; + } + } + return null; + } - if (parts.isEmpty()) { - logger.warn("No suitable content found to build A2A request message."); - return Optional.empty(); + private static @Nullable FunctionResponse findUserFunctionResponse(Event candidate) { + if (candidate.content().isEmpty() || candidate.content().get().parts().isEmpty()) { + return null; } + return candidate.content().get().parts().get().stream() + .filter(part -> part.functionResponse().isPresent()) + .findFirst() + .map(part -> part.functionResponse().get()) + .orElse(null); + } - return Optional.of( - new Message.Builder() - .messageId(UUID.randomUUID().toString()) - .parts(parts) - .role(Message.Role.USER) - .build()); + private static boolean isUserFunctionCall(Event event, String functionResponseId) { + if (event.content().isEmpty()) { + return false; + } + return event.content().get().parts().get().stream() + .anyMatch( + part -> + part.functionCall().isPresent() + && part.functionCall() + .get() + .id() + .map(id -> id.equals(functionResponseId)) + .orElse(false)); } + /** + * Converts a GenAI Content object to a list of A2A Parts. + * + * @param content The GenAI Content object to convert. + * @param isPartial Whether the content is partial. + * @return A list of A2A Parts. + */ public static ImmutableList> contentToParts( Optional content, boolean isPartial) { return content.flatMap(Content::parts).stream() @@ -85,4 +126,80 @@ public static ImmutableList> contentToParts( .map(part -> PartConverter.fromGenaiPart(part, isPartial)) .collect(toImmutableList()); } + + /** + * Returns the parts from the context events that should be sent to the agent. + * + *

All session events from the previous remote agent response (or the beginning of the session + * in case of the first agent invocation) are included into the A2A message. Events from other + * agents are presented as user messages and rephased as if a user was telling what happened in + * the session up to the point. + * + * @param context The invocation context to get the parts from. + * @return A list of A2A Parts. + */ + public static ImmutableList> messagePartsFromContext(InvocationContext context) { + if (context.session().events().isEmpty()) { + return ImmutableList.of(); + } + List events = context.session().events(); + int lastResponseIndex = -1; + String contextId = ""; + for (int i = events.size() - 1; i >= 0; i--) { + Event event = events.get(i); + if (event.author().equals(context.agent().name())) { + lastResponseIndex = i; + contextId = contextId(event); + break; + } + } + ImmutableList.Builder> partsBuilder = ImmutableList.builder(); + for (int i = lastResponseIndex + 1; i < events.size(); i++) { + Event event = events.get(i); + if (!event.author().equals("user") && !event.author().equals(context.agent().name())) { + event = presentAsUserMessage(event, contextId); + } + contentToParts(event.content(), event.partial().orElse(false)).forEach(partsBuilder::add); + } + return partsBuilder.build(); + } + + private static Event presentAsUserMessage(Event event, String contextId) { + Event.Builder userEvent = + new Event.Builder().id(UUID.randomUUID().toString()).invocationId(contextId).author("user"); + ImmutableList parts = + event.content().flatMap(Content::parts).stream() + .flatMap(Collection::stream) + // convert only non-thought parts to user message parts, skip thought parts as they are + // not meant to be shown to the user + .filter(part -> !part.thought().orElse(false)) + .map(part -> PartConverter.remoteCallAsUserPart(event.author(), part)) + .collect(toImmutableList()); + if (parts.isEmpty()) { + return userEvent.build(); + } + com.google.genai.types.Part forContext = + com.google.genai.types.Part.builder().text("For context:").build(); + return userEvent + .content( + Content.builder() + .parts( + ImmutableList.builder() + .add(forContext) + .addAll(parts) + .build()) + .build()) + .build(); + } + + private static String metadataValue(Event event, String key) { + if (event.customMetadata().isEmpty()) { + return ""; + } + return event.customMetadata().get().stream() + .filter(m -> m.key().map(k -> k.equals(key)).orElse(false)) + .findFirst() + .flatMap(m -> m.stringValue()) + .orElse(""); + } } diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 96ef66bc8..36af6cc8b 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -384,6 +384,50 @@ private static FilePart filePartToA2A(Part part, ImmutableMap.BuilderEvents are rephrased as if a user was telling what happened in the session up to the point. + * E.g. + * + *

{@code
+   * For context:
+   * User said: Now help me with Z
+   * Agent A said: Agent B can help you with it!
+   * Agent B said: Agent C might know better.*
+   * }
+ * + * @param author The author of the part. + * @param part The part to convert. + * @return The converted part. + */ + public static Part remoteCallAsUserPart(String author, Part part) { + if (part.text().isPresent()) { + String partText = String.format("[%s] said: %s", author, part.text().get()); + return Part.builder().text(partText).build(); + } else if (part.functionCall().isPresent()) { + FunctionCall functionCall = part.functionCall().get(); + String partText = + String.format( + "[%s] called tool %s with parameters: %s", + author, + functionCall.name().orElse(""), + functionCall.args().orElse(ImmutableMap.of())); + return Part.builder().text(partText).build(); + } else if (part.functionResponse().isPresent()) { + FunctionResponse functionResponse = part.functionResponse().get(); + String partText = + String.format( + "[%s] %s tool returned result: %s", + author, + functionResponse.name().orElse(""), + functionResponse.response().orElse(ImmutableMap.of())); + return Part.builder().text(partText).build(); + } else { + return part; + } + } + @SuppressWarnings("unchecked") // safe conversion from objectMapper.readValue private static Map coerceToMap(Object value) { if (value == null) { diff --git a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java index e75da64ba..b1ffa248a 100644 --- a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java @@ -412,10 +412,11 @@ public void runAsync_constructsRequestWithHistory() { .sendMessage(messageCaptor.capture(), any(List.class), any(Consumer.class), any()); Message message = messageCaptor.getValue(); assertThat(message.getRole()).isEqualTo(Message.Role.USER); - assertThat(message.getParts()).hasSize(3); + assertThat(message.getParts()).hasSize(4); assertThat(((TextPart) message.getParts().get(0)).getText()).isEqualTo("hello"); - assertThat(((TextPart) message.getParts().get(1)).getText()).isEqualTo("hi"); - assertThat(((TextPart) message.getParts().get(2)).getText()).isEqualTo("how are you?"); + assertThat(((TextPart) message.getParts().get(1)).getText()).isEqualTo("For context:"); + assertThat(((TextPart) message.getParts().get(2)).getText()).isEqualTo("[model] said: hi"); + assertThat(((TextPart) message.getParts().get(3)).getText()).isEqualTo("how are you?"); } @Test diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java index 8d460c457..207019199 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java @@ -4,23 +4,17 @@ import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; -import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; -import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; -import io.a2a.spec.DataPart; -import io.a2a.spec.Message; import io.a2a.spec.TextPart; import io.reactivex.rxjava3.core.Flowable; -import java.util.ArrayList; -import java.util.List; import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; @@ -30,101 +24,180 @@ public final class EventConverterTest { @Test - public void convertEventsToA2AMessage_preservesFunctionCallAndResponseParts() { - // Arrange session events: user text, function call, function response. - Part userTextPart = Part.builder().text("Roll a die").build(); - Event userEvent = + public void testTaskId() { + Event e = + Event.builder() + .customMetadata( + ImmutableList.of( + CustomMetadata.builder() + .key(EventConverter.ADK_TASK_ID_KEY) + .stringValue("task-123") + .build())) + .build(); + assertThat(EventConverter.taskId(e)).isEqualTo("task-123"); + } + + @Test + public void testTaskId_empty() { + Event e = Event.builder().build(); + assertThat(EventConverter.taskId(e)).isEmpty(); + } + + @Test + public void testContextId() { + Event e = + Event.builder() + .customMetadata( + ImmutableList.of( + CustomMetadata.builder() + .key(EventConverter.ADK_CONTEXT_ID_KEY) + .stringValue("context-456") + .build())) + .build(); + assertThat(EventConverter.contextId(e)).isEqualTo("context-456"); + } + + @Test + public void testContextId_empty() { + Event e = Event.builder().build(); + assertThat(EventConverter.contextId(e)).isEmpty(); + } + + @Test + public void testFindUserFunctionCall_success() { + Event agentEvent = Event.builder().author("agent").build(); + FunctionCall fc = FunctionCall.builder().name("my-func").id("fc-id").build(); + Event userEventWithCall = Event.builder() - .id("event-user") .author("user") - .content(Content.builder().role("user").parts(ImmutableList.of(userTextPart)).build()) + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionCall(fc).build())) + .build()) .build(); - Part functionCallPart = - Part.builder() - .functionCall( - FunctionCall.builder() - .name("roll_die") - .id("adk-call-1") - .args(ImmutableMap.of("sides", 6)) + FunctionResponse fr = FunctionResponse.builder().name("my-func").id("fc-id").build(); + Event userEventWithResponse = + Event.builder() + .author("user") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionResponse(fr).build())) .build()) .build(); - Event callEvent = + + ImmutableList events = + ImmutableList.of(userEventWithCall, agentEvent, userEventWithResponse); + assertThat(EventConverter.findUserFunctionCall(events)).isEqualTo(userEventWithCall); + } + + @Test + public void testFindUserFunctionCall_noMatchingCall() { + Event agentEvent = Event.builder().author("agent").build(); + FunctionCall fc = FunctionCall.builder().name("my-func").id("other-id").build(); + Event userEventWithCall = Event.builder() - .id("event-call") - .author("root_agent") + .author("user") .content( Content.builder() - .role("assistant") - .parts(ImmutableList.of(functionCallPart)) + .parts(ImmutableList.of(Part.builder().functionCall(fc).build())) .build()) .build(); - Part functionResponsePart = - Part.builder() - .functionResponse( - FunctionResponse.builder() - .name("roll_die") - .id("adk-call-1") - .response(ImmutableMap.of("result", 3)) + FunctionResponse fr = FunctionResponse.builder().name("my-func").id("fc-id").build(); + Event userEventWithResponse = + Event.builder() + .author("user") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionResponse(fr).build())) + .build()) + .build(); + + ImmutableList events = + ImmutableList.of(userEventWithCall, agentEvent, userEventWithResponse); + assertThat(EventConverter.findUserFunctionCall(events)).isNull(); + } + + @Test + public void testFindUserFunctionCall_lastEventNotUser() { + Event agentEvent = Event.builder().author("agent").build(); + FunctionCall fc = FunctionCall.builder().name("my-func").id("fc-id").build(); + Event userEventWithCall = + Event.builder() + .author("user") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionCall(fc).build())) .build()) .build(); - Event responseEvent = + FunctionResponse fr = FunctionResponse.builder().name("my-func").id("fc-id").build(); + // Last event is not a user event, so should return null. + Event agentEventWithResponse = Event.builder() - .id("event-response") - .author("roll_agent") + .author("agent") .content( Content.builder() - .role("tool") - .parts(ImmutableList.of(functionResponsePart)) + .parts(ImmutableList.of(Part.builder().functionResponse(fr).build())) .build()) .build(); - List events = new ArrayList<>(ImmutableList.of(userEvent, callEvent, responseEvent)); - Session session = - Session.builder("session-1").appName("demo").userId("user").events(events).build(); + ImmutableList events = + ImmutableList.of(userEventWithCall, agentEvent, agentEventWithResponse); - InvocationContext context = + assertThat(EventConverter.findUserFunctionCall(events)).isNull(); + } + + @Test + public void testContentToParts() { + Part textPart = Part.builder().text("hello").build(); + Content content = Content.builder().parts(ImmutableList.of(textPart)).build(); + ImmutableList> list = + EventConverter.contentToParts(Optional.of(content), false); + assertThat(list).hasSize(1); + assertThat(((TextPart) list.get(0)).getText()).isEqualTo("hello"); + } + + @Test + public void testMessagePartsFromContext() { + Session session = + Session.builder("session1") + .events( + ImmutableList.of( + Event.builder() + .author("user") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().text("hello").build())) + .build()) + .build(), + Event.builder() + .author("test_agent") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().text("hi").build())) + .build()) + .build(), + Event.builder() + .author("other_agent") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().text("hey").build())) + .build()) + .build())) + .build(); + BaseAgent agent = new TestAgent(); + InvocationContext ctx = InvocationContext.builder() - .sessionService(new InMemorySessionService()) - .artifactService(new InMemoryArtifactService()) - .pluginManager(new PluginManager()) - .invocationId("invocation-1") - .agent(new TestAgent()) .session(session) - .userContent( - Content.builder().role("user").parts(ImmutableList.of(userTextPart)).build()) - .endInvocation(false) + .sessionService(new InMemorySessionService()) + .agent(agent) .build(); + ImmutableList> parts = EventConverter.messagePartsFromContext(ctx); - // Act - Optional maybeMessage = EventConverter.convertEventsToA2AMessage(context); - - // Assert - assertThat(maybeMessage).isPresent(); - Message message = maybeMessage.get(); - assertThat(message.getParts()).hasSize(4); - assertThat(message.getParts().get(0)).isInstanceOf(TextPart.class); - assertThat(message.getParts().get(1)).isInstanceOf(DataPart.class); - assertThat(message.getParts().get(2)).isInstanceOf(DataPart.class); - assertThat(message.getParts().get(3)).isInstanceOf(TextPart.class); - - DataPart callDataPart = (DataPart) message.getParts().get(1); - assertThat(callDataPart.getMetadata().get(PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY)) - .isEqualTo(A2ADataPartMetadataType.FUNCTION_CALL.getType()); - assertThat(callDataPart.getData()).containsEntry("name", "roll_die"); - assertThat(callDataPart.getData()).containsEntry("id", "adk-call-1"); - assertThat(callDataPart.getData()).containsEntry("args", ImmutableMap.of("sides", 6)); - - DataPart responseDataPart = (DataPart) message.getParts().get(2); - assertThat(responseDataPart.getMetadata().get(PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY)) - .isEqualTo(A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); - assertThat(responseDataPart.getData()).containsEntry("name", "roll_die"); - assertThat(responseDataPart.getData()).containsEntry("id", "adk-call-1"); - assertThat(responseDataPart.getData()).containsEntry("response", ImmutableMap.of("result", 3)); - - TextPart lastTextPart = (TextPart) message.getParts().get(3); - assertThat(lastTextPart.getText()).isEqualTo("Roll a die"); + assertThat(parts).hasSize(2); + assertThat(((TextPart) parts.get(0)).getText()).isEqualTo("For context:"); + assertThat(((TextPart) parts.get(1)).getText()).isEqualTo("[other_agent] said: hey"); } private static final class TestAgent extends BaseAgent { From d9d84ee67406cce8eeb66abcf1be24fad9c58e29 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Mar 2026 08:19:14 -0700 Subject: [PATCH 27/50] feat: Trigger traceCallLlm to set call_llm attributes before span ends PiperOrigin-RevId: 882023326 --- .../adk/flows/llmflows/BaseLlmFlow.java | 19 +-- .../com/google/adk/telemetry/Tracing.java | 117 ++++++++++-------- .../adk/telemetry/ContextPropagationTest.java | 2 +- 3 files changed, 78 insertions(+), 60 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 79066b213..ab5f6567a 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -190,20 +190,23 @@ private Flowable callLlm( context, llmRequestBuilder, eventForCallbackUsage, exception) .switchIfEmpty(Single.error(exception)) .toFlowable()) - .doOnNext( - llmResp -> - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp)) .doOnError( error -> { Span span = Span.current(); span.setStatus(StatusCode.ERROR, error.getMessage()); span.recordException(error); }) - .compose(Tracing.trace("call_llm").setParent(spanContext)) + .compose( + Tracing.trace("call_llm") + .setParent(spanContext) + .onSuccess( + (span, llmResp) -> + Tracing.traceCallLlm( + span, + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp))) .concatMap( llmResp -> handleAfterModelCallback(context, llmResp, eventForCallbackUsage) diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 7f338fdcf..35bf3cc96 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -54,6 +54,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; import org.reactivestreams.Publisher; @@ -292,58 +293,49 @@ private static Map buildLlmRequestForTrace(LlmRequest llmRequest * @param llmResponse The LLM response object. */ public static void traceCallLlm( + Span span, InvocationContext invocationContext, String eventId, LlmRequest llmRequest, LlmResponse llmResponse) { - traceWithSpan( - "traceCallLlm", - span -> { - span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); - llmRequest - .model() - .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - - setInvocationAttributes(span, invocationContext, eventId); - - setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); - setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); - - llmRequest - .config() - .ifPresent( - config -> { - config - .topP() - .ifPresent( - topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); - config - .maxOutputTokens() - .ifPresent( - maxTokens -> - span.setAttribute( - GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); - }); - llmResponse - .usageMetadata() - .ifPresent( - usage -> { - usage - .promptTokenCount() - .ifPresent( - tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); - usage - .candidatesTokenCount() - .ifPresent( - tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); - }); - llmResponse - .finishReason() - .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) - .ifPresent( - reason -> - span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); - }); + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); + + setInvocationAttributes(span, invocationContext, eventId); + + setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); + setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + + llmRequest + .config() + .ifPresent( + config -> { + config + .topP() + .ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() + .ifPresent( + maxTokens -> + span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() + .ifPresent( + tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + .ifPresent( + reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); } /** @@ -455,6 +447,7 @@ public static final class TracerProvider private final String spanName; private Context explicitParentContext; private final List> spanConfigurers = new ArrayList<>(); + private BiConsumer onSuccessConsumer; private TracerProvider(String spanName) { this.spanName = spanName; @@ -474,6 +467,16 @@ public TracerProvider setParent(Context parentContext) { return this; } + /** + * Registers a callback to be executed with the span and the result item when the stream emits a + * success value. + */ + @CanIgnoreReturnValue + public TracerProvider onSuccess(BiConsumer consumer) { + this.onSuccessConsumer = consumer; + return this; + } + private Context getParentContext() { return explicitParentContext != null ? explicitParentContext : Context.current(); } @@ -504,7 +507,11 @@ public Publisher apply(Flowable upstream) { return Flowable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Flowable pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } @@ -513,7 +520,11 @@ public SingleSource apply(Single upstream) { return Single.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Single pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } @@ -522,7 +533,11 @@ public MaybeSource apply(Maybe upstream) { return Maybe.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Maybe pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 9439fe718..f809193cf 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -503,7 +503,7 @@ public void testTraceCallLlm() { .totalTokenCount(30) .build()) .build(); - Tracing.traceCallLlm(buildInvocationContext(), "event-1", llmRequest, llmResponse); + Tracing.traceCallLlm(span, buildInvocationContext(), "event-1", llmRequest, llmResponse); } finally { span.end(); } From 9cef81368fe87aa8a9841e8936804735152f2109 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Wed, 11 Mar 2026 12:34:29 -0700 Subject: [PATCH 28/50] refactor: modify SessionUtils.toContent method to accept Nullable PiperOrigin-RevId: 882142571 --- .../java/com/google/adk/sessions/SessionUtils.java | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/SessionUtils.java b/core/src/main/java/com/google/adk/sessions/SessionUtils.java index 7a795be4c..1aeca98c9 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionUtils.java +++ b/core/src/main/java/com/google/adk/sessions/SessionUtils.java @@ -24,6 +24,7 @@ import java.util.Base64; import java.util.List; import java.util.Optional; +import org.jspecify.annotations.Nullable; /** Utility functions for session service. */ public final class SessionUtils { @@ -53,7 +54,7 @@ public static Content encodeContent(Content content) { encodedParts.add(part); } } - return toContent(encodedParts, content.role()); + return toContent(encodedParts, content.role().orElse(null)); } /** Decodes Base64-encoded inline blobs in content. */ @@ -79,13 +80,15 @@ public static Content decodeContent(Content content) { decodedParts.add(part); } } - return toContent(decodedParts, content.role()); + return toContent(decodedParts, content.role().orElse(null)); } /** Builds content from parts and optional role. */ - private static Content toContent(List parts, Optional role) { + private static Content toContent(List parts, @Nullable String role) { Content.Builder contentBuilder = Content.builder().parts(parts); - role.ifPresent(contentBuilder::role); + if (role != null) { + contentBuilder.role(role); + } return contentBuilder.build(); } } From aabf15a526ba525cdb47c74c246c178eff1851d5 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Wed, 11 Mar 2026 12:38:49 -0700 Subject: [PATCH 29/50] feat!: remove deprecated LlmAgent.canonicalTools method PiperOrigin-RevId: 882144767 --- core/src/main/java/com/google/adk/agents/LlmAgent.java | 8 -------- 1 file changed, 8 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index bbed217f4..d326d8154 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -757,14 +757,6 @@ public Single> canonicalGlobalInstruction(ReadonlyCon throw new IllegalStateException("Unknown Instruction subtype: " + globalInstruction.getClass()); } - /** - * @deprecated Use {@link #canonicalTools(ReadonlyContext)} instead. - */ - @Deprecated - public Flowable canonicalTools(Optional context) { - return canonicalTools(context.orElse(null)); - } - /** * Constructs the list of tools for this agent based on the {@link #tools} field. * From 4864287f680a38ff5e5d25c908b6a53c5430660d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Mar 2026 17:17:58 -0700 Subject: [PATCH 30/50] chore: Update actions/checkout to v6 in GitHub workflows V4 still uses deprecated Node.js 20. PiperOrigin-RevId: 882275971 --- .github/workflows/pr-commit-check.yml | 2 +- .github/workflows/validation.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-commit-check.yml b/.github/workflows/pr-commit-check.yml index ec6644311..1e31e42f3 100644 --- a/.github/workflows/pr-commit-check.yml +++ b/.github/workflows/pr-commit-check.yml @@ -21,7 +21,7 @@ jobs: # Step 1: Check out the code # This action checks out your repository under $GITHUB_WORKSPACE, so your workflow can access it. - name: Checkout Code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: # We need to fetch all commits to accurately count them. # '0' means fetch all history for all branches and tags. diff --git a/.github/workflows/validation.yml b/.github/workflows/validation.yml index d9035a579..65e66f8fd 100644 --- a/.github/workflows/validation.yml +++ b/.github/workflows/validation.yml @@ -20,7 +20,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Java ${{ matrix.java-version }} uses: actions/setup-java@v4 From bdfb7a72188ce6e72c12c16c0abedb824b846160 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 12 Mar 2026 07:18:40 -0700 Subject: [PATCH 31/50] feat: add multiple LLM responses to LLM recordings for conformance tests PiperOrigin-RevId: 882574863 --- .../com/google/adk/plugins/ReplayPlugin.java | 6 ++++- .../adk/plugins/recordings/LlmRecording.java | 7 +++--- .../google/adk/plugins/ReplayPluginTest.java | 10 ++++---- .../recordings/RecordingsLoaderTest.java | 23 ++++++++++--------- 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java b/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java index 5571b8d57..89032082c 100644 --- a/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java +++ b/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java @@ -90,7 +90,11 @@ public Maybe beforeModelCallback( logger.debug("Verified and replaying LLM response for agent {}", agentName); // Return the recorded response - return recording.llmResponse().map(Maybe::just).orElse(Maybe.empty()); + return recording + .llmResponses() + .filter(responses -> !responses.isEmpty()) + .map(responses -> Maybe.just(responses.get(0))) + .orElse(Maybe.empty()); } @Override diff --git a/dev/src/main/java/com/google/adk/plugins/recordings/LlmRecording.java b/dev/src/main/java/com/google/adk/plugins/recordings/LlmRecording.java index fe17aac0d..701b1e7ae 100644 --- a/dev/src/main/java/com/google/adk/plugins/recordings/LlmRecording.java +++ b/dev/src/main/java/com/google/adk/plugins/recordings/LlmRecording.java @@ -20,6 +20,7 @@ import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.auto.value.AutoValue; +import java.util.List; import java.util.Optional; import javax.annotation.Nullable; @@ -31,8 +32,8 @@ public abstract class LlmRecording { /** The LLM request. */ public abstract Optional llmRequest(); - /** The LLM response. */ - public abstract Optional llmResponse(); + /** The LLM responses. */ + public abstract Optional> llmResponses(); public static Builder builder() { return new AutoValue_LlmRecording.Builder(); @@ -44,7 +45,7 @@ public static Builder builder() { public abstract static class Builder { public abstract Builder llmRequest(@Nullable LlmRequest llmRequest); - public abstract Builder llmResponse(@Nullable LlmResponse llmResponse); + public abstract Builder llmResponses(@Nullable List llmResponses); public abstract LlmRecording build(); } diff --git a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java index f29298bce..8e89c2567 100644 --- a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java +++ b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java @@ -83,11 +83,11 @@ void beforeModelCallback_withMatchingRecording_returnsRecordedResponse() throws - role: "user" parts: - text: "Hello" - llm_response: - content: - role: "model" - parts: - - text: "Recorded response" + llm_responses: + - content: + role: "model" + parts: + - text: "Recorded response" """); // Step 1: Setup replay config diff --git a/dev/src/test/java/com/google/adk/plugins/recordings/RecordingsLoaderTest.java b/dev/src/test/java/com/google/adk/plugins/recordings/RecordingsLoaderTest.java index 92d12bc6d..ee115644c 100644 --- a/dev/src/test/java/com/google/adk/plugins/recordings/RecordingsLoaderTest.java +++ b/dev/src/test/java/com/google/adk/plugins/recordings/RecordingsLoaderTest.java @@ -58,16 +58,16 @@ void testLoadCRecording() throws Exception { - function_declarations: - name: validate_email description: Validates email format - llm_response: - content: - parts: - - thought_signature: Cq0EAR_MhbYyfIgI1M5KlVyG9HzjQ_CvZiHb_RQ2KR0H_UkDj-LDdxdVayqSpG8F6wPq4aGB6lZlqjZIGvA5H2zX2RQ_Iu8Wb8t_wKoEpW4XcwzzU9Org_ZvTNx4TZHll5cH5ebo1LPRWfTqVn7cC1N5KwDZtS2XLwCmitucAAKGzGH4c-tM0dgj57NoMFa63iaHizzi2zupKoGPBB-ZmakNHAHRspkl85hKaq8m4fELHNNMnyi596jcGRHxTDBiqHmNG8PyRiOXRM9VOkNnPU8l2DN7b6CvaBPmH84t0MaHxFMmrMjTQaNTBw92lXT7LZfwYJrDxf1ZpVHjztpbIhfZyYyZmxhIDNcVlb5i4Xoe8Rcva51NgBJN-UAm9cXWBSvr2_EdQbWs7Tz57niquyLpD6fhnTPOWBN6PU2Nz5nMgq-SUyM7srg2Ta6OV9uwOYFAFl0klSBouZ44YTM-T-voCin7EobkTzzXcllDPJ5TPretD_mpkeATlJ3Gi3nPfFLuU2DqFb8fLZjovY5oseSkEvf6NYnGt26r290QzG0cFsZbpJdtysBL-lH-yOwKEl-26IjiWztk0wAxnIdrmILlD9hgXRuyudXI0hx4gH1KTIH7njNNyLMNevUYVGC4cGxa1IpCh4EevhfCT9PQYM-QPyRT4dRBNzoG_y_lZERctUNHAfp80ObBClHEvDjElC2H6kWlO_jBeDiyJpezO7OeYjmDipvKFk3rQgNP87A= - function_call: - name: validate_email - args: - email: test@example.com - role: model - finish_reason: STOP + llm_responses: + - content: + parts: + - thought_signature: Cq0EAR_MhbYyfIgI1M5KlVyG9HzjQ_CvZiHb_RQ2KR0H_UkDj-LDdxdVayqSpG8F6wPq4aGB6lZlqjZIGvA5H2zX2RQ_Iu8Wb8t_wKoEpW4XcwzzU9Org_ZvTNx4TZHll5cH5ebo1LPRWfTqVn7cC1N5KwDZtS2XLwCmitucAAKGzGH4c-tM0dgj57NoMFa63iaHizzi2zupKoGPBB-ZmakNHAHRspkl85hKaq8m4fELHNNMnyi596jcGRHxTDBiqHmNG8PyRiOXRM9VOkNnPU8l2DN7b6CvaBPmH84t0MaHxFMmrMjTQaNTBw92lXT7LZfwYJrDxf1ZpVHjztpbIhfZyYyZmxhIDNcVlb5i4Xoe8Rcva51NgBJN-UAm9cXWBSvr2_EdQbWs7Tz57niquyLpD6fhnTPOWBN6PU2Nz5nMgq-SUyM7srg2Ta6OV9uwOYFAFl0klSBouZ44YTM-T-voCin7EobkTzzXcllDPJ5TPretD_mpkeATlJ3Gi3nPfFLuU2DqFb8fLZjovY5oseSkEvf6NYnGt26r290QzG0cFsZbpJdtysBL-lH-yOwKEl-26IjiWztk0wAxnIdrmILlD9hgXRuyudXI0hx4gH1KTIH7njNNyLMNevUYVGC4cGxa1IpCh4EevhfCT9PQYM-QPyRT4dRBNzoG_y_lZERctUNHAfp80ObBClHEvDjElC2H6kWlO_jBeDiyJpezO7OeYjmDipvKFk3rQgNP87A= + function_call: + name: validate_email + args: + email: test@example.com + role: model + finish_reason: STOP - user_message_index: 0 agent_name: booking_assistant tool_recording: @@ -108,7 +108,8 @@ void testLoadCRecording() throws Exception { assertThat(systemInstructionText.get()).contains("booking assistant"); // Verify URL-safe Base64 deserialization (thought_signature with '_' and '-' characters) - var responseContent = firstRecording.llmRecording().get().llmResponse().get().content().get(); + var responseContent = + firstRecording.llmRecording().get().llmResponses().get().get(0).content().get(); var thoughtSignature = getOnlyPart(responseContent).thoughtSignature(); assertThat(thoughtSignature).isPresent(); assertThat(thoughtSignature.get()).isNotEmpty(); From 3f6504e9416f9f644ef431e612ec983b9a2edd9d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 12 Mar 2026 07:40:50 -0700 Subject: [PATCH 32/50] feat: update return type for stateDelta() to Map from ConcurrentMap PiperOrigin-RevId: 882583892 --- core/src/main/java/com/google/adk/events/EventActions.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 0b167de93..38e7d1d96 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -83,7 +83,7 @@ public void setSkipSummarization(boolean skipSummarization) { } @JsonProperty("stateDelta") - public ConcurrentMap stateDelta() { + public Map stateDelta() { return stateDelta; } From 32759f9eb1cc88dc106b8828e84d47d3abc8c09d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 12 Mar 2026 07:45:31 -0700 Subject: [PATCH 33/50] refactor: Replacing use of deprecated Runner constructor and Runner.runAsync methods PiperOrigin-RevId: 882585695 --- .../adk/models/langchain4j/RunLoop.java | 2 +- .../springai/SpringAIIntegrationTest.java | 21 +++++++++++-------- .../google/adk/models/springai/TestUtils.java | 16 +++++++++----- .../google/adk/web/service/RunnerService.java | 15 ++++++------- 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java index 04a2aa585..2dca5c49c 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java @@ -53,7 +53,7 @@ public static List runLoop(BaseAgent agent, boolean streaming, Object... allEvents.addAll( runner .runAsync( - session, + session.sessionKey(), messageContent, RunConfig.builder() .setStreamingMode( diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java index 6843c8eaa..328df0415 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java @@ -18,6 +18,7 @@ import static org.junit.jupiter.api.Assertions.*; import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; import com.google.adk.models.springai.integrations.tools.WeatherTool; import com.google.adk.runner.InMemoryRunner; @@ -73,14 +74,15 @@ public ChatResponse call(Prompt prompt) { // when Runner runner = new InMemoryRunner(agent); - Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); + Session session = + runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); Content userMessage = Content.builder().role("user").parts(List.of(Part.fromText("What is a qubit?"))).build(); List events = runner - .runAsync(session, userMessage, com.google.adk.agents.RunConfig.builder().build()) + .runAsync(session.sessionKey(), userMessage, RunConfig.builder().build()) .toList() .blockingGet(); @@ -149,7 +151,8 @@ public ChatResponse call(Prompt prompt) { // when Runner runner = new InMemoryRunner(agent); - Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); + Session session = + runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); Content userMessage = Content.builder() @@ -159,7 +162,7 @@ public ChatResponse call(Prompt prompt) { List events = runner - .runAsync(session, userMessage, com.google.adk.agents.RunConfig.builder().build()) + .runAsync(session.userId(), session.id(), userMessage, RunConfig.builder().build()) .toList() .blockingGet(); @@ -217,7 +220,8 @@ public Flux stream(Prompt prompt) { // when Runner runner = new InMemoryRunner(agent); - Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); + Session session = + runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); Content userMessage = Content.builder() @@ -228,11 +232,10 @@ public Flux stream(Prompt prompt) { List events = runner .runAsync( - session, + session.userId(), + session.id(), userMessage, - com.google.adk.agents.RunConfig.builder() - .setStreamingMode(com.google.adk.agents.RunConfig.StreamingMode.SSE) - .build()) + RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.SSE).build()) .toList() .blockingGet(); diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java index f18ded055..c23e68eae 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/TestUtils.java @@ -46,7 +46,7 @@ public static List askAgent(BaseAgent agent, boolean streaming, Object... allEvents.addAll( runner .runAsync( - session, + session.sessionKey(), messageContent, RunConfig.builder() .setStreamingMode( @@ -67,13 +67,17 @@ public static List askBlockingAgent(BaseAgent agent, Object... messages) } Runner runner = new InMemoryRunner(agent); - Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); + Session session = + runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); List events = new ArrayList<>(); for (Content content : contents) { List batchEvents = - runner.runAsync(session, content, RunConfig.builder().build()).toList().blockingGet(); + runner + .runAsync(session.userId(), session.id(), content, RunConfig.builder().build()) + .toList() + .blockingGet(); events.addAll(batchEvents); } @@ -88,7 +92,8 @@ public static List askAgentStreaming(BaseAgent agent, Object... messages) } Runner runner = new InMemoryRunner(agent); - Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); + Session session = + runner.sessionService().createSession(agent.name(), "test-user").blockingGet(); List events = new ArrayList<>(); @@ -96,7 +101,8 @@ public static List askAgentStreaming(BaseAgent agent, Object... messages) List batchEvents = runner .runAsync( - session, + session.userId(), + session.id(), content, RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.SSE).build()) .toList() diff --git a/dev/src/main/java/com/google/adk/web/service/RunnerService.java b/dev/src/main/java/com/google/adk/web/service/RunnerService.java index 7297af833..480c68472 100644 --- a/dev/src/main/java/com/google/adk/web/service/RunnerService.java +++ b/dev/src/main/java/com/google/adk/web/service/RunnerService.java @@ -78,13 +78,14 @@ public Runner getRunner(String appName) { "RunnerService: Creating Runner for appName: {}, using agent definition: {}", appName, agent.name()); - return new Runner( - agent, - appName, - this.artifactService, - this.sessionService, - this.memoryService, - this.extraPlugins); + return Runner.builder() + .agent(agent) + .appName(appName) + .artifactService(this.artifactService) + .sessionService(this.sessionService) + .memoryService(this.memoryService) + .plugins(this.extraPlugins) + .build(); } catch (java.util.NoSuchElementException e) { log.error( "Agent/App named '{}' not found in registry. Available apps: {}", From be3b3f8360888ea1f13796969bb19893c32727e0 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Thu, 12 Mar 2026 07:52:21 -0700 Subject: [PATCH 34/50] feat: remove executionId method that takes Optional param from CodeExecutionUtils PiperOrigin-RevId: 882588456 --- .../adk/codeexecutors/CodeExecutionUtils.java | 6 +-- .../adk/flows/llmflows/CodeExecution.java | 7 +++- .../adk/flows/llmflows/CodeExecutionTest.java | 40 +++++++++++++++++++ 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java b/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java index b9afdcaff..a4d3771c3 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java +++ b/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java @@ -34,6 +34,7 @@ import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.jspecify.annotations.Nullable; /** Utility functions for code execution. */ public final class CodeExecutionUtils { @@ -237,8 +238,7 @@ public abstract static class CodeExecutionInput extends JsonBaseModel { public static Builder builder() { return new AutoValue_CodeExecutionUtils_CodeExecutionInput.Builder() - .inputFiles(ImmutableList.of()) - .executionId(Optional.empty()); + .inputFiles(ImmutableList.of()); } /** Builder for {@link CodeExecutionInput}. */ @@ -248,7 +248,7 @@ public abstract static class Builder { public abstract Builder inputFiles(List inputFiles); - public abstract Builder executionId(Optional executionId); + public abstract Builder executionId(@Nullable String executionId); public abstract CodeExecutionInput build(); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java index f2cbe967e..d76cd1a04 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java @@ -240,7 +240,8 @@ private static Flowable runPreProcessor( .code(codeStr) .inputFiles(ImmutableList.of(file)) .executionId( - getOrSetExecutionId(invocationContext, codeExecutorContext)) + getOrSetExecutionId(invocationContext, codeExecutorContext) + .orElse(null)) .build()); codeExecutorContext.updateCodeExecutionResult( @@ -320,7 +321,9 @@ private static Flowable runPostProcessor( CodeExecutionInput.builder() .code(codeStr) .inputFiles(codeExecutorContext.getInputFiles()) - .executionId(getOrSetExecutionId(invocationContext, codeExecutorContext)) + .executionId( + getOrSetExecutionId(invocationContext, codeExecutorContext) + .orElse(null)) .build()); codeExecutorContext.updateCodeExecutionResult( invocationContext.invocationId(), diff --git a/core/src/test/java/com/google/adk/flows/llmflows/CodeExecutionTest.java b/core/src/test/java/com/google/adk/flows/llmflows/CodeExecutionTest.java index 353504dac..1485ca2c4 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/CodeExecutionTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/CodeExecutionTest.java @@ -20,6 +20,7 @@ import static com.google.adk.testing.TestUtils.createTestAgentBuilder; import static com.google.adk.testing.TestUtils.createTestLlm; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.verify; @@ -32,14 +33,19 @@ import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionInput; import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionResult; import com.google.adk.events.Event; +import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; +import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.observers.TestObserver; +import java.util.ArrayList; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -115,4 +121,38 @@ public void testResponseProcessor_withCode_executesCode() { assertThat(executionResultPart.codeExecutionResult().get().output()) .hasValue("Code execution result:\nhello\n\n"); } + + @Test + public void testRequestProcessor_withCode_hasNoErrors() throws Exception { + // arrange + LlmRequest.Builder llmReqBuilder = LlmRequest.builder(); + when(mockCodeExecutor.codeBlockDelimiters()) + .thenReturn(ImmutableList.of(ImmutableList.of("```tool_code", "\n```"))); + when(mockCodeExecutor.optimizeDataFile()).thenReturn(true); + when(mockCodeExecutor.errorRetryAttempts()).thenReturn(2); + CodeExecutionResult executionResult = CodeExecutionResult.builder().stdout("hello\n").build(); + when(mockCodeExecutor.executeCode(any(), any())).thenReturn(executionResult); + llmReqBuilder.contents( + new ArrayList<>( + ImmutableList.of( + Content.builder() + .role("user") + .parts( + ImmutableList.of( + Part.builder() + .inlineData( + Blob.builder() + .mimeType("text/csv") + .data("1,2,3\n".getBytes(UTF_8))) + .build())) + .build()))); + + // act + Single result = + CodeExecution.requestProcessor.processRequest(invocationContext, llmReqBuilder.build()); + TestObserver testObserver = result.test(); + + // assert + testObserver.assertNoErrors(); + } } From 9ce78d7c3e1b0fb6d8d4fdce9052a572ffb9e515 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 12 Mar 2026 07:59:06 -0700 Subject: [PATCH 35/50] feat: Update converters for task and artifact events; add long running tools ids PiperOrigin-RevId: 882591822 --- .../adk/a2a/converters/PartConverter.java | 113 +++++++--------- .../adk/a2a/converters/ResponseConverter.java | 123 ++++++++++++++---- .../adk/a2a/converters/PartConverterTest.java | 77 +++++------ .../a2a/converters/ResponseConverterTest.java | 77 +++++++++++ 4 files changed, 249 insertions(+), 141 deletions(-) diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 36af6cc8b..61f24fa21 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -78,13 +78,13 @@ public static Optional toTextPart(io.a2a.spec.Part part) { } /** Convert an A2A JSON part into a Google GenAI part representation. */ - public static Optional toGenaiPart(io.a2a.spec.Part a2aPart) { + public static com.google.genai.types.Part toGenaiPart(io.a2a.spec.Part a2aPart) { if (a2aPart == null) { - return Optional.empty(); + throw new IllegalArgumentException("A2A part cannot be null"); } if (a2aPart instanceof TextPart textPart) { - return Optional.of(com.google.genai.types.Part.builder().text(textPart.getText()).build()); + return com.google.genai.types.Part.builder().text(textPart.getText()).build(); } if (a2aPart instanceof FilePart filePart) { @@ -95,56 +95,41 @@ public static Optional toGenaiPart(io.a2a.spec.Part return convertDataPartToGenAiPart(dataPart); } - logger.warn("Unsupported A2A part type: {}", a2aPart.getClass()); - return Optional.empty(); + throw new IllegalArgumentException("Unsupported A2A part type: " + a2aPart.getClass()); } public static ImmutableList toGenaiParts( List> a2aParts) { - return a2aParts.stream() - .map(PartConverter::toGenaiPart) - .flatMap(Optional::stream) - .collect(toImmutableList()); + return a2aParts.stream().map(PartConverter::toGenaiPart).collect(toImmutableList()); } - private static Optional convertFilePartToGenAiPart( - FilePart filePart) { + private static com.google.genai.types.Part convertFilePartToGenAiPart(FilePart filePart) { FileContent fileContent = filePart.getFile(); if (fileContent instanceof FileWithUri fileWithUri) { - return Optional.of( - com.google.genai.types.Part.builder() - .fileData( - FileData.builder() - .fileUri(fileWithUri.uri()) - .mimeType(fileWithUri.mimeType()) - .build()) - .build()); + return com.google.genai.types.Part.builder() + .fileData( + FileData.builder() + .fileUri(fileWithUri.uri()) + .mimeType(fileWithUri.mimeType()) + .build()) + .build(); } if (fileContent instanceof FileWithBytes fileWithBytes) { String bytesString = fileWithBytes.bytes(); if (bytesString == null) { - logger.warn("FileWithBytes missing byte content"); - return Optional.empty(); - } - try { - byte[] decoded = Base64.getDecoder().decode(bytesString); - return Optional.of( - com.google.genai.types.Part.builder() - .inlineData(Blob.builder().data(decoded).mimeType(fileWithBytes.mimeType()).build()) - .build()); - } catch (IllegalArgumentException e) { - logger.warn("Failed to decode base64 file content", e); - return Optional.empty(); + throw new GenAiFieldMissingException("FileWithBytes missing byte content"); } + byte[] decoded = Base64.getDecoder().decode(bytesString); + return com.google.genai.types.Part.builder() + .inlineData(Blob.builder().data(decoded).mimeType(fileWithBytes.mimeType()).build()) + .build(); } - logger.warn("Unsupported FilePart content: {}", fileContent.getClass()); - return Optional.empty(); + throw new IllegalArgumentException("Unsupported FilePart content: " + fileContent.getClass()); } - private static Optional convertDataPartToGenAiPart( - DataPart dataPart) { + private static com.google.genai.types.Part convertDataPartToGenAiPart(DataPart dataPart) { Map data = Optional.ofNullable(dataPart.getData()).map(HashMap::new).orElseGet(HashMap::new); Map metadata = @@ -154,14 +139,12 @@ private static Optional convertDataPartToGenAiPart( if ((data.containsKey(NAME_KEY) && data.containsKey(ARGS_KEY)) || metadataType.equals(A2ADataPartMetadataType.FUNCTION_CALL.getType())) { - String functionName = String.valueOf(data.getOrDefault(NAME_KEY, null)); - String functionId = String.valueOf(data.getOrDefault(ID_KEY, null)); + String functionName = String.valueOf(data.getOrDefault(NAME_KEY, "")); + String functionId = String.valueOf(data.getOrDefault(ID_KEY, "")); Map args = coerceToMap(data.get(ARGS_KEY)); - return Optional.of( - com.google.genai.types.Part.builder() - .functionCall( - FunctionCall.builder().name(functionName).id(functionId).args(args).build()) - .build()); + return com.google.genai.types.Part.builder() + .functionCall(FunctionCall.builder().name(functionName).id(functionId).args(args).build()) + .build(); } if ((data.containsKey(NAME_KEY) && data.containsKey(RESPONSE_KEY)) @@ -169,15 +152,14 @@ private static Optional convertDataPartToGenAiPart( String functionName = String.valueOf(data.getOrDefault(NAME_KEY, "")); String functionId = String.valueOf(data.getOrDefault(ID_KEY, "")); Map response = coerceToMap(data.get(RESPONSE_KEY)); - return Optional.of( - com.google.genai.types.Part.builder() - .functionResponse( - FunctionResponse.builder() - .name(functionName) - .id(functionId) - .response(response) - .build()) - .build()); + return com.google.genai.types.Part.builder() + .functionResponse( + FunctionResponse.builder() + .name(functionName) + .id(functionId) + .response(response) + .build()) + .build(); } if ((data.containsKey(CODE_KEY) && data.containsKey(LANGUAGE_KEY)) @@ -185,13 +167,11 @@ private static Optional convertDataPartToGenAiPart( String code = String.valueOf(data.getOrDefault(CODE_KEY, "")); String language = String.valueOf( - data.getOrDefault(LANGUAGE_KEY, Language.Known.LANGUAGE_UNSPECIFIED.toString()) - .toString()); - return Optional.of( - com.google.genai.types.Part.builder() - .executableCode( - ExecutableCode.builder().code(code).language(new Language(language)).build()) - .build()); + data.getOrDefault(LANGUAGE_KEY, Language.Known.LANGUAGE_UNSPECIFIED.toString())); + return com.google.genai.types.Part.builder() + .executableCode( + ExecutableCode.builder().code(code).language(new Language(language)).build()) + .build(); } if ((data.containsKey(OUTCOME_KEY) && data.containsKey(OUTPUT_KEY)) @@ -199,22 +179,17 @@ private static Optional convertDataPartToGenAiPart( String outcome = String.valueOf(data.getOrDefault(OUTCOME_KEY, Outcome.Known.OUTCOME_OK).toString()); String output = String.valueOf(data.getOrDefault(OUTPUT_KEY, "")); - return Optional.of( - com.google.genai.types.Part.builder() - .codeExecutionResult( - CodeExecutionResult.builder() - .outcome(new Outcome(outcome)) - .output(output) - .build()) - .build()); + return com.google.genai.types.Part.builder() + .codeExecutionResult( + CodeExecutionResult.builder().outcome(new Outcome(outcome)).output(output).build()) + .build(); } try { String json = objectMapper.writeValueAsString(data); - return Optional.of(com.google.genai.types.Part.builder().text(json).build()); + return com.google.genai.types.Part.builder().text(json).build(); } catch (JsonProcessingException e) { - logger.warn("Failed to serialize DataPart payload", e); - return Optional.empty(); + throw new IllegalArgumentException("Failed to serialize DataPart payload", e); } } diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java index f3be48c1b..503432a30 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java @@ -16,6 +16,8 @@ package com.google.adk.a2a.converters; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Streams.zip; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; @@ -29,6 +31,7 @@ import io.a2a.client.TaskEvent; import io.a2a.client.TaskUpdateEvent; import io.a2a.spec.Artifact; +import io.a2a.spec.DataPart; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskArtifactUpdateEvent; @@ -36,6 +39,7 @@ import io.a2a.spec.TaskStatusUpdateEvent; import java.time.Instant; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.UUID; @@ -70,6 +74,14 @@ public static Optional clientEventToEvent( throw new IllegalArgumentException("Unsupported ClientEvent type: " + event.getClass()); } + private static boolean isPartial(Map metadata) { + if (metadata == null) { + return false; + } + return Objects.equals( + metadata.getOrDefault(PartConverter.A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, false), true); + } + /** * Converts a A2A {@link TaskUpdateEvent} to an ADK {@link Event}, if applicable. Returns null if * the event is not a final update for TaskArtifactUpdateEvent or if the message is empty for @@ -85,7 +97,14 @@ private static Optional handleTaskUpdate( boolean isAppend = Objects.equals(artifactEvent.isAppend(), true); boolean isLastChunk = Objects.equals(artifactEvent.isLastChunk(), true); + if (isLastChunk && isPartial(artifactEvent.getMetadata())) { + return Optional.empty(); + } + Event eventPart = artifactToEvent(artifactEvent.getArtifact(), context); + if (eventPart.content().flatMap(Content::parts).orElse(ImmutableList.of()).isEmpty()) { + return Optional.empty(); + } eventPart.setPartial(isAppend || !isLastChunk); // append=true, lastChunk=false: emit as partial, update aggregation // append=false, lastChunk=false: emit as partial, reset aggregation @@ -115,9 +134,8 @@ private static Optional handleTaskUpdate( .map(builder -> builder.turnComplete(true)) .map(builder -> builder.partial(false)) .map(Event.Builder::build); - } else { - return messageEvent; } + return messageEvent; } throw new IllegalArgumentException( "Unsupported TaskUpdateEvent type: " + updateEvent.getClass()); @@ -125,16 +143,12 @@ private static Optional handleTaskUpdate( /** Converts an artifact to an ADK event. */ public static Event artifactToEvent(Artifact artifact, InvocationContext invocationContext) { - Message message = - new Message.Builder().role(Message.Role.AGENT).parts(artifact.parts()).build(); - return messageToEvent(message, invocationContext); - } - - /** Converts an A2A message back to ADK events. */ - public static Event messageToEvent(Message message, InvocationContext invocationContext) { - return remoteAgentEventBuilder(invocationContext) - .content(fromModelParts(PartConverter.toGenaiParts(message.getParts()))) - .build(); + Event.Builder eventBuilder = remoteAgentEventBuilder(invocationContext); + ImmutableList genaiParts = PartConverter.toGenaiParts(artifact.parts()); + eventBuilder + .content(fromModelParts(genaiParts)) + .longRunningToolIds(getLongRunningToolIds(artifact.parts(), genaiParts)); + return eventBuilder.build(); } /** Converts an A2A message for a failed task to ADK event filling in the error message. */ @@ -147,6 +161,13 @@ public static Event messageToFailedEvent(Message message, InvocationContext invo return builder.build(); } + /** Converts an A2A message back to ADK events. */ + public static Event messageToEvent(Message message, InvocationContext invocationContext) { + return remoteAgentEventBuilder(invocationContext) + .content(fromModelParts(PartConverter.toGenaiParts(message.getParts()))) + .build(); + } + /** * Converts an A2A message back to ADK events. For streaming task in pending state it sets the * thought field to true, to mark them as thought updates. @@ -168,25 +189,71 @@ public static Event messageToEvent( * If none of these are present, an empty event is returned. */ public static Event taskToEvent(Task task, InvocationContext invocationContext) { - Message taskMessage = null; - - if (!task.getArtifacts().isEmpty()) { - taskMessage = - new Message.Builder() - .messageId("") - .role(Message.Role.AGENT) - .parts(Iterables.getLast(task.getArtifacts()).parts()) - .build(); - } else if (task.getStatus().message() != null) { - taskMessage = task.getStatus().message(); - } else if (!task.getHistory().isEmpty()) { - taskMessage = Iterables.getLast(task.getHistory()); + ImmutableList.Builder genaiParts = ImmutableList.builder(); + ImmutableSet.Builder longRunningToolIds = ImmutableSet.builder(); + + for (Artifact artifact : task.getArtifacts()) { + ImmutableList converted = PartConverter.toGenaiParts(artifact.parts()); + longRunningToolIds.addAll(getLongRunningToolIds(artifact.parts(), converted)); + genaiParts.addAll(converted); + } + + Event.Builder eventBuilder = remoteAgentEventBuilder(invocationContext); + + if (task.getStatus().message() != null) { + ImmutableList msgParts = + PartConverter.toGenaiParts(task.getStatus().message().getParts()); + longRunningToolIds.addAll( + getLongRunningToolIds(task.getStatus().message().getParts(), msgParts)); + if (task.getStatus().state() == TaskState.FAILED + && msgParts.size() == 1 + && msgParts.get(0).text().isPresent()) { + eventBuilder.errorMessage(msgParts.get(0).text().get()); + } else { + genaiParts.addAll(msgParts); + } } - if (taskMessage != null) { - return messageToEvent(taskMessage, invocationContext); + ImmutableList finalParts = genaiParts.build(); + boolean isFinal = + task.getStatus().state().isFinal() || task.getStatus().state() == TaskState.INPUT_REQUIRED; + + if (finalParts.isEmpty() && !isFinal) { + return emptyEvent(invocationContext); } - return emptyEvent(invocationContext); + if (!finalParts.isEmpty()) { + eventBuilder.content(fromModelParts(finalParts)); + } + if (task.getStatus().state() == TaskState.INPUT_REQUIRED) { + eventBuilder.longRunningToolIds(longRunningToolIds.build()); + } + eventBuilder.turnComplete(isFinal); + return eventBuilder.build(); + } + + private static ImmutableSet getLongRunningToolIds( + List> parts, List convertedParts) { + return zip( + parts.stream(), + convertedParts.stream(), + (part, convertedPart) -> { + if (!(part instanceof DataPart dataPart)) { + return Optional.empty(); + } + Object isLongRunning = + dataPart + .getMetadata() + .get(PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY); + if (!Objects.equals(isLongRunning, true)) { + return Optional.empty(); + } + if (convertedPart.functionCall().isEmpty()) { + return Optional.empty(); + } + return convertedPart.functionCall().get().id(); + }) + .flatMap(Optional::stream) + .collect(toImmutableSet()); } private static Event emptyEvent(InvocationContext invocationContext) { diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java index 8e8982ffa..d93466dd2 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java @@ -18,7 +18,6 @@ import io.a2a.spec.FileWithUri; import io.a2a.spec.TextPart; import java.util.Base64; -import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -27,29 +26,27 @@ public class PartConverterTest { @Test - public void toGenaiPart_withNullPart_returnsEmpty() { - assertThat(PartConverter.toGenaiPart(null)).isEmpty(); + public void toGenaiPart_withNullPart_throwsException() { + assertThrows(IllegalArgumentException.class, () -> PartConverter.toGenaiPart(null)); } @Test public void toGenaiPart_withTextPart_returnsGenaiTextPart() { TextPart textPart = new TextPart("Hello"); - Optional result = PartConverter.toGenaiPart(textPart); + Part result = PartConverter.toGenaiPart(textPart); - assertThat(result).isPresent(); - assertThat(result.get().text()).hasValue("Hello"); + assertThat(result.text()).hasValue("Hello"); } @Test public void toGenaiPart_withFilePartUri_returnsGenaiFilePart() { FilePart filePart = new FilePart(new FileWithUri("text/plain", "file.txt", "http://file.txt")); - Optional result = PartConverter.toGenaiPart(filePart); + Part result = PartConverter.toGenaiPart(filePart); - assertThat(result).isPresent(); - assertThat(result.get().fileData()).isPresent(); - FileData fileData = result.get().fileData().get(); + assertThat(result.fileData()).isPresent(); + FileData fileData = result.fileData().get(); assertThat(fileData.mimeType()).hasValue("text/plain"); assertThat(fileData.fileUri()).hasValue("http://file.txt"); } @@ -60,26 +57,25 @@ public void toGenaiPart_withFilePartBytes_returnsGenaiBlobPart() { String encoded = Base64.getEncoder().encodeToString(bytes); FilePart filePart = new FilePart(new FileWithBytes("text/plain", "file.txt", encoded)); - Optional result = PartConverter.toGenaiPart(filePart); + Part result = PartConverter.toGenaiPart(filePart); - assertThat(result).isPresent(); - assertThat(result.get().inlineData()).isPresent(); - Blob blob = result.get().inlineData().get(); + assertThat(result.inlineData()).isPresent(); + Blob blob = result.inlineData().get(); assertThat(blob.mimeType()).hasValue("text/plain"); assertThat(blob.data().get()).isEqualTo(bytes); } @Test - public void toGenaiPart_withFilePartBytes_handlesNullBytes() { + public void toGenaiPart_withFilePartBytes_handlesNullBytes_throwsException() { FilePart filePart = new FilePart(new FileWithBytes("text/plain", "file.txt", null)); - assertThat(PartConverter.toGenaiPart(filePart)).isEmpty(); + assertThrows(GenAiFieldMissingException.class, () -> PartConverter.toGenaiPart(filePart)); } @Test public void toGenaiPart_withFilePartBytes_handlesInvalidBase64() { FilePart filePart = new FilePart(new FileWithBytes("text/plain", "file.txt", "invalid-base64!")); - assertThat(PartConverter.toGenaiPart(filePart)).isEmpty(); + assertThrows(IllegalArgumentException.class, () -> PartConverter.toGenaiPart(filePart)); } @Test @@ -93,11 +89,10 @@ public void toGenaiPart_withDataPartFunctionCall_returnsGenaiFunctionCallPart() PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.FUNCTION_CALL.getType())); - Optional result = PartConverter.toGenaiPart(dataPart); + Part result = PartConverter.toGenaiPart(dataPart); - assertThat(result).isPresent(); - assertThat(result.get().functionCall()).isPresent(); - FunctionCall functionCall = result.get().functionCall().get(); + assertThat(result.functionCall()).isPresent(); + FunctionCall functionCall = result.functionCall().get(); assertThat(functionCall.name()).hasValue("func"); assertThat(functionCall.id()).hasValue("1"); assertThat(functionCall.args()).hasValue(ImmutableMap.of()); @@ -109,11 +104,10 @@ public void toGenaiPart_withDataPartFunctionCallByNameAndArgs_returnsGenaiFuncti ImmutableMap.of("name", "func", "id", "1", "args", ImmutableMap.of("param", "value")); DataPart dataPart = new DataPart(data, null); - Optional result = PartConverter.toGenaiPart(dataPart); + Part result = PartConverter.toGenaiPart(dataPart); - assertThat(result).isPresent(); - assertThat(result.get().functionCall()).isPresent(); - FunctionCall functionCall = result.get().functionCall().get(); + assertThat(result.functionCall()).isPresent(); + FunctionCall functionCall = result.functionCall().get(); assertThat(functionCall.name()).hasValue("func"); assertThat(functionCall.id()).hasValue("1"); assertThat(functionCall.args()).hasValue(ImmutableMap.of("param", "value")); @@ -130,11 +124,10 @@ public void toGenaiPart_withDataPartFunctionResponse_returnsGenaiFunctionRespons PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.FUNCTION_RESPONSE.getType())); - Optional result = PartConverter.toGenaiPart(dataPart); + Part result = PartConverter.toGenaiPart(dataPart); - assertThat(result).isPresent(); - assertThat(result.get().functionResponse()).isPresent(); - FunctionResponse functionResponse = result.get().functionResponse().get(); + assertThat(result.functionResponse()).isPresent(); + FunctionResponse functionResponse = result.functionResponse().get(); assertThat(functionResponse.name()).hasValue("func"); assertThat(functionResponse.id()).hasValue("1"); assertThat(functionResponse.response()).hasValue(ImmutableMap.of()); @@ -147,11 +140,10 @@ public void toGenaiPart_withDataPartFunctionResponse_returnsGenaiFunctionRespons ImmutableMap.of("name", "func", "id", "1", "response", ImmutableMap.of("result", "value")); DataPart dataPart = new DataPart(data, null); - Optional result = PartConverter.toGenaiPart(dataPart); + Part result = PartConverter.toGenaiPart(dataPart); - assertThat(result).isPresent(); - assertThat(result.get().functionResponse()).isPresent(); - FunctionResponse functionResponse = result.get().functionResponse().get(); + assertThat(result.functionResponse()).isPresent(); + FunctionResponse functionResponse = result.functionResponse().get(); assertThat(functionResponse.name()).hasValue("func"); assertThat(functionResponse.id()).hasValue("1"); assertThat(functionResponse.response()).hasValue(ImmutableMap.of("result", "value")); @@ -162,10 +154,9 @@ public void toGenaiPart_withOtherDataPart_returnsGenaiTextPartWithJson() { ImmutableMap data = ImmutableMap.of("key", "value"); DataPart dataPart = new DataPart(data, null); - Optional result = PartConverter.toGenaiPart(dataPart); + Part result = PartConverter.toGenaiPart(dataPart); - assertThat(result).isPresent(); - assertThat(result.get().text()).hasValue("{\"key\":\"value\"}"); + assertThat(result.text()).hasValue("{\"key\":\"value\"}"); } @Test @@ -293,11 +284,10 @@ public void toGenaiPart_dataPartWithEmptyStringCoercedToEmptyMap() { ImmutableMap data = ImmutableMap.of("name", "func", "id", "1", "args", ""); DataPart dataPart = new DataPart(data, null); - Optional result = PartConverter.toGenaiPart(dataPart); + Part result = PartConverter.toGenaiPart(dataPart); - assertThat(result).isPresent(); - assertThat(result.get().functionCall()).isPresent(); - assertThat(result.get().functionCall().get().args()).hasValue(ImmutableMap.of()); + assertThat(result.functionCall()).isPresent(); + assertThat(result.functionCall().get().args()).hasValue(ImmutableMap.of()); } @Test @@ -305,10 +295,9 @@ public void toGenaiPart_dataPartWithNonMapCoercedToMap() { ImmutableMap data = ImmutableMap.of("name", "func", "id", "1", "args", 123); DataPart dataPart = new DataPart(data, null); - Optional result = PartConverter.toGenaiPart(dataPart); + Part result = PartConverter.toGenaiPart(dataPart); - assertThat(result).isPresent(); - assertThat(result.get().functionCall()).isPresent(); - assertThat(result.get().functionCall().get().args()).hasValue(ImmutableMap.of("value", 123)); + assertThat(result.functionCall()).isPresent(); + assertThat(result.functionCall().get().args()).hasValue(ImmutableMap.of("value", 123)); } } diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java index 5378bdd7b..d84dc42cd 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java @@ -11,10 +11,12 @@ import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import io.a2a.client.MessageEvent; import io.a2a.client.TaskUpdateEvent; import io.a2a.spec.Artifact; +import io.a2a.spec.DataPart; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskArtifactUpdateEvent; @@ -144,6 +146,81 @@ public void taskToEvent_withNoMessage_returnsEmptyEvent() { assertThat(event.invocationId()).isEqualTo(invocationContext.invocationId()); } + @Test + public void taskToEvent_withInputRequired_parsesLongRunningToolIds() { + ImmutableMap data = + ImmutableMap.of("name", "myTool", "id", "call_123", "args", ImmutableMap.of()); + ImmutableMap metadata = + ImmutableMap.of( + PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + "function_call", + PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, + true); + DataPart dataPart = new DataPart(data, metadata); + ImmutableMap statusData = + ImmutableMap.of("name", "messageTools", "id", "msg_123", "args", ImmutableMap.of()); + ImmutableMap statusMetadata = + ImmutableMap.of( + PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + "function_call", + PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, + true); + DataPart statusDataPart = new DataPart(statusData, statusMetadata); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(statusDataPart)) + .build(); + TaskStatus status = new TaskStatus(TaskState.INPUT_REQUIRED, statusMessage, null); + + Artifact artifact = + new Artifact.Builder().artifactId("artifact-1").parts(ImmutableList.of(dataPart)).build(); + Task task = testTask().status(status).artifacts(ImmutableList.of(artifact)).build(); + + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.longRunningToolIds().get()).containsExactly("call_123", "msg_123"); + } + + @Test + public void taskToEvent_withFailedState_setsErrorCode() { + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Task failed"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.FAILED, statusMessage, null); + Task task = testTask().status(status).artifacts(ImmutableList.of()).build(); + + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.errorMessage()).hasValue("Task failed"); + } + + @Test + public void taskToEvent_withFinalEvent_returnsEmptyEvent() { + TaskStatus status = new TaskStatus(TaskState.COMPLETED); + Task task = testTask().status(status).artifacts(ImmutableList.of()).build(); + + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.invocationId()).isEqualTo(invocationContext.invocationId()); + assertThat(event.turnComplete()).hasValue(true); + assertThat(event.content().flatMap(Content::parts).orElse(ImmutableList.of())).isEmpty(); + } + + @Test + public void taskToEvent_withEmptyParts_returnsEmptyEvent() { + TaskStatus status = new TaskStatus(TaskState.SUBMITTED); + Task task = testTask().status(status).artifacts(ImmutableList.of()).build(); + + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.invocationId()).isEqualTo(invocationContext.invocationId()); + assertThat(event.content()).isPresent(); + assertThat(event.content().get().parts().orElse(ImmutableList.of())).isEmpty(); + } + @Test public void clientEventToEvent_withTaskUpdateEventAndThought_returnsThoughtEvent() { Message statusMessage = From 41f5af0dceb78501ca8b94e434e4d751f608a699 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 12 Mar 2026 08:36:18 -0700 Subject: [PATCH 36/50] fix: Removing deprecated InvocationContext methods PiperOrigin-RevId: 882608653 --- .../google/adk/agents/InvocationContext.java | 57 ------------------- 1 file changed, 57 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 7f0e49d0c..91ce13a87 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -27,7 +27,6 @@ import com.google.adk.sessions.Session; import com.google.adk.summarizer.EventsCompactionConfig; import com.google.errorprone.annotations.CanIgnoreReturnValue; -import com.google.errorprone.annotations.InlineMe; import com.google.genai.types.Content; import java.util.Map; import java.util.Objects; @@ -81,62 +80,6 @@ protected InvocationContext(Builder builder) { this.callbackContextData = builder.callbackContextData; } - /** - * @deprecated Use {@link #builder()} instead. - */ - @InlineMe( - replacement = - "InvocationContext.builder()" - + ".sessionService(sessionService)" - + ".artifactService(artifactService)" - + ".invocationId(invocationId)" - + ".agent(agent)" - + ".session(session)" - + ".userContent(userContent)" - + ".runConfig(runConfig)" - + ".build()", - imports = {"com.google.adk.agents.InvocationContext"}) - @Deprecated(forRemoval = true) - public static InvocationContext create( - BaseSessionService sessionService, - BaseArtifactService artifactService, - String invocationId, - BaseAgent agent, - Session session, - Content userContent, - RunConfig runConfig) { - return builder() - .sessionService(sessionService) - .artifactService(artifactService) - .invocationId(invocationId) - .agent(agent) - .session(session) - .userContent(userContent) - .runConfig(runConfig) - .build(); - } - - /** - * @deprecated Use {@link #builder()} instead. - */ - @Deprecated(forRemoval = true) - public static InvocationContext create( - BaseSessionService sessionService, - BaseArtifactService artifactService, - BaseAgent agent, - Session session, - LiveRequestQueue liveRequestQueue, - RunConfig runConfig) { - return builder() - .sessionService(sessionService) - .artifactService(artifactService) - .agent(agent) - .session(session) - .liveRequestQueue(liveRequestQueue) - .runConfig(runConfig) - .build(); - } - /** Returns a new {@link Builder} for creating {@link InvocationContext} instances. */ public static Builder builder() { return new Builder(); From f9d013bdc09eaf29dfff7f6abc3e57a986c1f08d Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Thu, 12 Mar 2026 09:38:23 -0700 Subject: [PATCH 37/50] refactor: update LoggingPlugin.formatContent to accept @Nullable instead of Optional PiperOrigin-RevId: 882637228 --- .../com/google/adk/plugins/LoggingPlugin.java | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java b/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java index 573f5048d..7daf13b11 100644 --- a/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java @@ -29,7 +29,7 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import java.util.Map; -import java.util.Optional; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,7 +57,7 @@ private void log(String message) { @Override public Maybe onUserMessageCallback( - InvocationContext invocationContext, Content userMessage) { + InvocationContext invocationContext, @Nullable Content userMessage) { return Maybe.fromAction( () -> { log("🚀 USER MESSAGE RECEIVED"); @@ -66,7 +66,7 @@ public Maybe onUserMessageCallback( log(" User ID: " + invocationContext.userId()); log(" App Name: " + invocationContext.appName()); log(" Root Agent: " + invocationContext.agent().name()); - log(" User Content: " + formatContent(Optional.ofNullable(userMessage))); + log(" User Content: " + formatContent(userMessage)); invocationContext.branch().ifPresent(branch -> log(" Branch: " + branch)); }); } @@ -88,7 +88,7 @@ public Maybe onEventCallback(InvocationContext invocationContext, Event e log("📢 EVENT YIELDED"); log(" Event ID: " + event.id()); log(" Author: " + event.author()); - log(" Content: " + formatContent(event.content())); + log(" Content: " + formatContent(event.content().orElse(null))); log(" Final Response: " + event.finalResponse()); if (!event.functionCalls().isEmpty()) { @@ -190,7 +190,7 @@ public Maybe afterModelCallback( log(" ❌ ERROR - Code: " + llmResponse.errorCode().get()); log(" Error Message: " + llmResponse.errorMessage().orElse("None")); } else { - log(" Content: " + formatContent(llmResponse.content())); + log(" Content: " + formatContent(llmResponse.content().orElse(null))); llmResponse.partial().ifPresent(partial -> log(" Partial: " + partial)); llmResponse .turnComplete() @@ -265,12 +265,8 @@ public Maybe> onToolErrorCallback( }); } - private String formatContent(Optional contentOptional) { - if (contentOptional.isEmpty()) { - return "None"; - } - Content content = contentOptional.get(); - if (content.parts().isEmpty() || content.parts().get().isEmpty()) { + private String formatContent(@Nullable Content content) { + if (content == null || content.parts().isEmpty() || content.parts().get().isEmpty()) { return "None"; } return content.parts().get().stream() From dc51aec88b04ffa3bd8b6021107db398fea257f6 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 13 Mar 2026 02:47:26 -0700 Subject: [PATCH 38/50] chore: update deprecation comment for setStateDelta PiperOrigin-RevId: 883051607 --- core/src/main/java/com/google/adk/events/EventActions.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 38e7d1d96..1ca856b45 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -87,7 +87,7 @@ public Map stateDelta() { return stateDelta; } - @Deprecated // Use stateDelta(), addState() and removeStateByKey() instead. + @Deprecated // Use stateDelta() and removeStateByKey() instead. public void setStateDelta(ConcurrentMap stateDelta) { this.stateDelta = stateDelta; } From 5fd4c53c88e977d004b9eee8fa3697625ec85f47 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Fri, 13 Mar 2026 03:47:16 -0700 Subject: [PATCH 39/50] feat: replace Optional type of version in BaseArtifactService.loadArtifact with Nullable PiperOrigin-RevId: 883074383 --- .../google/adk/artifacts/BaseArtifactService.java | 13 ++++--------- .../google/adk/artifacts/GcsArtifactService.java | 5 +++-- .../adk/artifacts/InMemoryArtifactService.java | 11 +++++------ .../java/com/google/adk/utils/InstructionUtils.java | 8 +------- .../adk/artifacts/GcsArtifactServiceTest.java | 13 ++++++------- .../adk/artifacts/InMemoryArtifactServiceTest.java | 2 +- .../google/adk/flows/llmflows/InstructionsTest.java | 7 +------ .../com/google/adk/tools/LoadArtifactsToolTest.java | 8 ++++---- .../adk/web/controller/ArtifactController.java | 7 ++----- 9 files changed, 27 insertions(+), 47 deletions(-) diff --git a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java index a9bb6ba4d..acf5979c2 100644 --- a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java @@ -22,7 +22,7 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; -import java.util.Optional; +import org.jspecify.annotations.Nullable; /** Base interface for artifact services. */ public interface BaseArtifactService { @@ -75,7 +75,7 @@ default Single saveAndReloadArtifact( /** Loads the latest version of an artifact from the service. */ default Maybe loadArtifact( String appName, String userId, String sessionId, String filename) { - return loadArtifact(appName, userId, sessionId, filename, Optional.empty()); + return loadArtifact(appName, userId, sessionId, filename, /* version= */ (Integer) null); } /** Loads the latest version of an artifact from the service. */ @@ -86,7 +86,7 @@ default Maybe loadArtifact(SessionKey sessionKey, String filename) { /** Loads a specific version of an artifact from the service. */ default Maybe loadArtifact( String appName, String userId, String sessionId, String filename, int version) { - return loadArtifact(appName, userId, sessionId, filename, Optional.of(version)); + return loadArtifact(appName, userId, sessionId, filename, Integer.valueOf(version)); } default Maybe loadArtifact(SessionKey sessionKey, String filename, int version) { @@ -94,13 +94,8 @@ default Maybe loadArtifact(SessionKey sessionKey, String filename, int ver sessionKey.appName(), sessionKey.userId(), sessionKey.id(), filename, version); } - /** - * @deprecated Use {@link #loadArtifact(String, String, String, String)} or {@link - * #loadArtifact(String, String, String, String, int)} instead. - */ - @Deprecated Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version); + String appName, String userId, String sessionId, String filename, @Nullable Integer version); /** * Lists all the artifact filenames within a session. diff --git a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java index e31d50327..977153828 100644 --- a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java @@ -38,6 +38,7 @@ import java.util.List; import java.util.Optional; import java.util.Set; +import org.jspecify.annotations.Nullable; /** An artifact service implementation using Google Cloud Storage (GCS). */ public final class GcsArtifactService implements BaseArtifactService { @@ -126,8 +127,8 @@ public Single saveArtifact( */ @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { - return version + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { + return Optional.ofNullable(version) .map(Maybe::just) .orElseGet( () -> diff --git a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java index 8c8ec2af8..510c96c2e 100644 --- a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java @@ -28,8 +28,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.stream.IntStream; +import org.jspecify.annotations.Nullable; /** An in-memory implementation of the {@link BaseArtifactService}. */ public final class InMemoryArtifactService implements BaseArtifactService { @@ -61,7 +61,7 @@ public Single saveArtifact( */ @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { List versions = getArtifactsMap(appName, userId, sessionId) .computeIfAbsent(filename, unused -> new ArrayList<>()); @@ -69,10 +69,9 @@ public Maybe loadArtifact( if (versions.isEmpty()) { return Maybe.empty(); } - if (version.isPresent()) { - int v = version.get(); - if (v >= 0 && v < versions.size()) { - return Maybe.just(versions.get(v)); + if (version != null) { + if (version >= 0 && version < versions.size()) { + return Maybe.just(versions.get(version)); } else { return Maybe.empty(); } diff --git a/core/src/main/java/com/google/adk/utils/InstructionUtils.java b/core/src/main/java/com/google/adk/utils/InstructionUtils.java index ea118b362..ff2a7b8bd 100644 --- a/core/src/main/java/com/google/adk/utils/InstructionUtils.java +++ b/core/src/main/java/com/google/adk/utils/InstructionUtils.java @@ -25,7 +25,6 @@ import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.List; -import java.util.Optional; import java.util.regex.MatchResult; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -167,12 +166,7 @@ private static Single resolveMatchAsync(InvocationContext context, Match Maybe artifactMaybe = context .artifactService() - .loadArtifact( - session.appName(), - session.userId(), - session.id(), - artifactName, - Optional.empty()); + .loadArtifact(session.appName(), session.userId(), session.id(), artifactName); return artifactMaybe .map(Part::toJson) diff --git a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java index 88abd60c4..3b3c8c402 100644 --- a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java @@ -158,7 +158,7 @@ public void load_latestVersion_loadsCorrectly() { when(mockStorage.get(blobIdV1)).thenReturn(blobV1); Optional loadedArtifact = - asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.empty())); + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME)); assertThat(loadedArtifact).isPresent(); Optional actualDataOptional = loadedArtifact.get().inlineData().get().data(); @@ -177,7 +177,7 @@ public void load_specificVersion_loadsCorrectly() { when(mockStorage.get(blobIdV0)).thenReturn(blobV0); Optional loadedArtifact = - asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.of(0))); + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, 0)); assertThat(loadedArtifact).isPresent(); Optional actualDataOptional = loadedArtifact.get().inlineData().get().data(); @@ -197,8 +197,7 @@ public void load_userNamespace_loadsCorrectly() { when(mockStorage.get(blobIdV0)).thenReturn(blobV0); Optional loadedArtifact = - asOptional( - service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, USER_FILENAME, Optional.empty())); + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, USER_FILENAME)); assertThat(loadedArtifact).isPresent(); Optional actualDataOptional = loadedArtifact.get().inlineData().get().data(); @@ -216,7 +215,7 @@ public void load_versionNotFound_returnsEmpty() { when(mockStorage.get(blobIdV0)).thenReturn(null); Optional loadedArtifact = - asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.of(0))); + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, 0)); assertThat(loadedArtifact).isEmpty(); verify(mockStorage).get(blobIdV0); @@ -227,7 +226,7 @@ public void load_noVersionsExist_returnsEmpty() { when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); Optional loadedArtifact = - asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.empty())); + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME)); assertThat(loadedArtifact).isEmpty(); } @@ -400,7 +399,7 @@ public void load_storageException_returnsEmpty() { when(mockStorage.get(blobIdV0)).thenThrow(new StorageException(500, "Induced error")); Optional loadedArtifact = - asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.of(0))); + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, 0)); assertThat(loadedArtifact).isEmpty(); } diff --git a/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java index 4cb493277..124a5e9d8 100644 --- a/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java @@ -59,7 +59,7 @@ public void loadArtifact_loadsLatest() { var unused2 = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact2).blockingGet(); Optional result = - asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.empty())); + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME)); assertThat(result).hasValue(artifact2); } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/InstructionsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/InstructionsTest.java index 2ac9e454d..90f710856 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/InstructionsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/InstructionsTest.java @@ -32,7 +32,6 @@ import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import org.junit.Before; import org.junit.Rule; @@ -122,11 +121,7 @@ public void processRequest_agentInstructionString_noPlaceholders_appendsInstruct Session session = createSession(); Part artifactPart = Part.fromText("Artifact content"); when(mockArtifactService.loadArtifact( - eq(session.appName()), - eq(session.userId()), - eq(session.id()), - eq("file.txt"), - eq(Optional.empty()))) + eq(session.appName()), eq(session.userId()), eq(session.id()), eq("file.txt"))) .thenReturn(Maybe.just(artifactPart)); LlmAgent agent = LlmAgent.builder().name("agent").instruction("File content: {artifact.file.txt}").build(); diff --git a/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java b/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java index 5ed7a1f40..8405cfc42 100644 --- a/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java +++ b/core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java @@ -1,7 +1,7 @@ package com.google.adk.tools; import static com.google.common.truth.Truth.assertThat; -import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.mock; @@ -105,7 +105,7 @@ public void processLlmRequest_noArtifactsInContext_completesWithoutLoading() { assertThat(finalRequest.config()).isPresent(); assertThat(finalRequest.config().get().systemInstruction()).isEmpty(); verify(mockArtifactService, never()) - .loadArtifact(anyString(), anyString(), anyString(), anyString(), any()); + .loadArtifact(anyString(), anyString(), anyString(), anyString(), anyInt()); } @Test @@ -130,7 +130,7 @@ public void processLlmRequest_artifactsInContext_noFunctionCall_appendsInstructi assertThat(appendedInstruction).contains("call the `load_artifacts` function"); verify(mockArtifactService, never()) - .loadArtifact(anyString(), anyString(), anyString(), anyString(), any()); + .loadArtifact(anyString(), anyString(), anyString(), anyString(), anyInt()); } @Test @@ -215,7 +215,7 @@ public void processLlmRequest_artifactsInContext_withOtherFunctionCall_doesNotLo .contains("You have a list of artifacts:"); verify(mockArtifactService, never()) - .loadArtifact(anyString(), anyString(), anyString(), anyString(), any()); + .loadArtifact(anyString(), anyString(), anyString(), anyString(), anyInt()); assertThat(finalRequest.contents()).containsExactly(functionCallContent); } } diff --git a/dev/src/main/java/com/google/adk/web/controller/ArtifactController.java b/dev/src/main/java/com/google/adk/web/controller/ArtifactController.java index 27164f216..c181ab558 100644 --- a/dev/src/main/java/com/google/adk/web/controller/ArtifactController.java +++ b/dev/src/main/java/com/google/adk/web/controller/ArtifactController.java @@ -24,7 +24,6 @@ import io.reactivex.rxjava3.core.Single; import java.util.Collections; import java.util.List; -import java.util.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -78,8 +77,7 @@ public Part loadArtifact( versionStr); Maybe artifactMaybe = - artifactService.loadArtifact( - appName, userId, sessionId, artifactName, Optional.ofNullable(version)); + artifactService.loadArtifact(appName, userId, sessionId, artifactName, version); Part artifact = artifactMaybe.blockingGet(); @@ -126,8 +124,7 @@ public Part loadArtifactVersion( versionId); Maybe artifactMaybe = - artifactService.loadArtifact( - appName, userId, sessionId, artifactName, Optional.of(versionId)); + artifactService.loadArtifact(appName, userId, sessionId, artifactName, versionId); Part artifact = artifactMaybe.blockingGet(); From 7cce374286cf7b5bf7672aef36eed7faf4a5b045 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Fri, 13 Mar 2026 05:25:56 -0700 Subject: [PATCH 40/50] chore: add explicit jackson-annotations dependency in core PiperOrigin-RevId: 883110228 --- core/pom.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/pom.xml b/core/pom.xml index a0f843f56..36ab783dd 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -92,6 +92,10 @@ com.google.errorprone error_prone_annotations + + com.fasterxml.jackson.core + jackson-annotations + com.fasterxml.jackson.core jackson-databind From 910d727f1981498151dea4cb91b9e5836f91e3ba Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Fri, 13 Mar 2026 05:39:37 -0700 Subject: [PATCH 41/50] feat!: refactor ApiClient constructors hierarchy to remove Optional parameters PiperOrigin-RevId: 883114754 --- .../com/google/adk/sessions/ApiClient.java | 94 ++++++++----------- .../google/adk/sessions/HttpApiClient.java | 25 +++-- .../google/adk/sessions/VertexAiClient.java | 11 +-- .../adk/sessions/VertexAiSessionService.java | 6 +- 4 files changed, 59 insertions(+), 77 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/ApiClient.java b/core/src/main/java/com/google/adk/sessions/ApiClient.java index e850199e9..1b0485dd2 100644 --- a/core/src/main/java/com/google/adk/sessions/ApiClient.java +++ b/core/src/main/java/com/google/adk/sessions/ApiClient.java @@ -16,11 +16,11 @@ package com.google.adk.sessions; -import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.StandardSystemProperty.JAVA_VERSION; import com.google.auth.oauth2.GoogleCredentials; import com.google.common.base.Ascii; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; import com.google.genai.errors.GenAiIOException; import com.google.genai.types.HttpOptions; @@ -35,83 +35,69 @@ abstract class ApiClient { OkHttpClient httpClient; // For Google AI APIs - final Optional apiKey; + final @Nullable String apiKey; // For Vertex AI APIs - final Optional project; - final Optional location; - final Optional credentials; + final @Nullable String project; + final @Nullable String location; + final @Nullable GoogleCredentials credentials; HttpOptions httpOptions; final boolean vertexAI; /** Constructs an ApiClient for Google AI APIs. */ - ApiClient(Optional apiKey, Optional customHttpOptions) { - checkNotNull(apiKey, "API Key cannot be null"); - checkNotNull(customHttpOptions, "customHttpOptions cannot be null"); + ApiClient(@Nullable String apiKey, @Nullable HttpOptions customHttpOptions) { - try { - this.apiKey = Optional.of(apiKey.orElseGet(() -> System.getenv("GOOGLE_API_KEY"))); - } catch (NullPointerException e) { + this.apiKey = apiKey != null ? apiKey : System.getenv("GOOGLE_API_KEY"); + + if (Strings.isNullOrEmpty(this.apiKey)) { throw new IllegalArgumentException( - "API key must either be provided or set in the environment variable" + " GOOGLE_API_KEY.", - e); + "API key must either be provided or set in the environment variable" + + " GOOGLE_API_KEY."); } - this.project = Optional.empty(); - this.location = Optional.empty(); - this.credentials = Optional.empty(); + this.project = null; + this.location = null; + this.credentials = null; this.vertexAI = false; this.httpOptions = defaultHttpOptions(/* vertexAI= */ false, this.location); - if (customHttpOptions.isPresent()) { - applyHttpOptions(customHttpOptions.get()); + if (customHttpOptions != null) { + applyHttpOptions(customHttpOptions); } this.httpClient = createHttpClient(httpOptions.timeout().orElse(null)); } ApiClient( - Optional project, - Optional location, - Optional credentials, - Optional customHttpOptions) { - checkNotNull(project, "project cannot be null"); - checkNotNull(location, "location cannot be null"); - checkNotNull(credentials, "credentials cannot be null"); - checkNotNull(customHttpOptions, "customHttpOptions cannot be null"); + @Nullable String project, + @Nullable String location, + @Nullable GoogleCredentials credentials, + @Nullable HttpOptions customHttpOptions) { - try { - this.project = Optional.of(project.orElseGet(() -> System.getenv("GOOGLE_CLOUD_PROJECT"))); - } catch (NullPointerException e) { + this.project = project != null ? project : System.getenv("GOOGLE_CLOUD_PROJECT"); + + if (Strings.isNullOrEmpty(this.project)) { throw new IllegalArgumentException( "Project must either be provided or set in the environment variable" - + " GOOGLE_CLOUD_PROJECT.", - e); - } - if (this.project.get().isEmpty()) { - throw new IllegalArgumentException("Project must not be empty."); + + " GOOGLE_CLOUD_PROJECT."); } - try { - this.location = Optional.of(location.orElse(System.getenv("GOOGLE_CLOUD_LOCATION"))); - } catch (NullPointerException e) { + this.location = location != null ? location : System.getenv("GOOGLE_CLOUD_LOCATION"); + + if (Strings.isNullOrEmpty(this.location)) { throw new IllegalArgumentException( "Location must either be provided or set in the environment variable" - + " GOOGLE_CLOUD_LOCATION.", - e); - } - if (this.location.get().isEmpty()) { - throw new IllegalArgumentException("Location must not be empty."); + + " GOOGLE_CLOUD_LOCATION."); } - this.credentials = Optional.of(credentials.orElseGet(this::defaultCredentials)); + this.credentials = credentials != null ? credentials : defaultCredentials(); this.httpOptions = defaultHttpOptions(/* vertexAI= */ true, this.location); - if (customHttpOptions.isPresent()) { - applyHttpOptions(customHttpOptions.get()); + if (customHttpOptions != null) { + applyHttpOptions(customHttpOptions); } - this.apiKey = Optional.empty(); + this.apiKey = null; this.vertexAI = true; this.httpClient = createHttpClient(httpOptions.timeout().orElse(null)); } @@ -142,17 +128,17 @@ public boolean vertexAI() { /** Returns the project ID for Vertex AI APIs. */ public @Nullable String project() { - return project.orElse(null); + return project; } /** Returns the location for Vertex AI APIs. */ public @Nullable String location() { - return location.orElse(null); + return location; } /** Returns the API key for Google AI APIs. */ public @Nullable String apiKey() { - return apiKey.orElse(null); + return apiKey; } /** Returns the HttpClient for API calls. */ @@ -192,7 +178,7 @@ private void applyHttpOptions(HttpOptions httpOptionsToApply) { this.httpOptions = mergedHttpOptionsBuilder.build(); } - static HttpOptions defaultHttpOptions(boolean vertexAI, Optional location) { + static HttpOptions defaultHttpOptions(boolean vertexAI, @Nullable String location) { ImmutableMap.Builder defaultHeaders = ImmutableMap.builder(); defaultHeaders .put("Content-Type", "application/json") @@ -202,14 +188,14 @@ static HttpOptions defaultHttpOptions(boolean vertexAI, Optional locatio HttpOptions.Builder defaultHttpOptionsBuilder = HttpOptions.builder().headers(defaultHeaders.buildOrThrow()); - if (vertexAI && location.isPresent()) { + if (vertexAI && location != null) { defaultHttpOptionsBuilder .baseUrl( - Ascii.equalsIgnoreCase(location.get(), "global") + Ascii.equalsIgnoreCase(location, "global") ? "https://aiplatform.googleapis.com" - : String.format("https://%s-aiplatform.googleapis.com", location.get())) + : String.format("https://%s-aiplatform.googleapis.com", location)) .apiVersion("v1beta1"); - } else if (vertexAI && location.isEmpty()) { + } else if (vertexAI && Strings.isNullOrEmpty(location)) { throw new IllegalArgumentException("Location must be provided for Vertex AI APIs."); } else { defaultHttpOptionsBuilder diff --git a/core/src/main/java/com/google/adk/sessions/HttpApiClient.java b/core/src/main/java/com/google/adk/sessions/HttpApiClient.java index bba39da89..3ddb97bda 100644 --- a/core/src/main/java/com/google/adk/sessions/HttpApiClient.java +++ b/core/src/main/java/com/google/adk/sessions/HttpApiClient.java @@ -18,16 +18,17 @@ import com.google.auth.oauth2.GoogleCredentials; import com.google.common.base.Ascii; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.genai.errors.GenAiIOException; import com.google.genai.types.HttpOptions; import java.io.IOException; import java.util.Map; -import java.util.Optional; import okhttp3.MediaType; import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; +import org.jspecify.annotations.Nullable; /** Base client for the HTTP APIs. */ public class HttpApiClient extends ApiClient { @@ -35,16 +36,16 @@ public class HttpApiClient extends ApiClient { MediaType.parse("application/json; charset=utf-8"); /** Constructs an ApiClient for Google AI APIs. */ - HttpApiClient(Optional apiKey, Optional httpOptions) { + HttpApiClient(@Nullable String apiKey, @Nullable HttpOptions httpOptions) { super(apiKey, httpOptions); } /** Constructs an ApiClient for Vertex AI APIs. */ HttpApiClient( - Optional project, - Optional location, - Optional credentials, - Optional httpOptions) { + @Nullable String project, + @Nullable String location, + @Nullable GoogleCredentials credentials, + @Nullable HttpOptions httpOptions) { super(project, location, credentials, httpOptions); } @@ -54,9 +55,7 @@ public ApiResponse request(String httpMethod, String path, String requestJson) { boolean queryBaseModel = Ascii.equalsIgnoreCase(httpMethod, "GET") && path.startsWith("publishers/google/models/"); if (this.vertexAI() && !path.startsWith("projects/") && !queryBaseModel) { - path = - String.format("projects/%s/locations/%s/", this.project.get(), this.location.get()) - + path; + path = String.format("projects/%s/locations/%s/", this.project, this.location) + path; } String requestUrl = String.format( @@ -85,11 +84,11 @@ private void setHeaders(Request.Builder requestBuilder) { requestBuilder.header(header.getKey(), header.getValue()); } - if (apiKey.isPresent()) { - requestBuilder.header("x-goog-api-key", apiKey.get()); + if (apiKey != null) { + requestBuilder.header("x-goog-api-key", apiKey); } else { - GoogleCredentials cred = - credentials.orElseThrow(() -> new IllegalStateException("credentials is required")); + Preconditions.checkState(credentials != null, "credentials is required"); + GoogleCredentials cred = credentials; try { cred.refreshIfExpired(); } catch (IOException e) { diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java index 718738b92..1168d1166 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java @@ -17,7 +17,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; import okhttp3.ResponseBody; @@ -37,17 +36,15 @@ final class VertexAiClient { } VertexAiClient() { - this.apiClient = - new HttpApiClient(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + this.apiClient = new HttpApiClient((String) null, null, null, null); } VertexAiClient( String project, String location, - Optional credentials, - Optional httpOptions) { - this.apiClient = - new HttpApiClient(Optional.of(project), Optional.of(location), credentials, httpOptions); + @Nullable GoogleCredentials credentials, + @Nullable HttpOptions httpOptions) { + this.apiClient = new HttpApiClient(project, location, credentials, httpOptions); } Maybe createSession( diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 2fff7a752..4336f96c9 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -40,7 +40,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.regex.Matcher; import java.util.regex.Pattern; -import javax.annotation.Nullable; +import org.jspecify.annotations.Nullable; /** Connects to the managed Vertex AI Session Service. */ // TODO: Use the genai HttpApiClient and ApiResponse methods once they are public. @@ -65,8 +65,8 @@ public VertexAiSessionService() { public VertexAiSessionService( String project, String location, - Optional credentials, - Optional httpOptions) { + @Nullable GoogleCredentials credentials, + @Nullable HttpOptions httpOptions) { this.client = new VertexAiClient(project, location, credentials, httpOptions); } From a47b651b5c4868a603fd79df164b70bc712c3a80 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Fri, 13 Mar 2026 07:06:00 -0700 Subject: [PATCH 42/50] chore: override new version to 0.9.0 Release-As: 0.9.0 PiperOrigin-RevId: 883145624 --- .release-please-manifest.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index d6a5f76bd..989b42066 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,4 @@ { ".": "0.8.0" } + From 94e671560f68b3a34ccc64d96544d2feed4a952d Mon Sep 17 00:00:00 2001 From: adk-java-releases-bot Date: Fri, 13 Mar 2026 15:08:29 +0100 Subject: [PATCH 43/50] chore(main): release 0.9.0 --- .release-please-manifest.json | 2 +- CHANGELOG.md | 56 +++++++++++++++++++ README.md | 4 +- a2a/pom.xml | 2 +- contrib/firestore-session-service/pom.xml | 2 +- contrib/langchain4j/pom.xml | 2 +- contrib/samples/a2a_basic/pom.xml | 2 +- contrib/samples/a2a_server/pom.xml | 2 +- contrib/samples/configagent/pom.xml | 2 +- contrib/samples/helloworld/pom.xml | 2 +- contrib/samples/mcpfilesystem/pom.xml | 2 +- contrib/samples/pom.xml | 2 +- contrib/spring-ai/pom.xml | 2 +- core/pom.xml | 2 +- .../src/main/java/com/google/adk/Version.java | 2 +- dev/pom.xml | 2 +- maven_plugin/examples/custom_tools/pom.xml | 2 +- maven_plugin/examples/simple-agent/pom.xml | 2 +- maven_plugin/pom.xml | 2 +- pom.xml | 2 +- tutorials/city-time-weather/pom.xml | 2 +- tutorials/live-audio-single-agent/pom.xml | 2 +- 22 files changed, 78 insertions(+), 22 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 989b42066..b0f3ba770 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,4 +1,4 @@ { - ".": "0.8.0" + ".": "0.9.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d5b9e5eb..ab111e90c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,61 @@ # Changelog +## [0.9.0](https://github.com/google/adk-java/compare/v0.8.0...v0.9.0) (2026-03-13) + + +### ⚠ BREAKING CHANGES + +* refactor ApiClient constructors hierarchy to remove Optional parameters +* remove deprecated LlmAgent.canonicalTools method +* remove deprecated LoadArtifactsTool.loadArtifacts method +* update LoopAgent's maxIteration field and methods to be @Nullable instead of Optional +* Remove Optional parameters in EventActions +* remove deprecated url method in ComputerState.Builder +* Remove deprecated create method in ResponseProcessor +* remove McpAsyncToolset constructors +* use @Nullable fields in Event class +* remove methods with Optional params from VertexCredential.Builder + +### Features + +* add formatting to the RemoteA2A agent so it filters out the previous agent responses and updates the context of the function calls and responses ([0d6dd55](https://github.com/google/adk-java/commit/0d6dd55f4870007e79db23e21bd261879dbfba79)) +* add multiple LLM responses to LLM recordings for conformance tests ([bdfb7a7](https://github.com/google/adk-java/commit/bdfb7a72188ce6e72c12c16c0abedb824b846160)) +* add support for gemini models in VertexAiRagRetrieval ([924fb71](https://github.com/google/adk-java/commit/924fb7174855b46a58be43373c1a29284c47dfa8)) +* Fixing the spans produced by agent calls to have the right parent spans ([3c8f488](https://github.com/google/adk-java/commit/3c8f4886f0e4c76abdbeb64a348bfccd5c16120e)) +* Fixing the spans produced by agent calls to have the right parent spans ([973f887](https://github.com/google/adk-java/commit/973f88743cabebcd2e6e7a8d5f141142b596dbbb)) +* refactor ApiClient constructors hierarchy to remove Optional parameters ([910d727](https://github.com/google/adk-java/commit/910d727f1981498151dea4cb91b9e5836f91e3ba)) +* Remove deprecated create method in ResponseProcessor ([5e1e1d4](https://github.com/google/adk-java/commit/5e1e1d434fa1f3931af30194422800757de96cb6)) +* remove deprecated LlmAgent.canonicalTools method ([aabf15a](https://github.com/google/adk-java/commit/aabf15a526ba525cdb47c74c246c178eff1851d5)) +* remove deprecated LoadArtifactsTool.loadArtifacts method ([bc38558](https://github.com/google/adk-java/commit/bc385589057a6daf0209a335280bf19d20b2126b)) +* remove deprecated url method in ComputerState.Builder ([a86ede0](https://github.com/google/adk-java/commit/a86ede007c3442ed73ee08a5c6ad0e2efa12998a)) +* remove executionId method that takes Optional param from CodeExecutionUtils ([be3b3f8](https://github.com/google/adk-java/commit/be3b3f8360888ea1f13796969bb19893c32727e0)) +* remove McpAsyncToolset constructors ([82ef5ac](https://github.com/google/adk-java/commit/82ef5ac2689e01676aa95d2616e3b4d8463e573e)) +* remove methods with Optional params from VertexCredential.Builder ([0b9057c](https://github.com/google/adk-java/commit/0b9057c9ccab98ea58597ec55b8168e32ac7c9a6)) +* Remove Optional parameters in EventActions ([b8316b1](https://github.com/google/adk-java/commit/b8316b1944ce17cc9208963cc09d900c379444c6)) +* replace Optional type of version in BaseArtifactService.loadArtifact with Nullable ([5fd4c53](https://github.com/google/adk-java/commit/5fd4c53c88e977d004b9eee8fa3697625ec85f47)) +* Trigger traceCallLlm to set call_llm attributes before span ends ([d9d84ee](https://github.com/google/adk-java/commit/d9d84ee67406cce8eeb66abcf1be24fad9c58e29)) +* Update converters for task and artifact events; add long running tools ids ([9ce78d7](https://github.com/google/adk-java/commit/9ce78d7c3e1b0fb6d8d4fdce9052a572ffb9e515)) +* update LoopAgent's maxIteration field and methods to be @Nullable instead of Optional ([e0d833b](https://github.com/google/adk-java/commit/e0d833b337e958e299d0d11a03f6bfa1468731bc)) +* update return type for artifactDelta getter and setter to Map from ConcurrentMap ([d1d5539](https://github.com/google/adk-java/commit/d1d5539ef763b6bfd5057c6ea0f2591225a98535)) +* update return type for requestedToolConfirmations getter and setter to Map from ConcurrentMap ([143b656](https://github.com/google/adk-java/commit/143b656949d61363d135e0b74ef5696e78eb270a)) +* update return type for stateDelta() to Map from ConcurrentMap ([3f6504e](https://github.com/google/adk-java/commit/3f6504e9416f9f644ef431e612ec983b9a2edd9d)) +* update State constructors to accept general Map types ([c6fdb63](https://github.com/google/adk-java/commit/c6fdb63c92e2f3481a01cfeafa946b6dce728c51)) +* use @Nullable fields in Event class ([67b602f](https://github.com/google/adk-java/commit/67b602f245f564238ea22298a37bf70049e56a12)) + + +### Bug Fixes + +* Explicitly setting the otel parent spans in agents, llm flow and function calls ([20f863f](https://github.com/google/adk-java/commit/20f863f716f653979551c481d85d4e7fa56a35da)) +* Make sure that `InvocationContext.callbackContextData` remains the same instance ([14ee28b](https://github.com/google/adk-java/commit/14ee28ba593a9f6f5f7b9bb6003441539fe33a18)) +* Removing deprecated InvocationContext methods ([41f5af0](https://github.com/google/adk-java/commit/41f5af0dceb78501ca8b94e434e4d751f608a699)) +* Removing deprecated methods in Runner ([0d8e22d](https://github.com/google/adk-java/commit/0d8e22d6e9fe4e8d29c87d485915ba51a22eb350)) +* Removing deprecated methods in Runner ([b857f01](https://github.com/google/adk-java/commit/b857f010a0f51df0eb25ecdc364465ffdd9fef65)) + + +### Miscellaneous Chores + +* override new version to 0.9.0 ([a47b651](https://github.com/google/adk-java/commit/a47b651b5c4868a603fd79df164b70bc712c3a80)) + ## [0.8.0](https://github.com/google/adk-java/compare/v0.7.0...v0.8.0) (2026-03-06) diff --git a/README.md b/README.md index 4a5dab81f..de1cfbef7 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,13 @@ If you're using Maven, add the following to your dependencies: com.google.adk google-adk - 0.8.0 + 0.9.0 com.google.adk google-adk-dev - 0.8.0 + 0.9.0 ``` diff --git a/a2a/pom.xml b/a2a/pom.xml index 5857720fd..1a6da6e41 100644 --- a/a2a/pom.xml +++ b/a2a/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 google-adk-a2a diff --git a/contrib/firestore-session-service/pom.xml b/contrib/firestore-session-service/pom.xml index a62bff5b6..f0e1f0d0e 100644 --- a/contrib/firestore-session-service/pom.xml +++ b/contrib/firestore-session-service/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 ../../pom.xml diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index e2ba4a7fb..a180c2e69 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 ../../pom.xml diff --git a/contrib/samples/a2a_basic/pom.xml b/contrib/samples/a2a_basic/pom.xml index 82b11b96f..708debeb6 100644 --- a/contrib/samples/a2a_basic/pom.xml +++ b/contrib/samples/a2a_basic/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.8.1-SNAPSHOT + 0.9.0 .. diff --git a/contrib/samples/a2a_server/pom.xml b/contrib/samples/a2a_server/pom.xml index 84023e260..3fc0141fe 100644 --- a/contrib/samples/a2a_server/pom.xml +++ b/contrib/samples/a2a_server/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.8.1-SNAPSHOT + 0.9.0 .. diff --git a/contrib/samples/configagent/pom.xml b/contrib/samples/configagent/pom.xml index 6f7bfff83..61f9f3010 100644 --- a/contrib/samples/configagent/pom.xml +++ b/contrib/samples/configagent/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.8.1-SNAPSHOT + 0.9.0 .. diff --git a/contrib/samples/helloworld/pom.xml b/contrib/samples/helloworld/pom.xml index 36d12eaf0..d0a2cf686 100644 --- a/contrib/samples/helloworld/pom.xml +++ b/contrib/samples/helloworld/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-samples - 0.8.1-SNAPSHOT + 0.9.0 .. diff --git a/contrib/samples/mcpfilesystem/pom.xml b/contrib/samples/mcpfilesystem/pom.xml index 935aa6531..ab606a392 100644 --- a/contrib/samples/mcpfilesystem/pom.xml +++ b/contrib/samples/mcpfilesystem/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 ../../.. diff --git a/contrib/samples/pom.xml b/contrib/samples/pom.xml index 905f8e711..7ce31df73 100644 --- a/contrib/samples/pom.xml +++ b/contrib/samples/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 ../.. diff --git a/contrib/spring-ai/pom.xml b/contrib/spring-ai/pom.xml index f49c3faae..c2672d359 100644 --- a/contrib/spring-ai/pom.xml +++ b/contrib/spring-ai/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 36ab783dd..059e481a9 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 google-adk diff --git a/core/src/main/java/com/google/adk/Version.java b/core/src/main/java/com/google/adk/Version.java index 1dc0282c3..a7aeb8b1f 100644 --- a/core/src/main/java/com/google/adk/Version.java +++ b/core/src/main/java/com/google/adk/Version.java @@ -22,7 +22,7 @@ */ public final class Version { // Don't touch this, release-please should keep it up to date. - public static final String JAVA_ADK_VERSION = "0.8.0"; // x-release-please-released-version + public static final String JAVA_ADK_VERSION = "0.9.0"; // x-release-please-released-version private Version() {} } diff --git a/dev/pom.xml b/dev/pom.xml index 57aa808c2..ac820d4f6 100644 --- a/dev/pom.xml +++ b/dev/pom.xml @@ -18,7 +18,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 google-adk-dev diff --git a/maven_plugin/examples/custom_tools/pom.xml b/maven_plugin/examples/custom_tools/pom.xml index abd3c60f2..054fb5ff3 100644 --- a/maven_plugin/examples/custom_tools/pom.xml +++ b/maven_plugin/examples/custom_tools/pom.xml @@ -4,7 +4,7 @@ com.example custom-tools-example - 0.8.1-SNAPSHOT + 0.9.0 jar ADK Custom Tools Example diff --git a/maven_plugin/examples/simple-agent/pom.xml b/maven_plugin/examples/simple-agent/pom.xml index 309fe9364..25c783909 100644 --- a/maven_plugin/examples/simple-agent/pom.xml +++ b/maven_plugin/examples/simple-agent/pom.xml @@ -4,7 +4,7 @@ com.example simple-adk-agent - 0.8.1-SNAPSHOT + 0.9.0 jar Simple ADK Agent Example diff --git a/maven_plugin/pom.xml b/maven_plugin/pom.xml index 6ff3404f3..9c13f2110 100644 --- a/maven_plugin/pom.xml +++ b/maven_plugin/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 ../pom.xml diff --git a/pom.xml b/pom.xml index ffe904d74..14bc61131 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 pom Google Agent Development Kit Maven Parent POM diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index f4e8bdb52..9c6fde79e 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 ../../pom.xml diff --git a/tutorials/live-audio-single-agent/pom.xml b/tutorials/live-audio-single-agent/pom.xml index b6e649222..3a663874c 100644 --- a/tutorials/live-audio-single-agent/pom.xml +++ b/tutorials/live-audio-single-agent/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.8.1-SNAPSHOT + 0.9.0 ../../pom.xml From 0d2c37c19044c1b8bf86125e7e4deec4b1293028 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 13 Mar 2026 07:44:44 -0700 Subject: [PATCH 44/50] feat: Introducing Tracing.withContext() Tracing.withContext() will improve how RxJava + Tracing works. Here's an example of how it will be used: ``` this.pluginManager .onUserMessageCallback(initialContext, newMessage) .compose(Tracing.withContext(capturedContext)) ``` The `.compose()` is a standin for calling `capturedContext.makeCurrent()` in the event handler. PiperOrigin-RevId: 883160268 --- .../com/google/adk/telemetry/Tracing.java | 187 ++++++++++++++++++ .../adk/telemetry/ContextPropagationTest.java | 83 +++++++- 2 files changed, 267 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 35bf3cc96..215e317e1 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -37,16 +37,20 @@ import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.CompletableObserver; import io.reactivex.rxjava3.core.CompletableSource; import io.reactivex.rxjava3.core.CompletableTransformer; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.FlowableTransformer; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.MaybeObserver; import io.reactivex.rxjava3.core.MaybeSource; import io.reactivex.rxjava3.core.MaybeTransformer; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.core.SingleObserver; import io.reactivex.rxjava3.core.SingleSource; import io.reactivex.rxjava3.core.SingleTransformer; +import io.reactivex.rxjava3.disposables.Disposable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -58,6 +62,8 @@ import java.util.function.Consumer; import java.util.function.Supplier; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -550,4 +556,185 @@ public CompletableSource apply(Completable upstream) { }); } } + + /** + * Returns a transformer that re-activates a given context for the duration of the stream's + * subscription. + * + * @param context The context to re-activate. + * @param The type of the stream. + * @return A transformer that re-activates the context. + */ + public static ContextTransformer withContext(Context context) { + return new ContextTransformer<>(context); + } + + /** + * A transformer that re-activates a given context for the duration of the stream's subscription. + * + * @param The type of the stream. + */ + public static final class ContextTransformer + implements FlowableTransformer, + SingleTransformer, + MaybeTransformer, + CompletableTransformer { + private final Context context; + + private ContextTransformer(Context context) { + this.context = context; + } + + @Override + public Publisher apply(Flowable upstream) { + return upstream.lift(subscriber -> TracingObserver.wrap(context, subscriber)); + } + + @Override + public SingleSource apply(Single upstream) { + return upstream.lift(observer -> TracingObserver.wrap(context, observer)); + } + + @Override + public MaybeSource apply(Maybe upstream) { + return upstream.lift(observer -> TracingObserver.wrap(context, observer)); + } + + @Override + public CompletableSource apply(Completable upstream) { + return upstream.lift(observer -> TracingObserver.wrap(context, observer)); + } + } + + /** + * An observer that wraps another observer and ensures that the OpenTelemetry context is active + * during all callback methods. + * + *

This implementation only wraps the data-flow callbacks (`onNext`, `onSuccess`, etc.). The + * `Subscription.request/cancel` and `Disposable.dispose` calls are not wrapped in the context. If + * the upstream logic depends on the context during these signals, they might lose trace + * information. Given this is a manual `withContext` utility, this might be an acceptable + * trade-off for simplicity/performance, but worth keeping in mind. + * + * @param The type of the items emitted by the stream. + */ + private static final class TracingObserver + implements Subscriber, SingleObserver, MaybeObserver, CompletableObserver { + private final Context context; + private final Subscriber subscriber; + private final SingleObserver singleObserver; + private final MaybeObserver maybeObserver; + private final CompletableObserver completableObserver; + + private TracingObserver( + Context context, + Subscriber subscriber, + SingleObserver singleObserver, + MaybeObserver maybeObserver, + CompletableObserver completableObserver) { + this.context = context; + this.subscriber = subscriber; + this.singleObserver = singleObserver; + this.maybeObserver = maybeObserver; + this.completableObserver = completableObserver; + } + + static TracingObserver wrap(Context context, Subscriber subscriber) { + return new TracingObserver<>(context, subscriber, null, null, null); + } + + static TracingObserver wrap(Context context, SingleObserver observer) { + return new TracingObserver<>(context, null, observer, null, null); + } + + static TracingObserver wrap(Context context, MaybeObserver observer) { + return new TracingObserver<>(context, null, null, observer, null); + } + + static TracingObserver wrap(Context context, CompletableObserver observer) { + return new TracingObserver<>(context, null, null, null, observer); + } + + private void runInContext(Runnable action) { + try (Scope scope = context.makeCurrent()) { + action.run(); + } + } + + @Override + public void onSubscribe(Subscription s) { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onSubscribe(s); + } + }); + } + + @Override + public void onSubscribe(Disposable d) { + runInContext( + () -> { + if (singleObserver != null) { + singleObserver.onSubscribe(d); + } else if (maybeObserver != null) { + maybeObserver.onSubscribe(d); + } else if (completableObserver != null) { + completableObserver.onSubscribe(d); + } + }); + } + + @Override + public void onNext(T t) { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onNext(t); + } + }); + } + + @Override + public void onSuccess(T t) { + runInContext( + () -> { + if (singleObserver != null) { + singleObserver.onSuccess(t); + } else if (maybeObserver != null) { + maybeObserver.onSuccess(t); + } + }); + } + + @Override + public void onError(Throwable t) { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onError(t); + } else if (singleObserver != null) { + singleObserver.onError(t); + } else if (maybeObserver != null) { + maybeObserver.onError(t); + } else if (completableObserver != null) { + completableObserver.onError(t); + } + }); + } + + @Override + public void onComplete() { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onComplete(); + } else if (maybeObserver != null) { + maybeObserver.onComplete(); + } else if (completableObserver != null) { + completableObserver.onComplete(); + } + }); + } + } } diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index f809193cf..e5795d61f 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -31,6 +31,7 @@ import com.google.adk.runner.Runner; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; +import com.google.adk.sessions.SessionKey; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -44,12 +45,17 @@ import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; import io.opentelemetry.context.Scope; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.List; +import java.util.Map; import java.util.Optional; import org.junit.After; import org.junit.Before; @@ -380,6 +386,70 @@ public void testTraceFlowable() throws InterruptedException { assertTrue(flowableSpanData.hasEnded()); } + @Test + public void testWithContextFlowable() throws InterruptedException { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.root().with(testKey, "test-value"); + + Flowable flowable = + Flowable.just(1, 2, 3) + .compose(Tracing.withContext(testContext)) + .subscribeOn(Schedulers.computation()) + .doOnNext( + i -> { + assertEquals("test-value", Context.current().get(testKey)); + }); + flowable.test().await().assertComplete(); + } + + @Test + public void testWithContextSingle() throws InterruptedException { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.root().with(testKey, "test-value"); + + Single single = + Single.just(1) + .compose(Tracing.withContext(testContext)) + .subscribeOn(Schedulers.computation()) + .doOnSuccess( + i -> { + assertEquals("test-value", Context.current().get(testKey)); + }); + single.test().await().assertComplete(); + } + + @Test + public void testWithContextMaybe() throws InterruptedException { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.root().with(testKey, "test-value"); + + Maybe maybe = + Maybe.just(1) + .compose(Tracing.withContext(testContext)) + .subscribeOn(Schedulers.computation()) + .doOnSuccess( + i -> { + assertEquals("test-value", Context.current().get(testKey)); + }); + maybe.test().await().assertComplete(); + } + + @Test + public void testWithContextCompletable() throws InterruptedException { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.root().with(testKey, "test-value"); + + Completable completable = + Completable.complete() + .compose(Tracing.withContext(testContext)) + .subscribeOn(Schedulers.computation()) + .doOnComplete( + () -> { + assertEquals("test-value", Context.current().get(testKey)); + }); + completable.test().await().assertComplete(); + } + @Test public void testTraceTransformer() throws InterruptedException { Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -595,7 +665,7 @@ public void runnerRunAsync_propagatesContext() throws InterruptedException { Session session = runner .sessionService() - .createSession("test-app", "test-user", null, "test-session") + .createSession(new SessionKey("test-app", "test-user", "test-session")) .blockingGet(); Content newMessage = Content.fromParts(Part.fromText("hi")); RunConfig runConfig = RunConfig.builder().build(); @@ -623,13 +693,20 @@ public void runnerRunLive_propagatesContext() throws InterruptedException { Span parentSpan = tracer.spanBuilder("parent").startSpan(); try (Scope s = parentSpan.makeCurrent()) { Session session = - Session.builder("test-session").userId("test-user").appName("test-app").build(); + runner + .sessionService() + .createSession("test-app", "test-user", (Map) null, "test-session") + .blockingGet(); Content newMessage = Content.fromParts(Part.fromText("hi")); RunConfig runConfig = RunConfig.builder().build(); LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); liveRequestQueue.content(newMessage); liveRequestQueue.close(); - runner.runLive(session, liveRequestQueue, runConfig).test().await().assertComplete(); + runner + .runLive(session.userId(), session.id(), liveRequestQueue, runConfig) + .test() + .await() + .assertComplete(); } finally { parentSpan.end(); } From 743ee78e4fbd9b0f275e494d5549b9340675740b Mon Sep 17 00:00:00 2001 From: adk-java-releases-bot Date: Fri, 13 Mar 2026 15:52:18 +0100 Subject: [PATCH 45/50] chore(main): release 0.9.1-SNAPSHOT --- a2a/pom.xml | 2 +- contrib/firestore-session-service/pom.xml | 2 +- contrib/langchain4j/pom.xml | 2 +- contrib/samples/a2a_basic/pom.xml | 2 +- contrib/samples/a2a_server/pom.xml | 2 +- contrib/samples/configagent/pom.xml | 2 +- contrib/samples/helloworld/pom.xml | 2 +- contrib/samples/mcpfilesystem/pom.xml | 2 +- contrib/samples/pom.xml | 2 +- contrib/spring-ai/pom.xml | 2 +- core/pom.xml | 2 +- dev/pom.xml | 2 +- maven_plugin/examples/custom_tools/pom.xml | 2 +- maven_plugin/examples/simple-agent/pom.xml | 2 +- maven_plugin/pom.xml | 2 +- pom.xml | 2 +- tutorials/city-time-weather/pom.xml | 2 +- tutorials/live-audio-single-agent/pom.xml | 2 +- 18 files changed, 18 insertions(+), 18 deletions(-) diff --git a/a2a/pom.xml b/a2a/pom.xml index 1a6da6e41..a2f9d9456 100644 --- a/a2a/pom.xml +++ b/a2a/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT google-adk-a2a diff --git a/contrib/firestore-session-service/pom.xml b/contrib/firestore-session-service/pom.xml index f0e1f0d0e..0079dce24 100644 --- a/contrib/firestore-session-service/pom.xml +++ b/contrib/firestore-session-service/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT ../../pom.xml diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index a180c2e69..c2326fa0a 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT ../../pom.xml diff --git a/contrib/samples/a2a_basic/pom.xml b/contrib/samples/a2a_basic/pom.xml index 708debeb6..0eccb733b 100644 --- a/contrib/samples/a2a_basic/pom.xml +++ b/contrib/samples/a2a_basic/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.9.0 + 0.9.1-SNAPSHOT .. diff --git a/contrib/samples/a2a_server/pom.xml b/contrib/samples/a2a_server/pom.xml index 3fc0141fe..0677ad718 100644 --- a/contrib/samples/a2a_server/pom.xml +++ b/contrib/samples/a2a_server/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.9.0 + 0.9.1-SNAPSHOT .. diff --git a/contrib/samples/configagent/pom.xml b/contrib/samples/configagent/pom.xml index 61f9f3010..059bd8a38 100644 --- a/contrib/samples/configagent/pom.xml +++ b/contrib/samples/configagent/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.9.0 + 0.9.1-SNAPSHOT .. diff --git a/contrib/samples/helloworld/pom.xml b/contrib/samples/helloworld/pom.xml index d0a2cf686..df5d5e709 100644 --- a/contrib/samples/helloworld/pom.xml +++ b/contrib/samples/helloworld/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-samples - 0.9.0 + 0.9.1-SNAPSHOT .. diff --git a/contrib/samples/mcpfilesystem/pom.xml b/contrib/samples/mcpfilesystem/pom.xml index ab606a392..16b139d35 100644 --- a/contrib/samples/mcpfilesystem/pom.xml +++ b/contrib/samples/mcpfilesystem/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT ../../.. diff --git a/contrib/samples/pom.xml b/contrib/samples/pom.xml index 7ce31df73..4a415113f 100644 --- a/contrib/samples/pom.xml +++ b/contrib/samples/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT ../.. diff --git a/contrib/spring-ai/pom.xml b/contrib/spring-ai/pom.xml index c2672d359..b24fa4b63 100644 --- a/contrib/spring-ai/pom.xml +++ b/contrib/spring-ai/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 059e481a9..8c3c2069c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT google-adk diff --git a/dev/pom.xml b/dev/pom.xml index ac820d4f6..6cabcba7c 100644 --- a/dev/pom.xml +++ b/dev/pom.xml @@ -18,7 +18,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT google-adk-dev diff --git a/maven_plugin/examples/custom_tools/pom.xml b/maven_plugin/examples/custom_tools/pom.xml index 054fb5ff3..f2118f9cc 100644 --- a/maven_plugin/examples/custom_tools/pom.xml +++ b/maven_plugin/examples/custom_tools/pom.xml @@ -4,7 +4,7 @@ com.example custom-tools-example - 0.9.0 + 0.9.1-SNAPSHOT jar ADK Custom Tools Example diff --git a/maven_plugin/examples/simple-agent/pom.xml b/maven_plugin/examples/simple-agent/pom.xml index 25c783909..5c0f4462d 100644 --- a/maven_plugin/examples/simple-agent/pom.xml +++ b/maven_plugin/examples/simple-agent/pom.xml @@ -4,7 +4,7 @@ com.example simple-adk-agent - 0.9.0 + 0.9.1-SNAPSHOT jar Simple ADK Agent Example diff --git a/maven_plugin/pom.xml b/maven_plugin/pom.xml index 9c13f2110..c48331f72 100644 --- a/maven_plugin/pom.xml +++ b/maven_plugin/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index 14bc61131..11696db73 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT pom Google Agent Development Kit Maven Parent POM diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index 9c6fde79e..76b7331f3 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT ../../pom.xml diff --git a/tutorials/live-audio-single-agent/pom.xml b/tutorials/live-audio-single-agent/pom.xml index 3a663874c..a330cf4bd 100644 --- a/tutorials/live-audio-single-agent/pom.xml +++ b/tutorials/live-audio-single-agent/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.0 + 0.9.1-SNAPSHOT ../../pom.xml From b6356d27c4dfbafdaa5803cb766b57ec5f09091a Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Mon, 16 Mar 2026 06:03:42 -0700 Subject: [PATCH 46/50] chore: update mcp dependency version to 0.17.2 PiperOrigin-RevId: 884394270 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 11696db73..bd0caca0d 100644 --- a/pom.xml +++ b/pom.xml @@ -49,7 +49,7 @@ cloud libraries. Once they update their otel dependencies we can consider updating ours here as well --> 1.51.0 - 0.14.0 + 0.17.2 2.47.0 1.41.0 4.33.5 From 7ebeb07bf2ee72475484d8a31ccf7b4c601dda96 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 16 Mar 2026 07:00:53 -0700 Subject: [PATCH 47/50] feat: init AGENTS.md file PiperOrigin-RevId: 884415542 --- AGENTS.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..e69de29bb From 567fdf048fee49afc86ca5d7d35f55424a6016ba Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Mon, 16 Mar 2026 09:11:14 -0700 Subject: [PATCH 48/50] fix: fix null handling in runAsyncImpl PiperOrigin-RevId: 884472852 --- .../java/com/google/adk/runner/Runner.java | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 29b2b76d3..5859c4786 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -458,6 +458,9 @@ protected Flowable runAsyncImpl( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { + Preconditions.checkNotNull(session, "session cannot be null"); + Preconditions.checkNotNull(newMessage, "newMessage cannot be null"); + Preconditions.checkNotNull(runConfig, "runConfig cannot be null"); return Flowable.defer( () -> { BaseAgent rootAgent = this.agent; @@ -476,19 +479,14 @@ protected Flowable runAsyncImpl( .defaultIfEmpty(newMessage) .flatMap( content -> - (content != null) - ? appendNewMessageToSession( - session, - content, - initialContext, - runConfig.saveInputBlobsAsArtifacts(), - stateDelta) - : Single.just(null)) + appendNewMessageToSession( + session, + content, + initialContext, + runConfig.saveInputBlobsAsArtifacts(), + stateDelta)) .flatMapPublisher( event -> { - if (event == null) { - return Flowable.empty(); - } // Get the updated session after the message and state delta are // applied return this.sessionService From b8cb7e2db6d5ce20f4d7a1b237bdc155563cf4bd Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Mon, 16 Mar 2026 09:50:32 -0700 Subject: [PATCH 49/50] feat: add type-safe runAsync methods to BaseTool PiperOrigin-RevId: 884493553 --- .../java/com/google/adk/tools/BaseTool.java | 81 +++++++++++++ .../com/google/adk/tools/BaseToolTest.java | 108 ++++++++++++++++++ 2 files changed, 189 insertions(+) diff --git a/core/src/main/java/com/google/adk/tools/BaseTool.java b/core/src/main/java/com/google/adk/tools/BaseTool.java index 1ea2808a1..01a399920 100644 --- a/core/src/main/java/com/google/adk/tools/BaseTool.java +++ b/core/src/main/java/com/google/adk/tools/BaseTool.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonAnySetter; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.JsonBaseModel; import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.models.LlmRequest; @@ -38,6 +39,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import javax.annotation.Nonnull; import org.jspecify.annotations.Nullable; import org.slf4j.Logger; @@ -93,6 +95,85 @@ public Single> runAsync(Map args, ToolContex throw new UnsupportedOperationException("This method is not implemented."); } + /** + * Calls a tool with generic arguments and returns a map of results. The args type {@code T} need + * to be serializable with {@link JsonBaseModel#getMapper()} + */ + public final Single> runAsync(T args, ToolContext toolContext) { + return runAsync(args, toolContext, JsonBaseModel.getMapper()); + } + + /** + * Calls a tool with generic arguments using a custom {@link ObjectMapper} and returns a map of + * results. The args type {@code T} needs to be serializable with the provided {@link + * ObjectMapper}. + */ + public final Single> runAsync( + T args, ToolContext toolContext, ObjectMapper objectMapper) { + return runAsync(args, toolContext, objectMapper, output -> output); + } + + /** + * Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results + * converted to a specified class. The input type {@code I} needs to be serializable and the + * output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}. + */ + public final Single runAsync( + I args, ToolContext toolContext, ObjectMapper objectMapper, Class oClass) { + return runAsync( + args, toolContext, objectMapper, output -> objectMapper.convertValue(output, oClass)); + } + + /** + * Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results + * converted to a specified type reference. The input type {@code I} needs to be serializable and + * the output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}. + */ + public final Single runAsync( + I args, + ToolContext toolContext, + ObjectMapper objectMapper, + TypeReference typeReference) { + return runAsync( + args, + toolContext, + objectMapper, + output -> objectMapper.convertValue(output, typeReference)); + } + + /** + * Calls a tool with generic arguments, returning the results converted to a specified class. The + * input type {@code I} needs to be serializable and the output type {@code O} needs to be + * deserializable with {@link JsonBaseModel#getMapper()} + */ + public final Single runAsync( + I args, ToolContext toolContext, Class oClass) { + return runAsync(args, toolContext, JsonBaseModel.getMapper(), oClass); + } + + /** + * Calls a tool with generic arguments, returning the results converted to a specified type + * reference. The input type needs to be serializable and the output type needs to be + * deserializable with {@link JsonBaseModel#getMapper()} + */ + public final Single runAsync( + I args, ToolContext toolContext, TypeReference typeReference) { + return runAsync(args, toolContext, JsonBaseModel.getMapper(), typeReference); + } + + private Single runAsync( + I args, + ToolContext toolContext, + ObjectMapper objectMapper, + Function, ? extends O> deserializer) { + return Single.defer( + () -> + Single.just( + objectMapper.convertValue(args, new TypeReference>() {}))) + .flatMap(argsMap -> runAsync(argsMap, toolContext)) + .map(deserializer::apply); + } + /** * Processes the outgoing {@link LlmRequest.Builder}. * diff --git a/core/src/test/java/com/google/adk/tools/BaseToolTest.java b/core/src/test/java/com/google/adk/tools/BaseToolTest.java index 2a07e7a44..d3c8da5aa 100644 --- a/core/src/test/java/com/google/adk/tools/BaseToolTest.java +++ b/core/src/test/java/com/google/adk/tools/BaseToolTest.java @@ -2,12 +2,15 @@ import static com.google.common.truth.Truth.assertThat; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.models.Gemini; import com.google.adk.models.LlmRequest; import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.GoogleMaps; @@ -17,6 +20,7 @@ import com.google.genai.types.UrlContext; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.observers.TestObserver; import java.util.Map; import java.util.Optional; import org.junit.Test; @@ -27,6 +31,20 @@ @RunWith(JUnit4.class) public final class BaseToolTest { + private final BaseTool doublingBaseTool = + new BaseTool("doubling-test-tool", "returns doubled args") { + @Override + public Single> runAsync( + Map args, ToolContext toolContext) { + String sArg = (String) args.get("s"); + Integer iArg = (Integer) args.get("i"); + return Single.just( + ImmutableMap.of( + "s", sArg + sArg, + "i", iArg + iArg)); + } + }; + @Test public void processLlmRequestNoDeclarationReturnsSameRequest() { BaseTool tool = @@ -247,4 +265,94 @@ public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() { assertThat(updatedLlmRequest.config().get().tools().get()) .containsExactly(Tool.builder().googleMaps(GoogleMaps.builder().build()).build()); } + + @Test + public void runAsync_withTypeReference_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(42, "foo"); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, /* toolContext= */ null, new TypeReference() {}); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(84, "foofoo"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withClass_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(21, "bar"); + + Single out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, TestToolArgs.class); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(42, "barbar"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withObjectOnly_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(11, "baz"); + + Single> out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null); + TestObserver> testObserver = out.test(); + + testObserver.assertComplete(); + ImmutableMap expected = ImmutableMap.of("i", 22, "s", "bazbaz"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withObjectMapperAndObjectOnly_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(11, "baz"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single> out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, objectMapper); + TestObserver> testObserver = out.test(); + + testObserver.assertComplete(); + ImmutableMap expected = ImmutableMap.of("i", 22, "s", "bazbaz"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withTypeReferenceAndObjectMapper_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(42, "foo"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, + /* toolContext= */ null, + objectMapper, + new TypeReference() {}); + + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(84, "foofoo"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withClassAndObjectMapper_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(21, "bar"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, /* toolContext= */ null, objectMapper, TestToolArgs.class); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(42, "barbar"); + testObserver.assertValue(expected); + } + + public record TestToolArgs(int i, String s) {} } From fca43fbb9684ec8d080e437761f6bb4e38adf255 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 16 Mar 2026 12:38:53 -0700 Subject: [PATCH 50/50] fix: prevent ConcurrentModificationException when session events are modified by another thread during iteration PiperOrigin-RevId: 884587639 --- .../google/adk/flows/llmflows/Contents.java | 12 ++-- .../adk/flows/llmflows/ContentsTest.java | 60 +++++++++++++++++++ 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 6ebd39a9c..840a370c6 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -64,6 +64,11 @@ public Single processRequest( modelName = ""; } + ImmutableList sessionEvents; + synchronized (context.session().events()) { + sessionEvents = ImmutableList.copyOf(context.session().events()); + } + if (llmAgent.includeContents() == LlmAgent.IncludeContents.NONE) { return Single.just( RequestProcessor.RequestProcessingResult.create( @@ -71,7 +76,7 @@ public Single processRequest( .contents( getCurrentTurnContents( context.branch().orElse(null), - context.session().events(), + sessionEvents, context.agent().name(), modelName)) .build(), @@ -80,10 +85,7 @@ public Single processRequest( ImmutableList contents = getContents( - context.branch().orElse(null), - context.session().events(), - context.agent().name(), - modelName); + context.branch().orElse(null), sessionEvents, context.agent().name(), modelName); return Single.just( RequestProcessor.RequestProcessingResult.create( diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 85e78666d..7164991f3 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -36,10 +36,13 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -780,6 +783,63 @@ public void processRequest_notEmptyContent() { assertThat(contents).containsExactly(e.content().get()); } + @Test + public void processRequest_concurrentReadAndWrite_noException() throws Exception { + LlmAgent agent = + LlmAgent.builder().name(AGENT).includeContents(LlmAgent.IncludeContents.DEFAULT).build(); + Session session = + sessionService + .createSession("test-app", "test-user", new HashMap<>(), "test-session") + .blockingGet(); + + // Seed with dummy events to widen the race capability + for (int i = 0; i < 5000; i++) { + session.events().add(createUserEvent("dummy" + i, "dummy")); + } + + InvocationContext context = + InvocationContext.builder() + .invocationId("test-invocation") + .agent(agent) + .session(session) + .sessionService(sessionService) + .build(); + + LlmRequest initialRequest = LlmRequest.builder().build(); + + AtomicReference writerError = new AtomicReference<>(); + CountDownLatch startLatch = new CountDownLatch(1); + + Thread writerThread = + new Thread( + () -> { + startLatch.countDown(); + try { + for (int i = 0; i < 2000; i++) { + session.events().add(createUserEvent("writer" + i, "new data")); + } + } catch (Throwable t) { + writerError.set(t); + } + }); + + writerThread.start(); + startLatch.await(); // wait for writer to be ready + + // Process (read) requests concurrently to trigger race conditions + for (int i = 0; i < 200; i++) { + var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); + if (writerError.get() != null) { + throw new RuntimeException("Writer failed", writerError.get()); + } + } + + writerThread.join(); + if (writerError.get() != null) { + throw new RuntimeException("Writer failed", writerError.get()); + } + } + private static Event createUserEvent(String id, String text) { return Event.builder() .id(id)