diff --git a/TASK_QUEUE.json b/TASK_QUEUE.json new file mode 100644 index 000000000..38901a5a5 --- /dev/null +++ b/TASK_QUEUE.json @@ -0,0 +1,129 @@ +{ + "project": "temporal-spring-ai", + "tasks": [ + { + "id": "T13", + "title": "Remove includeBuild from samples-java", + "description": "Once temporal-spring-ai is published to Maven Central, remove the includeBuild('../sdk-java') block from samples-java/settings.gradle and the grpc-util workaround from core/build.gradle.", + "severity": "low", + "category": "cleanup", + "depends_on": [], + "status": "blocked", + "notes": "Blocked on SDK release. Not actionable yet." + }, + { + "id": "T15", + "title": "Change default tool execution to run in workflow context", + "description": "Currently unannotated tools passed to defaultTools() are rejected. Change so they execute directly in workflow context by default \u2014 user is responsible for determinism. Remove @DeterministicTool annotation (no longer needed since direct execution is the default). Remove SandboxingAdvisor, LocalActivityToolCallbackWrapper, ExecuteToolLocalActivity, and ExecuteToolLocalActivityImpl. Remove ExecuteToolLocalActivityImpl registration from SpringAiPlugin. Keep @SideEffectTool as a convenience for wrapping in Workflow.sideEffect(). Keep activity stub / nexus stub auto-detection as shortcuts.", + "severity": "high", + "category": "refactor", + "depends_on": [], + "status": "completed", + "notes": "Blocked on PR review discussion with tconley1428. Agreed on direction but need to finalize details before implementing. Proposed design: https://github.com/temporalio/sdk-java/pull/2829#discussion_r3060711651" + }, + { + "id": "T22", + "title": "Discuss: temporal-spring-ai-starter artifact for easier onboarding", + "description": "Create a temporal-spring-ai-starter artifact (POM-only, no code) that transitively pulls in temporal-spring-ai, temporal-sdk, and temporal-spring-boot-starter. Prevents version mismatches and ClassNotFoundException for users starting from scratch.", + "severity": "medium", + "category": "discussion", + "depends_on": [], + "status": "blocked", + "notes": "DABH review comment: https://github.com/temporalio/sdk-java/pull/2829/files#r3053808755. Do after merging the main PR \u2014 not urgent for initial landing." + }, + { + "id": "T23", + "title": "Discuss: ActivityMcpClient capability caching and replay", + "description": "ActivityMcpClient caches getServerCapabilities() after first call. DABH asks if stale cache is a replay concern. Probably fine (activity result is in history, so replay uses the original value \u2014 which is correct). But worth confirming the design intent.", + "severity": "low", + "category": "discussion", + "depends_on": [], + "status": "completed", + "notes": "Replied to DABH: cache prevents non-determinism (live vs replay would diverge without it). Standard MCP practice." + }, + { + "id": "T24", + "title": "Discuss: Change ChatModelTypes.rawContent from Object to String", + "description": "Change ChatModelTypes.Message rawContent from Object to String. Spring AI's Content.getText() returns String. We always cast to String on both sides anyway. Object type gives false flexibility that would ClassCastException at runtime.", + "severity": "low", + "category": "bugfix", + "depends_on": [], + "status": "completed", + "notes": "DABH review comment: https://github.com/temporalio/sdk-java/pull/2829/files#r3054049714. Verified Spring AI Message interface uses String, not Object." + }, + { + "id": "T25", + "title": "Reply: compatibility matrix in docs", + "description": "DABH suggests documenting the compatibility matrix (Java version, Spring Boot version, Spring AI version). Acknowledge and defer to a docs PR.", + "severity": "low", + "category": "reply", + "depends_on": [], + "status": "completed", + "notes": "Replied to DABH acknowledging. Follow-up task T29 created." + }, + { + "id": "T26", + "title": "Reply: SandboxingAdvisor lacks tests", + "description": "DABH notes SandboxingAdvisor has no tests. Likely moot if T15 removes it. Reply explaining that.", + "severity": "low", + "category": "reply", + "depends_on": [ + "T15" + ], + "status": "superseded", + "notes": "DABH review comment: https://github.com/temporalio/sdk-java/pull/2829/files#r3053836427 Moot once T15 removes these classes." + }, + { + "id": "T27", + "title": "Reply: ToolContext silently dropped in LocalActivityToolCallbackWrapper", + "description": "DABH notes ToolContext is silently ignored. Likely moot if T15 removes LocalActivityToolCallbackWrapper. Reply explaining that.", + "severity": "low", + "category": "reply", + "depends_on": [ + "T15" + ], + "status": "superseded", + "notes": "DABH review comment: https://github.com/temporalio/sdk-java/pull/2829/files#r3054036773 Moot once T15 removes these classes." + }, + { + "id": "T28", + "title": "Restore VectorStorePlugin and EmbeddingModelPlugin as public classes", + "description": "Restore the plugin subclasses so users not using auto-config can create them manually (e.g. new VectorStorePlugin(vectorStore)). Auto-config uses them too. Reverts the builder-inline approach from the earlier refactor.", + "severity": "medium", + "category": "refactor", + "depends_on": [], + "status": "completed", + "notes": "tconley1428 review comments on SpringAiPlugin.java and SpringAiVectorStoreAutoConfiguration.java" + }, + { + "id": "T29", + "title": "Add README with compatibility matrix to temporal-spring-ai module", + "description": "Document supported versions: Java 17+, Spring Boot 3.x+, Spring AI 1.1.0, Temporal SDK 1.33.0+.", + "severity": "high", + "category": "docs", + "depends_on": [], + "status": "completed", + "notes": "Do in this PR, not a follow-up." + }, + { + "id": "T30", + "title": "Fix CI: Edge build fails with Java version mismatch", + "description": "Edge CI sets edgeDepsTest which compiles temporal-sdk at Java 21. Our module hardcodes Java 17, causing Gradle to reject the dependency. Fix: use 21 when edgeDepsTest is set, 17 otherwise.", + "severity": "high", + "category": "bugfix", + "depends_on": [], + "status": "completed", + "notes": "Edge CI log showed: Dependency resolution is looking for a library compatible with JVM runtime version 17, but temporal-sdk is only compatible with JVM runtime version 21 or newer." + }, + { + "id": "T31", + "title": "Fix CI: Docker build fails with Java 11 (release 17 not supported)", + "description": "Docker CI runs Java 11 which cannot compile --release 17. Conditionally exclude temporal-spring-ai from settings.gradle and BOM when JDK < 17.", + "severity": "high", + "category": "bugfix", + "depends_on": [], + "status": "completed", + "notes": "Docker CI log showed: error: release version 17 not supported. JAVA_HOME was Java 11." + } + ] +} diff --git a/settings.gradle b/settings.gradle index 918ceaa28..1cb980a89 100644 --- a/settings.gradle +++ b/settings.gradle @@ -6,6 +6,11 @@ include 'temporal-testing' include 'temporal-test-server' include 'temporal-opentracing' include 'temporal-kotlin' +// temporal-spring-ai requires Java 17+ (Spring AI dependency). +// Exclude from builds running on older JDKs. +if (JavaVersion.current().isCompatibleWith(JavaVersion.VERSION_17)) { + include 'temporal-spring-ai' +} include 'temporal-spring-boot-autoconfigure' include 'temporal-spring-boot-starter' include 'temporal-remote-data-encoder' diff --git a/temporal-bom/build.gradle b/temporal-bom/build.gradle index 8f5a8971d..12ccaacea 100644 --- a/temporal-bom/build.gradle +++ b/temporal-bom/build.gradle @@ -12,6 +12,9 @@ dependencies { api project(':temporal-sdk') api project(':temporal-serviceclient') api project(':temporal-shaded') + if (JavaVersion.current().isCompatibleWith(JavaVersion.VERSION_17)) { + api project(':temporal-spring-ai') + } api project(':temporal-spring-boot-autoconfigure') api project(':temporal-spring-boot-starter') api project(':temporal-test-server') diff --git a/temporal-spring-ai/README.md b/temporal-spring-ai/README.md new file mode 100644 index 000000000..04ef610a2 --- /dev/null +++ b/temporal-spring-ai/README.md @@ -0,0 +1,54 @@ +# temporal-spring-ai + +Integrates [Spring AI](https://docs.spring.io/spring-ai/reference/) with [Temporal](https://temporal.io/) workflows, making AI model calls, tool execution, vector store operations, embeddings, and MCP tool calls durable Temporal primitives. + +## Compatibility + +| Dependency | Minimum Version | +|---|---| +| Java | 17 | +| Spring Boot | 3.x | +| Spring AI | 1.1.0 | +| Temporal Java SDK | 1.33.0 | + +## Quick Start + +Add the dependency (Maven): + +```xml + + io.temporal + temporal-spring-ai + ${temporal-sdk.version} + +``` + +You also need `temporal-spring-boot-starter` and a Spring AI model starter (e.g. `spring-ai-starter-model-openai`). + +The plugin auto-registers `ChatModelActivity` with all Temporal workers. In your workflow: + +```java +@WorkflowInit +public MyWorkflowImpl(String goal) { + ActivityChatModel chatModel = ActivityChatModel.forDefault(); + this.chatClient = TemporalChatClient.builder(chatModel) + .defaultSystem("You are a helpful assistant.") + .defaultTools(myActivityStub) + .build(); +} + +@Override +public String run(String goal) { + return chatClient.prompt().user(goal).call().content(); +} +``` + +## Optional Integrations + +These are auto-configured when their dependencies are on the classpath: + +| Feature | Dependency | What it registers | +|---|---|---| +| Vector Store | `spring-ai-rag` | `VectorStoreActivity` | +| Embeddings | `spring-ai-rag` | `EmbeddingModelActivity` | +| MCP | `spring-ai-mcp` | `McpClientActivity` | diff --git a/temporal-spring-ai/build.gradle b/temporal-spring-ai/build.gradle new file mode 100644 index 000000000..21cc09c9b --- /dev/null +++ b/temporal-spring-ai/build.gradle @@ -0,0 +1,62 @@ +description = '''Temporal Java SDK Spring AI Plugin''' + +ext { + springAiVersion = '1.1.0' + // Spring AI requires Spring Boot 3.x / Java 17+ + springBootVersionForSpringAi = "$springBoot3Version" +} + +// Spring AI requires Java 17+, override the default Java 8 target from java.gradle. +// When edgeDepsTest is set, use 21 to match other modules (avoids Gradle JVM compatibility rejection). +ext { + springAiJavaVersion = project.hasProperty("edgeDepsTest") ? JavaVersion.VERSION_21 : JavaVersion.VERSION_17 + springAiRelease = project.hasProperty("edgeDepsTest") ? '21' : '17' +} + +java { + sourceCompatibility = springAiJavaVersion + targetCompatibility = springAiJavaVersion +} + +compileJava { + options.compilerArgs.removeAll(['--release', '8']) + options.compilerArgs.addAll(['--release', springAiRelease]) +} + +compileTestJava { + options.compilerArgs.removeAll(['--release', '8']) + options.compilerArgs.addAll(['--release', springAiRelease]) +} + +dependencies { + api(platform("org.springframework.boot:spring-boot-dependencies:$springBootVersionForSpringAi")) + api(platform("org.springframework.ai:spring-ai-bom:$springAiVersion")) + + // this module shouldn't carry temporal-sdk with it, especially for situations when users may be using a shaded artifact + compileOnly project(':temporal-sdk') + compileOnly project(':temporal-spring-boot-autoconfigure') + + api 'org.springframework.boot:spring-boot-autoconfigure' + api 'org.springframework.ai:spring-ai-client-chat' + + implementation 'org.springframework.boot:spring-boot-starter' + + // Optional: Vector store support + compileOnly 'org.springframework.ai:spring-ai-rag' + + // Optional: MCP (Model Context Protocol) support + compileOnly 'org.springframework.ai:spring-ai-mcp' + + testImplementation project(':temporal-sdk') + testImplementation project(':temporal-testing') + testImplementation "org.mockito:mockito-core:${mockitoVersion}" + testImplementation 'org.springframework.boot:spring-boot-starter-test' + testImplementation 'org.springframework.ai:spring-ai-rag' + + testRuntimeOnly group: 'ch.qos.logback', name: 'logback-classic', version: "${logbackVersion}" + testRuntimeOnly "org.junit.platform:junit-platform-launcher" +} + +tasks.test { + useJUnitPlatform() +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivity.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivity.java new file mode 100644 index 000000000..19caf9a54 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivity.java @@ -0,0 +1,25 @@ +package io.temporal.springai.activity; + +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.springai.model.ChatModelTypes; + +/** + * Temporal activity interface for calling Spring AI chat models. + * + *

This activity wraps a Spring AI {@link org.springframework.ai.chat.model.ChatModel} and makes + * it callable from within Temporal workflows. The activity handles serialization of prompts and + * responses, enabling durable AI conversations with automatic retries and timeout handling. + */ +@ActivityInterface +public interface ChatModelActivity { + + /** + * Calls the chat model with the given input. + * + * @param input the chat model input containing messages, options, and tool definitions + * @return the chat model output containing generated responses and metadata + */ + @ActivityMethod + ChatModelTypes.ChatModelActivityOutput callChatModel(ChatModelTypes.ChatModelActivityInput input); +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivityImpl.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivityImpl.java new file mode 100644 index 000000000..4eca09e67 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/ChatModelActivityImpl.java @@ -0,0 +1,276 @@ +package io.temporal.springai.activity; + +import io.temporal.springai.model.ChatModelTypes; +import io.temporal.springai.model.ChatModelTypes.Message; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.content.Media; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; + +/** + * Implementation of {@link ChatModelActivity} that delegates to a Spring AI {@link ChatModel}. + * + *

This implementation handles the conversion between Temporal-serializable types ({@link + * ChatModelTypes}) and Spring AI types. + * + *

Supports multiple chat models. The model to use is determined by the {@code modelName} field + * in the input. If no model name is specified, the default model is used. + */ +public class ChatModelActivityImpl implements ChatModelActivity { + + private final Map chatModels; + private final String defaultModelName; + + /** + * Creates an activity implementation with a single chat model. + * + * @param chatModel the chat model to use + */ + public ChatModelActivityImpl(ChatModel chatModel) { + this.chatModels = Map.of("default", chatModel); + this.defaultModelName = "default"; + } + + /** + * Creates an activity implementation with multiple chat models. + * + * @param chatModels map of model names to chat models + * @param defaultModelName the name of the default model to use when none is specified + */ + public ChatModelActivityImpl(Map chatModels, String defaultModelName) { + this.chatModels = chatModels; + this.defaultModelName = defaultModelName; + } + + @Override + public ChatModelTypes.ChatModelActivityOutput callChatModel( + ChatModelTypes.ChatModelActivityInput input) { + ChatModel chatModel = resolveChatModel(input.modelName()); + Prompt prompt = createPrompt(input); + ChatResponse response = chatModel.call(prompt); + return toOutput(response); + } + + private ChatModel resolveChatModel(String modelName) { + String name = (modelName != null && !modelName.isEmpty()) ? modelName : defaultModelName; + ChatModel model = chatModels.get(name); + if (model == null) { + throw new IllegalArgumentException( + "No chat model with name '" + name + "'. Available models: " + chatModels.keySet()); + } + return model; + } + + private Prompt createPrompt(ChatModelTypes.ChatModelActivityInput input) { + List messages = + input.messages().stream().map(this::toSpringMessage).collect(Collectors.toList()); + + ToolCallingChatOptions.Builder optionsBuilder = + ToolCallingChatOptions.builder() + .internalToolExecutionEnabled(false); // Let workflow handle tool execution + + if (input.modelOptions() != null) { + ChatModelTypes.ModelOptions opts = input.modelOptions(); + if (opts.model() != null) optionsBuilder.model(opts.model()); + if (opts.temperature() != null) optionsBuilder.temperature(opts.temperature()); + if (opts.maxTokens() != null) optionsBuilder.maxTokens(opts.maxTokens()); + if (opts.topP() != null) optionsBuilder.topP(opts.topP()); + if (opts.topK() != null) optionsBuilder.topK(opts.topK()); + if (opts.frequencyPenalty() != null) optionsBuilder.frequencyPenalty(opts.frequencyPenalty()); + if (opts.presencePenalty() != null) optionsBuilder.presencePenalty(opts.presencePenalty()); + if (opts.stopSequences() != null) optionsBuilder.stopSequences(opts.stopSequences()); + } + + // Add tool callbacks (stubs that provide definitions but won't be executed + // since internalToolExecutionEnabled is false) + if (!CollectionUtils.isEmpty(input.tools())) { + List toolCallbacks = + input.tools().stream() + .map( + tool -> + createStubToolCallback( + tool.function().name(), + tool.function().description(), + tool.function().jsonSchema())) + .collect(Collectors.toList()); + optionsBuilder.toolCallbacks(toolCallbacks); + } + + ToolCallingChatOptions chatOptions = optionsBuilder.build(); + + return Prompt.builder().messages(messages).chatOptions(chatOptions).build(); + } + + private org.springframework.ai.chat.messages.Message toSpringMessage(Message message) { + return switch (message.role()) { + case SYSTEM -> new SystemMessage(message.rawContent()); + case USER -> { + UserMessage.Builder builder = UserMessage.builder().text(message.rawContent()); + if (!CollectionUtils.isEmpty(message.mediaContents())) { + builder.media( + message.mediaContents().stream().map(this::toMedia).collect(Collectors.toList())); + } + yield builder.build(); + } + case ASSISTANT -> + AssistantMessage.builder() + .content(message.rawContent()) + .properties(Map.of()) + .toolCalls( + message.toolCalls() != null + ? message.toolCalls().stream() + .map( + tc -> + new AssistantMessage.ToolCall( + tc.id(), + tc.type(), + tc.function().name(), + tc.function().arguments())) + .collect(Collectors.toList()) + : List.of()) + .media( + message.mediaContents() != null + ? message.mediaContents().stream() + .map(this::toMedia) + .collect(Collectors.toList()) + : List.of()) + .build(); + case TOOL -> + ToolResponseMessage.builder() + .responses( + List.of( + new ToolResponseMessage.ToolResponse( + message.toolCallId(), message.name(), message.rawContent()))) + .build(); + }; + } + + private Media toMedia(ChatModelTypes.MediaContent mediaContent) { + MimeType mimeType = MimeType.valueOf(mediaContent.mimeType()); + if (mediaContent.uri() != null) { + try { + return new Media(mimeType, new URI(mediaContent.uri())); + } catch (URISyntaxException e) { + throw new RuntimeException("Invalid media URI: " + mediaContent.uri(), e); + } + } else if (mediaContent.data() != null) { + return new Media(mimeType, new ByteArrayResource(mediaContent.data())); + } + throw new IllegalArgumentException("Media content must have either uri or data"); + } + + private ChatModelTypes.ChatModelActivityOutput toOutput(ChatResponse response) { + List generations = + response.getResults().stream() + .map( + gen -> + new ChatModelTypes.ChatModelActivityOutput.Generation( + fromAssistantMessage(gen.getOutput()))) + .collect(Collectors.toList()); + + ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata metadata = null; + if (response.getMetadata() != null) { + var rateLimit = response.getMetadata().getRateLimit(); + var usage = response.getMetadata().getUsage(); + + metadata = + new ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata( + response.getMetadata().getModel(), + rateLimit != null + ? new ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata.RateLimit( + rateLimit.getRequestsLimit(), + rateLimit.getRequestsRemaining(), + rateLimit.getRequestsReset(), + rateLimit.getTokensLimit(), + rateLimit.getTokensRemaining(), + rateLimit.getTokensReset()) + : null, + usage != null + ? new ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata.Usage( + usage.getPromptTokens() != null ? usage.getPromptTokens().intValue() : null, + usage.getCompletionTokens() != null + ? usage.getCompletionTokens().intValue() + : null, + usage.getTotalTokens() != null ? usage.getTotalTokens().intValue() : null) + : null); + } + + return new ChatModelTypes.ChatModelActivityOutput(generations, metadata); + } + + private Message fromAssistantMessage(AssistantMessage assistantMessage) { + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = + assistantMessage.getToolCalls().stream() + .map( + tc -> + new Message.ToolCall( + tc.id(), + tc.type(), + new Message.ChatCompletionFunction(tc.name(), tc.arguments()))) + .collect(Collectors.toList()); + } + + List mediaContents = null; + if (!CollectionUtils.isEmpty(assistantMessage.getMedia())) { + mediaContents = + assistantMessage.getMedia().stream().map(this::fromMedia).collect(Collectors.toList()); + } + + return new Message( + assistantMessage.getText(), Message.Role.ASSISTANT, null, null, toolCalls, mediaContents); + } + + private ChatModelTypes.MediaContent fromMedia(Media media) { + String mimeType = media.getMimeType().toString(); + if (media.getData() instanceof String uri) { + return new ChatModelTypes.MediaContent(mimeType, uri); + } else if (media.getData() instanceof byte[] data) { + return new ChatModelTypes.MediaContent(mimeType, data); + } + throw new IllegalArgumentException( + "Unsupported media data type: " + media.getData().getClass()); + } + + /** + * Creates a stub ToolCallback that provides a tool definition but throws if called. This is used + * because Spring AI's ChatModel API requires ToolCallbacks, but we only need to inform the model + * about available tools - actual execution happens in the workflow (since + * internalToolExecutionEnabled is false). + */ + private ToolCallback createStubToolCallback(String name, String description, String inputSchema) { + ToolDefinition toolDefinition = + ToolDefinition.builder() + .name(name) + .description(description) + .inputSchema(inputSchema) + .build(); + + return new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + throw new UnsupportedOperationException( + "Tool execution should be handled by the workflow, not the activity. " + + "Ensure internalToolExecutionEnabled is set to false."); + } + }; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivity.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivity.java new file mode 100644 index 000000000..8deed81f2 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivity.java @@ -0,0 +1,62 @@ +package io.temporal.springai.activity; + +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.springai.model.EmbeddingModelTypes; + +/** + * Temporal activity interface for Spring AI EmbeddingModel operations. + * + *

This activity wraps Spring AI's {@link org.springframework.ai.embedding.EmbeddingModel}, + * making embedding generation durable and retriable within Temporal workflows. + * + *

Example usage in a workflow: + * + *

{@code
+ * EmbeddingModelActivity embeddingModel = Workflow.newActivityStub(
+ *     EmbeddingModelActivity.class,
+ *     ActivityOptions.newBuilder()
+ *         .setStartToCloseTimeout(Duration.ofMinutes(2))
+ *         .build());
+ *
+ * // Embed single text
+ * EmbedOutput result = embeddingModel.embed(new EmbedTextInput("Hello world"));
+ * List vector = result.embedding();
+ *
+ * // Embed batch
+ * EmbedBatchOutput batchResult = embeddingModel.embedBatch(
+ *     new EmbedBatchInput(List.of("text1", "text2", "text3")));
+ * }
+ */ +@ActivityInterface +public interface EmbeddingModelActivity { + + /** + * Generates an embedding for a single text. + * + * @param input the text to embed + * @return the embedding vector + */ + @ActivityMethod + EmbeddingModelTypes.EmbedOutput embed(EmbeddingModelTypes.EmbedTextInput input); + + /** + * Generates embeddings for multiple texts in a single request. + * + *

This is more efficient than calling {@link #embed} multiple times when you have multiple + * texts to embed. + * + * @param input the texts to embed + * @return the embedding vectors with metadata + */ + @ActivityMethod + EmbeddingModelTypes.EmbedBatchOutput embedBatch(EmbeddingModelTypes.EmbedBatchInput input); + + /** + * Returns the dimensionality of the embedding vectors produced by this model. + * + * @return the number of dimensions + */ + @ActivityMethod + EmbeddingModelTypes.DimensionsOutput dimensions(); +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivityImpl.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivityImpl.java new file mode 100644 index 000000000..d0c082381 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/EmbeddingModelActivityImpl.java @@ -0,0 +1,64 @@ +package io.temporal.springai.activity; + +import io.temporal.springai.model.EmbeddingModelTypes; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingResponse; + +/** + * Implementation of {@link EmbeddingModelActivity} that delegates to a Spring AI {@link + * EmbeddingModel}. + * + *

This implementation handles the conversion between Temporal-serializable types ({@link + * EmbeddingModelTypes}) and Spring AI types. + */ +public class EmbeddingModelActivityImpl implements EmbeddingModelActivity { + + private final EmbeddingModel embeddingModel; + + public EmbeddingModelActivityImpl(EmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + } + + @Override + public EmbeddingModelTypes.EmbedOutput embed(EmbeddingModelTypes.EmbedTextInput input) { + float[] embedding = embeddingModel.embed(input.text()); + return new EmbeddingModelTypes.EmbedOutput(embedding); + } + + @Override + public EmbeddingModelTypes.EmbedBatchOutput embedBatch( + EmbeddingModelTypes.EmbedBatchInput input) { + EmbeddingResponse response = embeddingModel.embedForResponse(input.texts()); + + List results = + IntStream.range(0, response.getResults().size()) + .mapToObj( + i -> { + var embedding = response.getResults().get(i); + return new EmbeddingModelTypes.EmbeddingResult(i, embedding.getOutput()); + }) + .collect(Collectors.toList()); + + EmbeddingModelTypes.EmbeddingMetadata metadata = null; + if (response.getMetadata() != null) { + var usage = response.getMetadata().getUsage(); + metadata = + new EmbeddingModelTypes.EmbeddingMetadata( + response.getMetadata().getModel(), + usage != null && usage.getTotalTokens() != null + ? usage.getTotalTokens().intValue() + : null, + embeddingModel.dimensions()); + } + + return new EmbeddingModelTypes.EmbedBatchOutput(results, metadata); + } + + @Override + public EmbeddingModelTypes.DimensionsOutput dimensions() { + return new EmbeddingModelTypes.DimensionsOutput(embeddingModel.dimensions()); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivity.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivity.java new file mode 100644 index 000000000..51747e645 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivity.java @@ -0,0 +1,59 @@ +package io.temporal.springai.activity; + +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.springai.model.VectorStoreTypes; + +/** + * Temporal activity interface for Spring AI VectorStore operations. + * + *

This activity wraps Spring AI's {@link org.springframework.ai.vectorstore.VectorStore}, making + * vector database operations durable and retriable within Temporal workflows. + * + *

Example usage in a workflow: + * + *

{@code
+ * VectorStoreActivity vectorStore = Workflow.newActivityStub(
+ *     VectorStoreActivity.class,
+ *     ActivityOptions.newBuilder()
+ *         .setStartToCloseTimeout(Duration.ofMinutes(5))
+ *         .build());
+ *
+ * // Add documents
+ * vectorStore.addDocuments(new AddDocumentsInput(documents));
+ *
+ * // Search
+ * SearchOutput results = vectorStore.similaritySearch(new SearchInput("query", 10));
+ * }
+ */ +@ActivityInterface +public interface VectorStoreActivity { + + /** + * Adds documents to the vector store. + * + *

If the documents don't have pre-computed embeddings, the vector store will use its + * configured EmbeddingModel to generate them. + * + * @param input the documents to add + */ + @ActivityMethod + void addDocuments(VectorStoreTypes.AddDocumentsInput input); + + /** + * Deletes documents from the vector store by their IDs. + * + * @param input the IDs of documents to delete + */ + @ActivityMethod + void deleteByIds(VectorStoreTypes.DeleteByIdsInput input); + + /** + * Performs a similarity search in the vector store. + * + * @param input the search parameters + * @return the search results with similarity scores + */ + @ActivityMethod + VectorStoreTypes.SearchOutput similaritySearch(VectorStoreTypes.SearchInput input); +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivityImpl.java b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivityImpl.java new file mode 100644 index 000000000..80ce75518 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/activity/VectorStoreActivityImpl.java @@ -0,0 +1,98 @@ +package io.temporal.springai.activity; + +import io.temporal.springai.model.VectorStoreTypes; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; + +/** + * Implementation of {@link VectorStoreActivity} that delegates to a Spring AI {@link VectorStore}. + * + *

This implementation handles the conversion between Temporal-serializable types ({@link + * VectorStoreTypes}) and Spring AI types. + */ +public class VectorStoreActivityImpl implements VectorStoreActivity { + + private final VectorStore vectorStore; + private final FilterExpressionTextParser filterParser = new FilterExpressionTextParser(); + + public VectorStoreActivityImpl(VectorStore vectorStore) { + this.vectorStore = vectorStore; + } + + @Override + public void addDocuments(VectorStoreTypes.AddDocumentsInput input) { + List documents = + input.documents().stream().map(this::toSpringDocument).collect(Collectors.toList()); + vectorStore.add(documents); + } + + @Override + public void deleteByIds(VectorStoreTypes.DeleteByIdsInput input) { + vectorStore.delete(input.ids()); + } + + @Override + public VectorStoreTypes.SearchOutput similaritySearch(VectorStoreTypes.SearchInput input) { + SearchRequest.Builder requestBuilder = + SearchRequest.builder().query(input.query()).topK(input.topK()); + + if (input.similarityThreshold() != null) { + requestBuilder.similarityThreshold(input.similarityThreshold()); + } + + if (input.filterExpression() != null && !input.filterExpression().isBlank()) { + requestBuilder.filterExpression(filterParser.parse(input.filterExpression())); + } + + List results = vectorStore.similaritySearch(requestBuilder.build()); + + List searchResults = + results.stream() + .map(doc -> new VectorStoreTypes.SearchResult(fromSpringDocument(doc), doc.getScore())) + .collect(Collectors.toList()); + + return new VectorStoreTypes.SearchOutput(searchResults); + } + + private Document toSpringDocument(VectorStoreTypes.Document doc) { + Document.Builder builder = Document.builder().id(doc.id()).text(doc.text()); + + if (doc.metadata() != null && !doc.metadata().isEmpty()) { + builder.metadata(new HashMap<>(doc.metadata())); + } + + return builder.build(); + } + + private VectorStoreTypes.Document fromSpringDocument(Document doc) { + // Convert metadata, handling potential non-serializable values + Map metadata = new HashMap<>(); + if (doc.getMetadata() != null) { + for (Map.Entry entry : doc.getMetadata().entrySet()) { + Object value = entry.getValue(); + // Only include serializable primitive types + if (value == null + || value instanceof String + || value instanceof Number + || value instanceof Boolean) { + metadata.put(entry.getKey(), value); + } else { + metadata.put(entry.getKey(), value.toString()); + } + } + } + + return new VectorStoreTypes.Document( + doc.getId(), + doc.getText(), + metadata, + null // Don't include embedding in results to reduce payload size + ); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiEmbeddingAutoConfiguration.java b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiEmbeddingAutoConfiguration.java new file mode 100644 index 000000000..286392ed7 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiEmbeddingAutoConfiguration.java @@ -0,0 +1,25 @@ +package io.temporal.springai.autoconfigure; + +import io.temporal.springai.plugin.EmbeddingModelPlugin; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.context.annotation.Bean; + +/** + * Auto-configuration for EmbeddingModel integration with Temporal. + * + *

Conditionally creates an {@link EmbeddingModelPlugin} when {@code spring-ai-rag} is on the + * classpath and an {@link EmbeddingModel} bean is available. + */ +@AutoConfiguration(after = SpringAiTemporalAutoConfiguration.class) +@ConditionalOnClass(name = "org.springframework.ai.embedding.EmbeddingModel") +@ConditionalOnBean(EmbeddingModel.class) +public class SpringAiEmbeddingAutoConfiguration { + + @Bean + public EmbeddingModelPlugin embeddingModelPlugin(EmbeddingModel embeddingModel) { + return new EmbeddingModelPlugin(embeddingModel); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiMcpAutoConfiguration.java b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiMcpAutoConfiguration.java new file mode 100644 index 000000000..0fa299f85 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiMcpAutoConfiguration.java @@ -0,0 +1,22 @@ +package io.temporal.springai.autoconfigure; + +import io.temporal.springai.plugin.McpPlugin; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.context.annotation.Bean; + +/** + * Auto-configuration for MCP (Model Context Protocol) integration with Temporal. + * + *

Conditionally creates a {@link McpPlugin} when {@code spring-ai-mcp} and the MCP client + * library are on the classpath. + */ +@AutoConfiguration(after = SpringAiTemporalAutoConfiguration.class) +@ConditionalOnClass(name = "io.modelcontextprotocol.client.McpSyncClient") +public class SpringAiMcpAutoConfiguration { + + @Bean + public McpPlugin mcpPlugin() { + return new McpPlugin(); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiTemporalAutoConfiguration.java b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiTemporalAutoConfiguration.java new file mode 100644 index 000000000..688a88c68 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiTemporalAutoConfiguration.java @@ -0,0 +1,36 @@ +package io.temporal.springai.autoconfigure; + +import io.temporal.springai.plugin.SpringAiPlugin; +import java.util.Map; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.context.annotation.Bean; + +/** + * Core auto-configuration for the Spring AI Temporal plugin. + * + *

Creates the {@link SpringAiPlugin} bean which registers {@link + * io.temporal.springai.activity.ChatModelActivity} with all Temporal workers. + * + *

Optional integrations are handled by separate auto-configuration classes: + * + *

+ */ +@AutoConfiguration +@ConditionalOnClass( + name = {"org.springframework.ai.chat.model.ChatModel", "io.temporal.worker.Worker"}) +public class SpringAiTemporalAutoConfiguration { + + @Bean + public SpringAiPlugin springAiPlugin( + @Autowired Map chatModels, ObjectProvider primaryChatModel) { + return new SpringAiPlugin(chatModels, primaryChatModel.getIfUnique()); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiVectorStoreAutoConfiguration.java b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiVectorStoreAutoConfiguration.java new file mode 100644 index 000000000..bf2cf1ff8 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/autoconfigure/SpringAiVectorStoreAutoConfiguration.java @@ -0,0 +1,25 @@ +package io.temporal.springai.autoconfigure; + +import io.temporal.springai.plugin.VectorStorePlugin; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.context.annotation.Bean; + +/** + * Auto-configuration for VectorStore integration with Temporal. + * + *

Conditionally creates a {@link VectorStorePlugin} when {@code spring-ai-rag} is on the + * classpath and a {@link VectorStore} bean is available. + */ +@AutoConfiguration(after = SpringAiTemporalAutoConfiguration.class) +@ConditionalOnClass(name = "org.springframework.ai.vectorstore.VectorStore") +@ConditionalOnBean(VectorStore.class) +public class SpringAiVectorStoreAutoConfiguration { + + @Bean + public VectorStorePlugin vectorStorePlugin(VectorStore vectorStore) { + return new VectorStorePlugin(vectorStore); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/chat/TemporalChatClient.java b/temporal-spring-ai/src/main/java/io/temporal/springai/chat/TemporalChatClient.java new file mode 100644 index 000000000..94517502e --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/chat/TemporalChatClient.java @@ -0,0 +1,176 @@ +package io.temporal.springai.chat; + +import io.micrometer.observation.ObservationRegistry; +import io.temporal.springai.util.TemporalToolUtil; +import java.util.Map; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.DefaultChatClient; +import org.springframework.ai.chat.client.DefaultChatClientBuilder; +import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * A Temporal-aware implementation of Spring AI's {@link ChatClient} that understands Temporal + * primitives like activity stubs and deterministic tools. + * + *

This client extends Spring AI's {@link DefaultChatClient} to add support for Temporal-specific + * features: + * + *

+ * + *

Example usage in a workflow: + * + *

{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     // Create the activity-backed chat model
+ *     ChatModelActivity chatModelActivity = Workflow.newActivityStub(
+ *             ChatModelActivity.class, activityOptions);
+ *     ActivityChatModel activityChatModel = new ActivityChatModel(chatModelActivity);
+ *
+ *     // Create tools
+ *     WeatherActivity weatherTool = Workflow.newActivityStub(WeatherActivity.class, opts);
+ *
+ *     // Build the Temporal-aware chat client
+ *     this.chatClient = TemporalChatClient.builder(activityChatModel)
+ *             .defaultSystem("You are a helpful assistant.")
+ *             .defaultTools(weatherTool, mathTools)
+ *             .build();
+ * }
+ *
+ * @Override
+ * public String chat(String message) {
+ *     return chatClient.prompt()
+ *             .user(message)
+ *             .call()
+ *             .content();
+ * }
+ * }
+ * + * @see Builder + * @see io.temporal.springai.model.ActivityChatModel + */ +public class TemporalChatClient extends DefaultChatClient { + + /** + * Creates a new TemporalChatClient with the given request specification. + * + * @param defaultChatClientRequest the default request specification + */ + public TemporalChatClient(DefaultChatClientRequestSpec defaultChatClientRequest) { + super(defaultChatClientRequest); + } + + /** + * Creates a builder for constructing a TemporalChatClient. + * + * @param chatModel the chat model to use (typically an {@code ActivityChatModel}) + * @return a new builder + */ + public static Builder builder(ChatModel chatModel) { + return builder(chatModel, ObservationRegistry.NOOP, null); + } + + /** + * Creates a builder with observation support. + * + * @param chatModel the chat model to use + * @param observationRegistry the observation registry for metrics + * @param customObservationConvention optional custom observation convention + * @return a new builder + */ + public static Builder builder( + ChatModel chatModel, + ObservationRegistry observationRegistry, + @Nullable ChatClientObservationConvention customObservationConvention) { + Assert.notNull(chatModel, "chatModel cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + return new Builder(chatModel, observationRegistry, customObservationConvention); + } + + /** + * A builder for creating {@link TemporalChatClient} instances that understand Temporal + * primitives. + * + *

This builder extends Spring AI's {@link DefaultChatClientBuilder} to add support for + * Temporal-specific tool types. When you call {@link #defaultTools(Object...)}, the builder + * automatically detects and converts: + * + *

+ * + * @see TemporalToolUtil + */ + public static class Builder extends DefaultChatClientBuilder { + + /** + * Creates a new builder for the given chat model. + * + * @param chatModel the chat model to use + */ + public Builder(ChatModel chatModel) { + super(chatModel, ObservationRegistry.NOOP, null, null); + } + + /** + * Creates a new builder with observation support. + * + * @param chatModel the chat model to use + * @param observationRegistry the observation registry for metrics + * @param customObservationConvention optional custom observation convention + */ + public Builder( + ChatModel chatModel, + ObservationRegistry observationRegistry, + @Nullable ChatClientObservationConvention customObservationConvention) { + super(chatModel, observationRegistry, customObservationConvention, null); + } + + /** + * Sets the default tools for all requests. + * + *

Activity stubs and Nexus stubs are auto-detected and executed as durable operations. + * {@code @SideEffectTool} classes are wrapped in {@code Workflow.sideEffect()}. Everything else + * executes directly in workflow context — the user is responsible for determinism. + * + * @param toolObjects the tool objects (activity stubs, {@code @SideEffectTool} instances, plain + * {@code @Tool} objects, etc.) + * @return this builder + */ + @Override + public ChatClient.Builder defaultTools(Object... toolObjects) { + Assert.notNull(toolObjects, "toolObjects cannot be null"); + Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements"); + this.defaultRequest.toolCallbacks(TemporalToolUtil.convertTools(toolObjects)); + return this; + } + + /** + * Tool context is not supported in Temporal workflows. + * + *

Tool context requires mutable state that cannot be safely passed through Temporal's + * serialization boundaries. Use activity parameters or workflow state instead. + * + * @param toolContext ignored + * @return never returns + * @throws UnsupportedOperationException always + */ + @Override + public ChatClient.Builder defaultToolContext(Map toolContext) { + throw new UnsupportedOperationException( + "defaultToolContext is not supported in TemporalChatClient. " + + "Tool context cannot be safely serialized through Temporal activities. " + + "Consider passing required context as activity parameters or workflow state."); + } + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/ActivityMcpClient.java b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/ActivityMcpClient.java new file mode 100644 index 000000000..360412a83 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/ActivityMcpClient.java @@ -0,0 +1,141 @@ +package io.temporal.springai.mcp; + +import io.modelcontextprotocol.spec.McpSchema; +import io.temporal.activity.ActivityOptions; +import io.temporal.common.RetryOptions; +import io.temporal.workflow.Workflow; +import java.time.Duration; +import java.util.Map; + +/** + * A workflow-safe wrapper for MCP (Model Context Protocol) client operations. + * + *

This class provides access to MCP tools within Temporal workflows. All MCP operations are + * executed as activities, providing durability, automatic retries, and timeout handling. + * + *

Usage in Workflows

+ * + *
{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     // Create an MCP client with default options
+ *     ActivityMcpClient mcpClient = ActivityMcpClient.create();
+ *
+ *     // Get tools from all connected MCP servers
+ *     List mcpTools = McpToolCallback.fromMcpClient(mcpClient);
+ *
+ *     // Use with TemporalChatClient
+ *     this.chatClient = TemporalChatClient.builder(chatModel)
+ *             .defaultToolCallbacks(mcpTools)
+ *             .build();
+ * }
+ * }
+ * + *

MCP Server Configuration

+ * + *

MCP servers are configured in the worker's Spring context using Spring AI's MCP client + * configuration. See the Spring AI MCP documentation for details. + * + * @see McpClientActivity + * @see McpToolCallback + */ +public class ActivityMcpClient { + + /** Default timeout for MCP activity calls (30 seconds). */ + public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(30); + + /** Default maximum retry attempts for MCP activity calls. */ + public static final int DEFAULT_MAX_ATTEMPTS = 3; + + private final McpClientActivity activity; + private Map serverCapabilities; + private Map clientInfo; + + /** + * Creates a new ActivityMcpClient with the given activity stub. + * + * @param activity the activity stub for MCP operations + */ + public ActivityMcpClient(McpClientActivity activity) { + this.activity = activity; + } + + /** + * Creates an ActivityMcpClient with default options. + * + *

Must be called from workflow code. + * + * @return a new ActivityMcpClient + */ + public static ActivityMcpClient create() { + return create(DEFAULT_TIMEOUT, DEFAULT_MAX_ATTEMPTS); + } + + /** + * Creates an ActivityMcpClient with custom options. + * + *

Must be called from workflow code. + * + * @param timeout the activity start-to-close timeout + * @param maxAttempts the maximum number of retry attempts + * @return a new ActivityMcpClient + */ + public static ActivityMcpClient create(Duration timeout, int maxAttempts) { + McpClientActivity activity = + Workflow.newActivityStub( + McpClientActivity.class, + ActivityOptions.newBuilder() + .setStartToCloseTimeout(timeout) + .setRetryOptions(RetryOptions.newBuilder().setMaximumAttempts(maxAttempts).build()) + .build()); + return new ActivityMcpClient(activity); + } + + /** + * Gets the server capabilities for all connected MCP clients. + * + *

Results are cached after the first call. + * + * @return map of client name to server capabilities + */ + public Map getServerCapabilities() { + if (serverCapabilities == null) { + serverCapabilities = activity.getServerCapabilities(); + } + return serverCapabilities; + } + + /** + * Gets client info for all connected MCP clients. + * + *

Results are cached after the first call. + * + * @return map of client name to client implementation info + */ + public Map getClientInfo() { + if (clientInfo == null) { + clientInfo = activity.getClientInfo(); + } + return clientInfo; + } + + /** + * Calls a tool on a specific MCP client. + * + * @param clientName the name of the MCP client + * @param request the tool call request + * @return the tool call result + */ + public McpSchema.CallToolResult callTool(String clientName, McpSchema.CallToolRequest request) { + return activity.callTool(clientName, request); + } + + /** + * Lists all available tools from all connected MCP clients. + * + * @return map of client name to list of tools + */ + public Map listTools() { + return activity.listTools(); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivity.java b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivity.java new file mode 100644 index 000000000..5c17ce7d3 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivity.java @@ -0,0 +1,56 @@ +package io.temporal.springai.mcp; + +import io.modelcontextprotocol.spec.McpSchema; +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import java.util.Map; + +/** + * Activity interface for interacting with MCP (Model Context Protocol) clients. + * + *

This activity provides durable access to MCP servers, allowing workflows to discover and call + * MCP tools as Temporal activities with full retry and timeout support. + * + *

The activity implementation ({@link McpClientActivityImpl}) is automatically registered by the + * plugin when MCP clients are available in the Spring context. + * + * @see ActivityMcpClient + * @see McpToolCallback + */ +@ActivityInterface(namePrefix = "MCP-Client-") +public interface McpClientActivity { + + /** + * Gets the server capabilities for all connected MCP clients. + * + * @return map of client name to server capabilities + */ + @ActivityMethod + Map getServerCapabilities(); + + /** + * Gets client info for all connected MCP clients. + * + * @return map of client name to client implementation info + */ + @ActivityMethod + Map getClientInfo(); + + /** + * Calls a tool on a specific MCP client. + * + * @param clientName the name of the MCP client + * @param request the tool call request + * @return the tool call result + */ + @ActivityMethod + McpSchema.CallToolResult callTool(String clientName, McpSchema.CallToolRequest request); + + /** + * Lists all available tools from all connected MCP clients. + * + * @return map of client name to list of tools + */ + @ActivityMethod + Map listTools(); +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivityImpl.java b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivityImpl.java new file mode 100644 index 000000000..439196d04 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpClientActivityImpl.java @@ -0,0 +1,74 @@ +package io.temporal.springai.mcp; + +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; +import io.temporal.failure.ApplicationFailure; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Implementation of {@link McpClientActivity} that delegates to Spring AI MCP clients. + * + *

This activity provides durable access to MCP servers. It is automatically registered by the + * plugin when MCP clients are available in the Spring context. + */ +public class McpClientActivityImpl implements McpClientActivity { + + private final Map mcpClients; + + /** + * Creates an activity implementation with the given MCP clients. + * + * @param mcpClients list of MCP sync clients from Spring context + */ + public McpClientActivityImpl(List mcpClients) { + this.mcpClients = + mcpClients.stream() + .collect( + Collectors.toMap( + c -> c.getClientInfo().name(), + c -> c, + (existing, duplicate) -> { + throw new IllegalArgumentException( + "Duplicate MCP client name: '" + + existing.getClientInfo().name() + + "'. Each MCP client must have a unique name."); + })); + } + + @Override + public Map getServerCapabilities() { + return mcpClients.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getServerCapabilities())); + } + + @Override + public Map getClientInfo() { + return mcpClients.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getClientInfo())); + } + + @Override + public McpSchema.CallToolResult callTool(String clientName, McpSchema.CallToolRequest request) { + McpSyncClient client = mcpClients.get(clientName); + if (client == null) { + throw ApplicationFailure.newBuilder() + .setType("ClientNotFound") + .setMessage( + "MCP client '" + + clientName + + "' not found. Available clients: " + + mcpClients.keySet()) + .setNonRetryable(true) + .build(); + } + return client.callTool(request); + } + + @Override + public Map listTools() { + return mcpClients.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().listTools())); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpToolCallback.java b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpToolCallback.java new file mode 100644 index 000000000..9cf821aae --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/mcp/McpToolCallback.java @@ -0,0 +1,133 @@ +package io.temporal.springai.mcp; + +import io.modelcontextprotocol.spec.McpSchema; +import java.util.List; +import java.util.Map; +import org.springframework.ai.mcp.McpToolUtils; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.DefaultToolDefinition; +import org.springframework.ai.tool.definition.ToolDefinition; + +/** + * A {@link ToolCallback} implementation that executes MCP tools via Temporal activities. + * + *

This class bridges MCP tools with Spring AI's tool calling system, allowing AI models to call + * MCP server tools through durable Temporal activities. + * + *

Usage in Workflows

+ * + *
{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     // Create an MCP client
+ *     ActivityMcpClient mcpClient = ActivityMcpClient.create();
+ *
+ *     // Convert MCP tools to ToolCallbacks
+ *     List mcpTools = McpToolCallback.fromMcpClient(mcpClient);
+ *
+ *     // Use with TemporalChatClient
+ *     this.chatClient = TemporalChatClient.builder(chatModel)
+ *             .defaultToolCallbacks(mcpTools)
+ *             .build();
+ * }
+ * }
+ * + * @see ActivityMcpClient + * @see McpClientActivity + */ +public class McpToolCallback implements ToolCallback { + + private final ActivityMcpClient client; + private final String clientName; + private final McpSchema.Tool tool; + private final ToolDefinition toolDefinition; + + /** + * Creates a new McpToolCallback for a specific MCP tool. + * + * @param client the MCP client to use for tool calls + * @param clientName the name of the MCP client that provides this tool + * @param tool the tool definition + * @param toolNamePrefix the prefix to use for the tool name (usually the MCP server name) + */ + public McpToolCallback( + ActivityMcpClient client, String clientName, McpSchema.Tool tool, String toolNamePrefix) { + this.client = client; + this.clientName = clientName; + this.tool = tool; + + // Cache the tool definition at construction time to avoid activity calls in queries + String prefixedName = McpToolUtils.prefixedToolName(toolNamePrefix, tool.name()); + this.toolDefinition = + DefaultToolDefinition.builder() + .name(prefixedName) + .description(tool.description()) + .inputSchema(ModelOptionsUtils.toJsonString(tool.inputSchema())) + .build(); + } + + /** + * Creates ToolCallbacks for all tools from all MCP clients. + * + *

This method discovers all available tools from the MCP clients and wraps them as + * ToolCallbacks that execute through Temporal activities. + * + * @param client the MCP client + * @return list of ToolCallbacks for all discovered tools + */ + public static List fromMcpClient(ActivityMcpClient client) { + // Get client info upfront for tool name prefixes + Map clientInfo = client.getClientInfo(); + + Map toolsMap = client.listTools(); + return toolsMap.entrySet().stream() + .flatMap( + entry -> { + String clientName = entry.getKey(); + McpSchema.Implementation impl = clientInfo.get(clientName); + String prefix = impl != null ? impl.name() : clientName; + + return entry.getValue().tools().stream() + .map( + tool -> (ToolCallback) new McpToolCallback(client, clientName, tool, prefix)); + }) + .toList(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + Map arguments = ModelOptionsUtils.jsonToMap(toolInput); + + // Use the original tool name (not prefixed) when calling the MCP server + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(tool.name(), arguments); + McpSchema.CallToolResult result = client.callTool(clientName, request); + + // Return the result as-is (including errors) so the AI can handle them. + // For example, an "access denied" error lets the AI suggest a valid path. + return ModelOptionsUtils.toJsonString(result.content()); + } + + /** + * Returns the name of the MCP client that provides this tool. + * + * @return the client name + */ + public String getClientName() { + return clientName; + } + + /** + * Returns the original tool definition from the MCP server. + * + * @return the tool definition + */ + public McpSchema.Tool getMcpTool() { + return tool; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java new file mode 100644 index 000000000..c446db1d8 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java @@ -0,0 +1,392 @@ +package io.temporal.springai.model; + +import io.temporal.activity.ActivityOptions; +import io.temporal.common.RetryOptions; +import io.temporal.springai.activity.ChatModelActivity; +import io.temporal.workflow.Workflow; +import java.net.URI; +import java.net.URISyntaxException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.content.Media; +import org.springframework.ai.model.tool.*; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; +import reactor.core.publisher.Flux; + +/** + * A {@link ChatModel} implementation that delegates to a Temporal activity. + * + *

This class enables Spring AI chat clients to be used within Temporal workflows. AI model calls + * are executed as activities, providing durability, automatic retries, and timeout handling. + * + *

Tool execution is handled locally in the workflow (not in the activity), allowing tools to be + * implemented as activities, local activities, or other Temporal primitives. + * + *

Usage

+ * + *

For a single chat model, use the constructor directly: + * + *

{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     ChatModelActivity chatModelActivity = Workflow.newActivityStub(
+ *         ChatModelActivity.class,
+ *         ActivityOptions.newBuilder()
+ *             .setStartToCloseTimeout(Duration.ofMinutes(2))
+ *             .build());
+ *
+ *     ActivityChatModel chatModel = new ActivityChatModel(chatModelActivity);
+ *     this.chatClient = ChatClient.builder(chatModel).build();
+ * }
+ * }
+ * + *

Multiple Chat Models

+ * + *

For applications with multiple chat models, use the static factory methods: + * + *

{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     // Use the default model (first or @Primary bean)
+ *     ActivityChatModel defaultModel = ActivityChatModel.forDefault();
+ *
+ *     // Use a specific model by bean name
+ *     ActivityChatModel openAiModel = ActivityChatModel.forModel("openAiChatModel");
+ *     ActivityChatModel anthropicModel = ActivityChatModel.forModel("anthropicChatModel");
+ *
+ *     // Use different models for different purposes
+ *     this.fastClient = TemporalChatClient.builder(openAiModel).build();
+ *     this.smartClient = TemporalChatClient.builder(anthropicModel).build();
+ * }
+ * }
+ * + * @see #forDefault() + * @see #forModel(String) + */ +public class ActivityChatModel implements ChatModel { + + /** Default timeout for chat model activity calls (2 minutes). */ + public static final Duration DEFAULT_TIMEOUT = Duration.ofMinutes(2); + + /** Default maximum retry attempts for chat model activity calls. */ + public static final int DEFAULT_MAX_ATTEMPTS = 3; + + private final ChatModelActivity chatModelActivity; + private final String modelName; + private final ToolCallingManager toolCallingManager; + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + + /** + * Creates a new ActivityChatModel that uses the default chat model. + * + * @param chatModelActivity the activity stub for calling the chat model + */ + public ActivityChatModel(ChatModelActivity chatModelActivity) { + this(chatModelActivity, null); + } + + /** + * Creates a new ActivityChatModel that uses a specific chat model. + * + * @param chatModelActivity the activity stub for calling the chat model + * @param modelName the name of the chat model to use, or null for default + */ + public ActivityChatModel(ChatModelActivity chatModelActivity, String modelName) { + this.chatModelActivity = chatModelActivity; + this.modelName = modelName; + this.toolCallingManager = ToolCallingManager.builder().build(); + this.toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + } + + /** + * Creates an ActivityChatModel for the default chat model. + * + *

This factory method creates the activity stub internally with default timeout and retry + * options. + * + *

Must be called from workflow code. + * + * @return an ActivityChatModel for the default chat model + */ + public static ActivityChatModel forDefault() { + return forModel(null, DEFAULT_TIMEOUT, DEFAULT_MAX_ATTEMPTS); + } + + /** + * Creates an ActivityChatModel for a specific chat model by bean name. + * + *

This factory method creates the activity stub internally with default timeout and retry + * options. + * + *

Must be called from workflow code. + * + * @param modelName the bean name of the chat model + * @return an ActivityChatModel for the specified chat model + * @throws IllegalArgumentException if no model with that name exists (at activity runtime) + */ + public static ActivityChatModel forModel(String modelName) { + return forModel(modelName, DEFAULT_TIMEOUT, DEFAULT_MAX_ATTEMPTS); + } + + /** + * Creates an ActivityChatModel for a specific chat model with custom options. + * + *

Must be called from workflow code. + * + * @param modelName the bean name of the chat model, or null for default + * @param timeout the activity start-to-close timeout + * @param maxAttempts the maximum number of retry attempts + * @return an ActivityChatModel for the specified chat model + */ + public static ActivityChatModel forModel(String modelName, Duration timeout, int maxAttempts) { + ChatModelActivity activity = + Workflow.newActivityStub( + ChatModelActivity.class, + ActivityOptions.newBuilder() + .setStartToCloseTimeout(timeout) + .setRetryOptions(RetryOptions.newBuilder().setMaximumAttempts(maxAttempts).build()) + .build()); + return new ActivityChatModel(activity, modelName); + } + + /** + * Returns the name of the chat model this instance uses. + * + * @return the model name, or null if using the default model + */ + public String getModelName() { + return modelName; + } + + /** + * Streaming is not supported through Temporal activities. + * + * @throws UnsupportedOperationException always + */ + @Override + public Flux stream(Prompt prompt) { + throw new UnsupportedOperationException("Streaming is not supported in ActivityChatModel."); + } + + @Override + public ChatOptions getDefaultOptions() { + return ToolCallingChatOptions.builder().build(); + } + + @Override + public ChatResponse call(Prompt prompt) { + return internalCall(prompt); + } + + private ChatResponse internalCall(Prompt prompt) { + // Convert prompt to activity input and call the activity + ChatModelTypes.ChatModelActivityInput input = createActivityInput(prompt); + ChatModelTypes.ChatModelActivityOutput output = chatModelActivity.callChatModel(input); + + // Convert activity output to ChatResponse + ChatResponse response = toResponse(output); + + // Handle tool calls if the model requested them + if (prompt.getOptions() != null + && toolExecutionEligibilityPredicate.isToolExecutionRequired( + prompt.getOptions(), response)) { + var toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response); + + if (toolExecutionResult.returnDirect()) { + return ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + + // Send tool results back to the model + return internalCall( + new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions())); + } + + return response; + } + + private ChatModelTypes.ChatModelActivityInput createActivityInput(Prompt prompt) { + // Convert messages + List messages = + prompt.getInstructions().stream() + .flatMap(msg -> toActivityMessages(msg).stream()) + .collect(Collectors.toList()); + + // Convert options + ChatModelTypes.ModelOptions modelOptions = null; + if (prompt.getOptions() != null) { + ChatOptions opts = prompt.getOptions(); + modelOptions = + new ChatModelTypes.ModelOptions( + opts.getModel(), + opts.getFrequencyPenalty(), + opts.getMaxTokens(), + opts.getPresencePenalty(), + opts.getStopSequences(), + opts.getTemperature(), + opts.getTopK(), + opts.getTopP()); + } + + // Convert tool definitions + List tools = List.of(); + if (prompt.getOptions() instanceof ToolCallingChatOptions toolOptions) { + List toolDefinitions = toolCallingManager.resolveToolDefinitions(toolOptions); + if (!CollectionUtils.isEmpty(toolDefinitions)) { + tools = + toolDefinitions.stream() + .map( + td -> + new ChatModelTypes.FunctionTool( + new ChatModelTypes.FunctionTool.Function( + td.name(), td.description(), td.inputSchema()))) + .collect(Collectors.toList()); + } + } + + return new ChatModelTypes.ChatModelActivityInput(modelName, messages, modelOptions, tools); + } + + private List toActivityMessages(Message message) { + return switch (message.getMessageType()) { + case SYSTEM -> + List.of( + new ChatModelTypes.Message(message.getText(), ChatModelTypes.Message.Role.SYSTEM)); + case USER -> { + List mediaContents = null; + if (message instanceof UserMessage userMessage + && !CollectionUtils.isEmpty(userMessage.getMedia())) { + mediaContents = + userMessage.getMedia().stream() + .map(this::toMediaContent) + .collect(Collectors.toList()); + } + yield List.of( + new ChatModelTypes.Message( + message.getText(), mediaContents, ChatModelTypes.Message.Role.USER)); + } + case ASSISTANT -> { + AssistantMessage assistantMessage = (AssistantMessage) message; + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = + assistantMessage.getToolCalls().stream() + .map( + tc -> + new ChatModelTypes.Message.ToolCall( + tc.id(), + tc.type(), + new ChatModelTypes.Message.ChatCompletionFunction( + tc.name(), tc.arguments()))) + .collect(Collectors.toList()); + } + List mediaContents = null; + if (!CollectionUtils.isEmpty(assistantMessage.getMedia())) { + mediaContents = + assistantMessage.getMedia().stream() + .map(this::toMediaContent) + .collect(Collectors.toList()); + } + yield List.of( + new ChatModelTypes.Message( + assistantMessage.getText(), + ChatModelTypes.Message.Role.ASSISTANT, + null, + null, + toolCalls, + mediaContents)); + } + case TOOL -> { + ToolResponseMessage toolMessage = (ToolResponseMessage) message; + yield toolMessage.getResponses().stream() + .map( + tr -> + new ChatModelTypes.Message( + tr.responseData(), + ChatModelTypes.Message.Role.TOOL, + tr.name(), + tr.id(), + null, + null)) + .collect(Collectors.toList()); + } + }; + } + + private ChatModelTypes.MediaContent toMediaContent(Media media) { + String mimeType = media.getMimeType().toString(); + if (media.getData() instanceof String uri) { + return new ChatModelTypes.MediaContent(mimeType, uri); + } else if (media.getData() instanceof byte[] data) { + return new ChatModelTypes.MediaContent(mimeType, data); + } + throw new IllegalArgumentException( + "Unsupported media data type: " + media.getData().getClass()); + } + + private ChatResponse toResponse(ChatModelTypes.ChatModelActivityOutput output) { + List generations = + output.generations().stream() + .map(gen -> new Generation(toAssistantMessage(gen.message()))) + .collect(Collectors.toList()); + + var builder = ChatResponse.builder().generations(generations); + if (output.metadata() != null) { + builder.metadata(ChatResponseMetadata.builder().model(output.metadata().model()).build()); + } + return builder.build(); + } + + private AssistantMessage toAssistantMessage(ChatModelTypes.Message message) { + List toolCalls = List.of(); + if (!CollectionUtils.isEmpty(message.toolCalls())) { + toolCalls = + message.toolCalls().stream() + .map( + tc -> + new AssistantMessage.ToolCall( + tc.id(), tc.type(), tc.function().name(), tc.function().arguments())) + .collect(Collectors.toList()); + } + + List media = List.of(); + if (!CollectionUtils.isEmpty(message.mediaContents())) { + media = message.mediaContents().stream().map(this::toMedia).collect(Collectors.toList()); + } + + return AssistantMessage.builder() + .content(message.rawContent()) + .properties(Map.of()) + .toolCalls(toolCalls) + .media(media) + .build(); + } + + private Media toMedia(ChatModelTypes.MediaContent mediaContent) { + MimeType mimeType = MimeType.valueOf(mediaContent.mimeType()); + if (mediaContent.uri() != null) { + try { + return new Media(mimeType, new URI(mediaContent.uri())); + } catch (URISyntaxException e) { + throw new RuntimeException("Invalid media URI: " + mediaContent.uri(), e); + } + } else if (mediaContent.data() != null) { + return new Media(mimeType, new ByteArrayResource(mediaContent.data())); + } + throw new IllegalArgumentException("Media content must have either uri or data"); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/model/ChatModelTypes.java b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ChatModelTypes.java new file mode 100644 index 000000000..7dca5e316 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/model/ChatModelTypes.java @@ -0,0 +1,192 @@ +package io.temporal.springai.model; + +import com.fasterxml.jackson.annotation.JsonFormat; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.time.Duration; +import java.util.List; + +/** + * Serializable types for chat model activity requests and responses. + * + *

These records are designed to be serialized by Temporal's data converter and passed between + * workflows and activities. + */ +public final class ChatModelTypes { + + private ChatModelTypes() {} + + /** + * Input to the chat model activity. + * + * @param modelName the name of the chat model bean to use (null for default) + * @param messages the conversation messages + * @param modelOptions options for the chat model (temperature, max tokens, etc.) + * @param tools tool definitions the model may call + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChatModelActivityInput( + @JsonProperty("model_name") String modelName, + @JsonProperty("messages") List messages, + @JsonProperty("model_options") ModelOptions modelOptions, + @JsonProperty("tools") List tools) { + /** Creates input for the default chat model. */ + public ChatModelActivityInput( + List messages, ModelOptions modelOptions, List tools) { + this(null, messages, modelOptions, tools); + } + } + + /** + * Output from the chat model activity. + * + * @param generations the generated responses + * @param metadata response metadata (model, usage, rate limits) + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChatModelActivityOutput( + @JsonProperty("generations") List generations, + @JsonProperty("metadata") ChatResponseMetadata metadata) { + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Generation(@JsonProperty("message") Message message) {} + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChatResponseMetadata( + @JsonProperty("model") String model, + @JsonProperty("rate_limit") RateLimit rateLimit, + @JsonProperty("usage") Usage usage) { + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record RateLimit( + @JsonProperty("request_limit") Long requestLimit, + @JsonProperty("request_remaining") Long requestRemaining, + @JsonProperty("request_reset") Duration requestReset, + @JsonProperty("token_limit") Long tokenLimit, + @JsonProperty("token_remaining") Long tokenRemaining, + @JsonProperty("token_reset") Duration tokenReset) {} + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Usage( + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("completion_tokens") Integer completionTokens, + @JsonProperty("total_tokens") Integer totalTokens) {} + } + } + + /** + * A message in the conversation. + * + * @param rawContent the message text content + * @param role the role of the message author + * @param name optional name for the participant + * @param toolCallId tool call ID this message responds to (for TOOL role) + * @param toolCalls tool calls requested by the model (for ASSISTANT role) + * @param mediaContents optional media attachments + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Message( + @JsonProperty("content") String rawContent, + @JsonProperty("role") Role role, + @JsonProperty("name") String name, + @JsonProperty("tool_call_id") String toolCallId, + @JsonProperty("tool_calls") + @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) + List toolCalls, + @JsonProperty("media") @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) + List mediaContents) { + public Message(String content, Role role) { + this(content, role, null, null, null, null); + } + + public Message(String content, List mediaContents, Role role) { + this(content, role, null, null, null, mediaContents); + } + + public enum Role { + @JsonProperty("system") + SYSTEM, + @JsonProperty("user") + USER, + @JsonProperty("assistant") + ASSISTANT, + @JsonProperty("tool") + TOOL + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ToolCall( + @JsonProperty("index") Integer index, + @JsonProperty("id") String id, + @JsonProperty("type") String type, + @JsonProperty("function") ChatCompletionFunction function) { + public ToolCall(String id, String type, ChatCompletionFunction function) { + this(null, id, type, function); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChatCompletionFunction( + @JsonProperty("name") String name, @JsonProperty("arguments") String arguments) {} + } + + /** + * Media content within a message. + * + * @param mimeType the MIME type (e.g., "image/png") + * @param uri optional URI to the content + * @param data optional raw data bytes + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record MediaContent( + @JsonProperty("mime_type") String mimeType, + @JsonProperty("uri") String uri, + @JsonProperty("data") byte[] data) { + public MediaContent(String mimeType, String uri) { + this(mimeType, uri, null); + } + + public MediaContent(String mimeType, byte[] data) { + this(mimeType, null, data); + } + } + + /** A tool the model may call. */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record FunctionTool( + @JsonProperty("type") String type, @JsonProperty("function") Function function) { + public FunctionTool(Function function) { + this("function", function); + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Function( + @JsonProperty("name") String name, + @JsonProperty("description") String description, + @JsonProperty("json_schema") String jsonSchema) {} + } + + /** Model options for the chat request. */ + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ModelOptions( + @JsonProperty("model") String model, + @JsonProperty("frequency_penalty") Double frequencyPenalty, + @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("presence_penalty") Double presencePenalty, + @JsonProperty("stop_sequences") List stopSequences, + @JsonProperty("temperature") Double temperature, + @JsonProperty("top_k") Integer topK, + @JsonProperty("top_p") Double topP) {} +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/model/EmbeddingModelTypes.java b/temporal-spring-ai/src/main/java/io/temporal/springai/model/EmbeddingModelTypes.java new file mode 100644 index 000000000..4691bc0be --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/model/EmbeddingModelTypes.java @@ -0,0 +1,67 @@ +package io.temporal.springai.model; + +import java.util.List; + +/** + * Serializable types for EmbeddingModel activity communication. + * + *

These records are used to pass data between workflows and the EmbeddingModelActivity, ensuring + * all data can be serialized by Temporal's data converter. + */ +public final class EmbeddingModelTypes { + + private EmbeddingModelTypes() {} + + /** + * Input for embedding a single text. + * + * @param text the text to embed + */ + public record EmbedTextInput(String text) {} + + /** + * Input for embedding multiple texts. + * + * @param texts the texts to embed + */ + public record EmbedBatchInput(List texts) {} + + /** + * Output containing a single embedding vector. + * + * @param embedding the embedding vector + */ + public record EmbedOutput(float[] embedding) {} + + /** + * Output containing multiple embedding vectors. + * + * @param embeddings the embedding vectors, one per input text + * @param metadata additional metadata about the embeddings + */ + public record EmbedBatchOutput(List embeddings, EmbeddingMetadata metadata) {} + + /** + * A single embedding result. + * + * @param index the index in the original input list + * @param embedding the embedding vector + */ + public record EmbeddingResult(int index, float[] embedding) {} + + /** + * Metadata about the embedding operation. + * + * @param model the model used for embedding + * @param totalTokens total tokens processed + * @param dimensions the dimensionality of the embeddings + */ + public record EmbeddingMetadata(String model, Integer totalTokens, Integer dimensions) {} + + /** + * Output containing embedding model dimensions. + * + * @param dimensions the number of dimensions in the embedding vectors + */ + public record DimensionsOutput(int dimensions) {} +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/model/VectorStoreTypes.java b/temporal-spring-ai/src/main/java/io/temporal/springai/model/VectorStoreTypes.java new file mode 100644 index 000000000..0eadd932e --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/model/VectorStoreTypes.java @@ -0,0 +1,82 @@ +package io.temporal.springai.model; + +import java.util.List; +import java.util.Map; + +/** + * Serializable types for VectorStore activity communication. + * + *

These records are used to pass data between workflows and the VectorStoreActivity, ensuring + * all data can be serialized by Temporal's data converter. + */ +public final class VectorStoreTypes { + + private VectorStoreTypes() {} + + /** + * Serializable representation of a document for vector storage. + * + * @param id unique identifier for the document + * @param text the text content of the document + * @param metadata additional metadata associated with the document + * @param embedding pre-computed embedding vector (optional, may be computed by the store) + */ + public record Document( + String id, String text, Map metadata, List embedding) { + public Document(String id, String text, Map metadata) { + this(id, text, metadata, null); + } + + public Document(String id, String text) { + this(id, text, Map.of(), null); + } + } + + /** + * Input for adding documents to the vector store. + * + * @param documents the documents to add + */ + public record AddDocumentsInput(List documents) {} + + /** + * Input for deleting documents by ID. + * + * @param ids the document IDs to delete + */ + public record DeleteByIdsInput(List ids) {} + + /** + * Input for similarity search. + * + * @param query the search query text + * @param topK maximum number of results to return + * @param similarityThreshold minimum similarity score (0.0 to 1.0) + * @param filterExpression optional filter expression for metadata filtering + */ + public record SearchInput( + String query, int topK, Double similarityThreshold, String filterExpression) { + public SearchInput(String query, int topK) { + this(query, topK, null, null); + } + + public SearchInput(String query) { + this(query, 4, null, null); + } + } + + /** + * Output from similarity search. + * + * @param documents the matching documents with their similarity scores + */ + public record SearchOutput(List documents) {} + + /** + * A single search result with similarity score. + * + * @param document the matched document + * @param score the similarity score + */ + public record SearchResult(Document document, Double score) {} +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/EmbeddingModelPlugin.java b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/EmbeddingModelPlugin.java new file mode 100644 index 000000000..c84bf2a71 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/EmbeddingModelPlugin.java @@ -0,0 +1,22 @@ +package io.temporal.springai.plugin; + +import io.temporal.common.SimplePlugin; +import io.temporal.springai.activity.EmbeddingModelActivityImpl; +import org.springframework.ai.embedding.EmbeddingModel; + +/** + * Temporal plugin that registers {@link io.temporal.springai.activity.EmbeddingModelActivity} with + * workers. + * + *

This plugin is conditionally created by auto-configuration when Spring AI's {@link + * EmbeddingModel} is on the classpath and an EmbeddingModel bean is available. It can also be + * created manually for non-auto-configured setups. + */ +public class EmbeddingModelPlugin extends SimplePlugin { + + public EmbeddingModelPlugin(EmbeddingModel embeddingModel) { + super( + SimplePlugin.newBuilder("io.temporal.spring-ai-embedding") + .registerActivitiesImplementations(new EmbeddingModelActivityImpl(embeddingModel))); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java new file mode 100644 index 000000000..2f3635cfd --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java @@ -0,0 +1,95 @@ +package io.temporal.springai.plugin; + +import io.modelcontextprotocol.client.McpSyncClient; +import io.temporal.common.SimplePlugin; +import io.temporal.springai.mcp.McpClientActivityImpl; +import io.temporal.worker.Worker; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nonnull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.SmartInitializingSingleton; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; + +/** + * Temporal plugin that registers {@link io.temporal.springai.mcp.McpClientActivity} with workers. + * + *

This plugin is conditionally created by auto-configuration when MCP classes are on the + * classpath. MCP clients may be created late by Spring AI's auto-configuration, so this plugin + * supports deferred registration via {@link SmartInitializingSingleton}. + */ +public class McpPlugin extends SimplePlugin + implements ApplicationContextAware, SmartInitializingSingleton { + + private static final Logger log = LoggerFactory.getLogger(McpPlugin.class); + + private List mcpClients = List.of(); + private ApplicationContext applicationContext; + private final List pendingWorkers = new ArrayList<>(); + + public McpPlugin() { + super("io.temporal.spring-ai-mcp"); + } + + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.applicationContext = applicationContext; + } + + @SuppressWarnings("unchecked") + private List getMcpClients() { + if (!mcpClients.isEmpty()) { + return mcpClients; + } + + if (applicationContext != null && applicationContext.containsBean("mcpSyncClients")) { + try { + Object bean = applicationContext.getBean("mcpSyncClients"); + if (bean instanceof List clientList && !clientList.isEmpty()) { + mcpClients = (List) clientList; + log.info("Found {} MCP client(s) in ApplicationContext", mcpClients.size()); + } + } catch (Exception e) { + log.debug("Failed to get mcpSyncClients bean: {}", e.getMessage()); + } + } + + return mcpClients; + } + + @Override + public void initializeWorker(@Nonnull String taskQueue, @Nonnull Worker worker) { + List clients = getMcpClients(); + if (!clients.isEmpty()) { + worker.registerActivitiesImplementations(new McpClientActivityImpl(clients)); + log.info( + "Registered McpClientActivity ({} clients) for task queue {}", clients.size(), taskQueue); + } else { + pendingWorkers.add(worker); + log.debug("MCP clients not yet available; will attempt registration after initialization"); + } + } + + @Override + public void afterSingletonsInstantiated() { + if (pendingWorkers.isEmpty()) { + return; + } + + List clients = getMcpClients(); + if (clients.isEmpty()) { + log.debug("No MCP clients found after all beans initialized"); + pendingWorkers.clear(); + return; + } + + for (Worker worker : pendingWorkers) { + worker.registerActivitiesImplementations(new McpClientActivityImpl(clients)); + log.info("Registered deferred McpClientActivity ({} clients)", clients.size()); + } + pendingWorkers.clear(); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/SpringAiPlugin.java b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/SpringAiPlugin.java new file mode 100644 index 000000000..0438ff558 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/SpringAiPlugin.java @@ -0,0 +1,151 @@ +package io.temporal.springai.plugin; + +import io.temporal.common.SimplePlugin; +import io.temporal.springai.activity.ChatModelActivityImpl; +import io.temporal.worker.Worker; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import javax.annotation.Nonnull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.lang.Nullable; + +/** + * Core Temporal plugin that registers {@link io.temporal.springai.activity.ChatModelActivity} with + * Temporal workers. + * + *

This plugin handles the required ChatModel integration. Optional integrations (VectorStore, + * EmbeddingModel, MCP) are handled by separate plugins that are conditionally created by + * auto-configuration: + * + *

+ * + *

In Workflows

+ * + *
{@code
+ * @WorkflowInit
+ * public MyWorkflowImpl() {
+ *     ActivityChatModel chatModel = ActivityChatModel.forDefault();
+ *     this.chatClient = TemporalChatClient.builder(chatModel).build();
+ * }
+ * }
+ * + * @see io.temporal.springai.activity.ChatModelActivity + * @see io.temporal.springai.model.ActivityChatModel + */ +public class SpringAiPlugin extends SimplePlugin { + + private static final Logger log = LoggerFactory.getLogger(SpringAiPlugin.class); + + /** The name used for the default chat model when none is specified. */ + public static final String DEFAULT_MODEL_NAME = "default"; + + private final Map chatModels; + private final String defaultModelName; + + /** + * Creates a new SpringAiPlugin with the given ChatModel. + * + * @param chatModel the Spring AI chat model to wrap as an activity + */ + public SpringAiPlugin(ChatModel chatModel) { + super("io.temporal.spring-ai"); + this.chatModels = Map.of(DEFAULT_MODEL_NAME, chatModel); + this.defaultModelName = DEFAULT_MODEL_NAME; + } + + /** + * Creates a new SpringAiPlugin with multiple ChatModels. + * + * @param chatModels map of bean names to ChatModel instances + * @param primaryChatModel the primary chat model (used to determine default), or null + */ + public SpringAiPlugin(Map chatModels, @Nullable ChatModel primaryChatModel) { + super("io.temporal.spring-ai"); + + if (chatModels == null || chatModels.isEmpty()) { + throw new IllegalArgumentException("At least one ChatModel bean is required"); + } + + this.chatModels = new LinkedHashMap<>(chatModels); + + if (primaryChatModel != null) { + String primaryName = + chatModels.entrySet().stream() + .filter(e -> e.getValue() == primaryChatModel) + .map(Map.Entry::getKey) + .findFirst() + .orElse(chatModels.keySet().iterator().next()); + this.defaultModelName = primaryName; + } else { + this.defaultModelName = chatModels.keySet().iterator().next(); + } + + if (chatModels.size() > 1) { + log.info( + "Registered {} chat models: {} (default: {})", + chatModels.size(), + chatModels.keySet(), + defaultModelName); + } + } + + @Override + public void initializeWorker(@Nonnull String taskQueue, @Nonnull Worker worker) { + ChatModelActivityImpl chatModelActivityImpl = + new ChatModelActivityImpl(chatModels, defaultModelName); + worker.registerActivitiesImplementations(chatModelActivityImpl); + + String modelInfo = chatModels.size() > 1 ? " (" + chatModels.size() + " models)" : ""; + log.info("Registered ChatModelActivity{} for task queue {}", modelInfo, taskQueue); + } + + /** + * Returns the default ChatModel wrapped by this plugin. + * + * @return the default chat model + */ + public ChatModel getChatModel() { + return chatModels.get(defaultModelName); + } + + /** + * Returns a specific ChatModel by bean name. + * + * @param modelName the bean name of the chat model + * @return the chat model + * @throws IllegalArgumentException if no model with that name exists + */ + public ChatModel getChatModel(String modelName) { + ChatModel model = chatModels.get(modelName); + if (model == null) { + throw new IllegalArgumentException( + "No chat model with name '" + modelName + "'. Available models: " + chatModels.keySet()); + } + return model; + } + + /** + * Returns all ChatModels wrapped by this plugin, keyed by bean name. + * + * @return unmodifiable map of chat models + */ + public Map getChatModels() { + return Collections.unmodifiableMap(chatModels); + } + + /** + * Returns the name of the default chat model. + * + * @return the default model name + */ + public String getDefaultModelName() { + return defaultModelName; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/VectorStorePlugin.java b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/VectorStorePlugin.java new file mode 100644 index 000000000..a1d9d5b28 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/VectorStorePlugin.java @@ -0,0 +1,22 @@ +package io.temporal.springai.plugin; + +import io.temporal.common.SimplePlugin; +import io.temporal.springai.activity.VectorStoreActivityImpl; +import org.springframework.ai.vectorstore.VectorStore; + +/** + * Temporal plugin that registers {@link io.temporal.springai.activity.VectorStoreActivity} with + * workers. + * + *

This plugin is conditionally created by auto-configuration when Spring AI's {@link + * VectorStore} is on the classpath and a VectorStore bean is available. It can also be created + * manually for non-auto-configured setups. + */ +public class VectorStorePlugin extends SimplePlugin { + + public VectorStorePlugin(VectorStore vectorStore) { + super( + SimplePlugin.newBuilder("io.temporal.spring-ai-vectorstore") + .registerActivitiesImplementations(new VectorStoreActivityImpl(vectorStore))); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolCallback.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolCallback.java new file mode 100644 index 000000000..6f2dfe21b --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolCallback.java @@ -0,0 +1,61 @@ +package io.temporal.springai.tool; + +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; + +/** + * A wrapper for {@link ToolCallback} that indicates the underlying tool is backed by a Temporal + * activity stub. + * + *

This wrapper delegates all operations to the underlying callback while serving as a marker to + * indicate that tool invocations will execute as Temporal activities, providing durability, + * automatic retries, and timeout handling. + * + *

This class is primarily used internally by {@link ActivityToolUtil} when converting activity + * stubs to tool callbacks. Users typically don't need to create instances directly. + * + * @see ActivityToolUtil#fromActivityStub(Object...) + */ +public class ActivityToolCallback implements ToolCallback { + private final ToolCallback delegate; + + /** + * Creates a new ActivityToolCallback wrapping the given callback. + * + * @param delegate the underlying tool callback to wrap + */ + public ActivityToolCallback(ToolCallback delegate) { + this.delegate = delegate; + } + + @Override + public ToolDefinition getToolDefinition() { + return delegate.getToolDefinition(); + } + + @Override + public ToolMetadata getToolMetadata() { + return delegate.getToolMetadata(); + } + + @Override + public String call(String toolInput) { + return delegate.call(toolInput); + } + + @Override + public String call(String toolInput, ToolContext toolContext) { + return delegate.call(toolInput, toolContext); + } + + /** + * Returns the underlying delegate callback. + * + * @return the wrapped callback + */ + public ToolCallback getDelegate() { + return delegate; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolUtil.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolUtil.java new file mode 100644 index 000000000..e168bcd86 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/ActivityToolUtil.java @@ -0,0 +1,135 @@ +package io.temporal.springai.tool; + +import io.temporal.activity.ActivityInterface; +import io.temporal.common.metadata.POJOActivityInterfaceMetadata; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.metadata.ToolMetadata; +import org.springframework.ai.tool.method.MethodToolCallback; +import org.springframework.ai.tool.support.ToolDefinitions; +import org.springframework.ai.tool.support.ToolUtils; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +/** + * Utility class for extracting tool definitions from Temporal activity interfaces. + * + *

This class bridges Spring AI's {@link Tool} annotation with Temporal's {@link + * ActivityInterface} annotation, allowing activity methods to be used as AI tools within workflows. + * + *

Example: + * + *

{@code
+ * @ActivityInterface
+ * public interface WeatherActivity {
+ *     @Tool(description = "Get the current weather for a city")
+ *     String getWeather(String city);
+ * }
+ *
+ * // In workflow:
+ * WeatherActivity weatherTool = Workflow.newActivityStub(WeatherActivity.class, opts);
+ * ToolCallback[] callbacks = ActivityToolUtil.fromActivityStub(weatherTool);
+ * }
+ */ +public final class ActivityToolUtil { + + private ActivityToolUtil() { + // Utility class + } + + /** + * Extracts {@link Tool} annotations from the given activity stub object. + * + *

Scans all interfaces implemented by the stub that are annotated with {@link + * ActivityInterface}, and returns a map of activity type names to their {@link Tool} annotations. + * + * @param activityStub the activity stub to extract annotations from + * @return a map of activity type names to Tool annotations + */ + public static Map getToolAnnotations(Object activityStub) { + return Stream.of(activityStub.getClass().getInterfaces()) + .filter(iface -> iface.isAnnotationPresent(ActivityInterface.class)) + .map(POJOActivityInterfaceMetadata::newInstance) + .flatMap(metadata -> metadata.getMethodsMetadata().stream()) + .filter(methodMetadata -> methodMetadata.getMethod().isAnnotationPresent(Tool.class)) + .collect( + Collectors.toMap( + methodMetadata -> methodMetadata.getActivityTypeName(), + methodMetadata -> methodMetadata.getMethod().getAnnotation(Tool.class))); + } + + /** + * Creates {@link ToolCallback} instances from activity stub objects. + * + *

For each activity stub, this method: + * + *

    + *
  1. Finds all interfaces annotated with {@link ActivityInterface} + *
  2. Extracts methods annotated with {@link Tool} + *
  3. Creates {@link MethodToolCallback} instances for each method + *
  4. Wraps them in {@link ActivityToolCallback} to mark their origin + *
+ * + *

Methods that return functional types (Function, Supplier, Consumer) are excluded as they are + * not supported as tools. + * + * @param toolObjects the activity stub objects to convert + * @return an array of ToolCallback instances + */ + public static ToolCallback[] fromActivityStub(Object... toolObjects) { + List callbacks = new ArrayList<>(); + + for (Object toolObject : toolObjects) { + Stream.of(toolObject.getClass().getInterfaces()) + .filter(iface -> iface.isAnnotationPresent(ActivityInterface.class)) + .flatMap(iface -> Stream.of(ReflectionUtils.getDeclaredMethods(iface))) + .filter(method -> method.isAnnotationPresent(Tool.class)) + .filter(method -> !isFunctionalType(method)) + .map(method -> createToolCallback(method, toolObject)) + .map(ActivityToolCallback::new) + .forEach(callbacks::add); + } + + return callbacks.toArray(new ToolCallback[0]); + } + + /** + * Checks if any interfaces implemented by the object are annotated with {@link ActivityInterface} + * and contain methods annotated with {@link Tool}. + * + * @param object the object to check + * @return true if the object has tool-annotated activity methods + */ + public static boolean hasToolAnnotations(Object object) { + return Stream.of(object.getClass().getInterfaces()) + .filter(iface -> iface.isAnnotationPresent(ActivityInterface.class)) + .flatMap(iface -> Stream.of(ReflectionUtils.getDeclaredMethods(iface))) + .anyMatch(method -> method.isAnnotationPresent(Tool.class)); + } + + private static MethodToolCallback createToolCallback(Method method, Object toolObject) { + return MethodToolCallback.builder() + .toolDefinition(ToolDefinitions.from(method)) + .toolMetadata(ToolMetadata.from(method)) + .toolMethod(method) + .toolObject(toolObject) + .toolCallResultConverter(ToolUtils.getToolCallResultConverter(method)) + .build(); + } + + private static boolean isFunctionalType(Method method) { + Class returnType = method.getReturnType(); + return ClassUtils.isAssignable(returnType, Function.class) + || ClassUtils.isAssignable(returnType, Supplier.class) + || ClassUtils.isAssignable(returnType, Consumer.class); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolCallback.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolCallback.java new file mode 100644 index 000000000..a010dcd2d --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolCallback.java @@ -0,0 +1,61 @@ +package io.temporal.springai.tool; + +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; + +/** + * A wrapper for {@link ToolCallback} that indicates the underlying tool is backed by a Temporal + * Nexus service stub. + * + *

This wrapper delegates all operations to the underlying callback while serving as a marker to + * indicate that tool invocations will execute as Nexus operations, providing cross-namespace + * communication and durability. + * + *

This class is primarily used internally by {@link NexusToolUtil} when converting Nexus service + * stubs to tool callbacks. Users typically don't need to create instances directly. + * + * @see NexusToolUtil#fromNexusServiceStub(Object...) + */ +public class NexusToolCallback implements ToolCallback { + private final ToolCallback delegate; + + /** + * Creates a new NexusToolCallback wrapping the given callback. + * + * @param delegate the underlying tool callback to wrap + */ + public NexusToolCallback(ToolCallback delegate) { + this.delegate = delegate; + } + + @Override + public ToolDefinition getToolDefinition() { + return delegate.getToolDefinition(); + } + + @Override + public ToolMetadata getToolMetadata() { + return delegate.getToolMetadata(); + } + + @Override + public String call(String toolInput) { + return delegate.call(toolInput); + } + + @Override + public String call(String toolInput, ToolContext toolContext) { + return delegate.call(toolInput, toolContext); + } + + /** + * Returns the underlying delegate callback. + * + * @return the wrapped callback + */ + public ToolCallback getDelegate() { + return delegate; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolUtil.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolUtil.java new file mode 100644 index 000000000..b2aa4a6a2 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/NexusToolUtil.java @@ -0,0 +1,111 @@ +package io.temporal.springai.tool; + +import io.nexusrpc.Service; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.metadata.ToolMetadata; +import org.springframework.ai.tool.method.MethodToolCallback; +import org.springframework.ai.tool.support.ToolDefinitions; +import org.springframework.ai.tool.support.ToolUtils; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +/** + * Utility class for extracting tool definitions from Temporal Nexus service interfaces. + * + *

This class bridges Spring AI's {@link Tool} annotation with Nexus RPC's {@link Service} + * annotation, allowing Nexus service methods to be used as AI tools within workflows. + * + *

Example: + * + *

{@code
+ * @Service
+ * public interface WeatherService {
+ *     @Tool(description = "Get the current weather for a city")
+ *     String getWeather(String city);
+ * }
+ *
+ * // In workflow:
+ * WeatherService weatherTool = Workflow.newNexusServiceStub(WeatherService.class, opts);
+ * ToolCallback[] callbacks = NexusToolUtil.fromNexusServiceStub(weatherTool);
+ * }
+ */ +public final class NexusToolUtil { + + private NexusToolUtil() { + // Utility class + } + + /** + * Creates {@link ToolCallback} instances from Nexus service stub objects. + * + *

For each Nexus service stub, this method: + * + *

    + *
  1. Finds all interfaces annotated with {@link Service} + *
  2. Extracts methods annotated with {@link Tool} + *
  3. Creates {@link MethodToolCallback} instances for each method + *
  4. Wraps them in {@link NexusToolCallback} to mark their origin + *
+ * + *

Methods that return functional types (Function, Supplier, Consumer) are excluded as they are + * not supported as tools. + * + * @param toolObjects the Nexus service stub objects to convert + * @return an array of ToolCallback instances + */ + public static ToolCallback[] fromNexusServiceStub(Object... toolObjects) { + List callbacks = new ArrayList<>(); + + for (Object toolObject : toolObjects) { + Stream.of(toolObject.getClass().getInterfaces()) + .filter(iface -> iface.isAnnotationPresent(Service.class)) + .flatMap(iface -> Stream.of(ReflectionUtils.getDeclaredMethods(iface))) + .filter(method -> method.isAnnotationPresent(Tool.class)) + .filter(method -> !isFunctionalType(method)) + .map(method -> createToolCallback(method, toolObject)) + .map(NexusToolCallback::new) + .forEach(callbacks::add); + } + + return callbacks.toArray(new ToolCallback[0]); + } + + /** + * Checks if any interfaces implemented by the object are annotated with {@link Service} and + * contain methods annotated with {@link Tool}. + * + * @param object the object to check + * @return true if the object has tool-annotated Nexus service methods + */ + public static boolean hasToolAnnotations(Object object) { + return Stream.of(object.getClass().getInterfaces()) + .filter(iface -> iface.isAnnotationPresent(Service.class)) + .flatMap(iface -> Stream.of(ReflectionUtils.getDeclaredMethods(iface))) + .anyMatch(method -> method.isAnnotationPresent(Tool.class)); + } + + private static MethodToolCallback createToolCallback(Method method, Object toolObject) { + return MethodToolCallback.builder() + .toolDefinition(ToolDefinitions.from(method)) + .toolMetadata(ToolMetadata.from(method)) + .toolMethod(method) + .toolObject(toolObject) + .toolCallResultConverter(ToolUtils.getToolCallResultConverter(method)) + .build(); + } + + private static boolean isFunctionalType(Method method) { + Class returnType = method.getReturnType(); + return ClassUtils.isAssignable(returnType, Function.class) + || ClassUtils.isAssignable(returnType, Supplier.class) + || ClassUtils.isAssignable(returnType, Consumer.class); + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectTool.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectTool.java new file mode 100644 index 000000000..b3520534f --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectTool.java @@ -0,0 +1,57 @@ +package io.temporal.springai.tool; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Marks a tool class as a side-effect tool, meaning its methods will be wrapped in {@code + * Workflow.sideEffect()} for safe execution in a Temporal workflow. + * + *

Side-effect tools are useful for operations that: + * + *

    + *
  • Are non-deterministic (e.g., reading current time, generating UUIDs) + *
  • Are cheap and don't need the full durability of an activity + *
  • Don't have external side effects that need to be retried on failure + *
+ * + *

The result of a side-effect tool is recorded in the workflow history, so on replay the same + * result is returned without re-executing the tool. + * + *

Example usage: + * + *

{@code
+ * @SideEffectTool
+ * public class TimestampTools {
+ *     @Tool(description = "Get the current timestamp")
+ *     public long currentTimeMillis() {
+ *         return System.currentTimeMillis();  // Non-deterministic, but recorded
+ *     }
+ *
+ *     @Tool(description = "Generate a random UUID")
+ *     public String randomUuid() {
+ *         return UUID.randomUUID().toString();
+ *     }
+ * }
+ *
+ * // In workflow:
+ * this.chatClient = TemporalChatClient.builder(activityChatModel)
+ *         .defaultTools(new TimestampTools())  // Wrapped in sideEffect()
+ *         .build();
+ * }
+ * + *

When to use which annotation: + * + *

    + *
  • {@code @SideEffectTool} - Non-deterministic but cheap operations (timestamps, random + * values) + *
  • Activity stub - Operations with external side effects or that need retry/durability + *
+ * + * @see io.temporal.workflow.Workflow#sideEffect(Class, io.temporal.workflow.Functions.Func) + */ +@Target({ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +public @interface SideEffectTool {} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectToolCallback.java b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectToolCallback.java new file mode 100644 index 000000000..561b5b057 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/tool/SideEffectToolCallback.java @@ -0,0 +1,66 @@ +package io.temporal.springai.tool; + +import io.temporal.workflow.Workflow; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; + +/** + * A wrapper for {@link ToolCallback} that executes the tool within {@code Workflow.sideEffect()}, + * making it safe for non-deterministic operations. + * + *

When a tool is wrapped in this callback: + * + *

    + *
  • The first execution records the result in workflow history + *
  • On replay, the recorded result is returned without re-execution + *
  • This ensures deterministic replay even for non-deterministic tools + *
+ * + *

This is used internally when processing tools marked with {@link SideEffectTool}. + * + * @see SideEffectTool + * @see io.temporal.workflow.Workflow#sideEffect(Class, io.temporal.workflow.Functions.Func) + */ +public class SideEffectToolCallback implements ToolCallback { + private final ToolCallback delegate; + + /** + * Creates a new SideEffectToolCallback wrapping the given callback. + * + * @param delegate the underlying tool callback to wrap + */ + public SideEffectToolCallback(ToolCallback delegate) { + this.delegate = delegate; + } + + @Override + public ToolDefinition getToolDefinition() { + return delegate.getToolDefinition(); + } + + @Override + public ToolMetadata getToolMetadata() { + return delegate.getToolMetadata(); + } + + @Override + public String call(String toolInput) { + return Workflow.sideEffect(String.class, () -> delegate.call(toolInput)); + } + + @Override + public String call(String toolInput, ToolContext toolContext) { + return Workflow.sideEffect(String.class, () -> delegate.call(toolInput, toolContext)); + } + + /** + * Returns the underlying delegate callback. + * + * @return the wrapped callback + */ + public ToolCallback getDelegate() { + return delegate; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalStubUtil.java b/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalStubUtil.java new file mode 100644 index 000000000..573cccb3b --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalStubUtil.java @@ -0,0 +1,87 @@ +package io.temporal.springai.util; + +import io.temporal.internal.sync.ActivityInvocationHandler; +import io.temporal.internal.sync.LocalActivityInvocationHandler; +import io.temporal.internal.sync.NexusServiceInvocationHandler; +import java.lang.reflect.Proxy; + +/** + * Utility class for detecting Temporal stub types. + * + *

Temporal creates dynamic proxies for various stub types (activities, local activities, child + * workflows, Nexus services). This utility provides methods to detect what type of stub an object + * is, which is useful for determining how to handle tool calls. + * + *

This class uses direct {@code instanceof} checks against the SDK's internal invocation handler + * classes. Since the {@code temporal-spring-ai} module lives in the SDK repo, this coupling is + * intentional and will be caught by compilation if the handler classes are renamed or moved. + */ +public final class TemporalStubUtil { + + private TemporalStubUtil() {} + + /** + * Checks if the given object is an activity stub created by {@code Workflow.newActivityStub()}. + * + * @param object the object to check + * @return true if the object is an activity stub (but not a local activity stub) + */ + public static boolean isActivityStub(Object object) { + if (object == null || !Proxy.isProxyClass(object.getClass())) { + return false; + } + var handler = Proxy.getInvocationHandler(object); + return handler instanceof ActivityInvocationHandler; + } + + /** + * Checks if the given object is a local activity stub created by {@code + * Workflow.newLocalActivityStub()}. + * + * @param object the object to check + * @return true if the object is a local activity stub + */ + public static boolean isLocalActivityStub(Object object) { + if (object == null || !Proxy.isProxyClass(object.getClass())) { + return false; + } + var handler = Proxy.getInvocationHandler(object); + return handler instanceof LocalActivityInvocationHandler; + } + + /** + * Checks if the given object is a child workflow stub created by {@code + * Workflow.newChildWorkflowStub()}. + * + *

Note: {@code ChildWorkflowInvocationHandler} is package-private in the SDK, so we check via + * the class name. This is safe because the module lives in the SDK repo — any rename would break + * compilation of this module's tests. + * + * @param object the object to check + * @return true if the object is a child workflow stub + */ + public static boolean isChildWorkflowStub(Object object) { + if (object == null || !Proxy.isProxyClass(object.getClass())) { + return false; + } + var handler = Proxy.getInvocationHandler(object); + // ChildWorkflowInvocationHandler is package-private, so we use class name check. + // This is the only handler where instanceof is not possible. + return handler.getClass().getName().endsWith("ChildWorkflowInvocationHandler"); + } + + /** + * Checks if the given object is a Nexus service stub created by {@code + * Workflow.newNexusServiceStub()}. + * + * @param object the object to check + * @return true if the object is a Nexus service stub + */ + public static boolean isNexusServiceStub(Object object) { + if (object == null || !Proxy.isProxyClass(object.getClass())) { + return false; + } + var handler = Proxy.getInvocationHandler(object); + return handler instanceof NexusServiceInvocationHandler; + } +} diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalToolUtil.java b/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalToolUtil.java new file mode 100644 index 000000000..44b499ee5 --- /dev/null +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/util/TemporalToolUtil.java @@ -0,0 +1,82 @@ +package io.temporal.springai.util; + +import io.temporal.springai.tool.ActivityToolUtil; +import io.temporal.springai.tool.NexusToolUtil; +import io.temporal.springai.tool.SideEffectTool; +import io.temporal.springai.tool.SideEffectToolCallback; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.springframework.ai.support.ToolCallbacks; +import org.springframework.ai.tool.ToolCallback; + +/** + * Utility class for converting tool objects to appropriate {@link ToolCallback} instances based on + * their type. + * + *

Each tool object is detected and handled as follows: + * + *

    + *
  • Activity stubs - Executed as durable Temporal activities + *
  • Local activity stubs - Executed as local activities + *
  • Nexus service stubs - Executed as Nexus operations + *
  • {@link SideEffectTool} classes - Wrapped in {@code Workflow.sideEffect()} + *
  • Plain objects with {@code @Tool} methods - Executed directly in workflow context. + * User is responsible for determinism. + *
  • Child workflow stubs - Not supported (use a plain tool that starts a child workflow) + *
+ * + * @see SideEffectTool + * @see SideEffectToolCallback + */ +public final class TemporalToolUtil { + + private TemporalToolUtil() {} + + /** + * Converts an array of tool objects to appropriate {@link ToolCallback} instances. + * + * @param toolObjects the tool objects to convert + * @return a list of ToolCallback instances + * @throws UnsupportedOperationException if a child workflow stub is passed + */ + public static List convertTools(Object... toolObjects) { + List toolCallbacks = new ArrayList<>(); + + for (Object toolObject : toolObjects) { + if (toolObject == null) { + throw new IllegalArgumentException("Tool object cannot be null"); + } + + if (TemporalStubUtil.isActivityStub(toolObject)) { + toolCallbacks.addAll(List.of(ActivityToolUtil.fromActivityStub(toolObject))); + + } else if (TemporalStubUtil.isLocalActivityStub(toolObject)) { + toolCallbacks.addAll(List.of(ActivityToolUtil.fromActivityStub(toolObject))); + + } else if (TemporalStubUtil.isNexusServiceStub(toolObject)) { + toolCallbacks.addAll(List.of(NexusToolUtil.fromNexusServiceStub(toolObject))); + + } else if (TemporalStubUtil.isChildWorkflowStub(toolObject)) { + throw new UnsupportedOperationException( + "Child workflow stubs are not supported as tools. " + + "Use a plain tool method that starts a child workflow instead."); + + } else if (toolObject.getClass().isAnnotationPresent(SideEffectTool.class)) { + ToolCallback[] rawCallbacks = ToolCallbacks.from(toolObject); + toolCallbacks.addAll( + Arrays.stream(rawCallbacks) + .map(SideEffectToolCallback::new) + .map(tc -> (ToolCallback) tc) + .toList()); + + } else { + // Plain tool — executes directly in workflow context. + // User is responsible for determinism. + toolCallbacks.addAll(List.of(ToolCallbacks.from(toolObject))); + } + } + + return toolCallbacks; + } +} diff --git a/temporal-spring-ai/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/temporal-spring-ai/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 000000000..7f86436f4 --- /dev/null +++ b/temporal-spring-ai/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1,4 @@ +io.temporal.springai.autoconfigure.SpringAiTemporalAutoConfiguration +io.temporal.springai.autoconfigure.SpringAiVectorStoreAutoConfiguration +io.temporal.springai.autoconfigure.SpringAiEmbeddingAutoConfiguration +io.temporal.springai.autoconfigure.SpringAiMcpAutoConfiguration diff --git a/temporal-spring-ai/src/test/java/io/temporal/springai/WorkflowDeterminismTest.java b/temporal-spring-ai/src/test/java/io/temporal/springai/WorkflowDeterminismTest.java new file mode 100644 index 000000000..b79c07551 --- /dev/null +++ b/temporal-spring-ai/src/test/java/io/temporal/springai/WorkflowDeterminismTest.java @@ -0,0 +1,199 @@ +package io.temporal.springai; + +import static org.junit.jupiter.api.Assertions.*; + +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowOptions; +import io.temporal.client.WorkflowStub; +import io.temporal.common.WorkflowExecutionHistory; +import io.temporal.springai.activity.ChatModelActivityImpl; +import io.temporal.springai.chat.TemporalChatClient; +import io.temporal.springai.model.ActivityChatModel; +import io.temporal.springai.tool.SideEffectTool; +import io.temporal.testing.TestWorkflowEnvironment; +import io.temporal.testing.WorkflowReplayer; +import io.temporal.worker.Worker; +import io.temporal.workflow.WorkflowInterface; +import io.temporal.workflow.WorkflowMethod; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.tool.annotation.Tool; + +/** + * Verifies that workflows using ActivityChatModel with tools are deterministic by running them to + * completion and then replaying from the captured history. + */ +class WorkflowDeterminismTest { + + private static final String TASK_QUEUE = "test-spring-ai"; + + private TestWorkflowEnvironment testEnv; + private WorkflowClient client; + + @BeforeEach + void setUp() { + testEnv = TestWorkflowEnvironment.newInstance(); + client = testEnv.getWorkflowClient(); + } + + @AfterEach + void tearDown() { + testEnv.close(); + } + + @Test + void workflowWithChatModel_replaysDeterministically() throws Exception { + Worker worker = testEnv.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(ChatWorkflowImpl.class); + worker.registerActivitiesImplementations( + new ChatModelActivityImpl(new StubChatModel("Hello from the model!"))); + + testEnv.start(); + + TestChatWorkflow workflow = + client.newWorkflowStub( + TestChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); + + String result = workflow.chat("Hi"); + assertEquals("Hello from the model!", result); + + WorkflowExecutionHistory history = + client.fetchHistory(WorkflowStub.fromTyped(workflow).getExecution().getWorkflowId()); + WorkflowReplayer.replayWorkflowExecution(history, ChatWorkflowImpl.class); + } + + @Test + void workflowWithTools_replaysDeterministically() throws Exception { + Worker worker = testEnv.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(ChatWithToolsWorkflowImpl.class); + + // First call: model requests the "add" tool. Second call: model returns final text. + ChatModel toolCallingModel = new ToolCallingStubChatModel(); + worker.registerActivitiesImplementations(new ChatModelActivityImpl(toolCallingModel)); + + testEnv.start(); + + TestChatWorkflow workflow = + client.newWorkflowStub( + TestChatWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build()); + + String result = workflow.chat("What is 2+3?"); + assertEquals("The answer is 5", result); + + WorkflowExecutionHistory history = + client.fetchHistory(WorkflowStub.fromTyped(workflow).getExecution().getWorkflowId()); + WorkflowReplayer.replayWorkflowExecution(history, ChatWithToolsWorkflowImpl.class); + } + + // --- Workflow interfaces and implementations --- + + @WorkflowInterface + public interface TestChatWorkflow { + @WorkflowMethod + String chat(String message); + } + + public static class ChatWorkflowImpl implements TestChatWorkflow { + @Override + public String chat(String message) { + ActivityChatModel chatModel = ActivityChatModel.forDefault(); + ChatClient chatClient = TemporalChatClient.builder(chatModel).build(); + return chatClient.prompt().user(message).call().content(); + } + } + + public static class ChatWithToolsWorkflowImpl implements TestChatWorkflow { + @Override + public String chat(String message) { + ActivityChatModel chatModel = ActivityChatModel.forDefault(); + TestPlainTools plainTools = new TestPlainTools(); + TestSideEffectTools sideEffectTools = new TestSideEffectTools(); + ChatClient chatClient = + TemporalChatClient.builder(chatModel).defaultTools(plainTools, sideEffectTools).build(); + return chatClient.prompt().user(message).call().content(); + } + } + + // --- Test tool classes --- + + public static class TestPlainTools { + @Tool(description = "Add two numbers") + public int add(int a, int b) { + return a + b; + } + } + + @SideEffectTool + public static class TestSideEffectTools { + @Tool(description = "Get a timestamp") + public String timestamp() { + return "2025-01-01T00:00:00Z"; + } + } + + // --- Stub ChatModels --- + + /** Always returns a final text response, no tool calls. */ + private static class StubChatModel implements ChatModel { + private final String response; + + StubChatModel(String response) { + this.response = response; + } + + @Override + public ChatResponse call(Prompt prompt) { + return ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage(response)))) + .build(); + } + + @Override + public reactor.core.publisher.Flux stream(Prompt prompt) { + throw new UnsupportedOperationException(); + } + } + + /** + * First call: returns a tool call request for "add(2, 3)". Second call (after tool response): + * returns final text "The answer is 5". + */ + private static class ToolCallingStubChatModel implements ChatModel { + private final AtomicInteger callCount = new AtomicInteger(0); + + @Override + public ChatResponse call(Prompt prompt) { + if (callCount.getAndIncrement() == 0) { + // First call: request a tool call + AssistantMessage toolRequest = + AssistantMessage.builder() + .content("") + .toolCalls( + List.of( + new AssistantMessage.ToolCall( + "call_1", "function", "add", "{\"a\":2,\"b\":3}"))) + .build(); + return ChatResponse.builder().generations(List.of(new Generation(toolRequest))).build(); + } else { + // Second call: return final response + return ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("The answer is 5")))) + .build(); + } + } + + @Override + public reactor.core.publisher.Flux stream(Prompt prompt) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/temporal-spring-ai/src/test/java/io/temporal/springai/activity/ChatModelActivityImplTest.java b/temporal-spring-ai/src/test/java/io/temporal/springai/activity/ChatModelActivityImplTest.java new file mode 100644 index 000000000..300fe7dd7 --- /dev/null +++ b/temporal-spring-ai/src/test/java/io/temporal/springai/activity/ChatModelActivityImplTest.java @@ -0,0 +1,297 @@ +package io.temporal.springai.activity; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import io.temporal.springai.model.ChatModelTypes.*; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.ToolCallingChatOptions; + +class ChatModelActivityImplTest { + + @Test + void systemMessage_roundTrip() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("reply")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("You are helpful", Message.Role.SYSTEM)), null, List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + + assertNotNull(output); + assertEquals(1, output.generations().size()); + assertEquals("reply", output.generations().get(0).message().rawContent()); + assertEquals(Message.Role.ASSISTANT, output.generations().get(0).message().role()); + + // Verify the prompt was constructed with a SystemMessage + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + Prompt prompt = captor.getValue(); + assertEquals(1, prompt.getInstructions().size()); + assertInstanceOf( + org.springframework.ai.chat.messages.SystemMessage.class, prompt.getInstructions().get(0)); + assertEquals("You are helpful", prompt.getInstructions().get(0).getText()); + } + + @Test + void userMessage_roundTrip() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("hi")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("hello", Message.Role.USER)), null, List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + Prompt prompt = captor.getValue(); + assertInstanceOf( + org.springframework.ai.chat.messages.UserMessage.class, prompt.getInstructions().get(0)); + } + + @Test + void assistantMessageWithToolCalls_roundTrip() { + ChatModel mockModel = mock(ChatModel.class); + + // Model returns a response with tool calls + AssistantMessage assistantWithTools = + AssistantMessage.builder() + .content("I'll check the weather") + .toolCalls( + List.of( + new AssistantMessage.ToolCall( + "call_123", "function", "getWeather", "{\"city\":\"Seattle\"}"))) + .build(); + + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(assistantWithTools))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("What's the weather?", Message.Role.USER)), null, List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + + // Verify tool calls are preserved in output + Message outputMsg = output.generations().get(0).message(); + assertNotNull(outputMsg.toolCalls()); + assertEquals(1, outputMsg.toolCalls().size()); + assertEquals("call_123", outputMsg.toolCalls().get(0).id()); + assertEquals("function", outputMsg.toolCalls().get(0).type()); + assertEquals("getWeather", outputMsg.toolCalls().get(0).function().name()); + assertEquals("{\"city\":\"Seattle\"}", outputMsg.toolCalls().get(0).function().arguments()); + } + + @Test + void toolResponseMessage_roundTrip() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("It's 55F")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, + List.of( + new Message( + "Weather: 55F", Message.Role.TOOL, "getWeather", "call_123", null, null)), + null, + List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + + // Verify tool response was passed to model + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + Prompt prompt = captor.getValue(); + assertInstanceOf( + org.springframework.ai.chat.messages.ToolResponseMessage.class, + prompt.getInstructions().get(0)); + } + + @Test + void modelOptions_passedThrough() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("ok")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ModelOptions opts = new ModelOptions("gpt-4", null, 100, null, null, 0.5, null, 0.9); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("hi", Message.Role.USER)), opts, List.of()); + + impl.callChatModel(input); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + Prompt prompt = captor.getValue(); + assertNotNull(prompt.getOptions()); + assertEquals("gpt-4", prompt.getOptions().getModel()); + assertEquals(0.5, prompt.getOptions().getTemperature()); + assertEquals(0.9, prompt.getOptions().getTopP()); + assertEquals(100, prompt.getOptions().getMaxTokens()); + } + + @Test + void toolDefinitions_passedAsStubs() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("ok")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + FunctionTool tool = + new FunctionTool( + new FunctionTool.Function( + "getWeather", "Get weather for a city", "{\"type\":\"object\"}")); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("hi", Message.Role.USER)), null, List.of(tool)); + + impl.callChatModel(input); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + Prompt prompt = captor.getValue(); + + // Verify tool execution is disabled (workflow handles it) + assertInstanceOf(ToolCallingChatOptions.class, prompt.getOptions()); + assertFalse(ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions())); + } + + @Test + void multipleModels_resolvedByName() { + ChatModel openAi = mock(ChatModel.class); + ChatModel anthropic = mock(ChatModel.class); + when(openAi.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("openai")))) + .build()); + when(anthropic.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("anthropic")))) + .build()); + + ChatModelActivityImpl impl = + new ChatModelActivityImpl(Map.of("openai", openAi, "anthropic", anthropic), "openai"); + + // Call with specific model + ChatModelActivityInput input = + new ChatModelActivityInput( + "anthropic", List.of(new Message("hi", Message.Role.USER)), null, List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + assertEquals("anthropic", output.generations().get(0).message().rawContent()); + verify(anthropic).call(any(Prompt.class)); + verify(openAi, never()).call(any(Prompt.class)); + } + + @Test + void multipleModels_defaultUsedWhenNameNull() { + ChatModel openAi = mock(ChatModel.class); + ChatModel anthropic = mock(ChatModel.class); + when(openAi.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("openai")))) + .build()); + + ChatModelActivityImpl impl = + new ChatModelActivityImpl(Map.of("openai", openAi, "anthropic", anthropic), "openai"); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, List.of(new Message("hi", Message.Role.USER)), null, List.of()); + + ChatModelActivityOutput output = impl.callChatModel(input); + assertEquals("openai", output.generations().get(0).message().rawContent()); + verify(openAi).call(any(Prompt.class)); + } + + @Test + void unknownModelName_throwsIllegalArgument() { + ChatModel model = mock(ChatModel.class); + ChatModelActivityImpl impl = new ChatModelActivityImpl(model); + + ChatModelActivityInput input = + new ChatModelActivityInput( + "nonexistent", List.of(new Message("hi", Message.Role.USER)), null, List.of()); + + assertThrows(IllegalArgumentException.class, () -> impl.callChatModel(input)); + } + + @Test + void multipleMessages_allConverted() { + ChatModel mockModel = mock(ChatModel.class); + when(mockModel.call(any(Prompt.class))) + .thenReturn( + ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("ok")))) + .build()); + + ChatModelActivityImpl impl = new ChatModelActivityImpl(mockModel); + + ChatModelActivityInput input = + new ChatModelActivityInput( + null, + List.of( + new Message("You are helpful", Message.Role.SYSTEM), + new Message("Hello", Message.Role.USER), + new Message("Hi there", Message.Role.ASSISTANT, null, null, null, null), + new Message("What's up?", Message.Role.USER)), + null, + List.of()); + + impl.callChatModel(input); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Prompt.class); + verify(mockModel).call(captor.capture()); + assertEquals(4, captor.getValue().getInstructions().size()); + } +} diff --git a/temporal-spring-ai/src/test/java/io/temporal/springai/plugin/SpringAiPluginTest.java b/temporal-spring-ai/src/test/java/io/temporal/springai/plugin/SpringAiPluginTest.java new file mode 100644 index 000000000..869ae7e5c --- /dev/null +++ b/temporal-spring-ai/src/test/java/io/temporal/springai/plugin/SpringAiPluginTest.java @@ -0,0 +1,135 @@ +package io.temporal.springai.plugin; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import io.temporal.springai.activity.ChatModelActivityImpl; +import io.temporal.springai.activity.EmbeddingModelActivityImpl; +import io.temporal.springai.activity.VectorStoreActivityImpl; +import io.temporal.worker.Worker; +import java.util.*; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.VectorStore; + +class SpringAiPluginTest { + + private List captureRegisteredActivities(Worker worker) { + ArgumentCaptor captor = ArgumentCaptor.forClass(Object.class); + verify(worker, atLeastOnce()).registerActivitiesImplementations(captor.capture()); + return captor.getAllValues(); + } + + private Set> activityTypes(List activities) { + return activities.stream().map(Object::getClass).collect(Collectors.toSet()); + } + + // --- Core SpringAiPlugin tests --- + + @Test + void singleModel_registersChatModelAndExecuteToolLocal() { + ChatModel chatModel = mock(ChatModel.class); + Worker worker = mock(Worker.class); + + SpringAiPlugin plugin = new SpringAiPlugin(chatModel); + plugin.initializeWorker("test-queue", worker); + + Set> types = activityTypes(captureRegisteredActivities(worker)); + assertTrue(types.contains(ChatModelActivityImpl.class)); + // No VectorStore or EmbeddingModel — those are separate plugins now + assertFalse(types.contains(VectorStoreActivityImpl.class)); + assertFalse(types.contains(EmbeddingModelActivityImpl.class)); + } + + @Test + void multipleModels_allExposed() { + ChatModel model1 = mock(ChatModel.class); + ChatModel model2 = mock(ChatModel.class); + Map models = new LinkedHashMap<>(); + models.put("openai", model1); + models.put("anthropic", model2); + + SpringAiPlugin plugin = new SpringAiPlugin(models, model1); + + assertEquals(2, plugin.getChatModels().size()); + assertSame(model1, plugin.getChatModel("openai")); + assertSame(model2, plugin.getChatModel("anthropic")); + } + + @Test + void primaryModel_usedAsDefault() { + ChatModel model1 = mock(ChatModel.class); + ChatModel model2 = mock(ChatModel.class); + Map models = new LinkedHashMap<>(); + models.put("openai", model1); + models.put("anthropic", model2); + + SpringAiPlugin plugin = new SpringAiPlugin(models, model2); + + assertEquals("anthropic", plugin.getDefaultModelName()); + assertSame(model2, plugin.getChatModel()); + } + + @Test + void noPrimaryModel_firstEntryIsDefault() { + ChatModel model1 = mock(ChatModel.class); + ChatModel model2 = mock(ChatModel.class); + Map models = new LinkedHashMap<>(); + models.put("openai", model1); + models.put("anthropic", model2); + + SpringAiPlugin plugin = new SpringAiPlugin(models, null); + + assertEquals("openai", plugin.getDefaultModelName()); + assertSame(model1, plugin.getChatModel()); + } + + @Test + void singleModelConstructor_usesDefaultModelName() { + ChatModel chatModel = mock(ChatModel.class); + SpringAiPlugin plugin = new SpringAiPlugin(chatModel); + + assertEquals(SpringAiPlugin.DEFAULT_MODEL_NAME, plugin.getDefaultModelName()); + assertSame(chatModel, plugin.getChatModel()); + } + + @Test + void nullChatModelsMap_throwsIllegalArgument() { + assertThrows(IllegalArgumentException.class, () -> new SpringAiPlugin(null, null)); + } + + @Test + void emptyChatModelsMap_throwsIllegalArgument() { + assertThrows( + IllegalArgumentException.class, () -> new SpringAiPlugin(new LinkedHashMap<>(), null)); + } + + // --- Optional plugin tests --- + + @Test + void vectorStorePlugin_registersActivity() { + VectorStore vectorStore = mock(VectorStore.class); + Worker worker = mock(Worker.class); + + VectorStorePlugin plugin = new VectorStorePlugin(vectorStore); + plugin.initializeWorker("test-queue", worker); + + Set> types = activityTypes(captureRegisteredActivities(worker)); + assertTrue(types.contains(VectorStoreActivityImpl.class)); + } + + @Test + void embeddingModelPlugin_registersActivity() { + EmbeddingModel embeddingModel = mock(EmbeddingModel.class); + Worker worker = mock(Worker.class); + + EmbeddingModelPlugin plugin = new EmbeddingModelPlugin(embeddingModel); + plugin.initializeWorker("test-queue", worker); + + Set> types = activityTypes(captureRegisteredActivities(worker)); + assertTrue(types.contains(EmbeddingModelActivityImpl.class)); + } +} diff --git a/temporal-spring-ai/src/test/java/io/temporal/springai/util/TemporalToolUtilTest.java b/temporal-spring-ai/src/test/java/io/temporal/springai/util/TemporalToolUtilTest.java new file mode 100644 index 000000000..1034863b3 --- /dev/null +++ b/temporal-spring-ai/src/test/java/io/temporal/springai/util/TemporalToolUtilTest.java @@ -0,0 +1,205 @@ +package io.temporal.springai.util; + +import static org.junit.jupiter.api.Assertions.*; + +import io.temporal.springai.tool.SideEffectTool; +import io.temporal.springai.tool.SideEffectToolCallback; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.annotation.Tool; + +class TemporalToolUtilTest { + + // --- Test fixture classes --- + + static class MathTools { + @Tool(description = "Add two numbers") + public int add(int a, int b) { + return a + b; + } + + @Tool(description = "Multiply two numbers") + public int multiply(int a, int b) { + return a * b; + } + } + + @SideEffectTool + static class TimestampTools { + @Tool(description = "Get the current timestamp") + public long currentTimeMillis() { + return System.currentTimeMillis(); + } + } + + @SideEffectTool + static class RandomTools { + @Tool(description = "Generate a random number") + public double random() { + return Math.random(); + } + } + + static class UnannotatedTools { + @Tool(description = "Some tool") + public String doSomething() { + return "result"; + } + } + + // --- Tests for plain tools (execute in workflow context) --- + + @Test + void convertTools_plainTool_producesStandardCallbacks() { + List callbacks = TemporalToolUtil.convertTools(new MathTools()); + + assertEquals(2, callbacks.size()); + for (ToolCallback cb : callbacks) { + assertFalse( + cb instanceof SideEffectToolCallback, + "Plain tool should not produce SideEffectToolCallback"); + } + } + + @Test + void convertTools_plainTool_hasCorrectToolNames() { + List callbacks = TemporalToolUtil.convertTools(new MathTools()); + + List toolNames = + callbacks.stream().map(cb -> cb.getToolDefinition().name()).sorted().toList(); + assertEquals(List.of("add", "multiply"), toolNames); + } + + @Test + void convertTools_unannotatedTool_producesStandardCallbacks() { + List callbacks = TemporalToolUtil.convertTools(new UnannotatedTools()); + + assertEquals(1, callbacks.size()); + assertEquals("doSomething", callbacks.get(0).getToolDefinition().name()); + } + + @Test + void convertTools_plainString_throwsIllegalState() { + // String has no @Tool methods — Spring AI's ToolCallbacks.from() throws + assertThrows(IllegalStateException.class, () -> TemporalToolUtil.convertTools("not a tool")); + } + + // --- Tests for @SideEffectTool --- + + @Test + void convertTools_sideEffectTool_producesSideEffectCallbackWrappers() { + List callbacks = TemporalToolUtil.convertTools(new TimestampTools()); + + assertEquals(1, callbacks.size()); + assertInstanceOf(SideEffectToolCallback.class, callbacks.get(0)); + } + + @Test + void convertTools_sideEffectTool_hasCorrectToolName() { + List callbacks = TemporalToolUtil.convertTools(new TimestampTools()); + + assertEquals("currentTimeMillis", callbacks.get(0).getToolDefinition().name()); + } + + @Test + void convertTools_sideEffectTool_delegateIsPreserved() { + List callbacks = TemporalToolUtil.convertTools(new TimestampTools()); + + SideEffectToolCallback wrapper = (SideEffectToolCallback) callbacks.get(0); + assertNotNull(wrapper.getDelegate()); + assertEquals("currentTimeMillis", wrapper.getDelegate().getToolDefinition().name()); + } + + // --- Tests for null handling --- + + @Test + void convertTools_nullObject_throwsIllegalArgumentException() { + assertThrows( + IllegalArgumentException.class, () -> TemporalToolUtil.convertTools((Object) null)); + } + + @Test + void convertTools_nullInArray_throwsIllegalArgumentException() { + assertThrows( + IllegalArgumentException.class, () -> TemporalToolUtil.convertTools(new MathTools(), null)); + } + + // --- Tests for empty input --- + + @Test + void convertTools_emptyArray_returnsEmptyList() { + List callbacks = TemporalToolUtil.convertTools(); + assertTrue(callbacks.isEmpty()); + } + + // --- Tests for mixed tool types --- + + @Test + void convertTools_mixedPlainAndSideEffect_allConvertCorrectly() { + List callbacks = + TemporalToolUtil.convertTools(new MathTools(), new TimestampTools(), new RandomTools()); + + assertEquals(4, callbacks.size()); + + long sideEffectCount = + callbacks.stream().filter(cb -> cb instanceof SideEffectToolCallback).count(); + long standardCount = + callbacks.stream().filter(cb -> !(cb instanceof SideEffectToolCallback)).count(); + + assertEquals(2, sideEffectCount); + assertEquals(2, standardCount); + } + + @Test + void convertTools_mixedWithUnannotated_allSucceed() { + List callbacks = + TemporalToolUtil.convertTools(new MathTools(), new UnannotatedTools()); + + assertEquals(3, callbacks.size()); // 2 from MathTools + 1 from UnannotatedTools + } + + // --- Tests for TemporalStubUtil negative cases --- + + @Test + void stubUtil_isActivityStub_nonProxy_returnsFalse() { + assertFalse(TemporalStubUtil.isActivityStub(new MathTools())); + assertFalse(TemporalStubUtil.isActivityStub("not a stub")); + assertFalse(TemporalStubUtil.isActivityStub(null)); + } + + @Test + void stubUtil_isLocalActivityStub_nonProxy_returnsFalse() { + assertFalse(TemporalStubUtil.isLocalActivityStub(new MathTools())); + assertFalse(TemporalStubUtil.isLocalActivityStub("not a stub")); + assertFalse(TemporalStubUtil.isLocalActivityStub(null)); + } + + @Test + void stubUtil_isChildWorkflowStub_nonProxy_returnsFalse() { + assertFalse(TemporalStubUtil.isChildWorkflowStub(new MathTools())); + assertFalse(TemporalStubUtil.isChildWorkflowStub("not a stub")); + assertFalse(TemporalStubUtil.isChildWorkflowStub(null)); + } + + @Test + void stubUtil_isNexusServiceStub_nonProxy_returnsFalse() { + assertFalse(TemporalStubUtil.isNexusServiceStub(new MathTools())); + assertFalse(TemporalStubUtil.isNexusServiceStub("not a stub")); + assertFalse(TemporalStubUtil.isNexusServiceStub(null)); + } + + @Test + void stubUtil_nonTemporalProxy_returnsFalse() { + Object proxy = + java.lang.reflect.Proxy.newProxyInstance( + getClass().getClassLoader(), + new Class[] {Runnable.class}, + (p, method, args) -> null); + + assertFalse(TemporalStubUtil.isActivityStub(proxy)); + assertFalse(TemporalStubUtil.isLocalActivityStub(proxy)); + assertFalse(TemporalStubUtil.isChildWorkflowStub(proxy)); + assertFalse(TemporalStubUtil.isNexusServiceStub(proxy)); + } +}