diff --git a/core/pom.xml b/core/pom.xml
index 165a7b3fa..d31a2691b 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -201,6 +201,59 @@
maven-compiler-plugin
+
+ maven-surefire-plugin
+
+
+ basic
+
+ test
+
+
+
+ apigee-llm
+
+ test
+
+
+ ApigeeLlmTest
+
+
+ api-key
+ false
+
+
+
+
+ apigee-llm-vertex-ai
+
+ test
+
+
+ ApigeeLlmTest#generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi
+
+ api-key
+
+ true
+
+
+
+
+ apigee-llm-proxy-url
+
+ test
+
+
+ ApigeeLlmTest#build_withoutProxyUrl_readsFromEnvironment
+
+ api-key
+
+ proxy-url
+
+
+
+
+
diff --git a/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java b/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java
index 1d1202ead..4e75dab75 100644
--- a/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java
+++ b/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java
@@ -17,6 +17,8 @@
package com.google.adk.codeexecutors;
+import static java.util.Objects.requireNonNullElse;
+
import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.command.ExecCreateCmdResponse;
import com.github.dockerjava.api.model.Container;
@@ -32,7 +34,6 @@
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
-import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -41,37 +42,68 @@ public class ContainerCodeExecutor extends BaseCodeExecutor {
private static final Logger logger = LoggerFactory.getLogger(ContainerCodeExecutor.class);
private static final String DEFAULT_IMAGE_TAG = "adk-code-executor:latest";
- private final Optional baseUrl;
+ private final String baseUrl;
private final String image;
- private final Optional dockerPath;
+ private final String dockerPath;
private final DockerClient dockerClient;
private Container container;
/**
- * Initializes the ContainerCodeExecutor.
+ * Creates a ContainerCodeExecutor from an image.
+ *
+ * @param baseUrl The base url of the user hosted Docker client.
+ * @param image The tag of the predefined image or custom image to run on the container.
+ */
+ public static ContainerCodeExecutor fromImage(String baseUrl, String image) {
+ return new ContainerCodeExecutor(baseUrl, image, null);
+ }
+
+ /**
+ * Creates a ContainerCodeExecutor from an image.
+ *
+ * @param image The tag of the predefined image or custom image to run on the container.
+ */
+ public static ContainerCodeExecutor fromImage(String image) {
+ return new ContainerCodeExecutor(null, image, null);
+ }
+
+ /**
+ * Creates a ContainerCodeExecutor from a Dockerfile path.
+ *
+ * @param baseUrl The base url of the user hosted Docker client.
+ * @param dockerPath The path to the directory containing the Dockerfile.
+ */
+ public static ContainerCodeExecutor fromDockerPath(String baseUrl, String dockerPath) {
+ return new ContainerCodeExecutor(baseUrl, null, dockerPath);
+ }
+
+ /**
+ * Creates a ContainerCodeExecutor from a Dockerfile path.
+ *
+ * @param dockerPath The path to the directory containing the Dockerfile.
+ */
+ public static ContainerCodeExecutor fromDockerPath(String dockerPath) {
+ return new ContainerCodeExecutor(null, null, dockerPath);
+ }
+
+ /**
+ * Initializes the ContainerCodeExecutor. Either dockerPath or image must be set.
*
- * @param baseUrl Optional. The base url of the user hosted Docker client.
- * @param image The tag of the predefined image or custom image to run on the container. Either
- * dockerPath or image must be set.
- * @param dockerPath The path to the directory containing the Dockerfile. If set, build the image
- * from the dockerfile path instead of using the predefined image. Either dockerPath or image
- * must be set.
+ * @deprecated Use one of the static factory methods instead.
*/
- public ContainerCodeExecutor(
- Optional baseUrl, Optional image, Optional dockerPath) {
- if (image.isEmpty() && dockerPath.isEmpty()) {
+ @Deprecated
+ public ContainerCodeExecutor(String baseUrl, String image, String dockerPath) {
+ if (image == null && dockerPath == null) {
throw new IllegalArgumentException(
"Either image or dockerPath must be set for ContainerCodeExecutor.");
}
this.baseUrl = baseUrl;
- this.image = image.orElse(DEFAULT_IMAGE_TAG);
- this.dockerPath = dockerPath.map(p -> Paths.get(p).toAbsolutePath().toString());
+ this.image = requireNonNullElse(image, DEFAULT_IMAGE_TAG);
+ this.dockerPath = dockerPath == null ? null : Paths.get(dockerPath).toAbsolutePath().toString();
- if (baseUrl.isPresent()) {
+ if (baseUrl != null) {
var config =
- DefaultDockerClientConfig.createDefaultConfigBuilder()
- .withDockerHost(baseUrl.get())
- .build();
+ DefaultDockerClientConfig.createDefaultConfigBuilder().withDockerHost(baseUrl).build();
this.dockerClient = DockerClientBuilder.getInstance(config).build();
} else {
this.dockerClient = DockerClientBuilder.getInstance().build();
@@ -121,12 +153,12 @@ public CodeExecutionResult executeCode(
}
private void buildDockerImage() {
- if (dockerPath.isEmpty()) {
+ if (dockerPath == null) {
throw new IllegalStateException("Docker path is not set.");
}
- File dockerfile = new File(dockerPath.get());
+ File dockerfile = new File(dockerPath);
if (!dockerfile.exists()) {
- throw new UncheckedIOException(new IOException("Invalid Docker path: " + dockerPath.get()));
+ throw new UncheckedIOException(new IOException("Invalid Docker path: " + dockerPath));
}
logger.info("Building Docker image...");
@@ -158,7 +190,7 @@ private void initContainer() {
if (dockerClient == null) {
throw new IllegalStateException("Docker client is not initialized.");
}
- if (dockerPath.isPresent()) {
+ if (dockerPath != null) {
buildDockerImage();
} else {
// If a dockerPath is not provided, always pull the image to ensure it's up-to-date.
diff --git a/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java b/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java
index 65364e7b4..6ba2832c0 100644
--- a/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java
+++ b/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java
@@ -31,6 +31,7 @@
import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Flowable;
import java.util.Map;
+import java.util.Objects;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -57,11 +58,11 @@ public void checkApiKey() {
@Test
public void build_withValidModelStrings_succeeds() {
String[] validModelStrings = {
- "apigee/gemini-1.5-flash",
- "apigee/v1/gemini-1.5-flash",
- "apigee/vertex_ai/gemini-1.5-flash",
- "apigee/gemini/v1/gemini-1.5-flash",
- "apigee/vertex_ai/v1beta/gemini-1.5-flash"
+ "apigee/whatever-model",
+ "apigee/v1/whatever-model",
+ "apigee/vertex_ai/whatever-model",
+ "apigee/gemini/v1/whatever-model",
+ "apigee/vertex_ai/v1beta/whatever-model"
};
for (String modelName : validModelStrings) {
@@ -93,18 +94,18 @@ public void build_withInvalidModelStrings_throwsException() {
public void generateContent_stripsApigeePrefixAndSendsToDelegate() {
when(mockGeminiDelegate.generateContent(any(), anyBoolean())).thenReturn(Flowable.empty());
- ApigeeLlm llm = new ApigeeLlm("apigee/gemini/v1/gemini-1.5-flash", mockGeminiDelegate);
+ ApigeeLlm llm = new ApigeeLlm("apigee/gemini/v1/whatever-model", mockGeminiDelegate);
LlmRequest request =
LlmRequest.builder()
- .model("apigee/gemini/v1/gemini-1.5-flash")
+ .model("apigee/gemini/v1/whatever-model")
.contents(ImmutableList.of(Content.builder().parts(Part.fromText("hi")).build()))
.build();
llm.generateContent(request, true).test().assertNoErrors();
ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(LlmRequest.class);
verify(mockGeminiDelegate).generateContent(requestCaptor.capture(), eq(true));
- assertThat(requestCaptor.getValue().model()).hasValue("gemini-1.5-flash");
+ assertThat(requestCaptor.getValue().model()).hasValue("whatever-model");
}
// Add a test to verify the vertexAI flag is set correctly.
@@ -112,7 +113,7 @@ public void generateContent_stripsApigeePrefixAndSendsToDelegate() {
public void generateContent_setsVertexAiFlagCorrectly_withVertexAi() {
ApigeeLlm llm =
ApigeeLlm.builder()
- .modelName("apigee/vertex_ai/gemini-1.5-flash")
+ .modelName("apigee/vertex_ai/whatever-model")
.proxyUrl(PROXY_URL)
.build();
assertThat(llm.getApiClient().vertexAI()).isTrue();
@@ -122,8 +123,10 @@ public void generateContent_setsVertexAiFlagCorrectly_withVertexAi() {
public void generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi() {
ApigeeLlm llm =
- ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").proxyUrl(PROXY_URL).build();
- if (System.getenv("GOOGLE_GENAI_USE_VERTEXAI") != null) {
+ ApigeeLlm.builder().modelName("apigee/whatever-model").proxyUrl(PROXY_URL).build();
+ String useVertexAi = System.getenv("GOOGLE_GENAI_USE_VERTEXAI");
+
+ if (Objects.equals(useVertexAi, "true") || Objects.equals(useVertexAi, "1")) {
assertThat(llm.getApiClient().vertexAI()).isTrue();
} else {
assertThat(llm.getApiClient().vertexAI()).isFalse();
@@ -133,7 +136,7 @@ public void generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi() {
@Test
public void generateContent_setsVertexAiFlagCorrectly_withGemini() {
ApigeeLlm llm =
- ApigeeLlm.builder().modelName("apigee/gemini/gemini-1.5-flash").proxyUrl(PROXY_URL).build();
+ ApigeeLlm.builder().modelName("apigee/gemini/whatever-model").proxyUrl(PROXY_URL).build();
assertThat(llm.getApiClient().vertexAI()).isFalse();
}
@@ -142,11 +145,11 @@ public void generateContent_setsVertexAiFlagCorrectly_withGemini() {
public void generateContent_setsApiVersionCorrectly() {
ImmutableMap modelToApiVersion =
ImmutableMap.of(
- "apigee/gemini-1.5-flash", "",
- "apigee/v1/gemini-1.5-flash", "v1",
- "apigee/vertex_ai/gemini-1.5-flash", "",
- "apigee/gemini/v1/gemini-1.5-flash", "v1",
- "apigee/vertex_ai/v1beta/gemini-1.5-flash", "v1beta");
+ "apigee/whatever-model", "",
+ "apigee/v1/whatever-model", "v1",
+ "apigee/vertex_ai/whatever-model", "",
+ "apigee/gemini/v1/whatever-model", "v1",
+ "apigee/vertex_ai/v1beta/whatever-model", "v1beta");
for (Map.Entry entry : modelToApiVersion.entrySet()) {
String modelName = entry.getKey();
@@ -165,7 +168,7 @@ public void build_withCustomHeaders_setsHeadersInHttpOptions() {
ImmutableMap customHeaders = ImmutableMap.of("X-Test-Header", "TestValue");
ApigeeLlm llm =
ApigeeLlm.builder()
- .modelName("apigee/gemini-1.5-flash")
+ .modelName("apigee/whatever-model")
.proxyUrl(PROXY_URL)
.customHeaders(customHeaders)
.build();
@@ -192,14 +195,14 @@ public void build_withTrailingSlashInModel_parsesVersionAndModelId() {
public void build_withoutProxyUrl_readsFromEnvironment() {
String envProxyUrl = System.getenv("APIGEE_PROXY_URL");
if (envProxyUrl != null) {
- ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").build();
+ ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/whatever-model").build();
assertThat(llm.getHttpOptions().baseUrl()).hasValue(envProxyUrl);
} else {
assertThrows(
IllegalArgumentException.class,
- () -> ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").build());
+ () -> ApigeeLlm.builder().modelName("apigee/whatever-model").build());
ApigeeLlm llm =
- ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/gemini-1.5-flash").build();
+ ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/whatever-model").build();
assertThat(llm.getHttpOptions().baseUrl()).hasValue(PROXY_URL);
}
}