diff --git a/contrib/sarvam-ai/pom.xml b/contrib/sarvam-ai/pom.xml index 12eb49ac0..636289044 100644 --- a/contrib/sarvam-ai/pom.xml +++ b/contrib/sarvam-ai/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.2.0 + 0.9.1-SNAPSHOT ../../pom.xml diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java index 4ced7f6c9..02b16c5c1 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java @@ -16,6 +16,7 @@ package com.google.adk.models.sarvamai; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.BaseLlm; @@ -24,13 +25,19 @@ import com.google.adk.models.LlmResponse; import com.google.adk.models.sarvamai.chat.ChatRequest; import com.google.adk.models.sarvamai.chat.ChatResponse; +import com.google.adk.models.sarvamai.chat.ChatToolCall; +import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import java.io.BufferedReader; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.concurrent.TimeUnit; import okhttp3.MediaType; @@ -233,12 +240,48 @@ private LlmResponse toLlmResponse(ChatResponse chatResponse) { } var choice = chatResponse.getChoices().get(0); var effectiveMsg = choice.effectiveMessage(); - if (effectiveMsg == null || effectiveMsg.getContent() == null) { - throw new SarvamAiException("No content in response choice"); + if (effectiveMsg == null) { + throw new SarvamAiException("No message in response choice"); } - Content content = - Content.builder().role("model").parts(Part.fromText(effectiveMsg.getContent())).build(); + // Handle tool_calls in response (model requesting function execution) + if (effectiveMsg.getToolCalls() != null && !effectiveMsg.getToolCalls().isEmpty()) { + List parts = new ArrayList<>(); + for (ChatToolCall tc : effectiveMsg.getToolCalls()) { + if (tc.getFunction() == null || tc.getFunction().getName() == null) { + continue; + } + String argsStr = tc.getFunction().getArguments(); + if (argsStr == null) { + argsStr = "{}"; + } + Map args; + try { + args = objectMapper.readValue(argsStr, new TypeReference>() {}); + } catch (Exception e) { + args = Map.of(); + } + FunctionCall fc = + FunctionCall.builder() + .id(tc.getId()) + .name(tc.getFunction().getName()) + .args(args) + .build(); + parts.add(Part.builder().functionCall(fc).build()); + } + if (parts.isEmpty()) { + throw new SarvamAiException("Tool calls in response but no valid function call found"); + } + Content content = Content.builder().role("model").parts(ImmutableList.copyOf(parts)).build(); + return LlmResponse.builder().content(content).build(); + } + + // Handle text content + String textContent = effectiveMsg.getContent(); + if (textContent == null) { + textContent = ""; + } + Content content = Content.builder().role("model").parts(Part.fromText(textContent)).build(); return LlmResponse.builder().content(content).build(); } diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java index a820ac47e..b01bc6a2d 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java @@ -19,6 +19,7 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; /** * A message in the Sarvam AI chat completion API (request or response). @@ -38,6 +39,12 @@ public final class ChatMessage { @JsonProperty("reasoning_content") private String reasoningContent; + @JsonProperty("tool_calls") + private List toolCalls; + + @JsonProperty("tool_call_id") + private String toolCallId; + public ChatMessage() {} public ChatMessage(String role, String content) { @@ -68,4 +75,20 @@ public String getReasoningContent() { public void setReasoningContent(String reasoningContent) { this.reasoningContent = reasoningContent; } + + public List getToolCalls() { + return toolCalls; + } + + public void setToolCalls(List toolCalls) { + this.toolCalls = toolCalls; + } + + public String getToolCallId() { + return toolCallId; + } + + public void setToolCallId(String toolCallId) { + this.toolCallId = toolCallId; + } } diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java index 3faefa2e9..7f67232d9 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java @@ -18,12 +18,19 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.LlmRequest; import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; +import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Part; +import com.google.genai.types.Schema; +import com.google.genai.types.Type; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Optional; /** * Request body for the Sarvam AI chat completions endpoint. Constructed from the ADK {@link @@ -73,6 +80,15 @@ public final class ChatRequest { @JsonProperty("stop") private Object stop; + @JsonProperty("tools") + private List> tools; + + @JsonProperty("tool_choice") + private String toolChoice; + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final String IDENTIFIER_REGEX = "[^a-zA-Z0-9_\\.-]"; + public ChatRequest() {} /** @@ -95,18 +111,68 @@ public static ChatRequest fromLlmRequest( if ("model".equals(role)) { role = "assistant"; } - StringBuilder textBuilder = new StringBuilder(); - content - .parts() - .ifPresent( - parts -> { - for (Part part : parts) { - part.text().ifPresent(textBuilder::append); - } - }); - if (textBuilder.length() > 0) { - request.messages.add(new ChatMessage(role, textBuilder.toString())); + List parts = content.parts().orElse(ImmutableList.of()); + if (parts.isEmpty()) { + continue; } + Part firstPart = parts.get(0); + + if (firstPart.functionResponse().isPresent()) { + var fr = firstPart.functionResponse().get(); + ChatMessage toolMsg = new ChatMessage(); + toolMsg.setRole("tool"); + toolMsg.setToolCallId(fr.id().orElse("call_" + fr.name().orElse("unknown"))); + toolMsg.setContent( + fr.response() + .map( + r -> { + try { + return OBJECT_MAPPER.writeValueAsString(r); + } catch (Exception e) { + return "{}"; + } + }) + .orElse("{}")); + request.messages.add(toolMsg); + } else if (firstPart.functionCall().isPresent()) { + var fc = firstPart.functionCall().get(); + ChatMessage assistantMsg = new ChatMessage(); + assistantMsg.setRole("assistant"); + assistantMsg.setContent(null); + ChatToolCall tc = new ChatToolCall(); + tc.setId(fc.id().orElse("call_" + fc.name().orElse("unknown"))); + tc.setType("function"); + ChatToolCall.ChatToolCallFunction tcf = new ChatToolCall.ChatToolCallFunction(); + tcf.setName(fc.name().orElse("")); + tcf.setArguments( + fc.args() + .map( + args -> { + try { + return OBJECT_MAPPER.writeValueAsString(args); + } catch (Exception e) { + return "{}"; + } + }) + .orElse("{}")); + tc.setFunction(tcf); + assistantMsg.setToolCalls(List.of(tc)); + request.messages.add(assistantMsg); + } else { + StringBuilder textBuilder = new StringBuilder(); + for (Part part : parts) { + part.text().ifPresent(textBuilder::append); + } + if (textBuilder.length() > 0) { + request.messages.add( + new ChatMessage(role.equals("model") ? "assistant" : role, textBuilder.toString())); + } + } + } + + if (!llmRequest.tools().isEmpty()) { + request.tools = buildTools(llmRequest); + request.toolChoice = "auto"; } config.temperature().ifPresent(v -> request.temperature = v); @@ -120,6 +186,96 @@ public static ChatRequest fromLlmRequest( return request; } + private static List> buildTools(LlmRequest llmRequest) { + List> toolsList = new ArrayList<>(); + llmRequest + .tools() + .forEach( + (name, baseTool) -> { + Optional declOpt = baseTool.declaration(); + if (declOpt.isEmpty()) { + return; + } + FunctionDeclaration decl = declOpt.get(); + Map funcMap = new java.util.HashMap<>(); + funcMap.put("name", cleanForIdentifier(decl.name().orElse(""))); + funcMap.put("description", cleanForIdentifier(decl.description().orElse(""))); + + decl.parameters() + .ifPresent( + paramsSchema -> { + Map paramsMap = new java.util.HashMap<>(); + paramsMap.put("type", "object"); + paramsSchema + .properties() + .ifPresent( + props -> { + Map propsMap = new java.util.HashMap<>(); + props.forEach( + (key, schema) -> { + Map schemaMap = schemaToMap(schema); + normalizeTypeStrings(schemaMap); + propsMap.put(key, schemaMap); + }); + paramsMap.put("properties", propsMap); + }); + paramsSchema.required().ifPresent(r -> paramsMap.put("required", r)); + funcMap.put("parameters", paramsMap); + }); + + Map toolWrapper = new java.util.HashMap<>(); + toolWrapper.put("type", "function"); + toolWrapper.put("function", funcMap); + toolsList.add(toolWrapper); + }); + return toolsList; + } + + /** Manually convert Schema to Map to avoid Jackson Optional serialization issues. */ + private static Map schemaToMap(Schema schema) { + Map map = new java.util.HashMap<>(); + schema.type().ifPresent(t -> map.put("type", schemaTypeToString(t))); + schema.description().ifPresent(d -> map.put("description", d)); + schema + .properties() + .ifPresent( + props -> { + Map propsMap = new java.util.HashMap<>(); + props.forEach((k, v) -> propsMap.put(k, schemaToMap(v))); + map.put("properties", propsMap); + }); + schema.required().ifPresent(r -> map.put("required", r)); + schema.items().ifPresent(i -> map.put("items", schemaToMap(i))); + return map; + } + + private static String schemaTypeToString(Type type) { + return switch (type.knownEnum()) { + case STRING -> "string"; + case NUMBER -> "number"; + case INTEGER -> "integer"; + case BOOLEAN -> "boolean"; + case ARRAY -> "array"; + case OBJECT -> "object"; + default -> "string"; + }; + } + + private static String cleanForIdentifier(String input) { + return input == null ? "" : input.replaceAll(IDENTIFIER_REGEX, ""); + } + + @SuppressWarnings("unchecked") + private static void normalizeTypeStrings(Map valueDict) { + if (valueDict == null) return; + if (valueDict.containsKey("type") && valueDict.get("type") instanceof String) { + valueDict.put("type", ((String) valueDict.get("type")).toLowerCase()); + } + if (valueDict.containsKey("items") && valueDict.get("items") instanceof Map) { + normalizeTypeStrings((Map) valueDict.get("items")); + } + } + public String getModel() { return model; } @@ -151,4 +307,12 @@ public String getReasoningEffort() { public Boolean getWikiGrounding() { return wikiGrounding; } + + public List> getTools() { + return tools; + } + + public String getToolChoice() { + return toolChoice; + } } diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatToolCall.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatToolCall.java new file mode 100644 index 000000000..cf13ac23d --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatToolCall.java @@ -0,0 +1,87 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** OpenAI-format tool call (id, type, function with name and arguments). */ +@JsonIgnoreProperties(ignoreUnknown = true) +public final class ChatToolCall { + + @JsonProperty("id") + private String id; + + @JsonProperty("type") + private String type; + + @JsonProperty("function") + private ChatToolCallFunction function; + + public ChatToolCall() {} + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public ChatToolCallFunction getFunction() { + return function; + } + + public void setFunction(ChatToolCallFunction function) { + this.function = function; + } + + /** Inner function object (name, arguments). */ + public static final class ChatToolCallFunction { + @JsonProperty("name") + private String name; + + @JsonProperty("arguments") + private String arguments; + + public ChatToolCallFunction() {} + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getArguments() { + return arguments; + } + + public void setArguments(String arguments) { + this.arguments = arguments; + } + } +} diff --git a/pom.xml b/pom.xml index bd0caca0d..991057c34 100644 --- a/pom.xml +++ b/pom.xml @@ -32,6 +32,7 @@ contrib/spring-ai contrib/samples contrib/firestore-session-service + contrib/sarvam-ai tutorials/city-time-weather tutorials/live-audio-single-agent a2a