diff --git a/.github/workflows/validation.yml b/.github/workflows/validation.yml index 65e66f8fd..241dc1ccb 100644 --- a/.github/workflows/validation.yml +++ b/.github/workflows/validation.yml @@ -23,13 +23,13 @@ jobs: uses: actions/checkout@v6 - name: Set up Java ${{ matrix.java-version }} - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: temurin java-version: ${{ matrix.java-version }} - name: Cache Maven packages - uses: actions/cache@v3 + uses: actions/cache@v5 with: path: ~/.m2/repository key: ${{ runner.os }}-maven-${{ matrix.java-version }}-${{ hashFiles('**/pom.xml') }} diff --git a/.release-please-manifest.json b/.release-please-manifest.json index b0f3ba770..802e9d13f 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,4 +1,3 @@ { - ".": "0.9.0" + ".": "1.0.0-rc.1" } - diff --git a/CHANGELOG.md b/CHANGELOG.md index ab111e90c..4ef000794 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,48 @@ # Changelog +## [1.0.0-rc.1](https://github.com/google/adk-java/compare/v0.9.0...v1.0.0-rc.1) (2026-03-20) + + +### ⚠ BREAKING CHANGES + +* remove McpToolset constructors taking Optional parameters +* remove deprecated Example processor + +### Features + +* add handling the a2a metadata in the RemoteA2AAgent; Add the enum type for the metadata keys ([e51f911](https://github.com/google/adk-java/commit/e51f9112050955657da0dfc3aedc00f90ad739ec)) +* add type-safe runAsync methods to BaseTool ([b8cb7e2](https://github.com/google/adk-java/commit/b8cb7e2db6d5ce20f4d7a1b237bdc155563cf4bd)) +* Enhance LangChain4j to support MCP tools with parametersJsonSchema ([2c71ba1](https://github.com/google/adk-java/commit/2c71ba1332e052189115cd4644b7a473c31ed414)) +* fixing context propagation for agent transfers ([9a08076](https://github.com/google/adk-java/commit/9a080763d83c319f539d1bacac4595d13b299e7e)) +* Implement basic version of BigQuery Agent Analytics Plugin ([c8ab0f9](https://github.com/google/adk-java/commit/c8ab0f96b09a6c9636728d634c62695fcd622246)) +* init AGENTS.md file ([7ebeb07](https://github.com/google/adk-java/commit/7ebeb07bf2ee72475484d8a31ccf7b4c601dda96)) +* Propagating the otel context ([8556d4a](https://github.com/google/adk-java/commit/8556d4af16ff04c6e3b678dcfc3d4bb232abc550)) +* remove McpToolset constructors taking Optional parameters ([dbb1394](https://github.com/google/adk-java/commit/dbb139439d38157b4b9af38c52824b1e8405a495)) +* Return List instead of ImmutableList in CallbackUtil methods ([8af5e03](https://github.com/google/adk-java/commit/8af5e03811dfd548830df43103c81a592c8bf361)) +* update requestedAuthConfigs and its builder to be of general Map types ([f145c74](https://github.com/google/adk-java/commit/f145c744482b6b25f29a0b718bd452065e39d930)) +* Update return type of App.plugins() from ImmutableList to List ([8ba4bfe](https://github.com/google/adk-java/commit/8ba4bfed3fa7045f3344329de7a39acddc64ee30)) +* Update return type of toolsets() from ImmutableList to List ([cd56902](https://github.com/google/adk-java/commit/cd56902b803d4f7a1f3c718529842823d9e4370a)) +* update Session.state() and its builder to be of general Map types ([4b9b99a](https://github.com/google/adk-java/commit/4b9b99ae7149a465ba2ae9b7496e01f669786553)) +* update stateDelta builder input to Map from ConcurrentMap ([0d1e5c7](https://github.com/google/adk-java/commit/0d1e5c7b0c42cea66b178cf8fedf08a8c20f7fd0)) + + +### Bug Fixes + +* fix null handling in runAsyncImpl ([567fdf0](https://github.com/google/adk-java/commit/567fdf048fee49afc86ca5d7d35f55424a6016ba)) +* improve processRequest_concurrentReadAndWrite_noException test case ([4eb3613](https://github.com/google/adk-java/commit/4eb3613b65cb1334e9432960d0f864ef09829c23)) +* include saveArtifact invocations in event chain ([551c31f](https://github.com/google/adk-java/commit/551c31f495aafde8568461cc0aa0973d7df7e5ac)) +* prevent ConcurrentModificationException when session events are modified by another thread during iteration ([fca43fb](https://github.com/google/adk-java/commit/fca43fbb9684ec8d080e437761f6bb4e38adf255)) +* Relaxing constraints for output schema ([d7e03ee](https://github.com/google/adk-java/commit/d7e03eeb067b83abd2afa3ea9bb5fc1c16143245)) +* Removing deprecated methods in Runner ([0af82e6](https://github.com/google/adk-java/commit/0af82e61a3c0dbbd95166a10b450cb507115ab60)) +* Use ConcurrentHashMap in InvocationReplayState ([94de7f1](https://github.com/google/adk-java/commit/94de7f199f86b39bdb7cce6e9800eb05008a8953)), closes [#1009](https://github.com/google/adk-java/issues/1009) +* workaround for the client config streaming settings are not respected ([#983](https://github.com/google/adk-java/issues/983)) ([3ba04d3](https://github.com/google/adk-java/commit/3ba04d33dc8f2ef8b151abe1be4d1c8b7afcc25a)) + + +### Miscellaneous Chores + +* remove deprecated Example processor ([28a8cd0](https://github.com/google/adk-java/commit/28a8cd04ca9348dbe51a15d2be3a2b5307394174)) +* set version to 1.0.0-rc.1 ([dc5d794](https://github.com/google/adk-java/commit/dc5d794c066571c7d87f006767bd32298e2a3ba8)) + ## [0.9.0](https://github.com/google/adk-java/compare/v0.8.0...v0.9.0) (2026-03-13) diff --git a/README.md b/README.md index de1cfbef7..169c78564 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,13 @@ If you're using Maven, add the following to your dependencies: com.google.adk google-adk - 0.9.0 + 1.0.0-rc.1 com.google.adk google-adk-dev - 0.9.0 + 1.0.0-rc.1 ``` diff --git a/a2a/pom.xml b/a2a/pom.xml index d454d63cb..840e78984 100644 --- a/a2a/pom.xml +++ b/a2a/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT google-adk-a2a diff --git a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java index ccb662b7c..4a375980c 100644 --- a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java @@ -117,7 +117,7 @@ private RemoteA2AAgent(Builder builder) { if (this.description.isEmpty() && this.agentCard.description() != null) { this.description = this.agentCard.description(); } - this.streaming = this.agentCard.capabilities().streaming(); + this.streaming = builder.streaming && this.agentCard.capabilities().streaming(); } public static Builder builder() { @@ -133,6 +133,13 @@ public static class Builder { private List subAgents; private List beforeAgentCallback; private List afterAgentCallback; + private boolean streaming; + + @CanIgnoreReturnValue + public Builder streaming(boolean streaming) { + this.streaming = streaming; + return this; + } @CanIgnoreReturnValue public Builder name(String name) { @@ -181,6 +188,10 @@ public RemoteA2AAgent build() { } } + public boolean isStreaming() { + return streaming; + } + private Message.Builder newA2AMessage(Message.Role role, List> parts) { return new Message.Builder().messageId(UUID.randomUUID().toString()).role(role).parts(parts); } diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java b/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java new file mode 100644 index 000000000..d4f1fef58 --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java @@ -0,0 +1,40 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.a2a.converters; + +/** + * Enum for the type of A2A metadata. Adds a prefix used to differentiage ADK-related values stored + * in Metadata an A2A event. + */ +public enum A2AMetadataKey { + TYPE("type"), + IS_LONG_RUNNING("is_long_running"), + PARTIAL("partial"), + GROUNDING_METADATA("grounding_metadata"), + USAGE_METADATA("usage_metadata"), + CUSTOM_METADATA("custom_metadata"), + ERROR_CODE("error_code"); + + private final String type; + + private A2AMetadataKey(String type) { + this.type = "adk_" + type; + } + + public String getType() { + return type; + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java b/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java new file mode 100644 index 000000000..e38f28828 --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.a2a.converters; + +/** + * Enum for the type of ADK metadata. Adds a prefix used to differentiate A2A-related values stored + * in custom metadata of an ADK session event. + */ +public enum AdkMetadataKey { + TASK_ID("task_id"), + CONTEXT_ID("context_id"); + + private final String type; + + private AdkMetadataKey(String type) { + this.type = "a2a:" + type; + } + + public String getType() { + return type; + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 61f24fa21..714a79736 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -52,11 +52,7 @@ public final class PartConverter { private static final Logger logger = LoggerFactory.getLogger(PartConverter.class); private static final ObjectMapper objectMapper = new ObjectMapper(); - // Constants for metadata types. By convention metadata keys are prefixed with "adk_" to align - // with the Python and Golang libraries. - public static final String A2A_DATA_PART_METADATA_TYPE_KEY = "adk_type"; - public static final String A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY = "adk_is_long_running"; - public static final String A2A_DATA_PART_METADATA_IS_PARTIAL_KEY = "adk_partial"; + // Constants for metadata types. public static final String LANGUAGE_KEY = "language"; public static final String OUTCOME_KEY = "outcome"; public static final String CODE_KEY = "code"; @@ -135,7 +131,7 @@ private static com.google.genai.types.Part convertDataPartToGenAiPart(DataPart d Map metadata = Optional.ofNullable(dataPart.getMetadata()).map(HashMap::new).orElseGet(HashMap::new); - String metadataType = metadata.getOrDefault(A2A_DATA_PART_METADATA_TYPE_KEY, "").toString(); + String metadataType = metadata.getOrDefault(A2AMetadataKey.TYPE.getType(), "").toString(); if ((data.containsKey(NAME_KEY) && data.containsKey(ARGS_KEY)) || metadataType.equals(A2ADataPartMetadataType.FUNCTION_CALL.getType())) { @@ -218,7 +214,7 @@ private static DataPart createDataPartFromFunctionCall( addValueIfPresent(data, WILL_CONTINUE_KEY, functionCall.willContinue()); addValueIfPresent(data, PARTIAL_ARGS_KEY, functionCall.partialArgs()); - metadata.put(A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.FUNCTION_CALL.getType()); + metadata.put(A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -245,7 +241,7 @@ private static DataPart createDataPartFromFunctionResponse( addValueIfPresent(data, PARTS_KEY, functionResponse.parts()); metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -268,7 +264,7 @@ private static DataPart createDataPartFromCodeExecutionResult( addValueIfPresent(data, OUTPUT_KEY, codeExecutionResult.output()); metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.CODE_EXECUTION_RESULT.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.CODE_EXECUTION_RESULT.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -290,8 +286,7 @@ private static DataPart createDataPartFromExecutableCode( .orElse(Language.Known.LANGUAGE_UNSPECIFIED.toString())); addValueIfPresent(data, CODE_KEY, executableCode.code()); - metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.EXECUTABLE_CODE.getType()); + metadata.put(A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.EXECUTABLE_CODE.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -305,7 +300,7 @@ public static io.a2a.spec.Part fromGenaiPart(Part part, boolean isPartial) { } ImmutableMap.Builder metadata = ImmutableMap.builder(); if (isPartial) { - metadata.put(A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, true); + metadata.put(A2AMetadataKey.PARTIAL.getType(), true); } if (part.text().isPresent()) { diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/RequestConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/RequestConverter.java index b289ced6c..ae9faadfa 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/RequestConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/RequestConverter.java @@ -61,8 +61,11 @@ public static Optional convertA2aMessageToAdkEvent(Message message, Strin // Convert each A2A Part to GenAI Part if (message.getParts() != null) { for (Part a2aPart : message.getParts()) { - Optional genaiPart = PartConverter.toGenaiPart(a2aPart); - genaiPart.ifPresent(genaiParts::add); + try { + genaiParts.add(PartConverter.toGenaiPart(a2aPart)); + } catch (IllegalArgumentException e) { + logger.debug("Skipping unconvertible A2A part: {}", e.getMessage()); + } } } @@ -125,15 +128,18 @@ public static ImmutableList convertAggregatedA2aMessageToAdkEvents( // Emit exactly one ADK Event per A2A Part, preserving order. for (Part a2aPart : message.getParts()) { - Optional genaiPart = PartConverter.toGenaiPart(a2aPart); - if (genaiPart.isEmpty()) { + com.google.genai.types.Part genaiPart; + try { + genaiPart = PartConverter.toGenaiPart(a2aPart); + } catch (IllegalArgumentException e) { + logger.debug("Skipping unconvertible A2A part in aggregate: {}", e.getMessage()); continue; } String author = extractAuthorFromMetadata(a2aPart); String role = determineRoleFromAuthor(author); - events.add(createEvent(ImmutableList.of(genaiPart.get()), author, role, invocationId)); + events.add(createEvent(ImmutableList.of(genaiPart), author, role, invocationId)); } if (events.isEmpty()) { @@ -162,8 +168,7 @@ private static String extractAuthorFromMetadata(Part a2aPart) { if (a2aPart instanceof DataPart dataPart) { Map metadata = Optional.ofNullable(dataPart.getMetadata()).orElse(ImmutableMap.of()); - String type = - metadata.getOrDefault(PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, "").toString(); + String type = metadata.getOrDefault(A2AMetadataKey.TYPE.getType(), "").toString(); if (type.equals(A2ADataPartMetadataType.FUNCTION_CALL.getType())) { return "model"; } diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java index c6ab896ea..b5733e9a9 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java @@ -19,12 +19,20 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Streams.zip; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.GroundingMetadata; import com.google.genai.types.Part; import io.a2a.client.ClientEvent; import io.a2a.client.MessageEvent; @@ -45,11 +53,13 @@ import java.util.Objects; import java.util.Optional; import java.util.UUID; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** Utility for converting ADK events to A2A spec messages (and back). */ public final class ResponseConverter { + private static final ObjectMapper objectMapper = new ObjectMapper(); private static final Logger logger = LoggerFactory.getLogger(ResponseConverter.class); private static final ImmutableSet PENDING_STATES = ImmutableSet.of(TaskState.WORKING, TaskState.SUBMITTED); @@ -76,12 +86,11 @@ public static Optional clientEventToEvent( throw new IllegalArgumentException("Unsupported ClientEvent type: " + event.getClass()); } - private static boolean isPartial(Map metadata) { + private static boolean isPartial(@Nullable Map metadata) { if (metadata == null) { return false; } - return Objects.equals( - metadata.getOrDefault(PartConverter.A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, false), true); + return Objects.equals(metadata.getOrDefault(A2AMetadataKey.PARTIAL.getType(), false), true); } /** @@ -112,7 +121,12 @@ private static Optional handleTaskUpdate( // append=false, lastChunk=false: emit as partial, reset aggregation // append=true, lastChunk=true: emit as partial, update aggregation and emit as non-partial // append=false, lastChunk=true: emit as non-partial, drop aggregation - return Optional.of(eventPart); + return Optional.of( + updateEventMetadata( + eventPart, + artifactEvent.getMetadata(), + artifactEvent.getTaskId(), + artifactEvent.getContextId())); } if (updateEvent instanceof TaskStatusUpdateEvent statusEvent) { @@ -130,14 +144,21 @@ private static Optional handleTaskUpdate( }); if (statusEvent.isFinal()) { - return messageEvent - .map(Event::toBuilder) - .or(() -> Optional.of(remoteAgentEventBuilder(context))) - .map(builder -> builder.turnComplete(true)) - .map(builder -> builder.partial(false)) - .map(Event.Builder::build); + messageEvent = + messageEvent + .map(Event::toBuilder) + .or(() -> Optional.of(remoteAgentEventBuilder(context))) + .map(builder -> builder.turnComplete(true)) + .map(builder -> builder.partial(false)) + .map(Event.Builder::build); } - return messageEvent; + return messageEvent.map( + finalMessageEvent -> + updateEventMetadata( + finalMessageEvent, + statusEvent.getMetadata(), + statusEvent.getTaskId(), + statusEvent.getContextId())); } throw new IllegalArgumentException( "Unsupported TaskUpdateEvent type: " + updateEvent.getClass()); @@ -165,9 +186,13 @@ public static Event messageToFailedEvent(Message message, InvocationContext invo /** Converts an A2A message back to ADK events. */ public static Event messageToEvent(Message message, InvocationContext invocationContext) { - return remoteAgentEventBuilder(invocationContext) - .content(fromModelParts(PartConverter.toGenaiParts(message.getParts()))) - .build(); + return updateEventMetadata( + remoteAgentEventBuilder(invocationContext) + .content(fromModelParts(PartConverter.toGenaiParts(message.getParts()))) + .build(), + message.getMetadata(), + message.getTaskId(), + message.getContextId()); } /** @@ -230,7 +255,8 @@ public static Event taskToEvent(Task task, InvocationContext invocationContext) eventBuilder.longRunningToolIds(longRunningToolIds.build()); } eventBuilder.turnComplete(isFinal); - return eventBuilder.build(); + return updateEventMetadata( + eventBuilder.build(), task.getMetadata(), task.getId(), task.getContextId()); } private static ImmutableSet getLongRunningToolIds( @@ -243,9 +269,7 @@ private static ImmutableSet getLongRunningToolIds( return Optional.empty(); } Object isLongRunning = - dataPart - .getMetadata() - .get(PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY); + dataPart.getMetadata().get(A2AMetadataKey.IS_LONG_RUNNING.getType()); if (!Objects.equals(isLongRunning, true)) { return Optional.empty(); } @@ -258,6 +282,77 @@ private static ImmutableSet getLongRunningToolIds( .collect(toImmutableSet()); } + private static Event updateEventMetadata( + Event event, + @Nullable Map clientMetadata, + @Nullable String taskId, + @Nullable String contextId) { + if (taskId == null || contextId == null) { + logger.warn("Task ID or context ID is null, skipping metadata update."); + return event; + } + + if (clientMetadata == null) { + clientMetadata = ImmutableMap.of(); + } + Event.Builder eventBuilder = event.toBuilder(); + Object groundingMetadata = clientMetadata.get(A2AMetadataKey.GROUNDING_METADATA.getType()); + // if groundingMetadata is null, parseMetadata will return null as well. + eventBuilder.groundingMetadata(parseMetadata(groundingMetadata, GroundingMetadata.class)); + Object usageMetadata = clientMetadata.get(A2AMetadataKey.USAGE_METADATA.getType()); + // if usageMetadata is null, parseMetadata will return null as well. + eventBuilder.usageMetadata( + parseMetadata(usageMetadata, GenerateContentResponseUsageMetadata.class)); + + ImmutableList.Builder customMetadataList = ImmutableList.builder(); + customMetadataList + .add( + CustomMetadata.builder() + .key(AdkMetadataKey.TASK_ID.getType()) + .stringValue(taskId) + .build()) + .add( + CustomMetadata.builder() + .key(AdkMetadataKey.CONTEXT_ID.getType()) + .stringValue(contextId) + .build()); + Object customMetadata = clientMetadata.get(A2AMetadataKey.CUSTOM_METADATA.getType()); + if (customMetadata != null) { + customMetadataList.addAll( + parseMetadata(customMetadata, new TypeReference>() {})); + } + eventBuilder.customMetadata(customMetadataList.build()); + + Object errorCode = clientMetadata.get(A2AMetadataKey.ERROR_CODE.getType()); + eventBuilder.errorCode(parseMetadata(errorCode, FinishReason.class)); + + return eventBuilder.build(); + } + + private static @Nullable T parseMetadata(@Nullable Object metadata, Class type) { + try { + if (metadata instanceof String jsonString) { + return objectMapper.readValue(jsonString, type); + } else { + return objectMapper.convertValue(metadata, type); + } + } catch (IllegalArgumentException | JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse metadata of type " + type, e); + } + } + + private static @Nullable T parseMetadata(@Nullable Object metadata, TypeReference type) { + try { + if (metadata instanceof String jsonString) { + return objectMapper.readValue(jsonString, type); + } else { + return objectMapper.convertValue(metadata, type); + } + } catch (IllegalArgumentException | JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse metadata of type " + type.getType(), e); + } + } + private static Event emptyEvent(InvocationContext invocationContext) { Event.Builder builder = Event.builder() diff --git a/a2a/src/main/java/com/google/adk/a2a/grpc/A2aService.java b/a2a/src/main/java/com/google/adk/a2a/grpc/A2aService.java index e6658e6c0..0e4881ab1 100644 --- a/a2a/src/main/java/com/google/adk/a2a/grpc/A2aService.java +++ b/a2a/src/main/java/com/google/adk/a2a/grpc/A2aService.java @@ -11,6 +11,7 @@ import com.google.adk.runner.Runner; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; +import com.google.adk.sessions.SessionKey; import com.google.genai.types.Content; import com.google.genai.types.Part; import io.grpc.stub.StreamObserver; @@ -167,8 +168,8 @@ public void sendMessage( .setMaxLlmCalls(20) .build(); - // Execute the agent using Runner with Session object - Flowable eventStream = runner.runAsync(session, userContent, runConfig); + SessionKey sessionKey = new SessionKey(session.appName(), session.userId(), session.id()); + Flowable eventStream = runner.runAsync(sessionKey, userContent, runConfig); // Collect all events and aggregate into a single response // Since sendMessage is unary RPC, we need to send a single response diff --git a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java index b1ffa248a..0609c3b04 100644 --- a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java @@ -113,6 +113,20 @@ public void setUp() { .build(); } + @Test + public void createAgent_streaming_false_returnsNonStreamingAgent() { + // With streaming false, the agent should not stream even if the AgentCard supports streaming. + RemoteA2AAgent agent = getAgentBuilder().streaming(false).build(); + assertThat(agent.isStreaming()).isFalse(); + } + + @Test + public void createAgent_streaming_true_returnsStreamingAgent() { + // With streaming true, the agent should support streaming if the AgentCard supports streaming. + RemoteA2AAgent agent = getAgentBuilder().streaming(true).build(); + assertThat(agent.isStreaming()).isTrue(); + } + @Test public void runAsync_aggregatesPartialEvents() { RemoteA2AAgent agent = createAgent(); @@ -763,7 +777,7 @@ private RemoteA2AAgent.Builder getAgentBuilder() { } private RemoteA2AAgent createAgent() { - return getAgentBuilder().build(); + return getAgentBuilder().streaming(true).build(); } @SuppressWarnings("unchecked") // cast for Mockito diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java index d93466dd2..4a0828c43 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java @@ -8,9 +8,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Blob; +import com.google.genai.types.CodeExecutionResult; +import com.google.genai.types.ExecutableCode; import com.google.genai.types.FileData; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Language; +import com.google.genai.types.Outcome; import com.google.genai.types.Part; import io.a2a.spec.DataPart; import io.a2a.spec.FilePart; @@ -86,8 +90,7 @@ public void toGenaiPart_withDataPartFunctionCall_returnsGenaiFunctionCallPart() new DataPart( data, ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_CALL.getType())); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType())); Part result = PartConverter.toGenaiPart(dataPart); @@ -121,7 +124,7 @@ public void toGenaiPart_withDataPartFunctionResponse_returnsGenaiFunctionRespons new DataPart( data, ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType())); Part result = PartConverter.toGenaiPart(dataPart); @@ -188,7 +191,7 @@ public void fromGenaiPart_withTextPart_returnsTextPart() { assertThat(((TextPart) result).getText()).isEqualTo("text"); assertThat(((TextPart) result).getMetadata()).containsEntry("thought", true); assertThat(((TextPart) result).getMetadata()) - .containsEntry(PartConverter.A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, true); + .containsEntry(A2AMetadataKey.PARTIAL.getType(), true); } @Test @@ -226,6 +229,39 @@ public void fromGenaiPart_withInlineDataPart_returnsFilePartWithBytes() { assertThat(Base64.getDecoder().decode(fileWithBytes.bytes())).isEqualTo(bytes); } + @Test + public void fromGenaiPart_dataPart_executableCode_returnsDataPart() { + ExecutableCode executableCode = + ExecutableCode.builder().code("print('hello')").language(new Language("python")).build(); + Part part = Part.builder().executableCode(executableCode).build(); + io.a2a.spec.Part result = PartConverter.fromGenaiPart(part, false); + + assertThat(result).isInstanceOf(DataPart.class); + DataPart dataPart = (DataPart) result; + assertThat(dataPart.getData().get("code")).isEqualTo("print('hello')"); + assertThat(dataPart.getData().get("language")).isEqualTo("python"); + assertThat(dataPart.getMetadata().get(A2AMetadataKey.TYPE.getType())) + .isEqualTo("executable_code"); + } + + @Test + public void fromGenaiPart_dataPart_codeExecutionResult_returnsDataPart() { + CodeExecutionResult codeExecutionResult = + CodeExecutionResult.builder() + .outcome(new Outcome("OUTCOME_OK")) + .output("print('hello')") + .build(); + Part part = Part.builder().codeExecutionResult(codeExecutionResult).build(); + io.a2a.spec.Part result = PartConverter.fromGenaiPart(part, false); + + assertThat(result).isInstanceOf(DataPart.class); + DataPart dataPart = (DataPart) result; + assertThat(dataPart.getData().get("outcome")).isEqualTo("OUTCOME_OK"); + assertThat(dataPart.getData().get("output")).isEqualTo("print('hello')"); + assertThat(dataPart.getMetadata().get(A2AMetadataKey.TYPE.getType())) + .isEqualTo("code_execution_result"); + } + @Test public void fromGenaiPart_withFunctionCallPart_returnsDataPart() { Part part = @@ -255,8 +291,7 @@ public void fromGenaiPart_withFunctionCallPart_returnsDataPart() { true); assertThat(dataPart.getMetadata()) .containsEntry( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_CALL.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType()); } @Test @@ -275,8 +310,7 @@ public void fromGenaiPart_withFunctionResponsePart_returnsDataPart() { .containsExactly("name", "func", "id", "1", "response", ImmutableMap.of()); assertThat(dataPart.getMetadata()) .containsEntry( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); } @Test diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java index d84dc42cd..b61b00e1a 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java @@ -1,6 +1,8 @@ package com.google.adk.a2a.converters; import static com.google.common.truth.Truth.assertThat; +import static java.util.stream.Collectors.joining; +import static org.junit.Assert.assertThrows; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; @@ -13,6 +15,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.GroundingMetadata; import io.a2a.client.MessageEvent; import io.a2a.client.TaskUpdateEvent; import io.a2a.spec.Artifact; @@ -136,6 +142,74 @@ public void taskToEvent_withStatusMessage_returnsEvent() { assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); } + @Test + public void taskToEvent_withGroundingMetadata_returnsEvent() { + GroundingMetadata groundingMetadata = + GroundingMetadata.builder().webSearchQueries("test-query").build(); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of( + A2AMetadataKey.GROUNDING_METADATA.getType(), groundingMetadata.toJson())) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); + assertThat(event.groundingMetadata()).hasValue(groundingMetadata); + } + + @Test + public void taskToEvent_withCustomMetadata_returnsEvent() { + ImmutableList customMetadataList = + ImmutableList.of( + CustomMetadata.builder().key("test-key").stringValue("test-value").build()); + String customMetadataJson = + customMetadataList.stream().map(CustomMetadata::toJson).collect(joining(",", "[", "]")); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata(ImmutableMap.of(A2AMetadataKey.CUSTOM_METADATA.getType(), customMetadataJson)) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); + assertThat(event.customMetadata().get()) + .containsExactly( + CustomMetadata.builder().key("a2a:task_id").stringValue("task-1").build(), + CustomMetadata.builder().key("a2a:context_id").stringValue("context-1").build(), + CustomMetadata.builder().key("test-key").stringValue("test-value").build()) + .inOrder(); + } + + @Test + public void messageToEvent_withMissingTaskId_returnsEvent() { + Message a2aMessage = + new Message.Builder() + .messageId("msg-1") + .role(Message.Role.USER) + .taskId("task-1") + .parts(ImmutableList.of(new TextPart("test-message"))) + .build(); + Event event = ResponseConverter.messageToEvent(a2aMessage, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.customMetadata()).isEmpty(); + } + @Test public void taskToEvent_withNoMessage_returnsEmptyEvent() { TaskStatus status = new TaskStatus(TaskState.WORKING, null, null); @@ -152,18 +226,18 @@ public void taskToEvent_withInputRequired_parsesLongRunningToolIds() { ImmutableMap.of("name", "myTool", "id", "call_123", "args", ImmutableMap.of()); ImmutableMap metadata = ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), "function_call", - PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, + A2AMetadataKey.IS_LONG_RUNNING.getType(), true); DataPart dataPart = new DataPart(data, metadata); ImmutableMap statusData = ImmutableMap.of("name", "messageTools", "id", "msg_123", "args", ImmutableMap.of()); ImmutableMap statusMetadata = ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), "function_call", - PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, + A2AMetadataKey.IS_LONG_RUNNING.getType(), true); DataPart statusDataPart = new DataPart(statusData, statusMetadata); Message statusMessage = @@ -361,6 +435,99 @@ public void clientEventToEvent_withFailedTaskStatusUpdateEvent_returnsErrorEvent assertThat(resultEvent.turnComplete()).hasValue(true); } + @Test + public void taskToEvent_withInvalidMetadata_throwsException() { + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of(A2AMetadataKey.GROUNDING_METADATA.getType(), "{ invalid json ]")) + .build(); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> ResponseConverter.taskToEvent(task, invocationContext)); + assertThat(exception).hasMessageThat().contains("Failed to parse metadata"); + assertThat(exception).hasMessageThat().contains("GroundingMetadata"); + } + + @Test + public void taskToEvent_withErrorCode_returnsEvent() { + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata(ImmutableMap.of(A2AMetadataKey.ERROR_CODE.getType(), "\"STOP\"")) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.errorCode()).hasValue(new FinishReason(FinishReason.Known.STOP)); + } + + @Test + public void taskToEvent_withUsageMetadata_returnsEvent() { + GenerateContentResponseUsageMetadata usageMetadata = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of(A2AMetadataKey.USAGE_METADATA.getType(), usageMetadata.toJson())) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.usageMetadata()).hasValue(usageMetadata); + } + + @Test + public void clientEventToEvent_withTaskArtifactUpdateEventAndPartialTrue_returnsEmpty() { + io.a2a.spec.Part a2aPart = new TextPart("Artifact content"); + Artifact artifact = + new Artifact.Builder().artifactId("artifact-1").parts(ImmutableList.of(a2aPart)).build(); + Task task = + testTask() + .status(new TaskStatus(TaskState.COMPLETED)) + .artifacts(ImmutableList.of(artifact)) + .build(); + TaskArtifactUpdateEvent updateEvent = + new TaskArtifactUpdateEvent.Builder() + .lastChunk(true) + .metadata(ImmutableMap.of(A2AMetadataKey.PARTIAL.getType(), true)) + .contextId("context-1") + .artifact(artifact) + .taskId("task-id-1") + .build(); + TaskUpdateEvent event = new TaskUpdateEvent(task, updateEvent); + + Optional optionalEvent = ResponseConverter.clientEventToEvent(event, invocationContext); + assertThat(optionalEvent).isEmpty(); + } + private static final class TestAgent extends BaseAgent { TestAgent() { super("test_agent", "test", ImmutableList.of(), null, null); diff --git a/a2a/src/test/java/com/google/adk/a2a/grpc/MediaSupportTest.java b/a2a/src/test/java/com/google/adk/a2a/grpc/MediaSupportTest.java index a8019be30..e7bc6eb8f 100644 --- a/a2a/src/test/java/com/google/adk/a2a/grpc/MediaSupportTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/grpc/MediaSupportTest.java @@ -14,7 +14,6 @@ import io.a2a.spec.FileWithUri; import io.a2a.spec.TextPart; import java.util.Base64; -import java.util.Optional; import org.junit.jupiter.api.Test; /** Tests for image, audio, and video support in A2A. */ @@ -36,11 +35,10 @@ class MediaSupportTest { void testTextPart_conversion() { // A2A TextPart to GenAI Part TextPart textPart = new TextPart("Hello, world!"); - Optional genaiPart = PartConverter.toGenaiPart(textPart); + Part genaiPart = PartConverter.toGenaiPart(textPart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().text()).isPresent(); - assertThat(genaiPart.get().text().get()).isEqualTo("Hello, world!"); + assertThat(genaiPart.text()).isPresent(); + assertThat(genaiPart.text().get()).isEqualTo("Hello, world!"); // GenAI Part to A2A TextPart Part genaiTextPart = Part.builder().text("Hello, world!").build(); @@ -57,11 +55,10 @@ void testImageFilePart_withUri() { FilePart imagePart = new FilePart(new FileWithUri("image/png", "test.png", "https://example.com/image.png")); - Optional genaiPart = PartConverter.toGenaiPart(imagePart); + Part genaiPart = PartConverter.toGenaiPart(imagePart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().fileData()).isPresent(); - FileData fileData = genaiPart.get().fileData().get(); + assertThat(genaiPart.fileData()).isPresent(); + FileData fileData = genaiPart.fileData().get(); assertThat(fileData.fileUri()).isPresent(); assertThat(fileData.fileUri().get()).isEqualTo("https://example.com/image.png"); assertThat(fileData.mimeType()).isPresent(); @@ -74,11 +71,10 @@ void testImageFilePart_withBytes() { FilePart imagePart = new FilePart(new FileWithBytes("image/png", "test.png", SAMPLE_IMAGE_BASE64)); - Optional genaiPart = PartConverter.toGenaiPart(imagePart); + Part genaiPart = PartConverter.toGenaiPart(imagePart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().inlineData()).isPresent(); - Blob blob = genaiPart.get().inlineData().get(); + assertThat(genaiPart.inlineData()).isPresent(); + Blob blob = genaiPart.inlineData().get(); assertThat(blob.mimeType()).isPresent(); assertThat(blob.mimeType().get()).isEqualTo("image/png"); assertThat(blob.data()).isPresent(); @@ -91,11 +87,10 @@ void testAudioFilePart_withUri() { FilePart audioPart = new FilePart(new FileWithUri("audio/mpeg", "test.mp3", "https://example.com/audio.mp3")); - Optional genaiPart = PartConverter.toGenaiPart(audioPart); + Part genaiPart = PartConverter.toGenaiPart(audioPart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().fileData()).isPresent(); - FileData fileData = genaiPart.get().fileData().get(); + assertThat(genaiPart.fileData()).isPresent(); + FileData fileData = genaiPart.fileData().get(); assertThat(fileData.mimeType()).isPresent(); assertThat(fileData.mimeType().get()).isEqualTo("audio/mpeg"); } @@ -106,11 +101,10 @@ void testAudioFilePart_withBytes() { FilePart audioPart = new FilePart(new FileWithBytes("audio/wav", "test.wav", SAMPLE_AUDIO_BASE64)); - Optional genaiPart = PartConverter.toGenaiPart(audioPart); + Part genaiPart = PartConverter.toGenaiPart(audioPart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().inlineData()).isPresent(); - Blob blob = genaiPart.get().inlineData().get(); + assertThat(genaiPart.inlineData()).isPresent(); + Blob blob = genaiPart.inlineData().get(); assertThat(blob.mimeType()).isPresent(); assertThat(blob.mimeType().get()).isEqualTo("audio/wav"); } @@ -121,11 +115,10 @@ void testVideoFilePart_withUri() { FilePart videoPart = new FilePart(new FileWithUri("video/mp4", "test.mp4", "https://example.com/video.mp4")); - Optional genaiPart = PartConverter.toGenaiPart(videoPart); + Part genaiPart = PartConverter.toGenaiPart(videoPart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().fileData()).isPresent(); - FileData fileData = genaiPart.get().fileData().get(); + assertThat(genaiPart.fileData()).isPresent(); + FileData fileData = genaiPart.fileData().get(); assertThat(fileData.mimeType()).isPresent(); assertThat(fileData.mimeType().get()).isEqualTo("video/mp4"); } @@ -136,11 +129,10 @@ void testVideoFilePart_withBytes() { FilePart videoPart = new FilePart(new FileWithBytes("video/mp4", "test.mp4", SAMPLE_VIDEO_BASE64)); - Optional genaiPart = PartConverter.toGenaiPart(videoPart); + Part genaiPart = PartConverter.toGenaiPart(videoPart); - assertThat(genaiPart).isPresent(); - assertThat(genaiPart.get().inlineData()).isPresent(); - Blob blob = genaiPart.get().inlineData().get(); + assertThat(genaiPart.inlineData()).isPresent(); + Blob blob = genaiPart.inlineData().get(); assertThat(blob.mimeType()).isPresent(); assertThat(blob.mimeType().get()).isEqualTo("video/mp4"); } @@ -227,16 +219,12 @@ void testMultipleMediaTypes_inMessage() { FilePart videoPart = new FilePart(new FileWithUri("video/mp4", "movie.mp4", "https://example.com/movie.mp4")); - Optional imageGenai = PartConverter.toGenaiPart(imagePart); - Optional audioGenai = PartConverter.toGenaiPart(audioPart); - Optional videoGenai = PartConverter.toGenaiPart(videoPart); + Part imageGenai = PartConverter.toGenaiPart(imagePart); + Part audioGenai = PartConverter.toGenaiPart(audioPart); + Part videoGenai = PartConverter.toGenaiPart(videoPart); - assertThat(imageGenai).isPresent(); - assertThat(audioGenai).isPresent(); - assertThat(videoGenai).isPresent(); - - assertThat(imageGenai.get().fileData().get().mimeType().get()).isEqualTo("image/jpeg"); - assertThat(audioGenai.get().fileData().get().mimeType().get()).isEqualTo("audio/mpeg"); - assertThat(videoGenai.get().fileData().get().mimeType().get()).isEqualTo("video/mp4"); + assertThat(imageGenai.fileData().get().mimeType().get()).isEqualTo("image/jpeg"); + assertThat(audioGenai.fileData().get().mimeType().get()).isEqualTo("audio/mpeg"); + assertThat(videoGenai.fileData().get().mimeType().get()).isEqualTo("video/mp4"); } } diff --git a/contrib/firestore-session-service/pom.xml b/contrib/firestore-session-service/pom.xml index 0079dce24..ed1ecd09b 100644 --- a/contrib/firestore-session-service/pom.xml +++ b/contrib/firestore-session-service/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT ../../pom.xml diff --git a/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java b/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java index ffcbcf8f5..43ca6889f 100644 --- a/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java +++ b/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java @@ -47,15 +47,11 @@ import com.google.genai.types.Part; import io.reactivex.rxjava3.observers.TestObserver; import java.time.Instant; -import java.util.AbstractMap; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -530,44 +526,6 @@ void appendAndGet_withAllPartTypes_serializesAndDeserializesCorrectly() { }); } - /** - * A wrapper class that implements ConcurrentMap but delegates to a HashMap. This is a workaround - * to allow putting null values, which ConcurrentHashMap forbids, for testing state removal logic. - */ - private static class HashMapAsConcurrentMap extends AbstractMap - implements ConcurrentMap { - private final HashMap map; - - public HashMapAsConcurrentMap(Map map) { - this.map = new HashMap<>(map); - } - - @Override - public Set> entrySet() { - return map.entrySet(); - } - - @Override - public V putIfAbsent(K key, V value) { - return map.putIfAbsent(key, value); - } - - @Override - public boolean remove(Object key, Object value) { - return map.remove(key, value); - } - - @Override - public boolean replace(K key, V oldValue, V newValue) { - return map.replace(key, oldValue, newValue); - } - - @Override - public V replace(K key, V value) { - return map.replace(key, value); - } - } - /** Tests that appendEvent with only app state deltas updates the correct stores. */ @Test void appendEvent_withAppOnlyStateDeltas_updatesCorrectStores() { @@ -662,63 +620,6 @@ void appendEvent_withUserOnlyStateDeltas_updatesCorrectStores() { verify(mockSessionDocRef, never()).update(eq(Constants.KEY_STATE), any()); } - /** - * Tests that appendEvent with all types of state deltas updates the correct stores and session - * state. - */ - @Test - void appendEvent_withAllStateDeltas_updatesCorrectStores() { - // Arrange - Session session = - Session.builder(SESSION_ID) - .appName(APP_NAME) - .userId(USER_ID) - .state(new ConcurrentHashMap<>()) // The session state itself must be concurrent - .build(); - session.state().put("keyToRemove", "someValue"); - - Map stateDeltaMap = new HashMap<>(); - stateDeltaMap.put("sessionKey", "sessionValue"); - stateDeltaMap.put("_app_appKey", "appValue"); - stateDeltaMap.put("_user_userKey", "userValue"); - stateDeltaMap.put("keyToRemove", null); - - // Use the wrapper to satisfy the ConcurrentMap interface for the builder - EventActions actions = - EventActions.builder().stateDelta(new HashMapAsConcurrentMap<>(stateDeltaMap)).build(); - - Event event = - Event.builder() - .author("model") - .content(Content.builder().parts(List.of(Part.fromText("..."))).build()) - .actions(actions) - .build(); - - when(mockSessionsCollection.document(SESSION_ID)).thenReturn(mockSessionDocRef); - when(mockEventsCollection.document()).thenReturn(mockEventDocRef); - when(mockEventDocRef.getId()).thenReturn(EVENT_ID); - // THIS IS THE MISSING MOCK: Stub the call to get the document by its specific ID. - when(mockEventsCollection.document(EVENT_ID)).thenReturn(mockEventDocRef); - // Add the missing mock for the final session update call - when(mockSessionDocRef.update(anyMap())) - .thenReturn(ApiFutures.immediateFuture(mockWriteResult)); - - // Act - sessionService.appendEvent(session, event).test().assertComplete(); - - // Assert - assertThat(session.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(session.state()).doesNotContainKey("keyToRemove"); - - ArgumentCaptor> appStateCaptor = ArgumentCaptor.forClass(Map.class); - verify(mockAppStateDocRef).set(appStateCaptor.capture(), any(SetOptions.class)); - assertThat(appStateCaptor.getValue()).containsEntry("appKey", "appValue"); - - ArgumentCaptor> userStateCaptor = ArgumentCaptor.forClass(Map.class); - verify(mockUserStateUserDocRef).set(userStateCaptor.capture(), any(SetOptions.class)); - assertThat(userStateCaptor.getValue()).containsEntry("userKey", "userValue"); - } - /** Tests that getSession skips malformed events and returns only the well-formed ones. */ @Test @SuppressWarnings("unchecked") diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index c2326fa0a..e88174849 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT ../../pom.xml @@ -58,11 +58,6 @@ google-adk ${project.version} - - com.google.adk - google-adk-dev - ${project.version} - com.google.genai google-genai diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 80c25610d..97331e7b4 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -18,10 +18,12 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; import com.google.adk.models.BaseLlm; import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.auto.value.AutoValue; import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; @@ -29,11 +31,11 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import com.google.genai.types.Schema; import com.google.genai.types.ToolConfig; import com.google.genai.types.Type; -import dev.langchain4j.Experimental; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.audio.Audio; @@ -51,6 +53,7 @@ import dev.langchain4j.data.pdf.PdfFile; import dev.langchain4j.data.video.Video; import dev.langchain4j.exception.UnsupportedFeatureException; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -64,6 +67,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; @@ -71,66 +75,101 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.UUID; +import org.jspecify.annotations.Nullable; -@Experimental -public class LangChain4j extends BaseLlm { +@AutoValue +public abstract class LangChain4j extends BaseLlm { private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference<>() {}; - private final ChatModel chatModel; - private final StreamingChatModel streamingChatModel; - private final ObjectMapper objectMapper; + LangChain4j() { + super(""); + } + + @Nullable + public abstract ChatModel chatModel(); + + @Nullable + public abstract StreamingChatModel streamingChatModel(); + + public abstract ObjectMapper objectMapper(); + + public abstract String modelName(); + + @Nullable + public abstract TokenCountEstimator tokenCountEstimator(); + + @Override + public String model() { + return modelName(); + } + + public static Builder builder() { + return new AutoValue_LangChain4j.Builder().objectMapper(new ObjectMapper()); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder chatModel(ChatModel chatModel); + + public abstract Builder streamingChatModel(StreamingChatModel streamingChatModel); + + public abstract Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator); + + public abstract Builder objectMapper(ObjectMapper objectMapper); + + public abstract Builder modelName(String modelName); + + public abstract LangChain4j build(); + } public LangChain4j(ChatModel chatModel) { - super( - Objects.requireNonNull( - chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = null; - this.objectMapper = new ObjectMapper(); + this(chatModel, null, null, chatModel.defaultRequestParameters().modelName(), null); } public LangChain4j(ChatModel chatModel, String modelName) { - super(Objects.requireNonNull(modelName, "chat model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = null; - this.objectMapper = new ObjectMapper(); + this(chatModel, null, null, modelName, null); } public LangChain4j(StreamingChatModel streamingChatModel) { - super( - Objects.requireNonNull( - streamingChatModel.defaultRequestParameters().modelName(), - "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this( + null, + streamingChatModel, + null, + streamingChatModel.defaultRequestParameters().modelName(), + null); } public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(null, streamingChatModel, null, modelName, null); } public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(chatModel, streamingChatModel, null, modelName, null); + } + + private LangChain4j( + ChatModel chatModel, + StreamingChatModel streamingChatModel, + ObjectMapper objectMapper, + String modelName, + TokenCountEstimator tokenCountEstimator) { + this(); + LangChain4j.builder() + .chatModel(chatModel) + .streamingChatModel(streamingChatModel) + .objectMapper(objectMapper) + .modelName(modelName) + .tokenCountEstimator(tokenCountEstimator) + .build(); } @Override public Flowable generateContent(LlmRequest llmRequest, boolean stream) { if (stream) { - if (this.streamingChatModel == null) { + if (this.streamingChatModel() == null) { return Flowable.error(new IllegalStateException("StreamingChatModel is not configured")); } @@ -138,54 +177,57 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre return Flowable.create( emitter -> { - streamingChatModel.chat( - chatRequest, - new StreamingChatResponseHandler() { - @Override - public void onPartialResponse(String s) { - emitter.onNext( - LlmResponse.builder().content(Content.fromParts(Part.fromText(s))).build()); - } - - @Override - public void onCompleteResponse(ChatResponse chatResponse) { - if (chatResponse.aiMessage().hasToolExecutionRequests()) { - AiMessage aiMessage = chatResponse.aiMessage(); - toParts(aiMessage).stream() - .map(Part::functionCall) - .forEach( - functionCall -> { - functionCall.ifPresent( - function -> { - emitter.onNext( - LlmResponse.builder() - .content( - Content.fromParts( - Part.fromFunctionCall( - function.name().orElse(""), - function.args().orElse(Map.of())))) - .build()); - }); - }); - } - emitter.onComplete(); - } - - @Override - public void onError(Throwable throwable) { - emitter.onError(throwable); - } - }); + streamingChatModel() + .chat( + chatRequest, + new StreamingChatResponseHandler() { + @Override + public void onPartialResponse(String s) { + emitter.onNext( + LlmResponse.builder() + .content(Content.fromParts(Part.fromText(s))) + .build()); + } + + @Override + public void onCompleteResponse(ChatResponse chatResponse) { + if (chatResponse.aiMessage().hasToolExecutionRequests()) { + AiMessage aiMessage = chatResponse.aiMessage(); + toParts(aiMessage).stream() + .map(Part::functionCall) + .forEach( + functionCall -> { + functionCall.ifPresent( + function -> { + emitter.onNext( + LlmResponse.builder() + .content( + Content.fromParts( + Part.fromFunctionCall( + function.name().orElse(""), + function.args().orElse(Map.of())))) + .build()); + }); + }); + } + emitter.onComplete(); + } + + @Override + public void onError(Throwable throwable) { + emitter.onError(throwable); + } + }); }, BackpressureStrategy.BUFFER); } else { - if (this.chatModel == null) { + if (this.chatModel() == null) { return Flowable.error(new IllegalStateException("ChatModel is not configured")); } ChatRequest chatRequest = toChatRequest(llmRequest); - ChatResponse chatResponse = chatModel.chat(chatRequest); - LlmResponse llmResponse = toLlmResponse(chatResponse); + ChatResponse chatResponse = chatModel().chat(chatRequest); + LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest); return Flowable.just(llmResponse); } @@ -412,7 +454,7 @@ private AiMessage toAiMessage(Content content) { private String toJson(Object object) { try { - return objectMapper.writeValueAsString(object); + return objectMapper().writeValueAsString(object); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -428,8 +470,24 @@ private List toToolSpecifications(LlmRequest llmRequest) { baseTool -> { if (baseTool.declaration().isPresent()) { FunctionDeclaration functionDeclaration = baseTool.declaration().get(); - if (functionDeclaration.parameters().isPresent()) { - Schema schema = functionDeclaration.parameters().get(); + Schema schema = null; + if (functionDeclaration.parametersJsonSchema().isPresent()) { + Object jsonSchemaObj = functionDeclaration.parametersJsonSchema().get(); + try { + if (jsonSchemaObj instanceof Schema) { + schema = (Schema) jsonSchemaObj; + } else { + schema = JsonBaseModel.getMapper().convertValue(jsonSchemaObj, Schema.class); + } + } catch (Exception e) { + throw new IllegalStateException( + "Failed to convert parametersJsonSchema to Schema: " + e.getMessage(), e); + } + } else if (functionDeclaration.parameters().isPresent()) { + schema = functionDeclaration.parameters().get(); + } + + if (schema != null) { ToolSpecification toolSpecification = ToolSpecification.builder() .name(baseTool.name()) @@ -438,11 +496,9 @@ private List toToolSpecifications(LlmRequest llmRequest) { .build(); toolSpecifications.add(toolSpecification); } else { - // TODO exception or something else? throw new IllegalStateException("Tool lacking parameters: " + baseTool); } } else { - // TODO exception or something else? throw new IllegalStateException("Tool lacking declaration: " + baseTool); } }); @@ -496,11 +552,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) { } } - private LlmResponse toLlmResponse(ChatResponse chatResponse) { + private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) { Content content = Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build(); - return LlmResponse.builder().content(content).build(); + LlmResponse.Builder builder = LlmResponse.builder().content(content); + TokenUsage tokenUsage = chatResponse.tokenUsage(); + if (tokenCountEstimator() != null) { + try { + int estimatedInput = + tokenCountEstimator().estimateTokenCountInMessages(chatRequest.messages()); + int estimatedOutput = + tokenCountEstimator().estimateTokenCountInText(chatResponse.aiMessage().text()); + int estimatedTotal = estimatedInput + estimatedOutput; + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(estimatedInput) + .candidatesTokenCount(estimatedOutput) + .totalTokenCount(estimatedTotal) + .build()); + } catch (Exception e) { + e.printStackTrace(); + } + } else if (tokenUsage != null) { + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(tokenUsage.inputTokenCount()) + .candidatesTokenCount(tokenUsage.outputTokenCount()) + .totalTokenCount(tokenUsage.totalTokenCount()) + .build()); + } + + return builder.build(); } private List toParts(AiMessage aiMessage) { @@ -524,14 +607,17 @@ private List toParts(AiMessage aiMessage) { }); return parts; } else { - Part part = Part.builder().text(aiMessage.text()).build(); - return List.of(part); + String text = aiMessage.text(); + if (text == null) { + return List.of(); + } + return List.of(Part.builder().text(text).build()); } } private Map toArgs(ToolExecutionRequest toolExecutionRequest) { try { - return objectMapper.readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); + return objectMapper().readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); } catch (JsonProcessingException e) { throw new RuntimeException(e); } diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index 3fafb046d..5b6d3f3ad 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -44,7 +44,7 @@ class LangChain4jIntegrationTest { - public static final String CLAUDE_3_7_SONNET_20250219 = "claude-3-7-sonnet-20250219"; + public static final String CLAUDE_4_6_SONNET = "claude-sonnet-4-6"; public static final String GEMINI_2_0_FLASH = "gemini-2.0-flash"; public static final String GPT_4_O_MINI = "gpt-4o-mini"; @@ -55,14 +55,15 @@ void testSimpleAgent() { AnthropicChatModel claudeModel = AnthropicChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); LlmAgent agent = LlmAgent.builder() .name("science-app") .description("Science teacher agent") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .model( + LangChain4j.builder().chatModel(claudeModel).modelName(CLAUDE_4_6_SONNET).build()) .instruction( """ You are a helpful science teacher that explains science concepts @@ -91,14 +92,15 @@ void testSingleAgentWithTools() { AnthropicChatModel claudeModel = AnthropicChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); BaseAgent agent = LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .model( + LangChain4j.builder().chatModel(claudeModel).modelName(CLAUDE_4_6_SONNET).build()) .instruction( """ You are a friendly assistant. @@ -155,7 +157,7 @@ void testSingleAgentWithTools() { List partsThree = contentThree.parts().get(); assertEquals(1, partsThree.size()); assertTrue(partsThree.get(0).text().isPresent()); - assertTrue(partsThree.get(0).text().get().contains("beautiful")); + assertTrue(partsThree.get(0).text().get().contains("sunny")); } @Test @@ -183,7 +185,7 @@ void testAgentTool() { LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly assistant. @@ -246,7 +248,7 @@ void testSubAgent() { LlmAgent.builder() .name("greeterAgent") .description("Friendly agent that greets users") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly that greets users. @@ -257,7 +259,7 @@ void testSubAgent() { LlmAgent.builder() .name("farewellAgent") .description("Friendly agent that says goodbye to users") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly that says goodbye to users. @@ -352,10 +354,14 @@ void testSimpleStreamingResponse() { AnthropicStreamingChatModel claudeStreamingModel = AnthropicStreamingChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); - LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_3_7_SONNET_20250219); + LangChain4j lc4jClaude = + LangChain4j.builder() + .streamingChatModel(claudeStreamingModel) + .modelName(CLAUDE_4_6_SONNET) + .build(); // when Flowable responses = @@ -413,7 +419,11 @@ void testStreamingRunConfig() { When someone greets you, respond with "Hello". If someone asks about the weather, call the `getWeather` function. """) - .model(new LangChain4j(streamingModel, "GPT_4_O_MINI")) + .model( + LangChain4j.builder() + .streamingChatModel(streamingModel) + .modelName("GPT_4_O_MINI") + .build()) // .model(new LangChain4j(streamingModel, // CLAUDE_3_7_SONNET_20250219)) .tools(FunctionTool.create(ToolExample.class, "getWeather")) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 428a5660c..a1ec7a3c2 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -19,6 +19,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.tools.FunctionTool; @@ -26,6 +27,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -33,6 +35,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; import java.util.List; @@ -57,8 +60,26 @@ void setUp() { chatModel = mock(ChatModel.class); streamingChatModel = mock(StreamingChatModel.class); - langChain4j = new LangChain4j(chatModel, MODEL_NAME); - streamingLangChain4j = new LangChain4j(streamingChatModel, MODEL_NAME); + langChain4j = LangChain4j.builder().chatModel(chatModel).modelName(MODEL_NAME).build(); + streamingLangChain4j = + LangChain4j.builder().streamingChatModel(streamingChatModel).modelName(MODEL_NAME).build(); + } + + @Test + void testBuilder() { + ObjectMapper customMapper = new ObjectMapper(); + LangChain4j customLc4j = + LangChain4j.builder() + .chatModel(chatModel) + .streamingChatModel(streamingChatModel) + .objectMapper(customMapper) + .modelName("custom-model") + .build(); + + assertThat(customLc4j.chatModel()).isEqualTo(chatModel); + assertThat(customLc4j.streamingChatModel()).isEqualTo(streamingChatModel); + assertThat(customLc4j.objectMapper()).isEqualTo(customMapper); + assertThat(customLc4j.modelName()).isEqualTo("custom-model"); } @Test @@ -688,4 +709,288 @@ void testGenerateContentWithStructuredResponseJsonSchema() { final UserMessage userMessage = (UserMessage) capturedRequest.messages().get(0); assertThat(userMessage.singleText()).isEqualTo("Give me information about John Doe"); } + + @Test + @DisplayName("Should handle MCP tools with parametersJsonSchema") + void testGenerateContentWithMcpToolParametersJsonSchema() { + // Given + // Create a mock BaseTool for MCP tool + final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class); + when(mcpTool.name()).thenReturn("mcpTool"); + when(mcpTool.description()).thenReturn("An MCP tool"); + + // Create a mock FunctionDeclaration + final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class); + when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration)); + + // MCP tools use parametersJsonSchema() instead of parameters() + // Create a JSON schema object (Map representation) + final Map jsonSchemaMap = + Map.of( + "type", + "object", + "properties", + Map.of("city", Map.of("type", "string", "description", "City name")), + "required", + List.of("city")); + + // Mock parametersJsonSchema() to return the JSON schema object + when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(jsonSchemaMap)); + when(functionDeclaration.parameters()).thenReturn(Optional.empty()); + + // Create a LlmRequest with the MCP tool + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool")))) + .tools(Map.of("mcpTool", mcpTool)) + .build(); + + // Mock the AI response + final AiMessage aiMessage = AiMessage.from("Tool executed successfully"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Tool executed successfully"); + + // Verify the request was built correctly with the tool specification + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool specifications were created from parametersJsonSchema + assertThat(capturedRequest.toolSpecifications()).isNotEmpty(); + assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); + assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); + } + + @Test + @DisplayName("Should handle MCP tools with parametersJsonSchema when it's already a Schema") + void testGenerateContentWithMcpToolParametersJsonSchemaAsSchema() { + // Given + // Create a mock BaseTool for MCP tool + final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class); + when(mcpTool.name()).thenReturn("mcpTool"); + when(mcpTool.description()).thenReturn("An MCP tool"); + + // Create a mock FunctionDeclaration + final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class); + when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration)); + + // Create a Schema object directly (when parametersJsonSchema returns Schema) + final Schema cityPropertySchema = + Schema.builder().type("STRING").description("City name").build(); + + final Schema objectSchema = + Schema.builder() + .type("OBJECT") + .properties(Map.of("city", cityPropertySchema)) + .required(List.of("city")) + .build(); + + // Mock parametersJsonSchema() to return Schema directly + when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(objectSchema)); + when(functionDeclaration.parameters()).thenReturn(Optional.empty()); + + // Create a LlmRequest with the MCP tool + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool")))) + .tools(Map.of("mcpTool", mcpTool)) + .build(); + + // Mock the AI response + final AiMessage aiMessage = AiMessage.from("Tool executed successfully"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Tool executed successfully"); + + // Verify the request was built correctly with the tool specification + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool specifications were created from parametersJsonSchema + assertThat(capturedRequest.toolSpecifications()).isNotEmpty(); + assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); + assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); + } + + @Test + @DisplayName( + "Should use TokenCountEstimator to estimate token usage when TokenUsage is not available") + void testTokenCountEstimatorFallback() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(50); // Input tokens + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(20); // Output tokens + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage (simulating when LLM doesn't provide token counts) + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response has usage metadata estimated by TokenCountEstimator + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("The weather is sunny today."); + + // IMPORTANT: Verify that token usage was estimated via the TokenCountEstimator + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(20)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(70)); // 50 + 20 + + // Verify the estimator was actually called + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should prioritize TokenCountEstimator over TokenUsage when estimator is provided") + void testTokenCountEstimatorPriority() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(100); // From estimator + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(50); // From estimator + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITH actual TokenUsage from the LLM + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + final TokenUsage actualTokenUsage = new TokenUsage(30, 15, 45); // Actual token counts from LLM + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(actualTokenUsage); // LLM provides token usage + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // IMPORTANT: When TokenCountEstimator is present, it takes priority over TokenUsage + assertThat(response).isNotNull(); + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(100)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(150)); // 100 + 50 + + // Verify the estimator was called (it takes priority) + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should not include usageMetadata when TokenUsage is null and no estimator provided") + void testNoUsageMetadataWithoutEstimator() { + // Given + // Create LangChain4j WITHOUT TokenCountEstimator (default behavior) + final LangChain4j langChain4jNoEstimator = + LangChain4j.builder().chatModel(chatModel).modelName(MODEL_NAME).build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Hello, world!")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("Hello! How can I help you?"); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jNoEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response does NOT have usage metadata + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Hello! How can I help you?"); + + // IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator + assertThat(response.usageMetadata()).isEmpty(); + } + + @Test + @DisplayName("Should handle null AiMessage text without throwing NPE") + void testGenerateContentWithNullAiMessageText() { + // Given + final LlmRequest llmRequest = + LlmRequest.builder().contents(List.of(Content.fromParts(Part.fromText("Hello")))).build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = mock(AiMessage.class); + when(aiMessage.text()).thenReturn(null); + when(aiMessage.hasToolExecutionRequests()).thenReturn(false); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final Flowable responseFlowable = langChain4j.generateContent(llmRequest, false); + final LlmResponse response = responseFlowable.blockingFirst(); + // Then - no NPE thrown, and content has no text parts + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts().orElse(List.of())).isEmpty(); + } } diff --git a/contrib/samples/a2a_basic/pom.xml b/contrib/samples/a2a_basic/pom.xml index 0eccb733b..80881842b 100644 --- a/contrib/samples/a2a_basic/pom.xml +++ b/contrib/samples/a2a_basic/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT .. diff --git a/contrib/samples/a2a_server/pom.xml b/contrib/samples/a2a_server/pom.xml index 0677ad718..61c44bc97 100644 --- a/contrib/samples/a2a_server/pom.xml +++ b/contrib/samples/a2a_server/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT .. diff --git a/contrib/samples/configagent/pom.xml b/contrib/samples/configagent/pom.xml index 059bd8a38..8f57b7f9e 100644 --- a/contrib/samples/configagent/pom.xml +++ b/contrib/samples/configagent/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT .. diff --git a/contrib/samples/helloworld/pom.xml b/contrib/samples/helloworld/pom.xml index df5d5e709..676a2bc96 100644 --- a/contrib/samples/helloworld/pom.xml +++ b/contrib/samples/helloworld/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-samples - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT .. diff --git a/contrib/samples/mcpfilesystem/pom.xml b/contrib/samples/mcpfilesystem/pom.xml index 16b139d35..7275313ab 100644 --- a/contrib/samples/mcpfilesystem/pom.xml +++ b/contrib/samples/mcpfilesystem/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT ../../.. diff --git a/contrib/samples/pom.xml b/contrib/samples/pom.xml index 4a415113f..ff48d6bd3 100644 --- a/contrib/samples/pom.xml +++ b/contrib/samples/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT ../.. diff --git a/contrib/spring-ai/pom.xml b/contrib/spring-ai/pom.xml index b24fa4b63..5f7300896 100644 --- a/contrib/spring-ai/pom.xml +++ b/contrib/spring-ai/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT ../../pom.xml @@ -29,7 +29,7 @@ Spring AI integration for the Agent Development Kit. - 2.0.0-M2 + 2.0.0-M3 1.21.3 diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java index f21b07ae9..c59a94f82 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java @@ -34,7 +34,6 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; -import org.springframework.ai.anthropic.api.AnthropicApi; /** * Integration tests with real Anthropic API. @@ -53,10 +52,14 @@ void testSimpleAgentWithRealAnthropicApi() throws InterruptedException { Thread.sleep(2000); // Create Anthropic model using Spring AI's builder pattern - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); // Wrap with SpringAI SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -92,10 +95,14 @@ void testStreamingWithRealAnthropicApi() throws InterruptedException { // Add delay to avoid rapid requests Thread.sleep(2000); - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -134,10 +141,14 @@ void testStreamingWithRealAnthropicApi() throws InterruptedException { @Test void testAgentWithToolsAndRealApi() { - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); LlmAgent agent = LlmAgent.builder() @@ -175,10 +186,13 @@ void testAgentWithToolsAndRealApi() { @Test void testDirectComparisonNonStreamingVsStreaming() throws InterruptedException { // Test both non-streaming and streaming with the same model to compare behavior - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -271,13 +285,13 @@ void testDirectComparisonNonStreamingVsStreaming() throws InterruptedException { @Test void testConfigurationOptions() { // Test with custom configuration - AnthropicChatOptions options = - AnthropicChatOptions.builder().model(CLAUDE_MODEL).temperature(0.7).maxTokens(100).build(); - - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).defaultOptions(options).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); diff --git a/core/pom.xml b/core/pom.xml index 8f7cb0dda..c7febd65a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT google-adk @@ -202,6 +202,26 @@ opentelemetry-sdk-testing test + + com.google.cloud + google-cloud-bigquery + 2.40.0 + + + org.apache.arrow + arrow-vector + 17.0.0 + + + org.apache.arrow + arrow-memory-core + 17.0.0 + + + org.apache.arrow + arrow-memory-netty + 17.0.0 + com.zaxxer HikariCP @@ -297,6 +317,16 @@ maven-compiler-plugin + + maven-jar-plugin + + + + test-jar + + + + maven-surefire-plugin diff --git a/core/src/main/java/com/google/adk/Version.java b/core/src/main/java/com/google/adk/Version.java index a7aeb8b1f..d6e18c5ad 100644 --- a/core/src/main/java/com/google/adk/Version.java +++ b/core/src/main/java/com/google/adk/Version.java @@ -22,7 +22,7 @@ */ public final class Version { // Don't touch this, release-please should keep it up to date. - public static final String JAVA_ADK_VERSION = "0.9.0"; // x-release-please-released-version + public static final String JAVA_ADK_VERSION = "1.0.0-rc.1"; // x-release-please-released-version private Version() {} } diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index ed6631c50..95fe838cc 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -529,7 +529,8 @@ public B beforeAgentCallback(BeforeAgentCallback beforeAgentCallback) { @CanIgnoreReturnValue public B beforeAgentCallback(List beforeAgentCallback) { - this.beforeAgentCallback = CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback); + this.beforeAgentCallback = + ImmutableList.copyOf(CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback)); return self(); } @@ -541,7 +542,8 @@ public B afterAgentCallback(AfterAgentCallback afterAgentCallback) { @CanIgnoreReturnValue public B afterAgentCallback(List afterAgentCallback) { - this.afterAgentCallback = CallbackUtil.getAfterAgentCallbacks(afterAgentCallback); + this.afterAgentCallback = + ImmutableList.copyOf(CallbackUtil.getAfterAgentCallbacks(afterAgentCallback)); return self(); } diff --git a/core/src/main/java/com/google/adk/agents/CallbackUtil.java b/core/src/main/java/com/google/adk/agents/CallbackUtil.java index 11740ae9c..4eb8704b6 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackUtil.java +++ b/core/src/main/java/com/google/adk/agents/CallbackUtil.java @@ -42,7 +42,7 @@ public final class CallbackUtil { * @return normalized async callbacks, or empty list if input is null. */ @CanIgnoreReturnValue - public static ImmutableList getBeforeAgentCallbacks( + public static List getBeforeAgentCallbacks( List beforeAgentCallbacks) { return getCallbacks( beforeAgentCallbacks, @@ -59,7 +59,7 @@ public static ImmutableList getBeforeAgentCallbacks( * @return normalized async callbacks, or empty list if input is null. */ @CanIgnoreReturnValue - public static ImmutableList getAfterAgentCallbacks( + public static List getAfterAgentCallbacks( List afterAgentCallback) { return getCallbacks( afterAgentCallback, diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 91ce13a87..365f4f8c1 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -90,16 +90,6 @@ public Builder toBuilder() { return new Builder(this); } - /** - * Creates a shallow copy of the given {@link InvocationContext}. - * - * @deprecated Use {@code other.toBuilder().build()} instead. - */ - @Deprecated(forRemoval = true) - public static InvocationContext copyOf(InvocationContext other) { - return other.toBuilder().build(); - } - /** Returns the session service for managing session state. */ public BaseSessionService sessionService() { return sessionService; @@ -156,16 +146,6 @@ public BaseAgent agent() { return agent; } - /** - * Sets the [agent] being invoked. This is useful when delegating to a sub-agent. - * - * @deprecated Use {@link #toBuilder()} and {@link Builder#agent(BaseAgent)} instead. - */ - @Deprecated(forRemoval = true) - public void agent(BaseAgent agent) { - this.agent = agent; - } - /** Returns the session associated with this invocation. */ public Session session() { return session; diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 6f6846bb8..530e3a2ff 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -45,8 +45,6 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.codeexecutors.BaseCodeExecutor; import com.google.adk.events.Event; -import com.google.adk.examples.BaseExampleProvider; -import com.google.adk.examples.Example; import com.google.adk.flows.llmflows.AutoFlow; import com.google.adk.flows.llmflows.BaseLlmFlow; import com.google.adk.flows.llmflows.SingleFlow; @@ -98,8 +96,6 @@ public enum IncludeContents { private final List toolsUnion; private final ImmutableList toolsets; private final Optional generateContentConfig; - // TODO: Remove exampleProvider field - examples should only be provided via ExampleTool - private final Optional exampleProvider; private final IncludeContents includeContents; private final boolean planning; @@ -133,7 +129,6 @@ protected LlmAgent(Builder builder) { this.globalInstruction = requireNonNullElse(builder.globalInstruction, new Instruction.Static("")); this.generateContentConfig = Optional.ofNullable(builder.generateContentConfig); - this.exampleProvider = Optional.ofNullable(builder.exampleProvider); this.includeContents = requireNonNullElse(builder.includeContents, IncludeContents.DEFAULT); this.planning = builder.planning != null && builder.planning; this.maxSteps = Optional.ofNullable(builder.maxSteps); @@ -181,7 +176,6 @@ public static class Builder extends BaseAgent.Builder { private Instruction globalInstruction; private ImmutableList toolsUnion; private GenerateContentConfig generateContentConfig; - private BaseExampleProvider exampleProvider; private IncludeContents includeContents; private Boolean planning; private Integer maxSteps; @@ -254,26 +248,6 @@ public Builder generateContentConfig(GenerateContentConfig generateContentConfig return this; } - // TODO: Remove these example provider methods and only use ExampleTool for providing examples. - // Direct example methods should be deprecated in favor of using ExampleTool consistently. - @CanIgnoreReturnValue - public Builder exampleProvider(BaseExampleProvider exampleProvider) { - this.exampleProvider = exampleProvider; - return this; - } - - @CanIgnoreReturnValue - public Builder exampleProvider(List examples) { - this.exampleProvider = (unused) -> examples; - return this; - } - - @CanIgnoreReturnValue - public Builder exampleProvider(Example... examples) { - this.exampleProvider = (unused) -> ImmutableList.copyOf(examples); - return this; - } - @CanIgnoreReturnValue public Builder includeContents(IncludeContents includeContents) { this.includeContents = includeContents; @@ -621,32 +595,6 @@ protected void validate() { this.disallowTransferToParent != null && this.disallowTransferToParent; this.disallowTransferToPeers = this.disallowTransferToPeers != null && this.disallowTransferToPeers; - - if (this.outputSchema != null) { - if (!this.disallowTransferToParent || !this.disallowTransferToPeers) { - logger.warn( - "Invalid config for agent {}: outputSchema cannot co-exist with agent transfer" - + " configurations. Setting disallowTransferToParent=true and" - + " disallowTransferToPeers=true.", - this.name); - this.disallowTransferToParent = true; - this.disallowTransferToPeers = true; - } - - if (this.subAgents != null && !this.subAgents.isEmpty()) { - throw new IllegalArgumentException( - "Invalid config for agent " - + this.name - + ": if outputSchema is set, subAgents must be empty to disable agent" - + " transfer."); - } - if (this.toolsUnion != null && !this.toolsUnion.isEmpty()) { - throw new IllegalArgumentException( - "Invalid config for agent " - + this.name - + ": if outputSchema is set, tools must be empty."); - } - } } @Override @@ -894,11 +842,6 @@ public Optional generateContentConfig() { return generateContentConfig; } - // TODO: Remove this getter - examples should only be provided via ExampleTool - public Optional exampleProvider() { - return exampleProvider; - } - public IncludeContents includeContents() { return includeContents; } @@ -911,7 +854,7 @@ public List toolsUnion() { return toolsUnion; } - public ImmutableList toolsets() { + public List toolsets() { return toolsets; } diff --git a/core/src/main/java/com/google/adk/apps/App.java b/core/src/main/java/com/google/adk/apps/App.java index 897e24490..a087b738c 100644 --- a/core/src/main/java/com/google/adk/apps/App.java +++ b/core/src/main/java/com/google/adk/apps/App.java @@ -64,7 +64,7 @@ public BaseAgent rootAgent() { return rootAgent; } - public ImmutableList plugins() { + public List plugins() { return plugins; } diff --git a/core/src/main/java/com/google/adk/artifacts/CassandraArtifactService.java b/core/src/main/java/com/google/adk/artifacts/CassandraArtifactService.java index 9977381e8..e30656942 100644 --- a/core/src/main/java/com/google/adk/artifacts/CassandraArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/CassandraArtifactService.java @@ -29,7 +29,7 @@ import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.List; -import java.util.Optional; +import org.jspecify.annotations.Nullable; /** * A Cassandra-backed implementation of the {@link BaseArtifactService}. @@ -74,11 +74,11 @@ public Single saveArtifact( @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { return Maybe.fromCallable( () -> { Row row; - if (version.isPresent()) { + if (version != null) { row = session .execute( @@ -87,7 +87,7 @@ public Maybe loadArtifact( userId, sessionId, filename, - version.get()) + version) .one(); } else { row = @@ -197,9 +197,7 @@ public static void main(String[] args) { // Load the artifact Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); System.out.println("Loaded artifact content: " + loadedArtifact.text().get()); CassandraHelper.close(); diff --git a/core/src/main/java/com/google/adk/artifacts/MapDbArtifactService.java b/core/src/main/java/com/google/adk/artifacts/MapDbArtifactService.java index a2a087b9a..7d88d5b00 100644 --- a/core/src/main/java/com/google/adk/artifacts/MapDbArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/MapDbArtifactService.java @@ -19,10 +19,10 @@ import io.reactivex.rxjava3.core.Single; import java.io.File; import java.util.NavigableMap; -import java.util.Optional; import java.util.Set; import java.util.logging.Level; import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; import org.mapdb.BTreeMap; // BTreeMap is suitable for range queries import org.mapdb.DB; import org.mapdb.DBMaker; @@ -167,15 +167,15 @@ public Single saveArtifact( */ @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { // The Callable should return the item (Part) or null. // Maybe.fromCallable will wrap the non-null item in a Maybe or emit empty if null. return Maybe.fromCallable( () -> { String key; - if (version.isPresent()) { + if (version != null) { // Load specific version - int v = version.get(); + int v = version; if (v < 0) { // Version numbers must be non-negative return null; // Return null for empty Maybe } diff --git a/core/src/main/java/com/google/adk/artifacts/MongoDbArtifactService.java b/core/src/main/java/com/google/adk/artifacts/MongoDbArtifactService.java index 0259a3b93..8313b57be 100644 --- a/core/src/main/java/com/google/adk/artifacts/MongoDbArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/MongoDbArtifactService.java @@ -5,7 +5,7 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; -import java.util.Optional; +import org.jspecify.annotations.Nullable; /** * @author Harshavardhan A @@ -22,7 +22,7 @@ public Single saveArtifact( @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { return null; } diff --git a/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java b/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java index a47e13fb1..74c60bd8e 100644 --- a/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java @@ -26,7 +26,7 @@ import io.reactivex.rxjava3.schedulers.Schedulers; import java.sql.SQLException; import java.util.List; -import java.util.Optional; +import org.jspecify.annotations.Nullable; /** * A PostgreSQL-backed implementation of the {@link BaseArtifactService}. @@ -166,14 +166,13 @@ public Single saveArtifact( @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { return Maybe.fromCallable( () -> { try { // Load from database ArtifactData artifactData = - dbHelper.loadArtifact( - appName, userId, sessionId, filename, version.orElse(null)); + dbHelper.loadArtifact(appName, userId, sessionId, filename, version); if (artifactData == null) { return null; diff --git a/core/src/main/java/com/google/adk/artifacts/RedisArtifactService.java b/core/src/main/java/com/google/adk/artifacts/RedisArtifactService.java index 87c58cdf7..c8553715a 100644 --- a/core/src/main/java/com/google/adk/artifacts/RedisArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/RedisArtifactService.java @@ -25,7 +25,7 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; -import java.util.Optional; +import org.jspecify.annotations.Nullable; import reactor.adapter.rxjava.RxJava3Adapter; /** @@ -72,11 +72,11 @@ public Single saveArtifact( @Override public Maybe loadArtifact( - String appName, String userId, String sessionId, String filename, Optional version) { + String appName, String userId, String sessionId, String filename, @Nullable Integer version) { String key = artifactKey(appName, userId, sessionId, filename); Single data; - if (version.isPresent()) { - data = RxJava3Adapter.monoToSingle(commands.lindex(key, version.get())); + if (version != null) { + data = RxJava3Adapter.monoToSingle(commands.lindex(key, version)); } else { data = RxJava3Adapter.monoToSingle(commands.lindex(key, -1)); } diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 1ca856b45..105300aa0 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -41,7 +41,7 @@ public class EventActions extends JsonBaseModel { private Set deletedArtifactIds; private @Nullable String transferToAgent; private @Nullable Boolean escalate; - private ConcurrentMap> requestedAuthConfigs; + private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent; private @Nullable EventCompaction compaction; @@ -139,13 +139,17 @@ public void setEscalate(@Nullable Boolean escalate) { } @JsonProperty("requestedAuthConfigs") - public ConcurrentMap> requestedAuthConfigs() { + public Map> requestedAuthConfigs() { return requestedAuthConfigs; } public void setRequestedAuthConfigs( - ConcurrentMap> requestedAuthConfigs) { - this.requestedAuthConfigs = requestedAuthConfigs; + Map> requestedAuthConfigs) { + if (requestedAuthConfigs == null) { + this.requestedAuthConfigs = new ConcurrentHashMap<>(); + } else { + this.requestedAuthConfigs = new ConcurrentHashMap<>(requestedAuthConfigs); + } } @JsonProperty("requestedToolConfirmations") @@ -157,9 +161,6 @@ public void setRequestedToolConfirmations( Map requestedToolConfirmations) { if (requestedToolConfirmations == null) { this.requestedToolConfirmations = new ConcurrentHashMap<>(); - } else if (requestedToolConfirmations instanceof ConcurrentMap) { - this.requestedToolConfirmations = - (ConcurrentMap) requestedToolConfirmations; } else { this.requestedToolConfirmations = new ConcurrentHashMap<>(requestedToolConfirmations); } @@ -251,7 +252,7 @@ public static class Builder { private Set deletedArtifactIds; private @Nullable String transferToAgent; private @Nullable Boolean escalate; - private ConcurrentMap> requestedAuthConfigs; + private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent = false; private @Nullable EventCompaction compaction; @@ -287,15 +288,23 @@ public Builder skipSummarization(boolean skipSummarization) { @CanIgnoreReturnValue @JsonProperty("stateDelta") - public Builder stateDelta(ConcurrentMap value) { - this.stateDelta = value; + public Builder stateDelta(@Nullable Map value) { + if (value == null) { + this.stateDelta = new ConcurrentHashMap<>(); + } else { + this.stateDelta = new ConcurrentHashMap<>(value); + } return this; } @CanIgnoreReturnValue @JsonProperty("artifactDelta") - public Builder artifactDelta(Map value) { - this.artifactDelta = new ConcurrentHashMap<>(value); + public Builder artifactDelta(@Nullable Map value) { + if (value == null) { + this.artifactDelta = new ConcurrentHashMap<>(); + } else { + this.artifactDelta = new ConcurrentHashMap<>(value); + } return this; } @@ -323,8 +332,12 @@ public Builder escalate(boolean escalate) { @CanIgnoreReturnValue @JsonProperty("requestedAuthConfigs") public Builder requestedAuthConfigs( - ConcurrentMap> value) { - this.requestedAuthConfigs = value; + @Nullable Map> value) { + if (value == null) { + this.requestedAuthConfigs = new ConcurrentHashMap<>(); + } else { + this.requestedAuthConfigs = new ConcurrentHashMap<>(value); + } return this; } @@ -333,10 +346,6 @@ public Builder requestedAuthConfigs( public Builder requestedToolConfirmations(@Nullable Map value) { if (value == null) { this.requestedToolConfirmations = new ConcurrentHashMap<>(); - return this; - } - if (value instanceof ConcurrentMap) { - this.requestedToolConfirmations = (ConcurrentMap) value; } else { this.requestedToolConfirmations = new ConcurrentHashMap<>(value); } diff --git a/core/src/main/java/com/google/adk/examples/ExampleUtils.java b/core/src/main/java/com/google/adk/examples/ExampleUtils.java index 9cce535dc..2f3927ece 100644 --- a/core/src/main/java/com/google/adk/examples/ExampleUtils.java +++ b/core/src/main/java/com/google/adk/examples/ExampleUtils.java @@ -64,6 +64,9 @@ public final class ExampleUtils { * @return string representation of the examples block. */ private static String convertExamplesToText(List examples) { + if (examples.isEmpty()) { + return ""; + } StringBuilder examplesStr = new StringBuilder(); // super header diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index ab5f6567a..d4fe1b838 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -91,8 +91,9 @@ public BaseLlmFlow( * RequestProcessor} transforming the provided {@code llmRequestRef} in-place, and emits the * events generated by them. */ - protected Flowable preprocess( + private Flowable preprocess( InvocationContext context, AtomicReference llmRequestRef) { + Context currentContext = Context.current(); LlmAgent agent = (LlmAgent) context.agent(); RequestProcessor toolsProcessor = @@ -114,6 +115,7 @@ protected Flowable preprocess( .concatMap( processor -> Single.defer(() -> processor.processRequest(context, llmRequestRef.get())) + .compose(Tracing.withContext(currentContext)) .doOnSuccess(result -> llmRequestRef.set(result.updatedRequest())) .flattenAsFlowable( result -> result.events() != null ? result.events() : ImmutableList.of())); @@ -128,7 +130,8 @@ protected Flowable postprocess( InvocationContext context, Event baseEventForLlmResponse, LlmRequest llmRequest, - LlmResponse llmResponse) { + LlmResponse llmResponse, + Context parentContext) { List> eventIterables = new ArrayList<>(); Single currentLlmResponse = Single.just(llmResponse); @@ -144,15 +147,16 @@ protected Flowable postprocess( }) .map(ResponseProcessingResult::updatedResponse); } - Context parentContext = Context.current(); - return currentLlmResponse.flatMapPublisher( - updatedResponse -> { - try (Scope scope = parentContext.makeCurrent()) { - return buildPostprocessingEvents( - updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); - } - }); + updatedResponse -> + buildPostprocessingEvents( + updatedResponse, + eventIterables, + context, + baseEventForLlmResponse, + llmRequest, + parentContext) + .compose(Tracing.withContext(parentContext))); } /** @@ -163,54 +167,80 @@ protected Flowable postprocess( * @param eventForCallbackUsage An Event object primarily for providing context (like actions) to * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ - private Flowable callLlm( + private Flowable callLlm( Context spanContext, InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { - LlmAgent agent = (LlmAgent) context.agent(); - LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) .toFlowable() + .concatMap( + llmResp -> + postprocess( + context, + eventForCallbackUsage, + llmRequestBuilder.build(), + llmResp, + spanContext)) .switchIfEmpty( Flowable.defer( () -> { + LlmAgent agent = (LlmAgent) context.agent(); BaseLlm llm = agent.resolvedModel().model().isPresent() ? agent.resolvedModel().model().get() : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, llmRequestBuilder, eventForCallbackUsage, exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .compose( - Tracing.trace("call_llm") - .setParent(spanContext) - .onSuccess( - (span, llmResp) -> - Tracing.traceCallLlm( - span, + LlmRequest finalLlmRequest = llmRequestBuilder.build(); + + Span span = + Tracing.getTracer() + .spanBuilder("call_llm") + .setParent(spanContext) + .startSpan(); + Context callLlmContext = spanContext.with(span); + + Flowable flowable = + llm.generateContent( + finalLlmRequest, + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, + llmRequestBuilder, + eventForCallbackUsage, + exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnError( + error -> { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .concatMap( + llmResp -> + handleAfterModelCallback(context, llmResp, eventForCallbackUsage) + .toFlowable()) + .flatMap( + llmResp -> + postprocess( context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp))) - .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()); + eventForCallbackUsage, + finalLlmRequest, + llmResp, + callLlmContext) + .doOnSubscribe( + s -> + Tracing.traceCallLlm( + span, + context, + eventForCallbackUsage.id(), + finalLlmRequest, + llmResp))); + + return Tracing.traceFlowable(callLlmContext, span, () -> flowable); })); } @@ -222,6 +252,7 @@ private Flowable callLlm( */ private Maybe handleBeforeModelCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -240,7 +271,11 @@ private Maybe handleBeforeModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequestBuilder) + .compose(Tracing.withContext(currentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult); @@ -257,6 +292,7 @@ private Maybe handleOnModelErrorCallback( LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent, Throwable throwable) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -277,7 +313,11 @@ private Maybe handleOnModelErrorCallback( () -> { LlmRequest llmRequest = llmRequestBuilder.build(); return Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequest, ex)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequest, ex) + .compose(Tracing.withContext(currentContext))) .firstElement(); }); @@ -292,6 +332,7 @@ private Maybe handleOnModelErrorCallback( */ private Single handleAfterModelCallback( InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -310,7 +351,11 @@ private Single handleAfterModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmResponse) + .compose(Tracing.withContext(currentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); @@ -330,7 +375,6 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex return Flowable.defer( () -> { - Context currentContext = Context.current(); return preprocess(context, llmRequestRef) .concatWith( Flowable.defer( @@ -362,23 +406,12 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex context, llmRequestAfterPreprocess, mutableEventTemplate) - .concatMap( - llmResponse -> { - try (Scope postScope = currentContext.makeCurrent()) { - return postprocess( - context, - mutableEventTemplate, - llmRequestAfterPreprocess, - llmResponse) - .doFinally( - () -> { - String oldId = mutableEventTemplate.id(); - String newId = Event.generateEventId(); - logger.debug( - "Resetting event ID from {} to {}", oldId, newId); - mutableEventTemplate.setId(newId); - }); - } + .doFinally( + () -> { + String oldId = mutableEventTemplate.id(); + String newId = Event.generateEventId(); + logger.debug("Resetting event ID from {} to {}", oldId, newId); + mutableEventTemplate.setId(newId); }) .concatMap( event -> { @@ -397,7 +430,12 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex "Agent not found: " + agentToTransfer))); } return postProcessedEvents.concatWith( - Flowable.defer(() -> nextAgent.get().runAsync(context))); + Flowable.defer( + () -> { + try (Scope s = spanContext.makeCurrent()) { + return nextAgent.get().runAsync(context); + } + })); } return postProcessedEvents; }); @@ -455,6 +493,8 @@ private Flowable run( public Flowable runLive(InvocationContext invocationContext) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); + // Capture agent context at assembly time to use as parent for agent transfer at subscription + // time. See Flowable.defer() usages below. Context spanContext = Context.current(); return preprocessEvents.concatWith( @@ -545,6 +585,10 @@ public void onError(Throwable e) { .author(invocationContext.agent().name()) .branch(invocationContext.branch().orElse(null)); + Span span = + Tracing.getTracer().spanBuilder("call_llm").setParent(spanContext).startSpan(); + Context callLlmContext = spanContext.with(span); + Flowable receiveFlow = connection .receive() @@ -556,7 +600,8 @@ public void onError(Throwable e) { invocationContext, baseEventForThisLlmResponse, llmRequestAfterPreprocess, - llmResponse); + llmResponse, + callLlmContext); }) .flatMap( event -> { @@ -570,7 +615,12 @@ public void onError(Throwable e) { "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = - nextAgent.get().runLive(invocationContext); + Flowable.defer( + () -> { + try (Scope s = spanContext.makeCurrent()) { + return nextAgent.get().runLive(invocationContext); + } + }); events = Flowable.concat(events, nextAgentEvents); } return events; @@ -592,7 +642,12 @@ public void onError(Throwable e) { } }); - return receiveFlow.takeWhile(event -> !event.actions().endInvocation().orElse(false)); + return Tracing.traceFlowable( + callLlmContext, + span, + () -> + receiveFlow.takeWhile( + event -> !event.actions().endInvocation().orElse(false))); })); } @@ -608,7 +663,8 @@ private Flowable buildPostprocessingEvents( List> eventIterables, InvocationContext context, Event baseEventForLlmResponse, - LlmRequest llmRequest) { + LlmRequest llmRequest, + Context parentContext) { Flowable processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables)); if (updatedResponse.content().isEmpty() && updatedResponse.errorCode().isEmpty() @@ -624,21 +680,28 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)); } - Maybe maybeFunctionResponseEvent = - context.runConfig().streamingMode() == StreamingMode.BIDI - ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) - : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); - - Flowable functionEvents = - maybeFunctionResponseEvent.flatMapPublisher( - functionResponseEvent -> { - Optional toolConfirmationEvent = - Functions.generateRequestConfirmationEvent( - context, modelResponseEvent, functionResponseEvent); - return toolConfirmationEvent.isPresent() - ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) - : Flowable.just(functionResponseEvent); - }); + Flowable functionEvents; + try (Scope scope = parentContext.makeCurrent()) { + Maybe maybeFunctionResponseEvent = + context.runConfig().streamingMode() == StreamingMode.BIDI + ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) + : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); + functionEvents = + maybeFunctionResponseEvent.flatMapPublisher( + functionResponseEvent -> { + Optional toolConfirmationEvent = + Functions.generateRequestConfirmationEvent( + context, modelResponseEvent, functionResponseEvent); + List events = new ArrayList<>(); + toolConfirmationEvent.ifPresent(events::add); + events.add(functionResponseEvent); + OutputSchema.getStructuredModelResponse(functionResponseEvent) + .ifPresent( + json -> + events.add(OutputSchema.createFinalModelResponseEvent(context, json))); + return Flowable.fromIterable(events); + }); + } return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Basic.java b/core/src/main/java/com/google/adk/flows/llmflows/Basic.java index 0876a26e8..5aa970be6 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Basic.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Basic.java @@ -19,6 +19,7 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.models.LlmRequest; +import com.google.adk.utils.ModelNameUtils; import com.google.common.collect.ImmutableList; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.LiveConnectConfig; @@ -60,7 +61,15 @@ public Single processRequest( .orElseGet(() -> GenerateContentConfig.builder().build())) .liveConnectConfig(liveConnectConfigBuilder.build()); - agent.outputSchema().ifPresent(builder::outputSchema); + agent + .outputSchema() + .ifPresent( + outputSchema -> { + if (agent.toolsUnion().isEmpty() + || ModelNameUtils.canUseOutputSchemaWithTools(modelName)) { + builder.outputSchema(outputSchema); + } + }); return Single.just( RequestProcessor.RequestProcessingResult.create(builder.build(), ImmutableList.of())); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Examples.java b/core/src/main/java/com/google/adk/flows/llmflows/Examples.java deleted file mode 100644 index d9cee5fa0..000000000 --- a/core/src/main/java/com/google/adk/flows/llmflows/Examples.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.flows.llmflows; - -import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.LlmAgent; -import com.google.adk.examples.ExampleUtils; -import com.google.adk.models.LlmRequest; -import com.google.common.collect.ImmutableList; -import com.google.genai.types.Content; -import io.reactivex.rxjava3.core.Single; - -/** {@link RequestProcessor} that populates examples in LLM request. */ -public final class Examples implements RequestProcessor { - - public Examples() {} - - @Override - public Single processRequest( - InvocationContext context, LlmRequest request) { - if (!(context.agent() instanceof LlmAgent)) { - throw new IllegalArgumentException("Agent in InvocationContext is not an instance of Agent."); - } - LlmAgent agent = (LlmAgent) context.agent(); - LlmRequest.Builder builder = request.toBuilder(); - - String query = - context - .userContent() - .flatMap(Content::parts) - .filter(parts -> !parts.isEmpty()) - .map(parts -> parts.get(0).text().orElse("")) - .orElse(""); - agent - .exampleProvider() - .ifPresent( - exampleProvider -> - builder.appendInstructions( - ImmutableList.of(ExampleUtils.buildExampleSi(exampleProvider, query)))); - return Single.just( - RequestProcessor.RequestProcessingResult.create(builder.build(), ImmutableList.of())); - } -} diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index c1a996064..84a8141ea 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -42,7 +42,6 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.context.Context; -import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Observable; @@ -163,7 +162,9 @@ public static Maybe handleFunctionCalls( } return functionResponseEventsObservable .toList() - .flatMapMaybe( + .toMaybe() + .compose(Tracing.withContext(parentContext)) + .flatMap( events -> { if (events.isEmpty()) { return Maybe.empty(); @@ -226,7 +227,9 @@ public static Maybe handleFunctionCallsLive( return responseEventsObservable .toList() - .flatMapMaybe( + .toMaybe() + .compose(Tracing.withContext(parentContext)) + .flatMap( events -> { if (events.isEmpty()) { return Maybe.empty(); @@ -243,47 +246,45 @@ private static Function> getFunctionCallMapper( Context parentContext) { return functionCall -> Maybe.defer( - () -> { - try (Scope scope = parentContext.makeCurrent()) { - BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = - ToolContext.builder(invocationContext) - .functionCallId(functionCall.id().orElse("")) - .toolConfirmation( - functionCall.id().map(toolConfirmations::get).orElse(null)) - .build(); - - Map functionArgs = - functionCall.args().map(HashMap::new).orElse(new HashMap<>()); - - Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) - .switchIfEmpty( - Maybe.defer( - () -> { - try (Scope innerScope = parentContext.makeCurrent()) { - return isLive - ? processFunctionLive( - invocationContext, - tool, - toolContext, - functionCall, - functionArgs, - parentContext) - : callTool(tool, functionArgs, toolContext, parentContext); - } - })); - - return postProcessFunctionResult( - maybeFunctionResult, - invocationContext, - tool, - functionArgs, - toolContext, - isLive, - parentContext); - } - }); + () -> { + BaseTool tool = tools.get(functionCall.name().get()); + ToolContext toolContext = + ToolContext.builder(invocationContext) + .functionCallId(functionCall.id().orElse("")) + .toolConfirmation( + functionCall.id().map(toolConfirmations::get).orElse(null)) + .build(); + + Map functionArgs = + functionCall.args().map(HashMap::new).orElse(new HashMap<>()); + + Maybe> maybeFunctionResult = + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + .switchIfEmpty( + Maybe.defer( + () -> + isLive + ? processFunctionLive( + invocationContext, + tool, + toolContext, + functionCall, + functionArgs, + parentContext) + : callTool( + tool, functionArgs, toolContext, parentContext)) + .compose(Tracing.withContext(parentContext))); + + return postProcessFunctionResult( + maybeFunctionResult, + invocationContext, + tool, + functionArgs, + toolContext, + isLive, + parentContext); + }) + .compose(Tracing.withContext(parentContext)); } /** @@ -410,34 +411,27 @@ private static Maybe postProcessFunctionResult( }) .flatMapMaybe( optionalInitialResult -> { - try (Scope scope = parentContext.makeCurrent()) { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - return maybeInvokeAfterToolCall( - invocationContext, tool, functionArgs, toolContext, initialFunctionResult) - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - return Maybe.fromCallable( - () -> - buildResponseEvent( - tool, - finalFunctionResult, - toolContext, - invocationContext)) - .compose( - Tracing.trace("tool_response [" + tool.name() + "]") - .setParent(parentContext)) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); - }); - } - }); + Map initialFunctionResult = optionalInitialResult.orElse(null); + + return maybeInvokeAfterToolCall( + invocationContext, tool, functionArgs, toolContext, initialFunctionResult) + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + Event event = + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext); + Tracing.traceToolResponse(event.id(), event); + return Maybe.just(event); + }); + }) + .compose( + Tracing.trace("tool_response [" + tool.name() + "]").setParent(parentContext)); } private static Optional mergeParallelFunctionResponseEvents( diff --git a/core/src/main/java/com/google/adk/flows/llmflows/OutputSchema.java b/core/src/main/java/com/google/adk/flows/llmflows/OutputSchema.java new file mode 100644 index 000000000..d1f322f18 --- /dev/null +++ b/core/src/main/java/com/google/adk/flows/llmflows/OutputSchema.java @@ -0,0 +1,119 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.flows.llmflows; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.adk.JsonBaseModel; +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.tools.SetModelResponseTool; +import com.google.adk.tools.ToolContext; +import com.google.adk.utils.ModelNameUtils; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Single; +import java.util.Objects; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Processor that handles output schema for agents with tools. */ +public final class OutputSchema implements RequestProcessor { + + private static final Logger logger = LoggerFactory.getLogger(OutputSchema.class); + + public OutputSchema() {} + + @Override + public Single processRequest( + InvocationContext context, LlmRequest request) { + if (!(context.agent() instanceof LlmAgent)) { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())); + } + LlmAgent agent = (LlmAgent) context.agent(); + String modelName = request.model().orElse(""); + + if (agent.outputSchema().isEmpty() + || agent.toolsUnion().isEmpty() + || ModelNameUtils.canUseOutputSchemaWithTools(modelName)) { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())); + } + + // Add the set_model_response tool to handle structured output + SetModelResponseTool setResponseTool = new SetModelResponseTool(agent.outputSchema().get()); + LlmRequest.Builder builder = request.toBuilder(); + + return setResponseTool + .processLlmRequest(builder, ToolContext.builder(context).build()) + .andThen( + Single.fromCallable( + () -> { + builder.appendInstructions( + ImmutableList.of( + "IMPORTANT: You have access to other tools, but you must provide your" + + " final response using the set_model_response tool with the" + + " required structured format. After using any other tools needed" + + " to complete the task, always call set_model_response with your" + + " final answer in the specified schema format.")); + return RequestProcessingResult.create(builder.build(), ImmutableList.of()); + })); + } + + /** + * Check if function response contains set_model_response and extract JSON. + * + * @param functionResponseEvent The function response event to check. + * @return JSON response string if set_model_response was called, Optional.empty() otherwise. + */ + public static Optional getStructuredModelResponse(Event functionResponseEvent) { + for (FunctionResponse funcResponse : functionResponseEvent.functionResponses()) { + if (Objects.equals(funcResponse.name().orElse(""), SetModelResponseTool.NAME)) { + Object response = funcResponse.response(); + // The tool returns the args map directly. + try { + return Optional.of(JsonBaseModel.getMapper().writeValueAsString(response)); + } catch (JsonProcessingException e) { + logger.error("Failed to serialize set_model_response result", e); + return Optional.empty(); + } + } + } + return Optional.empty(); + } + + /** + * Create a final model response event from set_model_response JSON. + * + * @param context The invocation context. + * @param jsonResponse The JSON response from set_model_response tool. + * @return A new Event that looks like a normal model response. + */ + public static Event createFinalModelResponseEvent( + InvocationContext context, String jsonResponse) { + return Event.builder() + .id(Event.generateEventId()) + .invocationId(context.invocationId()) + .author(context.agent().name()) + .branch(context.branch().orElse(null)) + .content(Content.builder().role("model").parts(Part.fromText(jsonResponse)).build()) + .build(); + } +} diff --git a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java index e00c0093d..a93eb3cb4 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java @@ -29,6 +29,7 @@ import com.google.adk.events.Event; import com.google.adk.events.ToolConfirmation; import com.google.adk.models.LlmRequest; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -37,6 +38,7 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.Collection; @@ -216,10 +218,13 @@ private Maybe assembleEvent( .build()) .build(); - return toolsMapSingle.flatMapMaybe( - toolsMap -> - Functions.handleFunctionCalls( - invocationContext, functionCallEvent, toolsMap, toolConfirmations)); + Context parentContext = Context.current(); + return toolsMapSingle + .flatMapMaybe( + toolsMap -> + Functions.handleFunctionCalls( + invocationContext, functionCallEvent, toolsMap, toolConfirmations)) + .compose(Tracing.withContext(parentContext)); } private static Optional> maybeCreateToolConfirmationEntry( diff --git a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java index de45ba702..41dff3b96 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java @@ -27,12 +27,12 @@ public class SingleFlow extends BaseLlmFlow { protected static final ImmutableList REQUEST_PROCESSORS = ImmutableList.of( new Basic(), + new OutputSchema(), new RequestConfirmationLlmRequestProcessor(), new Instructions(), new Identity(), new Compaction(), new Contents(), - new Examples(), CodeExecution.requestProcessor); protected static final ImmutableList RESPONSE_PROCESSORS = diff --git a/core/src/main/java/com/google/adk/models/Gemini.java b/core/src/main/java/com/google/adk/models/Gemini.java index 74cf78b98..6f145e1de 100644 --- a/core/src/main/java/com/google/adk/models/Gemini.java +++ b/core/src/main/java/com/google/adk/models/Gemini.java @@ -239,7 +239,7 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre p -> p.functionCall().isPresent() || p.functionResponse().isPresent() - || p.text().map(t -> !t.isBlank()).orElse(false))) + || p.text().isPresent())) .orElse(false)); } else { logger.debug("Sending generateContent request to model {}", effectiveModelName); @@ -272,11 +272,17 @@ static Flowable processRawResponses(Flowable 0 @@ -316,11 +322,20 @@ static Flowable processRawResponses(Flowable finalResponses = new ArrayList<>(); if (accumulatedThoughtText.length() > 0) { finalResponses.add( - thinkingResponseFromText(accumulatedThoughtText.toString())); + thinkingResponseFromText(accumulatedThoughtText.toString()).toBuilder() + .usageMetadata( + accumulatedText.length() > 0 + ? null + : finalRawResp.usageMetadata().orElse(null)) + .build()); } if (accumulatedText.length() > 0) { - finalResponses.add(responseFromText(accumulatedText.toString())); + finalResponses.add( + responseFromText(accumulatedText.toString()).toBuilder() + .usageMetadata(finalRawResp.usageMetadata().orElse(null)) + .build()); } + return Flowable.fromIterable(finalResponses); } return Flowable.empty(); diff --git a/core/src/main/java/com/google/adk/models/LlmRequest.java b/core/src/main/java/com/google/adk/models/LlmRequest.java index 1a45c3a95..760a7c1c6 100644 --- a/core/src/main/java/com/google/adk/models/LlmRequest.java +++ b/core/src/main/java/com/google/adk/models/LlmRequest.java @@ -150,7 +150,7 @@ private static Builder create() { abstract LiveConnectConfig liveConnectConfig(); @CanIgnoreReturnValue - abstract Builder tools(Map tools); + public abstract Builder tools(Map tools); abstract Map tools(); diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index e534da787..8d0366e9a 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -21,11 +21,13 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -126,6 +128,7 @@ public Maybe beforeRunCallback(InvocationContext invocationContext) { @Override public Completable afterRunCallback(InvocationContext invocationContext) { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletable( plugin -> @@ -136,11 +139,13 @@ public Completable afterRunCallback(InvocationContext invocationContext) { logger.error( "[{}] Error during callback 'afterRunCallback'", plugin.getName(), - e))); + e))) + .compose(Tracing.withContext(capturedContext)); } @Override public Completable close() { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletableDelayError( plugin -> @@ -149,7 +154,8 @@ public Completable close() { .doOnError( e -> logger.error( - "[{}] Error during callback 'close'", plugin.getName(), e))); + "[{}] Error during callback 'close'", plugin.getName(), e))) + .compose(Tracing.withContext(capturedContext)); } @Override @@ -227,7 +233,7 @@ public Maybe> onToolErrorCallback( */ private Maybe runMaybeCallbacks( Function> callbackExecutor, String callbackName) { - + Context capturedContext = Context.current(); return Flowable.fromIterable(this.plugins) .concatMapMaybe( plugin -> @@ -247,6 +253,7 @@ private Maybe runMaybeCallbacks( plugin.getName(), callbackName, e))) - .firstElement(); + .firstElement() + .compose(Tracing.withContext(capturedContext)); } } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java new file mode 100644 index 000000000..ef826fb56 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java @@ -0,0 +1,270 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.Exceptions.AppendSerializationError; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +/** Handles asynchronous batching and writing of events to BigQuery. */ +class BatchProcessor implements AutoCloseable { + private static final Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + + private final StreamWriter writer; + private final int batchSize; + private final Duration flushInterval; + @VisibleForTesting final BlockingQueue> queue; + private final ScheduledExecutorService executor; + @VisibleForTesting final BufferAllocator allocator; + final AtomicBoolean flushLock = new AtomicBoolean(false); + private final Schema arrowSchema; + private final VectorSchemaRoot root; + + public BatchProcessor( + StreamWriter writer, + int batchSize, + Duration flushInterval, + int queueMaxSize, + ScheduledExecutorService executor) { + this.writer = writer; + this.batchSize = batchSize; + this.flushInterval = flushInterval; + this.queue = new LinkedBlockingQueue<>(queueMaxSize); + this.executor = executor; + // It's safe to use Long.MAX_VALUE here as this is a top-level RootAllocator, + // and memory is properly managed via try-with-resources in the flush() method. + // The actual memory usage is bounded by the batchSize and individual row sizes. + this.allocator = new RootAllocator(Long.MAX_VALUE); + this.arrowSchema = BigQuerySchema.getArrowSchema(); + this.root = VectorSchemaRoot.create(arrowSchema, allocator); + } + + public void start() { + @SuppressWarnings("unused") + var unused = + executor.scheduleWithFixedDelay( + () -> { + try { + flush(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Error in background flush", e); + } + }, + flushInterval.toMillis(), + flushInterval.toMillis(), + MILLISECONDS); + } + + public void append(Map row) { + if (!queue.offer(row)) { + logger.warning("BigQuery event queue is full, dropping event."); + return; + } + if (queue.size() >= batchSize && !flushLock.get()) { + executor.execute(this::flush); + } + } + + public void flush() { + // Acquire the flushLock. If another flush is already in progress, return immediately. + if (!flushLock.compareAndSet(false, true)) { + return; + } + try { + if (queue.isEmpty()) { + return; + } + List> batch = new ArrayList<>(); + queue.drainTo(batch, batchSize); + if (batch.isEmpty()) { + return; + } + try { + root.allocateNew(); + for (int i = 0; i < batch.size(); i++) { + Map row = batch.get(i); + for (Field field : arrowSchema.getFields()) { + populateVector(root.getVector(field.getName()), i, row.get(field.getName())); + } + } + root.setRowCount(batch.size()); + try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch()) { + AppendRowsResponse result = writer.append(recordBatch).get(); + if (result.hasError()) { + logger.severe("BigQuery append error: " + result.getError().getMessage()); + for (var error : result.getRowErrorsList()) { + logger.severe( + String.format("Row error at index %d: %s", error.getIndex(), error.getMessage())); + } + } else { + logger.fine("Successfully wrote " + batch.size() + " rows to BigQuery."); + } + } catch (AppendSerializationError ase) { + logger.log( + Level.SEVERE, "Failed to write batch to BigQuery due to serialization error", ase); + Map rowIndexToErrorMessage = ase.getRowIndexToErrorMessage(); + if (rowIndexToErrorMessage != null && !rowIndexToErrorMessage.isEmpty()) { + logger.severe("Row-level errors found:"); + for (Map.Entry entry : rowIndexToErrorMessage.entrySet()) { + logger.severe( + String.format("Row error at index %d: %s", entry.getKey(), entry.getValue())); + } + } else { + logger.severe( + "AppendSerializationError occurred, but no row-specific errors were provided."); + } + } + } catch (Exception e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + logger.log(Level.SEVERE, "Failed to write batch to BigQuery", e); + } finally { + // Clear the vectors to release the memory. + root.clear(); + } + } finally { + flushLock.set(false); + if (queue.size() >= batchSize && !flushLock.get()) { + executor.execute(this::flush); + } + } + } + + private void populateVector(FieldVector vector, int index, Object value) { + if (value == null || (value instanceof JsonNode jsonNode && jsonNode.isNull())) { + vector.setNull(index); + return; + } + if (vector instanceof VarCharVector varCharVector) { + String strValue = (value instanceof JsonNode jsonNode) ? jsonNode.asText() : value.toString(); + varCharVector.setSafe(index, strValue.getBytes(UTF_8)); + } else if (vector instanceof BigIntVector bigIntVector) { + long longValue; + if (value instanceof JsonNode jsonNode) { + longValue = jsonNode.asLong(); + } else if (value instanceof Number number) { + longValue = number.longValue(); + } else { + longValue = Long.parseLong(value.toString()); + } + bigIntVector.setSafe(index, longValue); + } else if (vector instanceof BitVector bitVector) { + boolean boolValue = + (value instanceof JsonNode jsonNode) ? jsonNode.asBoolean() : (Boolean) value; + bitVector.setSafe(index, boolValue ? 1 : 0); + } else if (vector instanceof TimeStampVector timeStampVector) { + if (value instanceof Instant instant) { + long micros = + SECONDS.toMicros(instant.getEpochSecond()) + NANOSECONDS.toMicros(instant.getNano()); + timeStampVector.setSafe(index, micros); + } else if (value instanceof JsonNode jsonNode) { + timeStampVector.setSafe(index, jsonNode.asLong()); + } else if (value instanceof Long longValue) { + timeStampVector.setSafe(index, longValue); + } + } else if (vector instanceof ListVector listVector) { + int start = listVector.startNewValue(index); + if (value instanceof ArrayNode arrayNode) { + for (int i = 0; i < arrayNode.size(); i++) { + populateVector(listVector.getDataVector(), start + i, arrayNode.get(i)); + } + listVector.endValue(index, arrayNode.size()); + } else if (value instanceof List) { + List list = (List) value; + for (int i = 0; i < list.size(); i++) { + populateVector(listVector.getDataVector(), start + i, list.get(i)); + } + listVector.endValue(index, list.size()); + } + } else if (vector instanceof StructVector structVector) { + structVector.setIndexDefined(index); + if (value instanceof ObjectNode objectNode) { + for (FieldVector child : structVector.getChildrenFromFields()) { + populateVector(child, index, objectNode.get(child.getName())); + } + } else if (value instanceof Map) { + Map map = (Map) value; + for (FieldVector child : structVector.getChildrenFromFields()) { + populateVector(child, index, map.get(child.getName())); + } + } + } + } + + @Override + public void close() { + if (this.queue != null && !this.queue.isEmpty()) { + this.flush(); + } + if (this.allocator != null) { + try { + this.allocator.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close Buffer allocator", e); + } + } + if (this.root != null) { + try { + this.root.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close VectorSchemaRoot", e); + } + } + if (this.writer != null) { + try { + this.writer.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close BigQuery writer", e); + } + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java new file mode 100644 index 000000000..68b5fb5a1 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -0,0 +1,436 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.plugins.BasePlugin; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.api.gax.core.FixedCredentialsProvider; +import com.google.api.gax.retrying.RetrySettings; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryException; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Clustering; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardTableDefinition; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.TableInfo; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import java.io.IOException; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.threeten.bp.Duration; + +/** + * BigQuery Agent Analytics Plugin for Java. + * + *

Logs agent execution events directly to a BigQuery table using the Storage Write API. + */ +public class BigQueryAgentAnalyticsPlugin extends BasePlugin { + private static final Logger logger = + Logger.getLogger(BigQueryAgentAnalyticsPlugin.class.getName()); + private static final ImmutableList DEFAULT_AUTH_SCOPES = + ImmutableList.of("https://www.googleapis.com/auth/cloud-platform"); + private static final AtomicLong threadCounter = new AtomicLong(0); + + private final BigQueryLoggerConfig config; + private final BigQuery bigQuery; + private final BigQueryWriteClient writeClient; + private final ScheduledExecutorService executor; + private final Object tableEnsuredLock = new Object(); + @VisibleForTesting final BatchProcessor batchProcessor; + private volatile boolean tableEnsured = false; + + public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException { + this(config, createBigQuery(config)); + } + + public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery) + throws IOException { + super("bigquery_agent_analytics"); + this.config = config; + this.bigQuery = bigQuery; + ThreadFactory threadFactory = + r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); + this.executor = Executors.newScheduledThreadPool(1, threadFactory); + this.writeClient = createWriteClient(config); + + if (config.enabled()) { + StreamWriter writer = createWriter(config); + this.batchProcessor = + new BatchProcessor( + writer, + config.batchSize(), + config.batchFlushInterval(), + config.queueMaxSize(), + executor); + this.batchProcessor.start(); + } else { + this.batchProcessor = null; + } + } + + private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException { + BigQueryOptions.Builder builder = BigQueryOptions.newBuilder(); + if (config.credentials() != null) { + builder.setCredentials(config.credentials()); + } else { + builder.setCredentials( + GoogleCredentials.getApplicationDefault().createScoped(DEFAULT_AUTH_SCOPES)); + } + return builder.build().getService(); + } + + private void ensureTableExistsOnce() { + if (!tableEnsured) { + synchronized (tableEnsuredLock) { + if (!tableEnsured) { + // Table creation is expensive, so we only do it once per plugin instance. + tableEnsured = true; + ensureTableExists(bigQuery, config); + } + } + } + } + + private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) { + TableId tableId = TableId.of(config.projectId(), config.datasetId(), config.tableName()); + Schema schema = BigQuerySchema.getEventsSchema(); + try { + Table table = bigQuery.getTable(tableId); + logger.info("BigQuery table: " + tableId); + if (table == null) { + logger.info("Creating BigQuery table: " + tableId); + StandardTableDefinition.Builder tableDefinitionBuilder = + StandardTableDefinition.newBuilder().setSchema(schema); + if (!config.clusteringFields().isEmpty()) { + tableDefinitionBuilder.setClustering( + Clustering.newBuilder().setFields(config.clusteringFields()).build()); + } + TableInfo tableInfo = TableInfo.newBuilder(tableId, tableDefinitionBuilder.build()).build(); + bigQuery.create(tableInfo); + } else if (config.autoSchemaUpgrade()) { + // TODO(b/491851868): Implement auto-schema upgrade. + logger.info("BigQuery table already exists and auto-schema upgrade is enabled: " + tableId); + logger.info("Auto-schema upgrade is not implemented yet."); + } + } catch (BigQueryException e) { + if (e.getMessage().contains("invalid_grant")) { + logger.log( + Level.SEVERE, + "Failed to authenticate with BigQuery. Please run 'gcloud auth application-default" + + " login' to refresh your credentials or provide valid credentials in" + + " BigQueryLoggerConfig.", + e); + } else { + logger.log( + Level.WARNING, "Failed to check or create/upgrade BigQuery table: " + tableId, e); + } + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Failed to check or create/upgrade BigQuery table: " + tableId, e); + } + } + + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException { + if (config.credentials() != null) { + return BigQueryWriteClient.create( + BigQueryWriteSettings.newBuilder() + .setCredentialsProvider(FixedCredentialsProvider.create(config.credentials())) + .build()); + } + return BigQueryWriteClient.create(); + } + + protected String getStreamName(BigQueryLoggerConfig config) { + return String.format( + "projects/%s/datasets/%s/tables/%s/streams/_default", + config.projectId(), config.datasetId(), config.tableName()); + } + + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig(); + RetrySettings retrySettings = + RetrySettings.newBuilder() + .setMaxAttempts(retryConfig.maxRetries()) + .setInitialRetryDelay(Duration.ofMillis(retryConfig.initialDelay().toMillis())) + .setRetryDelayMultiplier(retryConfig.multiplier()) + .setMaxRetryDelay(Duration.ofMillis(retryConfig.maxDelay().toMillis())) + .build(); + + String streamName = getStreamName(config); + try { + return StreamWriter.newBuilder(streamName, writeClient) + .setRetrySettings(retrySettings) + .setWriterSchema(BigQuerySchema.getArrowSchema()) + .build(); + } catch (Exception e) { + throw new VerifyException("Failed to create StreamWriter for " + streamName, e); + } + } + + private void logEvent( + String eventType, + InvocationContext invocationContext, + Optional callbackContext, + Object content, + Map extraAttributes) { + if (batchProcessor == null) { + return; + } + + ensureTableExistsOnce(); + + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", eventType); + row.put( + "agent", + callbackContext.map(CallbackContext::agentName).orElse(invocationContext.agent().name())); + row.put("session_id", invocationContext.session().id()); + row.put("invocation_id", invocationContext.invocationId()); + row.put("user_id", invocationContext.userId()); + + if (content instanceof Content contentParts) { + row.put( + "content_parts", + JsonFormatter.formatContentParts(Optional.of(contentParts), config.maxContentLength())); + row.put( + "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + } else if (content != null) { + row.put( + "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + } + + Map attributes = new HashMap<>(config.customTags()); + if (extraAttributes != null) { + attributes.putAll(extraAttributes); + } + row.put( + "attributes", + JsonFormatter.smartTruncate(attributes, config.maxContentLength()).toString()); + + addTraceDetails(row); + batchProcessor.append(row); + } + + // TODO(b/491849911): Implement own trace management functionality. + private void addTraceDetails(Map row) { + SpanContext spanContext = Span.current().getSpanContext(); + if (spanContext.isValid()) { + row.put("trace_id", spanContext.getTraceId()); + row.put("span_id", spanContext.getSpanId()); + } + } + + @Override + public Completable close() { + if (batchProcessor != null) { + batchProcessor.close(); + } + if (writeClient != null) { + writeClient.close(); + } + try { + executor.shutdown(); + if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + return Completable.complete(); + } + + @Override + public Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { + return Maybe.fromAction( + () -> logEvent("USER_MESSAGE", invocationContext, Optional.empty(), userMessage, null)); + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + return Maybe.fromAction( + () -> logEvent("INVOCATION_START", invocationContext, Optional.empty(), null, null)); + } + + @Override + public Maybe onEventCallback(InvocationContext invocationContext, Event event) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("event_author", event.author()); + logEvent( + "EVENT", invocationContext, Optional.empty(), event.content().orElse(null), attrs); + }); + } + + @Override + public Completable afterRunCallback(InvocationContext invocationContext) { + return Completable.fromAction( + () -> { + logEvent("INVOCATION_END", invocationContext, Optional.empty(), null, null); + batchProcessor.flush(); + }); + } + + @Override + public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.fromAction( + () -> + logEvent( + "AGENT_START", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + null)); + } + + @Override + public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.fromAction( + () -> + logEvent( + "AGENT_END", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + null)); + } + + @Override + public Maybe beforeModelCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + LlmRequest req = llmRequest.build(); + attrs.put("model", req.model().orElse("unknown")); + logEvent( + "MODEL_REQUEST", + callbackContext.invocationContext(), + Optional.of(callbackContext), + req, + attrs); + }); + } + + @Override + public Maybe afterModelCallback( + CallbackContext callbackContext, LlmResponse llmResponse) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + llmResponse.usageMetadata().ifPresent(u -> attrs.put("usage_metadata", u)); + logEvent( + "MODEL_RESPONSE", + callbackContext.invocationContext(), + Optional.of(callbackContext), + llmResponse, + attrs); + }); + } + + @Override + public Maybe onModelErrorCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("error_message", error.getMessage()); + logEvent( + "MODEL_ERROR", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + attrs); + }); + } + + @Override + public Maybe> beforeToolCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + logEvent( + "TOOL_START", + toolContext.invocationContext(), + Optional.of(toolContext), + toolArgs, + attrs); + }); + } + + @Override + public Maybe> afterToolCallback( + BaseTool tool, + Map toolArgs, + ToolContext toolContext, + Map result) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + logEvent( + "TOOL_END", toolContext.invocationContext(), Optional.of(toolContext), result, attrs); + }); + } + + @Override + public Maybe> onToolErrorCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + attrs.put("error_message", error.getMessage()); + logEvent( + "TOOL_ERROR", toolContext.invocationContext(), Optional.of(toolContext), null, attrs); + }); + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java new file mode 100644 index 000000000..aa5bf37de --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java @@ -0,0 +1,204 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.google.auth.Credentials; +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; +import javax.annotation.Nullable; + +/** Configuration for the BigQueryAgentAnalyticsPlugin. */ +@AutoValue +public abstract class BigQueryLoggerConfig { + // Whether the plugin is enabled. + public abstract boolean enabled(); + + // List of event types to log. If None, all are allowed + // TODO(b/491852782): Implement allowlist/denylist for event types. + @Nullable + public abstract ImmutableList eventAllowlist(); + + // List of event types to ignore. + // TODO(b/491852782): Implement allowlist/denylist for event types. + @Nullable + public abstract ImmutableList eventDenylist(); + + // Max length for text content before truncation. + public abstract int maxContentLength(); + + // Project ID for the BigQuery table. + public abstract String projectId(); + + // Dataset ID for the BigQuery table. + public abstract String datasetId(); + + // Table name for the BigQuery table. + public abstract String tableName(); + + // Fields to cluster the table by. + public abstract ImmutableList clusteringFields(); + + // Whether to log multi-modal content. + // TODO(b/491852782): Implement logging of multi-modal content. + public abstract boolean logMultiModalContent(); + + // Retry configuration for BigQuery writes. + public abstract RetryConfig retryConfig(); + + // Number of rows to batch before flushing. + public abstract int batchSize(); + + // Duration to wait before flushing the queue. + public abstract Duration batchFlushInterval(); + + // Max time to wait for shutdown. + public abstract Duration shutdownTimeout(); + + // Max size of the batch processor queue. + public abstract int queueMaxSize(); + + // Optional custom formatter for content. + // TODO(b/491852782): Implement content formatter. + @Nullable + public abstract BiFunction contentFormatter(); + + // TODO(b/491852782): Implement connection id. + public abstract Optional connectionId(); + + // Toggle for session metadata (e.g. gchat thread-id). + // TODO(b/491852782): Implement logging of session metadata. + public abstract boolean logSessionMetadata(); + + // Static custom tags (e.g. {"agent_role": "sales"}). + // TODO(b/491852782): Implement custom tags. + public abstract ImmutableMap customTags(); + + // Automatically add new columns to existing tables when the plugin + // schema evolves. Only additive changes are made (columns are never + // dropped or altered). + // TODO(b/491852782): Implement auto-schema upgrade. + public abstract boolean autoSchemaUpgrade(); + + @Nullable + public abstract Credentials credentials(); + + public static Builder builder() { + return new AutoValue_BigQueryLoggerConfig.Builder() + .setEnabled(true) + .setMaxContentLength(500 * 1024) + .setDatasetId("agent_analytics") + .setTableName("events") + .setClusteringFields(ImmutableList.of("event_type", "agent", "user_id")) + .setLogMultiModalContent(true) + .setRetryConfig(RetryConfig.builder().build()) + .setBatchSize(1) + .setBatchFlushInterval(Duration.ofSeconds(1)) + .setShutdownTimeout(Duration.ofSeconds(10)) + .setQueueMaxSize(10000) + .setLogSessionMetadata(true) + .setCustomTags(ImmutableMap.of()) + // TODO(b/491851868): Enable auto-schema upgrade once implemented. + .setAutoSchemaUpgrade(false); + } + + /** Builder for {@link BigQueryLoggerConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setEnabled(boolean enabled); + + public abstract Builder setEventAllowlist(@Nullable List eventAllowlist); + + public abstract Builder setEventDenylist(@Nullable List eventDenylist); + + public abstract Builder setMaxContentLength(int maxContentLength); + + public abstract Builder setProjectId(String projectId); + + public abstract Builder setDatasetId(String datasetId); + + public abstract Builder setTableName(String tableName); + + public abstract Builder setClusteringFields(List clusteringFields); + + public abstract Builder setLogMultiModalContent(boolean logMultiModalContent); + + public abstract Builder setRetryConfig(RetryConfig retryConfig); + + public abstract Builder setBatchSize(int batchSize); + + public abstract Builder setBatchFlushInterval(Duration batchFlushInterval); + + public abstract Builder setShutdownTimeout(Duration shutdownTimeout); + + public abstract Builder setQueueMaxSize(int queueMaxSize); + + public abstract Builder setContentFormatter( + @Nullable BiFunction contentFormatter); + + public abstract Builder setConnectionId(String connectionId); + + public abstract Builder setLogSessionMetadata(boolean logSessionMetadata); + + public abstract Builder setCustomTags(Map customTags); + + public abstract Builder setAutoSchemaUpgrade(boolean autoSchemaUpgrade); + + public abstract Builder setCredentials(Credentials credentials); + + public abstract BigQueryLoggerConfig build(); + } + + /** Retry configuration for BigQuery writes. */ + @AutoValue + public abstract static class RetryConfig { + public abstract int maxRetries(); + + public abstract Duration initialDelay(); + + public abstract double multiplier(); + + public abstract Duration maxDelay(); + + public static Builder builder() { + return new AutoValue_BigQueryLoggerConfig_RetryConfig.Builder() + .setMaxRetries(3) + .setInitialDelay(Duration.ofSeconds(1)) + .setMultiplier(2.0) + .setMaxDelay(Duration.ofSeconds(10)); + } + + /** Builder for {@link RetryConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setMaxRetries(int maxRetries); + + public abstract Builder setInitialDelay(Duration initialDelay); + + public abstract Builder setMultiplier(double multiplier); + + public abstract Builder setMaxDelay(Duration maxDelay); + + public abstract RetryConfig build(); + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java new file mode 100644 index 000000000..81181a1e0 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java @@ -0,0 +1,304 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.FieldList; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardSQLTypeName; +import com.google.cloud.bigquery.storage.v1.TableFieldSchema; +import com.google.cloud.bigquery.storage.v1.TableFieldSchema.Mode; +import com.google.cloud.bigquery.storage.v1.TableSchema; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; + +/** Utility for defining the BigQuery events table schema. */ +public final class BigQuerySchema { + + private BigQuerySchema() {} + + private static final ImmutableMap> + FIELD_TYPE_TO_ARROW_FIELD_METADATA = + ImmutableMap.of( + StandardSQLTypeName.JSON, + ImmutableMap.of("ARROW:extension:name", "google:sqlType:json"), + StandardSQLTypeName.DATETIME, + ImmutableMap.of("ARROW:extension:name", "google:sqlType:datetime"), + StandardSQLTypeName.GEOGRAPHY, + ImmutableMap.of( + "ARROW:extension:name", + "google:sqlType:geography", + "ARROW:extension:metadata", + "{\"encoding\": \"WKT\"}")); + + /** Returns the BigQuery schema for the events table. */ + // TODO(b/491848381): Rely on the same schema defined for python plugin. + public static Schema getEventsSchema() { + return Schema.of( + Field.newBuilder("timestamp", StandardSQLTypeName.TIMESTAMP) + .setMode(Field.Mode.REQUIRED) + .setDescription("The UTC timestamp when the event occurred.") + .build(), + Field.newBuilder("event_type", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The category of the event.") + .build(), + Field.newBuilder("agent", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The name of the agent that generated this event.") + .build(), + Field.newBuilder("session_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("A unique identifier for the entire conversation session.") + .build(), + Field.newBuilder("invocation_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("A unique identifier for a single turn or execution.") + .build(), + Field.newBuilder("user_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The identifier of the end-user.") + .build(), + Field.newBuilder("trace_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry trace ID.") + .build(), + Field.newBuilder("span_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry span ID.") + .build(), + Field.newBuilder("parent_span_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry parent span ID.") + .build(), + Field.newBuilder("content", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("The primary payload of the event.") + .build(), + Field.newBuilder( + "content_parts", + StandardSQLTypeName.STRUCT, + FieldList.of( + Field.newBuilder("mime_type", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The MIME type of the content part.") + .build(), + Field.newBuilder("uri", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The URI of the content part if stored externally.") + .build(), + Field.newBuilder( + "object_ref", + StandardSQLTypeName.STRUCT, + FieldList.of( + Field.newBuilder("uri", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("version", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("authorizer", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("details", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .build())) + .setMode(Field.Mode.NULLABLE) + .setDescription("The ObjectRef of the content part if stored externally.") + .build(), + Field.newBuilder("text", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The raw text content.") + .build(), + Field.newBuilder("part_index", StandardSQLTypeName.INT64) + .setMode(Field.Mode.NULLABLE) + .setDescription("The zero-based index of this part.") + .build(), + Field.newBuilder("part_attributes", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Additional metadata as a JSON object string.") + .build(), + Field.newBuilder("storage_mode", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Indicates how the content part is stored.") + .build())) + .setMode(Field.Mode.REPEATED) + .setDescription("Multi-modal events content parts.") + .build(), + Field.newBuilder("attributes", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("A JSON object containing arbitrary key-value pairs.") + .build(), + Field.newBuilder("latency_ms", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("A JSON object containing latency measurements.") + .build(), + Field.newBuilder("status", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The outcome of the event.") + .build(), + Field.newBuilder("error_message", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Detailed error message if the status is 'ERROR'.") + .build(), + Field.newBuilder("is_truncated", StandardSQLTypeName.BOOL) + .setMode(Field.Mode.NULLABLE) + .setDescription("Indicates if the 'content' field was truncated.") + .build()); + } + + /** Returns the Arrow schema for the events table. */ + public static org.apache.arrow.vector.types.pojo.Schema getArrowSchema() { + return new org.apache.arrow.vector.types.pojo.Schema( + getEventsSchema().getFields().stream() + .map(BigQuerySchema::convertToArrowField) + .collect(toImmutableList())); + } + + /** Returns the serialized Arrow schema for the events table. */ + public static ByteString getSerializedArrowSchema() { + try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), getArrowSchema()); + return ByteString.copyFrom(out.toByteArray()); + } catch (IOException e) { + throw new VerifyException("Failed to serialize arrow schema", e); + } + } + + private static org.apache.arrow.vector.types.pojo.Field convertToArrowField(Field field) { + ArrowType arrowType = convertTypeToArrow(field.getType().getStandardType()); + ImmutableList children = null; + if (field.getSubFields() != null) { + children = + field.getSubFields().stream() + .map(BigQuerySchema::convertToArrowField) + .collect(toImmutableList()); + } + + ImmutableMap metadata = + FIELD_TYPE_TO_ARROW_FIELD_METADATA.get(field.getType().getStandardType()); + + FieldType fieldType = + new FieldType(field.getMode() != Field.Mode.REQUIRED, arrowType, null, metadata); + org.apache.arrow.vector.types.pojo.Field arrowField = + new org.apache.arrow.vector.types.pojo.Field(field.getName(), fieldType, children); + + if (field.getMode() == Field.Mode.REPEATED) { + return new org.apache.arrow.vector.types.pojo.Field( + field.getName(), + new FieldType(false, new ArrowType.List(), null), + ImmutableList.of( + new org.apache.arrow.vector.types.pojo.Field( + "element", arrowField.getFieldType(), arrowField.getChildren()))); + } + return arrowField; + } + + private static ArrowType convertTypeToArrow(StandardSQLTypeName type) { + return switch (type) { + case BOOL -> new ArrowType.Bool(); + case BYTES -> new ArrowType.Binary(); + case DATE -> new ArrowType.Date(DateUnit.DAY); + case DATETIME -> + // Arrow doesn't have a direct DATETIME, often mapped to Timestamp or Utf8 + new ArrowType.Utf8(); + case FLOAT64 -> new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case INT64 -> new ArrowType.Int(64, true); + case NUMERIC, BIGNUMERIC -> new ArrowType.Decimal(38, 9, 128); + case GEOGRAPHY, STRING, JSON -> new ArrowType.Utf8(); + case STRUCT -> new ArrowType.Struct(); + case TIME -> new ArrowType.Time(TimeUnit.MICROSECOND, 64); + case TIMESTAMP -> new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC"); + default -> new ArrowType.Null(); + }; + } + + /** Returns names of fields to cluster by default. */ + public static ImmutableList getDefaultClusteringFields() { + return ImmutableList.of("event_type", "agent", "user_id"); + } + + /** Returns the BigQuery TableSchema for the events table (Storage Write API). */ + public static TableSchema getEventsTableSchema() { + return convertTableSchema(getEventsSchema()); + } + + private static TableSchema convertTableSchema(Schema schema) { + TableSchema.Builder result = TableSchema.newBuilder(); + for (int i = 0; i < schema.getFields().size(); i++) { + result.addFields(i, convertFieldSchema(schema.getFields().get(i))); + } + return result.build(); + } + + private static TableFieldSchema convertFieldSchema(Field field) { + TableFieldSchema.Builder result = TableFieldSchema.newBuilder(); + Field.Mode mode = field.getMode() != null ? field.getMode() : Field.Mode.NULLABLE; + + Mode resultMode = Mode.valueOf(mode.name()); + result.setMode(resultMode).setName(field.getName()); + + StandardSQLTypeName standardType = field.getType().getStandardType(); + TableFieldSchema.Type resultType = convertType(standardType); + result.setType(resultType); + + if (field.getDescription() != null) { + result.setDescription(field.getDescription()); + } + if (field.getSubFields() != null) { + for (int i = 0; i < field.getSubFields().size(); i++) { + result.addFields(i, convertFieldSchema(field.getSubFields().get(i))); + } + } + return result.build(); + } + + private static TableFieldSchema.Type convertType(StandardSQLTypeName type) { + return switch (type) { + case BOOL -> TableFieldSchema.Type.BOOL; + case BYTES -> TableFieldSchema.Type.BYTES; + case DATE -> TableFieldSchema.Type.DATE; + case DATETIME -> TableFieldSchema.Type.DATETIME; + case FLOAT64 -> TableFieldSchema.Type.DOUBLE; + case GEOGRAPHY -> TableFieldSchema.Type.GEOGRAPHY; + case INT64 -> TableFieldSchema.Type.INT64; + case NUMERIC -> TableFieldSchema.Type.NUMERIC; + case STRING -> TableFieldSchema.Type.STRING; + case STRUCT -> TableFieldSchema.Type.STRUCT; + case TIME -> TableFieldSchema.Type.TIME; + case TIMESTAMP -> TableFieldSchema.Type.TIMESTAMP; + case BIGNUMERIC -> TableFieldSchema.Type.BIGNUMERIC; + case JSON -> TableFieldSchema.Type.JSON; + case INTERVAL -> TableFieldSchema.Type.INTERVAL; + default -> TableFieldSchema.Type.TYPE_UNSPECIFIED; + }; + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java new file mode 100644 index 000000000..b4b4a1049 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java @@ -0,0 +1,111 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.Part; +import java.util.List; +import java.util.Optional; + +/** Utility for formatting and truncating content for BigQuery logging. */ +final class JsonFormatter { + private static final ObjectMapper mapper = new ObjectMapper().findAndRegisterModules(); + + private JsonFormatter() {} + + /** Formats Content parts into an ArrayNode for BigQuery logging. */ + public static ArrayNode formatContentParts(Optional content, int maxLength) { + ArrayNode partsArray = mapper.createArrayNode(); + if (content.isEmpty() || content.get().parts() == null) { + return partsArray; + } + + List parts = content.get().parts().orElse(ImmutableList.of()); + + for (int i = 0; i < parts.size(); i++) { + Part part = parts.get(i); + ObjectNode partObj = mapper.createObjectNode(); + partObj.put("part_index", i); + partObj.put("storage_mode", "INLINE"); + + if (part.text().isPresent()) { + partObj.put("mime_type", "text/plain"); + partObj.put("text", truncateString(part.text().get(), maxLength)); + } else if (part.inlineData().isPresent()) { + Blob blob = part.inlineData().get(); + partObj.put("mime_type", blob.mimeType().orElse("")); + partObj.put("text", "[BINARY DATA]"); + } else if (part.fileData().isPresent()) { + FileData fileData = part.fileData().get(); + partObj.put("mime_type", fileData.mimeType().orElse("")); + partObj.put("uri", fileData.fileUri().orElse("")); + partObj.put("storage_mode", "EXTERNAL_URI"); + } + partsArray.add(partObj); + } + return partsArray; + } + + /** Recursively truncates long strings inside an object and returns a Jackson JsonNode. */ + public static JsonNode smartTruncate(Object obj, int maxLength) { + if (obj == null) { + return mapper.nullNode(); + } + try { + return recursiveSmartTruncate(mapper.valueToTree(obj), maxLength); + } catch (IllegalArgumentException e) { + // Fallback for types that mapper can't handle directly as a tree + return mapper.valueToTree(String.valueOf(obj)); + } + } + + private static JsonNode recursiveSmartTruncate(JsonNode node, int maxLength) { + if (node.isTextual()) { + return mapper.valueToTree(truncateString(node.asText(), maxLength)); + } else if (node.isObject()) { + ObjectNode newNode = mapper.createObjectNode(); + node.properties() + .iterator() + .forEachRemaining( + entry -> { + newNode.set(entry.getKey(), recursiveSmartTruncate(entry.getValue(), maxLength)); + }); + return newNode; + } else if (node.isArray()) { + ArrayNode newNode = mapper.createArrayNode(); + for (JsonNode element : node) { + newNode.add(recursiveSmartTruncate(element, maxLength)); + } + return newNode; + } + return node; + } + + private static String truncateString(String s, int maxLength) { + if (s == null || s.length() <= maxLength) { + return s; + } + return s.substring(0, maxLength) + "...[truncated]"; + } +} diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 5859c4786..2bfbca881 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -52,6 +52,7 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -312,6 +313,7 @@ private Single appendNewMessageToSession( throw new IllegalArgumentException("No parts in the new_message."); } + Completable saveArtifactsFlow = Completable.complete(); if (this.artifactService != null && saveInputBlobsAsArtifacts) { // The runner directly saves the artifacts (if applicable) in the user message and replaces // the artifact data with a file name placeholder. @@ -321,9 +323,11 @@ private Single appendNewMessageToSession( continue; } String fileName = "artifact_" + invocationContext.invocationId() + "_" + i; - var unused = - this.artifactService.saveArtifact( - this.appName, session.userId(), session.id(), fileName, part); + saveArtifactsFlow = + saveArtifactsFlow.andThen( + this.artifactService + .saveArtifact(this.appName, session.userId(), session.id(), fileName, part) + .ignoreElement()); newMessage .parts() @@ -348,7 +352,8 @@ private Single appendNewMessageToSession( EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build()); } - return this.sessionService.appendEvent(session, eventBuilder.build()); + return saveArtifactsFlow.andThen( + this.sessionService.appendEvent(session, eventBuilder.build())); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -375,20 +380,25 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - Maybe maybeSession = - this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); - return maybeSession - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) + .compose(Tracing.trace("invocation")); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -415,35 +425,6 @@ public Flowable runAsync(String userId, String sessionId, Content newMess return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); } - /** - * See {@link #runAsync(Session, Content, RunConfig, Map)}. - * - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { - return runAsync(session, newMessage, runConfig, /* stateDelta= */ null); - } - - /** - * Runs the agent asynchronously using a provided Session object. - * - * @param session The session to run the agent in. - * @param newMessage The new message from the user to process. - * @param runConfig Configuration for the agent run. - * @param stateDelta Optional map of state updates to merge into the session for this run. - * @return A Flowable stream of {@link Event} objects generated by the agent during execution. - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync( - Session session, - Content newMessage, - RunConfig runConfig, - @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta); - } - /** * Runs the agent asynchronously using a provided Session object. * @@ -461,6 +442,7 @@ protected Flowable runAsyncImpl( Preconditions.checkNotNull(session, "session cannot be null"); Preconditions.checkNotNull(newMessage, "newMessage cannot be null"); Preconditions.checkNotNull(runConfig, "runConfig cannot be null"); + Context capturedContext = Context.current(); return Flowable.defer( () -> { BaseAgent rootAgent = this.agent; @@ -476,6 +458,7 @@ protected Flowable runAsyncImpl( return this.pluginManager .onUserMessageCallback(initialContext, newMessage) + .compose(Tracing.withContext(capturedContext)) .defaultIfEmpty(newMessage) .flatMap( content -> @@ -500,7 +483,8 @@ protected Flowable runAsyncImpl( event, invocationId, runConfig, - rootAgent)); + rootAgent)) + .compose(Tracing.withContext(capturedContext)); }); }) .doOnError( @@ -508,8 +492,7 @@ protected Flowable runAsyncImpl( Span span = Span.current(); span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); span.recordException(throwable); - }) - .compose(Tracing.trace("invocation")); + }); } private Flowable runAgentWithFreshSession( @@ -546,7 +529,7 @@ private Flowable runAgentWithFreshSession( contextWithUpdatedSession .agent() .runAsync(contextWithUpdatedSession) - .flatMap( + .concatMap( agentEvent -> this.sessionService .appendEvent(updatedSession, agentEvent) @@ -562,12 +545,14 @@ private Flowable runAgentWithFreshSession( .toFlowable()); // If beforeRunCallback returns content, emit it and skip agent + Context capturedContext = Context.current(); return beforeRunEvent .toFlowable() .switchIfEmpty(agentEvents) .concatWith( Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) - .concatWith(Completable.defer(() -> compactEvents(updatedSession))); + .concatWith(Completable.defer(() -> compactEvents(updatedSession))) + .compose(Tracing.withContext(capturedContext)); } private Completable compactEvents(Session session) { @@ -632,46 +617,9 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { .agent(this.findAgentToRun(session, rootAgent)); } - /** - * Runs the agent in live mode, appending generated events to the session. - * - * @return stream of events from the agent. - */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return Flowable.defer( - () -> { - InvocationContext invocationContext = - newInvocationContextForLive(session, liveRequestQueue, runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - return invocationContextSingle - .flatMapPublisher( - updatedInvocationContext -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .doOnNext(event -> this.sessionService.appendEvent(session, event))) - .doOnError( - throwable -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - }); - }) - .compose(Tracing.trace("invocation")); + return runLiveImpl(session, liveRequestQueue, runConfig).compose(Tracing.trace("invocation")); } /** @@ -682,19 +630,25 @@ public Flowable runLive( */ public Flowable runLive( String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runLiveImpl(session, liveRequestQueue, runConfig))) + .compose(Tracing.trace("invocation")); } /** @@ -709,15 +663,46 @@ public Flowable runLive( } /** - * Runs the agent asynchronously with a default user ID. + * Runs the agent in live mode, appending generated events to the session. * - * @return stream of generated events. + * @return stream of events from the agent. */ - @Deprecated(since = "0.5.0", forRemoval = true) - public Flowable runWithSessionId( - String sessionId, Content newMessage, RunConfig runConfig) { - // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. - return this.runAsync("tmp-user", sessionId, newMessage, runConfig); + protected Flowable runLiveImpl( + Session session, @Nullable LiveRequestQueue liveRequestQueue, RunConfig runConfig) { + return Flowable.defer( + () -> { + Context capturedContext = Context.current(); + InvocationContext invocationContext = + newInvocationContextForLive(session, liveRequestQueue, runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }) + .compose(Tracing.withContext(capturedContext)); + }); } /** diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index b2a584b11..d9bb047a3 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -154,19 +154,13 @@ public Maybe getSession( if (config.numRecentEvents().isEmpty() && config.afterTimestamp().isPresent()) { Instant threshold = config.afterTimestamp().get(); - eventsInCopy.removeIf( - event -> getEventTimestampEpochSeconds(event) < threshold.getEpochSecond()); + eventsInCopy.removeIf(event -> getInstantFromEvent(event).isBefore(threshold)); } // Merge state into the potentially filtered copy and return return Maybe.just(mergeWithGlobalState(appName, userId, sessionCopy)); } - // Helper to get event timestamp as epoch seconds - private long getEventTimestampEpochSeconds(Event event) { - return event.timestamp() / 1000L; - } - @Override public Single listSessions(String appName, String userId) { Objects.requireNonNull(appName, "appName cannot be null"); @@ -294,10 +288,7 @@ public Single appendEvent(Session session, Event event) { /** Converts an event's timestamp to an Instant. Adapt based on actual Event structure. */ // TODO: have Event.timestamp() return Instant directly private Instant getInstantFromEvent(Event event) { - double epochSeconds = getEventTimestampEpochSeconds(event); - long seconds = (long) epochSeconds; - long nanos = (long) ((epochSeconds % 1.0) * 1_000_000_000L); - return Instant.ofEpochSecond(seconds, nanos); + return Instant.ofEpochMilli(event.timestamp()); } /** diff --git a/core/src/main/java/com/google/adk/sessions/PostgresSessionService.java b/core/src/main/java/com/google/adk/sessions/PostgresSessionService.java index c5d5c94c4..771b77d4d 100644 --- a/core/src/main/java/com/google/adk/sessions/PostgresSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/PostgresSessionService.java @@ -15,6 +15,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.UUID; @@ -255,7 +256,7 @@ public Single appendEvent(Session session, Event event) { // Apply state delta from event actions EventActions actions = event.actions(); if (actions != null) { - ConcurrentMap stateDelta = actions.stateDelta(); + Map stateDelta = actions.stateDelta(); if (stateDelta != null && !stateDelta.isEmpty()) { stateDelta.forEach( (key, value) -> { @@ -339,7 +340,7 @@ private void trimTempDeltaState(Event event) { if (event == null || event.actions() == null || event.actions().stateDelta() == null) { return; } - ConcurrentMap stateDelta = event.actions().stateDelta(); + Map stateDelta = event.actions().stateDelta(); stateDelta.entrySet().removeIf(entry -> entry.getKey().startsWith(State.TEMP_PREFIX)); } diff --git a/core/src/main/java/com/google/adk/sessions/Session.java b/core/src/main/java/com/google/adk/sessions/Session.java index f8376589a..94504fd96 100644 --- a/core/src/main/java/com/google/adk/sessions/Session.java +++ b/core/src/main/java/com/google/adk/sessions/Session.java @@ -27,8 +27,8 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; /** A {@link Session} object that encapsulates the {@link State} and {@link Event}s of a session. */ @JsonDeserialize(builder = Session.Builder.class) @@ -101,7 +101,7 @@ public Builder state(State state) { @CanIgnoreReturnValue @JsonProperty("state") - public Builder state(ConcurrentMap state) { + public Builder state(Map state) { this.state = new State(state); return this; } @@ -162,7 +162,7 @@ public String id() { } @JsonProperty("state") - public ConcurrentMap state() { + public Map state() { return state; } diff --git a/core/src/main/java/com/google/adk/telemetry/README.md b/core/src/main/java/com/google/adk/telemetry/README.md new file mode 100644 index 000000000..8665b3352 --- /dev/null +++ b/core/src/main/java/com/google/adk/telemetry/README.md @@ -0,0 +1,156 @@ +# ADK Telemetry and Tracing + +This package contains classes for capturing and reporting telemetry data within +the ADK, primarily for tracing agent execution leveraging OpenTelemetry. + +## Overview + +The `Tracing` utility class provides methods to trace various aspects of an +agent's execution, including: + +* Agent invocations +* LLM requests and responses +* Tool calls and responses + +These traces can be exported and visualized in telemetry backends like Google +Cloud Trace or Zipkin, or viewed through the ADK Dev Server UI, providing +observability into agent behavior. + +## How Tracing is Used + +Tracing is deeply integrated into the ADK's RxJava-based asynchronous workflows. + +### Agent Invocations + +Every agent's `runAsync` or `runLive` execution is wrapped in a span named +`invoke_agent `. The top-level agent invocation initiated by +`Runner.runAsync` or `Runner.runLive` is captured in a span named `invocation`. +Agent-specific metadata like name and description are added as span attributes, +following OpenTelemetry semantic conventions (e.g., `gen_ai.agent.name`). + +### LLM Calls + +Calls to Large Language Models (LLMs) are traced within a `call_llm` span. The +`traceCallLlm` method attaches detailed attributes to this span, including: + +* The LLM request (excluding large data like images) and response. +* Model name (`gen_ai.request.model`). +* Token usage (`gen_ai.usage.input_tokens`, `gen_ai.usage.output_tokens`). +* Configuration parameters (`gen_ai.request.top_p`, + `gen_ai.request.max_tokens`). +* Response finish reason (`gen_ai.response.finish_reasons`). + +### Tool Calls and Responses + +Tool executions triggered by the LLM are traced using `tool_call []` +and `tool_response []` spans. + +* `traceToolCall` records tool arguments in the + `gcp.vertex.agent.tool_call_args` attribute. +* `traceToolResponse` records tool output in the + `gcp.vertex.agent.tool_response` attribute. +* If multiple tools are called in parallel, a single `tool_response` span may + be created for the merged result. + +### Context Propagation + +ADK is built on RxJava and heavily uses asynchronous processing, which means +that work is often handed off between different threads. For tracing to work +correctly in such an environment, it's crucial that the active span's context +is propagated across these thread boundaries. If context is not propagated, +new spans may be orphaned or attached to the wrong parent, making traces +difficult to interpret. + +OpenTelemetry stores the currently active span in a thread-local variable. +When an asynchronous operation switches threads, this thread-local context is +lost. To solve this, ADK's `Tracing` class provides functionality to capture +the context on one thread and restore it on another when an asynchronous +operation resumes. This ensures that spans created on different threads are +correctly parented under the same trace. + +The primary mechanism for this is the `Tracing.withContext(context)` method, +which returns an RxJava transformer. When applied to an RxJava stream via +`.compose()`, this transformer ensures that the provided `Context` (containing +the parent span) is re-activated before any `onNext`, `onError`, `onComplete`, +or `onSuccess` signals are propagated downstream. It achieves this by wrapping +the downstream observer with a `TracingObserver`, which uses +`context.makeCurrent()` in a try-with-resources block around each callback, +guaranteeing that the correct span is active when downstream operators execute, +regardless of the thread. + +### RxJava Integration + +ADK integrates OpenTelemetry with RxJava streams to simplify span creation and +ensure context propagation: + +* **Span Creation**: The `Tracing.trace(spanName)` method returns an RxJava + transformer that can be applied to a `Flowable`, `Single`, `Maybe`, or + `Completable` using `.compose()`. This transformer wraps the stream's + execution in a new OpenTelemetry span. +* **Context Propagation**: The `Tracing.withContext(context)` transformer is + used with `.compose()` to ensure that the correct OpenTelemetry `Context` + (and thus the correct parent span) is active when stream operators or + subscriptions are executed, even across thread boundaries. + +## Trace Hierarchy Example + +A typical agent interaction might produce a trace hierarchy like the following: + +``` +invocation +└── invoke_agent my_agent + ├── call_llm + │ ├── tool_call [search_flights] + │ └── tool_response [search_flights] + └── call_llm +``` + +This shows: + +1. The overall `invocation` started by the `Runner`. +2. The invocation of `my_agent`. +3. The first `call_llm` made by `my_agent`. +4. A `tool_call` to `search_flights` and its corresponding `tool_response`. +5. A second `call_llm` made by `my_agent` to generate the final user response. + +### Nested Agents + +ADK supports nested agents, where one agent invokes another. If an agent has +sub-agents, it can transfer control to one of them using the built-in +`transfer_to_agent` tool. When `AgentA` calls `transfer_to_agent` to transfer +control to `AgentB`, the `invoke_agent AgentB` span will appear as a child of +the `invoke_agent AgentA` span, like so: + +``` +invocation +└── invoke_agent AgentA + ├── call_llm + │ ├── tool_call [transfer_to_agent] + │ └── tool_response [transfer_to_agent] + └── invoke_agent AgentB + ├── call_llm + └── ... +``` + +This structure allows you to see how `AgentA` delegated work to `AgentB`. + +## Span Creation References + +The following classes are the primary places where spans are created: + +* **`com.google.adk.runner.Runner`**: Initiates the top-level `invocation` + span for `runAsync` and `runLive`. +* **`com.google.adk.agents.BaseAgent`**: Creates the `invoke_agent + ` span for each agent execution. +* **`com.google.adk.flows.llmflows.BaseLlmFlow`**: Creates the `call_llm` span + when the LLM is invoked. +* **`com.google.adk.flows.llmflows.Functions`**: Creates `tool_call [...]` and + `tool_response [...]` spans when handling tool calls and responses. + +## Configuration + +**ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS**: This environment variable controls +whether LLM request/response content and tool arguments/responses are captured +in span attributes. It defaults to `true`. Set to `false` to exclude potentially +large or sensitive data from traces, in which case a `{}` JSON object will be +recorded instead. diff --git a/core/src/main/java/com/google/adk/tools/ExampleTool.java b/core/src/main/java/com/google/adk/tools/ExampleTool.java index d08481532..d03c2e4f1 100644 --- a/core/src/main/java/com/google/adk/tools/ExampleTool.java +++ b/core/src/main/java/com/google/adk/tools/ExampleTool.java @@ -85,7 +85,9 @@ public Completable processLlmRequest( return Completable.complete(); } - llmRequestBuilder.appendInstructions(ImmutableList.of(examplesBlock)); + if (!examplesBlock.isEmpty()) { + llmRequestBuilder.appendInstructions(ImmutableList.of(examplesBlock)); + } // Delegate to BaseTool to keep any declaration bookkeeping (none for this tool) return super.processLlmRequest(llmRequestBuilder, toolContext); } diff --git a/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java new file mode 100644 index 000000000..3b0e411b4 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java @@ -0,0 +1,68 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import com.google.adk.SchemaUtils; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import java.util.Optional; +import javax.annotation.Nonnull; + +/** + * Internal tool used for output schema workaround. + * + *

This tool allows the model to set its final response when output_schema is configured + * alongside other tools. The model should use this tool to provide its final structured response + * instead of outputting text directly. + */ +public class SetModelResponseTool extends BaseTool { + public static final String NAME = "set_model_response"; + + private final Schema outputSchema; + + public SetModelResponseTool(@Nonnull Schema outputSchema) { + super( + NAME, + "Set your final response using the required output schema. " + + "After using any other tools needed to complete the task, always call" + + " set_model_response with your final answer in the specified schema format."); + this.outputSchema = outputSchema; + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name(name()) + .description(description()) + .parameters(outputSchema) + .build()); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + // This tool is a marker for the final response, it doesn't do anything but return its arguments + // which will be captured as the final result. + return Single.fromCallable( + () -> { + SchemaUtils.validateMapOnSchema(args, outputSchema, /* isInput= */ false); + return args; + }); + } +} diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java index 207243ceb..4cafb9681 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java @@ -24,6 +24,8 @@ import com.google.adk.agents.ReadonlyContext; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ToolPredicate; +import com.google.common.collect.ImmutableList; import com.google.common.primitives.Booleans; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.ServerParameters; @@ -32,6 +34,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,7 +54,7 @@ public class McpToolset implements BaseToolset { private final McpSessionManager mcpSessionManager; private McpSyncClient mcpSession; private final ObjectMapper objectMapper; - private final Optional toolFilter; + private final @Nullable Object toolFilter; private static final int MAX_RETRIES = 3; private static final long RETRY_DELAY_MILLIS = 100; @@ -62,17 +65,29 @@ public class McpToolset implements BaseToolset { * * @param connectionParams The SSE connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( SseServerParameters connectionParams, ObjectMapper objectMapper, - Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); + } + + /** + * Initializes the McpToolset with SSE server parameters. + * + * @param connectionParams The SSE connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names + */ + public McpToolset( + SseServerParameters connectionParams, ObjectMapper objectMapper, List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); } /** @@ -82,7 +97,9 @@ public McpToolset( * @param objectMapper An ObjectMapper instance for parsing schemas. */ public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMapper) { - this(connectionParams, objectMapper, Optional.empty()); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -90,36 +107,39 @@ public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMappe * * @param connectionParams The local server connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( - ServerParameters connectionParams, ObjectMapper objectMapper, Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); } /** - * Initializes the McpToolset with local server parameters and no tool filter. + * Initializes the McpToolset with local server parameters. * * @param connectionParams The local server connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names */ - public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) { - this(connectionParams, objectMapper, Optional.empty()); + public McpToolset( + ServerParameters connectionParams, ObjectMapper objectMapper, List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); } /** - * Initializes the McpToolset with SSE server parameters, using the ObjectMapper used across the - * ADK. + * Initializes the McpToolset with local server parameters and no tool filter. * - * @param connectionParams The SSE connection parameters to the MCP server. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param connectionParams The local server connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. */ - public McpToolset(SseServerParameters connectionParams, Optional toolFilter) { - this(connectionParams, JsonBaseModel.getMapper(), toolFilter); + public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -129,28 +149,31 @@ public McpToolset(SseServerParameters connectionParams, Optional toolFil * @param connectionParams The SSE connection parameters to the MCP server. */ public McpToolset(SseServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + this(connectionParams, JsonBaseModel.getMapper()); } /** * Initializes the McpToolset with local server parameters, using the ObjectMapper used across the - * ADK. + * ADK and no tool filter. * * @param connectionParams The local server connection parameters to the MCP server. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. */ - public McpToolset(ServerParameters connectionParams, Optional toolFilter) { - this(connectionParams, JsonBaseModel.getMapper(), toolFilter); + public McpToolset(ServerParameters connectionParams) { + this(connectionParams, JsonBaseModel.getMapper()); } /** - * Initializes the McpToolset with local server parameters, using the ObjectMapper used across the - * ADK and no tool filter. + * Initializes the McpToolset with an McpSessionManager. * - * @param connectionParams The local server connection parameters to the MCP server. + * @param mcpSessionManager A McpSessionManager instance for testing. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolPredicate A {@link ToolPredicate} */ - public McpToolset(ServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + public McpToolset( + McpSessionManager mcpSessionManager, ObjectMapper objectMapper, ToolPredicate toolPredicate) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = Objects.requireNonNull(toolPredicate); } /** @@ -158,33 +181,69 @@ public McpToolset(ServerParameters connectionParams) { * * @param mcpSessionManager A McpSessionManager instance for testing. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolNames A list of tool names */ public McpToolset( - McpSessionManager mcpSessionManager, ObjectMapper objectMapper, Optional toolFilter) { - Objects.requireNonNull(mcpSessionManager); - Objects.requireNonNull(objectMapper); - this.mcpSessionManager = mcpSessionManager; - this.objectMapper = objectMapper; - this.toolFilter = toolFilter; + McpSessionManager mcpSessionManager, ObjectMapper objectMapper, List toolNames) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = ImmutableList.copyOf(toolNames); + } + + /** + * Initializes the McpToolset with an McpSessionManager and no tool filter. + * + * @param mcpSessionManager A McpSessionManager instance for testing. + * @param objectMapper An ObjectMapper instance for parsing schemas. + */ + public McpToolset(McpSessionManager mcpSessionManager, ObjectMapper objectMapper) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = null; } /** - * Initializes the McpToolset with Steamable HTTP server parameters. + * Initializes the McpToolset with Streamable HTTP server parameters. * * @param connectionParams The Streamable HTTP connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( StreamableHttpServerParameters connectionParams, ObjectMapper objectMapper, - Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); + } + + /** + * Initializes the McpToolset with Streamable HTTP server parameters. + * + * @param connectionParams The Streamable HTTP connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names + */ + public McpToolset( + StreamableHttpServerParameters connectionParams, + ObjectMapper objectMapper, + List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); + } + + /** + * Initializes the McpToolset with Streamable HTTP server parameters and no tool filter. + * + * @param connectionParams The Streamable HTTP connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + */ + public McpToolset(StreamableHttpServerParameters connectionParams, ObjectMapper objectMapper) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -194,7 +253,7 @@ public McpToolset( * @param connectionParams The Streamable HTTP connection parameters to the MCP server. */ public McpToolset(StreamableHttpServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + this(connectionParams, JsonBaseModel.getMapper()); } @Override @@ -215,8 +274,7 @@ public Flowable getTools(ReadonlyContext readonlyContext) { tool -> new McpTool( tool, this.mcpSession, this.mcpSessionManager, this.objectMapper)) - .filter( - tool -> isToolSelected(tool, toolFilter.orElse(null), readonlyContext))); + .filter(tool -> isToolSelected(tool, toolFilter, readonlyContext))); }) .retryWhen( errorObservable -> @@ -357,16 +415,18 @@ public static McpToolset fromConfig(BaseTool.ToolConfig config, String configAbs + " for McpToolset"); } - // Convert tool filter to Optional - Optional toolFilter = Optional.ofNullable(mcpToolsetConfig.toolFilter()); - + List toolNames = mcpToolsetConfig.toolFilter(); Object connectionParameters = Optional.ofNullable(mcpToolsetConfig.stdioConnectionParams()) .or(() -> Optional.ofNullable(mcpToolsetConfig.sseServerParams())) .orElse(mcpToolsetConfig.stdioConnectionParams()); // Create McpToolset with McpSessionManager having appropriate connection parameters - return new McpToolset(new McpSessionManager(connectionParameters), mapper, toolFilter); + if (toolNames != null) { + return new McpToolset(new McpSessionManager(connectionParameters), mapper, toolNames); + } else { + return new McpToolset(new McpSessionManager(connectionParameters), mapper); + } } catch (IllegalArgumentException e) { throw new ConfigurationException("Failed to parse McpToolsetConfig from ToolArgsConfig", e); } diff --git a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java index c46f6e3a8..cf0f2221e 100644 --- a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java +++ b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java @@ -21,6 +21,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +/** Utility class for model names. */ public final class ModelNameUtils { private static final String GEMINI_PREFIX = "gemini-"; private static final Pattern GEMINI_2_PATTERN = Pattern.compile("^gemini-2\\..*"); @@ -35,11 +36,15 @@ public static boolean isGeminiModel(String modelString) { } public static boolean isGemini2Model(String modelString) { + return matchesModelPattern(modelString, GEMINI_2_PATTERN); + } + + private static boolean matchesModelPattern(String modelString, Pattern pattern) { if (modelString == null) { return false; } String modelName = extractModelName(modelString); - return GEMINI_2_PATTERN.matcher(modelName).matches(); + return pattern.matcher(modelName).matches(); } /** @@ -65,6 +70,17 @@ public static boolean isInstanceOfGemini(Object o) { return false; } + /** + * Returns true if the model supports using output schema together with tools. + * + * @param modelString The model name or path. + * @return true if output schema with tools is supported, false otherwise. + */ + public static boolean canUseOutputSchemaWithTools(String modelString) { + // Current limitation for Vertex AI 2.x models. + return !isGemini2Model(modelString); + } + /** * Extract the actual model name from either simple or path-based format. * diff --git a/core/src/test/java/com/google/adk/VersionTest.java b/core/src/test/java/com/google/adk/VersionTest.java index ff7939165..4b6f55c9b 100644 --- a/core/src/test/java/com/google/adk/VersionTest.java +++ b/core/src/test/java/com/google/adk/VersionTest.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; +import java.util.regex.Pattern; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -25,6 +26,11 @@ @RunWith(JUnit4.class) public class VersionTest { + // from semver.org + private static final Pattern SEM_VER = + Pattern.compile( + "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$"); + @Test public void versionShouldMatchProjectVersion() { assertThat(Version.JAVA_ADK_VERSION).isNotNull(); @@ -32,6 +38,11 @@ public void versionShouldMatchProjectVersion() { assertThat(Version.JAVA_ADK_VERSION).isNotEqualTo("unknown"); assertThat(Version.JAVA_ADK_VERSION).isNotEqualTo("${project.version}"); - assertThat(Version.JAVA_ADK_VERSION).matches("\\d+\\.\\d+\\.\\d+(-SNAPSHOT)?"); + assertThat(Version.JAVA_ADK_VERSION).matches("\\d+\\.\\d+\\.\\d+(-SNAPSHOT|-rc\\.\\d+)?"); + } + + @Test + public void versionShouldFollowSemanticVersioning() { + assertThat(Version.JAVA_ADK_VERSION).matches(SEM_VER); } } diff --git a/core/src/test/java/com/google/adk/agents/AgentWithMemoryTest.java b/core/src/test/java/com/google/adk/agents/AgentWithMemoryTest.java index d5edfc876..361c5eb6b 100644 --- a/core/src/test/java/com/google/adk/agents/AgentWithMemoryTest.java +++ b/core/src/test/java/com/google/adk/agents/AgentWithMemoryTest.java @@ -41,7 +41,7 @@ public final class AgentWithMemoryTest { @Test public void agentRemembersUserNameWithMemoryTool() throws Exception { String userId = "test-user"; - String agentName = "test-agent"; + String agentName = "test_agent"; Part functionCall = Part.builder() @@ -101,7 +101,7 @@ public void agentRemembersUserNameWithMemoryTool() throws Exception { Session updatedSession = runner .sessionService() - .getSession("test-agent", userId, sessionId, Optional.empty()) + .getSession("test_agent", userId, sessionId, Optional.empty()) .blockingGet(); // Save the updated session to memory so we can bring it up on the next request. diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 594e47fd8..a9e7a6f8d 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -26,7 +26,6 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThrows; import com.google.adk.agents.Callbacks.AfterModelCallback; import com.google.adk.agents.Callbacks.AfterToolCallback; @@ -35,6 +34,7 @@ import com.google.adk.agents.Callbacks.OnModelErrorCallback; import com.google.adk.agents.Callbacks.OnToolErrorCallback; import com.google.adk.events.Event; +import com.google.adk.examples.Example; import com.google.adk.models.LlmRegistry; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; @@ -46,13 +46,14 @@ import com.google.adk.testing.TestUtils.EchoTool; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ExampleTool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; -import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Part; import com.google.genai.types.Schema; +import com.google.genai.types.Type; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; @@ -61,7 +62,6 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.After; @@ -211,75 +211,6 @@ public void run_withToolsAndMaxSteps_stopsAfterMaxSteps() { assertEqualIgnoringFunctionIds(events.get(3).content().get(), expectedFunctionResponseContent); } - @Test - public void build_withOutputSchemaAndTools_throwsIllegalArgumentException() { - BaseTool tool = - new BaseTool("test_tool", "test_description") { - @Override - public Optional declaration() { - return Optional.empty(); - } - }; - - Schema outputSchema = - Schema.builder() - .type("OBJECT") - .properties(ImmutableMap.of("status", Schema.builder().type("STRING").build())) - .required(ImmutableList.of("status")) - .build(); - - // Expecting an IllegalArgumentException when building the agent - IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> - LlmAgent.builder() // Use the agent builder directly - .name("agent with invalid tool config") - .outputSchema(outputSchema) // Set the output schema - .tools(ImmutableList.of(tool)) // Set tools (this should cause the error) - .build()); // Attempt to build the agent - - assertThat(exception) - .hasMessageThat() - .contains( - "Invalid config for agent agent with invalid tool config: if outputSchema is set, tools" - + " must be empty"); - } - - @Test - public void build_withOutputSchemaAndSubAgents_throwsIllegalArgumentException() { - ImmutableList subAgents = - ImmutableList.of( - createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .name("test_sub_agent") - .description("test_sub_agent_description") - .build()); - - Schema outputSchema = - Schema.builder() - .type("OBJECT") - .properties(ImmutableMap.of("status", Schema.builder().type("STRING").build())) - .required(ImmutableList.of("status")) - .build(); - - // Expecting an IllegalArgumentException when building the agent - IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> - LlmAgent.builder() // Use the agent builder directly - .name("agent with invalid tool config") - .outputSchema(outputSchema) // Set the output schema - .subAgents(subAgents) // Set subAgents (this should cause the error) - .build()); // Attempt to build the agent - - assertThat(exception) - .hasMessageThat() - .contains( - "Invalid config for agent agent with invalid tool config: if outputSchema is set," - + " subAgents must be empty to disable agent transfer."); - } - @Test public void testBuild_withNullInstruction_setsInstructionToEmptyString() { LlmAgent agent = @@ -572,8 +503,13 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { String agentSpanId = agentSpan.getSpanContext().getSpanId(); llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); - toolCallSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); - toolResponseSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); + + // The tool calls and responses are children of the first LLM call that produced the function + // call. + String firstLlmSpanId = llmSpans.get(0).getSpanContext().getSpanId(); + toolCallSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); + toolResponseSpans.forEach( + s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); } @Test @@ -638,6 +574,31 @@ public void runAsync_withSubAgents_createsSpans() throws InterruptedException { assertThat(llmSpans).hasSize(2); // One for main agent, one for sub agent } + @Test + public void run_outputSchemaWithTools_allowed() { + Schema personShema = + Schema.builder() + .type(Type.Known.OBJECT) + .properties( + ImmutableMap.of( + "name", Schema.builder().type(Type.Known.STRING).build(), + "age", Schema.builder().type(Type.Known.INTEGER).build(), + "city", Schema.builder().type(Type.Known.STRING).build())) + .build(); + LlmAgent agent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .outputSchema(personShema) + .tools(new EchoTool()) + .build(); + assertThat(agent.outputSchema()).hasValue(personShema); + assertThat( + agent + .canonicalTools(new ReadonlyContext(createInvocationContext(agent))) + .count() + .blockingGet()) + .isEqualTo(1); + } + private List findSpansByName(List spans, String name) { return spans.stream().filter(s -> s.getName().equals(name)).toList(); } @@ -649,4 +610,30 @@ private SpanData findSpanByName(List spans, String name) { .findFirst() .orElseThrow(() -> new AssertionError("Span not found: " + name)); } + + @Test + public void run_withExampleTool_doesNotAddFunctionDeclarations() { + ExampleTool tool = + ExampleTool.builder() + .addExample( + Example.builder() + .input(Content.fromParts(Part.fromText("qin"))) + .output(ImmutableList.of(Content.fromParts(Part.fromText("qout")))) + .build()) + .build(); + + Content modelContent = Content.fromParts(Part.fromText("Real LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(modelContent)); + LlmAgent agent = createTestAgentBuilder(testLlm).tools(tool).build(); + InvocationContext invocationContext = createInvocationContext(agent); + + var unused = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(testLlm.getRequests()).hasSize(1); + LlmRequest request = testLlm.getRequests().get(0); + + assertThat(request.config().isPresent()).isTrue(); + var config = request.config().get(); + assertThat(config.tools().isPresent()).isFalse(); + } } diff --git a/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceIT.java b/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceIT.java index 55080c2f0..54faffbc6 100644 --- a/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceIT.java +++ b/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceIT.java @@ -25,7 +25,6 @@ import io.reactivex.rxjava3.core.Maybe; import java.net.InetSocketAddress; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -79,9 +78,7 @@ public void testSaveAndLoadArtifact() { assertThat(version).isEqualTo(0); Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); assertThat(loadedArtifact.text().get()).isEqualTo("hello world"); } @@ -100,7 +97,7 @@ public void testDeleteArtifact() { artifactService.deleteArtifact(appName, userId, sessionId, filename).blockingAwait(); Maybe loadedArtifact = - artifactService.loadArtifact(appName, userId, sessionId, filename, Optional.of(version)); + artifactService.loadArtifact(appName, userId, sessionId, filename, version); assertThat(loadedArtifact.blockingGet()).isNull(); } @@ -135,9 +132,7 @@ public void testSaveAndLoadBinaryArtifact() { assertThat(version).isEqualTo(0); Part loadedBinaryArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); assertThat(loadedBinaryArtifact.inlineData().get().data().get()).isEqualTo(binaryData); assertThat(loadedBinaryArtifact.inlineData().get().mimeType().get()) .isEqualTo("application/octet-stream"); diff --git a/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceTest.java index 8c4f9ca1b..78b528fe1 100644 --- a/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/CassandraArtifactServiceTest.java @@ -32,7 +32,6 @@ import java.nio.ByteBuffer; import java.util.Collections; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -108,9 +107,7 @@ public void testSaveAndLoadArtifact() throws Exception { when(mockObjectMapper.readValue(artifactData, Part.class)).thenReturn(artifact); Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); assertThat(loadedArtifact).isEqualTo(artifact); } diff --git a/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceIT.java b/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceIT.java index ef600fd9e..70ecc3d9f 100644 --- a/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceIT.java +++ b/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceIT.java @@ -21,7 +21,6 @@ import com.google.genai.types.Part; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -122,7 +121,7 @@ public void testSaveAndLoadArtifact_Success() { // Act - Load Part loadedArtifact = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.empty()) + .loadArtifact(testAppName, testUserId, testSessionId, filename, null) .blockingGet(); // Assert - Content matches @@ -161,15 +160,15 @@ public void testVersioning_MultipleVersions() { // Act - Load specific versions Part loaded0 = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.of(0)) + .loadArtifact(testAppName, testUserId, testSessionId, filename, 0) .blockingGet(); Part loaded1 = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.of(1)) + .loadArtifact(testAppName, testUserId, testSessionId, filename, 1) .blockingGet(); Part loaded2 = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.of(2)) + .loadArtifact(testAppName, testUserId, testSessionId, filename, 2) .blockingGet(); // Assert - Each version contains correct content @@ -180,7 +179,7 @@ public void testVersioning_MultipleVersions() { // Act - Load latest (should be version 2) Part loadedLatest = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.empty()) + .loadArtifact(testAppName, testUserId, testSessionId, filename, null) .blockingGet(); // Assert - Latest is version 2 @@ -224,7 +223,7 @@ public void testDeleteArtifact() { // Verify it exists Part beforeDelete = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.empty()) + .loadArtifact(testAppName, testUserId, testSessionId, filename, null) .blockingGet(); assertThat(beforeDelete).isNotNull(); @@ -236,7 +235,7 @@ public void testDeleteArtifact() { // Assert - No longer exists Part afterDelete = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.empty()) + .loadArtifact(testAppName, testUserId, testSessionId, filename, null) .blockingGet(); assertThat(afterDelete).isNull(); } @@ -284,13 +283,9 @@ public void testMultiTenancy_AppNameIsolation() { // Act - Load from each app Part fromApp1 = - artifactService - .loadArtifact(app1, userId, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(app1, userId, sessionId, filename, null).blockingGet(); Part fromApp2 = - artifactService - .loadArtifact(app2, userId, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(app2, userId, sessionId, filename, null).blockingGet(); // Assert - Content is isolated assertThat(fromApp1.text()).isEqualTo("App1 content"); @@ -320,13 +315,9 @@ public void testMultiTenancy_UserIdIsolation() { // Act - Load from each user Part fromUser1 = - artifactService - .loadArtifact(appName, user1, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, user1, sessionId, filename, null).blockingGet(); Part fromUser2 = - artifactService - .loadArtifact(appName, user2, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, user2, sessionId, filename, null).blockingGet(); // Assert - Content is isolated assertThat(fromUser1.text()).isEqualTo("User1 content"); @@ -356,13 +347,9 @@ public void testMultiTenancy_SessionIdIsolation() { // Act - Load from each session Part fromSession1 = - artifactService - .loadArtifact(appName, userId, session1, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, userId, session1, filename, null).blockingGet(); Part fromSession2 = - artifactService - .loadArtifact(appName, userId, session2, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, userId, session2, filename, null).blockingGet(); // Assert - Content is isolated assertThat(fromSession1.text()).isEqualTo("Session1 content"); @@ -398,8 +385,7 @@ public void testLoadArtifact_NonExistent() { // Act Part result = artifactService - .loadArtifact( - testAppName, testUserId, testSessionId, "nonexistent.txt", Optional.empty()) + .loadArtifact(testAppName, testUserId, testSessionId, "nonexistent.txt", null) .blockingGet(); // Assert @@ -417,7 +403,7 @@ public void testLoadArtifact_NonExistentVersion() { // Act - Try to load non-existent version 99 Part result = artifactService - .loadArtifact(testAppName, testUserId, testSessionId, filename, Optional.of(99)) + .loadArtifact(testAppName, testUserId, testSessionId, filename, 99) .blockingGet(); // Assert diff --git a/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java index 509f72d9f..53e6d1e20 100644 --- a/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java @@ -33,7 +33,6 @@ import java.sql.Timestamp; import java.util.Arrays; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -176,9 +175,7 @@ public void testLoadArtifact_LatestVersion() throws Exception { // Act Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, null).blockingGet(); // Assert assertThat(loadedArtifact).isNotNull(); @@ -207,9 +204,7 @@ public void testLoadArtifact_SpecificVersion() throws Exception { // Act Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); // Assert assertThat(loadedArtifact).isNotNull(); @@ -231,9 +226,7 @@ public void testLoadArtifact_NotFound() throws Exception { // Act Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.empty()) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, null).blockingGet(); // Assert assertThat(loadedArtifact).isNull(); diff --git a/core/src/test/java/com/google/adk/artifacts/RedisArtifactServiceIT.java b/core/src/test/java/com/google/adk/artifacts/RedisArtifactServiceIT.java index 92e1772bf..cdd3a0f3b 100644 --- a/core/src/test/java/com/google/adk/artifacts/RedisArtifactServiceIT.java +++ b/core/src/test/java/com/google/adk/artifacts/RedisArtifactServiceIT.java @@ -20,7 +20,6 @@ import com.google.adk.store.RedisHelper; import com.google.genai.types.Part; -import java.util.Optional; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -64,9 +63,7 @@ public void testSaveAndLoadArtifact() { assertThat(version).isEqualTo(0); Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); assertThat(loadedArtifact.text().get()).isEqualTo("hello world"); } @@ -84,9 +81,7 @@ public void testSaveAndLoadBinaryArtifact() { assertThat(version).isEqualTo(0); Part loadedArtifact = - artifactService - .loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .blockingGet(); + artifactService.loadArtifact(appName, userId, sessionId, filename, version).blockingGet(); assertThat(loadedArtifact.inlineData().get().data().get()).isEqualTo(binaryData); assertThat(loadedArtifact.inlineData().get().mimeType().get()) .isEqualTo("application/octet-stream"); diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 22bb94e64..b1e645e1a 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -177,17 +177,6 @@ public void merge_failsOnMismatchedKeyTypesNestedInStateDelta() { IllegalArgumentException.class, () -> eventActions1.toBuilder().merge(eventActions2)); } - @Test - public void setRequestedToolConfirmations_withConcurrentMap_usesSameInstance() { - ConcurrentHashMap map = new ConcurrentHashMap<>(); - map.put("tool", TOOL_CONFIRMATION); - - EventActions actions = new EventActions(); - actions.setRequestedToolConfirmations(map); - - assertThat(actions.requestedToolConfirmations()).isSameInstanceAs(map); - } - @Test public void setRequestedToolConfirmations_withRegularMap_createsConcurrentMap() { ImmutableMap map = ImmutableMap.of("tool", TOOL_CONFIRMATION); diff --git a/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java b/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java index 4a1dcf8e3..2d22ed3f1 100644 --- a/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java +++ b/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java @@ -49,17 +49,7 @@ public List getExamples(String query) { @Test public void buildFewShotFewShot_noExamples() { TestExampleProvider exampleProvider = new TestExampleProvider(ImmutableList.of()); - String expected = - """ - - Begin few-shot - The following are examples of user queries and model responses using the available tools. - - End few-shot - Now, try to follow these examples and complete the following conversation - \ - """; - assertThat(ExampleUtils.buildExampleSi(exampleProvider, "test query")).isEqualTo(expected); + assertThat(ExampleUtils.buildExampleSi(exampleProvider, "test query")).isEmpty(); } @Test diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 4a0b345c6..6cae6c88a 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -43,9 +43,13 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.List; import java.util.Map; import java.util.Optional; @@ -572,6 +576,71 @@ public Single> runAsync(Map args, ToolContex } } + @Test + public void run_contextPropagation() { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + RequestProcessor requestProcessor = + (ctx, request) -> { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())) + .subscribeOn(Schedulers.computation()); + }; + + ResponseProcessor responseProcessor = + (ctx, response) -> { + return Single.just(ResponseProcessingResult.create(response, ImmutableList.of())) + .subscribeOn(Schedulers.computation()); + }; + + Callbacks.BeforeModelCallback beforeCallback = + (ctx, req) -> { + return Maybe.empty().subscribeOn(Schedulers.computation()); + }; + + Callbacks.AfterModelCallback afterCallback = + (ctx, resp) -> { + return Maybe.just(resp).subscribeOn(Schedulers.computation()); + }; + + Callbacks.OnModelErrorCallback onErrorCallback = + (ctx, req, err) -> { + return Maybe.just( + LlmResponse.builder().content(Content.fromParts(Part.fromText("error"))).build()) + .subscribeOn(Schedulers.computation()); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .beforeModelCallback(beforeCallback) + .afterModelCallback(afterCallback) + .onModelErrorCallback(onErrorCallback) + .build()); + + BaseLlmFlow baseLlmFlow = + createBaseLlmFlow(ImmutableList.of(requestProcessor), ImmutableList.of(responseProcessor)); + + List events; + try (Scope scope = testContext.makeCurrent()) { + events = + baseLlmFlow + .run(invocationContext) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .toList() + .blockingGet(); + } + + assertThat(events).hasSize(1); + assertThat(events.get(0).content()).hasValue(content); + } + @Test public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { GenerateContentResponseUsageMetadata usageMetadata = @@ -588,7 +657,12 @@ public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { List events = baseLlmFlow - .postprocess(invocationContext, baseEvent, LlmRequest.builder().build(), llmResponse) + .postprocess( + invocationContext, + baseEvent, + LlmRequest.builder().build(), + llmResponse, + Context.current()) .toList() .blockingGet(); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 7164991f3..1e6267dde 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -36,13 +36,15 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; -import java.util.HashMap; +import java.util.ArrayList; +import java.util.ConcurrentModificationException; +import java.util.Iterator; import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -787,15 +789,49 @@ public void processRequest_notEmptyContent() { public void processRequest_concurrentReadAndWrite_noException() throws Exception { LlmAgent agent = LlmAgent.builder().name(AGENT).includeContents(LlmAgent.IncludeContents.DEFAULT).build(); + List customEvents = + new ArrayList() { + private void checkLock() { + if (!Thread.holdsLock(this)) { + throw new ConcurrentModificationException("Unsynchronized iteration detected!"); + } + } + + @Override + public Iterator iterator() { + checkLock(); + return super.iterator(); + } + + @Override + public ListIterator listIterator() { + checkLock(); + return super.listIterator(); + } + + @Override + public ListIterator listIterator(int index) { + checkLock(); + return super.listIterator(index); + } + + @Override + public Stream stream() { + checkLock(); + return super.stream(); + } + }; + Session session = - sessionService - .createSession("test-app", "test-user", new HashMap<>(), "test-session") - .blockingGet(); + Session.builder("test-session") + .appName("test-app") + .userId("test-user") + .events(customEvents) + .build(); - // Seed with dummy events to widen the race capability - for (int i = 0; i < 5000; i++) { - session.events().add(createUserEvent("dummy" + i, "dummy")); - } + // The list must have at least one element so that operations interacting with events trigger + // iteration. + customEvents.add(createUserEvent("dummy", "dummy")); InvocationContext context = InvocationContext.builder() @@ -807,37 +843,8 @@ public void processRequest_concurrentReadAndWrite_noException() throws Exception LlmRequest initialRequest = LlmRequest.builder().build(); - AtomicReference writerError = new AtomicReference<>(); - CountDownLatch startLatch = new CountDownLatch(1); - - Thread writerThread = - new Thread( - () -> { - startLatch.countDown(); - try { - for (int i = 0; i < 2000; i++) { - session.events().add(createUserEvent("writer" + i, "new data")); - } - } catch (Throwable t) { - writerError.set(t); - } - }); - - writerThread.start(); - startLatch.await(); // wait for writer to be ready - - // Process (read) requests concurrently to trigger race conditions - for (int i = 0; i < 200; i++) { - var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); - if (writerError.get() != null) { - throw new RuntimeException("Writer failed", writerError.get()); - } - } - - writerThread.join(); - if (writerError.get() != null) { - throw new RuntimeException("Writer failed", writerError.get()); - } + // This single call will throw the exception if the list is accessed insecurely. + var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); } private static Event createUserEvent(String id, String text) { diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java deleted file mode 100644 index 7d1615dc2..000000000 --- a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.flows.llmflows; - -import static com.google.common.truth.Truth.assertThat; - -import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.RunConfig; -import com.google.adk.examples.BaseExampleProvider; -import com.google.adk.examples.Example; -import com.google.adk.models.LlmRequest; -import com.google.adk.sessions.InMemorySessionService; -import com.google.adk.sessions.Session; -import com.google.common.collect.ImmutableList; -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import java.util.List; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public final class ExamplesTest { - - private static final InMemorySessionService sessionService = new InMemorySessionService(); - - private static class TestExampleProvider implements BaseExampleProvider { - @Override - public List getExamples(String query) { - return ImmutableList.of( - Example.builder() - .input(Content.fromParts(Part.fromText("input1"))) - .output( - ImmutableList.of( - Content.builder().parts(Part.fromText("output1")).role("model").build())) - .build()); - } - } - - @Test - public void processRequest_withExampleProvider_addsExamplesToInstructions() { - LlmAgent agent = - LlmAgent.builder().name("test-agent").exampleProvider(new TestExampleProvider()).build(); - InvocationContext context = - InvocationContext.builder() - .invocationId("invocation1") - .session(Session.builder("session1").build()) - .sessionService(sessionService) - .agent(agent) - .userContent(Content.fromParts(Part.fromText("what is up?"))) - .runConfig(RunConfig.builder().build()) - .build(); - LlmRequest request = LlmRequest.builder().build(); - Examples examplesProcessor = new Examples(); - - RequestProcessor.RequestProcessingResult result = - examplesProcessor.processRequest(context, request).blockingGet(); - - assertThat(result.updatedRequest().getSystemInstructions()).isNotEmpty(); - assertThat(result.updatedRequest().getSystemInstructions().get(0)) - .contains("[user]\ninput1\n\n[model]\noutput1\n"); - } - - @Test - public void processRequest_withoutExampleProvider_doesNotAddExamplesToInstructions() { - LlmAgent agent = LlmAgent.builder().name("test-agent").build(); - InvocationContext context = - InvocationContext.builder() - .invocationId("invocation1") - .session(Session.builder("session1").build()) - .sessionService(sessionService) - .agent(agent) - .userContent(Content.fromParts(Part.fromText("what is up?"))) - .runConfig(RunConfig.builder().build()) - .build(); - LlmRequest request = LlmRequest.builder().build(); - Examples examplesProcessor = new Examples(); - - RequestProcessor.RequestProcessingResult result = - examplesProcessor.processRequest(context, request).blockingGet(); - - assertThat(result.updatedRequest().getSystemInstructions()).isEmpty(); - } -} diff --git a/core/src/test/java/com/google/adk/flows/llmflows/OutputSchemaTest.java b/core/src/test/java/com/google/adk/flows/llmflows/OutputSchemaTest.java new file mode 100644 index 000000000..ffd56de6c --- /dev/null +++ b/core/src/test/java/com/google/adk/flows/llmflows/OutputSchemaTest.java @@ -0,0 +1,192 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.flows.llmflows; + +import static com.google.adk.testing.TestUtils.createInvocationContext; +import static com.google.adk.testing.TestUtils.createTestLlm; +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.events.Event; +import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.testing.TestLlm; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.SetModelResponseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import com.google.genai.types.Schema; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class OutputSchemaTest { + + private static final Schema TEST_OUTPUT_SCHEMA = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + private OutputSchema outputSchemaProcessor; + private TestLlm testLlm; + private LlmRequest initialRequest; + + @Before + public void setUp() { + outputSchemaProcessor = new OutputSchema(); + testLlm = createTestLlm(LlmResponse.builder().build()); + initialRequest = LlmRequest.builder().model("gemini-2.0-pro").build(); + } + + public static class TestTool extends BaseTool { + public TestTool() { + super("test_tool", "test description"); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + return Single.just(ImmutableMap.of()); + } + } + + @Test + public void processRequest_noOutputSchema_doesNothing() { + LlmAgent agent = + LlmAgent.builder() + .name("agent") + .model(testLlm) + .tools(ImmutableList.of(new TestTool())) + .build(); + InvocationContext context = createInvocationContext(agent); + + RequestProcessingResult result = + outputSchemaProcessor.processRequest(context, initialRequest).blockingGet(); + + assertThat(result.updatedRequest()).isEqualTo(initialRequest); + assertThat(result.events()).isEmpty(); + } + + @Test + public void processRequest_noTools_doesNothing() { + LlmAgent agent = + LlmAgent.builder().name("agent").model(testLlm).outputSchema(TEST_OUTPUT_SCHEMA).build(); + InvocationContext context = createInvocationContext(agent); + + RequestProcessingResult result = + outputSchemaProcessor.processRequest(context, initialRequest).blockingGet(); + + assertThat(result.updatedRequest()).isEqualTo(initialRequest); + assertThat(result.events()).isEmpty(); + } + + @Test + public void processRequest_withOutputSchemaAndTools_addsSetModelResponseTool() { + LlmAgent agent = + LlmAgent.builder() + .name("agent") + .model(testLlm) + .outputSchema(TEST_OUTPUT_SCHEMA) + .tools(ImmutableList.of(new TestTool())) + .build(); + InvocationContext context = createInvocationContext(agent); + LlmRequest requestWithTools = + LlmRequest.builder() + .model("gemini-2.5-pro") + .tools(ImmutableMap.of("test_tool", new TestTool())) + .build(); + + RequestProcessingResult result = + outputSchemaProcessor.processRequest(context, requestWithTools).blockingGet(); + + LlmRequest updatedRequest = result.updatedRequest(); + assertThat(updatedRequest.tools()).hasSize(2); + assertThat( + updatedRequest.tools().values().stream() + .anyMatch(t -> t instanceof SetModelResponseTool)) + .isTrue(); + assertThat(updatedRequest.tools().values().stream().anyMatch(t -> t.name().equals("test_tool"))) + .isTrue(); + assertThat(updatedRequest.getSystemInstructions()).isNotEmpty(); + assertThat(updatedRequest.getSystemInstructions().get(0)) + .contains("you must provide your final response using the set_model_response tool"); + assertThat(result.events()).isEmpty(); + } + + @Test + public void getStructuredModelResponse_withSetModelResponse_returnsJson() { + FunctionResponse fr = + FunctionResponse.builder() + .name(SetModelResponseTool.NAME) + .response(ImmutableMap.of("field1", "value1")) + .build(); + Event event = + Event.builder() + .content( + Content.builder() + .parts(Part.builder().functionResponse(fr).build()) + .role("model") + .build()) + .build(); + + assertThat(OutputSchema.getStructuredModelResponse(event)).hasValue("{\"field1\":\"value1\"}"); + } + + @Test + public void getStructuredModelResponse_withoutSetModelResponse_returnsEmpty() { + FunctionResponse fr = + FunctionResponse.builder() + .name("other_tool") + .response(ImmutableMap.of("field1", "value1")) + .build(); + Event event = + Event.builder() + .content( + Content.builder() + .parts(Part.builder().functionResponse(fr).build()) + .role("model") + .build()) + .build(); + + assertThat(OutputSchema.getStructuredModelResponse(event)).isEmpty(); + } + + @Test + public void createFinalModelResponseEvent_createsModelResponseEvent() { + LlmAgent agent = LlmAgent.builder().name("agent").model(testLlm).build(); + InvocationContext context = createInvocationContext(agent); + String jsonResponse = "{\"field1\":\"value1\"}"; + + Event event = OutputSchema.createFinalModelResponseEvent(context, jsonResponse); + + assertThat(event.invocationId()).isEqualTo(context.invocationId()); + assertThat(event.author()).isEqualTo("agent"); + assertThat(event.content().get().role()).hasValue("model"); + assertThat(event.content().get().parts().get()).containsExactly(Part.fromText(jsonResponse)); + } +} diff --git a/core/src/test/java/com/google/adk/models/GeminiTest.java b/core/src/test/java/com/google/adk/models/GeminiTest.java index 07dd675e5..c230f5f68 100644 --- a/core/src/test/java/com/google/adk/models/GeminiTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiTest.java @@ -22,6 +22,7 @@ import com.google.genai.types.Content; import com.google.genai.types.FinishReason; import com.google.genai.types.GenerateContentResponse; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.functions.Predicate; @@ -123,6 +124,76 @@ public void processRawResponses_textThenEmpty_emitsPartialTextThenFullTextAndEmp isEmptyResponse()); } + @Test + public void processRawResponses_withTextChunks_partialResponsesIncludeUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + Flowable rawResponses = + Flowable.just( + toResponseWithText("Hello", metadata1), toResponseWithText(" world", metadata2)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponseWithUsageMetadata("Hello", metadata1), + isPartialTextResponseWithUsageMetadata(" world", metadata2)); + } + + @Test + public void processRawResponses_textAndStopReason_finalResponseIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(10, 20, 30); + Flowable rawResponses = + Flowable.just( + toResponseWithText("Hello"), + toResponseWithText(" world", FinishReason.Known.STOP, metadata)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponse("Hello"), + isPartialTextResponseWithUsageMetadata(" world", metadata), + isFinalTextResponseWithUsageMetadata("Hello world", metadata)); + } + + @Test + public void processRawResponses_thoughtChunksAndStop_includeUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + Flowable rawResponses = + Flowable.just( + toResponseWithThoughtText("Thinking", metadata1), + toResponseWithThoughtText(" deeply", FinishReason.Known.STOP, metadata2)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isPartialThoughtResponseWithUsageMetadata(" deeply", metadata2), + isFinalThoughtResponseWithUsageMetadata("Thinking deeply", metadata2)); + } + + @Test + public void processRawResponses_thoughtAndTextWithStop_onlyFinalTextIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 5, 10); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(10, 20, 30); + Flowable rawResponses = + Flowable.just( + toResponseWithThoughtText("Thinking", metadata1), + toResponseWithText("Answer", FinishReason.Known.STOP, metadata2)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isPartialTextResponseWithUsageMetadata("Answer", metadata2), + isFinalThoughtResponseWithNoUsageMetadata("Thinking"), + isFinalTextResponseWithUsageMetadata("Answer", metadata2)); + } + // Helper methods for assertions private void assertLlmResponses( @@ -170,6 +241,67 @@ private static Predicate isEmptyResponse() { }; } + private static Predicate isPartialTextResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isPartialThoughtResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) + .isTrue(); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalTextResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalThoughtResponseWithUsageMetadata( + String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial()).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) + .isTrue(); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalThoughtResponseWithNoUsageMetadata( + String expectedText) { + return response -> { + assertThat(response.partial()).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) + .isTrue(); + assertThat(response.usageMetadata()).isEmpty(); + return true; + }; + } + // Helper methods to create responses for testing private GenerateContentResponse toResponseWithText(String text) { @@ -191,4 +323,63 @@ private GenerateContentResponse toResponse(Part part) { private GenerateContentResponse toResponse(Candidate candidate) { return GenerateContentResponse.builder().candidates(candidate).build(); } + + private GenerateContentResponse toResponseWithText( + String text, GenerateContentResponseUsageMetadata usageMetadata) { + return GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(Part.fromText(text)).build()) + .build()) + .usageMetadata(usageMetadata) + .build(); + } + + private GenerateContentResponse toResponseWithText( + String text, + FinishReason.Known finishReason, + GenerateContentResponseUsageMetadata usageMetadata) { + return GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(Part.fromText(text)).build()) + .finishReason(new FinishReason(finishReason)) + .build()) + .usageMetadata(usageMetadata) + .build(); + } + + private GenerateContentResponse toResponseWithThoughtText( + String text, GenerateContentResponseUsageMetadata usageMetadata) { + Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build(); + return GenerateContentResponse.builder() + .candidates( + Candidate.builder().content(Content.builder().parts(thoughtPart).build()).build()) + .usageMetadata(usageMetadata) + .build(); + } + + private GenerateContentResponse toResponseWithThoughtText( + String text, + FinishReason.Known finishReason, + GenerateContentResponseUsageMetadata usageMetadata) { + Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build(); + return GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(thoughtPart).build()) + .finishReason(new FinishReason(finishReason)) + .build()) + .usageMetadata(usageMetadata) + .build(); + } + + private static GenerateContentResponseUsageMetadata createUsageMetadata( + int promptTokens, int candidateTokens, int totalTokens) { + return GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(promptTokens) + .candidatesTokenCount(candidateTokens) + .totalTokenCount(totalTokens) + .build(); + } } diff --git a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java index 4ae856fc7..3771143cf 100644 --- a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java +++ b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java @@ -37,8 +37,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.schedulers.Schedulers; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -144,6 +148,87 @@ public void onUserMessageCallback_pluginOrderRespected() { inOrder.verify(plugin2).onUserMessageCallback(mockInvocationContext, content); } + @Test + public void contextPropagation_runMaybeCallbacks() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + Content expectedContent = Content.builder().build(); + when(plugin1.onUserMessageCallback(any(), any())) + .thenReturn(Maybe.just(expectedContent).subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Maybe resultMaybe; + try (Scope scope = testContext.makeCurrent()) { + resultMaybe = pluginManager.onUserMessageCallback(mockInvocationContext, content); + } + + // Assert downstream operators have the propagated context + resultMaybe + .doOnSuccess( + result -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(expectedContent); + + verify(plugin1).onUserMessageCallback(mockInvocationContext, content); + } + + @Test + public void contextPropagation_afterRunCallback() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + when(plugin1.afterRunCallback(any())) + .thenReturn(Completable.complete().subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Completable resultCompletable; + try (Scope scope = testContext.makeCurrent()) { + resultCompletable = pluginManager.afterRunCallback(mockInvocationContext); + } + + // Assert downstream operators have the propagated context + resultCompletable + .doOnComplete( + () -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(); + + verify(plugin1).afterRunCallback(mockInvocationContext); + } + + @Test + public void contextPropagation_close() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + when(plugin1.close()).thenReturn(Completable.complete().subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Completable resultCompletable; + try (Scope scope = testContext.makeCurrent()) { + resultCompletable = pluginManager.close(); + } + + // Assert downstream operators have the propagated context + resultCompletable + .doOnComplete( + () -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(); + + verify(plugin1).close(); + } + @Test public void afterRunCallback_allComplete() { when(plugin1.afterRunCallback(any())).thenReturn(Completable.complete()); diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java new file mode 100644 index 000000000..4f4350d1a --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java @@ -0,0 +1,367 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.core.ApiFutures; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.RowError; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.rpc.Status; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class BatchProcessorTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock private StreamWriter mockWriter; + private ScheduledExecutorService executor; + private BatchProcessor batchProcessor; + private Schema schema; + private Handler mockHandler; + + @Before + public void setUp() { + executor = Executors.newScheduledThreadPool(1); + batchProcessor = new BatchProcessor(mockWriter, 10, Duration.ofMinutes(1), 100, executor); + schema = BigQuerySchema.getArrowSchema(); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + + Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + } + + @After + public void tearDown() { + batchProcessor.close(); + executor.shutdown(); + } + + @Test + public void flush_populatesTimestampFieldCorrectly() throws Exception { + Instant now = Instant.parse("2026-03-02T19:11:49.631Z"); + Map row = new HashMap<>(); + row.put("timestamp", now); + row.put("event_type", "TEST_EVENT"); + + final boolean[] checksPassed = {false}; + final String[] failureMessage = {null}; + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + if (root.getRowCount() != 1) { + failureMessage[0] = "Expected 1 row, got " + root.getRowCount(); + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + } + + var timestampVector = root.getVector("timestamp"); + if (!(timestampVector instanceof TimeStampMicroTZVector tzVector)) { + failureMessage[0] = "Vector should be an instance of TimeStampMicroTZVector"; + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + } + if (tzVector.isNull(0)) { + failureMessage[0] = "Timestamp should NOT be null"; + } else if (tzVector.get(0) != now.toEpochMilli() * 1000) { + failureMessage[0] = + "Expected " + (now.toEpochMilli() * 1000) + ", got " + tzVector.get(0); + } else { + checksPassed[0] = true; + } + } catch (RuntimeException e) { + failureMessage[0] = "Exception during check: " + e.getMessage(); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + assertTrue(failureMessage[0], checksPassed[0]); + } + + @Test + public void flush_populatesAllBasicFields() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", "BASIC_EVENT"); + row.put("is_truncated", true); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertEquals("BASIC_EVENT", root.getVector("event_type").getObject(0).toString()); + assertEquals(1, ((BitVector) root.getVector("is_truncated")).get(0)); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_populatesJsonFields() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("content", "{\"key\": \"value\"}"); + row.put("attributes", "{\"attr\": 123}"); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertEquals( + "{\"key\": \"value\"}", root.getVector("content").getObject(0).toString()); + assertEquals( + "{\"attr\": 123}", root.getVector("attributes").getObject(0).toString()); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_populatesNestedStructs() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + + List> contentParts = new ArrayList<>(); + Map part = new HashMap<>(); + part.put("mime_type", "text/plain"); + part.put("text", "hello world"); + part.put("part_index", 0L); + contentParts.add(part); + row.put("content_parts", contentParts); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + ListVector contentPartsVector = (ListVector) root.getVector("content_parts"); + StructVector structVector = (StructVector) contentPartsVector.getDataVector(); + + assertEquals(1, ((List) contentPartsVector.getObject(0)).size()); + VarCharVector mimeTypeVector = (VarCharVector) structVector.getChild("mime_type"); + assertEquals("text/plain", mimeTypeVector.getObject(0).toString()); + + VarCharVector textVector = (VarCharVector) structVector.getChild("text"); + assertEquals("hello world", textVector.getObject(0).toString()); + + BigIntVector partIndexVector = (BigIntVector) structVector.getChild("part_index"); + assertEquals(0L, partIndexVector.get(0)); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesBigQueryErrorResponse() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "ERROR_EVENT"); + + AppendRowsResponse responseWithError = + AppendRowsResponse.newBuilder() + .setError(Status.newBuilder().setMessage("Global error").build()) + .addRowErrors(RowError.newBuilder().setIndex(0).setMessage("Row error").build()) + .build(); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(responseWithError)); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesGenericExceptionDuringAppend() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "EXCEPTION_EVENT"); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenThrow(new RuntimeException("Simulated failure")); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void append_triggersFlushWhenBatchSizeReached() { + ScheduledExecutorService mockExecutor = mock(ScheduledExecutorService.class); + BatchProcessor bp = new BatchProcessor(mockWriter, 2, Duration.ofMinutes(1), 10, mockExecutor); + + Map row = new HashMap<>(); + bp.append(row); + verify(mockExecutor, never()).execute(any(Runnable.class)); + + bp.append(row); + verify(mockExecutor).execute(any(Runnable.class)); + } + + @Test + public void flush_doesNothingWhenQueueIsEmpty() throws Exception { + batchProcessor.flush(); + verify(mockWriter, never()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesNullValues() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", null); + row.put("is_truncated", null); + + final boolean[] checksPassed = {false}; + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertTrue(root.getVector("event_type").isNull(0)); + assertTrue(root.getVector("is_truncated").isNull(0)); + checksPassed[0] = true; + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + assertTrue("Null checks failed", checksPassed[0]); + } + + @Test + public void flush_handlesAllocationFailure() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "ALLOC_FAIL_EVENT"); + batchProcessor.append(row); + batchProcessor.allocator.setLimit(1); + + batchProcessor.flush(); + + verify(mockWriter, never()).append(any(ArrowRecordBatch.class)); + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + boolean foundError = false; + for (LogRecord record : captor.getAllValues()) { + if (record.getLevel().equals(Level.SEVERE) + && record.getMessage().contains("Failed to write batch to BigQuery")) { + foundError = true; + break; + } + } + assertTrue("Expected SEVERE error log not found", foundError); + } + + @Test + public void close_flushesAndClosesResources() throws Exception { + try (BatchProcessor bp = + new BatchProcessor(mockWriter, 10, Duration.ofMinutes(1), 100, executor)) { + Map row = new HashMap<>(); + row.put("event_type", "CLOSE_EVENT"); + bp.append(row); + } + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + verify(mockWriter).close(); + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java new file mode 100644 index 000000000..8147c5cc6 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -0,0 +1,457 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.Session; +import com.google.api.core.ApiFutures; +import com.google.auth.Credentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.context.Scope; +import io.reactivex.rxjava3.core.Flowable; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class BigQueryAgentAnalyticsPluginTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock private BigQuery mockBigQuery; + @Mock private StreamWriter mockWriter; + @Mock private BigQueryWriteClient mockWriteClient; + @Mock private InvocationContext mockInvocationContext; + private BaseAgent fakeAgent; + + private BigQueryLoggerConfig config; + private BigQueryAgentAnalyticsPlugin plugin; + private Handler mockHandler; + + @Before + public void setUp() throws Exception { + fakeAgent = new FakeAgent("agent_name"); + config = + BigQueryLoggerConfig.builder() + .setEnabled(true) + .setProjectId("project") + .setDatasetId("dataset") + .setTableName("table") + .setBatchSize(10) + .setBatchFlushInterval(Duration.ofSeconds(10)) + .setAutoSchemaUpgrade(false) + .setCredentials(mock(Credentials.class)) + .setCustomTags(ImmutableMap.of("global_tag", "global_value")) + .build(); + + when(mockBigQuery.getOptions()) + .thenReturn(BigQueryOptions.newBuilder().setProjectId("test-project").build()); + when(mockBigQuery.getTable(any(TableId.class))).thenReturn(mock(Table.class)); + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + + plugin = + new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + return mockWriter; + } + }; + + Session session = Session.builder("session_id").build(); + when(mockInvocationContext.session()).thenReturn(session); + when(mockInvocationContext.invocationId()).thenReturn("invocation_id"); + when(mockInvocationContext.agent()).thenReturn(fakeAgent); + when(mockInvocationContext.userId()).thenReturn("user_id"); + + Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + } + + @Test + public void onUserMessageCallback_appendsToWriter() throws Exception { + Content content = Content.builder().build(); + + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void beforeRunCallback_appendsToWriter() throws Exception { + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void afterRunCallback_flushesAndAppends() throws Exception { + plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void getStreamName_returnsCorrectFormat() { + BigQueryLoggerConfig config = + BigQueryLoggerConfig.builder() + .setProjectId("test-project") + .setDatasetId("test-dataset") + .setTableName("test-table") + .build(); + + String streamName = plugin.getStreamName(config); + + assertEquals( + "projects/test-project/datasets/test-dataset/tables/test-table/streams/_default", + streamName); + } + + @Test + public void formatContentParts_populatesCorrectFields() { + Content content = Content.fromParts(Part.fromText("hello")); + ArrayNode nodes = JsonFormatter.formatContentParts(Optional.of(content), 100); + assertEquals(1, nodes.size()); + ObjectNode node = (ObjectNode) nodes.get(0); + assertEquals(0, node.get("part_index").asInt()); + assertEquals("INLINE", node.get("storage_mode").asText()); + assertEquals("hello", node.get("text").asText()); + assertEquals("text/plain", node.get("mime_type").asText()); + } + + @Test + public void arrowSchema_hasJsonMetadata() { + Schema schema = BigQuerySchema.getArrowSchema(); + Field contentField = schema.findField("content"); + assertNotNull(contentField); + assertEquals("google:sqlType:json", contentField.getMetadata().get("ARROW:extension:name")); + } + + @Test + public void onUserMessageCallback_handlesTableCreationFailure() throws Exception { + Logger logger = Logger.getLogger(BigQueryAgentAnalyticsPlugin.class.getName()); + Handler mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + try { + when(mockBigQuery.getTable(any(TableId.class))) + .thenThrow(new RuntimeException("Table check failed")); + Content content = Content.builder().build(); + + // Should not throw exception + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + assertTrue( + captor + .getValue() + .getMessage() + .contains("Failed to check or create/upgrade BigQuery table")); + assertEquals(Level.WARNING, captor.getValue().getLevel()); + } finally { + logger.removeHandler(mockHandler); + } + } + + @Test + public void onUserMessageCallback_handlesAppendFailure() throws Exception { + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFailedFuture(new RuntimeException("Append failed"))); + Content content = Content.builder().build(); + + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + // Flush should handle the failed future from writer.append() + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + assertTrue(captor.getValue().getMessage().contains("Failed to write batch to BigQuery")); + assertEquals(Level.SEVERE, captor.getValue().getLevel()); + } + + @Test + public void ensureTableExists_calledOnlyOnce() throws Exception { + Content content = Content.builder().build(); + + // Multiple calls to logEvent via different callbacks + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); + + // Verify getting table was only done once. Using fully qualified name to avoid ambiguity. + verify(mockBigQuery).getTable(any(TableId.class)); + } + + @Test + public void arrowSchema_handlesNestedFields() { + Schema schema = BigQuerySchema.getArrowSchema(); + Field contentPartsField = schema.findField("content_parts"); + assertNotNull(contentPartsField); + // Repeated struct becomes a List of Structs + assertTrue(contentPartsField.getType() instanceof ArrowType.List); + + Field element = contentPartsField.getChildren().get(0); + assertEquals("element", element.getName()); + + // Check object_ref which is a nested STRUCT + Field objectRef = + element.getChildren().stream() + .filter(f -> f.getName().equals("object_ref")) + .findFirst() + .orElse(null); + assertNotNull(objectRef); + assertTrue(objectRef.getType() instanceof ArrowType.Struct); + assertFalse(objectRef.getChildren().isEmpty()); + } + + @Test + public void arrowSchema_handlesFieldNullability() { + Schema schema = BigQuerySchema.getArrowSchema(); + + // timestamp is REQUIRED in BigQuerySchema.getEventsSchema() + Field timestampField = schema.findField("timestamp"); + assertNotNull(timestampField); + assertFalse(timestampField.isNullable()); + + // event_type is NULLABLE in BigQuerySchema.getEventsSchema() + Field eventTypeField = schema.findField("event_type"); + assertNotNull(eventTypeField); + assertTrue(eventTypeField.isNullable()); + } + + @Test + public void logEvent_populatesCommonFields() throws Exception { + final boolean[] checksPassed = {false}; + final String[] failureMessage = {null}; + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + Schema schema = BigQuerySchema.getArrowSchema(); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, plugin.batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + if (root.getRowCount() != 1) { + failureMessage[0] = "Expected 1 row, got " + root.getRowCount(); + } else if (!Objects.equals( + root.getVector("event_type").getObject(0).toString(), "USER_MESSAGE")) { + failureMessage[0] = + "Wrong event_type: " + root.getVector("event_type").getObject(0); + } else if (!root.getVector("agent").getObject(0).toString().equals("agent_name")) { + failureMessage[0] = "Wrong agent: " + root.getVector("agent").getObject(0); + } else if (!root.getVector("session_id") + .getObject(0) + .toString() + .equals("session_id")) { + failureMessage[0] = + "Wrong session_id: " + root.getVector("session_id").getObject(0); + } else if (!root.getVector("invocation_id") + .getObject(0) + .toString() + .equals("invocation_id")) { + failureMessage[0] = + "Wrong invocation_id: " + root.getVector("invocation_id").getObject(0); + } else if (!root.getVector("user_id").getObject(0).toString().equals("user_id")) { + failureMessage[0] = "Wrong user_id: " + root.getVector("user_id").getObject(0); + } else if (((TimeStampMicroTZVector) root.getVector("timestamp")).get(0) <= 0) { + failureMessage[0] = "Timestamp not populated"; + } else { + // Check content and content_parts + String contentJson = root.getVector("content").getObject(0).toString(); + if (!contentJson.contains("test message")) { + failureMessage[0] = "Wrong content: " + contentJson; + } else { + ListVector contentPartsVector = (ListVector) root.getVector("content_parts"); + if (((List) contentPartsVector.getObject(0)).isEmpty()) { + failureMessage[0] = "content_parts is empty"; + } else { + // Check attributes + String attributesJson = root.getVector("attributes").getObject(0).toString(); + if (!attributesJson.contains("global_tag") + || !attributesJson.contains("global_value")) { + failureMessage[0] = "Wrong attributes: " + attributesJson; + } else { + checksPassed[0] = true; + } + } + } + } + } catch (RuntimeException e) { + failureMessage[0] = "Exception during inspection: " + e.getMessage(); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + Content content = Content.fromParts(Part.fromText("test message")); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.batchProcessor.flush(); + + assertTrue(failureMessage[0], checksPassed[0]); + } + + @Test + public void logEvent_populatesTraceDetails() throws Exception { + String traceId = "4bf92f3577b34da6a3ce929d0e0e4736"; + String spanId = "00f067aa0ba902b7"; + + SpanContext mockSpanContext = mock(SpanContext.class); + when(mockSpanContext.isValid()).thenReturn(true); + when(mockSpanContext.getTraceId()).thenReturn(traceId); + when(mockSpanContext.getSpanId()).thenReturn(spanId); + + Span mockSpan = Span.wrap(mockSpanContext); + + try (Scope scope = mockSpan.makeCurrent()) { + Content content = Content.builder().build(); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals(traceId, row.get("trace_id")); + assertEquals(spanId, row.get("span_id")); + } + } + + @Test + public void complexType_appendsToWriter() throws Exception { + Part part = Part.fromText("test text"); + Content content = Content.fromParts(part); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void onEventCallback_populatesCorrectFields() throws Exception { + Event event = + Event.builder() + .author("agent_author") + .content(Content.fromParts(Part.fromText("event content"))) + .build(); + + plugin.onEventCallback(mockInvocationContext, event).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("EVENT", row.get("event_type")); + assertEquals("agent_name", row.get("agent")); + assertTrue(row.get("attributes").toString().contains("agent_author")); + assertTrue(row.get("content").toString().contains("event content")); + } + + @Test + public void onModelErrorCallback_populatesCorrectFields() throws Exception { + CallbackContext mockCallbackContext = mock(CallbackContext.class); + when(mockCallbackContext.invocationContext()).thenReturn(mockInvocationContext); + when(mockCallbackContext.agentName()).thenReturn("agent_in_context"); + LlmRequest.Builder mockLlmRequestBuilder = mock(LlmRequest.Builder.class); + Throwable error = new RuntimeException("model error message"); + + plugin + .onModelErrorCallback(mockCallbackContext, mockLlmRequestBuilder, error) + .blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("MODEL_ERROR", row.get("event_type")); + assertEquals("agent_in_context", row.get("agent")); + assertTrue(row.get("attributes").toString().contains("model error message")); + } + + private static class FakeAgent extends BaseAgent { + FakeAgent(String name) { + super(name, "description", null, null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + } +} diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 8a0a84b08..efd565c16 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -24,6 +24,9 @@ import static com.google.adk.testing.TestUtils.createTextLlmResponse; import static com.google.adk.testing.TestUtils.simplifyEvents; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Arrays.stream; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.mock; @@ -31,15 +34,18 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; +import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; +import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; import com.google.adk.summarizer.EventsCompactionConfig; @@ -57,17 +63,22 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -75,6 +86,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; @RunWith(JUnit4.class) public final class RunnerTest { @@ -842,10 +854,62 @@ public void beforeRunCallback_withStateDelta_seesMergedState() { assertThat(sessionInCallback.state()).containsEntry("number", 123); } + @Test + public void runAsync_ensureEventsAreAppendedInOrder() throws Exception { + Event event1 = TestUtils.createEvent("1"); + Event event2 = TestUtils.createEvent("2"); + BaseAgent mockAgent = TestUtils.createSubAgent("test agent", event1, event2); + + BaseSessionService mockSessionService = mock(BaseSessionService.class); + + when(mockSessionService.getSession(any(), any(), any(), any())).thenReturn(Maybe.just(session)); + when(mockSessionService.appendEvent(any(), any())) + .thenAnswer( + invocation -> { + Event eventArg = invocation.getArgument(1); + Single result = Single.just(eventArg); + if (eventArg.id().equals("1")) { + // Artificially delay the first event to ensure it is appended first. + return result.delay(100, MILLISECONDS); + } + return result; + }); + + Runner mockRunner = + Runner.builder() + .agent(mockAgent) + .appName("test") + .sessionService(mockSessionService) + .build(); + + List results = + mockRunner + .runAsync("user", session.id(), createContent("user message")) + .toList() + .blockingGet(); + + assertThat(simplifyEvents(results)) + .containsExactly("author: content for event 1", "author: content for event 2") + .inOrder(); + } + private Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); } + private static Content createInlineDataContent(byte[]... data) { + return Content.builder() + .parts( + stream(data) + .map(dataBytes -> Part.fromBytes(dataBytes, "example/octet-stream")) + .toArray(Part[]::new)) + .build(); + } + + private static Content createInlineDataContent(String... data) { + return createInlineDataContent(stream(data).map(d -> d.getBytes(UTF_8)).toArray(byte[][]::new)); + } + @Test public void runAsync_createsInvocationSpan() { var unused = @@ -977,6 +1041,84 @@ public void runLive_createsInvocationSpan() { assertThat(invocationSpan.get().hasEnded()).isTrue(); } + @Test + public void runAsync_createsToolSpansWithCorrectParent() { + LlmAgent agentWithTool = + createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); + Runner runnerWithTool = + Runner.builder().app(App.builder().name("test").rootAgent(agentWithTool).build()).build(); + Session sessionWithTool = + runnerWithTool.sessionService().createSession("test", "user").blockingGet(); + + var unused = + runnerWithTool + .runAsync( + sessionWithTool.sessionKey(), + createContent("from user"), + RunConfig.builder().build()) + .toList() + .blockingGet(); + + List spans = openTelemetryRule.getSpans(); + List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); + List toolCallSpans = + spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); + List toolResponseSpans = + spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + + assertThat(llmSpans).hasSize(2); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); + String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); + String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + + assertThat(toolCallParentId).isEqualTo(toolResponseParentId); + assertThat(llmSpanIds).contains(toolCallParentId); + } + + @Test + public void runLive_createsToolSpansWithCorrectParent() throws Exception { + LlmAgent agentWithTool = + createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); + Runner runnerWithTool = + Runner.builder().app(App.builder().name("test").rootAgent(agentWithTool).build()).build(); + Session sessionWithTool = + runnerWithTool.sessionService().createSession("test", "user").blockingGet(); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber = + runnerWithTool + .runLive(sessionWithTool.sessionKey(), liveRequestQueue, RunConfig.builder().build()) + .test(); + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + + List spans = openTelemetryRule.getSpans(); + List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); + List toolCallSpans = + spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); + List toolResponseSpans = + spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + + // In runLive, there is one call_llm span for the execution + assertThat(llmSpans).hasSize(1); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); + String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); + String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + + assertThat(toolCallParentId).isEqualTo(toolResponseParentId); + assertThat(llmSpanIds).contains(toolCallParentId); + } + @Test public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() { RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); @@ -1188,6 +1330,53 @@ public void close_closesPluginsAndCodeExecutors() { verify(plugin).close(); } + @Test + public void runAsync_contextPropagation() { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + List events; + try (Scope scope = testContext.makeCurrent()) { + events = + runner + .runAsync("user", session.id(), createContent("test message")) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .toList() + .blockingGet(); + } + + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + } + + @Test + public void runLive_contextPropagation() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber; + try (Scope scope = testContext.makeCurrent()) { + testSubscriber = + runner + .runLive(session, liveRequestQueue, RunConfig.builder().build()) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test(); + } + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + assertThat(simplifyEvents(testSubscriber.values())).containsExactly("test agent: from llm"); + } + @Test public void buildRunnerWithPlugins_success() { BasePlugin plugin1 = mockPlugin("test1"); @@ -1203,4 +1392,40 @@ public static ImmutableMap echoTool(String message) { return ImmutableMap.of("message", message); } } + + @Test + public void runner_executesSaveArtifactFlow() { + // arrange + final AtomicInteger artifactsSavedCounter = new AtomicInteger(); + BaseArtifactService mockArtifactService = Mockito.mock(BaseArtifactService.class); + when(mockArtifactService.saveArtifact(any(), any(), any(), any(), any())) + .thenReturn( + Single.defer( + () -> { + // we want to assert not only that the saveArtifact method was + // called, but also that the flow that it returned was run, so + // we need to record the call in a counter + artifactsSavedCounter.incrementAndGet(); + return Single.just(42); + })); + Runner runner = + Runner.builder() + .app(App.builder().name("test").rootAgent(agent).build()) + .artifactService(mockArtifactService) + .build(); + session = runner.sessionService().createSession("test", "user").blockingGet(); + // each inline data will be saved using our mock artifact service + Content content = createInlineDataContent("test data", "test data 2"); + RunConfig runConfig = RunConfig.builder().setSaveInputBlobsAsArtifacts(true).build(); + + // act + var events = runner.runAsync("user", session.id(), content, runConfig).test(); + + // assert + events.assertComplete(); + // artifacts were saved + assertThat(artifactsSavedCounter.get()).isEqualTo(2); + // agent was run + assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); + } } diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 41e156ffd..0d9235b1b 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -20,6 +20,7 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import io.reactivex.rxjava3.core.Single; +import java.time.Instant; import java.util.HashMap; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; @@ -214,6 +215,24 @@ public void appendEvent_removesState() { assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey"); } + @Test + public void appendEvent_updatesSessionTimestampWithFractionalSeconds() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + + // Add an event with a timestamp that contains a fractional second + Event eventAdd = Event.builder().timestamp(5500).build(); + var unused = sessionService.appendEvent(session, eventAdd).blockingGet(); + + // Verify the last modified timestamp contains a fractional second + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSession.lastUpdateTime()).isEqualTo(Instant.ofEpochSecond(5, 500000000L)); + } + @Test public void sequentialAgents_shareTempState() { InMemorySessionService sessionService = new InMemorySessionService(); diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index e5795d61f..b13904934 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -16,6 +16,7 @@ package com.google.adk.telemetry; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -31,11 +32,15 @@ import com.google.adk.runner.Runner; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; -import com.google.adk.sessions.SessionKey; +import com.google.adk.testing.TestLlm; +import com.google.adk.testing.TestUtils; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.GenerateContentResponseUsageMetadata; @@ -54,6 +59,7 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.schedulers.Schedulers; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -123,10 +129,7 @@ public void testToolCallSpanLinksToParent() { parentSpanData.getSpanContext().getTraceId(), toolCallSpanData.getSpanContext().getTraceId()); - assertEquals( - "Tool call's parent should be the parent span", - parentSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, toolCallSpanData); } @Test @@ -146,7 +149,7 @@ public void testToolCallWithoutParentCreatesRootSpan() { // Then: Should create root span (backward compatible) List spans = openTelemetryRule.getSpans(); - assertEquals("Should have exactly 1 span", 1, spans.size()); + assertThat(spans).hasSize(1); SpanData toolCallSpanData = spans.get(0); assertFalse( @@ -193,7 +196,7 @@ public void testNestedSpanHierarchy() { List spans = openTelemetryRule.getSpans(); // The 4 spans are: "parent", "invocation", "tool_call [testTool]", and "tool_response // [testTool]". - assertEquals("Should have 4 spans in the hierarchy", 4, spans.size()); + assertThat(spans).hasSize(4); SpanData parentSpanData = findSpanByName("parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); @@ -210,22 +213,13 @@ public void testNestedSpanHierarchy() { SpanData toolResponseSpanData = findSpanByName("tool_response [testTool]"); // invocation should be child of parent - assertEquals( - "Invocation should be child of parent", - parentSpanData.getSpanContext().getSpanId(), - invocationSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, invocationSpanData); // tool_call should be child of invocation - assertEquals( - "Tool call should be child of invocation", - invocationSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId()); + assertParent(invocationSpanData, toolCallSpanData); // tool_response should be child of tool_call - assertEquals( - "Tool response should be child of tool call", - toolCallSpanData.getSpanContext().getSpanId(), - toolResponseSpanData.getParentSpanContext().getSpanId()); + assertParent(toolCallSpanData, toolResponseSpanData); } @Test @@ -253,7 +247,6 @@ public void testMultipleSpansInParallel() { // Verify all tool calls link to same parent SpanData parentSpanData = findSpanByName("parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - String parentSpanId = parentSpanData.getSpanContext().getSpanId(); // All tool calls should have same trace ID and parent span ID List toolCallSpans = @@ -261,7 +254,7 @@ public void testMultipleSpansInParallel() { .filter(s -> s.getName().startsWith("tool_call")) .toList(); - assertEquals("Should have 3 tool call spans", 3, toolCallSpans.size()); + assertThat(toolCallSpans).hasSize(3); toolCallSpans.forEach( span -> { @@ -269,10 +262,7 @@ public void testMultipleSpansInParallel() { "Tool call should have same trace ID as parent", parentTraceId, span.getSpanContext().getTraceId()); - assertEquals( - "Tool call should have parent as parent span", - parentSpanId, - span.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, span); }); } @@ -298,10 +288,7 @@ public void testInvokeAgentSpanLinksToInvocation() { SpanData invocationSpanData = findSpanByName("invocation"); SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - assertEquals( - "Agent run should be child of invocation", - invocationSpanData.getSpanContext().getSpanId(), - invokeAgentSpanData.getParentSpanContext().getSpanId()); + assertParent(invocationSpanData, invokeAgentSpanData); } @Test @@ -323,15 +310,12 @@ public void testCallLlmSpanLinksToAgentRun() { } List spans = openTelemetryRule.getSpans(); - assertEquals("Should have 2 spans", 2, spans.size()); + assertThat(spans).hasSize(2); SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); SpanData callLlmSpanData = findSpanByName("call_llm"); - assertEquals( - "Call LLM should be child of agent run", - invokeAgentSpanData.getSpanContext().getSpanId(), - callLlmSpanData.getParentSpanContext().getSpanId()); + assertParent(invokeAgentSpanData, callLlmSpanData); } @Test @@ -349,10 +333,7 @@ public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { SpanData parentSpanData = findSpanByName("invocation"); SpanData agentSpanData = findSpanByName("invoke_agent"); - assertEquals( - "Agent span should be a child of the invocation span", - parentSpanData.getSpanContext().getSpanId(), - agentSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, agentSpanData); } @Test @@ -380,9 +361,7 @@ public void testTraceFlowable() throws InterruptedException { SpanData parentSpanData = findSpanByName("parent"); SpanData flowableSpanData = findSpanByName("flowable"); - assertEquals( - parentSpanData.getSpanContext().getSpanId(), - flowableSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, flowableSpanData); assertTrue(flowableSpanData.hasEnded()); } @@ -469,9 +448,7 @@ public void testTraceTransformer() throws InterruptedException { SpanData parentSpanData = findSpanByName("parent"); SpanData transformerSpanData = findSpanByName("transformer"); - assertEquals( - parentSpanData.getSpanContext().getSpanId(), - transformerSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, transformerSpanData); assertTrue(transformerSpanData.hasEnded()); } @@ -485,7 +462,7 @@ public void testTraceAgentInvocation() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("invoke_agent", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -504,7 +481,7 @@ public void testTraceToolCall() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -541,7 +518,7 @@ public void testTraceToolResponse() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -578,7 +555,7 @@ public void testTraceCallLlm() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); @@ -606,12 +583,12 @@ public void testTraceSendData() { Tracing.traceSendData( buildInvocationContext(), "event-1", - ImmutableList.of(Content.builder().role("user").parts(Part.fromText("hello")).build())); + ImmutableList.of(Content.fromParts(Part.fromText("hello")))); } finally { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals( @@ -653,49 +630,35 @@ public void baseAgentRunAsync_propagatesContext() throws InterruptedException { } SpanData parent = findSpanByName("parent"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals(parent.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, agentSpan); } @Test public void runnerRunAsync_propagatesContext() throws InterruptedException { BaseAgent agent = new TestAgent(); - Runner runner = Runner.builder().agent(agent).appName("test-app").build(); Span parentSpan = tracer.spanBuilder("parent").startSpan(); try (Scope s = parentSpan.makeCurrent()) { - Session session = - runner - .sessionService() - .createSession(new SessionKey("test-app", "test-user", "test-session")) - .blockingGet(); - Content newMessage = Content.fromParts(Part.fromText("hi")); - RunConfig runConfig = RunConfig.builder().build(); - runner - .runAsync(session.userId(), session.id(), newMessage, runConfig, null) - .test() - .await() - .assertComplete(); + runAgent(agent); } finally { parentSpan.end(); } SpanData parent = findSpanByName("parent"); SpanData invocation = findSpanByName("invocation"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals( - parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); - assertEquals( - invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, invocation); + assertParent(invocation, agentSpan); } @Test public void runnerRunLive_propagatesContext() throws InterruptedException { BaseAgent agent = new TestAgent(); - Runner runner = Runner.builder().agent(agent).appName("test-app").build(); + Runner runner = + Runner.builder().agent(agent).appName("test_app").sessionService(sessionService).build(); Span parentSpan = tracer.spanBuilder("parent").startSpan(); try (Scope s = parentSpan.makeCurrent()) { Session session = - runner - .sessionService() - .createSession("test-app", "test-user", (Map) null, "test-session") + sessionService + .createSession("test_app", "test-user", (Map) null, "test-session") .blockingGet(); Content newMessage = Content.fromParts(Part.fromText("hi")); RunConfig runConfig = RunConfig.builder().build(); @@ -713,10 +676,171 @@ public void runnerRunLive_propagatesContext() throws InterruptedException { SpanData parent = findSpanByName("parent"); SpanData invocation = findSpanByName("invocation"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals( - parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); - assertEquals( - invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, invocation); + assertParent(invocation, agentSpan); + } + + @Test + public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { + // This test verifies the trace hierarchy created when an agent calls an LLM, + // which then invokes a tool. The expected hierarchy is: + // invocation + // └── invoke_agent test_agent + // ├── call_llm + // │ ├── tool_call [search_flights] + // │ └── tool_response [search_flights] + // └── call_llm + + SearchFlightsTool searchFlightsTool = new SearchFlightsTool(); + + TestLlm testLlm = + TestUtils.createTestLlm( + TestUtils.createLlmResponse( + Content.builder() + .role("model") + .parts( + Part.fromFunctionCall( + searchFlightsTool.name(), ImmutableMap.of("destination", "SFO"))) + .build()), + TestUtils.createLlmResponse(Content.fromParts(Part.fromText("done")))); + + LlmAgent agentWithTool = + LlmAgent.builder() + .name("test_agent") + .description("description") + .model(testLlm) + .tools(ImmutableList.of(searchFlightsTool)) + .build(); + + runAgent(agentWithTool); + + SpanData invocation = findSpanByName("invocation"); + SpanData invokeAgent = findSpanByName("invoke_agent test_agent"); + SpanData toolCall = findSpanByName("tool_call [search_flights]"); + SpanData toolResponse = findSpanByName("tool_response [search_flights]"); + List callLlmSpans = + openTelemetryRule.getSpans().stream() + .filter(s -> s.getName().equals("call_llm")) + .sorted(Comparator.comparing(SpanData::getStartEpochNanos)) + .toList(); + assertThat(callLlmSpans).hasSize(2); + SpanData callLlm1 = callLlmSpans.get(0); + SpanData callLlm2 = callLlmSpans.get(1); + + // Assert hierarchy: + // invocation + // └── invoke_agent test_agent + assertParent(invocation, invokeAgent); + // ├── call_llm 1 + assertParent(invokeAgent, callLlm1); + // │ ├── tool_call [search_flights] + assertParent(callLlm1, toolCall); + // │ └── tool_response [search_flights] + assertParent(callLlm1, toolResponse); + // └── call_llm 2 + assertParent(invokeAgent, callLlm2); + } + + @Test + public void testNestedAgentTraceHierarchy() throws InterruptedException { + // This test verifies the trace hierarchy created when AgentA transfers to AgentB. + // The expected hierarchy is: + // invocation + // └── invoke_agent AgentA + // ├── call_llm + // │ ├── tool_call [transfer_to_agent] + // │ └── tool_response [transfer_to_agent] + // └── invoke_agent AgentB + // └── call_llm + TestLlm llm = + TestUtils.createTestLlm( + TestUtils.createLlmResponse( + Content.builder() + .role("model") + .parts( + Part.fromFunctionCall( + "transfer_to_agent", ImmutableMap.of("agent_name", "AgentB"))) + .build()), + TestUtils.createLlmResponse(Content.fromParts(Part.fromText("agent b response")))); + LlmAgent agentB = LlmAgent.builder().name("AgentB").description("Agent B").model(llm).build(); + + LlmAgent agentA = + LlmAgent.builder() + .name("AgentA") + .description("Agent A") + .model(llm) + .subAgents(ImmutableList.of(agentB)) + .build(); + + runAgent(agentA); + + SpanData invocation = findSpanByName("invocation"); + SpanData agentASpan = findSpanByName("invoke_agent AgentA"); + SpanData toolCall = findSpanByName("tool_call [transfer_to_agent]"); + SpanData agentBSpan = findSpanByName("invoke_agent AgentB"); + SpanData toolResponse = findSpanByName("tool_response [transfer_to_agent]"); + + List callLlmSpans = + openTelemetryRule.getSpans().stream() + .filter(s -> s.getName().equals("call_llm")) + .sorted(Comparator.comparing(SpanData::getStartEpochNanos)) + .toList(); + assertThat(callLlmSpans).hasSize(2); + + SpanData agentACallLlm1 = callLlmSpans.get(0); + SpanData agentBCallLlm = callLlmSpans.get(1); + + assertParent(invocation, agentASpan); + assertParent(agentASpan, agentACallLlm1); + assertParent(agentACallLlm1, toolCall); + assertParent(agentACallLlm1, toolResponse); + assertParent(agentASpan, agentBSpan); + assertParent(agentBSpan, agentBCallLlm); + } + + private void runAgent(BaseAgent agent) throws InterruptedException { + Runner runner = + Runner.builder().agent(agent).appName("test_app").sessionService(sessionService).build(); + Session session = + sessionService.createSession("test_app", "test-user", null, "test-session").blockingGet(); + Content newMessage = Content.fromParts(Part.fromText("hi")); + RunConfig runConfig = RunConfig.builder().build(); + runner + .runAsync(session.sessionKey(), newMessage, runConfig, null) + .test() + .await() + .assertComplete(); + } + + /** Tool for testing. */ + public static class SearchFlightsTool extends BaseTool { + public SearchFlightsTool() { + super("search_flights", "Search for flights tool"); + } + + @Override + public Single> runAsync(Map args, ToolContext context) { + return Single.just(ImmutableMap.of("result", args)); + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name("search_flights") + .description("Search for flights tool") + .build()); + } + } + + /** + * Asserts that the parent span is the parent of the child span. + * + * @param parent The parent span. + * @param child The child span. + */ + private void assertParent(SpanData parent, SpanData child) { + assertEquals(parent.getSpanContext().getSpanId(), child.getParentSpanContext().getSpanId()); } /** @@ -744,7 +868,9 @@ private SpanData findSpanByName(String name) { private InvocationContext buildInvocationContext() { Session session = - sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + sessionService + .createSession("test_app", "test-user", (Map) null, "test-session") + .blockingGet(); return InvocationContext.builder() .sessionService(sessionService) .session(session) diff --git a/core/src/test/java/com/google/adk/testing/TestUtils.java b/core/src/test/java/com/google/adk/testing/TestUtils.java index 70ae14bf1..daed8d2e4 100644 --- a/core/src/test/java/com/google/adk/testing/TestUtils.java +++ b/core/src/test/java/com/google/adk/testing/TestUtils.java @@ -61,7 +61,7 @@ public static InvocationContext createInvocationContext(BaseAgent agent, RunConf .artifactService(new InMemoryArtifactService()) .invocationId("invocationId") .agent(agent) - .session(sessionService.createSession("test-app", "test-user").blockingGet()) + .session(sessionService.createSession("test_app", "test-user").blockingGet()) .userContent(Content.fromParts(Part.fromText("user content"))) .runConfig(runConfig) .build(); diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index 3a5390027..0f168c5df 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -143,7 +143,7 @@ public void declaration_withInputSchema_returnsDeclarationWithSchema() { AgentTool agentTool = AgentTool.create( createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .name("agent name") + .name("agent_name") .description("agent description") .inputSchema(inputSchema) .build()); @@ -153,7 +153,7 @@ public void declaration_withInputSchema_returnsDeclarationWithSchema() { assertThat(declaration) .isEqualTo( FunctionDeclaration.builder() - .name("agent name") + .name("agent_name") .description("agent description") .parameters(inputSchema) .build()); @@ -164,7 +164,7 @@ public void declaration_withoutInputSchema_returnsDeclarationWithRequestParamete AgentTool agentTool = AgentTool.create( createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .name("agent name") + .name("agent_name") .description("agent description") .build()); @@ -173,7 +173,7 @@ public void declaration_withoutInputSchema_returnsDeclarationWithRequestParamete assertThat(declaration) .isEqualTo( FunctionDeclaration.builder() - .name("agent name") + .name("agent_name") .description("agent description") .parameters( Schema.builder() @@ -200,7 +200,7 @@ public void call_withInputSchema_invalidInput_throwsException() throws Exception .build(); LlmAgent testAgent = createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .name("agent name") + .name("agent_name") .description("agent description") .inputSchema(inputSchema) .build(); @@ -256,7 +256,7 @@ public void call_withOutputSchema_invalidOutput_throwsException() throws Excepti "{\"is_valid\": \"invalid type\", " + "\"message\": \"success\"}"))) .build())) - .name("agent name") + .name("agent_name") .description("agent description") .outputSchema(outputSchema) .build(); @@ -301,7 +301,7 @@ public void call_withInputAndOutputSchema_successful() throws Exception { Part.fromText( "{\"is_valid\": true, " + "\"message\": \"success\"}"))) .build())) - .name("agent name") + .name("agent_name") .description("agent description") .inputSchema(inputSchema) .outputSchema(outputSchema) @@ -332,7 +332,7 @@ public void call_withoutSchema_returnsConcatenatedTextFromLastEvent() throws Exc Part.fromText("First text part. "), Part.fromText("Second text part."))) .build()))) - .name("agent name") + .name("agent_name") .description("agent description") .build(); AgentTool agentTool = AgentTool.create(testAgent); @@ -358,7 +358,7 @@ public void call_withThoughts_returnsOnlyNonThoughtText() throws Exception { .build()) .build()); LlmAgent testAgent = - createTestAgentBuilder(testLlm).name("agent name").description("agent description").build(); + createTestAgentBuilder(testLlm).name("agent_name").description("agent description").build(); AgentTool agentTool = AgentTool.create(testAgent); ToolContext toolContext = createToolContext(testAgent); @@ -373,7 +373,7 @@ public void call_emptyModelResponse_returnsEmptyMap() throws Exception { LlmAgent testAgent = createTestAgentBuilder( createTestLlm(LlmResponse.builder().content(Content.builder().build()).build())) - .name("agent name") + .name("agent_name") .description("agent description") .build(); AgentTool agentTool = AgentTool.create(testAgent); @@ -394,7 +394,7 @@ public void call_withInputSchema_argsAreSentToAgent() throws Exception { .build()); LlmAgent testAgent = createTestAgentBuilder(testLlm) - .name("agent name") + .name("agent_name") .description("agent description") .inputSchema( Schema.builder() @@ -422,7 +422,7 @@ public void call_withoutInputSchema_requestIsSentToAgent() throws Exception { .content(Content.fromParts(Part.fromText("test response"))) .build()); LlmAgent testAgent = - createTestAgentBuilder(testLlm).name("agent name").description("agent description").build(); + createTestAgentBuilder(testLlm).name("agent_name").description("agent description").build(); AgentTool agentTool = AgentTool.create(testAgent); ToolContext toolContext = createToolContext(testAgent); @@ -447,7 +447,7 @@ public void call_withStateDeltaInResponse_propagatesStateDelta() throws Exceptio .build()); LlmAgent testAgent = createTestAgentBuilder(testLlm) - .name("agent name") + .name("agent_name") .description("agent description") .afterAgentCallback(afterAgentCallback) .build(); @@ -477,7 +477,7 @@ public void call_withSkipSummarizationAndStateDelta_propagatesStateAndSetsSkipSu .build()); LlmAgent testAgent = createTestAgentBuilder(testLlm) - .name("agent name") + .name("agent_name") .description("agent description") .afterAgentCallback(afterAgentCallback) .build(); diff --git a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java index 4e80ed0ff..55e5d8f93 100644 --- a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java +++ b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java @@ -305,4 +305,30 @@ static final class WrongTypeProviderHolder { private WrongTypeProviderHolder() {} } + + @Test + public void declaration_isEmpty() { + ExampleTool tool = ExampleTool.builder().build(); + assertThat(tool.declaration().isPresent()).isFalse(); + } + + @Test + public void processLlmRequest_doesNotAddFunctionDeclarations() { + ExampleTool tool = ExampleTool.builder().addExample(makeExample("qin", "qout")).build(); + InvocationContext ctx = buildInvocationContext(); + LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); + + tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); + LlmRequest updated = builder.build(); + + if (updated.config().isPresent()) { + var config = updated.config().get(); + if (config.tools().isPresent()) { + var tools = config.tools().get(); + boolean hasFunctionDeclarations = + tools.stream().anyMatch(t -> t.functionDeclarations().isPresent()); + assertThat(hasFunctionDeclarations).isFalse(); + } + } + } } diff --git a/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java new file mode 100644 index 000000000..64b600af9 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SetModelResponseToolTest { + + @Test + public void declaration_returnsCorrectFunctionDeclaration() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + FunctionDeclaration declaration = tool.declaration().get(); + + assertThat(declaration.name()).hasValue("set_model_response"); + assertThat(declaration.description()).isPresent(); + assertThat(declaration.description().get()).contains("Set your final response"); + assertThat(declaration.parameters()).hasValue(outputSchema); + } + + @Test + public void runAsync_returnsArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map args = ImmutableMap.of("field1", "value1"); + + Map result = tool.runAsync(args, null).blockingGet(); + + assertThat(result).isEqualTo(args); + } + + @Test + public void runAsync_validatesArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map invalidArgs = ImmutableMap.of("field2", "value2"); + + // Should throw validation error + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> tool.runAsync(invalidArgs, null).blockingGet()); + + assertThat(exception).hasMessageThat().contains("does not match agent output schema"); + } + + @Test + public void runAsync_validatesComplexArgs() { + Schema complexSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "id", + Schema.builder().type("INTEGER").build(), + "tags", + Schema.builder() + .type("ARRAY") + .items(Schema.builder().type("STRING").build()) + .build(), + "metadata", + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("key", Schema.builder().type("STRING").build())) + .build())) + .required(ImmutableList.of("id", "tags", "metadata")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(complexSchema); + Map complexArgs = + ImmutableMap.of( + "id", 123, + "tags", ImmutableList.of("tag1", "tag2"), + "metadata", ImmutableMap.of("key", "value")); + + Map result = tool.runAsync(complexArgs, null).blockingGet(); + + assertThat(result).containsEntry("id", 123); + assertThat(result).containsEntry("tags", ImmutableList.of("tag1", "tag2")); + assertThat(result).containsEntry("metadata", ImmutableMap.of("key", "value")); + } +} diff --git a/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java b/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java index 3d322b73f..0db218347 100644 --- a/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java +++ b/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java @@ -34,7 +34,6 @@ import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.spec.McpSchema; import java.util.List; -import java.util.Optional; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -324,7 +323,7 @@ public void getTools_withToolFilter_returnsFilteredTools() { when(mockMcpSyncClient.listTools()).thenReturn(mockResult); McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.of(toolFilter)); + new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), toolFilter); List tools = toolset.getTools(mockReadonlyContext).toList().blockingGet(); @@ -340,8 +339,7 @@ public void getTools_retriesAndFailsAfterMaxRetries() { when(mockMcpSessionManager.createSession()).thenReturn(mockMcpSyncClient); when(mockMcpSyncClient.listTools()).thenThrow(new RuntimeException("Test Exception")); - McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty()); + McpToolset toolset = new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper()); toolset .getTools(mockReadonlyContext) @@ -362,8 +360,7 @@ public void getTools_succeedsOnLastRetryAttempt() { .thenThrow(new RuntimeException("Attempt 2 failed")) .thenReturn(mockResult); - McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty()); + McpToolset toolset = new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper()); List tools = toolset.getTools(mockReadonlyContext).toList().blockingGet(); diff --git a/dev/pom.xml b/dev/pom.xml index 6cabcba7c..5468a1187 100644 --- a/dev/pom.xml +++ b/dev/pom.xml @@ -18,7 +18,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT google-adk-dev diff --git a/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java b/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java index 7d70a0efb..eab293de0 100644 --- a/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java +++ b/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java @@ -16,8 +16,8 @@ package com.google.adk.plugins; import com.google.adk.plugins.recordings.Recordings; -import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** Per-invocation replay state to isolate concurrent runs. */ class InvocationReplayState { @@ -33,7 +33,7 @@ public InvocationReplayState(String testCasePath, int userMessageIndex, Recordin this.testCasePath = testCasePath; this.userMessageIndex = userMessageIndex; this.recordings = recordings; - this.agentReplayIndices = new HashMap<>(); + this.agentReplayIndices = new ConcurrentHashMap<>(); } public String getTestCasePath() { @@ -57,7 +57,6 @@ public void setAgentReplayIndex(String agentName, int index) { } public void incrementAgentReplayIndex(String agentName) { - int currentIndex = getAgentReplayIndex(agentName); - setAgentReplayIndex(agentName, currentIndex + 1); + agentReplayIndices.merge(agentName, 1, Integer::sum); } } diff --git a/maven_plugin/examples/custom_tools/pom.xml b/maven_plugin/examples/custom_tools/pom.xml index f2118f9cc..aa273d732 100644 --- a/maven_plugin/examples/custom_tools/pom.xml +++ b/maven_plugin/examples/custom_tools/pom.xml @@ -4,7 +4,7 @@ com.example custom-tools-example - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT jar ADK Custom Tools Example diff --git a/maven_plugin/examples/simple-agent/pom.xml b/maven_plugin/examples/simple-agent/pom.xml index 5c0f4462d..34aeb8c1c 100644 --- a/maven_plugin/examples/simple-agent/pom.xml +++ b/maven_plugin/examples/simple-agent/pom.xml @@ -4,7 +4,7 @@ com.example simple-adk-agent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT jar Simple ADK Agent Example diff --git a/maven_plugin/pom.xml b/maven_plugin/pom.xml index c48331f72..d0feb41e3 100644 --- a/maven_plugin/pom.xml +++ b/maven_plugin/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index bd0caca0d..40332472f 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT pom Google Agent Development Kit Maven Parent POM @@ -51,7 +51,7 @@ 1.51.0 0.17.2 2.47.0 - 1.41.0 + 1.43.0 4.33.5 5.11.4 5.20.0 @@ -62,7 +62,7 @@ 0.18.1 3.41.0 3.9.0 - 1.11.0 + 1.12.2 2.0.17 1.4.5 1.0.0 @@ -73,6 +73,8 @@ 2.15.0 3.9.0 5.6 + 4.1.118.Final + @{jacoco.agent.argLine} --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.text=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/jdk.internal.misc=ALL-UNNAMED -Dio.netty.tryReflectionSetAccessible=true @@ -85,6 +87,13 @@ pom import + + io.netty + netty-bom + ${netty.version} + pom + import + com.google.cloud libraries-bom @@ -338,6 +347,8 @@ + + ${surefire.argLine} plain diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index 76b7331f3..19ef08a2d 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT ../../pom.xml @@ -36,16 +36,6 @@ com.google.adk google-adk-dev ${project.version} - - - ch.qos.logback - logback-classic - - - - - org.slf4j - slf4j-simple diff --git a/tutorials/live-audio-single-agent/pom.xml b/tutorials/live-audio-single-agent/pom.xml index a330cf4bd..99243893b 100644 --- a/tutorials/live-audio-single-agent/pom.xml +++ b/tutorials/live-audio-single-agent/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.9.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT ../../pom.xml