diff --git a/cmd/testai/main-testai.go b/cmd/testai/main-testai.go index 04c3ae91f9..2684336e45 100644 --- a/cmd/testai/main-testai.go +++ b/cmd/testai/main-testai.go @@ -27,6 +27,7 @@ const ( DefaultAnthropicModel = "claude-sonnet-4-5" DefaultOpenAIModel = "gpt-5.1" DefaultOpenRouterModel = "mistralai/mistral-small-3.2-24b-instruct" + DefaultGeminiModel = "gemini-3-pro-preview" ) // TestResponseWriter implements http.ResponseWriter and additional interfaces for testing @@ -306,6 +307,57 @@ func testAnthropic(ctx context.Context, model, message string, tools []uctypes.T } } +func testGemini(ctx context.Context, model, message string, tools []uctypes.ToolDefinition) { + apiKey := os.Getenv("GOOGLE_APIKEY") + if apiKey == "" { + fmt.Println("Error: GOOGLE_APIKEY environment variable not set") + os.Exit(1) + } + + opts := &uctypes.AIOptsType{ + APIType: uctypes.APIType_GoogleGemini, + APIToken: apiKey, + Model: model, + MaxTokens: 8192, + Capabilities: []string{uctypes.AICapabilityTools, uctypes.AICapabilityImages, uctypes.AICapabilityPdfs}, + } + + // Generate a chat ID + chatID := uuid.New().String() + + // Convert to AIMessage format for WaveAIPostMessageWrap + aiMessage := &uctypes.AIMessage{ + MessageId: uuid.New().String(), + Parts: []uctypes.AIMessagePart{ + { + Type: uctypes.AIMessagePartTypeText, + Text: message, + }, + }, + } + + fmt.Printf("Testing Google Gemini streaming with WaveAIPostMessageWrap, model: %s\n", model) + fmt.Printf("Message: %s\n", message) + fmt.Printf("Chat ID: %s\n", chatID) + fmt.Println("---") + + testWriter := &TestResponseWriter{} + sseHandler := sse.MakeSSEHandlerCh(testWriter, ctx) + defer sseHandler.Close() + + chatOpts := uctypes.WaveChatOpts{ + ChatId: chatID, + ClientId: uuid.New().String(), + Config: *opts, + Tools: tools, + SystemPrompt: []string{"You are a helpful assistant. Be concise and clear in your responses."}, + } + err := aiusechat.WaveAIPostMessageWrap(ctx, sseHandler, aiMessage, chatOpts) + if err != nil { + fmt.Printf("Google Gemini streaming error: %v\n", err) + } +} + func testT1(ctx context.Context) { tool := aiusechat.GetAdderToolDefinition() tools := []uctypes.ToolDefinition{tool} @@ -322,8 +374,14 @@ func testT3(ctx context.Context) { testOpenAIComp(ctx, "gpt-4o", "what is 2+2? please be brief", nil) } +func testT4(ctx context.Context) { + tool := aiusechat.GetAdderToolDefinition() + tools := []uctypes.ToolDefinition{tool} + testGemini(ctx, DefaultGeminiModel, "what is 2+2+8, use the provider adder tool", tools) +} + func printUsage() { - fmt.Println("Usage: go run main-testai.go [--anthropic|--openaicomp|--openrouter] [--tools] [--model ] [message]") + fmt.Println("Usage: go run main-testai.go [--anthropic|--openaicomp|--openrouter|--gemini] [--tools] [--model ] [message]") fmt.Println("Examples:") fmt.Println(" go run main-testai.go 'What is 2+2?'") fmt.Println(" go run main-testai.go --model o4-mini 'What is 2+2?'") @@ -332,6 +390,8 @@ func printUsage() { fmt.Println(" go run main-testai.go --openaicomp --model gpt-4o 'What is 2+2?'") fmt.Println(" go run main-testai.go --openrouter 'What is 2+2?'") fmt.Println(" go run main-testai.go --openrouter --model anthropic/claude-3.5-sonnet 'What is 2+2?'") + fmt.Println(" go run main-testai.go --gemini 'What is 2+2?'") + fmt.Println(" go run main-testai.go --gemini --model gemini-1.5-pro 'What is 2+2?'") fmt.Println(" go run main-testai.go --tools 'Help me configure GitHub Actions monitoring'") fmt.Println("") fmt.Println("Default models:") @@ -339,25 +399,29 @@ func printUsage() { fmt.Printf(" Anthropic: %s\n", DefaultAnthropicModel) fmt.Printf(" OpenAI Completions: gpt-4o\n") fmt.Printf(" OpenRouter: %s\n", DefaultOpenRouterModel) + fmt.Printf(" Google Gemini: %s\n", DefaultGeminiModel) fmt.Println("") fmt.Println("Environment variables:") fmt.Println(" OPENAI_APIKEY (for OpenAI models)") fmt.Println(" ANTHROPIC_APIKEY (for Anthropic models)") fmt.Println(" OPENROUTER_APIKEY (for OpenRouter models)") + fmt.Println(" GOOGLE_APIKEY (for Google Gemini models)") } func main() { - var anthropic, openaicomp, openrouter, tools, help, t1, t2, t3 bool + var anthropic, openaicomp, openrouter, gemini, tools, help, t1, t2, t3, t4 bool var model string flag.BoolVar(&anthropic, "anthropic", false, "Use Anthropic API instead of OpenAI") flag.BoolVar(&openaicomp, "openaicomp", false, "Use OpenAI Completions API") flag.BoolVar(&openrouter, "openrouter", false, "Use OpenRouter API") + flag.BoolVar(&gemini, "gemini", false, "Use Google Gemini API") flag.BoolVar(&tools, "tools", false, "Enable GitHub Actions Monitor tools for testing") - flag.StringVar(&model, "model", "", fmt.Sprintf("AI model to use (defaults: %s for OpenAI, %s for Anthropic, %s for OpenRouter)", DefaultOpenAIModel, DefaultAnthropicModel, DefaultOpenRouterModel)) + flag.StringVar(&model, "model", "", fmt.Sprintf("AI model to use (defaults: %s for OpenAI, %s for Anthropic, %s for OpenRouter, %s for Gemini)", DefaultOpenAIModel, DefaultAnthropicModel, DefaultOpenRouterModel, DefaultGeminiModel)) flag.BoolVar(&help, "help", false, "Show usage information") flag.BoolVar(&t1, "t1", false, fmt.Sprintf("Run preset T1 test (%s with 'what is 2+2')", DefaultAnthropicModel)) flag.BoolVar(&t2, "t2", false, fmt.Sprintf("Run preset T2 test (%s with 'what is 2+2')", DefaultOpenAIModel)) - flag.BoolVar(&t3, "t3", false, "Run preset T3 test (OpenAI Completions API with gpt-4o)") + flag.BoolVar(&t3, "t3", false, "Run preset T3 test (OpenAI Completions API with gpt-5.1)") + flag.BoolVar(&t4, "t4", false, "Run preset T4 test (OpenAI Completions API with gemini-3-pro-preview)") flag.Parse() if help { @@ -380,6 +444,10 @@ func main() { testT3(ctx) return } + if t4 { + testT4(ctx) + return + } // Set default model based on API type if not provided if model == "" { @@ -389,6 +457,8 @@ func main() { model = "gpt-4o" } else if openrouter { model = DefaultOpenRouterModel + } else if gemini { + model = DefaultGeminiModel } else { model = DefaultOpenAIModel } @@ -411,6 +481,8 @@ func main() { testOpenAIComp(ctx, model, message, toolDefs) } else if openrouter { testOpenRouter(ctx, model, message, toolDefs) + } else if gemini { + testGemini(ctx, model, message, toolDefs) } else { testOpenAI(ctx, model, message, toolDefs) } diff --git a/cmd/wsh/cmd/wshcmd-secret.go b/cmd/wsh/cmd/wshcmd-secret.go index cfd8daf407..f2c287579a 100644 --- a/cmd/wsh/cmd/wshcmd-secret.go +++ b/cmd/wsh/cmd/wshcmd-secret.go @@ -179,7 +179,8 @@ func secretUiRun(cmd *cobra.Command, args []string) (rtnErr error) { wshCmd := &wshrpc.CommandCreateBlockData{ BlockDef: &waveobj.BlockDef{ Meta: map[string]interface{}{ - waveobj.MetaKey_View: "secretstore", + waveobj.MetaKey_View: "waveconfig", + waveobj.MetaKey_File: "secrets", }, }, Magnified: secretUiMagnified, diff --git a/docs/docs/waveai-modes.mdx b/docs/docs/waveai-modes.mdx index 0794a61a3a..ccdcdfd7f4 100644 --- a/docs/docs/waveai-modes.mdx +++ b/docs/docs/waveai-modes.mdx @@ -1,7 +1,7 @@ --- sidebar_position: 1.6 id: "waveai-modes" -title: "Wave AI (Local Models)" +title: "Wave AI (Local Models + BYOK)" --- Wave AI supports custom AI modes that allow you to use local models, custom API endpoints, and alternative AI providers. This gives you complete control over which models and providers you use with Wave's AI features. @@ -37,10 +37,11 @@ Wave AI now supports provider-based configuration which automatically applies se ### Supported API Types -Wave AI supports two OpenAI-compatible API types: +Wave AI supports the following API types: - **`openai-chat`**: Uses the `/v1/chat/completions` endpoint (most common) - **`openai-responses`**: Uses the `/v1/responses` endpoint (modern API for GPT-5+ models) +- **`google-gemini`**: Google's Gemini API format (automatically set when using `ai:provider: "google"`, not typically used directly) ## Configuration Structure @@ -49,7 +50,7 @@ Wave AI supports two OpenAI-compatible API types: ```json { "mode-key": { - "display:name": "Display Name", + "display:name": "Qwen (OpenRouter)", "ai:provider": "openrouter", "ai:model": "qwen/qwen-2.5-coder-32b-instruct" } @@ -89,10 +90,10 @@ Wave AI supports two OpenAI-compatible API types: | `display:icon` | No | Icon identifier for the mode | | `display:description` | No | Full description of the mode | | `ai:provider` | No | Provider preset: `openai`, `openrouter`, `google`, `azure`, `azure-legacy`, `custom` | -| `ai:apitype` | No | API type: `openai-chat` or `openai-responses` (defaults to `openai-chat` if not specified) | +| `ai:apitype` | No | API type: `openai-chat`, `openai-responses`, or `google-gemini` (defaults to `openai-chat` if not specified) | | `ai:model` | No | Model identifier (required for most providers) | | `ai:thinkinglevel` | No | Thinking level: `low`, `medium`, or `high` | -| `ai:endpoint` | No | Full API endpoint URL (auto-set by provider when available) | +| `ai:endpoint` | No | *Full* API endpoint URL (auto-set by provider when available) | | `ai:azureapiversion` | No | Azure API version (for `azure-legacy` provider, defaults to `2025-04-01-preview`) | | `ai:apitoken` | No | API key/token (not recommended - use secrets instead) | | `ai:apitokensecretname` | No | Name of secret containing API token (auto-set by provider) | @@ -110,6 +111,14 @@ The `ai:capabilities` field specifies what features the AI mode supports: - **`images`** - Allows image attachments in chat (model can view uploaded images) - **`pdfs`** - Allows PDF file attachments in chat (model can read PDF content) +**Provider-specific behavior:** +- **OpenAI and Google providers**: Capabilities are automatically configured based on the model. You don't need to specify them. +- **OpenRouter, Azure, Azure-Legacy, and Custom providers**: You must manually specify capabilities based on your model's features. + +:::warning +If you don't include `"tools"` in the `ai:capabilities` array, the AI model will not be able to interact with your Wave terminal widgets, read/write files, or execute commands. Most AI modes should include `"tools"` for the best Wave experience. +::: + Most models support `tools` and can benefit from it. Vision-capable models should include `images`. Not all models support PDFs, so only include `pdfs` if your model can process them. ## Local Model Examples @@ -127,7 +136,7 @@ Most models support `tools` and can benefit from it. Vision-capable models shoul "display:description": "Local Llama 3.3 70B model via Ollama", "ai:apitype": "openai-chat", "ai:model": "llama3.3:70b", - "ai:thinkinglevel": "normal", + "ai:thinkinglevel": "medium", "ai:endpoint": "http://localhost:11434/v1/chat/completions", "ai:apitoken": "ollama" } @@ -151,28 +160,28 @@ The `ai:apitoken` field is required but Ollama ignores it - you can set it to an "display:description": "Local Qwen model via LM Studio", "ai:apitype": "openai-chat", "ai:model": "qwen/qwen-2.5-coder-32b-instruct", - "ai:thinkinglevel": "normal", + "ai:thinkinglevel": "medium", "ai:endpoint": "http://localhost:1234/v1/chat/completions", "ai:apitoken": "not-needed" } } ``` -### Jan +### vLLM -[Jan](https://jan.ai) is another local AI runtime with OpenAI API compatibility: +[vLLM](https://docs.vllm.ai) is a high-performance inference server with OpenAI API compatibility: ```json { - "jan-local": { - "display:name": "Jan", + "vllm-local": { + "display:name": "vLLM", "display:order": 3, "display:icon": "server", - "display:description": "Local model via Jan", + "display:description": "Local model via vLLM", "ai:apitype": "openai-chat", "ai:model": "your-model-name", - "ai:thinkinglevel": "normal", - "ai:endpoint": "http://localhost:1337/v1/chat/completions", + "ai:thinkinglevel": "medium", + "ai:endpoint": "http://localhost:8000/v1/chat/completions", "ai:apitoken": "not-needed" } } @@ -198,6 +207,7 @@ The provider automatically sets: - `ai:endpoint` to `https://api.openai.com/v1/chat/completions` - `ai:apitype` to `openai-chat` (or `openai-responses` for GPT-5+ models) - `ai:apitokensecretname` to `OPENAI_KEY` (store your OpenAI API key with this name) +- `ai:capabilities` to `["tools", "images", "pdfs"]` (automatically determined based on model) For newer models like GPT-4.1 or GPT-5, the API type is automatically determined: @@ -230,6 +240,40 @@ The provider automatically sets: - `ai:apitype` to `openai-chat` - `ai:apitokensecretname` to `OPENROUTER_KEY` (store your OpenRouter API key with this name) +:::note +For OpenRouter, you must manually specify `ai:capabilities` based on your model's features. Example: +```json +{ + "openrouter-qwen": { + "display:name": "OpenRouter - Qwen", + "ai:provider": "openrouter", + "ai:model": "qwen/qwen-2.5-coder-32b-instruct", + "ai:capabilities": ["tools"] + } +} +``` +::: + +### Google AI (Gemini) + +[Google AI](https://ai.google.dev) provides the Gemini family of models. Using the `google` provider simplifies configuration: + +```json +{ + "google-gemini": { + "display:name": "Gemini 3 Pro", + "ai:provider": "google", + "ai:model": "gemini-3-pro-preview" + } +} +``` + +The provider automatically sets: +- `ai:endpoint` to `https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent` +- `ai:apitype` to `google-gemini` +- `ai:apitokensecretname` to `GOOGLE_AI_KEY` (store your Google AI API key with this name) +- `ai:capabilities` to `["tools", "images", "pdfs"]` (automatically configured) + ### Azure OpenAI (Modern API) For the modern Azure OpenAI API, use the `azure` provider: @@ -250,6 +294,21 @@ The provider automatically sets: - `ai:apitype` based on the model - `ai:apitokensecretname` to `AZURE_OPENAI_KEY` (store your Azure OpenAI key with this name) +:::note +For Azure providers, you must manually specify `ai:capabilities` based on your model's features. Example: +```json +{ + "azure-gpt4": { + "display:name": "Azure GPT-4", + "ai:provider": "azure", + "ai:model": "gpt-4", + "ai:azureresourcename": "your-resource-name", + "ai:capabilities": ["tools", "images"] + } +} +``` +::: + ### Azure OpenAI (Legacy Deployment API) For legacy Azure deployments, use the `azure-legacy` provider: @@ -267,6 +326,10 @@ For legacy Azure deployments, use the `azure-legacy` provider: The provider automatically constructs the full endpoint URL and sets the API version (defaults to `2025-04-01-preview`). You can override the API version with `ai:azureapiversion` if needed. +:::note +For Azure Legacy provider, you must manually specify `ai:capabilities` based on your model's features. +::: + ## Using Secrets for API Keys Instead of storing API keys directly in the configuration, you should use Wave's secret store to keep your credentials secure. Secrets are stored encrypted using your system's native keychain. diff --git a/docs/docs/waveai.mdx b/docs/docs/waveai.mdx index e352865ef9..1d027177ef 100644 --- a/docs/docs/waveai.mdx +++ b/docs/docs/waveai.mdx @@ -34,7 +34,7 @@ Controls AI's access to your workspace: ## File Attachments -Drag files onto the AI panel to attach: +Drag files onto the AI panel to attach (not supported with all models): | Type | Formats | Size Limit | Notes | |------|---------|------------|-------| @@ -68,7 +68,7 @@ Supports text files, images, PDFs, and directories. Use `-n` for new chat, `-s` - **Navigate Web**: Changes URLs in web browser widgets ### All Widgets -- **Capture Screenshots**: Takes screenshots of any widget for visual analysis +- **Capture Screenshots**: Takes screenshots of any widget for visual analysis (not supported on all models) :::warning Security File system operations require explicit approval. You control all file access. diff --git a/go.mod b/go.mod index 9bb5b64197..ba6d1da584 100644 --- a/go.mod +++ b/go.mod @@ -79,6 +79,7 @@ require ( github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/outrigdev/goid v0.3.0 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sirupsen/logrus v1.9.3 // indirect diff --git a/go.sum b/go.sum index e18f442fee..d38e9d0cc2 100644 --- a/go.sum +++ b/go.sum @@ -142,6 +142,8 @@ github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuE github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/outrigdev/goid v0.3.0 h1:t/otQD3EXc45cLtQVPUnNgEyRaTQA4cPeu3qVcrsIws= +github.com/outrigdev/goid v0.3.0/go.mod h1:hEH7f27ypN/GHWt/7gvkRoFYR0LZizfUBIAbak4neVE= github.com/photostorm/pty v1.1.19-0.20230903182454-31354506054b h1:cLGKfKb1uk0hxI0Q8L83UAJPpeJ+gSpn3cCU/tjd3eg= github.com/photostorm/pty v1.1.19-0.20230903182454-31354506054b/go.mod h1:KO+FcPtyLAiRC0hJwreJVvfwc7vnNz77UxBTIGHdPVk= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= diff --git a/pkg/aiusechat/aiutil/aiutil.go b/pkg/aiusechat/aiutil/aiutil.go index d72d7e030e..8918d30037 100644 --- a/pkg/aiusechat/aiutil/aiutil.go +++ b/pkg/aiusechat/aiutil/aiutil.go @@ -209,6 +209,13 @@ func CheckModelSubPrefix(model string, prefix string) bool { return false } +// GeminiSupportsImageToolResults returns true if the model supports multimodal function responses (images in tool results) +// This is only supported by Gemini 3 Pro and later models +func GeminiSupportsImageToolResults(model string) bool { + m := strings.ToLower(model) + return strings.Contains(m, "gemini-3") || strings.Contains(m, "gemini-4") +} + // CreateToolUseData creates a UIMessageDataToolUse from tool call information func CreateToolUseData(toolCallID, toolName string, arguments string, chatOpts uctypes.WaveChatOpts) uctypes.UIMessageDataToolUse { toolUseData := uctypes.UIMessageDataToolUse{ diff --git a/pkg/aiusechat/gemini/doc.go b/pkg/aiusechat/gemini/doc.go new file mode 100644 index 0000000000..7fe0699c5e --- /dev/null +++ b/pkg/aiusechat/gemini/doc.go @@ -0,0 +1,99 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package gemini implements the Google Gemini backend for WaveTerm's AI chat system. +// +// This package provides a complete implementation of the UseChatBackend interface +// for Google's Gemini API, including: +// - Streaming chat responses via Server-Sent Events (SSE) +// - Function calling (tool use) support +// - Multi-modal input support (text, images, PDFs) +// - Proper message conversion and state management +// +// # API Type +// +// The Gemini backend uses the API type constant: +// uctypes.APIType_GoogleGemini = "google-gemini" +// +// # Supported Features +// +// - Text messages +// - Image uploads (JPEG, PNG, etc.) - inline base64 encoding +// - PDF document uploads - inline base64 encoding +// - Text file attachments +// - Directory listings +// - Function/tool calling with structured arguments +// - Streaming responses with real-time token delivery +// +// # Usage +// +// The backend is automatically registered and can be obtained via: +// +// backend, err := aiusechat.GetBackendByAPIType(uctypes.APIType_GoogleGemini) +// +// To use the Gemini API, you need: +// 1. A Google AI API key +// 2. Configure the chat with APIType_GoogleGemini +// 3. Set the Model (e.g., "gemini-2.0-flash-exp") +// 4. Provide the API key in the Config.APIToken field +// +// # Configuration Example +// +// chatOpts := uctypes.WaveChatOpts{ +// ChatId: "my-chat-id", +// ClientId: "my-client-id", +// Config: uctypes.AIOptsType{ +// APIType: uctypes.APIType_GoogleGemini, +// Model: "gemini-2.0-flash-exp", +// APIToken: "your-google-api-key", +// MaxTokens: 8192, +// Capabilities: []string{ +// uctypes.AICapabilityTools, +// uctypes.AICapabilityImages, +// uctypes.AICapabilityPdfs, +// }, +// }, +// Tools: []uctypes.ToolDefinition{...}, +// SystemPrompt: []string{"You are a helpful assistant."}, +// } +// +// # Message Format +// +// The Gemini backend uses the GeminiChatMessage type internally, which stores: +// - MessageId: Unique identifier for idempotency +// - Role: "user" or "model" (model is Gemini's term for assistant) +// - Parts: Array of message parts (text, inline data, function calls/responses) +// - Usage: Token usage metadata +// +// # Function Calling +// +// Function calling is supported via Gemini's native function calling feature: +// - Tools are converted to Gemini's FunctionDeclaration format +// - Function calls are streamed with real-time argument updates +// - Function responses are sent back as user messages with FunctionResponse parts +// +// # API Endpoint +// +// By default, the backend uses: +// https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent +// +// You can override this by setting Config.BaseURL. +// +// # Error Handling +// +// The backend properly handles: +// - Content blocking/safety filters +// - Token limit errors +// - Network errors +// - Malformed responses +// - Context cancellation +// +// All errors are properly propagated through the SSE stream. +// +// # Limitations +// +// - File uploads must be provided as base64-encoded inline data +// - Images and PDFs use inline data, not file upload URIs +// - Multi-turn conversations require proper role alternation (user/model) +// - Some advanced Gemini features like caching are not yet implemented +package gemini diff --git a/pkg/aiusechat/gemini/gemini-backend.go b/pkg/aiusechat/gemini/gemini-backend.go new file mode 100644 index 0000000000..6455881468 --- /dev/null +++ b/pkg/aiusechat/gemini/gemini-backend.go @@ -0,0 +1,514 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package gemini + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/google/uuid" + "github.com/launchdarkly/eventsource" + "github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil" + "github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore" + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/wavebase" + "github.com/wavetermdev/waveterm/pkg/web/sse" +) + +// ensureAltSse ensures the ?alt=sse query parameter is set on the endpoint +func ensureAltSse(endpoint string) (string, error) { + parsedURL, err := url.Parse(endpoint) + if err != nil { + return "", fmt.Errorf("invalid ai:endpoint URL: %w", err) + } + + query := parsedURL.Query() + if query.Get("alt") != "sse" { + query.Set("alt", "sse") + parsedURL.RawQuery = query.Encode() + return parsedURL.String(), nil + } + + return endpoint, nil +} + +// UpdateToolUseData updates the tool use data for a specific tool call in the chat +func UpdateToolUseData(chatId string, toolCallId string, toolUseData uctypes.UIMessageDataToolUse) error { + chat := chatstore.DefaultChatStore.Get(chatId) + if chat == nil { + return fmt.Errorf("chat not found: %s", chatId) + } + + for _, genMsg := range chat.NativeMessages { + chatMsg, ok := genMsg.(*GeminiChatMessage) + if !ok { + continue + } + + for i, part := range chatMsg.Parts { + if part.FunctionCall != nil && part.ToolUseData != nil && part.ToolUseData.ToolCallId == toolCallId { + // Update the message with new tool use data + updatedMsg := &GeminiChatMessage{ + MessageId: chatMsg.MessageId, + Role: chatMsg.Role, + Parts: make([]GeminiMessagePart, len(chatMsg.Parts)), + Usage: chatMsg.Usage, + } + copy(updatedMsg.Parts, chatMsg.Parts) + updatedMsg.Parts[i].ToolUseData = &toolUseData + + aiOpts := &uctypes.AIOptsType{ + APIType: chat.APIType, + Model: chat.Model, + APIVersion: chat.APIVersion, + } + + return chatstore.DefaultChatStore.PostMessage(chatId, aiOpts, updatedMsg) + } + } + } + + return fmt.Errorf("tool call with ID %s not found in chat %s", toolCallId, chatId) +} + +// appendPartToLastUserMessage appends a text part to the last user message in the contents slice +func appendPartToLastUserMessage(contents []GeminiContent, text string) { + for i := len(contents) - 1; i >= 0; i-- { + if contents[i].Role == "user" { + contents[i].Parts = append(contents[i].Parts, GeminiMessagePart{ + Text: text, + }) + break + } + } +} + +// buildGeminiHTTPRequest creates an HTTP request for the Gemini API +func buildGeminiHTTPRequest(ctx context.Context, contents []GeminiContent, chatOpts uctypes.WaveChatOpts) (*http.Request, error) { + opts := chatOpts.Config + + if opts.Model == "" { + return nil, errors.New("ai:model is required") + } + if opts.APIToken == "" { + return nil, errors.New("ai:apitoken is required") + } + if opts.Endpoint == "" { + return nil, errors.New("ai:endpoint is required") + } + + maxTokens := opts.MaxTokens + if maxTokens <= 0 { + maxTokens = GeminiDefaultMaxTokens + } + + // Build request body + reqBody := &GeminiRequest{ + Contents: contents, + GenerationConfig: &GeminiGenerationConfig{ + MaxOutputTokens: int32(maxTokens), + Temperature: 0.7, // Default temperature + }, + } + + // Map thinking level for Gemini 3+ models + if opts.ThinkingLevel != "" && strings.Contains(opts.Model, "gemini-3") { + geminiThinkingLevel := "high" + if opts.ThinkingLevel == uctypes.ThinkingLevelLow { + geminiThinkingLevel = "low" + } + reqBody.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + ThinkingLevel: geminiThinkingLevel, + } + } + + // Add system instruction if provided + if len(chatOpts.SystemPrompt) > 0 { + systemText := strings.Join(chatOpts.SystemPrompt, "\n\n") + reqBody.SystemInstruction = &GeminiContent{ + Parts: []GeminiMessagePart{ + {Text: systemText}, + }, + } + } + + // Add tools if provided + var allTools []uctypes.ToolDefinition + allTools = append(allTools, chatOpts.Tools...) + allTools = append(allTools, chatOpts.TabTools...) + + if len(allTools) > 0 { + var functionDeclarations []GeminiFunctionDeclaration + for _, tool := range allTools { + // Only include tools whose capabilities are met + if !tool.HasRequiredCapabilities(opts.Capabilities) { + continue + } + functionDeclarations = append(functionDeclarations, ConvertToolDefinitionToGemini(tool)) + } + if len(functionDeclarations) > 0 { + reqBody.Tools = []GeminiTool{ + {FunctionDeclarations: functionDeclarations}, + } + reqBody.ToolConfig = &GeminiToolConfig{ + FunctionCallingConfig: &GeminiFunctionCallingConfig{ + Mode: "AUTO", + }, + } + } + } + + // Injected data - append to last user message as separate parts + if chatOpts.TabState != "" { + appendPartToLastUserMessage(reqBody.Contents, chatOpts.TabState) + } + if chatOpts.PlatformInfo != "" { + appendPartToLastUserMessage(reqBody.Contents, "\n"+chatOpts.PlatformInfo+"\n") + } + if chatOpts.AppStaticFiles != "" { + appendPartToLastUserMessage(reqBody.Contents, "\n"+chatOpts.AppStaticFiles+"\n") + } + if chatOpts.AppGoFile != "" { + appendPartToLastUserMessage(reqBody.Contents, "\n"+chatOpts.AppGoFile+"\n") + } + + if wavebase.IsDevMode() { + var toolNames []string + for _, tool := range allTools { + toolNames = append(toolNames, tool.Name) + } + log.Printf("gemini: model %s, messages: %d, tools: %s\n", opts.Model, len(contents), strings.Join(toolNames, ",")) + } + + // Encode request body + buf, err := aiutil.JsonEncodeRequestBody(reqBody) + if err != nil { + return nil, err + } + + // Build URL + endpoint, err := ensureAltSse(opts.Endpoint) + if err != nil { + return nil, err + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, &buf) + if err != nil { + return nil, err + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-goog-api-key", opts.APIToken) + + return req, nil +} + +// RunGeminiChatStep executes a chat step using the Gemini API +func RunGeminiChatStep( + ctx context.Context, + sseHandler *sse.SSEHandlerCh, + chatOpts uctypes.WaveChatOpts, + cont *uctypes.WaveContinueResponse, +) (*uctypes.WaveStopReason, *GeminiChatMessage, *uctypes.RateLimitInfo, error) { + if sseHandler == nil { + return nil, nil, nil, errors.New("sse handler is nil") + } + + // Get chat from store + chat := chatstore.DefaultChatStore.Get(chatOpts.ChatId) + if chat == nil { + return nil, nil, nil, fmt.Errorf("chat not found: %s", chatOpts.ChatId) + } + + // Validate that chatOpts.Config match the chat's stored configuration + if chat.APIType != chatOpts.Config.APIType { + return nil, nil, nil, fmt.Errorf("API type mismatch: chat has %s, chatOpts has %s", chat.APIType, chatOpts.Config.APIType) + } + if chat.Model != chatOpts.Config.Model { + return nil, nil, nil, fmt.Errorf("model mismatch: chat has %s, chatOpts has %s", chat.Model, chatOpts.Config.Model) + } + + // Context with timeout if provided + if chatOpts.Config.TimeoutMs > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(chatOpts.Config.TimeoutMs)*time.Millisecond) + defer cancel() + } + + // Convert GenAIMessages to Gemini contents + var contents []GeminiContent + for _, genMsg := range chat.NativeMessages { + chatMsg, ok := genMsg.(*GeminiChatMessage) + if !ok { + return nil, nil, nil, fmt.Errorf("expected GeminiChatMessage, got %T", genMsg) + } + + content := GeminiContent{ + Role: chatMsg.Role, + Parts: make([]GeminiMessagePart, len(chatMsg.Parts)), + } + for i, part := range chatMsg.Parts { + content.Parts[i] = *part.Clean() + } + contents = append(contents, content) + } + + req, err := buildGeminiHTTPRequest(ctx, contents, chatOpts) + if err != nil { + return nil, nil, nil, err + } + + httpClient := &http.Client{ + Timeout: 0, // rely on ctx; streaming can be long + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, nil, nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + + // Try to parse as Gemini error + var geminiErr GeminiErrorResponse + if err := json.Unmarshal(bodyBytes, &geminiErr); err == nil && geminiErr.Error != nil { + return nil, nil, nil, fmt.Errorf("Gemini API error (%d): %s", geminiErr.Error.Code, geminiErr.Error.Message) + } + + return nil, nil, nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, utilfn.TruncateString(string(bodyBytes), 120)) + } + + // Setup SSE if this is a new request (not a continuation) + if cont == nil { + if err := sseHandler.SetupSSE(); err != nil { + return nil, nil, nil, fmt.Errorf("failed to setup SSE: %w", err) + } + } + + // Stream processing + stopReason, assistantMsg, err := processGeminiStream(ctx, resp.Body, sseHandler, chatOpts, cont) + if err != nil { + return nil, nil, nil, err + } + + return stopReason, assistantMsg, nil, nil +} + +// processGeminiStream handles the streaming response from Gemini +func processGeminiStream( + ctx context.Context, + body io.Reader, + sseHandler *sse.SSEHandlerCh, + chatOpts uctypes.WaveChatOpts, + cont *uctypes.WaveContinueResponse, +) (*uctypes.WaveStopReason, *GeminiChatMessage, error) { + msgID := uuid.New().String() + textID := uuid.New().String() + textStarted := false + var textBuilder strings.Builder + var textThoughtSignature string + var finishReason string + var functionCalls []GeminiMessagePart + var usageMetadata *GeminiUsageMetadata + + if cont == nil { + _ = sseHandler.AiMsgStart(msgID) + } + _ = sseHandler.AiMsgStartStep() + + decoder := eventsource.NewDecoder(body) + + for { + if err := ctx.Err(); err != nil { + _ = sseHandler.AiMsgError("request cancelled") + return &uctypes.WaveStopReason{ + Kind: uctypes.StopKindCanceled, + ErrorType: "cancelled", + ErrorText: "request cancelled", + }, nil, err + } + + event, err := decoder.Decode() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + _ = sseHandler.AiMsgError(fmt.Sprintf("stream decode error: %v", err)) + return &uctypes.WaveStopReason{ + Kind: uctypes.StopKindError, + ErrorType: "stream", + ErrorText: err.Error(), + }, nil, fmt.Errorf("stream decode error: %w", err) + } + + data := event.Data() + if data == "" { + continue + } + + // Parse the JSON response + var chunk GeminiStreamResponse + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + log.Printf("gemini: failed to parse chunk: %v\n", err) + continue + } + + // Check for prompt feedback (blocking) + if chunk.PromptFeedback != nil && chunk.PromptFeedback.BlockReason != "" { + errorMsg := fmt.Sprintf("Content blocked: %s", chunk.PromptFeedback.BlockReason) + _ = sseHandler.AiMsgError(errorMsg) + return &uctypes.WaveStopReason{ + Kind: uctypes.StopKindContent, + ErrorType: "blocked", + ErrorText: errorMsg, + }, nil, fmt.Errorf("%s", errorMsg) + } + + // Store usage metadata if present + if chunk.UsageMetadata != nil { + usageMetadata = chunk.UsageMetadata + } + + // Log grounding metadata (web search queries) + if chunk.GroundingMetadata != nil && len(chunk.GroundingMetadata.WebSearchQueries) > 0 { + if wavebase.IsDevMode() { + log.Printf("gemini: web search queries executed: %v\n", chunk.GroundingMetadata.WebSearchQueries) + } + } + + // Process candidates + if len(chunk.Candidates) == 0 { + continue + } + + candidate := chunk.Candidates[0] + + // Log candidate grounding metadata if present + if candidate.GroundingMetadata != nil && len(candidate.GroundingMetadata.WebSearchQueries) > 0 { + if wavebase.IsDevMode() { + log.Printf("gemini: candidate web search queries: %v\n", candidate.GroundingMetadata.WebSearchQueries) + } + } + + // Store finish reason + if candidate.FinishReason != "" { + finishReason = candidate.FinishReason + } + + if candidate.Content == nil { + continue + } + + // Process content parts + for _, part := range candidate.Content.Parts { + if part.Text != "" { + if !textStarted { + _ = sseHandler.AiMsgTextStart(textID) + textStarted = true + } + textBuilder.WriteString(part.Text) + _ = sseHandler.AiMsgTextDelta(textID, part.Text) + if part.ThoughtSignature != "" { + textThoughtSignature = part.ThoughtSignature + } + } + + if part.FunctionCall != nil { + toolCallId := uuid.New().String() + + argsBytes, _ := json.Marshal(part.FunctionCall.Args) + aiutil.SendToolProgress(toolCallId, part.FunctionCall.Name, argsBytes, chatOpts, sseHandler, false) + + // Preserve thought_signature exactly as received from API + // It can be at part level, FunctionCall level, or both + functionCalls = append(functionCalls, GeminiMessagePart{ + FunctionCall: part.FunctionCall, + ThoughtSignature: part.ThoughtSignature, + ToolUseData: &uctypes.UIMessageDataToolUse{ + ToolCallId: toolCallId, + ToolName: part.FunctionCall.Name, + }, + }) + } + } + } + + // Determine stop reason + stopKind := uctypes.StopKindDone + switch finishReason { + case "MAX_TOKENS": + stopKind = uctypes.StopKindMaxTokens + case "SAFETY": + stopKind = uctypes.StopKindContent + case "RECITATION": + stopKind = uctypes.StopKindContent + } + + // Build assistant message + var parts []GeminiMessagePart + if textBuilder.Len() > 0 { + parts = append(parts, GeminiMessagePart{ + Text: textBuilder.String(), + ThoughtSignature: textThoughtSignature, + }) + } + parts = append(parts, functionCalls...) + + // Set usage metadata model + if usageMetadata != nil { + usageMetadata.Model = chatOpts.Config.Model + } + + assistantMsg := &GeminiChatMessage{ + MessageId: msgID, + Role: "model", + Parts: parts, + Usage: usageMetadata, + } + + // Build tool calls for stop reason + var waveToolCalls []uctypes.WaveToolCall + if len(functionCalls) > 0 { + stopKind = uctypes.StopKindToolUse + for _, fcPart := range functionCalls { + if fcPart.FunctionCall != nil && fcPart.ToolUseData != nil { + waveToolCalls = append(waveToolCalls, uctypes.WaveToolCall{ + ID: fcPart.ToolUseData.ToolCallId, + Name: fcPart.FunctionCall.Name, + Input: fcPart.FunctionCall.Args, + ToolUseData: fcPart.ToolUseData, + }) + } + } + } + + stopReason := &uctypes.WaveStopReason{ + Kind: stopKind, + RawReason: finishReason, + ToolCalls: waveToolCalls, + } + + if textStarted { + _ = sseHandler.AiMsgTextEnd(textID) + } + _ = sseHandler.AiMsgFinishStep() + if stopKind != uctypes.StopKindToolUse { + _ = sseHandler.AiMsgFinish(finishReason, nil) + } + + return stopReason, assistantMsg, nil +} diff --git a/pkg/aiusechat/gemini/gemini-convertmessage.go b/pkg/aiusechat/gemini/gemini-convertmessage.go new file mode 100644 index 0000000000..dc43da422f --- /dev/null +++ b/pkg/aiusechat/gemini/gemini-convertmessage.go @@ -0,0 +1,418 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package gemini + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "log" + "strings" + + "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil" + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" +) + +// cleanSchemaForGemini removes fields from JSON Schema that Gemini doesn't accept +// Gemini uses a strict subset of JSON Schema and rejects fields like $schema, units, title, etc. +func cleanSchemaForGemini(schema map[string]any) map[string]any { + if schema == nil { + return nil + } + + cleaned := make(map[string]any) + + // Fields that Gemini accepts in the root schema + allowedRootFields := map[string]bool{ + "type": true, + "properties": true, + "required": true, + "description": true, + "items": true, + "enum": true, + "format": true, + "minimum": true, + "maximum": true, + "pattern": true, + "default": true, + } + + for key, value := range schema { + if !allowedRootFields[key] { + // Skip fields like $schema, title, units, definitions, $ref, etc. + continue + } + + // Recursively clean nested schemas + switch key { + case "properties": + if props, ok := value.(map[string]any); ok { + cleanedProps := make(map[string]any) + for propName, propValue := range props { + if propSchema, ok := propValue.(map[string]any); ok { + cleanedProps[propName] = cleanSchemaForGemini(propSchema) + } else { + // Preserve non-map property values + cleanedProps[propName] = propValue + } + } + cleaned[key] = cleanedProps + } + case "items": + if items, ok := value.(map[string]any); ok { + cleaned[key] = cleanSchemaForGemini(items) + } else { + cleaned[key] = value + } + default: + cleaned[key] = value + } + } + + return cleaned +} + +// ConvertToolDefinitionToGemini converts a Wave ToolDefinition to Gemini format +func ConvertToolDefinitionToGemini(tool uctypes.ToolDefinition) GeminiFunctionDeclaration { + // Clean the schema to remove fields that Gemini doesn't accept + cleanedSchema := cleanSchemaForGemini(tool.InputSchema) + + return GeminiFunctionDeclaration{ + Name: tool.Name, + Description: tool.Description, + Parameters: cleanedSchema, + } +} + +// convertFileAIMessagePart converts a file AIMessagePart to Gemini format +func convertFileAIMessagePart(part uctypes.AIMessagePart) (*GeminiMessagePart, error) { + if part.Type != uctypes.AIMessagePartTypeFile { + return nil, fmt.Errorf("convertFileAIMessagePart expects 'file' type, got '%s'", part.Type) + } + if part.MimeType == "" { + return nil, fmt.Errorf("file part missing mimetype") + } + + // Handle different file types + switch { + case strings.HasPrefix(part.MimeType, "image/"): + // For images, we need base64 data + var base64Data string + if len(part.Data) > 0 { + base64Data = base64.StdEncoding.EncodeToString(part.Data) + } else if part.URL != "" { + // If URL is provided, it should be a data URL + if strings.HasPrefix(part.URL, "data:") { + // Extract base64 data from data URL + parts := strings.SplitN(part.URL, ",", 2) + if len(parts) == 2 { + base64Data = parts[1] + } else { + return nil, fmt.Errorf("invalid data URL format") + } + } else { + return nil, fmt.Errorf("dropping image with non-data URL (must be fetched and converted to base64)") + } + } else { + return nil, fmt.Errorf("image file part missing data") + } + + return &GeminiMessagePart{ + InlineData: &GeminiInlineData{ + MimeType: part.MimeType, + Data: base64Data, + }, + FileName: part.FileName, + PreviewUrl: part.PreviewUrl, + }, nil + + case part.MimeType == "application/pdf": + // Handle PDFs - Gemini supports base64 data for PDFs + if len(part.Data) == 0 { + if part.URL != "" { + return nil, fmt.Errorf("dropping PDF with URL (must be fetched and converted to base64 data)") + } + return nil, fmt.Errorf("PDF file part missing data") + } + + // Convert raw data to base64 + base64Data := base64.StdEncoding.EncodeToString(part.Data) + + return &GeminiMessagePart{ + InlineData: &GeminiInlineData{ + MimeType: "application/pdf", + Data: base64Data, + }, + FileName: part.FileName, + PreviewUrl: part.PreviewUrl, + }, nil + + case part.MimeType == "text/plain": + textData, err := aiutil.ExtractTextData(part.Data, part.URL) + if err != nil { + return nil, err + } + formattedText := aiutil.FormatAttachedTextFile(part.FileName, textData) + return &GeminiMessagePart{ + Text: formattedText, + }, nil + + case part.MimeType == "directory": + var jsonContent string + if len(part.Data) > 0 { + jsonContent = string(part.Data) + } else { + return nil, fmt.Errorf("directory listing part missing data") + } + + formattedText := aiutil.FormatAttachedDirectoryListing(part.FileName, jsonContent) + return &GeminiMessagePart{ + Text: formattedText, + }, nil + + default: + return nil, fmt.Errorf("dropping file with unsupported mimetype '%s' (Gemini supports images, PDFs, text/plain, and directories)", part.MimeType) + } +} + +// ConvertAIMessageToGeminiChatMessage converts an AIMessage to GeminiChatMessage +// These messages are ALWAYS role "user" +func ConvertAIMessageToGeminiChatMessage(aiMsg uctypes.AIMessage) (*GeminiChatMessage, error) { + if err := aiMsg.Validate(); err != nil { + return nil, fmt.Errorf("invalid AIMessage: %w", err) + } + + var parts []GeminiMessagePart + + for i, part := range aiMsg.Parts { + switch part.Type { + case uctypes.AIMessagePartTypeText: + if part.Text == "" { + return nil, fmt.Errorf("part %d: text type requires non-empty text field", i) + } + parts = append(parts, GeminiMessagePart{ + Text: part.Text, + }) + + case uctypes.AIMessagePartTypeFile: + geminiPart, err := convertFileAIMessagePart(part) + if err != nil { + log.Printf("gemini: %v", err) + continue + } + parts = append(parts, *geminiPart) + + default: + // Drop unknown part types + log.Printf("gemini: dropping unknown part type '%s'", part.Type) + continue + } + } + + return &GeminiChatMessage{ + MessageId: aiMsg.MessageId, + Role: "user", + Parts: parts, + }, nil +} + +// ConvertToolResultsToGeminiChatMessage converts AIToolResult slice to GeminiChatMessage +func ConvertToolResultsToGeminiChatMessage(toolResults []uctypes.AIToolResult) (*GeminiChatMessage, error) { + if len(toolResults) == 0 { + return nil, fmt.Errorf("toolResults cannot be empty") + } + + var parts []GeminiMessagePart + + for _, result := range toolResults { + if result.ToolUseID == "" { + return nil, fmt.Errorf("tool result missing ToolUseID") + } + + response := make(map[string]any) + var nestedParts []GeminiMessagePart + + if result.ErrorText != "" { + response["ok"] = false + response["error"] = result.ErrorText + } else if strings.HasPrefix(result.Text, "data:") { + mimeType, base64Data, err := utilfn.DecodeDataURL(result.Text) + if err != nil { + log.Printf("gemini: failed to decode data URL in tool result: %v\n", err) + response["ok"] = false + response["error"] = fmt.Sprintf("failed to decode data URL: %v", err) + } else if strings.HasPrefix(mimeType, "image/") { + // For image data URLs, use multimodal function response (Gemini 3 Pro+) + displayName := fmt.Sprintf("result_%s.%s", result.ToolUseID[:8], strings.TrimPrefix(mimeType, "image/")) + response["ok"] = true + response["image"] = map[string]string{"$ref": displayName} + + // Add the image data as a nested part + nestedParts = append(nestedParts, GeminiMessagePart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: base64.StdEncoding.EncodeToString(base64Data), + DisplayName: displayName, + }, + }) + } else { + log.Printf("gemini: unsupported data URL mimetype in tool result: %s\n", mimeType) + response["ok"] = false + response["error"] = fmt.Sprintf("unsupported data URL mimetype: %s", mimeType) + } + } else { + response["ok"] = true + response["result"] = result.Text + } + + parts = append(parts, GeminiMessagePart{ + FunctionResponse: &GeminiFunctionResponse{ + Name: result.ToolName, + Response: response, + Parts: nestedParts, + }, + }) + } + + return &GeminiChatMessage{ + MessageId: uuid.New().String(), + Role: "user", // Function responses are sent as user messages + Parts: parts, + }, nil +} + +// convertContentPartToUIPart converts a Gemini content part to UIMessagePart +func convertContentPartToUIPart(part GeminiMessagePart, role string) []uctypes.UIMessagePart { + var uiParts []uctypes.UIMessagePart + + if part.Text != "" { + if found, dataPart := aiutil.ConvertDataUserFile(part.Text); found { + if dataPart != nil { + uiParts = append(uiParts, *dataPart) + } + } else { + uiParts = append(uiParts, uctypes.UIMessagePart{ + Type: "text", + Text: part.Text, + }) + } + } + + if part.InlineData != nil && role == "user" { + // Show uploaded files in user messages + var mimeType string + if strings.HasPrefix(part.InlineData.MimeType, "image/") { + mimeType = "image/*" + } else { + mimeType = part.InlineData.MimeType + } + + uiParts = append(uiParts, uctypes.UIMessagePart{ + Type: "data-userfile", + Data: uctypes.UIMessageDataUserFile{ + FileName: part.FileName, + MimeType: mimeType, + PreviewUrl: part.PreviewUrl, + }, + }) + } + + // Tool use parts are handled separately by the backend + if part.ToolUseData != nil { + uiParts = append(uiParts, uctypes.UIMessagePart{ + Type: "data-tooluse", + ID: part.ToolUseData.ToolCallId, + Data: *part.ToolUseData, + }) + } + + return uiParts +} + +// convertToUIMessage converts a GeminiChatMessage to a UIMessage +func (m *GeminiChatMessage) convertToUIMessage() *uctypes.UIMessage { + var parts []uctypes.UIMessagePart + + for _, part := range m.Parts { + // Skip function responses - they're not shown in UI + if part.FunctionResponse != nil { + continue + } + + partUIParts := convertContentPartToUIPart(part, m.Role) + parts = append(parts, partUIParts...) + } + + if len(parts) == 0 { + return nil + } + + // Convert Gemini role to standard role + role := m.Role + if role == "model" { + role = "assistant" + } + + return &uctypes.UIMessage{ + ID: m.MessageId, + Role: role, + Parts: parts, + } +} + +// ConvertAIChatToUIChat converts an AIChat to a UIChat for Gemini +func ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) { + if aiChat.APIType != uctypes.APIType_GoogleGemini { + return nil, fmt.Errorf("APIType must be '%s', got '%s'", uctypes.APIType_GoogleGemini, aiChat.APIType) + } + + uiMessages := make([]uctypes.UIMessage, 0, len(aiChat.NativeMessages)) + for i, nativeMsg := range aiChat.NativeMessages { + geminiMsg, ok := nativeMsg.(*GeminiChatMessage) + if !ok { + return nil, fmt.Errorf("message %d: expected *GeminiChatMessage, got %T", i, nativeMsg) + } + uiMsg := geminiMsg.convertToUIMessage() + if uiMsg != nil { + uiMessages = append(uiMessages, *uiMsg) + } + } + + return &uctypes.UIChat{ + ChatId: aiChat.ChatId, + APIType: aiChat.APIType, + Model: aiChat.Model, + APIVersion: aiChat.APIVersion, + Messages: uiMessages, + }, nil +} + +// GetFunctionCallInputByToolCallId returns the function call input associated with the given tool call ID +func GetFunctionCallInputByToolCallId(aiChat uctypes.AIChat, toolCallId string) *uctypes.AIFunctionCallInput { + for _, nativeMsg := range aiChat.NativeMessages { + geminiMsg, ok := nativeMsg.(*GeminiChatMessage) + if !ok { + continue + } + for _, part := range geminiMsg.Parts { + if part.FunctionCall != nil && part.ToolUseData != nil && part.ToolUseData.ToolCallId == toolCallId { + // Convert args map to JSON string + argsBytes, err := json.Marshal(part.FunctionCall.Args) + if err != nil { + log.Printf("gemini: error marshaling function call args: %v", err) + continue + } + return &uctypes.AIFunctionCallInput{ + CallId: toolCallId, + Name: part.FunctionCall.Name, + Arguments: string(argsBytes), + ToolUseData: part.ToolUseData, + } + } + } + } + return nil +} diff --git a/pkg/aiusechat/gemini/gemini-types.go b/pkg/aiusechat/gemini/gemini-types.go new file mode 100644 index 0000000000..e873abbb4c --- /dev/null +++ b/pkg/aiusechat/gemini/gemini-types.go @@ -0,0 +1,232 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package gemini + +import ( + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" +) + +const ( + GeminiDefaultMaxTokens = 8192 +) + +// GeminiChatMessage represents a stored chat message for Gemini backend +type GeminiChatMessage struct { + MessageId string `json:"messageid"` + Role string `json:"role"` // "user", "model" + Parts []GeminiMessagePart `json:"parts"` + Usage *GeminiUsageMetadata `json:"usage,omitempty"` +} + +func (m *GeminiChatMessage) GetMessageId() string { + return m.MessageId +} + +func (m *GeminiChatMessage) GetRole() string { + return m.Role +} + +func (m *GeminiChatMessage) GetUsage() *uctypes.AIUsage { + if m.Usage == nil { + return nil + } + return &uctypes.AIUsage{ + APIType: uctypes.APIType_GoogleGemini, + Model: m.Usage.Model, + InputTokens: m.Usage.PromptTokenCount, + OutputTokens: m.Usage.CandidatesTokenCount, + } +} + +// GeminiMessagePart represents different types of content in a message +type GeminiMessagePart struct { + // Text part + Text string `json:"text,omitempty"` + + // Inline data (images, PDFs, etc.) + InlineData *GeminiInlineData `json:"inlineData,omitempty"` + + // File data (for uploaded files) + FileData *GeminiFileData `json:"fileData,omitempty"` + + // Function call (assistant calling a tool) + FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"` + + // Function response (result of tool execution) + FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"` + + // Thought signature (for thinking models - applies to text and function calls) + ThoughtSignature string `json:"thoughtSignature,omitempty"` + + // Internal fields (not sent to API) + PreviewUrl string `json:"previewurl,omitempty"` // internal field + FileName string `json:"filename,omitempty"` // internal field + ToolUseData *uctypes.UIMessageDataToolUse `json:"toolusedata,omitempty"` // internal field +} + +// Clean removes internal fields before sending to API +func (p *GeminiMessagePart) Clean() *GeminiMessagePart { + if p == nil { + return nil + } + cleaned := *p + cleaned.PreviewUrl = "" + cleaned.FileName = "" + cleaned.ToolUseData = nil + return &cleaned +} + +// GeminiInlineData represents inline binary data +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` // base64 encoded + DisplayName string `json:"displayName,omitempty"` // for multimodal function responses +} + +// GeminiFileData represents uploaded file reference +type GeminiFileData struct { + MimeType string `json:"mimeType"` + FileUri string `json:"fileUri"` // gs:// URI from file upload + DisplayName string `json:"displayName,omitempty"` // for multimodal function responses +} + +// GeminiFunctionCall represents a function call from the model +type GeminiFunctionCall struct { + Name string `json:"name"` + Args map[string]any `json:"args,omitempty"` +} + +// GeminiFunctionResponse represents a function execution result +type GeminiFunctionResponse struct { + Name string `json:"name"` + Response map[string]any `json:"response"` + Parts []GeminiMessagePart `json:"parts,omitempty"` // nested parts for multimodal content (Gemini 3 Pro and later) +} + +// GeminiUsageMetadata represents token usage +type GeminiUsageMetadata struct { + Model string `json:"model,omitempty"` // internal field + PromptTokenCount int `json:"promptTokenCount"` + CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` +} + +// GeminiThinkingConfig represents thinking configuration for Gemini 3+ models +type GeminiThinkingConfig struct { + ThinkingLevel string `json:"thinkingLevel,omitempty"` // "low" or "high" +} + +// GeminiGenerationConfig represents generation parameters +type GeminiGenerationConfig struct { + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"topP,omitempty"` + TopK int32 `json:"topK,omitempty"` + CandidateCount int32 `json:"candidateCount,omitempty"` + MaxOutputTokens int32 `json:"maxOutputTokens,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` // for Gemini 3+ models +} + +// GeminiTool represents a function tool definition +type GeminiTool struct { + FunctionDeclarations []GeminiFunctionDeclaration `json:"functionDeclarations,omitempty"` + GoogleSearch *GeminiGoogleSearch `json:"googleSearch,omitempty"` +} + +// GeminiGoogleSearch represents Google Search configuration (empty for default) +type GeminiGoogleSearch struct{} + +// GeminiFunctionDeclaration represents a function schema +type GeminiFunctionDeclaration struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]any `json:"parameters,omitempty"` +} + +// GeminiToolConfig represents tool choice configuration +type GeminiToolConfig struct { + FunctionCallingConfig *GeminiFunctionCallingConfig `json:"functionCallingConfig,omitempty"` +} + +// GeminiFunctionCallingConfig represents function calling configuration +type GeminiFunctionCallingConfig struct { + Mode string `json:"mode,omitempty"` // "AUTO", "ANY", "NONE" +} + +// GeminiContent represents a content message for the API +type GeminiContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiMessagePart `json:"parts"` +} + +// Clean removes internal fields from all parts +func (c *GeminiContent) Clean() *GeminiContent { + if c == nil { + return nil + } + cleaned := &GeminiContent{ + Role: c.Role, + Parts: make([]GeminiMessagePart, len(c.Parts)), + } + for i, part := range c.Parts { + cleaned.Parts[i] = *part.Clean() + } + return cleaned +} + +// GeminiRequest represents a request to the Gemini API +type GeminiRequest struct { + Contents []GeminiContent `json:"contents"` + SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"` + GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"` + Tools []GeminiTool `json:"tools,omitempty"` + ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"` +} + +// GeminiStreamResponse represents a streaming response chunk +type GeminiStreamResponse struct { + Candidates []GeminiCandidate `json:"candidates,omitempty"` + PromptFeedback *GeminiPromptFeedback `json:"promptFeedback,omitempty"` + UsageMetadata *GeminiUsageMetadata `json:"usageMetadata,omitempty"` + GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"` +} + +// GeminiCandidate represents a candidate response +type GeminiCandidate struct { + Content *GeminiContent `json:"content,omitempty"` + FinishReason string `json:"finishReason,omitempty"` + Index int `json:"index,omitempty"` + SafetyRatings []GeminiSafetyRating `json:"safetyRatings,omitempty"` + GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"` +} + +// GeminiSafetyRating represents a safety rating +type GeminiSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +// GeminiPromptFeedback represents feedback about the prompt +type GeminiPromptFeedback struct { + BlockReason string `json:"blockReason,omitempty"` + SafetyRatings []GeminiSafetyRating `json:"safetyRatings,omitempty"` +} + +// GeminiErrorResponse represents an error response +type GeminiErrorResponse struct { + Error *GeminiError `json:"error,omitempty"` +} + +// GeminiError represents an error +type GeminiError struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status,omitempty"` +} + +// GeminiGroundingMetadata represents grounding metadata with web search results +type GeminiGroundingMetadata struct { + WebSearchQueries []string `json:"webSearchQueries,omitempty"` +} diff --git a/pkg/aiusechat/tools.go b/pkg/aiusechat/tools.go index e3f643faa3..550fec95cb 100644 --- a/pkg/aiusechat/tools.go +++ b/pkg/aiusechat/tools.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" "github.com/wavetermdev/waveterm/pkg/blockcontroller" "github.com/wavetermdev/waveterm/pkg/util/utilfn" @@ -132,7 +133,7 @@ func MakeBlockShortDesc(block *waveobj.Block) string { } } -func GenerateTabStateAndTools(ctx context.Context, tabid string, widgetAccess bool) (string, []uctypes.ToolDefinition, error) { +func GenerateTabStateAndTools(ctx context.Context, tabid string, widgetAccess bool, chatOpts *uctypes.WaveChatOpts) (string, []uctypes.ToolDefinition, error) { if tabid == "" { return "", nil, nil } @@ -160,7 +161,13 @@ func GenerateTabStateAndTools(ctx context.Context, tabid string, widgetAccess bo // log.Printf("TABPROMPT %s\n", tabState) var tools []uctypes.ToolDefinition if widgetAccess { - tools = append(tools, GetCaptureScreenshotToolDefinition(tabid)) + // Only add screenshot tool for: + // - openai-responses API type + // - google-gemini API type with Gemini 3+ models + if chatOpts.Config.APIType == uctypes.APIType_OpenAIResponses || + (chatOpts.Config.APIType == uctypes.APIType_GoogleGemini && aiutil.GeminiSupportsImageToolResults(chatOpts.Config.Model)) { + tools = append(tools, GetCaptureScreenshotToolDefinition(tabid)) + } tools = append(tools, GetReadTextFileToolDefinition()) tools = append(tools, GetReadDirToolDefinition()) tools = append(tools, GetWriteTextFileToolDefinition()) diff --git a/pkg/aiusechat/uctypes/uctypes.go b/pkg/aiusechat/uctypes/uctypes.go index 7f8f859979..7d3ea45c13 100644 --- a/pkg/aiusechat/uctypes/uctypes.go +++ b/pkg/aiusechat/uctypes/uctypes.go @@ -22,6 +22,7 @@ const ( APIType_AnthropicMessages = "anthropic-messages" APIType_OpenAIResponses = "openai-responses" APIType_OpenAIChat = "openai-chat" + APIType_GoogleGemini = "google-gemini" ) const ( diff --git a/pkg/aiusechat/usechat-backend.go b/pkg/aiusechat/usechat-backend.go index 528cd3af5c..6ae1d94663 100644 --- a/pkg/aiusechat/usechat-backend.go +++ b/pkg/aiusechat/usechat-backend.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/wavetermdev/waveterm/pkg/aiusechat/anthropic" + "github.com/wavetermdev/waveterm/pkg/aiusechat/gemini" "github.com/wavetermdev/waveterm/pkg/aiusechat/openai" "github.com/wavetermdev/waveterm/pkg/aiusechat/openaichat" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" @@ -54,6 +55,7 @@ type UseChatBackend interface { var _ UseChatBackend = (*openaiResponsesBackend)(nil) var _ UseChatBackend = (*openaiCompletionsBackend)(nil) var _ UseChatBackend = (*anthropicBackend)(nil) +var _ UseChatBackend = (*geminiBackend)(nil) // GetBackendByAPIType returns the appropriate UseChatBackend implementation for the given API type func GetBackendByAPIType(apiType string) (UseChatBackend, error) { @@ -64,6 +66,8 @@ func GetBackendByAPIType(apiType string) (UseChatBackend, error) { return &openaiCompletionsBackend{}, nil case uctypes.APIType_AnthropicMessages: return &anthropicBackend{}, nil + case uctypes.APIType_GoogleGemini: + return &geminiBackend{}, nil default: return nil, fmt.Errorf("unsupported API type: %s", apiType) } @@ -196,3 +200,43 @@ func (b *anthropicBackend) GetFunctionCallInputByToolCallId(aiChat uctypes.AICha func (b *anthropicBackend) ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) { return anthropic.ConvertAIChatToUIChat(aiChat) } + +// geminiBackend implements UseChatBackend for Google Gemini API +type geminiBackend struct{} + +func (b *geminiBackend) RunChatStep( + ctx context.Context, + sseHandler *sse.SSEHandlerCh, + chatOpts uctypes.WaveChatOpts, + cont *uctypes.WaveContinueResponse, +) (*uctypes.WaveStopReason, []uctypes.GenAIMessage, *uctypes.RateLimitInfo, error) { + stopReason, msg, rateLimitInfo, err := gemini.RunGeminiChatStep(ctx, sseHandler, chatOpts, cont) + if msg == nil { + return stopReason, nil, rateLimitInfo, err + } + return stopReason, []uctypes.GenAIMessage{msg}, rateLimitInfo, err +} + +func (b *geminiBackend) UpdateToolUseData(chatId string, toolCallId string, toolUseData uctypes.UIMessageDataToolUse) error { + return gemini.UpdateToolUseData(chatId, toolCallId, toolUseData) +} + +func (b *geminiBackend) ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error) { + msg, err := gemini.ConvertToolResultsToGeminiChatMessage(toolResults) + if err != nil { + return nil, err + } + return []uctypes.GenAIMessage{msg}, nil +} + +func (b *geminiBackend) ConvertAIMessageToNativeChatMessage(message uctypes.AIMessage) (uctypes.GenAIMessage, error) { + return gemini.ConvertAIMessageToGeminiChatMessage(message) +} + +func (b *geminiBackend) GetFunctionCallInputByToolCallId(aiChat uctypes.AIChat, toolCallId string) *uctypes.AIFunctionCallInput { + return gemini.GetFunctionCallInputByToolCallId(aiChat, toolCallId) +} + +func (b *geminiBackend) ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) { + return gemini.ConvertAIChatToUIChat(aiChat) +} diff --git a/pkg/aiusechat/usechat-mode.go b/pkg/aiusechat/usechat-mode.go index b6472003bb..a92c4a3df2 100644 --- a/pkg/aiusechat/usechat-mode.go +++ b/pkg/aiusechat/usechat-mode.go @@ -61,6 +61,13 @@ func applyProviderDefaults(config *wconfig.AIModeConfigType) { if config.APITokenSecretName == "" { config.APITokenSecretName = "OPENAI_KEY" } + if len(config.Capabilities) == 0 { + if isO1Model(config.Model) { + config.Capabilities = []string{} + } else { + config.Capabilities = []string{uctypes.AICapabilityTools, uctypes.AICapabilityImages, uctypes.AICapabilityPdfs} + } + } } if config.Provider == uctypes.AIProvider_OpenRouter { if config.Endpoint == "" { @@ -108,6 +115,20 @@ func applyProviderDefaults(config *wconfig.AIModeConfigType) { config.APITokenSecretName = "AZURE_OPENAI_KEY" } } + if config.Provider == uctypes.AIProvider_Google { + if config.APIType == "" { + config.APIType = uctypes.APIType_GoogleGemini + } + if config.Endpoint == "" && config.Model != "" { + config.Endpoint = fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:streamGenerateContent", config.Model) + } + if config.APITokenSecretName == "" { + config.APITokenSecretName = "GOOGLE_AI_KEY" + } + if len(config.Capabilities) == 0 { + config.Capabilities = []string{uctypes.AICapabilityTools, uctypes.AICapabilityImages, uctypes.AICapabilityPdfs} + } + } if config.APIType == "" { config.APIType = uctypes.APIType_OpenAIChat } @@ -162,6 +183,19 @@ func isLegacyOpenAIModel(model string) bool { return false } +func isO1Model(model string) bool { + if model == "" { + return false + } + o1Prefixes := []string{"o1", "o1-mini"} + for _, prefix := range o1Prefixes { + if aiutil.CheckModelPrefix(model, prefix) { + return true + } + } + return false +} + func isValidAzureResourceName(name string) bool { if name == "" || len(name) > 63 { return false diff --git a/pkg/aiusechat/usechat-prompts.go b/pkg/aiusechat/usechat-prompts.go index b8bcb7aa03..7aacc8f40a 100644 --- a/pkg/aiusechat/usechat-prompts.go +++ b/pkg/aiusechat/usechat-prompts.go @@ -5,13 +5,6 @@ package aiusechat import "strings" -var SystemPromptText = strings.Join([]string{ - `You are Wave AI, an intelligent assistant embedded within Wave Terminal, a modern terminal application with graphical widgets.`, - `You appear as a pull-out panel on the left side of a tab, with the tab's widgets laid out on the right.`, - `Widget context is provided as informational only.`, - `Do NOT assume any API access or ability to interact with the widgets except via tools provided (note that some widgets may expose NO tools, so their context is informational only).`, -}, " ") - var SystemPromptText_OpenAI = strings.Join([]string{ `You are Wave AI, an assistant embedded in Wave Terminal (a terminal with graphical widgets).`, `You appear as a pull-out panel on the left; widgets are on the right.`, diff --git a/pkg/aiusechat/usechat.go b/pkg/aiusechat/usechat.go index ae5e9d9796..0ec2fa7d34 100644 --- a/pkg/aiusechat/usechat.go +++ b/pkg/aiusechat/usechat.go @@ -645,7 +645,7 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) { if req.TabId != "" { chatOpts.TabStateGenerator = func() (string, []uctypes.ToolDefinition, string, error) { - tabState, tabTools, err := GenerateTabStateAndTools(r.Context(), req.TabId, req.WidgetAccess) + tabState, tabTools, err := GenerateTabStateAndTools(r.Context(), req.TabId, req.WidgetAccess, &chatOpts) return tabState, tabTools, req.TabId, err } } diff --git a/pkg/wconfig/settingsconfig.go b/pkg/wconfig/settingsconfig.go index 611899d08b..1d6adf1eda 100644 --- a/pkg/wconfig/settingsconfig.go +++ b/pkg/wconfig/settingsconfig.go @@ -268,7 +268,7 @@ type AIModeConfigType struct { DisplayIcon string `json:"display:icon,omitempty"` DisplayDescription string `json:"display:description,omitempty"` Provider string `json:"ai:provider,omitempty" jsonschema:"enum=wave,enum=google,enum=openrouter,enum=openai,enum=azure,enum=azure-legacy,enum=custom"` - APIType string `json:"ai:apitype,omitempty" jsonschema:"enum=anthropic-messages,enum=openai-responses,enum=openai-chat"` + APIType string `json:"ai:apitype,omitempty" jsonschema:"enum=google-gemini,enum=openai-responses,enum=openai-chat"` Model string `json:"ai:model,omitempty"` ThinkingLevel string `json:"ai:thinkinglevel,omitempty" jsonschema:"enum=low,enum=medium,enum=high"` Endpoint string `json:"ai:endpoint,omitempty"`