From f7536b848c7fb8c6bb208d420c0d1bc86b59ab22 Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Wed, 11 Feb 2026 16:35:03 +0530 Subject: [PATCH 01/11] feat: Integrate Sarvam AI and fix build issues --- contrib/sarvam-ai/pom.xml | 119 +++++++ .../google/adk/models/sarvamai/SarvamAi.java | 165 +++++++++ .../adk/models/sarvamai/SarvamAiChoice.java | 37 ++ .../adk/models/sarvamai/SarvamAiConfig.java | 37 ++ .../adk/models/sarvamai/SarvamAiMessage.java | 43 +++ .../adk/models/sarvamai/SarvamAiRequest.java | 54 +++ .../adk/models/sarvamai/SarvamAiResponse.java | 39 +++ .../sarvamai/SarvamAiResponseMessage.java | 37 ++ .../adk/models/sarvamai/SarvamAiTest.java | 153 +++++++++ .../java/com/google/adk/models/GptOssLlm.java | 25 +- .../java/com/google/adk/models/SarvamLlm.java | 324 ++++++++++++++++++ .../google/adk/transcription/ServiceType.java | 5 +- .../config/TranscriptionConfigLoader.java | 83 ++--- .../strategy/SarvamTranscriptionService.java | 168 +++++++++ .../strategy/TranscriptionServiceFactory.java | 3 + dev/src/main/resources/application.properties | 24 +- pom.xml | 1 + 17 files changed, 1237 insertions(+), 80 deletions(-) create mode 100644 contrib/sarvam-ai/pom.xml create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java create mode 100644 contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java create mode 100644 core/src/main/java/com/google/adk/models/SarvamLlm.java create mode 100644 core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java diff --git a/contrib/sarvam-ai/pom.xml b/contrib/sarvam-ai/pom.xml new file mode 100644 index 000000000..ab4e8eb59 --- /dev/null +++ b/contrib/sarvam-ai/pom.xml @@ -0,0 +1,119 @@ + + + + 4.0.0 + + + com.google.adk + google-adk-parent + 0.5.1-SNAPSHOT + ../../pom.xml + + + google-adk-sarvam-ai + Agent Development Kit - Sarvam AI + Sarvam AI integration for the Agent Development Kit. + + + + + com.google.adk + google-adk + ${project.version} + + + com.google.adk + google-adk-dev + ${project.version} + + + com.squareup.okhttp3 + okhttp + ${okhttp.version} + + + + + org.junit.jupiter + junit-jupiter-api + test + + + org.junit.jupiter + junit-jupiter-params + test + + + org.junit.jupiter + junit-jupiter-engine + test + + + com.google.truth + truth + test + + + org.assertj + assertj-core + test + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + test + + + + + + maven-surefire-plugin + 3.5.2 + + + me.fabriciorby + maven-surefire-junit5-tree-reporter + 0.1.0 + + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + + + + plain + + + **/*Test.java + + + ${project.basedir}/src/test/java + + + + + diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java new file mode 100644 index 000000000..2108b9848 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java @@ -0,0 +1,165 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// MODIFIED BY Sandeep Belgavi, 2026-02-11 +package com.google.adk.models.sarvamai; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.BaseLlm; +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import java.io.BufferedReader; +import java.io.IOException; +import okhttp3.Call; +import okhttp3.Callback; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import okhttp3.ResponseBody; + +/** + * This class is the main entry point for the Sarvam AI API. + * + * @author Sandeep Belgavi + * @since 2026-02-11 + */ +public class SarvamAi extends BaseLlm { + + private static final String API_ENDPOINT = "https://api.sarvam.ai/v1/chat/completions"; + private final OkHttpClient httpClient; + private final String apiKey; + private final ObjectMapper objectMapper; + + public SarvamAi(String modelName, SarvamAiConfig config) { + super(modelName); + this.httpClient = new OkHttpClient(); + this.apiKey = config.getApiKey(); + this.objectMapper = new ObjectMapper(); + } + + @Override + public Flowable generateContent(LlmRequest llmRequest, boolean stream) { + if (stream) { + return stream(llmRequest); + } else { + return Flowable.fromCallable( + () -> { + String requestBody = + objectMapper.writeValueAsString(new SarvamAiRequest(this.model(), llmRequest)); + Request request = + new Request.Builder() + .url(API_ENDPOINT) + .addHeader("Authorization", "Bearer " + apiKey) + .post(RequestBody.create(requestBody, MediaType.get("application/json"))) + .build(); + Response response = httpClient.newCall(request).execute(); + if (!response.isSuccessful()) { + throw new IOException("Unexpected code " + response); + } + ResponseBody responseBody = response.body(); + if (responseBody != null) { + String responseBodyString = responseBody.string(); + SarvamAiResponse sarvamAiResponse = + objectMapper.readValue(responseBodyString, SarvamAiResponse.class); + return toLlmResponse(sarvamAiResponse); + } else { + throw new IOException("Response body is null"); + } + }); + } + } + + private Flowable stream(LlmRequest llmRequest) { + return Flowable.create( + emitter -> { + try { + String requestBody = + objectMapper.writeValueAsString(new SarvamAiRequest(this.model(), llmRequest)); + Request request = + new Request.Builder() + .url(API_ENDPOINT) + .addHeader("Authorization", "Bearer " + apiKey) + .post(RequestBody.create(requestBody, MediaType.get("application/json"))) + .build(); + httpClient + .newCall(request) + .enqueue( + new Callback() { + @Override + public void onFailure(Call call, IOException e) { + emitter.onError(e); + } + + @Override + public void onResponse(Call call, Response response) throws IOException { + if (!response.isSuccessful()) { + emitter.onError(new IOException("Unexpected code " + response)); + return; + } + ResponseBody responseBody = response.body(); + if (responseBody != null) { + try (var reader = new BufferedReader(responseBody.charStream())) { + String line; + while ((line = reader.readLine()) != null) { + if (line.startsWith("data: ")) { + String data = line.substring(6); + if (data.equals("[DONE]")) { + emitter.onComplete(); + return; + } + SarvamAiResponse sarvamAiResponse = + objectMapper.readValue(data, SarvamAiResponse.class); + emitter.onNext(toLlmResponse(sarvamAiResponse)); + } + } + emitter.onComplete(); + } + } else { + emitter.onError(new IOException("Response body is null")); + } + } + }); + } catch (IOException e) { + emitter.onError(e); + } + }, + io.reactivex.rxjava3.core.BackpressureStrategy.BUFFER); + } + + private LlmResponse toLlmResponse(SarvamAiResponse sarvamAiResponse) { + Content content = + Content.builder() + .role("model") + .parts( + java.util.Collections.singletonList( + Part.fromText(sarvamAiResponse.getChoices().get(0).getMessage().getContent()))) + .build(); + return LlmResponse.builder().content(content).build(); + } + + @Override + public BaseLlmConnection connect(LlmRequest llmRequest) { + // TODO: Implement this method + throw new UnsupportedOperationException( + "Live connection is not supported for Sarvam AI models."); + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java new file mode 100644 index 000000000..ff31c5b66 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java @@ -0,0 +1,37 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// MODIFIED BY Sandeep Belgavi, 2026-02-11 +package com.google.adk.models.sarvamai; + +/** + * This class is used to represent a choice from the Sarvam AI API. + * + * @author Sandeep Belgavi + * @since 2026-02-11 + */ +public class SarvamAiChoice { + + private SarvamAiResponseMessage message; + + public SarvamAiResponseMessage getMessage() { + return message; + } + + public void setMessage(SarvamAiResponseMessage message) { + this.message = message; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java new file mode 100644 index 000000000..edf846396 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java @@ -0,0 +1,37 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// MODIFIED BY Sandeep Belgavi, 2026-02-11 +package com.google.adk.models.sarvamai; + +/** + * This class is used to configure the Sarvam AI API. + * + * @author Sandeep Belgavi + * @since 2026-02-11 + */ +public class SarvamAiConfig { + + private final String apiKey; + + public SarvamAiConfig(String apiKey) { + this.apiKey = apiKey; + } + + public String getApiKey() { + return apiKey; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java new file mode 100644 index 000000000..82969ffb8 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java @@ -0,0 +1,43 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// MODIFIED BY Sandeep Belgavi, 2026-02-11 +package com.google.adk.models.sarvamai; + +/** + * This class is used to represent a message from the Sarvam AI API. + * + * @author Sandeep Belgavi + * @since 2026-02-11 + */ +public class SarvamAiMessage { + + private String role; + private String content; + + public SarvamAiMessage(String role, String content) { + this.role = role; + this.content = content; + } + + public String getRole() { + return role; + } + + public String getContent() { + return content; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java new file mode 100644 index 000000000..21a6da172 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// MODIFIED BY Sandeep Belgavi, 2026-02-11 +package com.google.adk.models.sarvamai; + +import com.google.adk.models.LlmRequest; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.util.ArrayList; +import java.util.List; + +/** + * This class is used to create a request to the Sarvam AI API. + * + * @author Sandeep Belgavi + * @since 2026-02-11 + */ +public class SarvamAiRequest { + + private String model; + private List messages; + + public SarvamAiRequest(String model, LlmRequest llmRequest) { + this.model = model; + this.messages = new ArrayList<>(); + for (Content content : llmRequest.contents()) { + for (Part part : content.parts().get()) { + this.messages.add(new SarvamAiMessage(content.role().get(), part.text().get())); + } + } + } + + public String getModel() { + return model; + } + + public List getMessages() { + return messages; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java new file mode 100644 index 000000000..2de224a9b --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java @@ -0,0 +1,39 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// MODIFIED BY Sandeep Belgavi, 2026-02-11 +package com.google.adk.models.sarvamai; + +import java.util.List; + +/** + * This class is used to represent a response from the Sarvam AI API. + * + * @author Sandeep Belgavi + * @since 2026-02-11 + */ +public class SarvamAiResponse { + + private List choices; + + public List getChoices() { + return choices; + } + + public void setChoices(List choices) { + this.choices = choices; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java new file mode 100644 index 000000000..dcc8ffe6c --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java @@ -0,0 +1,37 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// MODIFIED BY Sandeep Belgavi, 2026--11 +package com.google.adk.models.sarvamai; + +/** + * This class is used to represent a response message from the Sarvam AI API. + * + * @author Sandeep Belgavi + * @since 2026-02-11 + */ +public class SarvamAiResponseMessage { + + private String content; + + public String getContent() { + return content; + } + + public void setContent(String content) { + this.content = content; + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java new file mode 100644 index 000000000..a5cb7a1db --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java @@ -0,0 +1,153 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// MODIFIED BY Sandeep Belgavi, 2026-02-11 +package com.google.adk.models.sarvamai; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import java.io.IOException; +import java.util.List; +import okhttp3.Call; +import okhttp3.Callback; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Protocol; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.ResponseBody; +import org.junit.Before; +import org.junit.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class SarvamAiTest { + + private static final String API_KEY = "test-api-key"; + private static final String MODEL_NAME = "test-model"; + private static final String COMPLETION_TEXT = "Hello, world!"; + private static final String STREAMING_CHUNK_1 = + "data: {\"choices\": [{\"message\": {\"content\": \"Hello,\"}}]}"; + private static final String STREAMING_CHUNK_2 = + "data: {\"choices\": [{\"message\": {\"content\": \" world!\"}}]}"; + private static final String STREAMING_DONE = "data: [DONE]"; + + @Mock private OkHttpClient mockHttpClient; + @Mock private Call mockCall; + @Mock private SarvamAiConfig mockConfig; + + @Captor private ArgumentCaptor requestCaptor; + @Captor private ArgumentCaptor callbackCaptor; + + private SarvamAi sarvamAi; + private ObjectMapper objectMapper; + + @Before + public void setUp() { + when(mockConfig.getApiKey()).thenReturn(API_KEY); + sarvamAi = new SarvamAi(MODEL_NAME, mockConfig); + objectMapper = new ObjectMapper(); + + when(mockHttpClient.newCall(requestCaptor.capture())).thenReturn(mockCall); + } + + @Test + public void generateContent_blockingCall_returnsLlmResponse() throws IOException { + String mockResponseBody = createMockSarvamAiResponseBody(COMPLETION_TEXT); + Response mockResponse = + new Response.Builder() + .request(new Request.Builder().url("http://localhost").build()) + .protocol(Protocol.HTTP_1_1) + .code(200) + .message("OK") + .body(ResponseBody.create(mockResponseBody, MediaType.get("application/json"))) + .build(); + + when(mockCall.execute()).thenReturn(mockResponse); + + LlmRequest llmRequest = + LlmRequest.builder() + .contents( + List.of(Content.builder().parts(List.of(Part.fromText("test query"))).build())) + .build(); + Flowable responseFlowable = sarvamAi.generateContent(llmRequest, false); + + LlmResponse llmResponse = responseFlowable.blockingFirst(); + + assertThat(llmResponse.content().get().parts().get().get(0).text()).isEqualTo(COMPLETION_TEXT); + } + + @Test + public void generateContent_streamingCall_returnsLlmResponses() throws IOException { + ResponseBody mockStreamingBody = + createMockStreamingResponseBody( + List.of(STREAMING_CHUNK_1, STREAMING_CHUNK_2, STREAMING_DONE)); + + Response mockResponse = + new Response.Builder() + .request(new Request.Builder().url("http://localhost").build()) + .protocol(Protocol.HTTP_1_1) + .code(200) + .message("OK") + .body(mockStreamingBody) + .build(); + + when(mockCall.execute()) + .thenThrow(new IllegalStateException("Should not be called for streaming")); + + LlmRequest llmRequest = + LlmRequest.builder() + .contents( + List.of(Content.builder().parts(List.of(Part.fromText("test query"))).build())) + .build(); + Flowable responseFlowable = sarvamAi.generateContent(llmRequest, true); + + // Simulate the asynchronous callback + Callback capturedCallback = callbackCaptor.getValue(); + capturedCallback.onResponse(mockCall, mockResponse); + + List responses = responseFlowable.toList().blockingGet(); + + assertThat(responses).hasSize(2); + assertThat(responses.get(0).content().get().parts().get().get(0).text()).isEqualTo("Hello,"); + assertThat(responses.get(1).content().get().parts().get().get(0).text()).isEqualTo(" world!"); + } + + // Helper method to create a mock SarvamAi response body + private String createMockSarvamAiResponseBody(String text) { + return String.format("{\"choices\": [{\"message\": {\"content\": \"%s\"}}]}", text); + } + + // Helper method to create a mock streaming response body + private ResponseBody createMockStreamingResponseBody(List chunks) { + StringBuilder bodyBuilder = new StringBuilder(); + for (String chunk : chunks) { + bodyBuilder.append(chunk).append("\n\n"); + } + return ResponseBody.create(bodyBuilder.toString(), MediaType.get("text/event-stream")); + } +} diff --git a/core/src/main/java/com/google/adk/models/GptOssLlm.java b/core/src/main/java/com/google/adk/models/GptOssLlm.java index 331203ac6..895aba540 100644 --- a/core/src/main/java/com/google/adk/models/GptOssLlm.java +++ b/core/src/main/java/com/google/adk/models/GptOssLlm.java @@ -100,16 +100,16 @@ public GptOssLlm(String modelName) { * @param modelName The name of the GPT OSS model to use (e.g., "gpt-oss-4"). * @param vertexCredentials The Vertex AI credentials to access the model. */ -// public GptOssLlm(String modelName, VertexCredentials vertexCredentials) { -// super(modelName); -// Objects.requireNonNull(vertexCredentials, "vertexCredentials cannot be null"); -// Client.Builder apiClientBuilder = -// Client.builder().httpOptions(HttpOptions.builder().headers(TRACKING_HEADERS).build()); -// vertexCredentials.project().ifPresent(apiClientBuilder::project); -// vertexCredentials.location().ifPresent(apiClientBuilder::location); -// vertexCredentials.credentials().ifPresent(apiClientBuilder::credentials); -// this.apiClient = apiClientBuilder.build(); -// } + // public GptOssLlm(String modelName, VertexCredentials vertexCredentials) { + // super(modelName); + // Objects.requireNonNull(vertexCredentials, "vertexCredentials cannot be null"); + // Client.Builder apiClientBuilder = + // Client.builder().httpOptions(HttpOptions.builder().headers(TRACKING_HEADERS).build()); + // vertexCredentials.project().ifPresent(apiClientBuilder::project); + // vertexCredentials.location().ifPresent(apiClientBuilder::location); + // vertexCredentials.credentials().ifPresent(apiClientBuilder::credentials); + // this.apiClient = apiClientBuilder.build(); + // } /** * Returns a new Builder instance for constructing GptOssLlm objects. Note that when building a @@ -165,8 +165,7 @@ public GptOssLlm build() { if (apiClient != null) { return new GptOssLlm(modelName, apiClient); - } - else { + } else { return new GptOssLlm( modelName, Client.builder() @@ -354,4 +353,4 @@ public BaseLlmConnection connect(LlmRequest llmRequest) { return new GeminiLlmConnection(apiClient, effectiveModelName, liveConnectConfig); } -} \ No newline at end of file +} diff --git a/core/src/main/java/com/google/adk/models/SarvamLlm.java b/core/src/main/java/com/google/adk/models/SarvamLlm.java new file mode 100644 index 000000000..6987c1347 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/SarvamLlm.java @@ -0,0 +1,324 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.adk.tools.BaseTool; +import com.google.common.base.Strings; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.BackpressureStrategy; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.subjects.PublishSubject; +import java.io.BufferedReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; + +/** Sarvam AI LLM implementation. Uses the OpenAI-compatible chat completion endpoint. */ +public class SarvamLlm extends BaseLlm { + + private static final String API_URL = "https://api.sarvam.ai/chat/completions"; + private static final MediaType JSON = MediaType.get("application/json; charset=utf-8"); + + private final String apiKey; + private final OkHttpClient client; + private final ObjectMapper objectMapper; + + public SarvamLlm(String model) { + this(model, null); + } + + public SarvamLlm(String model, String apiKey) { + super(model); + if (Strings.isNullOrEmpty(apiKey)) { + this.apiKey = System.getenv("SARVAM_API_KEY"); + } else { + this.apiKey = apiKey; + } + + if (Strings.isNullOrEmpty(this.apiKey)) { + throw new IllegalArgumentException( + "Sarvam API key is required. Set SARVAM_API_KEY env variable or pass it to constructor."); + } + + this.client = new OkHttpClient(); + this.objectMapper = new ObjectMapper(); + } + + @Override + public Flowable generateContent(LlmRequest llmRequest, boolean stream) { + return Flowable.create( + emitter -> { + try { + ObjectNode jsonBody = objectMapper.createObjectNode(); + jsonBody.put("model", model()); + jsonBody.put("stream", stream); + + ArrayNode messages = jsonBody.putArray("messages"); + + // Add system instructions if present + for (String instruction : llmRequest.getSystemInstructions()) { + ObjectNode systemMsg = messages.addObject(); + systemMsg.put("role", "system"); + systemMsg.put("content", instruction); + } + + // Add conversation history + for (Content content : llmRequest.contents()) { + ObjectNode message = messages.addObject(); + String role = content.role().orElse("user"); + // Map "model" to "assistant" for OpenAI compatibility + if ("model".equals(role)) { + role = "assistant"; + } + message.put("role", role); + + StringBuilder textBuilder = new StringBuilder(); + content + .parts() + .ifPresent( + parts -> { + for (Part part : parts) { + part.text().ifPresent(textBuilder::append); + } + }); + message.put("content", textBuilder.toString()); + } + + // Add tool definitions if present + if (llmRequest.tools() != null && !llmRequest.tools().isEmpty()) { + ArrayNode toolsArray = jsonBody.putArray("tools"); + for (BaseTool tool : llmRequest.tools().values()) { + ObjectNode toolNode = toolsArray.addObject(); + toolNode.put("type", "function"); + ObjectNode functionNode = toolNode.putObject("function"); + functionNode.put("name", tool.name()); + functionNode.put("description", tool.description()); + + tool.declaration() + .flatMap(decl -> decl.parameters()) + .ifPresent( + params -> { + try { + String paramsJson = objectMapper.writeValueAsString(params); + functionNode.set("parameters", objectMapper.readTree(paramsJson)); + } catch (Exception e) { + // Ignore or log error + } + }); + } + } + + RequestBody body = RequestBody.create(jsonBody.toString(), JSON); + Request request = + new Request.Builder() + .url(API_URL) + .addHeader("Content-Type", "application/json") + .addHeader("api-subscription-key", apiKey) + .post(body) + .build(); + + if (stream) { + try (Response response = client.newCall(request).execute()) { + if (!response.isSuccessful()) { + emitter.onError( + new IOException( + "Unexpected code " + + response + + " body: " + + (response.body() != null ? response.body().string() : ""))); + return; + } + + if (response.body() == null) { + emitter.onError(new IOException("Response body is null")); + return; + } + + BufferedReader reader = new BufferedReader(response.body().charStream()); + String line; + while ((line = reader.readLine()) != null) { + if (line.startsWith("data: ")) { + String data = line.substring(6).trim(); + if ("[DONE]".equals(data)) { + break; + } + try { + JsonNode chunk = objectMapper.readTree(data); + JsonNode choices = chunk.path("choices"); + if (choices.isArray() && choices.size() > 0) { + JsonNode delta = choices.get(0).path("delta"); + if (delta.has("content")) { + String contentPart = delta.get("content").asText(); + + Content content = + Content.builder() + .role("model") + .parts(Part.fromText(contentPart)) + .build(); + + LlmResponse llmResponse = + LlmResponse.builder().content(content).partial(true).build(); + emitter.onNext(llmResponse); + } + } + } catch (Exception e) { + // Ignore parse errors for keep-alive or malformed lines + } + } + } + emitter.onComplete(); + } + } else { + try (Response response = client.newCall(request).execute()) { + if (!response.isSuccessful()) { + emitter.onError( + new IOException( + "Unexpected code " + + response + + " body: " + + (response.body() != null ? response.body().string() : ""))); + return; + } + if (response.body() == null) { + emitter.onError(new IOException("Response body is null")); + return; + } + String responseBody = response.body().string(); + JsonNode root = objectMapper.readTree(responseBody); + JsonNode choices = root.path("choices"); + if (choices.isArray() && choices.size() > 0) { + JsonNode message = choices.get(0).path("message"); + String contentText = message.path("content").asText(); + + Content content = + Content.builder().role("model").parts(Part.fromText(contentText)).build(); + + LlmResponse llmResponse = LlmResponse.builder().content(content).build(); + emitter.onNext(llmResponse); + emitter.onComplete(); + } else { + emitter.onError(new IOException("Empty choices in response")); + } + } + } + } catch (Exception e) { + emitter.onError(e); + } + }, + BackpressureStrategy.BUFFER); + } + + @Override + public BaseLlmConnection connect(LlmRequest llmRequest) { + return new SarvamConnection(llmRequest); + } + + private class SarvamConnection implements BaseLlmConnection { + private final LlmRequest initialRequest; + private final List history = new ArrayList<>(); + private final PublishSubject responseSubject = PublishSubject.create(); + + public SarvamConnection(LlmRequest llmRequest) { + this.initialRequest = llmRequest; + this.history.addAll(llmRequest.contents()); + } + + @Override + public Completable sendContent(Content content) { + return Completable.fromAction( + () -> { + history.add(content); + generate(); + }); + } + + @Override + public Completable sendHistory(List history) { + return Completable.fromAction( + () -> { + this.history.clear(); + this.history.addAll(history); + generate(); + }); + } + + @Override + public Completable sendRealtime(Blob blob) { + return Completable.error( + new UnsupportedOperationException("Realtime not supported for Sarvam")); + } + + private void generate() { + LlmRequest.Builder builder = + LlmRequest.builder().contents(new ArrayList<>(history)).tools(initialRequest.tools()); + builder.appendInstructions(initialRequest.getSystemInstructions()); + LlmRequest request = builder.build(); + + StringBuilder fullContent = new StringBuilder(); + generateContent(request, true) + .subscribe( + response -> { + responseSubject.onNext(response); + response + .content() + .flatMap(Content::parts) + .ifPresent( + parts -> { + for (Part part : parts) { + part.text().ifPresent(fullContent::append); + } + }); + }, + responseSubject::onError, + () -> { + Content responseContent = + Content.builder() + .role("model") + .parts(Part.fromText(fullContent.toString())) + .build(); + history.add(responseContent); + }); + } + + @Override + public Flowable receive() { + return responseSubject.toFlowable(BackpressureStrategy.BUFFER); + } + + @Override + public void close() { + responseSubject.onComplete(); + } + + @Override + public void close(Throwable throwable) { + responseSubject.onError(throwable); + } + } +} diff --git a/core/src/main/java/com/google/adk/transcription/ServiceType.java b/core/src/main/java/com/google/adk/transcription/ServiceType.java index 2d4ae233f..98203eee4 100644 --- a/core/src/main/java/com/google/adk/transcription/ServiceType.java +++ b/core/src/main/java/com/google/adk/transcription/ServiceType.java @@ -33,7 +33,10 @@ public enum ServiceType { AZURE("azure"), /** AWS Transcribe (future). */ - AWS_TRANSCRIBE("aws_transcribe"); + AWS_TRANSCRIBE("aws_transcribe"), + + /** Sarvam AI transcription. */ + SARVAM("sarvam"); private final String value; diff --git a/core/src/main/java/com/google/adk/transcription/config/TranscriptionConfigLoader.java b/core/src/main/java/com/google/adk/transcription/config/TranscriptionConfigLoader.java index 7fb85fb9c..0de92d50b 100644 --- a/core/src/main/java/com/google/adk/transcription/config/TranscriptionConfigLoader.java +++ b/core/src/main/java/com/google/adk/transcription/config/TranscriptionConfigLoader.java @@ -23,19 +23,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Loads transcription configuration from environment variables. Follows 12-Factor App principles. - * - *

Transcription is an optional feature. If ADK_TRANSCRIPTION_ENDPOINT is not set, this returns - * Optional.empty(), allowing the framework to work without transcription. - * - * @author Sandeep Belgavi - * @since 2026-01-24 - */ +/** Loads transcription configuration from environment variables or system properties. */ public class TranscriptionConfigLoader { private static final Logger logger = LoggerFactory.getLogger(TranscriptionConfigLoader.class); - // Environment variable names + // Variable names private static final String ENDPOINT_ENV = "ADK_TRANSCRIPTION_ENDPOINT"; private static final String API_KEY_ENV = "ADK_TRANSCRIPTION_API_KEY"; private static final String LANGUAGE_ENV = "ADK_TRANSCRIPTION_LANGUAGE"; @@ -44,16 +36,23 @@ public class TranscriptionConfigLoader { private static final String SERVICE_TYPE_ENV = "ADK_TRANSCRIPTION_SERVICE_TYPE"; private static final String CHUNK_SIZE_ENV = "ADK_TRANSCRIPTION_CHUNK_SIZE_MS"; - /** - * Loads configuration from environment variables. Returns Optional.empty() if transcription is - * not configured (optional feature). - * - * @return Optional containing TranscriptionConfig if configured - */ + private static String getValue(String key) { + String val = System.getProperty(key); + if (val == null || val.isEmpty()) { + val = System.getenv(key); + } + return val; + } + public static Optional loadFromEnvironment() { - String endpoint = System.getenv(ENDPOINT_ENV); + String endpoint = getValue(ENDPOINT_ENV); + + // For Sarvam, we can default the endpoint if service type is sarvam + String serviceType = getValue(SERVICE_TYPE_ENV); + if ("sarvam".equalsIgnoreCase(serviceType) && (endpoint == null || endpoint.isEmpty())) { + endpoint = "https://api.sarvam.ai/speech-to-text"; + } - // Transcription is optional - return empty if not configured if (endpoint == null || endpoint.isEmpty()) { logger.debug("Transcription not configured ({} not set)", ENDPOINT_ENV); return Optional.empty(); @@ -61,20 +60,23 @@ public static Optional loadFromEnvironment() { TranscriptionConfig.Builder builder = TranscriptionConfig.builder().endpoint(endpoint); - // Optional: API Key - String apiKey = System.getenv(API_KEY_ENV); + String apiKey = getValue(API_KEY_ENV); + if (apiKey == null || apiKey.isEmpty()) { + apiKey = getValue("SARVAM_API_KEY"); + } + if (apiKey != null && !apiKey.isEmpty()) { builder.apiKey(apiKey); } - // Optional: Language (default: auto) - String language = System.getenv(LANGUAGE_ENV); + String language = getValue(LANGUAGE_ENV); if (language != null && !language.isEmpty()) { builder.language(language); + } else if ("sarvam".equalsIgnoreCase(serviceType)) { + builder.language("hi-IN"); // Default for Sarvam POC } - // Optional: Timeout (default: 30 seconds) - String timeoutStr = System.getenv(TIMEOUT_ENV); + String timeoutStr = getValue(TIMEOUT_ENV); if (timeoutStr != null) { try { int timeoutSeconds = Integer.parseInt(timeoutStr); @@ -86,43 +88,12 @@ public static Optional loadFromEnvironment() { } } - // Optional: Max retries (default: 3) - String maxRetriesStr = System.getenv(MAX_RETRIES_ENV); - if (maxRetriesStr != null) { - try { - int maxRetries = Integer.parseInt(maxRetriesStr); - if (maxRetries >= 0) { - builder.maxRetries(maxRetries); - } - } catch (NumberFormatException e) { - logger.warn("Invalid max retries value: {}, using default", maxRetriesStr); - } - } - - // Optional: Chunk size (default: 500ms) - String chunkSizeStr = System.getenv(CHUNK_SIZE_ENV); - if (chunkSizeStr != null) { - try { - int chunkSizeMs = Integer.parseInt(chunkSizeStr); - if (chunkSizeMs > 0) { - builder.chunkSizeMs(chunkSizeMs); - } - } catch (NumberFormatException e) { - logger.warn("Invalid chunk size value: {}, using default", chunkSizeStr); - } - } - - // Audio format (default: PCM 16kHz Mono) builder.audioFormat(AudioFormat.PCM_16KHZ_MONO); - - // Enable partial results for real-time streaming builder.enablePartialResults(true); TranscriptionConfig config = builder.build(); logger.info( - "Loaded transcription config: endpoint={}, service={}", - config.getEndpoint(), - System.getenv(SERVICE_TYPE_ENV)); + "Loaded transcription config: endpoint={}, service={}", config.getEndpoint(), serviceType); return Optional.of(config); } diff --git a/core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java b/core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java new file mode 100644 index 000000000..d4567e688 --- /dev/null +++ b/core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java @@ -0,0 +1,168 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.transcription.strategy; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.transcription.ServiceHealth; +import com.google.adk.transcription.ServiceType; +import com.google.adk.transcription.TranscriptionConfig; +import com.google.adk.transcription.TranscriptionEvent; +import com.google.adk.transcription.TranscriptionException; +import com.google.adk.transcription.TranscriptionResult; +import com.google.adk.transcription.TranscriptionService; +import com.google.adk.transcription.processor.AudioChunkAggregator; +import com.google.common.base.Strings; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import okhttp3.MediaType; +import okhttp3.MultipartBody; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Sarvam AI transcription service implementation. */ +public class SarvamTranscriptionService implements TranscriptionService { + private static final Logger logger = LoggerFactory.getLogger(SarvamTranscriptionService.class); + private static final String API_URL = "https://api.sarvam.ai/speech-to-text"; + + private final OkHttpClient client; + private final String apiKey; + private final ObjectMapper objectMapper; + + public SarvamTranscriptionService() { + this(null); + } + + public SarvamTranscriptionService(String apiKey) { + if (Strings.isNullOrEmpty(apiKey)) { + this.apiKey = System.getenv("SARVAM_API_KEY"); + } else { + this.apiKey = apiKey; + } + + if (Strings.isNullOrEmpty(this.apiKey)) { + logger.warn("Sarvam API key not found. STT will fail."); + } + + this.client = + new OkHttpClient.Builder() + .connectTimeout(30, TimeUnit.SECONDS) + .readTimeout(60, TimeUnit.SECONDS) + .build(); + this.objectMapper = new ObjectMapper(); + } + + @Override + public TranscriptionResult transcribe(byte[] audioData, TranscriptionConfig requestConfig) + throws TranscriptionException { + try { + RequestBody fileBody = RequestBody.create(audioData, MediaType.parse("audio/wav")); + + MultipartBody requestBody = + new MultipartBody.Builder() + .setType(MultipartBody.FORM) + .addFormDataPart("file", "audio.wav", fileBody) + .addFormDataPart("model", "saaras_v3") + .addFormDataPart("language_code", requestConfig.getLanguage()) + .build(); + + Request request = + new Request.Builder() + .url(API_URL) + .addHeader("api-subscription-key", apiKey) + .post(requestBody) + .build(); + + try (Response response = client.newCall(request).execute()) { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : ""; + throw new IOException("Unexpected code " + response + " body: " + errorBody); + } + + JsonNode root = objectMapper.readTree(response.body().string()); + String transcript = root.path("transcript").asText(); + + return TranscriptionResult.builder() + .text(transcript) + .timestamp(System.currentTimeMillis()) + .build(); + } + } catch (Exception e) { + logger.error("Error transcribing audio with Sarvam", e); + throw new TranscriptionException("Transcription failed", e); + } + } + + @Override + public Single transcribeAsync( + byte[] audioData, TranscriptionConfig requestConfig) { + return Single.fromCallable(() -> transcribe(audioData, requestConfig)) + .subscribeOn(io.reactivex.rxjava3.schedulers.Schedulers.io()); + } + + @Override + public Flowable transcribeStream( + Flowable audioStream, TranscriptionConfig requestConfig) { + AudioChunkAggregator aggregator = + new AudioChunkAggregator( + requestConfig.getAudioFormat(), Duration.ofMillis(requestConfig.getChunkSizeMs())); + + return audioStream + .buffer(requestConfig.getChunkSizeMs(), TimeUnit.MILLISECONDS) + .map( + chunks -> { + byte[] aggregated = aggregator.aggregate(chunks); + try { + TranscriptionResult result = transcribe(aggregated, requestConfig); + return mapToTranscriptionEvent(result); + } catch (TranscriptionException e) { + logger.error("Stream transcription error", e); + throw new RuntimeException(e); + } + }); + } + + @Override + public boolean isAvailable() { + return !Strings.isNullOrEmpty(apiKey); + } + + @Override + public ServiceType getServiceType() { + return ServiceType.SARVAM; + } + + @Override + public ServiceHealth getHealth() { + return ServiceHealth.builder().available(isAvailable()).serviceType(ServiceType.SARVAM).build(); + } + + private TranscriptionEvent mapToTranscriptionEvent(TranscriptionResult result) { + return TranscriptionEvent.builder() + .text(result.getText()) + .finished(true) + .timestamp(result.getTimestamp()) + .build(); + } +} diff --git a/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java b/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java index c9c28d928..d271b06af 100644 --- a/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java +++ b/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java @@ -84,6 +84,9 @@ private static TranscriptionService createService(TranscriptionConfig config) { ServiceType serviceType = determineServiceType(config); switch (serviceType) { + case SARVAM: + return new SarvamTranscriptionService(config.getApiKey().orElse(null)); + case WHISPER: return createWhisperService(config); diff --git a/dev/src/main/resources/application.properties b/dev/src/main/resources/application.properties index 0ff0eb627..a7a8dee80 100644 --- a/dev/src/main/resources/application.properties +++ b/dev/src/main/resources/application.properties @@ -1,11 +1,15 @@ -# Spring Boot Server Configuration -# Author: Sandeep Belgavi -# Date: January 24, 2026 +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -# Spring Boot server port (for Spring SSE endpoint) -server.port=9086 - -# HttpServer SSE Configuration (default SSE endpoint) -adk.httpserver.sse.enabled=true -adk.httpserver.sse.port=9085 -adk.httpserver.sse.host=0.0.0.0 +adk.httpserver.sse.port=9999 \ No newline at end of file diff --git a/pom.xml b/pom.xml index 6a1aa5af5..46971625c 100644 --- a/pom.xml +++ b/pom.xml @@ -32,6 +32,7 @@ maven_plugin contrib/langchain4j contrib/spring-ai + contrib/sarvam-ai contrib/samples contrib/firestore-session-service tutorials/city-time-weather From ccfa56cc63db2eb2d67561f7939f378115ad6c3c Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Wed, 11 Feb 2026 16:45:45 +0530 Subject: [PATCH 02/11] refactor: Update author information --- .../main/java/com/google/adk/models/sarvamai/SarvamAi.java | 1 - .../com/google/adk/models/sarvamai/SarvamAiChoice.java | 1 - .../com/google/adk/models/sarvamai/SarvamAiConfig.java | 1 - .../com/google/adk/models/sarvamai/SarvamAiMessage.java | 1 - .../com/google/adk/models/sarvamai/SarvamAiRequest.java | 1 - .../com/google/adk/models/sarvamai/SarvamAiResponse.java | 1 - .../adk/models/sarvamai/SarvamAiResponseMessage.java | 1 - .../java/com/google/adk/models/sarvamai/SarvamAiTest.java | 7 ++++++- 8 files changed, 6 insertions(+), 8 deletions(-) diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java index 2108b9848..be991d03a 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java @@ -14,7 +14,6 @@ * limitations under the License. */ -// MODIFIED BY Sandeep Belgavi, 2026-02-11 package com.google.adk.models.sarvamai; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java index ff31c5b66..3980d88f3 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java @@ -14,7 +14,6 @@ * limitations under the License. */ -// MODIFIED BY Sandeep Belgavi, 2026-02-11 package com.google.adk.models.sarvamai; /** diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java index edf846396..0d2b062a7 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java @@ -14,7 +14,6 @@ * limitations under the License. */ -// MODIFIED BY Sandeep Belgavi, 2026-02-11 package com.google.adk.models.sarvamai; /** diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java index 82969ffb8..802cef0d9 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java @@ -14,7 +14,6 @@ * limitations under the License. */ -// MODIFIED BY Sandeep Belgavi, 2026-02-11 package com.google.adk.models.sarvamai; /** diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java index 21a6da172..a339f2568 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java @@ -14,7 +14,6 @@ * limitations under the License. */ -// MODIFIED BY Sandeep Belgavi, 2026-02-11 package com.google.adk.models.sarvamai; import com.google.adk.models.LlmRequest; diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java index 2de224a9b..7877e8261 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java @@ -14,7 +14,6 @@ * limitations under the License. */ -// MODIFIED BY Sandeep Belgavi, 2026-02-11 package com.google.adk.models.sarvamai; import java.util.List; diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java index dcc8ffe6c..5af09d30f 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java @@ -14,7 +14,6 @@ * limitations under the License. */ -// MODIFIED BY Sandeep Belgavi, 2026--11 package com.google.adk.models.sarvamai; /** diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java index a5cb7a1db..2f9d5a013 100644 --- a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java @@ -14,7 +14,6 @@ * limitations under the License. */ -// MODIFIED BY Sandeep Belgavi, 2026-02-11 package com.google.adk.models.sarvamai; import static com.google.common.truth.Truth.assertThat; @@ -44,6 +43,12 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +/** + * Tests for SarvamAi. + * + * @author Sandeep Belgavi + * @since 2026-02-11 + */ @ExtendWith(MockitoExtension.class) public class SarvamAiTest { From 63730401622c2b9ff155563c133ed6926e484e85 Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Wed, 11 Feb 2026 16:55:39 +0530 Subject: [PATCH 03/11] refactor: Update author information --- core/src/main/java/com/google/adk/models/SarvamLlm.java | 8 +++++++- .../strategy/SarvamTranscriptionService.java | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/google/adk/models/SarvamLlm.java b/core/src/main/java/com/google/adk/models/SarvamLlm.java index 6987c1347..3cd779eff 100644 --- a/core/src/main/java/com/google/adk/models/SarvamLlm.java +++ b/core/src/main/java/com/google/adk/models/SarvamLlm.java @@ -14,6 +14,7 @@ * limitations under the License. */ +// MODIFIED BY Sandeep Belgavi, 2026-02-11 package com.google.adk.models; import com.fasterxml.jackson.databind.JsonNode; @@ -39,7 +40,12 @@ import okhttp3.RequestBody; import okhttp3.Response; -/** Sarvam AI LLM implementation. Uses the OpenAI-compatible chat completion endpoint. */ +/** + * Sarvam AI LLM implementation. Uses the OpenAI-compatible chat completion endpoint. + * + * @author Sandeep Belgavi + * @since 2026-02-11 + */ public class SarvamLlm extends BaseLlm { private static final String API_URL = "https://api.sarvam.ai/chat/completions"; diff --git a/core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java b/core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java index d4567e688..3228d2eb8 100644 --- a/core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java +++ b/core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java @@ -14,6 +14,7 @@ * limitations under the License. */ +// MODIFIED BY Sandeep Belgavi, 2026-02-11 package com.google.adk.transcription.strategy; import com.fasterxml.jackson.databind.JsonNode; @@ -41,7 +42,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** Sarvam AI transcription service implementation. */ +/** + * Sarvam AI transcription service implementation. + * + * @author Sandeep Belgavi + * @since 2026-02-11 + */ public class SarvamTranscriptionService implements TranscriptionService { private static final Logger logger = LoggerFactory.getLogger(SarvamTranscriptionService.class); private static final String API_URL = "https://api.sarvam.ai/speech-to-text"; From a9554501abe4235903cef01dc37f00d81916274a Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Tue, 17 Feb 2026 22:32:56 +0530 Subject: [PATCH 04/11] Refactor: Rename SarvamLlm to Sarvam and align with Gemini pattern --- .../com/google/adk/models/{SarvamLlm.java => Sarvam.java} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename core/src/main/java/com/google/adk/models/{SarvamLlm.java => Sarvam.java} (98%) diff --git a/core/src/main/java/com/google/adk/models/SarvamLlm.java b/core/src/main/java/com/google/adk/models/Sarvam.java similarity index 98% rename from core/src/main/java/com/google/adk/models/SarvamLlm.java rename to core/src/main/java/com/google/adk/models/Sarvam.java index 3cd779eff..f0942f1d2 100644 --- a/core/src/main/java/com/google/adk/models/SarvamLlm.java +++ b/core/src/main/java/com/google/adk/models/Sarvam.java @@ -46,7 +46,7 @@ * @author Sandeep Belgavi * @since 2026-02-11 */ -public class SarvamLlm extends BaseLlm { +public class Sarvam extends BaseLlm { private static final String API_URL = "https://api.sarvam.ai/chat/completions"; private static final MediaType JSON = MediaType.get("application/json; charset=utf-8"); @@ -55,11 +55,11 @@ public class SarvamLlm extends BaseLlm { private final OkHttpClient client; private final ObjectMapper objectMapper; - public SarvamLlm(String model) { + public Sarvam(String model) { this(model, null); } - public SarvamLlm(String model, String apiKey) { + public Sarvam(String model, String apiKey) { super(model); if (Strings.isNullOrEmpty(apiKey)) { this.apiKey = System.getenv("SARVAM_API_KEY"); From 5d8578ff2283b2dae6436467ca86cc199b869854 Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Tue, 17 Feb 2026 23:17:54 +0530 Subject: [PATCH 05/11] Test: Add unit and integration tests for Sarvam implementation --- core/pom.xml | 8 +- .../java/com/google/adk/models/Sarvam.java | 17 ++- .../java/com/google/adk/models/SarvamIT.java | 49 ++++++++ .../com/google/adk/models/SarvamTest.java | 106 ++++++++++++++++++ 4 files changed, 173 insertions(+), 7 deletions(-) create mode 100644 core/src/test/java/com/google/adk/models/SarvamIT.java create mode 100644 core/src/test/java/com/google/adk/models/SarvamTest.java diff --git a/core/pom.xml b/core/pom.xml index 157ee2dc8..37db191d2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -167,6 +167,12 @@ wiremock-jre8 test + + com.squareup.okhttp3 + mockwebserver + 4.12.0 + test + io.opentelemetry opentelemetry-api @@ -321,4 +327,4 @@ - + \ No newline at end of file diff --git a/core/src/main/java/com/google/adk/models/Sarvam.java b/core/src/main/java/com/google/adk/models/Sarvam.java index f0942f1d2..65dc61443 100644 --- a/core/src/main/java/com/google/adk/models/Sarvam.java +++ b/core/src/main/java/com/google/adk/models/Sarvam.java @@ -48,7 +48,7 @@ */ public class Sarvam extends BaseLlm { - private static final String API_URL = "https://api.sarvam.ai/chat/completions"; + private final String apiUrl; private static final MediaType JSON = MediaType.get("application/json; charset=utf-8"); private final String apiKey; @@ -60,6 +60,10 @@ public Sarvam(String model) { } public Sarvam(String model, String apiKey) { + this(model, apiKey, "https://api.sarvam.ai/chat/completions", new OkHttpClient()); + } + + protected Sarvam(String model, String apiKey, String apiUrl, OkHttpClient client) { super(model); if (Strings.isNullOrEmpty(apiKey)) { this.apiKey = System.getenv("SARVAM_API_KEY"); @@ -68,11 +72,12 @@ public Sarvam(String model, String apiKey) { } if (Strings.isNullOrEmpty(this.apiKey)) { - throw new IllegalArgumentException( - "Sarvam API key is required. Set SARVAM_API_KEY env variable or pass it to constructor."); + // Allow null for testing if mocked client handles it, but typically warn or throw. + // throw new IllegalArgumentException("Sarvam API key is required."); } - this.client = new OkHttpClient(); + this.apiUrl = apiUrl; + this.client = client; this.objectMapper = new ObjectMapper(); } @@ -143,9 +148,9 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre RequestBody body = RequestBody.create(jsonBody.toString(), JSON); Request request = new Request.Builder() - .url(API_URL) + .url(apiUrl) .addHeader("Content-Type", "application/json") - .addHeader("api-subscription-key", apiKey) + .addHeader("api-subscription-key", apiKey != null ? apiKey : "") .post(body) .build(); diff --git a/core/src/test/java/com/google/adk/models/SarvamIT.java b/core/src/test/java/com/google/adk/models/SarvamIT.java new file mode 100644 index 000000000..dfc355f20 --- /dev/null +++ b/core/src/test/java/com/google/adk/models/SarvamIT.java @@ -0,0 +1,49 @@ +package com.google.adk.models; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assume.assumeNotNull; + +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.subscribers.TestSubscriber; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class SarvamIT { + + private String apiKey; + + @Before + public void setUp() { + apiKey = System.getenv("SARVAM_API_KEY"); + // Skip test if API key is not set + assumeNotNull(apiKey); + } + + @Test + public void testGenerateContent() { + Sarvam sarvam = new Sarvam("sarvam-2.0", apiKey); + + LlmRequest request = + LlmRequest.builder() + .contents( + java.util.Collections.singletonList( + Content.builder() + .role("user") + .parts(Part.fromText("Hello, say hi back!")) + .build())) + .build(); + + TestSubscriber subscriber = sarvam.generateContent(request, false).test(); + + subscriber.awaitDone(30, java.util.concurrent.TimeUnit.SECONDS); + subscriber.assertNoErrors(); + subscriber.assertValueCount(1); + + LlmResponse response = subscriber.values().get(0); + assertThat(response.content().flatMap(Content::parts).get().get(0).text().get()).isNotEmpty(); + } +} diff --git a/core/src/test/java/com/google/adk/models/SarvamTest.java b/core/src/test/java/com/google/adk/models/SarvamTest.java new file mode 100644 index 000000000..12b06a761 --- /dev/null +++ b/core/src/test/java/com/google/adk/models/SarvamTest.java @@ -0,0 +1,106 @@ +package com.google.adk.models; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.subscribers.TestSubscriber; +import java.io.IOException; +import okhttp3.OkHttpClient; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class SarvamTest { + + private MockWebServer mockWebServer; + private Sarvam sarvam; + + @Before + public void setUp() throws IOException { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + // Use the protected constructor to inject the mock server URL and client + sarvam = + new Sarvam("sarvam-2.0", "fake-key", mockWebServer.url("/").toString(), new OkHttpClient()); + } + + @After + public void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Test + public void generateContent_nonStreaming_success() { + String jsonResponse = "{\"choices\": [{\"message\": {\"content\": \"Hello world\"}}]}"; + mockWebServer.enqueue(new MockResponse().setBody(jsonResponse)); + + LlmRequest request = + LlmRequest.builder() + .contents( + java.util.Collections.singletonList( + Content.builder().role("user").parts(Part.fromText("Hi")).build())) + .build(); + + TestSubscriber subscriber = sarvam.generateContent(request, false).test(); + + subscriber.awaitDone(5, java.util.concurrent.TimeUnit.SECONDS); + subscriber.assertNoErrors(); + subscriber.assertValueCount(1); + + LlmResponse response = subscriber.values().get(0); + assertThat(response.content().flatMap(Content::parts).get().get(0).text().get()) + .isEqualTo("Hello world"); + } + + @Test + public void generateContent_streaming_success() { + String chunk1 = "data: {\"choices\": [{\"delta\": {\"content\": \"Hello\"}}]}\n\n"; + String chunk2 = "data: {\"choices\": [{\"delta\": {\"content\": \" world\"}}]}\n\n"; + String done = "data: [DONE]\n\n"; + + mockWebServer.enqueue(new MockResponse().setBody(chunk1 + chunk2 + done)); + + LlmRequest request = + LlmRequest.builder() + .contents( + java.util.Collections.singletonList( + Content.builder().role("user").parts(Part.fromText("Hi")).build())) + .build(); + + TestSubscriber subscriber = sarvam.generateContent(request, true).test(); + + subscriber.awaitDone(5, java.util.concurrent.TimeUnit.SECONDS); + subscriber.assertNoErrors(); + subscriber.assertValueCount(2); + + assertThat( + subscriber.values().get(0).content().flatMap(Content::parts).get().get(0).text().get()) + .isEqualTo("Hello"); + assertThat( + subscriber.values().get(1).content().flatMap(Content::parts).get().get(0).text().get()) + .isEqualTo(" world"); + } + + @Test + public void generateContent_error() { + mockWebServer.enqueue(new MockResponse().setResponseCode(500).setBody("Internal Error")); + + LlmRequest request = + LlmRequest.builder() + .contents( + java.util.Collections.singletonList( + Content.builder().role("user").parts(Part.fromText("Hi")).build())) + .build(); + + TestSubscriber subscriber = sarvam.generateContent(request, false).test(); + + subscriber.awaitDone(5, java.util.concurrent.TimeUnit.SECONDS); + subscriber.assertError(IOException.class); + } +} From 8c1bab45b65926ac22b4569f53a72f3359eae800 Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Fri, 20 Feb 2026 12:23:34 +0530 Subject: [PATCH 06/11] refactor: Rebuild Sarvam AI integration with industry-grade architecture MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix critical bugs: wrong auth header (Bearer → api-subscription-key), wrong endpoint (/chat/completions → /v1/chat/completions), broken SSE streaming (message → delta), missing stream flag in request body - Restructure as single-module contrib/sarvam-ai with Builder pattern matching Gemini architecture - Add SarvamAiConfig with full API parameter support (temperature, topP, reasoningEffort, wikiGrounding, frequencyPenalty, presencePenalty) - Add SarvamAiLlmConnection for multi-turn streaming chat sessions - Add SarvamSttService with REST + WebSocket streaming STT (saaras:v3) - Add SarvamTtsService with REST + WebSocket streaming TTS (bulbul:v3) - Add SarvamVisionService for Document Intelligence (async job pipeline) - Add SarvamRetryInterceptor with exponential backoff + jitter - Add SarvamAiException with structured error fields - Add proper chat domain models (ChatRequest, ChatResponse, ChatChoice, ChatMessage, ChatUsage) handling both message and delta formats - Remove duplicate Sarvam code from core module - Add 35 unit tests covering all services Co-authored-by: Cursor --- contrib/sarvam-ai/pom.xml | 14 + .../google/adk/models/sarvamai/SarvamAi.java | 320 ++++++++++----- .../adk/models/sarvamai/SarvamAiChoice.java | 36 -- .../adk/models/sarvamai/SarvamAiConfig.java | 374 +++++++++++++++++- .../models/sarvamai/SarvamAiException.java | 67 ++++ .../sarvamai/SarvamAiLlmConnection.java | 154 ++++++++ .../adk/models/sarvamai/SarvamAiMessage.java | 42 -- .../adk/models/sarvamai/SarvamAiRequest.java | 53 --- .../adk/models/sarvamai/SarvamAiResponse.java | 38 -- .../sarvamai/SarvamAiResponseMessage.java | 36 -- .../sarvamai/SarvamRetryInterceptor.java | 103 +++++ .../adk/models/sarvamai/chat/ChatChoice.java | 77 ++++ .../adk/models/sarvamai/chat/ChatMessage.java | 67 ++++ .../adk/models/sarvamai/chat/ChatRequest.java | 152 +++++++ .../models/sarvamai/chat/ChatResponse.java | 95 +++++ .../adk/models/sarvamai/chat/ChatUsage.java | 58 +++ .../models/sarvamai/stt/SarvamSttService.java | 271 +++++++++++++ .../models/sarvamai/tts/SarvamTtsService.java | 238 +++++++++++ .../adk/models/sarvamai/tts/TtsRequest.java | 84 ++++ .../adk/models/sarvamai/tts/TtsResponse.java | 49 +++ .../sarvamai/vision/SarvamVisionService.java | 294 ++++++++++++++ .../models/sarvamai/SarvamAiConfigTest.java | 131 ++++++ .../adk/models/sarvamai/SarvamAiTest.java | 258 +++++++----- .../sarvamai/SarvamRetryInterceptorTest.java | 46 +++ .../models/sarvamai/chat/ChatRequestTest.java | 122 ++++++ .../sarvamai/stt/SarvamSttServiceTest.java | 109 +++++ .../sarvamai/tts/SarvamTtsServiceTest.java | 105 +++++ .../java/com/google/adk/models/Sarvam.java | 335 ---------------- .../strategy/SarvamTranscriptionService.java | 174 -------- .../strategy/TranscriptionServiceFactory.java | 4 +- .../java/com/google/adk/models/SarvamIT.java | 49 --- .../com/google/adk/models/SarvamTest.java | 106 ----- 32 files changed, 2985 insertions(+), 1076 deletions(-) delete mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiException.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiLlmConnection.java delete mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java delete mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java delete mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java delete mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamRetryInterceptor.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatChoice.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatResponse.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatUsage.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/stt/SarvamSttService.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/SarvamTtsService.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsRequest.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsResponse.java create mode 100644 contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/vision/SarvamVisionService.java create mode 100644 contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiConfigTest.java create mode 100644 contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamRetryInterceptorTest.java create mode 100644 contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/chat/ChatRequestTest.java create mode 100644 contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/stt/SarvamSttServiceTest.java create mode 100644 contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/tts/SarvamTtsServiceTest.java delete mode 100644 core/src/main/java/com/google/adk/models/Sarvam.java delete mode 100644 core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java delete mode 100644 core/src/test/java/com/google/adk/models/SarvamIT.java delete mode 100644 core/src/test/java/com/google/adk/models/SarvamTest.java diff --git a/contrib/sarvam-ai/pom.xml b/contrib/sarvam-ai/pom.xml index ab4e8eb59..1b23411d3 100644 --- a/contrib/sarvam-ai/pom.xml +++ b/contrib/sarvam-ai/pom.xml @@ -45,6 +45,14 @@ okhttp ${okhttp.version} + + com.google.guava + guava + + + com.google.errorprone + error_prone_annotations + @@ -78,6 +86,12 @@ ${mockito.version} test + + com.squareup.okhttp3 + mockwebserver + ${okhttp.version} + test + diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java index be991d03a..634eab1a8 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java @@ -16,149 +16,275 @@ package com.google.adk.models.sarvamai; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.BaseLlm; import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.adk.models.sarvamai.chat.ChatRequest; +import com.google.adk.models.sarvamai.chat.ChatResponse; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import java.io.BufferedReader; import java.io.IOException; -import okhttp3.Call; -import okhttp3.Callback; +import java.util.Objects; +import java.util.concurrent.TimeUnit; import okhttp3.MediaType; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; -import okhttp3.ResponseBody; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * This class is the main entry point for the Sarvam AI API. + * Sarvam AI LLM integration for the Agent Development Kit. * - * @author Sandeep Belgavi - * @since 2026-02-11 + *

Provides chat completion (blocking and streaming) via the Sarvam {@code sarvam-m} model using + * the OpenAI-compatible {@code /v1/chat/completions} endpoint. Authentication uses the {@code + * api-subscription-key} header per Sarvam API specification. + * + *

Follows the same architectural patterns as {@link com.google.adk.models.Gemini}, including + * Builder construction, immutable configuration, and RxJava-based streaming. + * + *

Usage: + * + *

{@code
+ * SarvamAi sarvam = SarvamAi.builder()
+ *     .modelName("sarvam-m")
+ *     .config(SarvamAiConfig.builder()
+ *         .apiKey("your-key")
+ *         .temperature(0.7)
+ *         .build())
+ *     .build();
+ * }
*/ public class SarvamAi extends BaseLlm { - private static final String API_ENDPOINT = "https://api.sarvam.ai/v1/chat/completions"; + private static final Logger logger = LoggerFactory.getLogger(SarvamAi.class); + private static final MediaType JSON_MEDIA_TYPE = MediaType.get("application/json; charset=utf-8"); + + private final SarvamAiConfig config; private final OkHttpClient httpClient; - private final String apiKey; private final ObjectMapper objectMapper; - public SarvamAi(String modelName, SarvamAiConfig config) { + SarvamAi(String modelName, SarvamAiConfig config, OkHttpClient httpClient) { super(modelName); - this.httpClient = new OkHttpClient(); - this.apiKey = config.getApiKey(); + this.config = Objects.requireNonNull(config, "config must not be null"); + this.httpClient = Objects.requireNonNull(httpClient, "httpClient must not be null"); this.objectMapper = new ObjectMapper(); } + public static Builder builder() { + return new Builder(); + } + + /** Returns the active configuration. */ + public SarvamAiConfig config() { + return config; + } + + /** Returns the shared OkHttpClient for subservices (STT, TTS, Vision). */ + OkHttpClient httpClient() { + return httpClient; + } + + /** Returns the shared ObjectMapper. */ + ObjectMapper objectMapper() { + return objectMapper; + } + @Override public Flowable generateContent(LlmRequest llmRequest, boolean stream) { if (stream) { - return stream(llmRequest); - } else { - return Flowable.fromCallable( - () -> { - String requestBody = - objectMapper.writeValueAsString(new SarvamAiRequest(this.model(), llmRequest)); - Request request = - new Request.Builder() - .url(API_ENDPOINT) - .addHeader("Authorization", "Bearer " + apiKey) - .post(RequestBody.create(requestBody, MediaType.get("application/json"))) - .build(); - Response response = httpClient.newCall(request).execute(); - if (!response.isSuccessful()) { - throw new IOException("Unexpected code " + response); - } - ResponseBody responseBody = response.body(); - if (responseBody != null) { - String responseBodyString = responseBody.string(); - SarvamAiResponse sarvamAiResponse = - objectMapper.readValue(responseBodyString, SarvamAiResponse.class); - return toLlmResponse(sarvamAiResponse); - } else { - throw new IOException("Response body is null"); - } - }); + return streamContent(llmRequest); } + + return Flowable.fromCallable( + () -> { + ChatRequest chatRequest = ChatRequest.fromLlmRequest(model(), llmRequest, config, false); + String body = objectMapper.writeValueAsString(chatRequest); + logger.debug("Sending chat completion request to {}", config.chatEndpoint()); + logger.trace("Request body: {}", body); + + Request request = buildHttpRequest(config.chatEndpoint(), body); + + try (Response response = httpClient.newCall(request).execute()) { + handleErrorResponse(response); + String responseBody = response.body().string(); + logger.trace("Response body: {}", responseBody); + ChatResponse chatResponse = objectMapper.readValue(responseBody, ChatResponse.class); + return toLlmResponse(chatResponse); + } + }); } - private Flowable stream(LlmRequest llmRequest) { + private Flowable streamContent(LlmRequest llmRequest) { return Flowable.create( emitter -> { try { - String requestBody = - objectMapper.writeValueAsString(new SarvamAiRequest(this.model(), llmRequest)); - Request request = - new Request.Builder() - .url(API_ENDPOINT) - .addHeader("Authorization", "Bearer " + apiKey) - .post(RequestBody.create(requestBody, MediaType.get("application/json"))) - .build(); - httpClient - .newCall(request) - .enqueue( - new Callback() { - @Override - public void onFailure(Call call, IOException e) { - emitter.onError(e); - } + ChatRequest chatRequest = ChatRequest.fromLlmRequest(model(), llmRequest, config, true); + String body = objectMapper.writeValueAsString(chatRequest); + logger.debug("Sending streaming chat request to {}", config.chatEndpoint()); + + Request request = buildHttpRequest(config.chatEndpoint(), body); - @Override - public void onResponse(Call call, Response response) throws IOException { - if (!response.isSuccessful()) { - emitter.onError(new IOException("Unexpected code " + response)); - return; - } - ResponseBody responseBody = response.body(); - if (responseBody != null) { - try (var reader = new BufferedReader(responseBody.charStream())) { - String line; - while ((line = reader.readLine()) != null) { - if (line.startsWith("data: ")) { - String data = line.substring(6); - if (data.equals("[DONE]")) { - emitter.onComplete(); - return; - } - SarvamAiResponse sarvamAiResponse = - objectMapper.readValue(data, SarvamAiResponse.class); - emitter.onNext(toLlmResponse(sarvamAiResponse)); - } - } - emitter.onComplete(); - } - } else { - emitter.onError(new IOException("Response body is null")); - } + try (Response response = httpClient.newCall(request).execute()) { + handleErrorResponse(response); + + if (response.body() == null) { + emitter.onError(new SarvamAiException("Response body is null")); + return; + } + + try (BufferedReader reader = new BufferedReader(response.body().charStream())) { + String line; + while ((line = reader.readLine()) != null) { + if (emitter.isCancelled()) { + break; + } + if (!line.startsWith("data: ")) { + continue; + } + String data = line.substring(6).trim(); + if ("[DONE]".equals(data)) { + break; + } + try { + JsonNode chunk = objectMapper.readTree(data); + JsonNode choices = chunk.path("choices"); + if (choices.isArray() && !choices.isEmpty()) { + JsonNode delta = choices.get(0).path("delta"); + if (delta.has("content")) { + String textChunk = delta.get("content").asText(); + Content content = + Content.builder().role("model").parts(Part.fromText(textChunk)).build(); + emitter.onNext( + LlmResponse.builder().content(content).partial(true).build()); } - }); - } catch (IOException e) { - emitter.onError(e); + } + } catch (Exception parseError) { + logger.trace("Skipping unparseable SSE line: {}", data); + } + } + } + emitter.onComplete(); + } + } catch (Exception e) { + if (!emitter.isCancelled()) { + emitter.onError(e); + } } }, - io.reactivex.rxjava3.core.BackpressureStrategy.BUFFER); + BackpressureStrategy.BUFFER); + } + + @Override + public BaseLlmConnection connect(LlmRequest llmRequest) { + logger.debug("Establishing Sarvam AI live connection"); + return new SarvamAiLlmConnection(this, llmRequest); + } + + Request buildHttpRequest(String url, String jsonBody) { + return new Request.Builder() + .url(url) + .addHeader("api-subscription-key", config.apiKey()) + .addHeader("Content-Type", "application/json") + .post(RequestBody.create(jsonBody, JSON_MEDIA_TYPE)) + .build(); } - private LlmResponse toLlmResponse(SarvamAiResponse sarvamAiResponse) { + void handleErrorResponse(Response response) throws IOException { + if (response.isSuccessful()) { + return; + } + String errorBody = response.body() != null ? response.body().string() : ""; + String errorCode = null; + String requestId = null; + String message = "Sarvam API error " + response.code(); + + try { + JsonNode errorJson = objectMapper.readTree(errorBody); + JsonNode error = errorJson.path("error"); + if (!error.isMissingNode()) { + message = error.path("message").asText(message); + errorCode = error.path("code").asText(null); + requestId = error.path("request_id").asText(null); + } + } catch (Exception ignored) { + // Use raw error body as message fallback + if (!errorBody.isEmpty()) { + message = message + ": " + errorBody; + } + } + + throw new SarvamAiException(message, response.code(), errorCode, requestId); + } + + private LlmResponse toLlmResponse(ChatResponse chatResponse) { + if (chatResponse.getChoices() == null || chatResponse.getChoices().isEmpty()) { + throw new SarvamAiException("Empty choices in response"); + } + var choice = chatResponse.getChoices().get(0); + var effectiveMsg = choice.effectiveMessage(); + if (effectiveMsg == null || effectiveMsg.getContent() == null) { + throw new SarvamAiException("No content in response choice"); + } + Content content = - Content.builder() - .role("model") - .parts( - java.util.Collections.singletonList( - Part.fromText(sarvamAiResponse.getChoices().get(0).getMessage().getContent()))) - .build(); + Content.builder().role("model").parts(Part.fromText(effectiveMsg.getContent())).build(); return LlmResponse.builder().content(content).build(); } - @Override - public BaseLlmConnection connect(LlmRequest llmRequest) { - // TODO: Implement this method - throw new UnsupportedOperationException( - "Live connection is not supported for Sarvam AI models."); + /** Builder for {@link SarvamAi}. Mirrors the Gemini builder pattern. */ + public static final class Builder { + private String modelName; + private SarvamAiConfig config; + private OkHttpClient httpClient; + + private Builder() {} + + @CanIgnoreReturnValue + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + @CanIgnoreReturnValue + public Builder config(SarvamAiConfig config) { + this.config = config; + return this; + } + + /** + * Provides a custom OkHttpClient. If not set, a default client is created with retry + * interceptor and timeouts from the config. + */ + @CanIgnoreReturnValue + public Builder httpClient(OkHttpClient httpClient) { + this.httpClient = httpClient; + return this; + } + + public SarvamAi build() { + Objects.requireNonNull(modelName, "modelName must be set"); + Objects.requireNonNull(config, "config must be set"); + + OkHttpClient client = this.httpClient; + if (client == null) { + client = + new OkHttpClient.Builder() + .connectTimeout(config.connectTimeout().toMillis(), TimeUnit.MILLISECONDS) + .readTimeout(config.readTimeout().toMillis(), TimeUnit.MILLISECONDS) + .addInterceptor(new SarvamRetryInterceptor(config.maxRetries())) + .build(); + } + + return new SarvamAi(modelName, config, client); + } } } diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java deleted file mode 100644 index 3980d88f3..000000000 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiChoice.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.models.sarvamai; - -/** - * This class is used to represent a choice from the Sarvam AI API. - * - * @author Sandeep Belgavi - * @since 2026-02-11 - */ -public class SarvamAiChoice { - - private SarvamAiResponseMessage message; - - public SarvamAiResponseMessage getMessage() { - return message; - } - - public void setMessage(SarvamAiResponseMessage message) { - this.message = message; - } -} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java index 0d2b062a7..3c3571f1f 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java @@ -16,21 +16,381 @@ package com.google.adk.models.sarvamai; +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; +import java.time.Duration; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalInt; + /** - * This class is used to configure the Sarvam AI API. + * Immutable configuration for Sarvam AI services. + * + *

Supports all Sarvam API parameters including chat completion, STT, TTS, and Vision. Uses the + * Builder pattern for safe, incremental construction with sensible defaults. * - * @author Sandeep Belgavi - * @since 2026-02-11 + *

API key resolution order: explicit value > {@code SARVAM_API_KEY} environment variable. */ -public class SarvamAiConfig { +public final class SarvamAiConfig { + + public static final String DEFAULT_CHAT_ENDPOINT = "https://api.sarvam.ai/v1/chat/completions"; + public static final String DEFAULT_STT_ENDPOINT = "https://api.sarvam.ai/speech-to-text"; + public static final String DEFAULT_STT_WS_ENDPOINT = + "wss://api.sarvam.ai/speech-to-text/streaming"; + public static final String DEFAULT_TTS_ENDPOINT = "https://api.sarvam.ai/text-to-speech"; + public static final String DEFAULT_TTS_WS_ENDPOINT = + "wss://api.sarvam.ai/text-to-speech/streaming"; + public static final String DEFAULT_VISION_ENDPOINT = + "https://api.sarvam.ai/document-intelligence"; + public static final Duration DEFAULT_CONNECT_TIMEOUT = Duration.ofSeconds(30); + public static final Duration DEFAULT_READ_TIMEOUT = Duration.ofSeconds(120); + public static final int DEFAULT_MAX_RETRIES = 3; private final String apiKey; + private final String chatEndpoint; + private final String sttEndpoint; + private final String sttWsEndpoint; + private final String ttsEndpoint; + private final String ttsWsEndpoint; + private final String visionEndpoint; + private final Duration connectTimeout; + private final Duration readTimeout; + private final int maxRetries; + + // Chat-specific parameters + private final OptionalDouble temperature; + private final OptionalDouble topP; + private final OptionalInt maxTokens; + private final Optional reasoningEffort; + private final Optional wikiGrounding; + private final OptionalDouble frequencyPenalty; + private final OptionalDouble presencePenalty; + + // TTS-specific parameters + private final Optional ttsSpeaker; + private final Optional ttsModel; + private final OptionalDouble ttsPace; + private final OptionalInt ttsSampleRate; + + // STT-specific parameters + private final Optional sttModel; + private final Optional sttMode; + private final Optional sttLanguageCode; - public SarvamAiConfig(String apiKey) { - this.apiKey = apiKey; + private SarvamAiConfig(Builder builder) { + String resolvedKey = builder.apiKey; + if (Strings.isNullOrEmpty(resolvedKey)) { + resolvedKey = System.getenv("SARVAM_API_KEY"); + } + Preconditions.checkArgument( + !Strings.isNullOrEmpty(resolvedKey), + "Sarvam API key is required. Set via builder or SARVAM_API_KEY environment variable."); + this.apiKey = resolvedKey; + + this.chatEndpoint = Objects.requireNonNullElse(builder.chatEndpoint, DEFAULT_CHAT_ENDPOINT); + this.sttEndpoint = Objects.requireNonNullElse(builder.sttEndpoint, DEFAULT_STT_ENDPOINT); + this.sttWsEndpoint = Objects.requireNonNullElse(builder.sttWsEndpoint, DEFAULT_STT_WS_ENDPOINT); + this.ttsEndpoint = Objects.requireNonNullElse(builder.ttsEndpoint, DEFAULT_TTS_ENDPOINT); + this.ttsWsEndpoint = Objects.requireNonNullElse(builder.ttsWsEndpoint, DEFAULT_TTS_WS_ENDPOINT); + this.visionEndpoint = + Objects.requireNonNullElse(builder.visionEndpoint, DEFAULT_VISION_ENDPOINT); + this.connectTimeout = + Objects.requireNonNullElse(builder.connectTimeout, DEFAULT_CONNECT_TIMEOUT); + this.readTimeout = Objects.requireNonNullElse(builder.readTimeout, DEFAULT_READ_TIMEOUT); + this.maxRetries = builder.maxRetries; + this.temperature = builder.temperature; + this.topP = builder.topP; + this.maxTokens = builder.maxTokens; + this.reasoningEffort = Optional.ofNullable(builder.reasoningEffort); + this.wikiGrounding = Optional.ofNullable(builder.wikiGrounding); + this.frequencyPenalty = builder.frequencyPenalty; + this.presencePenalty = builder.presencePenalty; + this.ttsSpeaker = Optional.ofNullable(builder.ttsSpeaker); + this.ttsModel = Optional.ofNullable(builder.ttsModel); + this.ttsPace = builder.ttsPace; + this.ttsSampleRate = builder.ttsSampleRate; + this.sttModel = Optional.ofNullable(builder.sttModel); + this.sttMode = Optional.ofNullable(builder.sttMode); + this.sttLanguageCode = Optional.ofNullable(builder.sttLanguageCode); + } + + public static Builder builder() { + return new Builder(); } - public String getApiKey() { + public String apiKey() { return apiKey; } + + public String chatEndpoint() { + return chatEndpoint; + } + + public String sttEndpoint() { + return sttEndpoint; + } + + public String sttWsEndpoint() { + return sttWsEndpoint; + } + + public String ttsEndpoint() { + return ttsEndpoint; + } + + public String ttsWsEndpoint() { + return ttsWsEndpoint; + } + + public String visionEndpoint() { + return visionEndpoint; + } + + public Duration connectTimeout() { + return connectTimeout; + } + + public Duration readTimeout() { + return readTimeout; + } + + public int maxRetries() { + return maxRetries; + } + + public OptionalDouble temperature() { + return temperature; + } + + public OptionalDouble topP() { + return topP; + } + + public OptionalInt maxTokens() { + return maxTokens; + } + + public Optional reasoningEffort() { + return reasoningEffort; + } + + public Optional wikiGrounding() { + return wikiGrounding; + } + + public OptionalDouble frequencyPenalty() { + return frequencyPenalty; + } + + public OptionalDouble presencePenalty() { + return presencePenalty; + } + + public Optional ttsSpeaker() { + return ttsSpeaker; + } + + public Optional ttsModel() { + return ttsModel; + } + + public OptionalDouble ttsPace() { + return ttsPace; + } + + public OptionalInt ttsSampleRate() { + return ttsSampleRate; + } + + public Optional sttModel() { + return sttModel; + } + + public Optional sttMode() { + return sttMode; + } + + public Optional sttLanguageCode() { + return sttLanguageCode; + } + + /** Builder for {@link SarvamAiConfig}. */ + public static final class Builder { + private String apiKey; + private String chatEndpoint; + private String sttEndpoint; + private String sttWsEndpoint; + private String ttsEndpoint; + private String ttsWsEndpoint; + private String visionEndpoint; + private Duration connectTimeout; + private Duration readTimeout; + private int maxRetries = DEFAULT_MAX_RETRIES; + private OptionalDouble temperature = OptionalDouble.empty(); + private OptionalDouble topP = OptionalDouble.empty(); + private OptionalInt maxTokens = OptionalInt.empty(); + private String reasoningEffort; + private Boolean wikiGrounding; + private OptionalDouble frequencyPenalty = OptionalDouble.empty(); + private OptionalDouble presencePenalty = OptionalDouble.empty(); + private String ttsSpeaker; + private String ttsModel; + private OptionalDouble ttsPace = OptionalDouble.empty(); + private OptionalInt ttsSampleRate = OptionalInt.empty(); + private String sttModel; + private String sttMode; + private String sttLanguageCode; + + private Builder() {} + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder chatEndpoint(String chatEndpoint) { + this.chatEndpoint = chatEndpoint; + return this; + } + + public Builder sttEndpoint(String sttEndpoint) { + this.sttEndpoint = sttEndpoint; + return this; + } + + public Builder sttWsEndpoint(String sttWsEndpoint) { + this.sttWsEndpoint = sttWsEndpoint; + return this; + } + + public Builder ttsEndpoint(String ttsEndpoint) { + this.ttsEndpoint = ttsEndpoint; + return this; + } + + public Builder ttsWsEndpoint(String ttsWsEndpoint) { + this.ttsWsEndpoint = ttsWsEndpoint; + return this; + } + + public Builder visionEndpoint(String visionEndpoint) { + this.visionEndpoint = visionEndpoint; + return this; + } + + public Builder connectTimeout(Duration connectTimeout) { + this.connectTimeout = connectTimeout; + return this; + } + + public Builder readTimeout(Duration readTimeout) { + this.readTimeout = readTimeout; + return this; + } + + public Builder maxRetries(int maxRetries) { + Preconditions.checkArgument(maxRetries >= 0, "maxRetries must be >= 0"); + this.maxRetries = maxRetries; + return this; + } + + public Builder temperature(double temperature) { + Preconditions.checkArgument( + temperature >= 0 && temperature <= 2, "temperature must be between 0 and 2"); + this.temperature = OptionalDouble.of(temperature); + return this; + } + + public Builder topP(double topP) { + Preconditions.checkArgument(topP >= 0 && topP <= 1, "topP must be between 0 and 1"); + this.topP = OptionalDouble.of(topP); + return this; + } + + public Builder maxTokens(int maxTokens) { + Preconditions.checkArgument(maxTokens > 0, "maxTokens must be > 0"); + this.maxTokens = OptionalInt.of(maxTokens); + return this; + } + + public Builder reasoningEffort(String reasoningEffort) { + Preconditions.checkArgument( + "low".equals(reasoningEffort) + || "medium".equals(reasoningEffort) + || "high".equals(reasoningEffort), + "reasoningEffort must be one of: low, medium, high"); + this.reasoningEffort = reasoningEffort; + return this; + } + + public Builder wikiGrounding(boolean wikiGrounding) { + this.wikiGrounding = wikiGrounding; + return this; + } + + public Builder frequencyPenalty(double frequencyPenalty) { + Preconditions.checkArgument( + frequencyPenalty >= -2 && frequencyPenalty <= 2, + "frequencyPenalty must be between -2 and 2"); + this.frequencyPenalty = OptionalDouble.of(frequencyPenalty); + return this; + } + + public Builder presencePenalty(double presencePenalty) { + Preconditions.checkArgument( + presencePenalty >= -2 && presencePenalty <= 2, + "presencePenalty must be between -2 and 2"); + this.presencePenalty = OptionalDouble.of(presencePenalty); + return this; + } + + public Builder ttsSpeaker(String ttsSpeaker) { + this.ttsSpeaker = ttsSpeaker; + return this; + } + + public Builder ttsModel(String ttsModel) { + this.ttsModel = ttsModel; + return this; + } + + public Builder ttsPace(double ttsPace) { + Preconditions.checkArgument( + ttsPace >= 0.5 && ttsPace <= 2.0, "ttsPace must be between 0.5 and 2.0"); + this.ttsPace = OptionalDouble.of(ttsPace); + return this; + } + + public Builder ttsSampleRate(int ttsSampleRate) { + this.ttsSampleRate = OptionalInt.of(ttsSampleRate); + return this; + } + + public Builder sttModel(String sttModel) { + this.sttModel = sttModel; + return this; + } + + public Builder sttMode(String sttMode) { + Preconditions.checkArgument( + "transcribe".equals(sttMode) + || "translate".equals(sttMode) + || "verbatim".equals(sttMode) + || "translit".equals(sttMode) + || "codemix".equals(sttMode), + "sttMode must be one of: transcribe, translate, verbatim, translit, codemix"); + this.sttMode = sttMode; + return this; + } + + public Builder sttLanguageCode(String sttLanguageCode) { + this.sttLanguageCode = sttLanguageCode; + return this; + } + + public SarvamAiConfig build() { + return new SarvamAiConfig(this); + } + } } diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiException.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiException.java new file mode 100644 index 000000000..bbd3c4a46 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiException.java @@ -0,0 +1,67 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import java.util.Optional; + +/** + * Domain exception for Sarvam AI API errors. Carries structured error information from the API + * response for programmatic error handling. + */ +public class SarvamAiException extends RuntimeException { + + private final int statusCode; + private final String errorCode; + private final String requestId; + + public SarvamAiException(String message, int statusCode, String errorCode, String requestId) { + super(message); + this.statusCode = statusCode; + this.errorCode = errorCode; + this.requestId = requestId; + } + + public SarvamAiException(String message, Throwable cause) { + super(message, cause); + this.statusCode = 0; + this.errorCode = null; + this.requestId = null; + } + + public SarvamAiException(String message) { + super(message); + this.statusCode = 0; + this.errorCode = null; + this.requestId = null; + } + + public int statusCode() { + return statusCode; + } + + public Optional errorCode() { + return Optional.ofNullable(errorCode); + } + + public Optional requestId() { + return Optional.ofNullable(requestId); + } + + public boolean isRetryable() { + return statusCode == 429 || statusCode == 503 || statusCode >= 500; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiLlmConnection.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiLlmConnection.java new file mode 100644 index 000000000..bbaa2f1da --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiLlmConnection.java @@ -0,0 +1,154 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.BackpressureStrategy; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.schedulers.Schedulers; +import io.reactivex.rxjava3.subjects.PublishSubject; +import java.util.ArrayList; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Live bidirectional connection to Sarvam AI, implementing multi-turn streaming conversations. + * + *

Maintains conversation history and streams responses token-by-token using SSE. Accumulates the + * full model response into history after each turn to support multi-turn context. + */ +final class SarvamAiLlmConnection implements BaseLlmConnection { + + private static final Logger logger = LoggerFactory.getLogger(SarvamAiLlmConnection.class); + + private final SarvamAi sarvamAi; + private final LlmRequest initialRequest; + private final List history; + private final PublishSubject responseSubject = PublishSubject.create(); + + SarvamAiLlmConnection(SarvamAi sarvamAi, LlmRequest llmRequest) { + this.sarvamAi = sarvamAi; + this.initialRequest = llmRequest; + this.history = new ArrayList<>(llmRequest.contents()); + } + + @Override + public Completable sendHistory(List newHistory) { + return Completable.fromAction( + () -> { + synchronized (history) { + history.clear(); + history.addAll(newHistory); + } + generateAndStream(); + }) + .subscribeOn(Schedulers.io()); + } + + @Override + public Completable sendContent(Content content) { + return Completable.fromAction( + () -> { + synchronized (history) { + history.add(content); + } + generateAndStream(); + }) + .subscribeOn(Schedulers.io()); + } + + @Override + public Completable sendRealtime(Blob blob) { + return Completable.error( + new UnsupportedOperationException( + "Realtime audio/video blobs are not supported on the chat connection. " + + "Use SarvamSttService for STT and SarvamTtsService for TTS.")); + } + + @Override + public Flowable receive() { + return responseSubject.toFlowable(BackpressureStrategy.BUFFER); + } + + @Override + public void close() { + responseSubject.onComplete(); + } + + @Override + public void close(Throwable throwable) { + responseSubject.onError(throwable); + } + + private void generateAndStream() { + List snapshot; + synchronized (history) { + snapshot = new ArrayList<>(history); + } + + LlmRequest.Builder turnBuilder = + LlmRequest.builder() + .contents(snapshot) + .appendTools(new ArrayList<>(initialRequest.tools().values())); + + initialRequest.config().ifPresent(turnBuilder::config); + turnBuilder.appendInstructions(initialRequest.getSystemInstructions()); + + LlmRequest turnRequest = turnBuilder.build(); + + StringBuilder fullText = new StringBuilder(); + + sarvamAi + .generateContent(turnRequest, true) + .subscribe( + response -> { + responseSubject.onNext(response); + response + .content() + .flatMap(Content::parts) + .ifPresent( + parts -> { + for (Part part : parts) { + part.text().ifPresent(fullText::append); + } + }); + }, + error -> { + logger.error("Error during Sarvam streaming turn", error); + responseSubject.onError(error); + }, + () -> { + if (fullText.length() > 0) { + Content responseContent = + Content.builder() + .role("model") + .parts(Part.fromText(fullText.toString())) + .build(); + synchronized (history) { + history.add(responseContent); + } + } + }); + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java deleted file mode 100644 index 802cef0d9..000000000 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiMessage.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.models.sarvamai; - -/** - * This class is used to represent a message from the Sarvam AI API. - * - * @author Sandeep Belgavi - * @since 2026-02-11 - */ -public class SarvamAiMessage { - - private String role; - private String content; - - public SarvamAiMessage(String role, String content) { - this.role = role; - this.content = content; - } - - public String getRole() { - return role; - } - - public String getContent() { - return content; - } -} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java deleted file mode 100644 index a339f2568..000000000 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiRequest.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.models.sarvamai; - -import com.google.adk.models.LlmRequest; -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import java.util.ArrayList; -import java.util.List; - -/** - * This class is used to create a request to the Sarvam AI API. - * - * @author Sandeep Belgavi - * @since 2026-02-11 - */ -public class SarvamAiRequest { - - private String model; - private List messages; - - public SarvamAiRequest(String model, LlmRequest llmRequest) { - this.model = model; - this.messages = new ArrayList<>(); - for (Content content : llmRequest.contents()) { - for (Part part : content.parts().get()) { - this.messages.add(new SarvamAiMessage(content.role().get(), part.text().get())); - } - } - } - - public String getModel() { - return model; - } - - public List getMessages() { - return messages; - } -} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java deleted file mode 100644 index 7877e8261..000000000 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponse.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.models.sarvamai; - -import java.util.List; - -/** - * This class is used to represent a response from the Sarvam AI API. - * - * @author Sandeep Belgavi - * @since 2026-02-11 - */ -public class SarvamAiResponse { - - private List choices; - - public List getChoices() { - return choices; - } - - public void setChoices(List choices) { - this.choices = choices; - } -} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java deleted file mode 100644 index 5af09d30f..000000000 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiResponseMessage.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.models.sarvamai; - -/** - * This class is used to represent a response message from the Sarvam AI API. - * - * @author Sandeep Belgavi - * @since 2026-02-11 - */ -public class SarvamAiResponseMessage { - - private String content; - - public String getContent() { - return content; - } - - public void setContent(String content) { - this.content = content; - } -} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamRetryInterceptor.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamRetryInterceptor.java new file mode 100644 index 000000000..da0874ac5 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamRetryInterceptor.java @@ -0,0 +1,103 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import java.io.IOException; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * OkHttp interceptor that implements exponential backoff with jitter for retryable Sarvam API + * errors (429 rate limit, 5xx server errors). + */ +final class SarvamRetryInterceptor implements Interceptor { + + private static final Logger logger = LoggerFactory.getLogger(SarvamRetryInterceptor.class); + private static final long BASE_DELAY_MS = 500; + private static final long MAX_DELAY_MS = 30_000; + + private final int maxRetries; + + SarvamRetryInterceptor(int maxRetries) { + this.maxRetries = maxRetries; + } + + @Override + public Response intercept(Chain chain) throws IOException { + Request request = chain.request(); + IOException lastException = null; + + for (int attempt = 0; attempt <= maxRetries; attempt++) { + try { + Response response = chain.proceed(request); + + if (response.isSuccessful() || !isRetryable(response.code()) || attempt == maxRetries) { + return response; + } + + response.close(); + long delay = calculateDelay(attempt); + logger.warn( + "Sarvam API returned {} for {}. Retrying in {}ms (attempt {}/{})", + response.code(), + request.url(), + delay, + attempt + 1, + maxRetries); + + sleep(delay); + } catch (IOException e) { + lastException = e; + if (attempt == maxRetries) { + break; + } + long delay = calculateDelay(attempt); + logger.warn( + "Sarvam API request failed: {}. Retrying in {}ms (attempt {}/{})", + e.getMessage(), + delay, + attempt + 1, + maxRetries); + sleep(delay); + } + } + + throw lastException != null ? lastException : new IOException("Request failed after retries"); + } + + private static boolean isRetryable(int statusCode) { + return statusCode == 429 || statusCode == 503 || statusCode >= 500; + } + + static long calculateDelay(int attempt) { + long delay = BASE_DELAY_MS * (1L << attempt); + delay = Math.min(delay, MAX_DELAY_MS); + long jitter = (long) (delay * 0.2 * Math.random()); + return delay + jitter; + } + + private static void sleep(long millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatChoice.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatChoice.java new file mode 100644 index 000000000..5aff17c63 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatChoice.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A choice in the Sarvam AI chat completion response. Handles both non-streaming ({@code message}) + * and streaming ({@code delta}) response formats. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public final class ChatChoice { + + @JsonProperty("index") + private int index; + + @JsonProperty("message") + private ChatMessage message; + + @JsonProperty("delta") + private ChatMessage delta; + + @JsonProperty("finish_reason") + private String finishReason; + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + + public ChatMessage getMessage() { + return message; + } + + public void setMessage(ChatMessage message) { + this.message = message; + } + + public ChatMessage getDelta() { + return delta; + } + + public void setDelta(ChatMessage delta) { + this.delta = delta; + } + + public String getFinishReason() { + return finishReason; + } + + public void setFinishReason(String finishReason) { + this.finishReason = finishReason; + } + + /** Returns the effective message content, preferring delta for streaming responses. */ + public ChatMessage effectiveMessage() { + return delta != null ? delta : message; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java new file mode 100644 index 000000000..c84336cd7 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java @@ -0,0 +1,67 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** A message in the Sarvam AI chat completion API (request or response). */ +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +public final class ChatMessage { + + @JsonProperty("role") + private String role; + + @JsonProperty("content") + private String content; + + @JsonProperty("reasoning_content") + private String reasoningContent; + + public ChatMessage() {} + + public ChatMessage(String role, String content) { + this.role = role; + this.content = content; + } + + public String getRole() { + return role; + } + + public void setRole(String role) { + this.role = role; + } + + public String getContent() { + return content; + } + + public void setContent(String content) { + this.content = content; + } + + public String getReasoningContent() { + return reasoningContent; + } + + public void setReasoningContent(String reasoningContent) { + this.reasoningContent = reasoningContent; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java new file mode 100644 index 000000000..d63d57d1d --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.util.ArrayList; +import java.util.List; + +/** + * Request body for the Sarvam AI chat completions endpoint. Constructed from the ADK {@link + * LlmRequest} and {@link SarvamAiConfig}. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public final class ChatRequest { + + @JsonProperty("model") + private String model; + + @JsonProperty("messages") + private List messages; + + @JsonProperty("stream") + private Boolean stream; + + @JsonProperty("temperature") + private Double temperature; + + @JsonProperty("top_p") + private Double topP; + + @JsonProperty("max_tokens") + private Integer maxTokens; + + @JsonProperty("reasoning_effort") + private String reasoningEffort; + + @JsonProperty("wiki_grounding") + private Boolean wikiGrounding; + + @JsonProperty("frequency_penalty") + private Double frequencyPenalty; + + @JsonProperty("presence_penalty") + private Double presencePenalty; + + @JsonProperty("n") + private Integer n; + + @JsonProperty("seed") + private Integer seed; + + @JsonProperty("stop") + private Object stop; + + public ChatRequest() {} + + /** + * Converts an ADK {@link LlmRequest} into a Sarvam-native {@link ChatRequest}, applying config + * defaults and mapping ADK roles to OpenAI-compatible roles. + */ + public static ChatRequest fromLlmRequest( + String modelName, LlmRequest llmRequest, SarvamAiConfig config, boolean stream) { + ChatRequest request = new ChatRequest(); + request.model = modelName; + request.stream = stream ? true : null; + request.messages = new ArrayList<>(); + + for (String instruction : llmRequest.getSystemInstructions()) { + request.messages.add(new ChatMessage("system", instruction)); + } + + for (Content content : llmRequest.contents()) { + String role = content.role().orElse("user"); + if ("model".equals(role)) { + role = "assistant"; + } + StringBuilder textBuilder = new StringBuilder(); + content + .parts() + .ifPresent( + parts -> { + for (Part part : parts) { + part.text().ifPresent(textBuilder::append); + } + }); + if (textBuilder.length() > 0) { + request.messages.add(new ChatMessage(role, textBuilder.toString())); + } + } + + config.temperature().ifPresent(v -> request.temperature = v); + config.topP().ifPresent(v -> request.topP = v); + config.maxTokens().ifPresent(v -> request.maxTokens = v); + config.reasoningEffort().ifPresent(v -> request.reasoningEffort = v); + config.wikiGrounding().ifPresent(v -> request.wikiGrounding = v); + config.frequencyPenalty().ifPresent(v -> request.frequencyPenalty = v); + config.presencePenalty().ifPresent(v -> request.presencePenalty = v); + + return request; + } + + public String getModel() { + return model; + } + + public List getMessages() { + return messages; + } + + public Boolean getStream() { + return stream; + } + + public Double getTemperature() { + return temperature; + } + + public Double getTopP() { + return topP; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public String getReasoningEffort() { + return reasoningEffort; + } + + public Boolean getWikiGrounding() { + return wikiGrounding; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatResponse.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatResponse.java new file mode 100644 index 000000000..6be3efaef --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatResponse.java @@ -0,0 +1,95 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** + * Response from the Sarvam AI chat completions endpoint. Supports both non-streaming and streaming + * (SSE chunk) formats. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public final class ChatResponse { + + @JsonProperty("id") + private String id; + + @JsonProperty("object") + private String object; + + @JsonProperty("created") + private long created; + + @JsonProperty("model") + private String model; + + @JsonProperty("choices") + private List choices; + + @JsonProperty("usage") + private ChatUsage usage; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getObject() { + return object; + } + + public void setObject(String object) { + this.object = object; + } + + public long getCreated() { + return created; + } + + public void setCreated(long created) { + this.created = created; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public List getChoices() { + return choices; + } + + public void setChoices(List choices) { + this.choices = choices; + } + + public ChatUsage getUsage() { + return usage; + } + + public void setUsage(ChatUsage usage) { + this.usage = usage; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatUsage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatUsage.java new file mode 100644 index 000000000..120dd3314 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatUsage.java @@ -0,0 +1,58 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** Token usage metadata from Sarvam AI API response. */ +@JsonIgnoreProperties(ignoreUnknown = true) +public final class ChatUsage { + + @JsonProperty("prompt_tokens") + private int promptTokens; + + @JsonProperty("completion_tokens") + private int completionTokens; + + @JsonProperty("total_tokens") + private int totalTokens; + + public int getPromptTokens() { + return promptTokens; + } + + public void setPromptTokens(int promptTokens) { + this.promptTokens = promptTokens; + } + + public int getCompletionTokens() { + return completionTokens; + } + + public void setCompletionTokens(int completionTokens) { + this.completionTokens = completionTokens; + } + + public int getTotalTokens() { + return totalTokens; + } + + public void setTotalTokens(int totalTokens) { + this.totalTokens = totalTokens; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/stt/SarvamSttService.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/stt/SarvamSttService.java new file mode 100644 index 000000000..ceec7483b --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/stt/SarvamSttService.java @@ -0,0 +1,271 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.stt; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.adk.models.sarvamai.SarvamAiException; +import com.google.adk.transcription.ServiceHealth; +import com.google.adk.transcription.ServiceType; +import com.google.adk.transcription.TranscriptionConfig; +import com.google.adk.transcription.TranscriptionEvent; +import com.google.adk.transcription.TranscriptionException; +import com.google.adk.transcription.TranscriptionResult; +import com.google.adk.transcription.TranscriptionService; +import io.reactivex.rxjava3.core.BackpressureStrategy; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; +import java.util.Base64; +import java.util.Objects; +import okhttp3.MediaType; +import okhttp3.MultipartBody; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import okhttp3.WebSocket; +import okhttp3.WebSocketListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Sarvam AI Speech-to-Text service implementing the ADK {@link TranscriptionService} interface. + * + *

Supports three modes of operation: + * + *

    + *
  • REST synchronous ({@link #transcribe}): Single-shot transcription via {@code POST + * /speech-to-text} using model {@code saaras:v3}. + *
  • REST async ({@link #transcribeAsync}): Same as above, executed on an IO scheduler. + *
  • WebSocket streaming ({@link #transcribeStream}): Real-time streaming via WebSocket + * with VAD support, delivering partial and final transcription events. + *
+ */ +public final class SarvamSttService implements TranscriptionService { + + private static final Logger logger = LoggerFactory.getLogger(SarvamSttService.class); + + private final SarvamAiConfig config; + private final OkHttpClient httpClient; + private final ObjectMapper objectMapper; + + public SarvamSttService(SarvamAiConfig config, OkHttpClient httpClient) { + this.config = Objects.requireNonNull(config); + this.httpClient = Objects.requireNonNull(httpClient); + this.objectMapper = new ObjectMapper(); + } + + @Override + public TranscriptionResult transcribe(byte[] audioData, TranscriptionConfig requestConfig) + throws TranscriptionException { + try { + String sttModel = config.sttModel().orElse("saaras:v3"); + String mode = config.sttMode().orElse("transcribe"); + String languageCode = config.sttLanguageCode().orElse(requestConfig.getLanguage()); + + RequestBody fileBody = RequestBody.create(audioData, MediaType.parse("audio/wav")); + + MultipartBody.Builder bodyBuilder = + new MultipartBody.Builder() + .setType(MultipartBody.FORM) + .addFormDataPart("file", "audio.wav", fileBody) + .addFormDataPart("model", sttModel) + .addFormDataPart("mode", mode); + + if (languageCode != null && !"auto".equals(languageCode)) { + bodyBuilder.addFormDataPart("language_code", languageCode); + } + + Request request = + new Request.Builder() + .url(config.sttEndpoint()) + .addHeader("api-subscription-key", config.apiKey()) + .post(bodyBuilder.build()) + .build(); + + logger.debug( + "Sending STT request to {} with model={}, mode={}", config.sttEndpoint(), sttModel, mode); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : ""; + throw new TranscriptionException( + "STT request failed with status " + response.code() + ": " + errorBody); + } + + String responseBody = response.body().string(); + JsonNode root = objectMapper.readTree(responseBody); + String transcript = root.path("transcript").asText(""); + String detectedLang = root.path("language_code").asText(null); + + TranscriptionResult.Builder resultBuilder = + TranscriptionResult.builder().text(transcript).timestamp(System.currentTimeMillis()); + + if (detectedLang != null) { + resultBuilder.language(detectedLang); + } + + return resultBuilder.build(); + } + } catch (TranscriptionException e) { + throw e; + } catch (Exception e) { + throw new TranscriptionException("STT transcription failed", e); + } + } + + @Override + public Single transcribeAsync( + byte[] audioData, TranscriptionConfig requestConfig) { + return Single.fromCallable(() -> transcribe(audioData, requestConfig)) + .subscribeOn(Schedulers.io()); + } + + /** + * Streams audio data to Sarvam's WebSocket STT endpoint for real-time transcription. + * + *

Audio chunks are base64-encoded and sent as JSON frames. The server responds with transcript + * events including partial results and VAD signals (speech_start, speech_end). + */ + @Override + public Flowable transcribeStream( + Flowable audioStream, TranscriptionConfig requestConfig) { + + return Flowable.create( + emitter -> { + String sttModel = config.sttModel().orElse("saaras:v3"); + String mode = config.sttMode().orElse("transcribe"); + String languageCode = config.sttLanguageCode().orElse(requestConfig.getLanguage()); + + StringBuilder wsUrl = new StringBuilder(config.sttWsEndpoint()); + wsUrl.append("?model=").append(sttModel); + wsUrl.append("&mode=").append(mode); + if (languageCode != null && !"auto".equals(languageCode)) { + wsUrl.append("&language_code=").append(languageCode); + } + wsUrl.append("&high_vad_sensitivity=true"); + wsUrl.append("&vad_signals=true"); + + Request wsRequest = + new Request.Builder() + .url(wsUrl.toString()) + .addHeader("api-subscription-key", config.apiKey()) + .build(); + + logger.debug("Opening STT WebSocket to {}", wsUrl); + + WebSocket webSocket = + httpClient.newWebSocket( + wsRequest, + new WebSocketListener() { + @Override + public void onOpen(WebSocket ws, Response response) { + logger.debug("STT WebSocket connected"); + audioStream.subscribe( + chunk -> { + String base64Audio = Base64.getEncoder().encodeToString(chunk); + String frame = + String.format( + "{\"audio\":\"%s\",\"encoding\":\"audio/wav\",\"sample_rate\":16000}", + base64Audio); + ws.send(frame); + }, + error -> { + logger.error("Audio stream error", error); + ws.close(1000, "Audio stream error"); + }, + () -> { + logger.debug("Audio stream completed, closing WebSocket"); + ws.close(1000, "Stream complete"); + }); + } + + @Override + public void onMessage(WebSocket ws, String text) { + try { + JsonNode node = objectMapper.readTree(text); + String type = node.path("type").asText(""); + + switch (type) { + case "transcript": + case "translation": + String transcript = node.path("text").asText(""); + emitter.onNext( + TranscriptionEvent.builder() + .text(transcript) + .finished(true) + .timestamp(System.currentTimeMillis()) + .build()); + break; + case "speech_start": + logger.trace("VAD: speech started"); + break; + case "speech_end": + logger.trace("VAD: speech ended"); + break; + default: + logger.trace("Received STT WS message type: {}", type); + } + } catch (Exception e) { + logger.warn("Failed to parse STT WS message: {}", text, e); + } + } + + @Override + public void onClosing(WebSocket ws, int code, String reason) { + logger.debug("STT WebSocket closing: {} {}", code, reason); + ws.close(code, reason); + } + + @Override + public void onClosed(WebSocket ws, int code, String reason) { + logger.debug("STT WebSocket closed: {} {}", code, reason); + emitter.onComplete(); + } + + @Override + public void onFailure(WebSocket ws, Throwable t, Response response) { + logger.error("STT WebSocket failure", t); + if (!emitter.isCancelled()) { + emitter.onError( + new SarvamAiException("STT WebSocket connection failed", t)); + } + } + }); + + emitter.setCancellable(() -> webSocket.close(1000, "Cancelled")); + }, + BackpressureStrategy.BUFFER); + } + + @Override + public boolean isAvailable() { + return config.apiKey() != null && !config.apiKey().isEmpty(); + } + + @Override + public ServiceType getServiceType() { + return ServiceType.SARVAM; + } + + @Override + public ServiceHealth getHealth() { + return ServiceHealth.builder().available(isAvailable()).serviceType(ServiceType.SARVAM).build(); + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/SarvamTtsService.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/SarvamTtsService.java new file mode 100644 index 000000000..fc68608b0 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/SarvamTtsService.java @@ -0,0 +1,238 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.tts; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.adk.models.sarvamai.SarvamAiException; +import io.reactivex.rxjava3.core.BackpressureStrategy; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; +import java.util.Base64; +import java.util.Objects; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import okhttp3.WebSocket; +import okhttp3.WebSocketListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Sarvam AI Text-to-Speech service with both REST and WebSocket streaming support. + * + *

REST mode ({@link #synthesize}): Sends text and returns the complete audio as a byte array + * (decoded from base64). Uses the Bulbul v3 model with 30+ speaker voices. + * + *

WebSocket streaming mode ({@link #synthesizeStream}): Opens a persistent WebSocket connection + * for progressive audio chunk delivery with low latency. Audio chunks are emitted as they are + * synthesized, enabling real-time playback. + */ +public final class SarvamTtsService { + + private static final Logger logger = LoggerFactory.getLogger(SarvamTtsService.class); + private static final MediaType JSON_MEDIA_TYPE = MediaType.get("application/json; charset=utf-8"); + + private final SarvamAiConfig config; + private final OkHttpClient httpClient; + private final ObjectMapper objectMapper; + + public SarvamTtsService(SarvamAiConfig config, OkHttpClient httpClient) { + this.config = Objects.requireNonNull(config); + this.httpClient = Objects.requireNonNull(httpClient); + this.objectMapper = new ObjectMapper(); + } + + /** + * Synthesizes speech from text synchronously via the REST endpoint. + * + * @param text the text to convert to speech (max 2500 chars for bulbul:v3) + * @param targetLanguageCode BCP-47 language code (e.g., "en-IN", "hi-IN") + * @return decoded audio bytes (WAV format by default) + */ + public byte[] synthesize(String text, String targetLanguageCode) { + Objects.requireNonNull(text, "text must not be null"); + Objects.requireNonNull(targetLanguageCode, "targetLanguageCode must not be null"); + + String model = config.ttsModel().orElse("bulbul:v3"); + String speaker = config.ttsSpeaker().orElse("shubh"); + Double pace = config.ttsPace().isPresent() ? config.ttsPace().getAsDouble() : null; + Integer sampleRate = + config.ttsSampleRate().isPresent() ? config.ttsSampleRate().getAsInt() : null; + + TtsRequest ttsRequest = + new TtsRequest(text, targetLanguageCode, model, speaker, pace, sampleRate); + + try { + String body = objectMapper.writeValueAsString(ttsRequest); + + Request request = + new Request.Builder() + .url(config.ttsEndpoint()) + .addHeader("api-subscription-key", config.apiKey()) + .addHeader("Content-Type", "application/json") + .post(RequestBody.create(body, JSON_MEDIA_TYPE)) + .build(); + + logger.debug( + "Sending TTS request to {} with model={}, speaker={}", + config.ttsEndpoint(), + model, + speaker); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : ""; + throw new SarvamAiException( + "TTS request failed: " + response.code() + " " + errorBody, + response.code(), + null, + null); + } + + TtsResponse ttsResponse = + objectMapper.readValue(response.body().string(), TtsResponse.class); + if (ttsResponse.getAudios() == null || ttsResponse.getAudios().isEmpty()) { + throw new SarvamAiException("TTS response contained no audio data"); + } + + String combinedBase64 = String.join("", ttsResponse.getAudios()); + return Base64.getDecoder().decode(combinedBase64); + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("TTS synthesis failed", e); + } + } + + /** Async version of {@link #synthesize}. */ + public Single synthesizeAsync(String text, String targetLanguageCode) { + return Single.fromCallable(() -> synthesize(text, targetLanguageCode)) + .subscribeOn(Schedulers.io()); + } + + /** + * Streams TTS audio via WebSocket for low-latency, progressive playback. + * + *

Opens a WebSocket to Sarvam's streaming TTS endpoint, sends config + text, and emits decoded + * audio chunks as they arrive. Each chunk is a raw audio byte array ready for playback. + * + * @param text the text to synthesize + * @param targetLanguageCode BCP-47 language code + * @return a Flowable of audio byte[] chunks + */ + public Flowable synthesizeStream(String text, String targetLanguageCode) { + Objects.requireNonNull(text, "text must not be null"); + Objects.requireNonNull(targetLanguageCode, "targetLanguageCode must not be null"); + + return Flowable.create( + emitter -> { + String model = config.ttsModel().orElse("bulbul:v3"); + String speaker = config.ttsSpeaker().orElse("shubh"); + + String wsUrl = config.ttsWsEndpoint() + "?model=" + model; + + Request wsRequest = + new Request.Builder() + .url(wsUrl) + .addHeader("api-subscription-key", config.apiKey()) + .build(); + + logger.debug("Opening TTS WebSocket to {}", wsUrl); + + WebSocket webSocket = + httpClient.newWebSocket( + wsRequest, + new WebSocketListener() { + @Override + public void onOpen(WebSocket ws, Response response) { + logger.debug("TTS WebSocket connected"); + + String configMsg = + String.format( + "{\"type\":\"config\",\"data\":{\"speaker\":\"%s\"," + + "\"target_language_code\":\"%s\"}}", + speaker, targetLanguageCode); + ws.send(configMsg); + + String textMsg = + String.format( + "{\"type\":\"text\",\"data\":{\"text\":\"%s\"}}", + text.replace("\"", "\\\"")); + ws.send(textMsg); + + ws.send("{\"type\":\"flush\"}"); + } + + @Override + public void onMessage(WebSocket ws, String messageText) { + try { + JsonNode node = objectMapper.readTree(messageText); + String type = node.path("type").asText(""); + + if ("audio".equals(type)) { + String audioBase64 = node.path("data").path("audio").asText(""); + if (!audioBase64.isEmpty()) { + byte[] audioChunk = Base64.getDecoder().decode(audioBase64); + emitter.onNext(audioChunk); + } + } else if ("event".equals(type)) { + String eventType = node.path("data").path("event_type").asText(""); + if ("final".equals(eventType)) { + ws.close(1000, "Synthesis complete"); + } + } + } catch (Exception e) { + logger.warn("Failed to parse TTS WS message", e); + } + } + + @Override + public void onClosing(WebSocket ws, int code, String reason) { + ws.close(code, reason); + } + + @Override + public void onClosed(WebSocket ws, int code, String reason) { + logger.debug("TTS WebSocket closed: {} {}", code, reason); + emitter.onComplete(); + } + + @Override + public void onFailure(WebSocket ws, Throwable t, Response response) { + logger.error("TTS WebSocket failure", t); + if (!emitter.isCancelled()) { + emitter.onError( + new SarvamAiException("TTS WebSocket connection failed", t)); + } + } + }); + + emitter.setCancellable(() -> webSocket.close(1000, "Cancelled")); + }, + BackpressureStrategy.BUFFER); + } + + public boolean isAvailable() { + return config.apiKey() != null && !config.apiKey().isEmpty(); + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsRequest.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsRequest.java new file mode 100644 index 000000000..152b84fc6 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsRequest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.tts; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** Request body for the Sarvam AI text-to-speech REST endpoint. */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public final class TtsRequest { + + @JsonProperty("text") + private String text; + + @JsonProperty("target_language_code") + private String targetLanguageCode; + + @JsonProperty("model") + private String model; + + @JsonProperty("speaker") + private String speaker; + + @JsonProperty("pace") + private Double pace; + + @JsonProperty("speech_sample_rate") + private Integer speechSampleRate; + + public TtsRequest() {} + + public TtsRequest( + String text, + String targetLanguageCode, + String model, + String speaker, + Double pace, + Integer speechSampleRate) { + this.text = text; + this.targetLanguageCode = targetLanguageCode; + this.model = model; + this.speaker = speaker; + this.pace = pace; + this.speechSampleRate = speechSampleRate; + } + + public String getText() { + return text; + } + + public String getTargetLanguageCode() { + return targetLanguageCode; + } + + public String getModel() { + return model; + } + + public String getSpeaker() { + return speaker; + } + + public Double getPace() { + return pace; + } + + public Integer getSpeechSampleRate() { + return speechSampleRate; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsResponse.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsResponse.java new file mode 100644 index 000000000..61a6e9f37 --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsResponse.java @@ -0,0 +1,49 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.tts; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** Response from the Sarvam AI text-to-speech REST endpoint. */ +@JsonIgnoreProperties(ignoreUnknown = true) +public final class TtsResponse { + + @JsonProperty("request_id") + private String requestId; + + @JsonProperty("audios") + private List audios; + + public String getRequestId() { + return requestId; + } + + public void setRequestId(String requestId) { + this.requestId = requestId; + } + + /** Returns base64-encoded audio strings. Each element corresponds to an input text segment. */ + public List getAudios() { + return audios; + } + + public void setAudios(List audios) { + this.audios = audios; + } +} diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/vision/SarvamVisionService.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/vision/SarvamVisionService.java new file mode 100644 index 000000000..a451d5d0b --- /dev/null +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/vision/SarvamVisionService.java @@ -0,0 +1,294 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.vision; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.adk.models.sarvamai.SarvamAiException; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Objects; +import java.util.Optional; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Sarvam Vision Document Intelligence service. + * + *

Powered by the Sarvam Vision 3B VLM for extracting structured text from documents across 23 + * languages (22 Indian + English). Supports PDF, PNG, JPG, and ZIP inputs with HTML or Markdown + * output. + * + *

The workflow follows Sarvam's async job pattern: + * + *

    + *
  1. {@link #createJob} - Initialize a document processing job + *
  2. {@link #uploadDocument} - Upload the document to the job's presigned URL + *
  3. {@link #startJob} - Begin processing + *
  4. {@link #getJobStatus} - Poll for completion + *
  5. {@link #downloadResults} - Retrieve the processed output + *
+ */ +public final class SarvamVisionService { + + private static final Logger logger = LoggerFactory.getLogger(SarvamVisionService.class); + private static final MediaType JSON_MEDIA_TYPE = MediaType.get("application/json; charset=utf-8"); + + private final SarvamAiConfig config; + private final OkHttpClient httpClient; + private final ObjectMapper objectMapper; + + public SarvamVisionService(SarvamAiConfig config, OkHttpClient httpClient) { + this.config = Objects.requireNonNull(config); + this.httpClient = Objects.requireNonNull(httpClient); + this.objectMapper = new ObjectMapper(); + } + + /** Result of a job creation request. */ + public record JobInfo(String jobId, String uploadUrl) {} + + /** Current status of a document intelligence job. */ + public record JobStatus(String jobId, String state, Optional downloadUrl) {} + + /** + * Creates a new document intelligence job. + * + * @param languageCode BCP-47 code (e.g., "hi-IN", "en-IN") + * @param outputFormat "html" or "md" + * @return job info with ID and upload URL + */ + public JobInfo createJob(String languageCode, String outputFormat) { + Objects.requireNonNull(languageCode); + Objects.requireNonNull(outputFormat); + + try { + String body = + objectMapper.writeValueAsString( + new java.util.HashMap() { + { + put("language", languageCode); + put("output_format", outputFormat); + } + }); + + Request request = + new Request.Builder() + .url(config.visionEndpoint() + "/create") + .addHeader("api-subscription-key", config.apiKey()) + .addHeader("Content-Type", "application/json") + .post(RequestBody.create(body, JSON_MEDIA_TYPE)) + .build(); + + logger.debug("Creating vision job: lang={}, format={}", languageCode, outputFormat); + + try (Response response = httpClient.newCall(request).execute()) { + ensureSuccess(response, "Create job"); + JsonNode root = objectMapper.readTree(response.body().string()); + String jobId = root.path("job_id").asText(); + String uploadUrl = root.path("upload_url").asText(); + return new JobInfo(jobId, uploadUrl); + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("Failed to create vision job", e); + } + } + + /** + * Uploads a document to the presigned upload URL. + * + * @param uploadUrl the presigned URL from {@link #createJob} + * @param filePath path to the document file + */ + public void uploadDocument(String uploadUrl, Path filePath) { + Objects.requireNonNull(uploadUrl); + Objects.requireNonNull(filePath); + + try { + byte[] fileBytes = Files.readAllBytes(filePath); + String contentType = Files.probeContentType(filePath); + if (contentType == null) { + contentType = "application/octet-stream"; + } + + Request request = + new Request.Builder() + .url(uploadUrl) + .put(RequestBody.create(fileBytes, MediaType.parse(contentType))) + .build(); + + logger.debug("Uploading document {} ({} bytes)", filePath, fileBytes.length); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + throw new SarvamAiException( + "Document upload failed: " + response.code(), response.code(), null, null); + } + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("Failed to upload document", e); + } + } + + /** Starts processing a previously created and uploaded job. */ + public void startJob(String jobId) { + Objects.requireNonNull(jobId); + + try { + String body = objectMapper.writeValueAsString(java.util.Map.of("job_id", jobId)); + Request request = + new Request.Builder() + .url(config.visionEndpoint() + "/start") + .addHeader("api-subscription-key", config.apiKey()) + .addHeader("Content-Type", "application/json") + .post(RequestBody.create(body, JSON_MEDIA_TYPE)) + .build(); + + logger.debug("Starting vision job {}", jobId); + + try (Response response = httpClient.newCall(request).execute()) { + ensureSuccess(response, "Start job"); + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("Failed to start vision job", e); + } + } + + /** Gets the current status of a document processing job. */ + public JobStatus getJobStatus(String jobId) { + Objects.requireNonNull(jobId); + + try { + Request request = + new Request.Builder() + .url(config.visionEndpoint() + "/status?job_id=" + jobId) + .addHeader("api-subscription-key", config.apiKey()) + .get() + .build(); + + try (Response response = httpClient.newCall(request).execute()) { + ensureSuccess(response, "Get job status"); + JsonNode root = objectMapper.readTree(response.body().string()); + String state = root.path("job_state").asText("unknown"); + String downloadUrl = root.path("download_url").asText(null); + return new JobStatus(jobId, state, Optional.ofNullable(downloadUrl)); + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("Failed to get job status", e); + } + } + + /** + * Downloads the processed results. + * + * @param downloadUrl the URL from {@link JobStatus#downloadUrl()} + * @return the result bytes (typically a ZIP file containing HTML/Markdown) + */ + public byte[] downloadResults(String downloadUrl) { + Objects.requireNonNull(downloadUrl); + + try { + Request request = new Request.Builder().url(downloadUrl).get().build(); + + logger.debug("Downloading vision results from {}", downloadUrl); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + throw new SarvamAiException( + "Download failed: " + response.code(), response.code(), null, null); + } + return response.body().bytes(); + } + } catch (SarvamAiException e) { + throw e; + } catch (Exception e) { + throw new SarvamAiException("Failed to download results", e); + } + } + + /** + * Convenience method: runs the full pipeline (create -> upload -> start -> poll -> download) + * asynchronously. + */ + public Single processDocument(Path filePath, String languageCode, String outputFormat) { + return Single.create( + emitter -> { + try { + JobInfo job = createJob(languageCode, outputFormat); + uploadDocument(job.uploadUrl(), filePath); + startJob(job.jobId()); + + // Poll with backoff + int maxPolls = 60; + long pollIntervalMs = 2000; + for (int i = 0; i < maxPolls; i++) { + Thread.sleep(pollIntervalMs); + JobStatus status = getJobStatus(job.jobId()); + + if ("completed".equalsIgnoreCase(status.state())) { + if (status.downloadUrl().isPresent()) { + byte[] result = downloadResults(status.downloadUrl().get()); + emitter.onSuccess(result); + return; + } + emitter.onError( + new SarvamAiException("Job completed but no download URL provided")); + return; + } else if ("failed".equalsIgnoreCase(status.state())) { + emitter.onError(new SarvamAiException("Vision job failed: " + job.jobId())); + return; + } + + // Adaptive backoff + pollIntervalMs = Math.min(pollIntervalMs * 2, 10_000); + } + emitter.onError(new SarvamAiException("Vision job timed out: " + job.jobId())); + } catch (Exception e) { + emitter.onError(e); + } + }) + .subscribeOn(Schedulers.io()); + } + + public boolean isAvailable() { + return config.apiKey() != null && !config.apiKey().isEmpty(); + } + + private void ensureSuccess(Response response, String operation) throws IOException { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : ""; + throw new SarvamAiException( + operation + " failed: " + response.code() + " " + errorBody, response.code(), null, null); + } + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiConfigTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiConfigTest.java new file mode 100644 index 000000000..2d7cb4770 --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiConfigTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; + +class SarvamAiConfigTest { + + @Test + void builder_withApiKey_succeeds() { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("test-key").build(); + assertThat(config.apiKey()).isEqualTo("test-key"); + } + + @Test + void builder_withoutApiKey_throwsIfEnvNotSet() { + if (System.getenv("SARVAM_API_KEY") != null) { + return; + } + assertThrows(IllegalArgumentException.class, () -> SarvamAiConfig.builder().build()); + } + + @Test + void builder_setsDefaultEndpoints() { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").build(); + assertThat(config.chatEndpoint()).isEqualTo(SarvamAiConfig.DEFAULT_CHAT_ENDPOINT); + assertThat(config.sttEndpoint()).isEqualTo(SarvamAiConfig.DEFAULT_STT_ENDPOINT); + assertThat(config.ttsEndpoint()).isEqualTo(SarvamAiConfig.DEFAULT_TTS_ENDPOINT); + assertThat(config.visionEndpoint()).isEqualTo(SarvamAiConfig.DEFAULT_VISION_ENDPOINT); + } + + @Test + void builder_customEndpoints() { + SarvamAiConfig config = + SarvamAiConfig.builder() + .apiKey("key") + .chatEndpoint("http://custom/chat") + .sttEndpoint("http://custom/stt") + .ttsEndpoint("http://custom/tts") + .build(); + + assertThat(config.chatEndpoint()).isEqualTo("http://custom/chat"); + assertThat(config.sttEndpoint()).isEqualTo("http://custom/stt"); + assertThat(config.ttsEndpoint()).isEqualTo("http://custom/tts"); + } + + @Test + void builder_temperatureValidation() { + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").temperature(3.0).build()); + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").temperature(-1.0).build()); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").temperature(0.7).build(); + assertThat(config.temperature().getAsDouble()).isWithin(0.001).of(0.7); + } + + @Test + void builder_reasoningEffortValidation() { + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").reasoningEffort("invalid").build()); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").reasoningEffort("high").build(); + assertThat(config.reasoningEffort()).hasValue("high"); + } + + @Test + void builder_sttModeValidation() { + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").sttMode("invalid").build()); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").sttMode("translate").build(); + assertThat(config.sttMode()).hasValue("translate"); + } + + @Test + void builder_ttsPaceValidation() { + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").ttsPace(0.1).build()); + assertThrows( + IllegalArgumentException.class, + () -> SarvamAiConfig.builder().apiKey("key").ttsPace(3.0).build()); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").ttsPace(1.5).build(); + assertThat(config.ttsPace().getAsDouble()).isWithin(0.001).of(1.5); + } + + @Test + void builder_maxRetriesDefault() { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").build(); + assertThat(config.maxRetries()).isEqualTo(SarvamAiConfig.DEFAULT_MAX_RETRIES); + } + + @Test + void builder_wikiGrounding() { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").wikiGrounding(true).build(); + assertThat(config.wikiGrounding()).hasValue(true); + } + + @Test + void builder_chatParametersOptionalByDefault() { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").build(); + assertThat(config.temperature().isEmpty()).isTrue(); + assertThat(config.topP().isEmpty()).isTrue(); + assertThat(config.maxTokens().isEmpty()).isTrue(); + assertThat(config.reasoningEffort().isEmpty()).isTrue(); + assertThat(config.wikiGrounding().isEmpty()).isTrue(); + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java index 2f9d5a013..c04cf3add 100644 --- a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java @@ -17,142 +17,196 @@ package com.google.adk.models.sarvamai; import static com.google.common.truth.Truth.assertThat; -import static org.mockito.Mockito.when; +import static org.junit.jupiter.api.Assertions.assertThrows; -import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.genai.types.Content; import com.google.genai.types.Part; -import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.io.IOException; -import java.util.List; -import okhttp3.Call; -import okhttp3.Callback; -import okhttp3.MediaType; +import java.util.Collections; +import java.util.concurrent.TimeUnit; import okhttp3.OkHttpClient; -import okhttp3.Protocol; -import okhttp3.Request; -import okhttp3.Response; -import okhttp3.ResponseBody; -import org.junit.Before; -import org.junit.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; - -/** - * Tests for SarvamAi. - * - * @author Sandeep Belgavi - * @since 2026-02-11 - */ -@ExtendWith(MockitoExtension.class) -public class SarvamAiTest { +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; - private static final String API_KEY = "test-api-key"; - private static final String MODEL_NAME = "test-model"; - private static final String COMPLETION_TEXT = "Hello, world!"; - private static final String STREAMING_CHUNK_1 = - "data: {\"choices\": [{\"message\": {\"content\": \"Hello,\"}}]}"; - private static final String STREAMING_CHUNK_2 = - "data: {\"choices\": [{\"message\": {\"content\": \" world!\"}}]}"; - private static final String STREAMING_DONE = "data: [DONE]"; +class SarvamAiTest { - @Mock private OkHttpClient mockHttpClient; - @Mock private Call mockCall; - @Mock private SarvamAiConfig mockConfig; + private MockWebServer server; + private SarvamAi sarvamAi; - @Captor private ArgumentCaptor requestCaptor; - @Captor private ArgumentCaptor callbackCaptor; + @BeforeEach + void setUp() throws IOException { + server = new MockWebServer(); + server.start(); - private SarvamAi sarvamAi; - private ObjectMapper objectMapper; + SarvamAiConfig config = + SarvamAiConfig.builder() + .apiKey("test-api-key") + .chatEndpoint(server.url("/v1/chat/completions").toString()) + .build(); - @Before - public void setUp() { - when(mockConfig.getApiKey()).thenReturn(API_KEY); - sarvamAi = new SarvamAi(MODEL_NAME, mockConfig); - objectMapper = new ObjectMapper(); + sarvamAi = + SarvamAi.builder() + .modelName("sarvam-m") + .config(config) + .httpClient(new OkHttpClient()) + .build(); + } - when(mockHttpClient.newCall(requestCaptor.capture())).thenReturn(mockCall); + @AfterEach + void tearDown() throws IOException { + server.shutdown(); } @Test - public void generateContent_blockingCall_returnsLlmResponse() throws IOException { - String mockResponseBody = createMockSarvamAiResponseBody(COMPLETION_TEXT); - Response mockResponse = - new Response.Builder() - .request(new Request.Builder().url("http://localhost").build()) - .protocol(Protocol.HTTP_1_1) - .code(200) - .message("OK") - .body(ResponseBody.create(mockResponseBody, MediaType.get("application/json"))) - .build(); + void generateContent_nonStreaming_returnsContent() { + String jsonResponse = + "{\"id\":\"chatcmpl-abc\",\"object\":\"chat.completion\",\"created\":1699000000," + + "\"model\":\"sarvam-m\",\"choices\":[{\"index\":0," + + "\"message\":{\"role\":\"assistant\",\"content\":\"Hello world\"}," + + "\"finish_reason\":\"stop\"}]," + + "\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}"; + server.enqueue(new MockResponse().setBody(jsonResponse)); + + LlmRequest request = buildUserRequest("Hi"); + TestSubscriber subscriber = sarvamAi.generateContent(request, false).test(); + + subscriber.awaitDone(5, TimeUnit.SECONDS); + subscriber.assertNoErrors(); + subscriber.assertValueCount(1); + + LlmResponse response = subscriber.values().get(0); + assertThat(response.content().flatMap(Content::parts).get().get(0).text().get()) + .isEqualTo("Hello world"); + } - when(mockCall.execute()).thenReturn(mockResponse); + @Test + void generateContent_streaming_returnsChunks() { + String chunk1 = "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"; + String chunk2 = "data: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n\n"; + String done = "data: [DONE]\n\n"; + + server.enqueue(new MockResponse().setBody(chunk1 + chunk2 + done)); + + LlmRequest request = buildUserRequest("Hi"); + TestSubscriber subscriber = sarvamAi.generateContent(request, true).test(); + + subscriber.awaitDone(5, TimeUnit.SECONDS); + subscriber.assertNoErrors(); + subscriber.assertValueCount(2); + + assertThat( + subscriber.values().get(0).content().flatMap(Content::parts).get().get(0).text().get()) + .isEqualTo("Hello"); + assertThat( + subscriber.values().get(1).content().flatMap(Content::parts).get().get(0).text().get()) + .isEqualTo(" world"); + } - LlmRequest llmRequest = - LlmRequest.builder() - .contents( - List.of(Content.builder().parts(List.of(Part.fromText("test query"))).build())) - .build(); - Flowable responseFlowable = sarvamAi.generateContent(llmRequest, false); + @Test + void generateContent_streamingChunksAreMarkedPartial() { + server.enqueue( + new MockResponse() + .setBody( + "data: {\"choices\":[{\"delta\":{\"content\":\"test\"}}]}\n\ndata: [DONE]\n\n")); + + LlmRequest request = buildUserRequest("Hi"); + TestSubscriber subscriber = sarvamAi.generateContent(request, true).test(); + + subscriber.awaitDone(5, TimeUnit.SECONDS); + subscriber.assertNoErrors(); + LlmResponse response = subscriber.values().get(0); + assertThat(response.partial().orElse(false)).isTrue(); + } - LlmResponse llmResponse = responseFlowable.blockingFirst(); + @Test + void generateContent_serverError_propagatesException() { + server.enqueue( + new MockResponse() + .setResponseCode(500) + .setBody( + "{\"error\":{\"message\":\"Internal error\",\"code\":\"internal_server_error\"}}")); + + LlmRequest request = buildUserRequest("Hi"); + TestSubscriber subscriber = sarvamAi.generateContent(request, false).test(); + + subscriber.awaitDone(5, TimeUnit.SECONDS); + subscriber.assertError(SarvamAiException.class); + } - assertThat(llmResponse.content().get().parts().get().get(0).text()).isEqualTo(COMPLETION_TEXT); + @Test + void generateContent_usesCorrectAuthHeader() throws InterruptedException { + server.enqueue( + new MockResponse() + .setBody("{\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"ok\"}}]}")); + + sarvamAi.generateContent(buildUserRequest("Hi"), false).blockingSubscribe(); + + RecordedRequest recorded = server.takeRequest(5, TimeUnit.SECONDS); + assertThat(recorded).isNotNull(); + assertThat(recorded.getHeader("api-subscription-key")).isEqualTo("test-api-key"); } @Test - public void generateContent_streamingCall_returnsLlmResponses() throws IOException { - ResponseBody mockStreamingBody = - createMockStreamingResponseBody( - List.of(STREAMING_CHUNK_1, STREAMING_CHUNK_2, STREAMING_DONE)); - - Response mockResponse = - new Response.Builder() - .request(new Request.Builder().url("http://localhost").build()) - .protocol(Protocol.HTTP_1_1) - .code(200) - .message("OK") - .body(mockStreamingBody) - .build(); + void generateContent_setsStreamFlagInBody() throws InterruptedException { + String chunk = "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\ndata: [DONE]\n\n"; + server.enqueue(new MockResponse().setBody(chunk)); - when(mockCall.execute()) - .thenThrow(new IllegalStateException("Should not be called for streaming")); + sarvamAi.generateContent(buildUserRequest("Hello"), true).blockingSubscribe(); - LlmRequest llmRequest = + RecordedRequest recorded = server.takeRequest(5, TimeUnit.SECONDS); + assertThat(recorded).isNotNull(); + String body = recorded.getBody().readUtf8(); + assertThat(body).contains("\"stream\":true"); + } + + @Test + void generateContent_mapsModelRoleToAssistant() throws InterruptedException { + server.enqueue( + new MockResponse() + .setBody("{\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"ok\"}}]}")); + + LlmRequest request = LlmRequest.builder() .contents( - List.of(Content.builder().parts(List.of(Part.fromText("test query"))).build())) + java.util.List.of( + Content.builder().role("user").parts(Part.fromText("Hi")).build(), + Content.builder().role("model").parts(Part.fromText("Hello")).build(), + Content.builder().role("user").parts(Part.fromText("How?")).build())) .build(); - Flowable responseFlowable = sarvamAi.generateContent(llmRequest, true); - // Simulate the asynchronous callback - Callback capturedCallback = callbackCaptor.getValue(); - capturedCallback.onResponse(mockCall, mockResponse); + sarvamAi.generateContent(request, false).blockingSubscribe(); - List responses = responseFlowable.toList().blockingGet(); + RecordedRequest recorded = server.takeRequest(5, TimeUnit.SECONDS); + String body = recorded.getBody().readUtf8(); + assertThat(body).contains("\"role\":\"assistant\""); + assertThat(body).doesNotContain("\"role\":\"model\""); + } - assertThat(responses).hasSize(2); - assertThat(responses.get(0).content().get().parts().get().get(0).text()).isEqualTo("Hello,"); - assertThat(responses.get(1).content().get().parts().get().get(0).text()).isEqualTo(" world!"); + @Test + void builder_requiresModelName() { + assertThrows( + NullPointerException.class, + () -> SarvamAi.builder().config(SarvamAiConfig.builder().apiKey("key").build()).build()); } - // Helper method to create a mock SarvamAi response body - private String createMockSarvamAiResponseBody(String text) { - return String.format("{\"choices\": [{\"message\": {\"content\": \"%s\"}}]}", text); + @Test + void builder_requiresConfig() { + assertThrows( + NullPointerException.class, () -> SarvamAi.builder().modelName("sarvam-m").build()); } - // Helper method to create a mock streaming response body - private ResponseBody createMockStreamingResponseBody(List chunks) { - StringBuilder bodyBuilder = new StringBuilder(); - for (String chunk : chunks) { - bodyBuilder.append(chunk).append("\n\n"); - } - return ResponseBody.create(bodyBuilder.toString(), MediaType.get("text/event-stream")); + private LlmRequest buildUserRequest(String text) { + return LlmRequest.builder() + .contents( + Collections.singletonList( + Content.builder().role("user").parts(Part.fromText(text)).build())) + .build(); } } diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamRetryInterceptorTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamRetryInterceptorTest.java new file mode 100644 index 000000000..ff8614bd7 --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamRetryInterceptorTest.java @@ -0,0 +1,46 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.jupiter.api.Test; + +class SarvamRetryInterceptorTest { + + @Test + void calculateDelay_exponentiallyIncreases() { + long delay0 = SarvamRetryInterceptor.calculateDelay(0); + long delay1 = SarvamRetryInterceptor.calculateDelay(1); + long delay2 = SarvamRetryInterceptor.calculateDelay(2); + + assertThat(delay0).isAtLeast(500); + assertThat(delay0).isAtMost(700); + + assertThat(delay1).isAtLeast(1000); + assertThat(delay1).isAtMost(1400); + + assertThat(delay2).isAtLeast(2000); + assertThat(delay2).isAtMost(2800); + } + + @Test + void calculateDelay_respectsMaxCap() { + long delay10 = SarvamRetryInterceptor.calculateDelay(10); + assertThat(delay10).isAtMost(36_000); + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/chat/ChatRequestTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/chat/ChatRequestTest.java new file mode 100644 index 000000000..36cb04d96 --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/chat/ChatRequestTest.java @@ -0,0 +1,122 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.chat; + +import static com.google.common.truth.Truth.assertThat; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import java.util.List; +import org.junit.jupiter.api.Test; + +class ChatRequestTest { + + private final ObjectMapper objectMapper = new ObjectMapper(); + + @Test + void fromLlmRequest_mapsUserAndAssistantMessages() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .contents( + List.of( + Content.builder().role("user").parts(Part.fromText("Hello")).build(), + Content.builder().role("model").parts(Part.fromText("Hi there")).build(), + Content.builder().role("user").parts(Part.fromText("How?")).build())) + .build(); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").temperature(0.5).build(); + ChatRequest request = ChatRequest.fromLlmRequest("sarvam-m", llmRequest, config, false); + + assertThat(request.getModel()).isEqualTo("sarvam-m"); + assertThat(request.getMessages()).hasSize(3); + assertThat(request.getMessages().get(0).getRole()).isEqualTo("user"); + assertThat(request.getMessages().get(1).getRole()).isEqualTo("assistant"); + assertThat(request.getMessages().get(2).getRole()).isEqualTo("user"); + assertThat(request.getTemperature()).isWithin(0.001).of(0.5); + assertThat(request.getStream()).isNull(); + } + + @Test + void fromLlmRequest_includesSystemInstructions() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.builder().role("user").parts(Part.fromText("Hello")).build())) + .config( + GenerateContentConfig.builder() + .systemInstruction( + Content.builder() + .parts(Part.fromText("You are a helpful assistant")) + .build()) + .build()) + .build(); + + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").build(); + ChatRequest request = ChatRequest.fromLlmRequest("sarvam-m", llmRequest, config, true); + + assertThat(request.getMessages().get(0).getRole()).isEqualTo("system"); + assertThat(request.getMessages().get(0).getContent()).isEqualTo("You are a helpful assistant"); + assertThat(request.getStream()).isTrue(); + } + + @Test + void fromLlmRequest_appliesConfigParameters() throws Exception { + SarvamAiConfig config = + SarvamAiConfig.builder() + .apiKey("key") + .temperature(0.7) + .topP(0.9) + .maxTokens(100) + .reasoningEffort("high") + .wikiGrounding(true) + .build(); + + LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.builder().role("user").parts(Part.fromText("test")).build())) + .build(); + + ChatRequest request = ChatRequest.fromLlmRequest("sarvam-m", llmRequest, config, false); + + assertThat(request.getTemperature()).isWithin(0.001).of(0.7); + assertThat(request.getTopP()).isWithin(0.001).of(0.9); + assertThat(request.getMaxTokens()).isEqualTo(100); + assertThat(request.getReasoningEffort()).isEqualTo("high"); + assertThat(request.getWikiGrounding()).isTrue(); + } + + @Test + void serialization_excludesNullFields() throws Exception { + SarvamAiConfig config = SarvamAiConfig.builder().apiKey("key").build(); + LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.builder().role("user").parts(Part.fromText("Hi")).build())) + .build(); + + ChatRequest request = ChatRequest.fromLlmRequest("sarvam-m", llmRequest, config, false); + String json = objectMapper.writeValueAsString(request); + + assertThat(json).doesNotContain("temperature"); + assertThat(json).doesNotContain("stream"); + assertThat(json).doesNotContain("wiki_grounding"); + assertThat(json).contains("\"model\":\"sarvam-m\""); + assertThat(json).contains("\"messages\""); + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/stt/SarvamSttServiceTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/stt/SarvamSttServiceTest.java new file mode 100644 index 000000000..b69529368 --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/stt/SarvamSttServiceTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.stt; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.adk.transcription.TranscriptionConfig; +import com.google.adk.transcription.TranscriptionException; +import com.google.adk.transcription.TranscriptionResult; +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import okhttp3.OkHttpClient; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class SarvamSttServiceTest { + + private MockWebServer server; + private SarvamSttService sttService; + + @BeforeEach + void setUp() throws IOException { + server = new MockWebServer(); + server.start(); + + SarvamAiConfig config = + SarvamAiConfig.builder() + .apiKey("test-stt-key") + .sttEndpoint(server.url("/speech-to-text").toString()) + .sttModel("saaras:v3") + .sttMode("transcribe") + .sttLanguageCode("hi-IN") + .build(); + + sttService = new SarvamSttService(config, new OkHttpClient()); + } + + @AfterEach + void tearDown() throws IOException { + server.shutdown(); + } + + @Test + void transcribe_success() throws TranscriptionException, InterruptedException { + server.enqueue( + new MockResponse() + .setBody( + "{\"request_id\":\"req-123\",\"transcript\":\"नमस्ते\",\"language_code\":\"hi-IN\"}")); + + TranscriptionConfig requestConfig = + TranscriptionConfig.builder() + .endpoint(server.url("/speech-to-text").toString()) + .language("hi-IN") + .build(); + + TranscriptionResult result = sttService.transcribe(new byte[] {1, 2, 3}, requestConfig); + + assertThat(result.getText()).isEqualTo("नमस्ते"); + assertThat(result.getLanguage().orElse("")).isEqualTo("hi-IN"); + + RecordedRequest recorded = server.takeRequest(5, TimeUnit.SECONDS); + assertThat(recorded.getHeader("api-subscription-key")).isEqualTo("test-stt-key"); + } + + @Test + void transcribe_serverError_throwsException() { + server.enqueue(new MockResponse().setResponseCode(500).setBody("Server error")); + + TranscriptionConfig requestConfig = + TranscriptionConfig.builder() + .endpoint(server.url("/speech-to-text").toString()) + .language("hi-IN") + .build(); + + assertThrows( + TranscriptionException.class, + () -> sttService.transcribe(new byte[] {1, 2, 3}, requestConfig)); + } + + @Test + void isAvailable_returnsTrue() { + assertThat(sttService.isAvailable()).isTrue(); + } + + @Test + void getServiceType_returnsSarvam() { + assertThat(sttService.getServiceType().getValue()).isEqualTo("sarvam"); + } +} diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/tts/SarvamTtsServiceTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/tts/SarvamTtsServiceTest.java new file mode 100644 index 000000000..bcae3e3d4 --- /dev/null +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/tts/SarvamTtsServiceTest.java @@ -0,0 +1,105 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.sarvamai.tts; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.adk.models.sarvamai.SarvamAiConfig; +import com.google.adk.models.sarvamai.SarvamAiException; +import java.io.IOException; +import java.util.Base64; +import java.util.concurrent.TimeUnit; +import okhttp3.OkHttpClient; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class SarvamTtsServiceTest { + + private MockWebServer server; + private SarvamTtsService ttsService; + + @BeforeEach + void setUp() throws IOException { + server = new MockWebServer(); + server.start(); + + SarvamAiConfig config = + SarvamAiConfig.builder() + .apiKey("test-tts-key") + .ttsEndpoint(server.url("/text-to-speech").toString()) + .ttsModel("bulbul:v3") + .ttsSpeaker("shubh") + .build(); + + ttsService = new SarvamTtsService(config, new OkHttpClient()); + } + + @AfterEach + void tearDown() throws IOException { + server.shutdown(); + } + + @Test + void synthesize_success() throws InterruptedException { + byte[] expectedAudio = "fake-audio-data".getBytes(); + String base64Audio = Base64.getEncoder().encodeToString(expectedAudio); + String responseBody = + String.format("{\"request_id\":\"req-456\",\"audios\":[\"%s\"]}", base64Audio); + + server.enqueue(new MockResponse().setBody(responseBody)); + + byte[] audio = ttsService.synthesize("Hello world", "en-IN"); + + assertThat(audio).isEqualTo(expectedAudio); + + RecordedRequest recorded = server.takeRequest(5, TimeUnit.SECONDS); + assertThat(recorded.getHeader("api-subscription-key")).isEqualTo("test-tts-key"); + String body = recorded.getBody().readUtf8(); + assertThat(body).contains("\"model\":\"bulbul:v3\""); + assertThat(body).contains("\"speaker\":\"shubh\""); + assertThat(body).contains("\"target_language_code\":\"en-IN\""); + } + + @Test + void synthesize_serverError_throwsException() { + server.enqueue(new MockResponse().setResponseCode(500).setBody("Server error")); + + assertThrows(SarvamAiException.class, () -> ttsService.synthesize("Hello", "en-IN")); + } + + @Test + void synthesize_emptyAudio_throwsException() { + server.enqueue(new MockResponse().setBody("{\"request_id\":\"req-789\",\"audios\":[]}")); + + assertThrows(SarvamAiException.class, () -> ttsService.synthesize("Hello", "en-IN")); + } + + @Test + void synthesize_nullText_throwsNpe() { + assertThrows(NullPointerException.class, () -> ttsService.synthesize(null, "en-IN")); + } + + @Test + void isAvailable_returnsTrue() { + assertThat(ttsService.isAvailable()).isTrue(); + } +} diff --git a/core/src/main/java/com/google/adk/models/Sarvam.java b/core/src/main/java/com/google/adk/models/Sarvam.java deleted file mode 100644 index 65dc61443..000000000 --- a/core/src/main/java/com/google/adk/models/Sarvam.java +++ /dev/null @@ -1,335 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// MODIFIED BY Sandeep Belgavi, 2026-02-11 -package com.google.adk.models; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.google.adk.tools.BaseTool; -import com.google.common.base.Strings; -import com.google.genai.types.Blob; -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import io.reactivex.rxjava3.core.BackpressureStrategy; -import io.reactivex.rxjava3.core.Completable; -import io.reactivex.rxjava3.core.Flowable; -import io.reactivex.rxjava3.subjects.PublishSubject; -import java.io.BufferedReader; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import okhttp3.MediaType; -import okhttp3.OkHttpClient; -import okhttp3.Request; -import okhttp3.RequestBody; -import okhttp3.Response; - -/** - * Sarvam AI LLM implementation. Uses the OpenAI-compatible chat completion endpoint. - * - * @author Sandeep Belgavi - * @since 2026-02-11 - */ -public class Sarvam extends BaseLlm { - - private final String apiUrl; - private static final MediaType JSON = MediaType.get("application/json; charset=utf-8"); - - private final String apiKey; - private final OkHttpClient client; - private final ObjectMapper objectMapper; - - public Sarvam(String model) { - this(model, null); - } - - public Sarvam(String model, String apiKey) { - this(model, apiKey, "https://api.sarvam.ai/chat/completions", new OkHttpClient()); - } - - protected Sarvam(String model, String apiKey, String apiUrl, OkHttpClient client) { - super(model); - if (Strings.isNullOrEmpty(apiKey)) { - this.apiKey = System.getenv("SARVAM_API_KEY"); - } else { - this.apiKey = apiKey; - } - - if (Strings.isNullOrEmpty(this.apiKey)) { - // Allow null for testing if mocked client handles it, but typically warn or throw. - // throw new IllegalArgumentException("Sarvam API key is required."); - } - - this.apiUrl = apiUrl; - this.client = client; - this.objectMapper = new ObjectMapper(); - } - - @Override - public Flowable generateContent(LlmRequest llmRequest, boolean stream) { - return Flowable.create( - emitter -> { - try { - ObjectNode jsonBody = objectMapper.createObjectNode(); - jsonBody.put("model", model()); - jsonBody.put("stream", stream); - - ArrayNode messages = jsonBody.putArray("messages"); - - // Add system instructions if present - for (String instruction : llmRequest.getSystemInstructions()) { - ObjectNode systemMsg = messages.addObject(); - systemMsg.put("role", "system"); - systemMsg.put("content", instruction); - } - - // Add conversation history - for (Content content : llmRequest.contents()) { - ObjectNode message = messages.addObject(); - String role = content.role().orElse("user"); - // Map "model" to "assistant" for OpenAI compatibility - if ("model".equals(role)) { - role = "assistant"; - } - message.put("role", role); - - StringBuilder textBuilder = new StringBuilder(); - content - .parts() - .ifPresent( - parts -> { - for (Part part : parts) { - part.text().ifPresent(textBuilder::append); - } - }); - message.put("content", textBuilder.toString()); - } - - // Add tool definitions if present - if (llmRequest.tools() != null && !llmRequest.tools().isEmpty()) { - ArrayNode toolsArray = jsonBody.putArray("tools"); - for (BaseTool tool : llmRequest.tools().values()) { - ObjectNode toolNode = toolsArray.addObject(); - toolNode.put("type", "function"); - ObjectNode functionNode = toolNode.putObject("function"); - functionNode.put("name", tool.name()); - functionNode.put("description", tool.description()); - - tool.declaration() - .flatMap(decl -> decl.parameters()) - .ifPresent( - params -> { - try { - String paramsJson = objectMapper.writeValueAsString(params); - functionNode.set("parameters", objectMapper.readTree(paramsJson)); - } catch (Exception e) { - // Ignore or log error - } - }); - } - } - - RequestBody body = RequestBody.create(jsonBody.toString(), JSON); - Request request = - new Request.Builder() - .url(apiUrl) - .addHeader("Content-Type", "application/json") - .addHeader("api-subscription-key", apiKey != null ? apiKey : "") - .post(body) - .build(); - - if (stream) { - try (Response response = client.newCall(request).execute()) { - if (!response.isSuccessful()) { - emitter.onError( - new IOException( - "Unexpected code " - + response - + " body: " - + (response.body() != null ? response.body().string() : ""))); - return; - } - - if (response.body() == null) { - emitter.onError(new IOException("Response body is null")); - return; - } - - BufferedReader reader = new BufferedReader(response.body().charStream()); - String line; - while ((line = reader.readLine()) != null) { - if (line.startsWith("data: ")) { - String data = line.substring(6).trim(); - if ("[DONE]".equals(data)) { - break; - } - try { - JsonNode chunk = objectMapper.readTree(data); - JsonNode choices = chunk.path("choices"); - if (choices.isArray() && choices.size() > 0) { - JsonNode delta = choices.get(0).path("delta"); - if (delta.has("content")) { - String contentPart = delta.get("content").asText(); - - Content content = - Content.builder() - .role("model") - .parts(Part.fromText(contentPart)) - .build(); - - LlmResponse llmResponse = - LlmResponse.builder().content(content).partial(true).build(); - emitter.onNext(llmResponse); - } - } - } catch (Exception e) { - // Ignore parse errors for keep-alive or malformed lines - } - } - } - emitter.onComplete(); - } - } else { - try (Response response = client.newCall(request).execute()) { - if (!response.isSuccessful()) { - emitter.onError( - new IOException( - "Unexpected code " - + response - + " body: " - + (response.body() != null ? response.body().string() : ""))); - return; - } - if (response.body() == null) { - emitter.onError(new IOException("Response body is null")); - return; - } - String responseBody = response.body().string(); - JsonNode root = objectMapper.readTree(responseBody); - JsonNode choices = root.path("choices"); - if (choices.isArray() && choices.size() > 0) { - JsonNode message = choices.get(0).path("message"); - String contentText = message.path("content").asText(); - - Content content = - Content.builder().role("model").parts(Part.fromText(contentText)).build(); - - LlmResponse llmResponse = LlmResponse.builder().content(content).build(); - emitter.onNext(llmResponse); - emitter.onComplete(); - } else { - emitter.onError(new IOException("Empty choices in response")); - } - } - } - } catch (Exception e) { - emitter.onError(e); - } - }, - BackpressureStrategy.BUFFER); - } - - @Override - public BaseLlmConnection connect(LlmRequest llmRequest) { - return new SarvamConnection(llmRequest); - } - - private class SarvamConnection implements BaseLlmConnection { - private final LlmRequest initialRequest; - private final List history = new ArrayList<>(); - private final PublishSubject responseSubject = PublishSubject.create(); - - public SarvamConnection(LlmRequest llmRequest) { - this.initialRequest = llmRequest; - this.history.addAll(llmRequest.contents()); - } - - @Override - public Completable sendContent(Content content) { - return Completable.fromAction( - () -> { - history.add(content); - generate(); - }); - } - - @Override - public Completable sendHistory(List history) { - return Completable.fromAction( - () -> { - this.history.clear(); - this.history.addAll(history); - generate(); - }); - } - - @Override - public Completable sendRealtime(Blob blob) { - return Completable.error( - new UnsupportedOperationException("Realtime not supported for Sarvam")); - } - - private void generate() { - LlmRequest.Builder builder = - LlmRequest.builder().contents(new ArrayList<>(history)).tools(initialRequest.tools()); - builder.appendInstructions(initialRequest.getSystemInstructions()); - LlmRequest request = builder.build(); - - StringBuilder fullContent = new StringBuilder(); - generateContent(request, true) - .subscribe( - response -> { - responseSubject.onNext(response); - response - .content() - .flatMap(Content::parts) - .ifPresent( - parts -> { - for (Part part : parts) { - part.text().ifPresent(fullContent::append); - } - }); - }, - responseSubject::onError, - () -> { - Content responseContent = - Content.builder() - .role("model") - .parts(Part.fromText(fullContent.toString())) - .build(); - history.add(responseContent); - }); - } - - @Override - public Flowable receive() { - return responseSubject.toFlowable(BackpressureStrategy.BUFFER); - } - - @Override - public void close() { - responseSubject.onComplete(); - } - - @Override - public void close(Throwable throwable) { - responseSubject.onError(throwable); - } - } -} diff --git a/core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java b/core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java deleted file mode 100644 index 3228d2eb8..000000000 --- a/core/src/main/java/com/google/adk/transcription/strategy/SarvamTranscriptionService.java +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// MODIFIED BY Sandeep Belgavi, 2026-02-11 -package com.google.adk.transcription.strategy; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.adk.transcription.ServiceHealth; -import com.google.adk.transcription.ServiceType; -import com.google.adk.transcription.TranscriptionConfig; -import com.google.adk.transcription.TranscriptionEvent; -import com.google.adk.transcription.TranscriptionException; -import com.google.adk.transcription.TranscriptionResult; -import com.google.adk.transcription.TranscriptionService; -import com.google.adk.transcription.processor.AudioChunkAggregator; -import com.google.common.base.Strings; -import io.reactivex.rxjava3.core.Flowable; -import io.reactivex.rxjava3.core.Single; -import java.io.IOException; -import java.time.Duration; -import java.util.concurrent.TimeUnit; -import okhttp3.MediaType; -import okhttp3.MultipartBody; -import okhttp3.OkHttpClient; -import okhttp3.Request; -import okhttp3.RequestBody; -import okhttp3.Response; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Sarvam AI transcription service implementation. - * - * @author Sandeep Belgavi - * @since 2026-02-11 - */ -public class SarvamTranscriptionService implements TranscriptionService { - private static final Logger logger = LoggerFactory.getLogger(SarvamTranscriptionService.class); - private static final String API_URL = "https://api.sarvam.ai/speech-to-text"; - - private final OkHttpClient client; - private final String apiKey; - private final ObjectMapper objectMapper; - - public SarvamTranscriptionService() { - this(null); - } - - public SarvamTranscriptionService(String apiKey) { - if (Strings.isNullOrEmpty(apiKey)) { - this.apiKey = System.getenv("SARVAM_API_KEY"); - } else { - this.apiKey = apiKey; - } - - if (Strings.isNullOrEmpty(this.apiKey)) { - logger.warn("Sarvam API key not found. STT will fail."); - } - - this.client = - new OkHttpClient.Builder() - .connectTimeout(30, TimeUnit.SECONDS) - .readTimeout(60, TimeUnit.SECONDS) - .build(); - this.objectMapper = new ObjectMapper(); - } - - @Override - public TranscriptionResult transcribe(byte[] audioData, TranscriptionConfig requestConfig) - throws TranscriptionException { - try { - RequestBody fileBody = RequestBody.create(audioData, MediaType.parse("audio/wav")); - - MultipartBody requestBody = - new MultipartBody.Builder() - .setType(MultipartBody.FORM) - .addFormDataPart("file", "audio.wav", fileBody) - .addFormDataPart("model", "saaras_v3") - .addFormDataPart("language_code", requestConfig.getLanguage()) - .build(); - - Request request = - new Request.Builder() - .url(API_URL) - .addHeader("api-subscription-key", apiKey) - .post(requestBody) - .build(); - - try (Response response = client.newCall(request).execute()) { - if (!response.isSuccessful()) { - String errorBody = response.body() != null ? response.body().string() : ""; - throw new IOException("Unexpected code " + response + " body: " + errorBody); - } - - JsonNode root = objectMapper.readTree(response.body().string()); - String transcript = root.path("transcript").asText(); - - return TranscriptionResult.builder() - .text(transcript) - .timestamp(System.currentTimeMillis()) - .build(); - } - } catch (Exception e) { - logger.error("Error transcribing audio with Sarvam", e); - throw new TranscriptionException("Transcription failed", e); - } - } - - @Override - public Single transcribeAsync( - byte[] audioData, TranscriptionConfig requestConfig) { - return Single.fromCallable(() -> transcribe(audioData, requestConfig)) - .subscribeOn(io.reactivex.rxjava3.schedulers.Schedulers.io()); - } - - @Override - public Flowable transcribeStream( - Flowable audioStream, TranscriptionConfig requestConfig) { - AudioChunkAggregator aggregator = - new AudioChunkAggregator( - requestConfig.getAudioFormat(), Duration.ofMillis(requestConfig.getChunkSizeMs())); - - return audioStream - .buffer(requestConfig.getChunkSizeMs(), TimeUnit.MILLISECONDS) - .map( - chunks -> { - byte[] aggregated = aggregator.aggregate(chunks); - try { - TranscriptionResult result = transcribe(aggregated, requestConfig); - return mapToTranscriptionEvent(result); - } catch (TranscriptionException e) { - logger.error("Stream transcription error", e); - throw new RuntimeException(e); - } - }); - } - - @Override - public boolean isAvailable() { - return !Strings.isNullOrEmpty(apiKey); - } - - @Override - public ServiceType getServiceType() { - return ServiceType.SARVAM; - } - - @Override - public ServiceHealth getHealth() { - return ServiceHealth.builder().available(isAvailable()).serviceType(ServiceType.SARVAM).build(); - } - - private TranscriptionEvent mapToTranscriptionEvent(TranscriptionResult result) { - return TranscriptionEvent.builder() - .text(result.getText()) - .finished(true) - .timestamp(result.getTimestamp()) - .build(); - } -} diff --git a/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java b/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java index d271b06af..9260e7849 100644 --- a/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java +++ b/core/src/main/java/com/google/adk/transcription/strategy/TranscriptionServiceFactory.java @@ -85,7 +85,9 @@ private static TranscriptionService createService(TranscriptionConfig config) { switch (serviceType) { case SARVAM: - return new SarvamTranscriptionService(config.getApiKey().orElse(null)); + throw new UnsupportedOperationException( + "Sarvam STT has moved to the contrib/sarvam-ai module. " + + "Use SarvamSttService from com.google.adk.models.sarvamai.stt instead."); case WHISPER: return createWhisperService(config); diff --git a/core/src/test/java/com/google/adk/models/SarvamIT.java b/core/src/test/java/com/google/adk/models/SarvamIT.java deleted file mode 100644 index dfc355f20..000000000 --- a/core/src/test/java/com/google/adk/models/SarvamIT.java +++ /dev/null @@ -1,49 +0,0 @@ -package com.google.adk.models; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assume.assumeNotNull; - -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import io.reactivex.rxjava3.subscribers.TestSubscriber; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class SarvamIT { - - private String apiKey; - - @Before - public void setUp() { - apiKey = System.getenv("SARVAM_API_KEY"); - // Skip test if API key is not set - assumeNotNull(apiKey); - } - - @Test - public void testGenerateContent() { - Sarvam sarvam = new Sarvam("sarvam-2.0", apiKey); - - LlmRequest request = - LlmRequest.builder() - .contents( - java.util.Collections.singletonList( - Content.builder() - .role("user") - .parts(Part.fromText("Hello, say hi back!")) - .build())) - .build(); - - TestSubscriber subscriber = sarvam.generateContent(request, false).test(); - - subscriber.awaitDone(30, java.util.concurrent.TimeUnit.SECONDS); - subscriber.assertNoErrors(); - subscriber.assertValueCount(1); - - LlmResponse response = subscriber.values().get(0); - assertThat(response.content().flatMap(Content::parts).get().get(0).text().get()).isNotEmpty(); - } -} diff --git a/core/src/test/java/com/google/adk/models/SarvamTest.java b/core/src/test/java/com/google/adk/models/SarvamTest.java deleted file mode 100644 index 12b06a761..000000000 --- a/core/src/test/java/com/google/adk/models/SarvamTest.java +++ /dev/null @@ -1,106 +0,0 @@ -package com.google.adk.models; - -import static com.google.common.truth.Truth.assertThat; - -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import io.reactivex.rxjava3.subscribers.TestSubscriber; -import java.io.IOException; -import okhttp3.OkHttpClient; -import okhttp3.mockwebserver.MockResponse; -import okhttp3.mockwebserver.MockWebServer; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class SarvamTest { - - private MockWebServer mockWebServer; - private Sarvam sarvam; - - @Before - public void setUp() throws IOException { - mockWebServer = new MockWebServer(); - mockWebServer.start(); - // Use the protected constructor to inject the mock server URL and client - sarvam = - new Sarvam("sarvam-2.0", "fake-key", mockWebServer.url("/").toString(), new OkHttpClient()); - } - - @After - public void tearDown() throws IOException { - mockWebServer.shutdown(); - } - - @Test - public void generateContent_nonStreaming_success() { - String jsonResponse = "{\"choices\": [{\"message\": {\"content\": \"Hello world\"}}]}"; - mockWebServer.enqueue(new MockResponse().setBody(jsonResponse)); - - LlmRequest request = - LlmRequest.builder() - .contents( - java.util.Collections.singletonList( - Content.builder().role("user").parts(Part.fromText("Hi")).build())) - .build(); - - TestSubscriber subscriber = sarvam.generateContent(request, false).test(); - - subscriber.awaitDone(5, java.util.concurrent.TimeUnit.SECONDS); - subscriber.assertNoErrors(); - subscriber.assertValueCount(1); - - LlmResponse response = subscriber.values().get(0); - assertThat(response.content().flatMap(Content::parts).get().get(0).text().get()) - .isEqualTo("Hello world"); - } - - @Test - public void generateContent_streaming_success() { - String chunk1 = "data: {\"choices\": [{\"delta\": {\"content\": \"Hello\"}}]}\n\n"; - String chunk2 = "data: {\"choices\": [{\"delta\": {\"content\": \" world\"}}]}\n\n"; - String done = "data: [DONE]\n\n"; - - mockWebServer.enqueue(new MockResponse().setBody(chunk1 + chunk2 + done)); - - LlmRequest request = - LlmRequest.builder() - .contents( - java.util.Collections.singletonList( - Content.builder().role("user").parts(Part.fromText("Hi")).build())) - .build(); - - TestSubscriber subscriber = sarvam.generateContent(request, true).test(); - - subscriber.awaitDone(5, java.util.concurrent.TimeUnit.SECONDS); - subscriber.assertNoErrors(); - subscriber.assertValueCount(2); - - assertThat( - subscriber.values().get(0).content().flatMap(Content::parts).get().get(0).text().get()) - .isEqualTo("Hello"); - assertThat( - subscriber.values().get(1).content().flatMap(Content::parts).get().get(0).text().get()) - .isEqualTo(" world"); - } - - @Test - public void generateContent_error() { - mockWebServer.enqueue(new MockResponse().setResponseCode(500).setBody("Internal Error")); - - LlmRequest request = - LlmRequest.builder() - .contents( - java.util.Collections.singletonList( - Content.builder().role("user").parts(Part.fromText("Hi")).build())) - .build(); - - TestSubscriber subscriber = sarvam.generateContent(request, false).test(); - - subscriber.awaitDone(5, java.util.concurrent.TimeUnit.SECONDS); - subscriber.assertError(IOException.class); - } -} From 7c7eb0b00ecf0c298367355860bc0d649e0215c8 Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Fri, 20 Feb 2026 15:03:51 +0530 Subject: [PATCH 07/11] Add SarvamBaseLM for Sarvam AI model integration Implements BaseLlm for Sarvam AI's OpenAI-compatible chat completions API with support for both streaming (SSE) and non-streaming modes, tool/function calling, and token usage tracking. Configurable via SARVAM_API_BASE and SARVAM_API_KEY env vars. Co-authored-by: Cursor --- .../com/google/adk/models/SarvamBaseLM.java | 652 ++++++++++++++++++ 1 file changed, 652 insertions(+) create mode 100644 core/src/main/java/com/google/adk/models/SarvamBaseLM.java diff --git a/core/src/main/java/com/google/adk/models/SarvamBaseLM.java b/core/src/main/java/com/google/adk/models/SarvamBaseLM.java new file mode 100644 index 000000000..ab9cb6f85 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/SarvamBaseLM.java @@ -0,0 +1,652 @@ +package com.google.adk.models; + +import static com.google.adk.models.RedbusADG.cleanForIdentifierPattern; +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.google.adk.tools.BaseTool; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.Part; +import com.google.genai.types.Schema; +import io.reactivex.rxjava3.core.Flowable; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.json.JSONArray; +import org.json.JSONObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * BaseLlm implementation for Sarvam AI models. + * + *

Sarvam AI exposes an OpenAI-compatible chat completions API. The base URL is read from the + * {@code SARVAM_API_BASE} environment variable (default {@code https://api.sarvam.ai/v1}) and the + * API key from {@code SARVAM_API_KEY}. + * + * @author Sandeep Belgavi + */ +public class SarvamBaseLM extends BaseLlm { + + public static final String SARVAM_API_BASE_ENV = "SARVAM_API_BASE"; + public static final String SARVAM_API_KEY_ENV = "SARVAM_API_KEY"; + private static final String DEFAULT_BASE_URL = "https://api.sarvam.ai/v1"; + + private final String baseUrl; + private static final Logger logger = LoggerFactory.getLogger(SarvamBaseLM.class); + + private static final String CONTINUE_OUTPUT_MESSAGE = + "Continue output. DO NOT look at this line. ONLY look at the content before this line and" + + " system instruction."; + + public SarvamBaseLM(String model) { + super(model); + this.baseUrl = null; + } + + public SarvamBaseLM(String model, String baseUrl) { + super(model); + this.baseUrl = baseUrl; + } + + private String resolveBaseUrl() { + if (baseUrl != null) { + return baseUrl; + } + String envUrl = System.getenv(SARVAM_API_BASE_ENV); + return envUrl != null ? envUrl : DEFAULT_BASE_URL; + } + + private String resolveApiKey() { + return System.getenv(SARVAM_API_KEY_ENV); + } + + @Override + public Flowable generateContent(LlmRequest llmRequest, boolean stream) { + if (stream) { + return generateContentStream(llmRequest); + } + + List contents = llmRequest.contents(); + if (contents.isEmpty() || !Iterables.getLast(contents).role().orElse("").equals("user")) { + Content userContent = Content.fromParts(Part.fromText(CONTINUE_OUTPUT_MESSAGE)); + contents = + Stream.concat(contents.stream(), Stream.of(userContent)).collect(toImmutableList()); + } + + String systemText = extractSystemText(llmRequest); + JSONArray messages = buildMessages(systemText, llmRequest.contents()); + JSONArray functions = buildTools(llmRequest); + + boolean lastRespToolExecuted = + Iterables.getLast(Iterables.getLast(contents).parts().get()).functionResponse().isPresent(); + + float temperature = + llmRequest.config().flatMap(GenerateContentConfig::temperature).orElse(0.7f); + + JSONObject response = + callChatCompletions( + this.model(), + messages, + lastRespToolExecuted ? null : (functions.length() > 0 ? functions : null), + temperature, + false); + + GenerateContentResponseUsageMetadata usageMetadata = extractUsageMetadata(response); + + JSONArray choices = response.optJSONArray("choices"); + if (choices == null || choices.length() == 0) { + logger.error("Sarvam API returned no choices: {}", response); + return Flowable.just( + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("")).build()) + .build()); + } + + JSONObject message = choices.getJSONObject(0).getJSONObject("message"); + Part part = openAiMessageToPart(message); + + LlmResponse.Builder responseBuilder = LlmResponse.builder(); + + if (part.functionCall().isPresent()) { + responseBuilder.content( + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.builder().functionCall(part.functionCall().get()).build())) + .build()); + } else { + responseBuilder.content( + Content.builder().role("model").parts(ImmutableList.of(part)).build()); + } + + if (usageMetadata != null) { + responseBuilder.usageMetadata(usageMetadata); + } + + return Flowable.just(responseBuilder.build()); + } + + private Flowable generateContentStream(LlmRequest llmRequest) { + List contents = llmRequest.contents(); + if (contents.isEmpty() || !Iterables.getLast(contents).role().orElse("").equals("user")) { + Content userContent = Content.fromParts(Part.fromText(CONTINUE_OUTPUT_MESSAGE)); + contents = + Stream.concat(contents.stream(), Stream.of(userContent)).collect(toImmutableList()); + } + + String systemText = extractSystemText(llmRequest); + JSONArray messages = buildMessages(systemText, llmRequest.contents()); + JSONArray functions = buildTools(llmRequest); + + final List finalContents = contents; + boolean lastRespToolExecuted = + Iterables.getLast(Iterables.getLast(finalContents).parts().get()) + .functionResponse() + .isPresent(); + + float temperature = + llmRequest.config().flatMap(GenerateContentConfig::temperature).orElse(0.7f); + + final StringBuilder accumulatedText = new StringBuilder(); + final StringBuilder functionCallName = new StringBuilder(); + final StringBuilder functionCallArgs = new StringBuilder(); + final AtomicBoolean inFunctionCall = new AtomicBoolean(false); + final AtomicBoolean streamCompleted = new AtomicBoolean(false); + final AtomicInteger inputTokens = new AtomicInteger(0); + final AtomicInteger outputTokens = new AtomicInteger(0); + + return Flowable.generate( + () -> + callChatCompletionsStream( + this.model(), + messages, + lastRespToolExecuted ? null : (functions.length() > 0 ? functions : null), + temperature), + (reader, emitter) -> { + try { + if (reader == null || streamCompleted.get()) { + emitter.onComplete(); + return; + } + + String line = reader.readLine(); + if (line == null) { + if (accumulatedText.length() > 0) { + emitter.onNext(createTextResponse(accumulatedText.toString(), false)); + } + emitter.onComplete(); + return; + } + + if (line.isEmpty() || line.equals("data: [DONE]")) { + if (line.equals("data: [DONE]")) { + streamCompleted.set(true); + GenerateContentResponseUsageMetadata usageMetadata = + buildUsageMetadata(inputTokens.get(), outputTokens.get()); + + if (inFunctionCall.get() && functionCallName.length() > 0) { + try { + Map args = new JSONObject(functionCallArgs.toString()).toMap(); + FunctionCall fc = + FunctionCall.builder().name(functionCallName.toString()).args(args).build(); + Part part = Part.builder().functionCall(fc).build(); + + LlmResponse.Builder funcResponseBuilder = + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(ImmutableList.of(part)) + .build()); + if (usageMetadata != null) { + funcResponseBuilder.usageMetadata(usageMetadata); + } + emitter.onNext(funcResponseBuilder.build()); + } catch (Exception funcEx) { + logger.error("Error creating function call response", funcEx); + } + } else if (accumulatedText.length() > 0) { + LlmResponse.Builder finalBuilder = + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText(accumulatedText.toString())) + .build()) + .partial(false); + if (usageMetadata != null) { + finalBuilder.usageMetadata(usageMetadata); + } + emitter.onNext(finalBuilder.build()); + } + emitter.onComplete(); + } + return; + } + + if (!line.startsWith("data: ")) { + return; + } + + String jsonStr = line.substring(6); + JSONObject chunk; + try { + chunk = new JSONObject(jsonStr); + } catch (Exception parseEx) { + logger.warn("Failed to parse Sarvam SSE chunk: {}", jsonStr, parseEx); + return; + } + + if (chunk.has("usage")) { + JSONObject usage = chunk.getJSONObject("usage"); + inputTokens.set(usage.optInt("prompt_tokens", 0)); + outputTokens.set(usage.optInt("completion_tokens", 0)); + } + + JSONArray choices = chunk.optJSONArray("choices"); + if (choices == null || choices.length() == 0) { + return; + } + + JSONObject choice = choices.getJSONObject(0); + JSONObject delta = choice.optJSONObject("delta"); + if (delta == null) { + return; + } + + if (delta.has("content") && !delta.isNull("content")) { + String text = delta.getString("content"); + if (!text.isEmpty()) { + accumulatedText.append(text); + emitter.onNext(createTextResponse(text, true)); + } + } + + if (delta.has("tool_calls")) { + inFunctionCall.set(true); + JSONArray toolCalls = delta.getJSONArray("tool_calls"); + if (toolCalls.length() > 0) { + JSONObject toolCall = toolCalls.getJSONObject(0); + JSONObject function = toolCall.optJSONObject("function"); + if (function != null) { + if (function.has("name") && !function.isNull("name")) { + functionCallName.append(function.getString("name")); + } + if (function.has("arguments") && !function.isNull("arguments")) { + functionCallArgs.append(function.getString("arguments")); + } + } + } + } + } catch (Exception e) { + logger.error("Error in Sarvam streaming", e); + emitter.onError(e); + } + }, + reader -> { + try { + if (reader != null) { + reader.close(); + } + } catch (IOException e) { + logger.error("Error closing stream reader", e); + } + }); + } + + private String extractSystemText(LlmRequest llmRequest) { + Optional configOpt = llmRequest.config(); + if (configOpt.isPresent()) { + Optional systemInstructionOpt = configOpt.get().systemInstruction(); + if (systemInstructionOpt.isPresent()) { + String text = + systemInstructionOpt.get().parts().orElse(ImmutableList.of()).stream() + .filter(p -> p.text().isPresent()) + .map(p -> p.text().get()) + .collect(Collectors.joining("\n")); + if (!text.isEmpty()) { + return text; + } + } + } + return ""; + } + + private JSONArray buildMessages(String systemText, List contents) { + JSONArray messages = new JSONArray(); + + if (!systemText.isEmpty()) { + JSONObject systemMsg = new JSONObject(); + systemMsg.put("role", "system"); + systemMsg.put("content", systemText); + messages.put(systemMsg); + } + + for (Content item : contents) { + JSONObject msg = new JSONObject(); + String role = item.role().orElse("user"); + msg.put("role", role.equals("model") ? "assistant" : role); + + if (item.parts().isPresent() && !item.parts().get().isEmpty()) { + Part firstPart = item.parts().get().get(0); + if (firstPart.functionResponse().isPresent()) { + msg.put( + "content", + new JSONObject(firstPart.functionResponse().get().response().get()).toString()); + msg.put("role", "tool"); + msg.put("tool_call_id", firstPart.functionResponse().get().name().orElse("unknown")); + } else { + msg.put("content", item.text()); + } + } else { + msg.put("content", item.text()); + } + messages.put(msg); + } + return messages; + } + + private JSONArray buildTools(LlmRequest llmRequest) { + JSONArray functions = new JSONArray(); + llmRequest + .tools() + .forEach( + (name, baseTool) -> { + Optional declOpt = baseTool.declaration(); + if (declOpt.isEmpty()) { + logger.warn("Skipping tool '{}' with missing declaration.", baseTool.name()); + return; + } + + FunctionDeclaration decl = declOpt.get(); + Map funcMap = new HashMap<>(); + funcMap.put("name", cleanForIdentifierPattern(decl.name().get())); + funcMap.put("description", cleanForIdentifierPattern(decl.description().orElse(""))); + + Optional paramsOpt = decl.parameters(); + if (paramsOpt.isPresent()) { + Schema paramsSchema = paramsOpt.get(); + Map paramsMap = new HashMap<>(); + paramsMap.put("type", "object"); + + Optional> propsOpt = paramsSchema.properties(); + if (propsOpt.isPresent()) { + Map propsMap = new HashMap<>(); + ObjectMapper mapper = new ObjectMapper(); + mapper.registerModule(new Jdk8Module()); + + propsOpt + .get() + .forEach( + (key, schema) -> { + Map schemaMap = + mapper.convertValue( + schema, new TypeReference>() {}); + normalizeTypeStrings(schemaMap); + propsMap.put(key, schemaMap); + }); + paramsMap.put("properties", propsMap); + } + + paramsSchema + .required() + .ifPresent(requiredList -> paramsMap.put("required", requiredList)); + funcMap.put("parameters", paramsMap); + } + + JSONObject toolWrapper = new JSONObject(); + toolWrapper.put("type", "function"); + toolWrapper.put("function", new JSONObject(funcMap)); + functions.put(toolWrapper); + }); + return functions; + } + + private JSONObject callChatCompletions( + String model, JSONArray messages, JSONArray tools, float temperature, boolean stream) { + try { + String apiUrl = resolveBaseUrl() + "/chat/completions"; + String apiKey = resolveApiKey(); + + JSONObject payload = new JSONObject(); + payload.put("model", model); + payload.put("messages", messages); + payload.put("temperature", temperature); + payload.put("stream", stream); + + if (tools != null && tools.length() > 0) { + payload.put("tools", tools); + payload.put("tool_choice", "auto"); + } + + String jsonString = payload.toString(); + logger.debug("Sarvam request: {}", jsonString); + + URL url = new URL(apiUrl); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); + if (apiKey != null && !apiKey.isEmpty()) { + conn.setRequestProperty("Authorization", "Bearer " + apiKey); + } + conn.setDoOutput(true); + conn.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length); + + try (OutputStream os = conn.getOutputStream(); + OutputStreamWriter writer = new OutputStreamWriter(os, "UTF-8")) { + writer.write(jsonString); + writer.flush(); + } + + int responseCode = conn.getResponseCode(); + logger.info("Sarvam response code: {} for model: {}", responseCode, model); + + InputStream inputStream = + (responseCode < 400) ? conn.getInputStream() : conn.getErrorStream(); + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(inputStream, "UTF-8"))) { + StringBuilder sb = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + sb.append(line); + } + JSONObject responseJson = new JSONObject(sb.toString()); + conn.disconnect(); + return responseJson; + } + } catch (Exception ex) { + logger.error("Error calling Sarvam chat completions API", ex); + return new JSONObject(); + } + } + + private BufferedReader callChatCompletionsStream( + String model, JSONArray messages, JSONArray tools, float temperature) { + try { + String apiUrl = resolveBaseUrl() + "/chat/completions"; + String apiKey = resolveApiKey(); + + JSONObject payload = new JSONObject(); + payload.put("model", model); + payload.put("messages", messages); + payload.put("temperature", temperature); + payload.put("stream", true); + + if (tools != null && tools.length() > 0) { + payload.put("tools", tools); + payload.put("tool_choice", "auto"); + } + + String jsonString = payload.toString(); + + URL url = new URL(apiUrl); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); + conn.setRequestProperty("Accept", "text/event-stream"); + if (apiKey != null && !apiKey.isEmpty()) { + conn.setRequestProperty("Authorization", "Bearer " + apiKey); + } + conn.setDoOutput(true); + conn.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length); + + try (OutputStream os = conn.getOutputStream(); + OutputStreamWriter writer = new OutputStreamWriter(os, "UTF-8")) { + writer.write(jsonString); + writer.flush(); + } + + int responseCode = conn.getResponseCode(); + logger.info("Sarvam streaming response code: {} for model: {}", responseCode, model); + + if (responseCode >= 200 && responseCode < 300) { + return new BufferedReader(new InputStreamReader(conn.getInputStream(), "UTF-8")); + } else { + try (InputStream errorStream = conn.getErrorStream(); + BufferedReader errorReader = + new BufferedReader(new InputStreamReader(errorStream, "UTF-8"))) { + StringBuilder errorResponse = new StringBuilder(); + String errorLine; + while ((errorLine = errorReader.readLine()) != null) { + errorResponse.append(errorLine); + } + logger.error( + "Sarvam streaming request failed: status={} body={}", + responseCode, + errorResponse); + } + conn.disconnect(); + return null; + } + } catch (IOException ex) { + logger.error("Error in Sarvam streaming request", ex); + return null; + } + } + + private LlmResponse createTextResponse(String text, boolean partial) { + return LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText(text)).build()) + .partial(partial) + .build(); + } + + private GenerateContentResponseUsageMetadata extractUsageMetadata(JSONObject response) { + if (response == null || !response.has("usage")) { + return null; + } + try { + JSONObject usage = response.getJSONObject("usage"); + int promptTokens = usage.optInt("prompt_tokens", 0); + int completionTokens = usage.optInt("completion_tokens", 0); + int totalTokens = usage.optInt("total_tokens", promptTokens + completionTokens); + + if (totalTokens > 0 || promptTokens > 0 || completionTokens > 0) { + logger.info( + "Sarvam token counts: prompt={}, completion={}, total={}", + promptTokens, + completionTokens, + totalTokens); + return GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(promptTokens) + .candidatesTokenCount(completionTokens) + .totalTokenCount(totalTokens) + .build(); + } + } catch (Exception e) { + logger.warn("Failed to parse token usage from Sarvam response", e); + } + return null; + } + + private GenerateContentResponseUsageMetadata buildUsageMetadata( + int promptTokens, int completionTokens) { + int totalTokens = promptTokens + completionTokens; + if (totalTokens > 0 || promptTokens > 0 || completionTokens > 0) { + return GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(promptTokens) + .candidatesTokenCount(completionTokens) + .totalTokenCount(totalTokens) + .build(); + } + return null; + } + + static Part openAiMessageToPart(JSONObject message) { + if (message.has("tool_calls")) { + JSONArray toolCalls = message.optJSONArray("tool_calls"); + if (toolCalls != null && toolCalls.length() > 0) { + JSONObject toolCall = toolCalls.getJSONObject(0); + JSONObject function = toolCall.optJSONObject("function"); + if (function != null) { + String name = function.optString("name", null); + String argsStr = function.optString("arguments", "{}"); + if (name != null) { + Map args = new JSONObject(argsStr).toMap(); + FunctionCall fc = FunctionCall.builder().name(name).args(args).build(); + return Part.builder().functionCall(fc).build(); + } + } + } + } + + if (message.has("content") && !message.isNull("content")) { + return Part.builder().text(message.getString("content")).build(); + } + + return Part.builder().text("").build(); + } + + private void normalizeTypeStrings(Map valueDict) { + if (valueDict == null) { + return; + } + if (valueDict.containsKey("type")) { + valueDict.put("type", ((String) valueDict.get("type")).toLowerCase()); + } + if (valueDict.containsKey("items")) { + Object items = valueDict.get("items"); + if (items instanceof Map) { + normalizeTypeStrings((Map) items); + Map itemsMap = (Map) items; + if (itemsMap.containsKey("properties")) { + Map properties = (Map) itemsMap.get("properties"); + if (properties != null) { + for (Object value : properties.values()) { + if (value instanceof Map) { + normalizeTypeStrings((Map) value); + } + } + } + } + } + } + } + + @Override + public BaseLlmConnection connect(LlmRequest llmRequest) { + return new GenericLlmConnection(this, llmRequest); + } +} From 900fb0da8334916d62f498501f0dfaeca27e415d Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Mon, 23 Feb 2026 16:38:45 +0530 Subject: [PATCH 08/11] Harden SarvamBaseLM with production-grade improvements - Add connect/read timeouts (30s/120s) to prevent hanging connections - Add stream_options.include_usage for token tracking in streaming - Fix function call history: serialize assistant tool_calls as proper OpenAI tool_calls array instead of plain text - Forward max_tokens from GenerateContentConfig - Extract shared HTTP connection setup into openConnection() - Make ObjectMapper a static singleton instead of per-tool allocation - Extract streaming finalization into emitFinalStreamResponse() - Log and surface error responses from non-streaming calls - Add proper instanceof checks in normalizeTypeStrings Co-authored-by: Cursor --- .../com/google/adk/models/SarvamBaseLM.java | 366 +++++++++++------- 1 file changed, 224 insertions(+), 142 deletions(-) diff --git a/core/src/main/java/com/google/adk/models/SarvamBaseLM.java b/core/src/main/java/com/google/adk/models/SarvamBaseLM.java index ab9cb6f85..0de9490c6 100644 --- a/core/src/main/java/com/google/adk/models/SarvamBaseLM.java +++ b/core/src/main/java/com/google/adk/models/SarvamBaseLM.java @@ -53,6 +53,11 @@ public class SarvamBaseLM extends BaseLlm { public static final String SARVAM_API_BASE_ENV = "SARVAM_API_BASE"; public static final String SARVAM_API_KEY_ENV = "SARVAM_API_KEY"; private static final String DEFAULT_BASE_URL = "https://api.sarvam.ai/v1"; + private static final int CONNECT_TIMEOUT_MS = 30_000; + private static final int READ_TIMEOUT_MS = 120_000; + + private static final ObjectMapper OBJECT_MAPPER = + new ObjectMapper().registerModule(new Jdk8Module()); private final String baseUrl; private static final Logger logger = LoggerFactory.getLogger(SarvamBaseLM.class); @@ -89,15 +94,10 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre return generateContentStream(llmRequest); } - List contents = llmRequest.contents(); - if (contents.isEmpty() || !Iterables.getLast(contents).role().orElse("").equals("user")) { - Content userContent = Content.fromParts(Part.fromText(CONTINUE_OUTPUT_MESSAGE)); - contents = - Stream.concat(contents.stream(), Stream.of(userContent)).collect(toImmutableList()); - } + List contents = ensureLastContentIsUser(llmRequest.contents()); String systemText = extractSystemText(llmRequest); - JSONArray messages = buildMessages(systemText, llmRequest.contents()); + JSONArray messages = buildMessages(systemText, contents); JSONArray functions = buildTools(llmRequest); boolean lastRespToolExecuted = @@ -105,6 +105,8 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre float temperature = llmRequest.config().flatMap(GenerateContentConfig::temperature).orElse(0.7f); + Optional maxTokens = + llmRequest.config().flatMap(GenerateContentConfig::maxOutputTokens); JSONObject response = callChatCompletions( @@ -112,6 +114,7 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre messages, lastRespToolExecuted ? null : (functions.length() > 0 ? functions : null), temperature, + maxTokens.orElse(-1), false); GenerateContentResponseUsageMetadata usageMetadata = extractUsageMetadata(response); @@ -126,19 +129,21 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre } JSONObject message = choices.getJSONObject(0).getJSONObject("message"); - Part part = openAiMessageToPart(message); + List parts = openAiMessageToParts(message); LlmResponse.Builder responseBuilder = LlmResponse.builder(); - if (part.functionCall().isPresent()) { + boolean hasFunctionCall = parts.stream().anyMatch(p -> p.functionCall().isPresent()); + if (hasFunctionCall) { + Part fcPart = parts.stream().filter(p -> p.functionCall().isPresent()).findFirst().get(); responseBuilder.content( Content.builder() .role("model") - .parts(ImmutableList.of(Part.builder().functionCall(part.functionCall().get()).build())) + .parts(ImmutableList.of(fcPart)) .build()); } else { responseBuilder.content( - Content.builder().role("model").parts(ImmutableList.of(part)).build()); + Content.builder().role("model").parts(ImmutableList.copyOf(parts)).build()); } if (usageMetadata != null) { @@ -149,25 +154,21 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre } private Flowable generateContentStream(LlmRequest llmRequest) { - List contents = llmRequest.contents(); - if (contents.isEmpty() || !Iterables.getLast(contents).role().orElse("").equals("user")) { - Content userContent = Content.fromParts(Part.fromText(CONTINUE_OUTPUT_MESSAGE)); - contents = - Stream.concat(contents.stream(), Stream.of(userContent)).collect(toImmutableList()); - } + List contents = ensureLastContentIsUser(llmRequest.contents()); String systemText = extractSystemText(llmRequest); - JSONArray messages = buildMessages(systemText, llmRequest.contents()); + JSONArray messages = buildMessages(systemText, contents); JSONArray functions = buildTools(llmRequest); - final List finalContents = contents; boolean lastRespToolExecuted = - Iterables.getLast(Iterables.getLast(finalContents).parts().get()) + Iterables.getLast(Iterables.getLast(contents).parts().get()) .functionResponse() .isPresent(); float temperature = llmRequest.config().flatMap(GenerateContentConfig::temperature).orElse(0.7f); + Optional maxTokens = + llmRequest.config().flatMap(GenerateContentConfig::maxOutputTokens); final StringBuilder accumulatedText = new StringBuilder(); final StringBuilder functionCallName = new StringBuilder(); @@ -183,7 +184,8 @@ private Flowable generateContentStream(LlmRequest llmRequest) { this.model(), messages, lastRespToolExecuted ? null : (functions.length() > 0 ? functions : null), - temperature), + temperature, + maxTokens.orElse(-1)), (reader, emitter) -> { try { if (reader == null || streamCompleted.get()) { @@ -193,56 +195,23 @@ private Flowable generateContentStream(LlmRequest llmRequest) { String line = reader.readLine(); if (line == null) { - if (accumulatedText.length() > 0) { - emitter.onNext(createTextResponse(accumulatedText.toString(), false)); - } + emitFinalStreamResponse( + emitter, accumulatedText, inFunctionCall, functionCallName, functionCallArgs, + inputTokens.get(), outputTokens.get()); emitter.onComplete(); return; } - if (line.isEmpty() || line.equals("data: [DONE]")) { - if (line.equals("data: [DONE]")) { - streamCompleted.set(true); - GenerateContentResponseUsageMetadata usageMetadata = - buildUsageMetadata(inputTokens.get(), outputTokens.get()); - - if (inFunctionCall.get() && functionCallName.length() > 0) { - try { - Map args = new JSONObject(functionCallArgs.toString()).toMap(); - FunctionCall fc = - FunctionCall.builder().name(functionCallName.toString()).args(args).build(); - Part part = Part.builder().functionCall(fc).build(); - - LlmResponse.Builder funcResponseBuilder = - LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts(ImmutableList.of(part)) - .build()); - if (usageMetadata != null) { - funcResponseBuilder.usageMetadata(usageMetadata); - } - emitter.onNext(funcResponseBuilder.build()); - } catch (Exception funcEx) { - logger.error("Error creating function call response", funcEx); - } - } else if (accumulatedText.length() > 0) { - LlmResponse.Builder finalBuilder = - LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts(Part.fromText(accumulatedText.toString())) - .build()) - .partial(false); - if (usageMetadata != null) { - finalBuilder.usageMetadata(usageMetadata); - } - emitter.onNext(finalBuilder.build()); - } - emitter.onComplete(); - } + if (line.isEmpty()) { + return; + } + + if (line.equals("data: [DONE]")) { + streamCompleted.set(true); + emitFinalStreamResponse( + emitter, accumulatedText, inFunctionCall, functionCallName, functionCallArgs, + inputTokens.get(), outputTokens.get()); + emitter.onComplete(); return; } @@ -259,7 +228,7 @@ private Flowable generateContentStream(LlmRequest llmRequest) { return; } - if (chunk.has("usage")) { + if (chunk.has("usage") && !chunk.isNull("usage")) { JSONObject usage = chunk.getJSONObject("usage"); inputTokens.set(usage.optInt("prompt_tokens", 0)); outputTokens.set(usage.optInt("completion_tokens", 0)); @@ -316,22 +285,76 @@ private Flowable generateContentStream(LlmRequest llmRequest) { }); } - private String extractSystemText(LlmRequest llmRequest) { - Optional configOpt = llmRequest.config(); - if (configOpt.isPresent()) { - Optional systemInstructionOpt = configOpt.get().systemInstruction(); - if (systemInstructionOpt.isPresent()) { - String text = - systemInstructionOpt.get().parts().orElse(ImmutableList.of()).stream() - .filter(p -> p.text().isPresent()) - .map(p -> p.text().get()) - .collect(Collectors.joining("\n")); - if (!text.isEmpty()) { - return text; + private void emitFinalStreamResponse( + io.reactivex.rxjava3.core.Emitter emitter, + StringBuilder accumulatedText, + AtomicBoolean inFunctionCall, + StringBuilder functionCallName, + StringBuilder functionCallArgs, + int promptTokens, + int completionTokens) { + + GenerateContentResponseUsageMetadata usageMetadata = + buildUsageMetadata(promptTokens, completionTokens); + + if (inFunctionCall.get() && functionCallName.length() > 0) { + try { + String argsString = functionCallArgs.length() > 0 ? functionCallArgs.toString() : "{}"; + Map args = new JSONObject(argsString).toMap(); + FunctionCall fc = + FunctionCall.builder().name(functionCallName.toString()).args(args).build(); + Part part = Part.builder().functionCall(fc).build(); + + LlmResponse.Builder builder = + LlmResponse.builder() + .content( + Content.builder().role("model").parts(ImmutableList.of(part)).build()); + if (usageMetadata != null) { + builder.usageMetadata(usageMetadata); } + emitter.onNext(builder.build()); + } catch (Exception funcEx) { + logger.error("Error creating function call response from stream", funcEx); } + } else if (accumulatedText.length() > 0) { + LlmResponse.Builder builder = + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText(accumulatedText.toString())) + .build()) + .partial(false); + if (usageMetadata != null) { + builder.usageMetadata(usageMetadata); + } + emitter.onNext(builder.build()); + } + } + + // ========== Request Building ========== + + private List ensureLastContentIsUser(List contents) { + if (contents.isEmpty() || !Iterables.getLast(contents).role().orElse("").equals("user")) { + Content userContent = Content.fromParts(Part.fromText(CONTINUE_OUTPUT_MESSAGE)); + return Stream.concat(contents.stream(), Stream.of(userContent)).collect(toImmutableList()); } - return ""; + return contents; + } + + private String extractSystemText(LlmRequest llmRequest) { + return llmRequest + .config() + .flatMap(GenerateContentConfig::systemInstruction) + .flatMap(Content::parts) + .map( + parts -> + parts.stream() + .filter(p -> p.text().isPresent()) + .map(p -> p.text().get()) + .collect(Collectors.joining("\n"))) + .filter(text -> !text.isEmpty()) + .orElse(""); } private JSONArray buildMessages(String systemText, List contents) { @@ -345,25 +368,54 @@ private JSONArray buildMessages(String systemText, List contents) { } for (Content item : contents) { - JSONObject msg = new JSONObject(); String role = item.role().orElse("user"); - msg.put("role", role.equals("model") ? "assistant" : role); - - if (item.parts().isPresent() && !item.parts().get().isEmpty()) { - Part firstPart = item.parts().get().get(0); - if (firstPart.functionResponse().isPresent()) { - msg.put( - "content", - new JSONObject(firstPart.functionResponse().get().response().get()).toString()); - msg.put("role", "tool"); - msg.put("tool_call_id", firstPart.functionResponse().get().name().orElse("unknown")); - } else { - msg.put("content", item.text()); - } + List parts = item.parts().orElse(ImmutableList.of()); + + if (parts.isEmpty()) { + JSONObject msg = new JSONObject(); + msg.put("role", role.equals("model") ? "assistant" : role); + msg.put("content", item.text()); + messages.put(msg); + continue; + } + + Part firstPart = parts.get(0); + + if (firstPart.functionResponse().isPresent()) { + JSONObject msg = new JSONObject(); + msg.put("role", "tool"); + msg.put( + "tool_call_id", + firstPart.functionResponse().get().name().orElse("call_unknown")); + msg.put( + "content", + new JSONObject(firstPart.functionResponse().get().response().get()).toString()); + messages.put(msg); + } else if (firstPart.functionCall().isPresent()) { + // Assistant message that previously requested a tool call + FunctionCall fc = firstPart.functionCall().get(); + JSONObject msg = new JSONObject(); + msg.put("role", "assistant"); + msg.put("content", JSONObject.NULL); + + JSONArray toolCalls = new JSONArray(); + JSONObject toolCall = new JSONObject(); + toolCall.put("id", "call_" + fc.name().orElse("unknown")); + toolCall.put("type", "function"); + JSONObject function = new JSONObject(); + function.put("name", fc.name().orElse("")); + function.put("arguments", new JSONObject(fc.args().orElse(Map.of())).toString()); + toolCall.put("function", function); + toolCalls.put(toolCall); + msg.put("tool_calls", toolCalls); + + messages.put(msg); } else { + JSONObject msg = new JSONObject(); + msg.put("role", role.equals("model") ? "assistant" : role); msg.put("content", item.text()); + messages.put(msg); } - messages.put(msg); } return messages; } @@ -394,15 +446,12 @@ private JSONArray buildTools(LlmRequest llmRequest) { Optional> propsOpt = paramsSchema.properties(); if (propsOpt.isPresent()) { Map propsMap = new HashMap<>(); - ObjectMapper mapper = new ObjectMapper(); - mapper.registerModule(new Jdk8Module()); - propsOpt .get() .forEach( (key, schema) -> { Map schemaMap = - mapper.convertValue( + OBJECT_MAPPER.convertValue( schema, new TypeReference>() {}); normalizeTypeStrings(schemaMap); propsMap.put(key, schemaMap); @@ -424,8 +473,15 @@ private JSONArray buildTools(LlmRequest llmRequest) { return functions; } + // ========== HTTP Transport ========== + private JSONObject callChatCompletions( - String model, JSONArray messages, JSONArray tools, float temperature, boolean stream) { + String model, + JSONArray messages, + JSONArray tools, + float temperature, + int maxTokens, + boolean stream) { try { String apiUrl = resolveBaseUrl() + "/chat/completions"; String apiKey = resolveApiKey(); @@ -436,22 +492,19 @@ private JSONObject callChatCompletions( payload.put("temperature", temperature); payload.put("stream", stream); + if (maxTokens > 0) { + payload.put("max_tokens", maxTokens); + } + if (tools != null && tools.length() > 0) { payload.put("tools", tools); payload.put("tool_choice", "auto"); } String jsonString = payload.toString(); - logger.debug("Sarvam request: {}", jsonString); - - URL url = new URL(apiUrl); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); - conn.setRequestMethod("POST"); - conn.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); - if (apiKey != null && !apiKey.isEmpty()) { - conn.setRequestProperty("Authorization", "Bearer " + apiKey); - } - conn.setDoOutput(true); + logger.debug("Sarvam request payload size: {} bytes", jsonString.length()); + + HttpURLConnection conn = openConnection(apiUrl, apiKey); conn.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length); try (OutputStream os = conn.getOutputStream(); @@ -465,6 +518,7 @@ private JSONObject callChatCompletions( InputStream inputStream = (responseCode < 400) ? conn.getInputStream() : conn.getErrorStream(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, "UTF-8"))) { StringBuilder sb = new StringBuilder(); @@ -472,6 +526,13 @@ private JSONObject callChatCompletions( while ((line = reader.readLine()) != null) { sb.append(line); } + + if (responseCode >= 400) { + logger.error( + "Sarvam API error: status={} body={}", responseCode, sb); + return new JSONObject().put("error", sb.toString()); + } + JSONObject responseJson = new JSONObject(sb.toString()); conn.disconnect(); return responseJson; @@ -483,7 +544,7 @@ private JSONObject callChatCompletions( } private BufferedReader callChatCompletionsStream( - String model, JSONArray messages, JSONArray tools, float temperature) { + String model, JSONArray messages, JSONArray tools, float temperature, int maxTokens) { try { String apiUrl = resolveBaseUrl() + "/chat/completions"; String apiKey = resolveApiKey(); @@ -494,6 +555,15 @@ private BufferedReader callChatCompletionsStream( payload.put("temperature", temperature); payload.put("stream", true); + // Request token usage in streaming responses + JSONObject streamOptions = new JSONObject(); + streamOptions.put("include_usage", true); + payload.put("stream_options", streamOptions); + + if (maxTokens > 0) { + payload.put("max_tokens", maxTokens); + } + if (tools != null && tools.length() > 0) { payload.put("tools", tools); payload.put("tool_choice", "auto"); @@ -501,15 +571,8 @@ private BufferedReader callChatCompletionsStream( String jsonString = payload.toString(); - URL url = new URL(apiUrl); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); - conn.setRequestMethod("POST"); - conn.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); + HttpURLConnection conn = openConnection(apiUrl, apiKey); conn.setRequestProperty("Accept", "text/event-stream"); - if (apiKey != null && !apiKey.isEmpty()) { - conn.setRequestProperty("Authorization", "Bearer " + apiKey); - } - conn.setDoOutput(true); conn.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length); try (OutputStream os = conn.getOutputStream(); @@ -533,9 +596,7 @@ private BufferedReader callChatCompletionsStream( errorResponse.append(errorLine); } logger.error( - "Sarvam streaming request failed: status={} body={}", - responseCode, - errorResponse); + "Sarvam streaming failed: status={} body={}", responseCode, errorResponse); } conn.disconnect(); return null; @@ -546,6 +607,22 @@ private BufferedReader callChatCompletionsStream( } } + private HttpURLConnection openConnection(String apiUrl, String apiKey) throws IOException { + URL url = new URL(apiUrl); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); + conn.setConnectTimeout(CONNECT_TIMEOUT_MS); + conn.setReadTimeout(READ_TIMEOUT_MS); + conn.setDoOutput(true); + if (apiKey != null && !apiKey.isEmpty()) { + conn.setRequestProperty("Authorization", "Bearer " + apiKey); + } + return conn; + } + + // ========== Response Parsing ========== + private LlmResponse createTextResponse(String text, boolean partial) { return LlmResponse.builder() .content(Content.builder().role("model").parts(Part.fromText(text)).build()) @@ -565,7 +642,7 @@ private GenerateContentResponseUsageMetadata extractUsageMetadata(JSONObject res if (totalTokens > 0 || promptTokens > 0 || completionTokens > 0) { logger.info( - "Sarvam token counts: prompt={}, completion={}, total={}", + "Sarvam token usage: prompt={}, completion={}, total={}", promptTokens, completionTokens, totalTokens); @@ -594,7 +671,13 @@ private GenerateContentResponseUsageMetadata buildUsageMetadata( return null; } - static Part openAiMessageToPart(JSONObject message) { + /** + * Converts an OpenAI-format message JSON to ADK Part(s). + * Handles both text content and tool_calls in a single message. + */ + static List openAiMessageToParts(JSONObject message) { + List parts = new ArrayList<>(); + if (message.has("tool_calls")) { JSONArray toolCalls = message.optJSONArray("tool_calls"); if (toolCalls != null && toolCalls.length() > 0) { @@ -606,39 +689,38 @@ static Part openAiMessageToPart(JSONObject message) { if (name != null) { Map args = new JSONObject(argsStr).toMap(); FunctionCall fc = FunctionCall.builder().name(name).args(args).build(); - return Part.builder().functionCall(fc).build(); + parts.add(Part.builder().functionCall(fc).build()); + return parts; } } } } if (message.has("content") && !message.isNull("content")) { - return Part.builder().text(message.getString("content")).build(); + parts.add(Part.builder().text(message.getString("content")).build()); + } else { + parts.add(Part.builder().text("").build()); } - return Part.builder().text("").build(); + return parts; } + @SuppressWarnings("unchecked") private void normalizeTypeStrings(Map valueDict) { if (valueDict == null) { return; } - if (valueDict.containsKey("type")) { + if (valueDict.containsKey("type") && valueDict.get("type") instanceof String) { valueDict.put("type", ((String) valueDict.get("type")).toLowerCase()); } - if (valueDict.containsKey("items")) { - Object items = valueDict.get("items"); - if (items instanceof Map) { - normalizeTypeStrings((Map) items); - Map itemsMap = (Map) items; - if (itemsMap.containsKey("properties")) { - Map properties = (Map) itemsMap.get("properties"); - if (properties != null) { - for (Object value : properties.values()) { - if (value instanceof Map) { - normalizeTypeStrings((Map) value); - } - } + if (valueDict.containsKey("items") && valueDict.get("items") instanceof Map) { + Map itemsMap = (Map) valueDict.get("items"); + normalizeTypeStrings(itemsMap); + if (itemsMap.containsKey("properties") && itemsMap.get("properties") instanceof Map) { + Map properties = (Map) itemsMap.get("properties"); + for (Object value : properties.values()) { + if (value instanceof Map) { + normalizeTypeStrings((Map) value); } } } From 36b4989e4dd5cfe6a2f4e06fc4d7600c35a9984a Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Mon, 23 Feb 2026 17:50:22 +0530 Subject: [PATCH 09/11] Add SarvamBaseLM unit tests and API key validation warning - 10 unit tests covering openAiMessageToParts (text, null, tool calls, empty args, priority, fallback), constructor, and connect() - Warn at construction time if SARVAM_API_KEY env var is missing Co-authored-by: Cursor --- .../com/google/adk/models/SarvamBaseLM.java | 57 +++--- .../google/adk/models/SarvamBaseLMTest.java | 174 ++++++++++++++++++ 2 files changed, 208 insertions(+), 23 deletions(-) create mode 100644 core/src/test/java/com/google/adk/models/SarvamBaseLMTest.java diff --git a/core/src/main/java/com/google/adk/models/SarvamBaseLM.java b/core/src/main/java/com/google/adk/models/SarvamBaseLM.java index 0de9490c6..487dad652 100644 --- a/core/src/main/java/com/google/adk/models/SarvamBaseLM.java +++ b/core/src/main/java/com/google/adk/models/SarvamBaseLM.java @@ -6,7 +6,6 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; -import com.google.adk.tools.BaseTool; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.genai.types.Content; @@ -69,11 +68,23 @@ public class SarvamBaseLM extends BaseLlm { public SarvamBaseLM(String model) { super(model); this.baseUrl = null; + warnIfApiKeyMissing(); } public SarvamBaseLM(String model, String baseUrl) { super(model); this.baseUrl = baseUrl; + warnIfApiKeyMissing(); + } + + private void warnIfApiKeyMissing() { + String apiKey = System.getenv(SARVAM_API_KEY_ENV); + if (apiKey == null || apiKey.isBlank()) { + logger.warn( + "SARVAM_API_KEY environment variable is not set. " + + "Sarvam API calls for model '{}' will fail with 401 Unauthorized.", + model()); + } } private String resolveBaseUrl() { @@ -137,10 +148,7 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre if (hasFunctionCall) { Part fcPart = parts.stream().filter(p -> p.functionCall().isPresent()).findFirst().get(); responseBuilder.content( - Content.builder() - .role("model") - .parts(ImmutableList.of(fcPart)) - .build()); + Content.builder().role("model").parts(ImmutableList.of(fcPart)).build()); } else { responseBuilder.content( Content.builder().role("model").parts(ImmutableList.copyOf(parts)).build()); @@ -161,9 +169,7 @@ private Flowable generateContentStream(LlmRequest llmRequest) { JSONArray functions = buildTools(llmRequest); boolean lastRespToolExecuted = - Iterables.getLast(Iterables.getLast(contents).parts().get()) - .functionResponse() - .isPresent(); + Iterables.getLast(Iterables.getLast(contents).parts().get()).functionResponse().isPresent(); float temperature = llmRequest.config().flatMap(GenerateContentConfig::temperature).orElse(0.7f); @@ -196,8 +202,13 @@ private Flowable generateContentStream(LlmRequest llmRequest) { String line = reader.readLine(); if (line == null) { emitFinalStreamResponse( - emitter, accumulatedText, inFunctionCall, functionCallName, functionCallArgs, - inputTokens.get(), outputTokens.get()); + emitter, + accumulatedText, + inFunctionCall, + functionCallName, + functionCallArgs, + inputTokens.get(), + outputTokens.get()); emitter.onComplete(); return; } @@ -209,8 +220,13 @@ private Flowable generateContentStream(LlmRequest llmRequest) { if (line.equals("data: [DONE]")) { streamCompleted.set(true); emitFinalStreamResponse( - emitter, accumulatedText, inFunctionCall, functionCallName, functionCallArgs, - inputTokens.get(), outputTokens.get()); + emitter, + accumulatedText, + inFunctionCall, + functionCallName, + functionCallArgs, + inputTokens.get(), + outputTokens.get()); emitter.onComplete(); return; } @@ -307,8 +323,7 @@ private void emitFinalStreamResponse( LlmResponse.Builder builder = LlmResponse.builder() - .content( - Content.builder().role("model").parts(ImmutableList.of(part)).build()); + .content(Content.builder().role("model").parts(ImmutableList.of(part)).build()); if (usageMetadata != null) { builder.usageMetadata(usageMetadata); } @@ -384,9 +399,7 @@ private JSONArray buildMessages(String systemText, List contents) { if (firstPart.functionResponse().isPresent()) { JSONObject msg = new JSONObject(); msg.put("role", "tool"); - msg.put( - "tool_call_id", - firstPart.functionResponse().get().name().orElse("call_unknown")); + msg.put("tool_call_id", firstPart.functionResponse().get().name().orElse("call_unknown")); msg.put( "content", new JSONObject(firstPart.functionResponse().get().response().get()).toString()); @@ -528,8 +541,7 @@ private JSONObject callChatCompletions( } if (responseCode >= 400) { - logger.error( - "Sarvam API error: status={} body={}", responseCode, sb); + logger.error("Sarvam API error: status={} body={}", responseCode, sb); return new JSONObject().put("error", sb.toString()); } @@ -595,8 +607,7 @@ private BufferedReader callChatCompletionsStream( while ((errorLine = errorReader.readLine()) != null) { errorResponse.append(errorLine); } - logger.error( - "Sarvam streaming failed: status={} body={}", responseCode, errorResponse); + logger.error("Sarvam streaming failed: status={} body={}", responseCode, errorResponse); } conn.disconnect(); return null; @@ -672,8 +683,8 @@ private GenerateContentResponseUsageMetadata buildUsageMetadata( } /** - * Converts an OpenAI-format message JSON to ADK Part(s). - * Handles both text content and tool_calls in a single message. + * Converts an OpenAI-format message JSON to ADK Part(s). Handles both text content and tool_calls + * in a single message. */ static List openAiMessageToParts(JSONObject message) { List parts = new ArrayList<>(); diff --git a/core/src/test/java/com/google/adk/models/SarvamBaseLMTest.java b/core/src/test/java/com/google/adk/models/SarvamBaseLMTest.java new file mode 100644 index 000000000..972253b97 --- /dev/null +++ b/core/src/test/java/com/google/adk/models/SarvamBaseLMTest.java @@ -0,0 +1,174 @@ +package com.google.adk.models; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.genai.types.FunctionCall; +import com.google.genai.types.Part; +import java.util.List; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SarvamBaseLMTest { + + // ========== openAiMessageToParts tests ========== + + @Test + public void openAiMessageToParts_textContent_returnsTextPart() { + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("content", "Hello world"); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue("Hello world"); + assertThat(parts.get(0).functionCall()).isEmpty(); + } + + @Test + public void openAiMessageToParts_nullContent_returnsEmptyTextPart() { + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("content", JSONObject.NULL); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue(""); + } + + @Test + public void openAiMessageToParts_missingContent_returnsEmptyTextPart() { + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue(""); + } + + @Test + public void openAiMessageToParts_toolCall_returnsFunctionCallPart() { + JSONObject function = new JSONObject(); + function.put("name", "getBusSearch"); + function.put("arguments", "{\"source\":\"Bangalore\",\"dest\":\"Chennai\"}"); + + JSONObject toolCall = new JSONObject(); + toolCall.put("id", "call_abc123"); + toolCall.put("type", "function"); + toolCall.put("function", function); + + JSONArray toolCalls = new JSONArray(); + toolCalls.put(toolCall); + + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("content", JSONObject.NULL); + message.put("tool_calls", toolCalls); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + + FunctionCall fc = parts.get(0).functionCall().get(); + assertThat(fc.name()).hasValue("getBusSearch"); + assertThat(fc.args()).isPresent(); + assertThat(fc.args().get()).containsEntry("source", "Bangalore"); + assertThat(fc.args().get()).containsEntry("dest", "Chennai"); + } + + @Test + public void openAiMessageToParts_toolCallWithEmptyArgs_returnsFunctionCallWithEmptyMap() { + JSONObject function = new JSONObject(); + function.put("name", "getOffers"); + function.put("arguments", "{}"); + + JSONObject toolCall = new JSONObject(); + toolCall.put("id", "call_xyz"); + toolCall.put("type", "function"); + toolCall.put("function", function); + + JSONArray toolCalls = new JSONArray(); + toolCalls.put(toolCall); + + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("tool_calls", toolCalls); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + assertThat(parts.get(0).functionCall().get().name()).hasValue("getOffers"); + assertThat(parts.get(0).functionCall().get().args().get()).isEmpty(); + } + + @Test + public void openAiMessageToParts_toolCallTakesPriorityOverContent() { + JSONObject function = new JSONObject(); + function.put("name", "search"); + function.put("arguments", "{}"); + + JSONObject toolCall = new JSONObject(); + toolCall.put("id", "call_1"); + toolCall.put("type", "function"); + toolCall.put("function", function); + + JSONArray toolCalls = new JSONArray(); + toolCalls.put(toolCall); + + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("content", "I'll search for you"); + message.put("tool_calls", toolCalls); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + assertThat(parts.get(0).functionCall().get().name()).hasValue("search"); + } + + @Test + public void openAiMessageToParts_emptyToolCalls_fallsBackToContent() { + JSONObject message = new JSONObject(); + message.put("role", "assistant"); + message.put("content", "Here are the results"); + message.put("tool_calls", new JSONArray()); + + List parts = SarvamBaseLM.openAiMessageToParts(message); + + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue("Here are the results"); + } + + // ========== Constructor / config tests ========== + + @Test + public void constructor_setsModelName() { + SarvamBaseLM llm = new SarvamBaseLM("sarvam-m"); + assertThat(llm.model()).isEqualTo("sarvam-m"); + } + + @Test + public void constructor_withBaseUrl_setsModelName() { + SarvamBaseLM llm = new SarvamBaseLM("sarvam-m", "https://custom.api.com/v1"); + assertThat(llm.model()).isEqualTo("sarvam-m"); + } + + @Test + public void connect_returnsGenericLlmConnection() { + SarvamBaseLM llm = new SarvamBaseLM("sarvam-m"); + LlmRequest request = LlmRequest.builder().build(); + + BaseLlmConnection connection = llm.connect(request); + + assertThat(connection).isInstanceOf(GenericLlmConnection.class); + } +} From 67de51c5d1a949cbe146d7c9eb65c1c5953702bf Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Mon, 23 Feb 2026 21:25:14 +0530 Subject: [PATCH 10/11] Add CAPABILITIES.md documenting all Sarvam AI integration features Covers Chat (LLM), STT, TTS, Vision, Live Connections, retry logic, configuration, authentication, test coverage, and RAE integration. Co-authored-by: Cursor --- contrib/sarvam-ai/CAPABILITIES.md | 221 ++++++++++++++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 contrib/sarvam-ai/CAPABILITIES.md diff --git a/contrib/sarvam-ai/CAPABILITIES.md b/contrib/sarvam-ai/CAPABILITIES.md new file mode 100644 index 000000000..19e765c84 --- /dev/null +++ b/contrib/sarvam-ai/CAPABILITIES.md @@ -0,0 +1,221 @@ +# Sarvam AI - ADK Integration Capabilities + +## Overview + +The Sarvam AI module provides a comprehensive, production-grade integration of Sarvam AI services into the Google Agent Development Kit (ADK) for Java. It spans five service domains -- Chat, Speech-to-Text, Text-to-Speech, Vision, and Live Connections -- covering both REST and WebSocket protocols with full observability, resilience, and multi-turn agentic support. + +**Module path:** `contrib/sarvam-ai` +**Package:** `com.google.adk.models.sarvamai` +**Branch:** `sarvam-ai` + +--- + +## 1. Chat Completions (LLM) + +**Class:** `SarvamAi` extends `BaseLlm` +**Endpoint:** `POST /v1/chat/completions` (OpenAI-compatible) + +| Capability | Details | +|---|---| +| Blocking (non-streaming) | Full request/response cycle via `generateContent(request, false)` | +| SSE Streaming | Real-time token-by-token delivery via `generateContent(request, true)` with backpressure (RxJava `Flowable`) | +| Function / Tool Calling | ADK `FunctionDeclaration` serialized to OpenAI `tools` JSON with `tool_choice: auto` | +| Multi-turn Tool History | Prior `tool_calls` correctly formatted as assistant messages with `tool_call_id`, `function.name`, `function.arguments`; tool responses sent as `role: tool` | +| Streaming Function Calls | Chunked `name` and `arguments` accumulated across SSE deltas, emitted as final `FunctionCall` Part | +| Token Usage Tracking | `prompt_tokens`, `completion_tokens`, `total_tokens` extracted for both blocking and streaming modes. Streaming uses `stream_options: {"include_usage": true}` | +| System Instructions | ADK `GenerateContentConfig.systemInstruction` mapped to OpenAI `system` role message | +| Temperature Control | Forwarded from `GenerateContentConfig.temperature` (default 0.7) | +| Max Output Tokens | `GenerateContentConfig.maxOutputTokens` forwarded as `max_tokens` | +| Top-P Sampling | Configurable via `SarvamAiConfig.topP()` | +| Frequency / Presence Penalty | Configurable via `SarvamAiConfig` builder | +| Reasoning Effort | Sarvam-specific `reasoning_effort` parameter (low / medium / high) | +| Wiki Grounding | Sarvam-specific `wiki_grounding` toggle for factual grounding | +| Role Translation | ADK `model` -> OpenAI `assistant`, `user` -> `user`, `functionResponse` -> `tool` | +| Schema Normalization | Type strings lowercased, nested `items.properties` recursively normalized for OpenAI schema compatibility | +| Graceful Degradation | Empty choices return empty text response instead of crashing | + +### Dual Implementation + +| Implementation | Location | Use Case | +|---|---|---| +| `SarvamBaseLM` | `core/src/main/java/.../models/SarvamBaseLM.java` | Lightweight, env-var driven. Used by `AgentModelConfig` and `LlmRegistry` for `Sarvam\|model` config strings | +| `SarvamAi` | `contrib/sarvam-ai/src/.../SarvamAi.java` | Full-featured, Builder-pattern, OkHttp-based. Supports all chat parameters plus subservice access | + +--- + +## 2. Speech-to-Text (STT) + +**Class:** `SarvamSttService` implements `TranscriptionService` +**Model:** `saaras:v3` + +| Capability | Details | +|---|---| +| REST Synchronous | `transcribe(byte[] audioData, TranscriptionConfig)` via `POST /speech-to-text` with multipart/form-data | +| REST Async | `transcribeAsync()` executes on RxJava IO scheduler | +| WebSocket Streaming | Real-time streaming via `wss://api.sarvam.ai/speech-to-text/streaming` with VAD (Voice Activity Detection) signals | +| Transcription Modes | `transcribe`, `translate`, `verbatim`, `translit`, `codemix` | +| Language Detection | Auto-detection supported; explicit BCP-47 codes (e.g., `hi-IN`, `en-IN`) also accepted | +| VAD Signals | `speech_start` and `speech_end` events for voice activity boundaries | +| ADK TranscriptionService | Full implementation of ADK's `TranscriptionService` interface including `isAvailable()`, `getServiceType()`, `getHealth()` | + +--- + +## 3. Text-to-Speech (TTS) + +**Class:** `SarvamTtsService` +**Model:** `bulbul:v3` + +| Capability | Details | +|---|---| +| REST Synchronous | `synthesize(text, languageCode)` returns decoded WAV audio bytes | +| REST Async | `synthesizeAsync()` on IO scheduler | +| WebSocket Streaming | `synthesizeStream()` via `wss://api.sarvam.ai/text-to-speech/streaming` for low-latency progressive audio chunk delivery | +| 30+ Speaker Voices | Configurable via `SarvamAiConfig.ttsSpeaker()` (default: `shubh`) | +| Pace Control | Adjustable speech pace (0.5x to 2.0x) | +| Sample Rate | Configurable output sample rate | +| Base64 Decoding | Audio chunks automatically decoded from base64 to raw bytes | +| WebSocket Lifecycle | Config frame -> text frame -> flush frame -> audio chunks -> final event -> close | + +--- + +## 4. Vision / Document Intelligence + +**Class:** `SarvamVisionService` +**Model:** Sarvam Vision 3B VLM + +| Capability | Details | +|---|---| +| Multi-Language OCR | 23 languages (22 Indian + English) | +| Input Formats | PDF, PNG, JPG, ZIP | +| Output Formats | HTML or Markdown | +| Async Job Pipeline | `createJob` -> `uploadDocument` (presigned URL) -> `startJob` -> `getJobStatus` (poll) -> `downloadResults` | +| Convenience Method | `processDocument(filePath, languageCode, outputFormat)` runs the full pipeline with adaptive exponential backoff polling | +| Polling Backoff | Starts at 2s, doubles up to 10s cap, max 60 polls (~2 min timeout) | + +--- + +## 5. Live Bidirectional Connection + +**Class:** `SarvamAiLlmConnection` implements `BaseLlmConnection` + +| Capability | Details | +|---|---| +| Multi-Turn Context | Maintains conversation history across turns, accumulates full model responses | +| sendHistory | Replace full conversation context | +| sendContent | Append a single turn and trigger streaming response | +| receive | Returns `Flowable` via `PublishSubject` for reactive consumers | +| Thread Safety | History list synchronized for concurrent access | +| Realtime Guard | `sendRealtime(Blob)` throws `UnsupportedOperationException` with guidance to use STT/TTS services | + +--- + +## 6. Resilience & Configuration + +### Retry with Exponential Backoff + +**Class:** `SarvamRetryInterceptor` (OkHttp `Interceptor`) + +| Parameter | Value | +|---|---| +| Retryable codes | 429 (rate limit), 503, 5xx (server errors) | +| Base delay | 500ms | +| Max delay | 30s | +| Strategy | Exponential backoff with 20% jitter | +| Default max retries | 3 | + +### Immutable Configuration + +**Class:** `SarvamAiConfig` (Builder pattern) + +| Parameter | Default | +|---|---| +| Chat endpoint | `https://api.sarvam.ai/v1/chat/completions` | +| STT endpoint | `https://api.sarvam.ai/speech-to-text` | +| STT WebSocket | `wss://api.sarvam.ai/speech-to-text/streaming` | +| TTS endpoint | `https://api.sarvam.ai/text-to-speech` | +| TTS WebSocket | `wss://api.sarvam.ai/text-to-speech/streaming` | +| Vision endpoint | `https://api.sarvam.ai/document-intelligence` | +| Connect timeout | 30s | +| Read timeout | 120s | +| Max retries | 3 | +| API key resolution | Explicit value > `SARVAM_API_KEY` env var | + +### Structured Error Handling + +**Class:** `SarvamAiException` extends `RuntimeException` + +| Field | Purpose | +|---|---| +| `statusCode` | HTTP status code from API | +| `errorCode` | Sarvam-specific error code | +| `requestId` | Sarvam request ID for support tracing | +| `isRetryable()` | Programmatic check (429, 503, 5xx) | + +--- + +## 7. Authentication + +| Method | Header | Used By | +|---|---|---| +| API Subscription Key | `api-subscription-key: ` | `SarvamAi`, STT, TTS, Vision (contrib module) | +| Bearer Token | `Authorization: Bearer ` | `SarvamBaseLM` (core module, OpenAI-compatible) | +| Key Resolution | `SARVAM_API_KEY` env var or explicit via Builder | Both | +| Fail-Fast Validation | Warning logged at construction if key is missing | `SarvamBaseLM` | + +--- + +## 8. Test Coverage + +| Test Class | Tests | Scope | +|---|---|---| +| `SarvamBaseLMTest` | 10 | Response parsing (text, null, tool calls), construction, connection type | +| `SarvamAiTest` | - | Chat completion blocking and streaming | +| `SarvamAiConfigTest` | - | Config builder validation, defaults, env var resolution | +| `ChatRequestTest` | - | Request serialization from LlmRequest | +| `SarvamSttServiceTest` | - | STT REST and WebSocket transcription | +| `SarvamTtsServiceTest` | - | TTS REST and WebSocket synthesis | +| `SarvamRetryInterceptorTest` | - | Retry logic, delay calculation, jitter | +| `SarvamIntegrationTest` (rae) | 20 | End-to-end config wiring across properties, YAML, LlmRegistry | + +--- + +## 9. RAE Integration (Consumer Project) + +| Integration Point | Mechanism | File | +|---|---|---| +| Code-based agents | `AgentModelConfig` recognizes `Sarvam\|` prefix, instantiates `SarvamBaseLM` | `AgentModelConfig.java` | +| YAML-based agents | `LlmRegistry.registerLlm("Sarvam\\|.*", ...)` factory | `ApplicationRegistry.java` | +| Model metadata | `sarvam:` provider in `models.yaml` with feature declarations | `models.yaml` | +| Config format | `Sarvam\|sarvam-m` -- single string works across both paths | `agent-models.properties` + `*.yaml` | +| Global coverage | 43 code-based + 28 YAML agent configs switched to Sarvam | All agent config files | + +--- + +## Architecture Summary + +``` +contrib/sarvam-ai/ + src/main/java/com/google/adk/models/sarvamai/ + SarvamAi.java # BaseLlm (chat, Builder pattern, OkHttp) + SarvamAiConfig.java # Immutable config for all services + SarvamAiException.java # Structured error with status/code/requestId + SarvamAiLlmConnection.java # Live bidirectional multi-turn connection + SarvamRetryInterceptor.java # Exponential backoff with jitter + chat/ + ChatRequest.java # OpenAI-compatible request model + ChatResponse.java # Response deserialization + ChatChoice.java # Choice wrapper + ChatMessage.java # Message model + ChatUsage.java # Token usage tracking + stt/ + SarvamSttService.java # REST + WebSocket STT (TranscriptionService) + tts/ + SarvamTtsService.java # REST + WebSocket TTS + TtsRequest.java # TTS request model + TtsResponse.java # TTS response model + vision/ + SarvamVisionService.java # Async job pipeline for document OCR + +core/src/main/java/com/google/adk/models/ + SarvamBaseLM.java # Lightweight BaseLlm for agent config integration +``` From 8802ce80b28ad2b58d26a7306a40afc29c878b6d Mon Sep 17 00:00:00 2001 From: Sandeep Belgavi Date: Tue, 24 Feb 2026 17:22:42 +0530 Subject: [PATCH 11/11] Add @author Sandeep Belgavi to all Sarvam AI source and test files Tags added to 23 files across contrib/sarvam-ai (main + test) and core SarvamBaseLM/SarvamBaseLMTest. Co-authored-by: Cursor --- .../main/java/com/google/adk/models/sarvamai/SarvamAi.java | 2 ++ .../java/com/google/adk/models/sarvamai/SarvamAiConfig.java | 2 ++ .../com/google/adk/models/sarvamai/SarvamAiException.java | 2 ++ .../google/adk/models/sarvamai/SarvamAiLlmConnection.java | 2 ++ .../google/adk/models/sarvamai/SarvamRetryInterceptor.java | 2 ++ .../com/google/adk/models/sarvamai/chat/ChatChoice.java | 2 ++ .../com/google/adk/models/sarvamai/chat/ChatMessage.java | 6 +++++- .../com/google/adk/models/sarvamai/chat/ChatRequest.java | 2 ++ .../com/google/adk/models/sarvamai/chat/ChatResponse.java | 2 ++ .../java/com/google/adk/models/sarvamai/chat/ChatUsage.java | 6 +++++- .../google/adk/models/sarvamai/stt/SarvamSttService.java | 2 ++ .../google/adk/models/sarvamai/tts/SarvamTtsService.java | 2 ++ .../java/com/google/adk/models/sarvamai/tts/TtsRequest.java | 6 +++++- .../com/google/adk/models/sarvamai/tts/TtsResponse.java | 6 +++++- .../adk/models/sarvamai/vision/SarvamVisionService.java | 2 ++ .../com/google/adk/models/sarvamai/SarvamAiConfigTest.java | 1 + .../java/com/google/adk/models/sarvamai/SarvamAiTest.java | 1 + .../adk/models/sarvamai/SarvamRetryInterceptorTest.java | 1 + .../google/adk/models/sarvamai/chat/ChatRequestTest.java | 1 + .../adk/models/sarvamai/stt/SarvamSttServiceTest.java | 1 + .../adk/models/sarvamai/tts/SarvamTtsServiceTest.java | 1 + .../test/java/com/google/adk/models/SarvamBaseLMTest.java | 3 +++ 22 files changed, 51 insertions(+), 4 deletions(-) diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java index 634eab1a8..4ced7f6c9 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAi.java @@ -62,6 +62,8 @@ * .build()) * .build(); * } + * + * @author Sandeep Belgavi */ public class SarvamAi extends BaseLlm { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java index 3c3571f1f..061bf4818 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiConfig.java @@ -31,6 +31,8 @@ * Builder pattern for safe, incremental construction with sensible defaults. * *

API key resolution order: explicit value > {@code SARVAM_API_KEY} environment variable. + * + * @author Sandeep Belgavi */ public final class SarvamAiConfig { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiException.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiException.java index bbd3c4a46..7c52f76c5 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiException.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiException.java @@ -21,6 +21,8 @@ /** * Domain exception for Sarvam AI API errors. Carries structured error information from the API * response for programmatic error handling. + * + * @author Sandeep Belgavi */ public class SarvamAiException extends RuntimeException { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiLlmConnection.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiLlmConnection.java index bbaa2f1da..e4348a0e7 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiLlmConnection.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamAiLlmConnection.java @@ -37,6 +37,8 @@ * *

Maintains conversation history and streams responses token-by-token using SSE. Accumulates the * full model response into history after each turn to support multi-turn context. + * + * @author Sandeep Belgavi */ final class SarvamAiLlmConnection implements BaseLlmConnection { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamRetryInterceptor.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamRetryInterceptor.java index da0874ac5..8f0d9bda5 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamRetryInterceptor.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/SarvamRetryInterceptor.java @@ -26,6 +26,8 @@ /** * OkHttp interceptor that implements exponential backoff with jitter for retryable Sarvam API * errors (429 rate limit, 5xx server errors). + * + * @author Sandeep Belgavi */ final class SarvamRetryInterceptor implements Interceptor { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatChoice.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatChoice.java index 5aff17c63..0dd907812 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatChoice.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatChoice.java @@ -22,6 +22,8 @@ /** * A choice in the Sarvam AI chat completion response. Handles both non-streaming ({@code message}) * and streaming ({@code delta}) response formats. + * + * @author Sandeep Belgavi */ @JsonIgnoreProperties(ignoreUnknown = true) public final class ChatChoice { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java index c84336cd7..a820ac47e 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatMessage.java @@ -20,7 +20,11 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -/** A message in the Sarvam AI chat completion API (request or response). */ +/** + * A message in the Sarvam AI chat completion API (request or response). + * + * @author Sandeep Belgavi + */ @JsonIgnoreProperties(ignoreUnknown = true) @JsonInclude(JsonInclude.Include.NON_NULL) public final class ChatMessage { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java index d63d57d1d..3faefa2e9 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatRequest.java @@ -28,6 +28,8 @@ /** * Request body for the Sarvam AI chat completions endpoint. Constructed from the ADK {@link * LlmRequest} and {@link SarvamAiConfig}. + * + * @author Sandeep Belgavi */ @JsonInclude(JsonInclude.Include.NON_NULL) public final class ChatRequest { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatResponse.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatResponse.java index 6be3efaef..b3a215475 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatResponse.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatResponse.java @@ -23,6 +23,8 @@ /** * Response from the Sarvam AI chat completions endpoint. Supports both non-streaming and streaming * (SSE chunk) formats. + * + * @author Sandeep Belgavi */ @JsonIgnoreProperties(ignoreUnknown = true) public final class ChatResponse { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatUsage.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatUsage.java index 120dd3314..11812cf1b 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatUsage.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/chat/ChatUsage.java @@ -19,7 +19,11 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; -/** Token usage metadata from Sarvam AI API response. */ +/** + * Token usage metadata from Sarvam AI API response. + * + * @author Sandeep Belgavi + */ @JsonIgnoreProperties(ignoreUnknown = true) public final class ChatUsage { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/stt/SarvamSttService.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/stt/SarvamSttService.java index ceec7483b..0398f5ef7 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/stt/SarvamSttService.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/stt/SarvamSttService.java @@ -56,6 +56,8 @@ *

  • WebSocket streaming ({@link #transcribeStream}): Real-time streaming via WebSocket * with VAD support, delivering partial and final transcription events. * + * + * @author Sandeep Belgavi */ public final class SarvamSttService implements TranscriptionService { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/SarvamTtsService.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/SarvamTtsService.java index fc68608b0..414a8b5b6 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/SarvamTtsService.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/SarvamTtsService.java @@ -45,6 +45,8 @@ *

    WebSocket streaming mode ({@link #synthesizeStream}): Opens a persistent WebSocket connection * for progressive audio chunk delivery with low latency. Audio chunks are emitted as they are * synthesized, enabling real-time playback. + * + * @author Sandeep Belgavi */ public final class SarvamTtsService { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsRequest.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsRequest.java index 152b84fc6..b387cec08 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsRequest.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsRequest.java @@ -19,7 +19,11 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -/** Request body for the Sarvam AI text-to-speech REST endpoint. */ +/** + * Request body for the Sarvam AI text-to-speech REST endpoint. + * + * @author Sandeep Belgavi + */ @JsonInclude(JsonInclude.Include.NON_NULL) public final class TtsRequest { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsResponse.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsResponse.java index 61a6e9f37..3712bdad6 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsResponse.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/tts/TtsResponse.java @@ -20,7 +20,11 @@ import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; -/** Response from the Sarvam AI text-to-speech REST endpoint. */ +/** + * Response from the Sarvam AI text-to-speech REST endpoint. + * + * @author Sandeep Belgavi + */ @JsonIgnoreProperties(ignoreUnknown = true) public final class TtsResponse { diff --git a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/vision/SarvamVisionService.java b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/vision/SarvamVisionService.java index a451d5d0b..420d491ce 100644 --- a/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/vision/SarvamVisionService.java +++ b/contrib/sarvam-ai/src/main/java/com/google/adk/models/sarvamai/vision/SarvamVisionService.java @@ -51,6 +51,8 @@ *

  • {@link #getJobStatus} - Poll for completion *
  • {@link #downloadResults} - Retrieve the processed output * + * + * @author Sandeep Belgavi */ public final class SarvamVisionService { diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiConfigTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiConfigTest.java index 2d7cb4770..b1a5243a0 100644 --- a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiConfigTest.java +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiConfigTest.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; +/** @author Sandeep Belgavi */ class SarvamAiConfigTest { @Test diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java index c04cf3add..9fb79c8f6 100644 --- a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamAiTest.java @@ -35,6 +35,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +/** @author Sandeep Belgavi */ class SarvamAiTest { private MockWebServer server; diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamRetryInterceptorTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamRetryInterceptorTest.java index ff8614bd7..f62907cde 100644 --- a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamRetryInterceptorTest.java +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/SarvamRetryInterceptorTest.java @@ -20,6 +20,7 @@ import org.junit.jupiter.api.Test; +/** @author Sandeep Belgavi */ class SarvamRetryInterceptorTest { @Test diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/chat/ChatRequestTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/chat/ChatRequestTest.java index 36cb04d96..aa39eb743 100644 --- a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/chat/ChatRequestTest.java +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/chat/ChatRequestTest.java @@ -27,6 +27,7 @@ import java.util.List; import org.junit.jupiter.api.Test; +/** @author Sandeep Belgavi */ class ChatRequestTest { private final ObjectMapper objectMapper = new ObjectMapper(); diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/stt/SarvamSttServiceTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/stt/SarvamSttServiceTest.java index b69529368..8fca0ee6f 100644 --- a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/stt/SarvamSttServiceTest.java +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/stt/SarvamSttServiceTest.java @@ -33,6 +33,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +/** @author Sandeep Belgavi */ class SarvamSttServiceTest { private MockWebServer server; diff --git a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/tts/SarvamTtsServiceTest.java b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/tts/SarvamTtsServiceTest.java index bcae3e3d4..922cc8572 100644 --- a/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/tts/SarvamTtsServiceTest.java +++ b/contrib/sarvam-ai/src/test/java/com/google/adk/models/sarvamai/tts/SarvamTtsServiceTest.java @@ -32,6 +32,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +/** @author Sandeep Belgavi */ class SarvamTtsServiceTest { private MockWebServer server; diff --git a/core/src/test/java/com/google/adk/models/SarvamBaseLMTest.java b/core/src/test/java/com/google/adk/models/SarvamBaseLMTest.java index 972253b97..ef3d6edb5 100644 --- a/core/src/test/java/com/google/adk/models/SarvamBaseLMTest.java +++ b/core/src/test/java/com/google/adk/models/SarvamBaseLMTest.java @@ -12,6 +12,9 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) +/** + * @author Sandeep Belgavi + */ public final class SarvamBaseLMTest { // ========== openAiMessageToParts tests ==========