diff --git a/core/pom.xml b/core/pom.xml index 165a7b3fa..d31a2691b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -201,6 +201,59 @@ maven-compiler-plugin + + maven-surefire-plugin + + + basic + + test + + + + apigee-llm + + test + + + ApigeeLlmTest + + + api-key + false + + + + + apigee-llm-vertex-ai + + test + + + ApigeeLlmTest#generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi + + api-key + + true + + + + + apigee-llm-proxy-url + + test + + + ApigeeLlmTest#build_withoutProxyUrl_readsFromEnvironment + + api-key + + proxy-url + + + + + diff --git a/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java b/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java index 1d1202ead..4e75dab75 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java +++ b/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java @@ -17,6 +17,8 @@ package com.google.adk.codeexecutors; +import static java.util.Objects.requireNonNullElse; + import com.github.dockerjava.api.DockerClient; import com.github.dockerjava.api.command.ExecCreateCmdResponse; import com.github.dockerjava.api.model.Container; @@ -32,7 +34,6 @@ import java.io.UncheckedIOException; import java.nio.charset.StandardCharsets; import java.nio.file.Paths; -import java.util.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,37 +42,68 @@ public class ContainerCodeExecutor extends BaseCodeExecutor { private static final Logger logger = LoggerFactory.getLogger(ContainerCodeExecutor.class); private static final String DEFAULT_IMAGE_TAG = "adk-code-executor:latest"; - private final Optional baseUrl; + private final String baseUrl; private final String image; - private final Optional dockerPath; + private final String dockerPath; private final DockerClient dockerClient; private Container container; /** - * Initializes the ContainerCodeExecutor. + * Creates a ContainerCodeExecutor from an image. + * + * @param baseUrl The base url of the user hosted Docker client. + * @param image The tag of the predefined image or custom image to run on the container. + */ + public static ContainerCodeExecutor fromImage(String baseUrl, String image) { + return new ContainerCodeExecutor(baseUrl, image, null); + } + + /** + * Creates a ContainerCodeExecutor from an image. + * + * @param image The tag of the predefined image or custom image to run on the container. + */ + public static ContainerCodeExecutor fromImage(String image) { + return new ContainerCodeExecutor(null, image, null); + } + + /** + * Creates a ContainerCodeExecutor from a Dockerfile path. + * + * @param baseUrl The base url of the user hosted Docker client. + * @param dockerPath The path to the directory containing the Dockerfile. + */ + public static ContainerCodeExecutor fromDockerPath(String baseUrl, String dockerPath) { + return new ContainerCodeExecutor(baseUrl, null, dockerPath); + } + + /** + * Creates a ContainerCodeExecutor from a Dockerfile path. + * + * @param dockerPath The path to the directory containing the Dockerfile. + */ + public static ContainerCodeExecutor fromDockerPath(String dockerPath) { + return new ContainerCodeExecutor(null, null, dockerPath); + } + + /** + * Initializes the ContainerCodeExecutor. Either dockerPath or image must be set. * - * @param baseUrl Optional. The base url of the user hosted Docker client. - * @param image The tag of the predefined image or custom image to run on the container. Either - * dockerPath or image must be set. - * @param dockerPath The path to the directory containing the Dockerfile. If set, build the image - * from the dockerfile path instead of using the predefined image. Either dockerPath or image - * must be set. + * @deprecated Use one of the static factory methods instead. */ - public ContainerCodeExecutor( - Optional baseUrl, Optional image, Optional dockerPath) { - if (image.isEmpty() && dockerPath.isEmpty()) { + @Deprecated + public ContainerCodeExecutor(String baseUrl, String image, String dockerPath) { + if (image == null && dockerPath == null) { throw new IllegalArgumentException( "Either image or dockerPath must be set for ContainerCodeExecutor."); } this.baseUrl = baseUrl; - this.image = image.orElse(DEFAULT_IMAGE_TAG); - this.dockerPath = dockerPath.map(p -> Paths.get(p).toAbsolutePath().toString()); + this.image = requireNonNullElse(image, DEFAULT_IMAGE_TAG); + this.dockerPath = dockerPath == null ? null : Paths.get(dockerPath).toAbsolutePath().toString(); - if (baseUrl.isPresent()) { + if (baseUrl != null) { var config = - DefaultDockerClientConfig.createDefaultConfigBuilder() - .withDockerHost(baseUrl.get()) - .build(); + DefaultDockerClientConfig.createDefaultConfigBuilder().withDockerHost(baseUrl).build(); this.dockerClient = DockerClientBuilder.getInstance(config).build(); } else { this.dockerClient = DockerClientBuilder.getInstance().build(); @@ -121,12 +153,12 @@ public CodeExecutionResult executeCode( } private void buildDockerImage() { - if (dockerPath.isEmpty()) { + if (dockerPath == null) { throw new IllegalStateException("Docker path is not set."); } - File dockerfile = new File(dockerPath.get()); + File dockerfile = new File(dockerPath); if (!dockerfile.exists()) { - throw new UncheckedIOException(new IOException("Invalid Docker path: " + dockerPath.get())); + throw new UncheckedIOException(new IOException("Invalid Docker path: " + dockerPath)); } logger.info("Building Docker image..."); @@ -158,7 +190,7 @@ private void initContainer() { if (dockerClient == null) { throw new IllegalStateException("Docker client is not initialized."); } - if (dockerPath.isPresent()) { + if (dockerPath != null) { buildDockerImage(); } else { // If a dockerPath is not provided, always pull the image to ensure it's up-to-date. diff --git a/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java b/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java index 65364e7b4..6ba2832c0 100644 --- a/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java +++ b/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java @@ -31,6 +31,7 @@ import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; import java.util.Map; +import java.util.Objects; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -57,11 +58,11 @@ public void checkApiKey() { @Test public void build_withValidModelStrings_succeeds() { String[] validModelStrings = { - "apigee/gemini-1.5-flash", - "apigee/v1/gemini-1.5-flash", - "apigee/vertex_ai/gemini-1.5-flash", - "apigee/gemini/v1/gemini-1.5-flash", - "apigee/vertex_ai/v1beta/gemini-1.5-flash" + "apigee/whatever-model", + "apigee/v1/whatever-model", + "apigee/vertex_ai/whatever-model", + "apigee/gemini/v1/whatever-model", + "apigee/vertex_ai/v1beta/whatever-model" }; for (String modelName : validModelStrings) { @@ -93,18 +94,18 @@ public void build_withInvalidModelStrings_throwsException() { public void generateContent_stripsApigeePrefixAndSendsToDelegate() { when(mockGeminiDelegate.generateContent(any(), anyBoolean())).thenReturn(Flowable.empty()); - ApigeeLlm llm = new ApigeeLlm("apigee/gemini/v1/gemini-1.5-flash", mockGeminiDelegate); + ApigeeLlm llm = new ApigeeLlm("apigee/gemini/v1/whatever-model", mockGeminiDelegate); LlmRequest request = LlmRequest.builder() - .model("apigee/gemini/v1/gemini-1.5-flash") + .model("apigee/gemini/v1/whatever-model") .contents(ImmutableList.of(Content.builder().parts(Part.fromText("hi")).build())) .build(); llm.generateContent(request, true).test().assertNoErrors(); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(LlmRequest.class); verify(mockGeminiDelegate).generateContent(requestCaptor.capture(), eq(true)); - assertThat(requestCaptor.getValue().model()).hasValue("gemini-1.5-flash"); + assertThat(requestCaptor.getValue().model()).hasValue("whatever-model"); } // Add a test to verify the vertexAI flag is set correctly. @@ -112,7 +113,7 @@ public void generateContent_stripsApigeePrefixAndSendsToDelegate() { public void generateContent_setsVertexAiFlagCorrectly_withVertexAi() { ApigeeLlm llm = ApigeeLlm.builder() - .modelName("apigee/vertex_ai/gemini-1.5-flash") + .modelName("apigee/vertex_ai/whatever-model") .proxyUrl(PROXY_URL) .build(); assertThat(llm.getApiClient().vertexAI()).isTrue(); @@ -122,8 +123,10 @@ public void generateContent_setsVertexAiFlagCorrectly_withVertexAi() { public void generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi() { ApigeeLlm llm = - ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").proxyUrl(PROXY_URL).build(); - if (System.getenv("GOOGLE_GENAI_USE_VERTEXAI") != null) { + ApigeeLlm.builder().modelName("apigee/whatever-model").proxyUrl(PROXY_URL).build(); + String useVertexAi = System.getenv("GOOGLE_GENAI_USE_VERTEXAI"); + + if (Objects.equals(useVertexAi, "true") || Objects.equals(useVertexAi, "1")) { assertThat(llm.getApiClient().vertexAI()).isTrue(); } else { assertThat(llm.getApiClient().vertexAI()).isFalse(); @@ -133,7 +136,7 @@ public void generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi() { @Test public void generateContent_setsVertexAiFlagCorrectly_withGemini() { ApigeeLlm llm = - ApigeeLlm.builder().modelName("apigee/gemini/gemini-1.5-flash").proxyUrl(PROXY_URL).build(); + ApigeeLlm.builder().modelName("apigee/gemini/whatever-model").proxyUrl(PROXY_URL).build(); assertThat(llm.getApiClient().vertexAI()).isFalse(); } @@ -142,11 +145,11 @@ public void generateContent_setsVertexAiFlagCorrectly_withGemini() { public void generateContent_setsApiVersionCorrectly() { ImmutableMap modelToApiVersion = ImmutableMap.of( - "apigee/gemini-1.5-flash", "", - "apigee/v1/gemini-1.5-flash", "v1", - "apigee/vertex_ai/gemini-1.5-flash", "", - "apigee/gemini/v1/gemini-1.5-flash", "v1", - "apigee/vertex_ai/v1beta/gemini-1.5-flash", "v1beta"); + "apigee/whatever-model", "", + "apigee/v1/whatever-model", "v1", + "apigee/vertex_ai/whatever-model", "", + "apigee/gemini/v1/whatever-model", "v1", + "apigee/vertex_ai/v1beta/whatever-model", "v1beta"); for (Map.Entry entry : modelToApiVersion.entrySet()) { String modelName = entry.getKey(); @@ -165,7 +168,7 @@ public void build_withCustomHeaders_setsHeadersInHttpOptions() { ImmutableMap customHeaders = ImmutableMap.of("X-Test-Header", "TestValue"); ApigeeLlm llm = ApigeeLlm.builder() - .modelName("apigee/gemini-1.5-flash") + .modelName("apigee/whatever-model") .proxyUrl(PROXY_URL) .customHeaders(customHeaders) .build(); @@ -192,14 +195,14 @@ public void build_withTrailingSlashInModel_parsesVersionAndModelId() { public void build_withoutProxyUrl_readsFromEnvironment() { String envProxyUrl = System.getenv("APIGEE_PROXY_URL"); if (envProxyUrl != null) { - ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").build(); + ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/whatever-model").build(); assertThat(llm.getHttpOptions().baseUrl()).hasValue(envProxyUrl); } else { assertThrows( IllegalArgumentException.class, - () -> ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").build()); + () -> ApigeeLlm.builder().modelName("apigee/whatever-model").build()); ApigeeLlm llm = - ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/gemini-1.5-flash").build(); + ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/whatever-model").build(); assertThat(llm.getHttpOptions().baseUrl()).hasValue(PROXY_URL); } }