diff --git a/frontend/app/aipanel/aimessage.tsx b/frontend/app/aipanel/aimessage.tsx index 26386cb090..7ca8c9b091 100644 --- a/frontend/app/aipanel/aimessage.tsx +++ b/frontend/app/aipanel/aimessage.tsx @@ -2,9 +2,10 @@ // SPDX-License-Identifier: Apache-2.0 import { WaveStreamdown } from "@/app/element/streamdown"; +import { RpcApi } from "@/app/store/wshclientapi"; +import { TabRpcClient } from "@/app/store/wshrpcutil"; import { cn } from "@/util/util"; -import { useAtomValue } from "jotai"; -import { memo } from "react"; +import { memo, useEffect, useState } from "react"; import { getFileIcon } from "./ai-utils"; import { WaveUIMessage, WaveUIMessagePart } from "./aitypes"; import { WaveAIModel } from "./waveai-model"; @@ -67,6 +68,84 @@ const UserMessageFiles = memo(({ fileParts }: UserMessageFilesProps) => { UserMessageFiles.displayName = "UserMessageFiles"; +interface AIToolUseProps { + part: WaveUIMessagePart & { type: "data-tooluse" }; + isStreaming: boolean; +} + +const AIToolUse = memo(({ part }: AIToolUseProps) => { + const toolData = part.data; + const [userApprovalOverride, setUserApprovalOverride] = useState(null); + + const statusIcon = toolData.status === "completed" ? "✓" : toolData.status === "error" ? "✗" : "•"; + const statusColor = + toolData.status === "completed" + ? "text-green-400" + : toolData.status === "error" + ? "text-red-400" + : "text-gray-400"; + + const effectiveApproval = userApprovalOverride || toolData.approval; + + useEffect(() => { + if (effectiveApproval !== "needs-approval") return; + + const interval = setInterval(() => { + RpcApi.WaveAIToolApproveCommand(TabRpcClient, { + toolcallid: toolData.toolcallid, + keepalive: true, + }); + }, 4000); + + return () => clearInterval(interval); + }, [effectiveApproval, toolData.toolcallid]); + + const handleApprove = () => { + setUserApprovalOverride("user-approved"); + RpcApi.WaveAIToolApproveCommand(TabRpcClient, { + toolcallid: toolData.toolcallid, + approval: "user-approved", + }); + }; + + const handleDeny = () => { + setUserApprovalOverride("user-denied"); + RpcApi.WaveAIToolApproveCommand(TabRpcClient, { + toolcallid: toolData.toolcallid, + approval: "user-denied", + }); + }; + + return ( +
+ {statusIcon} +
+
{toolData.toolname}
+ {toolData.tooldesc &&
{toolData.tooldesc}
} + {toolData.errormessage &&
{toolData.errormessage}
} + {effectiveApproval === "needs-approval" && ( +
+ + +
+ )} +
+
+ ); +}); + +AIToolUse.displayName = "AIToolUse"; + interface AIMessagePartProps { part: WaveUIMessagePart; role: string; @@ -93,9 +172,8 @@ const AIMessagePart = memo(({ part, role, isStreaming }: AIMessagePartProps) => } } - if (part.type.startsWith("tool-") && "state" in part && part.state === "input-available") { - const toolName = part.type.substring(5); // Remove "tool-" prefix - return
Calling tool {toolName}
; + if (part.type === "data-tooluse" && part.data) { + return ; } return null; @@ -110,7 +188,9 @@ interface AIMessageProps { const isDisplayPart = (part: WaveUIMessagePart): boolean => { return ( - part.type === "text" || (part.type.startsWith("tool-") && "state" in part && part.state === "input-available") + part.type === "text" || + part.type === "data-tooluse" || + (part.type.startsWith("tool-") && "state" in part && part.state === "input-available") ); }; @@ -122,7 +202,10 @@ export const AIMessage = memo(({ message, isStreaming }: AIMessageProps) => { ); const hasContent = displayParts.length > 0 && - displayParts.some((part) => (part.type === "text" && part.text) || part.type.startsWith("tool-")); + displayParts.some( + (part) => + (part.type === "text" && part.text) || part.type.startsWith("tool-") || part.type === "data-tooluse" + ); const showThinkingOnly = !hasContent && isStreaming && message.role === "assistant"; const showThinkingInline = hasContent && isStreaming && message.role === "assistant"; diff --git a/frontend/app/aipanel/aitypes.ts b/frontend/app/aipanel/aitypes.ts index 8df0f2157a..f16c1f6e3c 100644 --- a/frontend/app/aipanel/aitypes.ts +++ b/frontend/app/aipanel/aitypes.ts @@ -10,6 +10,14 @@ type WaveUIDataTypes = { mimetype: string; previewurl?: string; }; + tooluse: { + toolcallid: string; + toolname: string; + tooldesc: string; + status: "pending" | "error" | "completed"; + errormessage?: string; + approval?: "needs-approval" | "user-approved" | "user-denied" | "auto-approved" | "timeout"; + }; }; export type WaveUIMessage = UIMessage; diff --git a/frontend/app/store/wshclientapi.ts b/frontend/app/store/wshclientapi.ts index 1bb2d71bc6..dc81b159ac 100644 --- a/frontend/app/store/wshclientapi.ts +++ b/frontend/app/store/wshclientapi.ts @@ -492,6 +492,11 @@ class RpcApiType { return client.wshRpcCall("waveaienabletelemetry", null, opts); } + // command "waveaitoolapprove" [call] + WaveAIToolApproveCommand(client: WshClient, data: CommandWaveAIToolApproveData, opts?: RpcOpts): Promise { + return client.wshRpcCall("waveaitoolapprove", data, opts); + } + // command "waveinfo" [call] WaveInfoCommand(client: WshClient, opts?: RpcOpts): Promise { return client.wshRpcCall("waveinfo", null, opts); diff --git a/frontend/app/tab/tabbar.tsx b/frontend/app/tab/tabbar.tsx index 481ee6d678..0f52c9dbd0 100644 --- a/frontend/app/tab/tabbar.tsx +++ b/frontend/app/tab/tabbar.tsx @@ -6,7 +6,7 @@ import { modalsModel } from "@/app/store/modalmodel"; import { WorkspaceLayoutModel } from "@/app/workspace/workspace-layout-model"; import { WindowDrag } from "@/element/windowdrag"; import { deleteLayoutModelForTab } from "@/layout/index"; -import { atoms, createTab, getApi, globalStore, isDev, setActiveTab } from "@/store/global"; +import { atoms, createTab, getApi, globalStore, setActiveTab } from "@/store/global"; import { PLATFORM, PlatformMacOS } from "@/util/platformutil"; import { fireAndForget } from "@/util/util"; import { useAtomValue } from "jotai"; @@ -640,7 +640,7 @@ const TabBar = memo(({ workspace }: TabBarProps) => { } const tabsWrapperWidth = tabIds.length * tabWidthRef.current; - const waveaiButton = isDev() ? ( + const waveaiButton = (
{ AI
- ) : undefined; + ); const appMenuButton = PLATFORM !== PlatformMacOS && !settings["window:showmenubar"] ? (
diff --git a/frontend/app/workspace/workspace-layout-model.ts b/frontend/app/workspace/workspace-layout-model.ts index a404202f1f..f2a429efd2 100644 --- a/frontend/app/workspace/workspace-layout-model.ts +++ b/frontend/app/workspace/workspace-layout-model.ts @@ -7,7 +7,7 @@ import * as WOS from "@/app/store/wos"; import { RpcApi } from "@/app/store/wshclientapi"; import { TabRpcClient } from "@/app/store/wshrpcutil"; import { getLayoutModelForStaticTab } from "@/layout/lib/layoutModelHooks"; -import { atoms, getApi, getTabMetaKeyAtom, isDev, recordTEvent, refocusNode } from "@/store/global"; +import { atoms, getApi, getTabMetaKeyAtom, recordTEvent, refocusNode } from "@/store/global"; import debug from "debug"; import * as jotai from "jotai"; import { debounce } from "lodash-es"; @@ -42,7 +42,7 @@ class WorkspaceLayoutModel { this.panelContainerRef = null; this.aiPanelWrapperRef = null; this.inResize = false; - this.aiPanelVisible = isDev(); + this.aiPanelVisible = true; this.aiPanelWidth = null; this.panelVisibleAtom = jotai.atom(this.aiPanelVisible); @@ -219,9 +219,6 @@ class WorkspaceLayoutModel { } setAIPanelVisible(visible: boolean): void { - if (!isDev() && visible) { - return; - } if (this.focusTimeoutRef != null) { clearTimeout(this.focusTimeoutRef); this.focusTimeoutRef = null; @@ -290,9 +287,6 @@ class WorkspaceLayoutModel { } handleAIPanelResize(width: number, windowWidth: number): void { - if (!isDev()) { - return; - } if (!this.getAIPanelVisible()) { return; } diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index d4dbfb6ab8..14ea692ffa 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -320,6 +320,13 @@ declare global { waitms: number; }; + // wshrpc.CommandWaveAIToolApproveData + type CommandWaveAIToolApproveData = { + toolcallid: string; + keepalive?: boolean; + approval?: string; + }; + // wshrpc.CommandWebSelectorData type CommandWebSelectorData = { workspaceid: string; @@ -944,6 +951,7 @@ declare global { "waveai:outputtokens"?: number; "waveai:requestcount"?: number; "waveai:toolusecount"?: number; + "waveai:tooluseerrorcount"?: number; "waveai:tooldetail"?: {[key: string]: number}; "waveai:premiumreq"?: number; "waveai:proxyreq"?: number; diff --git a/pkg/aiusechat/openai/openai-backend.go b/pkg/aiusechat/openai/openai-backend.go index 491e5c4621..3df718e267 100644 --- a/pkg/aiusechat/openai/openai-backend.go +++ b/pkg/aiusechat/openai/openai-backend.go @@ -39,11 +39,12 @@ type OpenAIMessage struct { } type OpenAIFunctionCallInput struct { - Type string `json:"type"` // Required: The type of the function tool call. Always function_call - CallId string `json:"call_id"` // Required: The unique ID of the function tool call generated by the model - Name string `json:"name"` // Required: The name of the function to run - Arguments string `json:"arguments"` // Required: A JSON string of the arguments to pass to the function - Status string `json:"status,omitempty"` // Optional: The status of the item. One of in_progress, completed, or incomplete + Type string `json:"type"` // Required: The type of the function tool call. Always function_call + CallId string `json:"call_id"` // Required: The unique ID of the function tool call generated by the model + Name string `json:"name"` // Required: The name of the function to run + Arguments string `json:"arguments"` // Required: A JSON string of the arguments to pass to the function + Status string `json:"status,omitempty"` // Optional: The status of the item. One of in_progress, completed, or incomplete + ToolUseData *uctypes.UIMessageDataToolUse `json:"toolusedata,omitempty"` // Internal field for UI tool use data (must be cleaned before sending to API) // removed the "id" field (optional to send back in inputs) } @@ -93,6 +94,15 @@ func (m *OpenAIMessage) CleanAndCopy() *OpenAIMessage { return rtn } +func (f *OpenAIFunctionCallInput) Clean() *OpenAIFunctionCallInput { + if f.ToolUseData == nil { + return f + } + rtn := *f + rtn.ToolUseData = nil + return &rtn +} + type openAIErrorResponse struct { Error openAIErrorType `json:"error"` } @@ -313,14 +323,47 @@ type openaiBlockState struct { } type openaiStreamingState struct { - blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming + blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming + toolUseData map[string]*uctypes.UIMessageDataToolUse // Use toolCallId as key msgID string model string stepStarted bool + chatOpts uctypes.WaveChatOpts } // ---------- Public entrypoint ---------- +func UpdateToolUseData(chatId string, callId string, newToolUseData *uctypes.UIMessageDataToolUse) error { + chat := chatstore.DefaultChatStore.Get(chatId) + if chat == nil { + return fmt.Errorf("chat not found: %s", chatId) + } + + for _, genMsg := range chat.NativeMessages { + chatMsg, ok := genMsg.(*OpenAIChatMessage) + if !ok { + continue + } + + if chatMsg.FunctionCall != nil && chatMsg.FunctionCall.CallId == callId { + updatedMsg := *chatMsg + updatedFunctionCall := *chatMsg.FunctionCall + updatedFunctionCall.ToolUseData = newToolUseData + updatedMsg.FunctionCall = &updatedFunctionCall + + aiOpts := &uctypes.AIOptsType{ + APIType: chat.APIType, + Model: chat.Model, + APIVersion: chat.APIVersion, + } + + return chatstore.DefaultChatStore.PostMessage(chatId, aiOpts, &updatedMsg) + } + } + + return fmt.Errorf("function call with callId %s not found in chat %s", callId, chatId) +} + func RunOpenAIChatStep( ctx context.Context, sse *sse.SSEHandlerCh, @@ -377,7 +420,8 @@ func RunOpenAIChatStep( cleanedMsg := chatMsg.Message.CleanAndCopy() inputs = append(inputs, *cleanedMsg) } else if chatMsg.FunctionCall != nil { - inputs = append(inputs, *chatMsg.FunctionCall) + cleanedFunctionCall := chatMsg.FunctionCall.Clean() + inputs = append(inputs, *cleanedFunctionCall) } else if chatMsg.FunctionCallOutput != nil { inputs = append(inputs, *chatMsg.FunctionCallOutput) } @@ -444,7 +488,7 @@ func RunOpenAIChatStep( // Use eventsource decoder for proper SSE parsing decoder := eventsource.NewDecoder(resp.Body) - stopReason, rtnMessages := handleOpenAIStreamingResp(ctx, sse, decoder, cont) + stopReason, rtnMessages := handleOpenAIStreamingResp(ctx, sse, decoder, cont, chatOpts) return stopReason, rtnMessages, rateLimitInfo, nil } @@ -473,10 +517,12 @@ func parseOpenAIHTTPError(resp *http.Response) error { } // handleOpenAIStreamingResp handles the OpenAI SSE streaming response -func handleOpenAIStreamingResp(ctx context.Context, sse *sse.SSEHandlerCh, decoder *eventsource.Decoder, cont *uctypes.WaveContinueResponse) (*uctypes.WaveStopReason, []*OpenAIChatMessage) { +func handleOpenAIStreamingResp(ctx context.Context, sse *sse.SSEHandlerCh, decoder *eventsource.Decoder, cont *uctypes.WaveContinueResponse, chatOpts uctypes.WaveChatOpts) (*uctypes.WaveStopReason, []*OpenAIChatMessage) { // Per-response state state := &openaiStreamingState{ - blockMap: map[string]*openaiBlockState{}, + blockMap: map[string]*openaiBlockState{}, + toolUseData: map[string]*uctypes.UIMessageDataToolUse{}, + chatOpts: chatOpts, } var rtnStopReason *uctypes.WaveStopReason @@ -602,7 +648,8 @@ func handleOpenAIEvent( toolCallID: ev.Item.CallId, toolName: ev.Item.Name, } - _ = sse.AiMsgToolInputStart(ev.Item.CallId, ev.Item.Name) + // no longer send tool inputs to FE + // _ = sse.AiMsgToolInputStart(ev.Item.CallId, ev.Item.Name) } return nil, nil @@ -712,7 +759,7 @@ func handleOpenAIEvent( } // Extract partial message if available - finalMessages, _ := extractMessageAndToolsFromResponse(ev.Response) + finalMessages, _ := extractMessageAndToolsFromResponse(ev.Response, state.toolUseData) _ = sse.AiMsgError(errorMsg) return &uctypes.WaveStopReason{ @@ -725,7 +772,7 @@ func handleOpenAIEvent( } // Extract the final message and tool calls from the response output - finalMessages, toolCalls := extractMessageAndToolsFromResponse(ev.Response) + finalMessages, toolCalls := extractMessageAndToolsFromResponse(ev.Response, state.toolUseData) stopKind := uctypes.StopKindDone if len(toolCalls) > 0 { @@ -758,8 +805,18 @@ func handleOpenAIEvent( // Get the function call info from the block state if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockToolUse { - raw := json.RawMessage(ev.Arguments) - _ = sse.AiMsgToolInputAvailable(st.toolCallID, st.toolName, raw) + // raw := json.RawMessage(ev.Arguments) + // no longer send tool inputs to fe + // _ = sse.AiMsgToolInputAvailable(st.toolCallID, st.toolName, raw) + + toolDef := state.chatOpts.GetToolDefinition(st.toolName) + toolUseData := createToolUseData(st.toolCallID, st.toolName, toolDef, ev.Arguments) + state.toolUseData[st.toolCallID] = toolUseData + if toolUseData.Approval == uctypes.ApprovalNeedsApproval && state.chatOpts.RegisterToolApproval != nil { + state.chatOpts.RegisterToolApproval(st.toolCallID) + } + log.Printf("AI data-tooluse %s\n", st.toolCallID) + _ = sse.AiMsgData("data-tooluse", st.toolCallID, *toolUseData) } return nil, nil @@ -769,9 +826,40 @@ func handleOpenAIEvent( return nil, nil } } +func createToolUseData(toolCallID, toolName string, toolDef *uctypes.ToolDefinition, arguments string) *uctypes.UIMessageDataToolUse { + toolUseData := &uctypes.UIMessageDataToolUse{ + ToolCallId: toolCallID, + ToolName: toolName, + Status: uctypes.ToolUseStatusPending, + } + + if toolDef == nil { + toolUseData.Status = uctypes.ToolUseStatusError + toolUseData.ErrorMessage = "tool not found" + return toolUseData + } + + var parsedArgs any + if err := json.Unmarshal([]byte(arguments), &parsedArgs); err != nil { + toolUseData.Status = uctypes.ToolUseStatusError + toolUseData.ErrorMessage = fmt.Sprintf("failed to parse tool arguments: %v", err) + return toolUseData + } + + if toolDef.ToolInputDesc != nil { + toolUseData.ToolDesc = toolDef.ToolInputDesc(parsedArgs) + } + + if toolDef.ToolApproval != nil { + toolUseData.Approval = toolDef.ToolApproval(parsedArgs) + } + + return toolUseData +} + // extractMessageAndToolsFromResponse extracts the final OpenAI message and tool calls from the completed response -func extractMessageAndToolsFromResponse(resp openaiResponse) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) { +func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[string]*uctypes.UIMessageDataToolUse) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) { var messageContent []OpenAIMessageContent var toolCalls []uctypes.WaveToolCall var messages []*OpenAIChatMessage @@ -792,7 +880,7 @@ func extractMessageAndToolsFromResponse(resp openaiResponse) ([]*OpenAIChatMessa case "function_call": // Extract tool call information toolCall := uctypes.WaveToolCall{ - ID: outputItem.Id, + ID: outputItem.CallId, Name: outputItem.Name, } @@ -804,6 +892,13 @@ func extractMessageAndToolsFromResponse(resp openaiResponse) ([]*OpenAIChatMessa } } + // Attach UIToolUseData if available + if data, ok := toolUseData[outputItem.CallId]; ok { + toolCall.ToolUseData = data + } else { + log.Printf("AI no data-tooluse for %s (callid: %s)\n", outputItem.Id, outputItem.CallId) + } + toolCalls = append(toolCalls, toolCall) // Create separate FunctionCall message @@ -811,13 +906,18 @@ func extractMessageAndToolsFromResponse(resp openaiResponse) ([]*OpenAIChatMessa if outputItem.Arguments != "" { argsStr = outputItem.Arguments } + var toolUseDataPtr *uctypes.UIMessageDataToolUse + if data, ok := toolUseData[outputItem.CallId]; ok { + toolUseDataPtr = data + } functionCallMsg := &OpenAIChatMessage{ MessageId: uuid.New().String(), FunctionCall: &OpenAIFunctionCallInput{ - Type: "function_call", - CallId: outputItem.Id, - Name: outputItem.Name, - Arguments: argsStr, + Type: "function_call", + CallId: outputItem.CallId, + Name: outputItem.Name, + Arguments: argsStr, + ToolUseData: toolUseDataPtr, }, } messages = append(messages, functionCallMsg) @@ -844,4 +944,3 @@ func extractMessageAndToolsFromResponse(resp openaiResponse) ([]*OpenAIChatMessa return allMessages, toolCalls } - diff --git a/pkg/aiusechat/openai/openai-convertmessage.go b/pkg/aiusechat/openai/openai-convertmessage.go index 351c28f185..98a60ca26f 100644 --- a/pkg/aiusechat/openai/openai-convertmessage.go +++ b/pkg/aiusechat/openai/openai-convertmessage.go @@ -463,23 +463,11 @@ func (m *OpenAIChatMessage) ConvertToUIMessage() *uctypes.UIMessage { } else if m.FunctionCall != nil { // Handle function call input role = "assistant" - if m.FunctionCall.Name != "" && m.FunctionCall.CallId != "" { - // Parse arguments JSON string to interface{} - var args interface{} - if m.FunctionCall.Arguments != "" { - if err := json.Unmarshal([]byte(m.FunctionCall.Arguments), &args); err != nil { - log.Printf("openai: failed to parse function call arguments: %v", err) - args = map[string]interface{}{} - } - } else { - args = map[string]interface{}{} - } - + if m.FunctionCall.ToolUseData != nil { parts = append(parts, uctypes.UIMessagePart{ - Type: "tool-" + m.FunctionCall.Name, - State: "input-available", - ToolCallID: m.FunctionCall.CallId, - Input: args, + Type: "data-tooluse", + ID: m.FunctionCall.CallId, + Data: *m.FunctionCall.ToolUseData, }) } } else if m.FunctionCallOutput != nil { diff --git a/pkg/aiusechat/toolapproval.go b/pkg/aiusechat/toolapproval.go new file mode 100644 index 0000000000..7c374a15b6 --- /dev/null +++ b/pkg/aiusechat/toolapproval.go @@ -0,0 +1,116 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package aiusechat + +import ( + "sync" + "time" + + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" +) + +const ( + InitialApprovalTimeout = 10 * time.Second + KeepAliveExtension = 10 * time.Second +) + +type ApprovalRequest struct { + approval string + done bool + doneChan chan struct{} + timer *time.Timer + mu sync.Mutex +} + +type ApprovalRegistry struct { + mu sync.Mutex + requests map[string]*ApprovalRequest +} + +var globalApprovalRegistry = &ApprovalRegistry{ + requests: make(map[string]*ApprovalRequest), +} + +func registerToolApprovalRequest(toolCallId string, req *ApprovalRequest) { + globalApprovalRegistry.mu.Lock() + defer globalApprovalRegistry.mu.Unlock() + globalApprovalRegistry.requests[toolCallId] = req +} + +func getToolApprovalRequest(toolCallId string) (*ApprovalRequest, bool) { + globalApprovalRegistry.mu.Lock() + defer globalApprovalRegistry.mu.Unlock() + req, exists := globalApprovalRegistry.requests[toolCallId] + return req, exists +} + +func RegisterToolApproval(toolCallId string) { + req := &ApprovalRequest{ + doneChan: make(chan struct{}), + } + + req.timer = time.AfterFunc(InitialApprovalTimeout, func() { + UpdateToolApproval(toolCallId, uctypes.ApprovalTimeout, false) + }) + + registerToolApprovalRequest(toolCallId, req) +} + +func UpdateToolApproval(toolCallId string, approval string, keepAlive bool) error { + req, exists := getToolApprovalRequest(toolCallId) + if !exists { + return nil + } + + req.mu.Lock() + defer req.mu.Unlock() + + if req.done { + return nil + } + + if keepAlive && approval == "" { + req.timer.Reset(KeepAliveExtension) + return nil + } + + req.approval = approval + req.done = true + + if req.timer != nil { + req.timer.Stop() + } + + close(req.doneChan) + return nil +} +func CurrentToolApprovalStatus(toolCallId string) string { + req, exists := getToolApprovalRequest(toolCallId) + if !exists { + return "" + } + + req.mu.Lock() + defer req.mu.Unlock() + return req.approval +} + +func WaitForToolApproval(toolCallId string) string { + req, exists := getToolApprovalRequest(toolCallId) + if !exists { + return "" + } + + <-req.doneChan + + req.mu.Lock() + approval := req.approval + req.mu.Unlock() + + globalApprovalRegistry.mu.Lock() + delete(globalApprovalRegistry.requests, toolCallId) + globalApprovalRegistry.mu.Unlock() + + return approval +} diff --git a/pkg/aiusechat/tools.go b/pkg/aiusechat/tools.go index c9ff05f0d4..02105a9155 100644 --- a/pkg/aiusechat/tools.go +++ b/pkg/aiusechat/tools.go @@ -126,6 +126,7 @@ func GenerateTabStateAndTools(ctx context.Context, tabid string, widgetAccess bo var tools []uctypes.ToolDefinition if widgetAccess { tools = append(tools, GetCaptureScreenshotToolDefinition(tabid)) + tools = append(tools, GetReadTextFileToolDefinition()) viewTypes := make(map[string]bool) for _, block := range blocks { if block.Meta == nil { diff --git a/pkg/aiusechat/tools_readfile.go b/pkg/aiusechat/tools_readfile.go index 5ac62f40e7..9f5e38b047 100644 --- a/pkg/aiusechat/tools_readfile.go +++ b/pkg/aiusechat/tools_readfile.go @@ -12,18 +12,71 @@ import ( "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" "github.com/wavetermdev/waveterm/pkg/util/readutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/wavebase" ) +const ReadFileDefaultLineCount = 100 +const ReadFileDefaultMaxBytes = 50 * 1024 const StopReasonMaxBytes = "max_bytes" type readTextFileParams struct { Filename string `json:"filename"` - Origin *string `json:"origin"` // "start" or "end", defaults to "start" - Offset *int `json:"offset"` // lines to skip, defaults to 0 - Count *int `json:"count"` // number of lines to read, defaults to DefaultLineCount + Origin *string `json:"origin"` // "start" or "end", defaults to "start" + Offset *int `json:"offset"` // lines to skip, defaults to 0 + Count *int `json:"count"` // number of lines to read, defaults to DefaultLineCount MaxBytes *int `json:"max_bytes"` } +func parseReadTextFileInput(input any) (*readTextFileParams, error) { + result := &readTextFileParams{} + + if input == nil { + return nil, fmt.Errorf("input is required") + } + + if err := utilfn.ReUnmarshal(result, input); err != nil { + return nil, fmt.Errorf("invalid input format: %w", err) + } + + if result.Filename == "" { + return nil, fmt.Errorf("missing filename parameter") + } + + if result.Origin == nil { + origin := "start" + result.Origin = &origin + } + + if *result.Origin != "start" && *result.Origin != "end" { + return nil, fmt.Errorf("invalid origin value '%s': must be 'start' or 'end'", *result.Origin) + } + + if result.Offset == nil { + offset := 0 + result.Offset = &offset + } + + if *result.Offset < 0 { + return nil, fmt.Errorf("offset must be non-negative, got %d", *result.Offset) + } + + if result.Count == nil { + count := ReadFileDefaultLineCount + result.Count = &count + } + + if *result.Count < 1 { + return nil, fmt.Errorf("count must be at least 1, got %d", *result.Count) + } + + if result.MaxBytes == nil { + maxBytes := ReadFileDefaultMaxBytes + result.MaxBytes = &maxBytes + } + + return result, nil +} + // truncateData truncates data to maxBytes while respecting line boundaries. // For origin "start", keeps the beginning and truncates at last newline before maxBytes. // For origin "end", keeps the end and truncates from beginning at first newline after removing excess. @@ -49,34 +102,32 @@ func truncateData(data string, origin string, maxBytes int) string { } func readTextFileCallback(input any) (any, error) { - const DefaultLineCount = 100 - const DefaultMaxBytes = 50 * 1024 const ReadLimit = 1024 * 1024 * 1024 - var params readTextFileParams - if err := utilfn.ReUnmarshal(¶ms, input); err != nil { - return nil, fmt.Errorf("invalid input format: %w", err) + params, err := parseReadTextFileInput(input) + if err != nil { + return nil, err } - if params.Filename == "" { - return nil, fmt.Errorf("missing filename parameter") + expandedPath, err := wavebase.ExpandHomeDir(params.Filename) + if err != nil { + return nil, fmt.Errorf("failed to expand path: %w", err) } - maxBytes := DefaultMaxBytes - if params.MaxBytes != nil { - maxBytes = *params.MaxBytes + fileInfo, err := os.Stat(expandedPath) + if err != nil { + return nil, fmt.Errorf("failed to stat file: %w", err) } - file, err := os.Open(params.Filename) - if err != nil { - return nil, fmt.Errorf("failed to open file: %w", err) + if fileInfo.IsDir() { + return nil, fmt.Errorf("path is a directory, cannot be read with the read_text_file tool. use the read_dir tool if available to read directories") } - defer file.Close() - fileInfo, err := file.Stat() + file, err := os.Open(expandedPath) if err != nil { - return nil, fmt.Errorf("failed to stat file: %w", err) + return nil, fmt.Errorf("failed to open file: %w", err) } + defer file.Close() totalSize := fileInfo.Size() modTime := fileInfo.ModTime() @@ -92,31 +143,10 @@ func readTextFileCallback(input any) (any, error) { return nil, fmt.Errorf("file appears to be binary content") } - origin := "start" - if params.Origin != nil { - origin = *params.Origin - } - - if origin != "start" && origin != "end" { - return nil, fmt.Errorf("invalid origin value '%s': must be 'start' or 'end'", origin) - } - - offset := 0 - if params.Offset != nil { - offset = *params.Offset - } - - count := DefaultLineCount - if params.Count != nil { - count = *params.Count - if count < 1 { - return nil, fmt.Errorf("count must be at least 1, got %d", count) - } - } - - if offset < 0 { - offset = 0 - } + origin := *params.Origin + offset := *params.Offset + count := *params.Count + maxBytes := *params.MaxBytes var lines []string var stopReason string @@ -150,6 +180,7 @@ func readTextFileCallback(input any) (any, error) { "data": data, "modified": utilfn.FormatRelativeTime(modTime), "modified_time": modTime.UTC().Format("2006-01-02 15:04:05 UTC"), + "mode": fileInfo.Mode().String(), } if stopReason != "" { result["truncated"] = stopReason @@ -162,7 +193,7 @@ func GetReadTextFileToolDefinition() uctypes.ToolDefinition { return uctypes.ToolDefinition{ Name: "read_text_file", DisplayName: "Read Text File", - Description: "Read a text file from the filesystem. Can read specific line ranges or from the end. Detects and rejects binary files.", + Description: "Read a text file from the filesystem. Can read specific line ranges or from the end. Detects and rejects binary files. Requires user approval.", ToolLogName: "gen:readfile", Strict: false, InputSchema: map[string]any{ @@ -200,6 +231,30 @@ func GetReadTextFileToolDefinition() uctypes.ToolDefinition { "required": []string{"filename"}, "additionalProperties": false, }, + ToolInputDesc: func(input any) string { + parsed, err := parseReadTextFileInput(input) + if err != nil { + return fmt.Sprintf("error parsing input: %v", err) + } + + origin := *parsed.Origin + offset := *parsed.Offset + count := *parsed.Count + + if origin == "start" && offset == 0 { + return fmt.Sprintf("reading %q (first %d lines)", parsed.Filename, count) + } + if origin == "end" && offset == 0 { + return fmt.Sprintf("reading %q (last %d lines)", parsed.Filename, count) + } + if origin == "end" { + return fmt.Sprintf("reading %q (from end: offset %d lines, count %d lines)", parsed.Filename, offset, count) + } + return fmt.Sprintf("reading %q (from start: offset %d lines, count %d lines)", parsed.Filename, offset, count) + }, ToolAnyCallback: readTextFileCallback, + ToolApproval: func(input any) string { + return uctypes.ApprovalNeedsApproval + }, } } diff --git a/pkg/aiusechat/tools_screenshot.go b/pkg/aiusechat/tools_screenshot.go index 5d89b25469..ad366cd20e 100644 --- a/pkg/aiusechat/tools_screenshot.go +++ b/pkg/aiusechat/tools_screenshot.go @@ -73,6 +73,17 @@ func GetCaptureScreenshotToolDefinition(tabId string) uctypes.ToolDefinition { "required": []string{"widget_id"}, "additionalProperties": false, }, + ToolInputDesc: func(input any) string { + inputMap, ok := input.(map[string]any) + if !ok { + return "error parsing input: invalid format" + } + widgetId, ok := inputMap["widget_id"].(string) + if !ok { + return "error parsing input: missing widget_id" + } + return fmt.Sprintf("capturing screenshot of widget %s", widgetId) + }, ToolTextCallback: makeTabCaptureBlockScreenshot(tabId), } } diff --git a/pkg/aiusechat/uctypes/usechat-types.go b/pkg/aiusechat/uctypes/usechat-types.go index bb5a077837..ff6ec62ebd 100644 --- a/pkg/aiusechat/uctypes/usechat-types.go +++ b/pkg/aiusechat/uctypes/usechat-types.go @@ -86,6 +86,7 @@ type ToolDefinition struct { ToolTextCallback func(any) (string, error) `json:"-"` ToolAnyCallback func(any) (any, error) `json:"-"` ToolInputDesc func(any) string `json:"-"` + ToolApproval func(any) string `json:"-"` } func (td *ToolDefinition) Clean() *ToolDefinition { @@ -98,6 +99,16 @@ func (td *ToolDefinition) Clean() *ToolDefinition { return &rtn } +func (td *ToolDefinition) Desc() string { + if td == nil { + return "" + } + if td.ShortDescription != "" { + return td.ShortDescription + } + return td.Description +} + //------------------ // Wave specific types, stop reasons, tool calls, config // these are used internally to coordinate the calls/steps @@ -108,6 +119,33 @@ const ( ThinkingLevelHigh = "high" ) +const ( + ToolUseStatusPending = "pending" + ToolUseStatusError = "error" + ToolUseStatusCompleted = "completed" +) + +const ( + ApprovalNeedsApproval = "needs-approval" + ApprovalUserApproved = "user-approved" + ApprovalUserDenied = "user-denied" + ApprovalTimeout = "timeout" + ApprovalAutoApproved = "auto-approved" +) + +type UIMessageDataToolUse struct { + ToolCallId string `json:"toolcallid"` + ToolName string `json:"toolname"` + ToolDesc string `json:"tooldesc"` + Status string `json:"status"` + ErrorMessage string `json:"errormessage,omitempty"` + Approval string `json:"approval,omitempty"` +} + +func (d *UIMessageDataToolUse) IsApproved() bool { + return d.Approval == "" || d.Approval == ApprovalUserApproved || d.Approval == ApprovalAutoApproved +} + type StopReasonKind string const ( @@ -123,9 +161,10 @@ const ( ) type WaveToolCall struct { - ID string `json:"id"` // Anthropic tool_use.id - Name string `json:"name,omitempty"` // tool name (if provided) - Input any `json:"input,omitempty"` // accumulated input JSON + ID string `json:"id"` // Anthropic tool_use.id + Name string `json:"name,omitempty"` // tool name (if provided) + Input any `json:"input,omitempty"` // accumulated input JSON + ToolUseData *UIMessageDataToolUse `json:"toolusedata,omitempty"` // UI tool use data } type WaveStopReason struct { @@ -193,6 +232,7 @@ type AIMetrics struct { Usage AIUsage `json:"usage"` RequestCount int `json:"requestcount"` ToolUseCount int `json:"toolusecount"` + ToolUseErrorCount int `json:"tooluseerrorcount"` ToolDetail map[string]int `json:"tooldetail,omitempty"` PremiumReqCount int `json:"premiumreqcount"` ProxyReqCount int `json:"proxyreqcount"` @@ -201,8 +241,8 @@ type AIMetrics struct { PDFCount int `json:"pdfcount"` TextDocCount int `json:"textdoccount"` TextLen int `json:"textlen"` - FirstByteLatency int `json:"firstbytelatency"` // ms - RequestDuration int `json:"requestduration"` // ms + FirstByteLatency int `json:"firstbytelatency"` // ms + RequestDuration int `json:"requestduration"` // ms WidgetAccess bool `json:"widgetaccess"` } @@ -376,19 +416,34 @@ func (m *UIMessage) GetContent() string { } type WaveChatOpts struct { - ChatId string - ClientId string - Config AIOptsType - Tools []ToolDefinition - SystemPrompt []string - TabStateGenerator func() (string, []ToolDefinition, error) - WidgetAccess bool + ChatId string + ClientId string + Config AIOptsType + Tools []ToolDefinition + SystemPrompt []string + TabStateGenerator func() (string, []ToolDefinition, error) + WidgetAccess bool + RegisterToolApproval func(string) // emphemeral to the step TabState string TabTools []ToolDefinition } +func (opts *WaveChatOpts) GetToolDefinition(toolName string) *ToolDefinition { + for _, tool := range opts.Tools { + if tool.Name == toolName { + return &tool + } + } + for _, tool := range opts.TabTools { + if tool.Name == toolName { + return &tool + } + } + return nil +} + type ProxyErrorResponse struct { Success bool `json:"success"` Error string `json:"error"` diff --git a/pkg/aiusechat/usechat.go b/pkg/aiusechat/usechat.go index 18f73c89d1..993c8a34cc 100644 --- a/pkg/aiusechat/usechat.go +++ b/pkg/aiusechat/usechat.go @@ -21,6 +21,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata" + "github.com/wavetermdev/waveterm/pkg/util/ds" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/web/sse" @@ -40,6 +41,9 @@ const DefaultMaxTokens = 4 * 1024 var ( globalRateLimitInfo = &uctypes.RateLimitInfo{Unknown: true} rateLimitLock sync.Mutex + + activeToolMap = ds.MakeSyncMap[bool]() // key is toolcallid + activeChats = ds.MakeSyncMap[bool]() // key is chatid ) var SystemPromptText = strings.Join([]string{ @@ -200,31 +204,121 @@ func GetChatUsage(chat *uctypes.AIChat) uctypes.AIUsage { return usage } -func processToolResults(stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) { - var toolResults []uctypes.AIToolResult - for _, toolCall := range stopReason.ToolCalls { - inputJSON, _ := json.Marshal(toolCall.Input) - log.Printf("TOOLUSE name=%s id=%s input=%s\n", toolCall.Name, toolCall.ID, utilfn.TruncateString(string(inputJSON), 40)) - result := ResolveToolCall(toolCall, chatOpts) - toolResults = append(toolResults, result) +func updateToolUseDataInChat(chatOpts uctypes.WaveChatOpts, toolCallID string, toolUseData *uctypes.UIMessageDataToolUse) { + if chatOpts.Config.APIType == APIType_OpenAI { + if err := openai.UpdateToolUseData(chatOpts.ChatId, toolCallID, toolUseData); err != nil { + log.Printf("failed to update tool use data in chat: %v\n", err) + } + } else if chatOpts.Config.APIType == APIType_Anthropic { + log.Printf("warning: UpdateToolUseData not implemented for anthropic\n") + } +} + +func processToolCall(toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) uctypes.AIToolResult { - // Track tool usage by ToolLogName - toolDef := getToolDefinition(toolCall.Name, chatOpts) - if toolDef != nil && toolDef.ToolLogName != "" { - metrics.ToolDetail[toolDef.ToolLogName]++ + if toolCall.ToolUseData == nil { + errorMsg := "Invalid Tool Call" + log.Printf(" error=%s\n", errorMsg) + metrics.ToolUseErrorCount++ + return uctypes.AIToolResult{ + ToolName: toolCall.Name, + ToolUseID: toolCall.ID, + ErrorText: errorMsg, } - if result.ErrorText != "" { - log.Printf(" error=%s\n", result.ErrorText) - } else { - log.Printf(" result=%s\n", utilfn.TruncateString(result.Text, 40)) + } + + inputJSON, _ := json.Marshal(toolCall.Input) + log.Printf("TOOLUSE name=%s id=%s input=%s approval=%q\n", toolCall.Name, toolCall.ID, utilfn.TruncateString(string(inputJSON), 40), toolCall.ToolUseData.Approval) + + if toolCall.ToolUseData.Status == uctypes.ToolUseStatusError { + errorMsg := toolCall.ToolUseData.ErrorMessage + if errorMsg == "" { + errorMsg = "Unspecified Tool Error" + } + log.Printf(" error=%s\n", errorMsg) + metrics.ToolUseErrorCount++ + _ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData) + updateToolUseDataInChat(chatOpts, toolCall.ID, toolCall.ToolUseData) + + return uctypes.AIToolResult{ + ToolName: toolCall.Name, + ToolUseID: toolCall.ID, + ErrorText: errorMsg, + } + } + + if toolCall.ToolUseData.Approval == uctypes.ApprovalNeedsApproval { + log.Printf(" waiting for approval...\n") + approval := WaitForToolApproval(toolCall.ID) + log.Printf(" approval result: %q\n", approval) + if approval != "" { + toolCall.ToolUseData.Approval = approval + } + + if !toolCall.ToolUseData.IsApproved() { + errorMsg := "Tool use denied or timed out" + if approval == uctypes.ApprovalUserDenied { + errorMsg = "Tool use denied by user" + } else if approval == uctypes.ApprovalTimeout { + errorMsg = "Tool approval timed out" + } + log.Printf(" error=%s\n", errorMsg) + metrics.ToolUseErrorCount++ + toolCall.ToolUseData.Status = uctypes.ToolUseStatusError + toolCall.ToolUseData.ErrorMessage = errorMsg + _ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData) + updateToolUseDataInChat(chatOpts, toolCall.ID, toolCall.ToolUseData) + + return uctypes.AIToolResult{ + ToolName: toolCall.Name, + ToolUseID: toolCall.ID, + ErrorText: errorMsg, + } } + + _ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData) + updateToolUseDataInChat(chatOpts, toolCall.ID, toolCall.ToolUseData) + } + + result := ResolveToolCall(toolCall, chatOpts) + + // Track tool usage by ToolLogName + toolDef := chatOpts.GetToolDefinition(toolCall.Name) + if toolDef != nil && toolDef.ToolLogName != "" { + metrics.ToolDetail[toolDef.ToolLogName]++ + } + + if result.ErrorText != "" { + toolCall.ToolUseData.Status = uctypes.ToolUseStatusError + toolCall.ToolUseData.ErrorMessage = result.ErrorText + log.Printf(" error=%s\n", result.ErrorText) + metrics.ToolUseErrorCount++ + } else { + toolCall.ToolUseData.Status = uctypes.ToolUseStatusCompleted + log.Printf(" result=%s\n", utilfn.TruncateString(result.Text, 40)) + } + _ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData) + updateToolUseDataInChat(chatOpts, toolCall.ID, toolCall.ToolUseData) + + return result +} + +func processToolCalls(stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) { + for _, toolCall := range stopReason.ToolCalls { + activeToolMap.Set(toolCall.ID, true) + defer activeToolMap.Delete(toolCall.ID) + } + + var toolResults []uctypes.AIToolResult + for _, toolCall := range stopReason.ToolCalls { + result := processToolCall(toolCall, chatOpts, sseHandler, metrics) + toolResults = append(toolResults, result) } if chatOpts.Config.APIType == APIType_OpenAI { toolResultMsgs, err := openai.ConvertToolResultsToOpenAIChatMessage(toolResults) if err != nil { - _ = sseHandler.AiMsgError(fmt.Sprintf("Failed to convert tool results to OpenAI messages: %v", err)) - _ = sseHandler.AiMsgFinish("", nil) + log.Printf("Failed to convert tool results to OpenAI messages: %v", err) } else { for _, msg := range toolResultMsgs { chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg) @@ -233,8 +327,7 @@ func processToolResults(stopReason *uctypes.WaveStopReason, chatOpts uctypes.Wav } else { toolResultMsg, err := anthropic.ConvertToolResultsToAnthropicChatMessage(toolResults) if err != nil { - _ = sseHandler.AiMsgError(fmt.Sprintf("Failed to convert tool results to Anthropic message: %v", err)) - _ = sseHandler.AiMsgFinish("", nil) + log.Printf("Failed to convert tool results to Anthropic message: %v", err) } else { chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, toolResultMsg) } @@ -243,6 +336,11 @@ func processToolResults(stopReason *uctypes.WaveStopReason, chatOpts uctypes.Wav func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctypes.WaveChatOpts) (*uctypes.AIMetrics, error) { log.Printf("RunAIChat\n") + if !activeChats.SetUnless(chatOpts.ChatId, true) { + return nil, fmt.Errorf("chat %s is already running", chatOpts.ChatId) + } + defer activeChats.Delete(chatOpts.ChatId) + metrics := &uctypes.AIMetrics{ Usage: uctypes.AIUsage{ APIType: chatOpts.Config.APIType, @@ -306,7 +404,7 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp } if stopReason != nil && stopReason.Kind == uctypes.StopKindToolUse { metrics.ToolUseCount += len(stopReason.ToolCalls) - processToolResults(stopReason, chatOpts, sseHandler, metrics) + processToolCalls(stopReason, chatOpts, sseHandler, metrics) var messageID string if len(rtnMessage) > 0 && rtnMessage[0] != nil { @@ -325,21 +423,6 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp return metrics, nil } -// ResolveToolCall resolves a single tool call and returns an AIToolResult -func getToolDefinition(toolName string, chatOpts uctypes.WaveChatOpts) *uctypes.ToolDefinition { - for _, tool := range chatOpts.Tools { - if tool.Name == toolName { - return &tool - } - } - for _, tool := range chatOpts.TabTools { - if tool.Name == toolName { - return &tool - } - } - return nil -} - func ResolveToolCall(toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts) (result uctypes.AIToolResult) { result = uctypes.AIToolResult{ ToolName: toolCall.Name, @@ -353,7 +436,7 @@ func ResolveToolCall(toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpt } }() - toolDef := getToolDefinition(toolCall.Name, chatOpts) + toolDef := chatOpts.GetToolDefinition(toolCall.Name) if toolDef == nil { result.ErrorText = fmt.Sprintf("tool '%s' not found", toolCall.Name) @@ -443,23 +526,24 @@ func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, me func sendAIMetricsTelemetry(ctx context.Context, metrics *uctypes.AIMetrics) { event := telemetrydata.MakeTEvent("waveai:post", telemetrydata.TEventProps{ - WaveAIAPIType: metrics.Usage.APIType, - WaveAIModel: metrics.Usage.Model, - WaveAIInputTokens: metrics.Usage.InputTokens, - WaveAIOutputTokens: metrics.Usage.OutputTokens, - WaveAIRequestCount: metrics.RequestCount, - WaveAIToolUseCount: metrics.ToolUseCount, - WaveAIToolDetail: metrics.ToolDetail, - WaveAIPremiumReq: metrics.PremiumReqCount, - WaveAIProxyReq: metrics.ProxyReqCount, - WaveAIHadError: metrics.HadError, - WaveAIImageCount: metrics.ImageCount, - WaveAIPDFCount: metrics.PDFCount, - WaveAITextDocCount: metrics.TextDocCount, - WaveAITextLen: metrics.TextLen, - WaveAIFirstByteMs: metrics.FirstByteLatency, - WaveAIRequestDurMs: metrics.RequestDuration, - WaveAIWidgetAccess: metrics.WidgetAccess, + WaveAIAPIType: metrics.Usage.APIType, + WaveAIModel: metrics.Usage.Model, + WaveAIInputTokens: metrics.Usage.InputTokens, + WaveAIOutputTokens: metrics.Usage.OutputTokens, + WaveAIRequestCount: metrics.RequestCount, + WaveAIToolUseCount: metrics.ToolUseCount, + WaveAIToolUseErrorCount: metrics.ToolUseErrorCount, + WaveAIToolDetail: metrics.ToolDetail, + WaveAIPremiumReq: metrics.PremiumReqCount, + WaveAIProxyReq: metrics.ProxyReqCount, + WaveAIHadError: metrics.HadError, + WaveAIImageCount: metrics.ImageCount, + WaveAIPDFCount: metrics.PDFCount, + WaveAITextDocCount: metrics.TextDocCount, + WaveAITextLen: metrics.TextLen, + WaveAIFirstByteMs: metrics.FirstByteLatency, + WaveAIRequestDurMs: metrics.RequestDuration, + WaveAIWidgetAccess: metrics.WidgetAccess, }) _ = telemetry.RecordTEvent(ctx, event) } @@ -513,10 +597,11 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) { // Call the core WaveAIPostMessage function chatOpts := uctypes.WaveChatOpts{ - ChatId: req.ChatID, - ClientId: client.OID, - Config: *aiOpts, - WidgetAccess: req.WidgetAccess, + ChatId: req.ChatID, + ClientId: client.OID, + Config: *aiOpts, + WidgetAccess: req.WidgetAccess, + RegisterToolApproval: RegisterToolApproval, } if chatOpts.Config.APIType == APIType_OpenAI { chatOpts.SystemPrompt = []string{SystemPromptText_OpenAI} diff --git a/pkg/telemetry/telemetrydata/telemetrydata.go b/pkg/telemetry/telemetrydata/telemetrydata.go index 95e10d3c47..b65a416779 100644 --- a/pkg/telemetry/telemetrydata/telemetrydata.go +++ b/pkg/telemetry/telemetrydata/telemetrydata.go @@ -69,11 +69,11 @@ type TEventUserProps struct { type TEventProps struct { TEventUserProps `tstype:"-"` // generally don't need to set these since they will be automatically copied over - ActiveMinutes int `json:"activity:activeminutes,omitempty"` - FgMinutes int `json:"activity:fgminutes,omitempty"` - OpenMinutes int `json:"activity:openminutes,omitempty"` - WaveAIActiveMinutes int `json:"activity:waveaiactiveminutes,omitempty"` - WaveAIFgMinutes int `json:"activity:waveaifgminutes,omitempty"` + ActiveMinutes int `json:"activity:activeminutes,omitempty"` + FgMinutes int `json:"activity:fgminutes,omitempty"` + OpenMinutes int `json:"activity:openminutes,omitempty"` + WaveAIActiveMinutes int `json:"activity:waveaiactiveminutes,omitempty"` + WaveAIFgMinutes int `json:"activity:waveaifgminutes,omitempty"` AppFirstDay bool `json:"app:firstday,omitempty"` AppFirstLaunch bool `json:"app:firstlaunch,omitempty"` @@ -101,23 +101,24 @@ type TEventProps struct { CountWSLConn int `json:"count:wslconn,omitempty"` CountViews map[string]int `json:"count:views,omitempty"` - WaveAIAPIType string `json:"waveai:apitype,omitempty"` - WaveAIModel string `json:"waveai:model,omitempty"` - WaveAIInputTokens int `json:"waveai:inputtokens,omitempty"` - WaveAIOutputTokens int `json:"waveai:outputtokens,omitempty"` - WaveAIRequestCount int `json:"waveai:requestcount,omitempty"` - WaveAIToolUseCount int `json:"waveai:toolusecount,omitempty"` - WaveAIToolDetail map[string]int `json:"waveai:tooldetail,omitempty"` - WaveAIPremiumReq int `json:"waveai:premiumreq,omitempty"` - WaveAIProxyReq int `json:"waveai:proxyreq,omitempty"` - WaveAIHadError bool `json:"waveai:haderror,omitempty"` - WaveAIImageCount int `json:"waveai:imagecount,omitempty"` - WaveAIPDFCount int `json:"waveai:pdfcount,omitempty"` - WaveAITextDocCount int `json:"waveai:textdoccount,omitempty"` - WaveAITextLen int `json:"waveai:textlen,omitempty"` - WaveAIFirstByteMs int `json:"waveai:firstbytems,omitempty"` // ms - WaveAIRequestDurMs int `json:"waveai:requestdurms,omitempty"` // ms - WaveAIWidgetAccess bool `json:"waveai:widgetaccess,omitempty"` + WaveAIAPIType string `json:"waveai:apitype,omitempty"` + WaveAIModel string `json:"waveai:model,omitempty"` + WaveAIInputTokens int `json:"waveai:inputtokens,omitempty"` + WaveAIOutputTokens int `json:"waveai:outputtokens,omitempty"` + WaveAIRequestCount int `json:"waveai:requestcount,omitempty"` + WaveAIToolUseCount int `json:"waveai:toolusecount,omitempty"` + WaveAIToolUseErrorCount int `json:"waveai:tooluseerrorcount,omitempty"` + WaveAIToolDetail map[string]int `json:"waveai:tooldetail,omitempty"` + WaveAIPremiumReq int `json:"waveai:premiumreq,omitempty"` + WaveAIProxyReq int `json:"waveai:proxyreq,omitempty"` + WaveAIHadError bool `json:"waveai:haderror,omitempty"` + WaveAIImageCount int `json:"waveai:imagecount,omitempty"` + WaveAIPDFCount int `json:"waveai:pdfcount,omitempty"` + WaveAITextDocCount int `json:"waveai:textdoccount,omitempty"` + WaveAITextLen int `json:"waveai:textlen,omitempty"` + WaveAIFirstByteMs int `json:"waveai:firstbytems,omitempty"` // ms + WaveAIRequestDurMs int `json:"waveai:requestdurms,omitempty"` // ms + WaveAIWidgetAccess bool `json:"waveai:widgetaccess,omitempty"` UserSet *TEventUserProps `json:"$set,omitempty"` UserSetOnce *TEventUserProps `json:"$set_once,omitempty"` diff --git a/pkg/util/ds/syncmap.go b/pkg/util/ds/syncmap.go index daf3969b09..a422343ac5 100644 --- a/pkg/util/ds/syncmap.go +++ b/pkg/util/ds/syncmap.go @@ -41,3 +41,24 @@ func (sm *SyncMap[T]) Delete(key string) { defer sm.lock.Unlock() delete(sm.m, key) } + +func (sm *SyncMap[T]) SetUnless(key string, value T) bool { + sm.lock.Lock() + defer sm.lock.Unlock() + if _, exists := sm.m[key]; exists { + return false + } + sm.m[key] = value + return true +} + +func (sm *SyncMap[T]) TestAndSet(key string, newValue T, testFn func(T, bool) bool) bool { + sm.lock.Lock() + defer sm.lock.Unlock() + currentValue, exists := sm.m[key] + if testFn(currentValue, exists) { + sm.m[key] = newValue + return true + } + return false +} diff --git a/pkg/web/sse/ssehandler.go b/pkg/web/sse/ssehandler.go index 5742b7a642..70c4706d8e 100644 --- a/pkg/web/sse/ssehandler.go +++ b/pkg/web/sse/ssehandler.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "sync" "time" ) @@ -459,3 +460,16 @@ func (h *SSEHandlerCh) AiMsgError(errText string) error { } return h.WriteJsonData(resp) } + + +func (h *SSEHandlerCh) AiMsgData(dataType string, id string, data interface{}) error { + if !strings.HasPrefix(dataType, "data-") { + panic(fmt.Sprintf("AiMsgData type must start with 'data-', got: %s", dataType)) + } + resp := map[string]interface{}{ + "type": dataType, + "id": id, + "data": data, + } + return h.WriteJsonData(resp) +} diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index f4aad41c5b..02d0c7242b 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -587,6 +587,12 @@ func WaveAIEnableTelemetryCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) error return err } +// command "waveaitoolapprove", wshserver.WaveAIToolApproveCommand +func WaveAIToolApproveCommand(w *wshutil.WshRpc, data wshrpc.CommandWaveAIToolApproveData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "waveaitoolapprove", data, opts) + return err +} + // command "waveinfo", wshserver.WaveInfoCommand func WaveInfoCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (*wshrpc.WaveInfoData, error) { resp, err := sendRpcRequestCallHelper[*wshrpc.WaveInfoData](w, "waveinfo", nil, opts) diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 13e937b1bd..754b0fa4cf 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -142,6 +142,7 @@ const ( Command_WaveAIEnableTelemetry = "waveaienabletelemetry" Command_GetWaveAIChat = "getwaveaichat" Command_GetWaveAIRateLimit = "getwaveairatelimit" + Command_WaveAIToolApprove = "waveaitoolapprove" Command_CaptureBlockScreenshot = "captureblockscreenshot" @@ -271,6 +272,7 @@ type WshRpcInterface interface { WaveAIEnableTelemetryCommand(ctx context.Context) error GetWaveAIChatCommand(ctx context.Context, data CommandGetWaveAIChatData) (*uctypes.UIChat, error) GetWaveAIRateLimitCommand(ctx context.Context) (*uctypes.RateLimitInfo, error) + WaveAIToolApproveCommand(ctx context.Context, data CommandWaveAIToolApproveData) error // screenshot CaptureBlockScreenshotCommand(ctx context.Context, data CommandCaptureBlockScreenshotData) (string, error) @@ -726,6 +728,12 @@ type CommandGetWaveAIChatData struct { ChatId string `json:"chatid"` } +type CommandWaveAIToolApproveData struct { + ToolCallId string `json:"toolcallid"` + KeepAlive bool `json:"keepalive,omitempty"` + Approval string `json:"approval,omitempty"` +} + type CommandCaptureBlockScreenshotData struct { BlockId string `json:"blockid" wshcontext:"BlockId"` } diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 554fffab4e..63a75a723a 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -960,6 +960,10 @@ func (ws *WshServer) GetWaveAIRateLimitCommand(ctx context.Context) (*uctypes.Ra return aiusechat.GetGlobalRateLimit(), nil } +func (ws *WshServer) WaveAIToolApproveCommand(ctx context.Context, data wshrpc.CommandWaveAIToolApproveData) error { + return aiusechat.UpdateToolApproval(data.ToolCallId, data.Approval, data.KeepAlive) +} + var wshActivityRe = regexp.MustCompile(`^[a-z:#]+$`) func (ws *WshServer) WshActivityCommand(ctx context.Context, data map[string]int) error {