Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,59 @@
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<executions>
<execution>
<id>basic</id>
<goals>
<goal>test</goal>
</goals>
</execution>
<execution>
<id>apigee-llm</id>
<goals>
<goal>test</goal>
</goals>
<configuration>
<test>ApigeeLlmTest</test>
<!-- this test requires the following env -->
<environmentVariables>
<GOOGLE_API_KEY>api-key</GOOGLE_API_KEY>
<GOOGLE_GENAI_USE_VERTEXAI>false</GOOGLE_GENAI_USE_VERTEXAI>
</environmentVariables>
</configuration>
</execution>
<execution>
<id>apigee-llm-vertex-ai</id>
<goals>
<goal>test</goal>
</goals>
<configuration>
<test>ApigeeLlmTest#generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi</test>
<environmentVariables>
<GOOGLE_API_KEY>api-key</GOOGLE_API_KEY>
<!-- runs a second variant of the test method -->
<GOOGLE_GENAI_USE_VERTEXAI>true</GOOGLE_GENAI_USE_VERTEXAI>
</environmentVariables>
</configuration>
</execution>
<execution>
<id>apigee-llm-proxy-url</id>
<goals>
<goal>test</goal>
</goals>
<configuration>
<test>ApigeeLlmTest#build_withoutProxyUrl_readsFromEnvironment</test>
<environmentVariables>
<GOOGLE_API_KEY>api-key</GOOGLE_API_KEY>
<!-- runs a second variant of the test method -->
<APIGEE_PROXY_URL>proxy-url</APIGEE_PROXY_URL>
</environmentVariables>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package com.google.adk.codeexecutors;

import static java.util.Objects.requireNonNullElse;

import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.command.ExecCreateCmdResponse;
import com.github.dockerjava.api.model.Container;
Expand All @@ -32,7 +34,6 @@
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -41,37 +42,68 @@ public class ContainerCodeExecutor extends BaseCodeExecutor {
private static final Logger logger = LoggerFactory.getLogger(ContainerCodeExecutor.class);
private static final String DEFAULT_IMAGE_TAG = "adk-code-executor:latest";

private final Optional<String> baseUrl;
private final String baseUrl;
private final String image;
private final Optional<String> dockerPath;
private final String dockerPath;
private final DockerClient dockerClient;
private Container container;

/**
* Initializes the ContainerCodeExecutor.
* Creates a ContainerCodeExecutor from an image.
*
* @param baseUrl The base url of the user hosted Docker client.
* @param image The tag of the predefined image or custom image to run on the container.
*/
public static ContainerCodeExecutor fromImage(String baseUrl, String image) {
return new ContainerCodeExecutor(baseUrl, image, null);
}

/**
* Creates a ContainerCodeExecutor from an image.
*
* @param image The tag of the predefined image or custom image to run on the container.
*/
public static ContainerCodeExecutor fromImage(String image) {
return new ContainerCodeExecutor(null, image, null);
}

/**
* Creates a ContainerCodeExecutor from a Dockerfile path.
*
* @param baseUrl The base url of the user hosted Docker client.
* @param dockerPath The path to the directory containing the Dockerfile.
*/
public static ContainerCodeExecutor fromDockerPath(String baseUrl, String dockerPath) {
return new ContainerCodeExecutor(baseUrl, null, dockerPath);
}

/**
* Creates a ContainerCodeExecutor from a Dockerfile path.
*
* @param dockerPath The path to the directory containing the Dockerfile.
*/
public static ContainerCodeExecutor fromDockerPath(String dockerPath) {
return new ContainerCodeExecutor(null, null, dockerPath);
}

/**
* Initializes the ContainerCodeExecutor. Either dockerPath or image must be set.
*
* @param baseUrl Optional. The base url of the user hosted Docker client.
* @param image The tag of the predefined image or custom image to run on the container. Either
* dockerPath or image must be set.
* @param dockerPath The path to the directory containing the Dockerfile. If set, build the image
* from the dockerfile path instead of using the predefined image. Either dockerPath or image
* must be set.
* @deprecated Use one of the static factory methods instead.
*/
public ContainerCodeExecutor(
Optional<String> baseUrl, Optional<String> image, Optional<String> dockerPath) {
if (image.isEmpty() && dockerPath.isEmpty()) {
@Deprecated
public ContainerCodeExecutor(String baseUrl, String image, String dockerPath) {
if (image == null && dockerPath == null) {
throw new IllegalArgumentException(
"Either image or dockerPath must be set for ContainerCodeExecutor.");
}
this.baseUrl = baseUrl;
this.image = image.orElse(DEFAULT_IMAGE_TAG);
this.dockerPath = dockerPath.map(p -> Paths.get(p).toAbsolutePath().toString());
this.image = requireNonNullElse(image, DEFAULT_IMAGE_TAG);
this.dockerPath = dockerPath == null ? null : Paths.get(dockerPath).toAbsolutePath().toString();

if (baseUrl.isPresent()) {
if (baseUrl != null) {
var config =
DefaultDockerClientConfig.createDefaultConfigBuilder()
.withDockerHost(baseUrl.get())
.build();
DefaultDockerClientConfig.createDefaultConfigBuilder().withDockerHost(baseUrl).build();
this.dockerClient = DockerClientBuilder.getInstance(config).build();
} else {
this.dockerClient = DockerClientBuilder.getInstance().build();
Expand Down Expand Up @@ -121,12 +153,12 @@ public CodeExecutionResult executeCode(
}

private void buildDockerImage() {
if (dockerPath.isEmpty()) {
if (dockerPath == null) {
throw new IllegalStateException("Docker path is not set.");
}
File dockerfile = new File(dockerPath.get());
File dockerfile = new File(dockerPath);
if (!dockerfile.exists()) {
throw new UncheckedIOException(new IOException("Invalid Docker path: " + dockerPath.get()));
throw new UncheckedIOException(new IOException("Invalid Docker path: " + dockerPath));
}

logger.info("Building Docker image...");
Expand Down Expand Up @@ -158,7 +190,7 @@ private void initContainer() {
if (dockerClient == null) {
throw new IllegalStateException("Docker client is not initialized.");
}
if (dockerPath.isPresent()) {
if (dockerPath != null) {
buildDockerImage();
} else {
// If a dockerPath is not provided, always pull the image to ensure it's up-to-date.
Expand Down
45 changes: 24 additions & 21 deletions core/src/test/java/com/google/adk/models/ApigeeLlmTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Flowable;
import java.util.Map;
import java.util.Objects;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand All @@ -57,11 +58,11 @@ public void checkApiKey() {
@Test
public void build_withValidModelStrings_succeeds() {
String[] validModelStrings = {
"apigee/gemini-1.5-flash",
"apigee/v1/gemini-1.5-flash",
"apigee/vertex_ai/gemini-1.5-flash",
"apigee/gemini/v1/gemini-1.5-flash",
"apigee/vertex_ai/v1beta/gemini-1.5-flash"
"apigee/whatever-model",
"apigee/v1/whatever-model",
"apigee/vertex_ai/whatever-model",
"apigee/gemini/v1/whatever-model",
"apigee/vertex_ai/v1beta/whatever-model"
};

for (String modelName : validModelStrings) {
Expand Down Expand Up @@ -93,26 +94,26 @@ public void build_withInvalidModelStrings_throwsException() {
public void generateContent_stripsApigeePrefixAndSendsToDelegate() {
when(mockGeminiDelegate.generateContent(any(), anyBoolean())).thenReturn(Flowable.empty());

ApigeeLlm llm = new ApigeeLlm("apigee/gemini/v1/gemini-1.5-flash", mockGeminiDelegate);
ApigeeLlm llm = new ApigeeLlm("apigee/gemini/v1/whatever-model", mockGeminiDelegate);

LlmRequest request =
LlmRequest.builder()
.model("apigee/gemini/v1/gemini-1.5-flash")
.model("apigee/gemini/v1/whatever-model")
.contents(ImmutableList.of(Content.builder().parts(Part.fromText("hi")).build()))
.build();
llm.generateContent(request, true).test().assertNoErrors();

ArgumentCaptor<LlmRequest> requestCaptor = ArgumentCaptor.forClass(LlmRequest.class);
verify(mockGeminiDelegate).generateContent(requestCaptor.capture(), eq(true));
assertThat(requestCaptor.getValue().model()).hasValue("gemini-1.5-flash");
assertThat(requestCaptor.getValue().model()).hasValue("whatever-model");
}

// Add a test to verify the vertexAI flag is set correctly.
@Test
public void generateContent_setsVertexAiFlagCorrectly_withVertexAi() {
ApigeeLlm llm =
ApigeeLlm.builder()
.modelName("apigee/vertex_ai/gemini-1.5-flash")
.modelName("apigee/vertex_ai/whatever-model")
.proxyUrl(PROXY_URL)
.build();
assertThat(llm.getApiClient().vertexAI()).isTrue();
Expand All @@ -122,8 +123,10 @@ public void generateContent_setsVertexAiFlagCorrectly_withVertexAi() {
public void generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi() {

ApigeeLlm llm =
ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").proxyUrl(PROXY_URL).build();
if (System.getenv("GOOGLE_GENAI_USE_VERTEXAI") != null) {
ApigeeLlm.builder().modelName("apigee/whatever-model").proxyUrl(PROXY_URL).build();
String useVertexAi = System.getenv("GOOGLE_GENAI_USE_VERTEXAI");

if (Objects.equals(useVertexAi, "true") || Objects.equals(useVertexAi, "1")) {
assertThat(llm.getApiClient().vertexAI()).isTrue();
} else {
assertThat(llm.getApiClient().vertexAI()).isFalse();
Expand All @@ -133,7 +136,7 @@ public void generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi() {
@Test
public void generateContent_setsVertexAiFlagCorrectly_withGemini() {
ApigeeLlm llm =
ApigeeLlm.builder().modelName("apigee/gemini/gemini-1.5-flash").proxyUrl(PROXY_URL).build();
ApigeeLlm.builder().modelName("apigee/gemini/whatever-model").proxyUrl(PROXY_URL).build();
assertThat(llm.getApiClient().vertexAI()).isFalse();
}

Expand All @@ -142,11 +145,11 @@ public void generateContent_setsVertexAiFlagCorrectly_withGemini() {
public void generateContent_setsApiVersionCorrectly() {
ImmutableMap<String, String> modelToApiVersion =
ImmutableMap.of(
"apigee/gemini-1.5-flash", "",
"apigee/v1/gemini-1.5-flash", "v1",
"apigee/vertex_ai/gemini-1.5-flash", "",
"apigee/gemini/v1/gemini-1.5-flash", "v1",
"apigee/vertex_ai/v1beta/gemini-1.5-flash", "v1beta");
"apigee/whatever-model", "",
"apigee/v1/whatever-model", "v1",
"apigee/vertex_ai/whatever-model", "",
"apigee/gemini/v1/whatever-model", "v1",
"apigee/vertex_ai/v1beta/whatever-model", "v1beta");

for (Map.Entry<String, String> entry : modelToApiVersion.entrySet()) {
String modelName = entry.getKey();
Expand All @@ -165,7 +168,7 @@ public void build_withCustomHeaders_setsHeadersInHttpOptions() {
ImmutableMap<String, String> customHeaders = ImmutableMap.of("X-Test-Header", "TestValue");
ApigeeLlm llm =
ApigeeLlm.builder()
.modelName("apigee/gemini-1.5-flash")
.modelName("apigee/whatever-model")
.proxyUrl(PROXY_URL)
.customHeaders(customHeaders)
.build();
Expand All @@ -192,14 +195,14 @@ public void build_withTrailingSlashInModel_parsesVersionAndModelId() {
public void build_withoutProxyUrl_readsFromEnvironment() {
String envProxyUrl = System.getenv("APIGEE_PROXY_URL");
if (envProxyUrl != null) {
ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").build();
ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/whatever-model").build();
assertThat(llm.getHttpOptions().baseUrl()).hasValue(envProxyUrl);
} else {
assertThrows(
IllegalArgumentException.class,
() -> ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").build());
() -> ApigeeLlm.builder().modelName("apigee/whatever-model").build());
ApigeeLlm llm =
ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/gemini-1.5-flash").build();
ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/whatever-model").build();
assertThat(llm.getHttpOptions().baseUrl()).hasValue(PROXY_URL);
}
}
Expand Down
Loading