Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 43 additions & 25 deletions pkg/aiusechat/openai/openai-backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,11 @@ func (m *OpenAIChatMessage) GetUsage() *uctypes.AIUsage {
return nil
}
return &uctypes.AIUsage{
APIType: "openai",
Model: m.Usage.Model,
InputTokens: m.Usage.InputTokens,
OutputTokens: m.Usage.OutputTokens,
APIType: "openai",
Model: m.Usage.Model,
InputTokens: m.Usage.InputTokens,
OutputTokens: m.Usage.OutputTokens,
NativeWebSearchCount: m.Usage.NativeWebSearchCount,
}
}

Expand Down Expand Up @@ -281,12 +282,13 @@ type openaiTextFormat struct {
}

type OpenAIUsage struct {
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
InputTokensDetails *openaiInputTokensDetails `json:"input_tokens_details,omitempty"`
OutputTokensDetails *openaiOutputTokensDetails `json:"output_tokens_details,omitempty"`
Model string `json:"model,omitempty"` // internal field (not from OpenAI API)
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
InputTokensDetails *openaiInputTokensDetails `json:"input_tokens_details,omitempty"`
OutputTokensDetails *openaiOutputTokensDetails `json:"output_tokens_details,omitempty"`
Model string `json:"model,omitempty"` // internal field (not from OpenAI API)
NativeWebSearchCount int `json:"nativewebsearchcount,omitempty"` // internal field (not from OpenAI API)
}

type openaiInputTokensDetails struct {
Expand Down Expand Up @@ -323,12 +325,13 @@ type openaiBlockState struct {
}

type openaiStreamingState struct {
blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming
toolUseData map[string]*uctypes.UIMessageDataToolUse // Use toolCallId as key
msgID string
model string
stepStarted bool
chatOpts uctypes.WaveChatOpts
blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming
toolUseData map[string]*uctypes.UIMessageDataToolUse // Use toolCallId as key
msgID string
model string
stepStarted bool
chatOpts uctypes.WaveChatOpts
webSearchCount int
}

// ---------- Public entrypoint ----------
Expand Down Expand Up @@ -759,7 +762,7 @@ func handleOpenAIEvent(
}

// Extract partial message if available
finalMessages, _ := extractMessageAndToolsFromResponse(ev.Response, state.toolUseData)
finalMessages, _ := extractMessageAndToolsFromResponse(ev.Response, state)

_ = sse.AiMsgError(errorMsg)
return &uctypes.WaveStopReason{
Expand All @@ -772,7 +775,7 @@ func handleOpenAIEvent(
}

// Extract the final message and tool calls from the response output
finalMessages, toolCalls := extractMessageAndToolsFromResponse(ev.Response, state.toolUseData)
finalMessages, toolCalls := extractMessageAndToolsFromResponse(ev.Response, state)

stopKind := uctypes.StopKindDone
if len(toolCalls) > 0 {
Expand Down Expand Up @@ -820,6 +823,19 @@ func handleOpenAIEvent(
}
return nil, nil

case "response.web_search_call.in_progress":
return nil, nil

case "response.web_search_call.searching":
return nil, nil

case "response.web_search_call.completed":
state.webSearchCount++
return nil, nil

case "response.output_text.annotation.added":
return nil, nil

default:
// log unknown events for debugging
log.Printf("OpenAI: unknown event: %s, data: %s", eventName, data)
Expand Down Expand Up @@ -857,9 +873,8 @@ func createToolUseData(toolCallID, toolName string, toolDef *uctypes.ToolDefinit
return toolUseData
}


// extractMessageAndToolsFromResponse extracts the final OpenAI message and tool calls from the completed response
func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[string]*uctypes.UIMessageDataToolUse) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) {
func extractMessageAndToolsFromResponse(resp openaiResponse, state *openaiStreamingState) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) {
var messageContent []OpenAIMessageContent
var toolCalls []uctypes.WaveToolCall
var messages []*OpenAIChatMessage
Expand Down Expand Up @@ -893,7 +908,7 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
}

// Attach UIToolUseData if available
if data, ok := toolUseData[outputItem.CallId]; ok {
if data, ok := state.toolUseData[outputItem.CallId]; ok {
toolCall.ToolUseData = data
} else {
log.Printf("AI no data-tooluse for %s (callid: %s)\n", outputItem.Id, outputItem.CallId)
Expand All @@ -907,7 +922,7 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
argsStr = outputItem.Arguments
}
var toolUseDataPtr *uctypes.UIMessageDataToolUse
if data, ok := toolUseData[outputItem.CallId]; ok {
if data, ok := state.toolUseData[outputItem.CallId]; ok {
toolUseDataPtr = data
}
functionCallMsg := &OpenAIChatMessage{
Expand All @@ -925,17 +940,20 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
}

// Create OpenAIChatMessage with assistant message (first in slice)
if resp.Usage != nil {
usage := resp.Usage
if usage != nil {
resp.Usage.Model = resp.Model
if state.webSearchCount > 0 {
usage.NativeWebSearchCount = state.webSearchCount
}
}

assistantMessage := &OpenAIChatMessage{
MessageId: uuid.New().String(),
Message: &OpenAIMessage{
Role: "assistant",
Content: messageContent,
},
Usage: resp.Usage,
Usage: usage,
}

// Return assistant message first, followed by function call messages
Expand Down
20 changes: 14 additions & 6 deletions pkg/aiusechat/openai/openai-convertmessage.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ type OpenAIRequest struct {
}

type OpenAIRequestTool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters any `json:"parameters"`
Strict bool `json:"strict"`
Type string `json:"type"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Parameters any `json:"parameters,omitempty"`
Strict bool `json:"strict,omitempty"`
}

// ConvertToolDefinitionToOpenAI converts a generic ToolDefinition to OpenAI format
Expand Down Expand Up @@ -113,13 +113,13 @@ func debugPrintReq(req *OpenAIRequest, endpoint string) {
// buildOpenAIHTTPRequest creates a complete HTTP request for the OpenAI API
func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes.WaveChatOpts, cont *uctypes.WaveContinueResponse) (*http.Request, error) {
opts := chatOpts.Config

// If continuing from premium rate limit, downgrade to default model and low thinking
if cont != nil && cont.ContinueFromKind == uctypes.StopKindPremiumRateLimit {
opts.Model = uctypes.DefaultOpenAIModel
opts.ThinkingLevel = uctypes.ThinkingLevelLow
}

if opts.Model == "" {
return nil, errors.New("opts.model is required")
}
Expand Down Expand Up @@ -183,6 +183,14 @@ func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes.
reqBody.Tools = append(reqBody.Tools, convertedTool)
}

// Add native web search tool if enabled
if chatOpts.AllowNativeWebSearch {
webSearchTool := OpenAIRequestTool{
Type: "web_search",
}
reqBody.Tools = append(reqBody.Tools, webSearchTool)
}

// Set reasoning based on thinking level
if opts.ThinkingLevel != "" {
reqBody.Reasoning = &ReasoningType{
Expand Down
10 changes: 6 additions & 4 deletions pkg/aiusechat/uctypes/usechat-types.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,11 @@ type AIChat struct {
}

type AIUsage struct {
APIType string `json:"apitype"`
Model string `json:"model"`
InputTokens int `json:"inputtokens,omitempty"`
OutputTokens int `json:"outputtokens,omitempty"`
APIType string `json:"apitype"`
Model string `json:"model"`
InputTokens int `json:"inputtokens,omitempty"`
OutputTokens int `json:"outputtokens,omitempty"`
NativeWebSearchCount int `json:"nativewebsearchcount,omitempty"`
}

type AIMetrics struct {
Expand Down Expand Up @@ -424,6 +425,7 @@ type WaveChatOpts struct {
TabStateGenerator func() (string, []ToolDefinition, error)
WidgetAccess bool
RegisterToolApproval func(string)
AllowNativeWebSearch bool

// emphemeral to the step
TabState string
Expand Down
42 changes: 23 additions & 19 deletions pkg/aiusechat/usechat.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ func getUsage(msgs []uctypes.GenAIMessage) uctypes.AIUsage {
} else {
rtn.InputTokens += usage.InputTokens
rtn.OutputTokens += usage.OutputTokens
rtn.NativeWebSearchCount += usage.NativeWebSearchCount
}
}
}
Expand Down Expand Up @@ -369,9 +370,10 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
}
if len(rtnMessage) > 0 {
usage := getUsage(rtnMessage)
log.Printf("usage: input=%d output=%d\n", usage.InputTokens, usage.OutputTokens)
log.Printf("usage: input=%d output=%d websearch=%d\n", usage.InputTokens, usage.OutputTokens, usage.NativeWebSearchCount)
metrics.Usage.InputTokens += usage.InputTokens
metrics.Usage.OutputTokens += usage.OutputTokens
metrics.Usage.NativeWebSearchCount += usage.NativeWebSearchCount
if usage.Model != "" && metrics.Usage.Model != usage.Model {
metrics.Usage.Model = "mixed"
}
Expand Down Expand Up @@ -526,24 +528,25 @@ func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, me

func sendAIMetricsTelemetry(ctx context.Context, metrics *uctypes.AIMetrics) {
event := telemetrydata.MakeTEvent("waveai:post", telemetrydata.TEventProps{
WaveAIAPIType: metrics.Usage.APIType,
WaveAIModel: metrics.Usage.Model,
WaveAIInputTokens: metrics.Usage.InputTokens,
WaveAIOutputTokens: metrics.Usage.OutputTokens,
WaveAIRequestCount: metrics.RequestCount,
WaveAIToolUseCount: metrics.ToolUseCount,
WaveAIToolUseErrorCount: metrics.ToolUseErrorCount,
WaveAIToolDetail: metrics.ToolDetail,
WaveAIPremiumReq: metrics.PremiumReqCount,
WaveAIProxyReq: metrics.ProxyReqCount,
WaveAIHadError: metrics.HadError,
WaveAIImageCount: metrics.ImageCount,
WaveAIPDFCount: metrics.PDFCount,
WaveAITextDocCount: metrics.TextDocCount,
WaveAITextLen: metrics.TextLen,
WaveAIFirstByteMs: metrics.FirstByteLatency,
WaveAIRequestDurMs: metrics.RequestDuration,
WaveAIWidgetAccess: metrics.WidgetAccess,
WaveAIAPIType: metrics.Usage.APIType,
WaveAIModel: metrics.Usage.Model,
WaveAIInputTokens: metrics.Usage.InputTokens,
WaveAIOutputTokens: metrics.Usage.OutputTokens,
WaveAINativeWebSearchCount: metrics.Usage.NativeWebSearchCount,
WaveAIRequestCount: metrics.RequestCount,
WaveAIToolUseCount: metrics.ToolUseCount,
WaveAIToolUseErrorCount: metrics.ToolUseErrorCount,
WaveAIToolDetail: metrics.ToolDetail,
WaveAIPremiumReq: metrics.PremiumReqCount,
WaveAIProxyReq: metrics.ProxyReqCount,
WaveAIHadError: metrics.HadError,
WaveAIImageCount: metrics.ImageCount,
WaveAIPDFCount: metrics.PDFCount,
WaveAITextDocCount: metrics.TextDocCount,
WaveAITextLen: metrics.TextLen,
WaveAIFirstByteMs: metrics.FirstByteLatency,
WaveAIRequestDurMs: metrics.RequestDuration,
WaveAIWidgetAccess: metrics.WidgetAccess,
})
_ = telemetry.RecordTEvent(ctx, event)
}
Expand Down Expand Up @@ -602,6 +605,7 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) {
Config: *aiOpts,
WidgetAccess: req.WidgetAccess,
RegisterToolApproval: RegisterToolApproval,
AllowNativeWebSearch: true,
}
if chatOpts.Config.APIType == APIType_OpenAI {
chatOpts.SystemPrompt = []string{SystemPromptText_OpenAI}
Expand Down
37 changes: 19 additions & 18 deletions pkg/telemetry/telemetrydata/telemetrydata.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,24 +101,25 @@ type TEventProps struct {
CountWSLConn int `json:"count:wslconn,omitempty"`
CountViews map[string]int `json:"count:views,omitempty"`

WaveAIAPIType string `json:"waveai:apitype,omitempty"`
WaveAIModel string `json:"waveai:model,omitempty"`
WaveAIInputTokens int `json:"waveai:inputtokens,omitempty"`
WaveAIOutputTokens int `json:"waveai:outputtokens,omitempty"`
WaveAIRequestCount int `json:"waveai:requestcount,omitempty"`
WaveAIToolUseCount int `json:"waveai:toolusecount,omitempty"`
WaveAIToolUseErrorCount int `json:"waveai:tooluseerrorcount,omitempty"`
WaveAIToolDetail map[string]int `json:"waveai:tooldetail,omitempty"`
WaveAIPremiumReq int `json:"waveai:premiumreq,omitempty"`
WaveAIProxyReq int `json:"waveai:proxyreq,omitempty"`
WaveAIHadError bool `json:"waveai:haderror,omitempty"`
WaveAIImageCount int `json:"waveai:imagecount,omitempty"`
WaveAIPDFCount int `json:"waveai:pdfcount,omitempty"`
WaveAITextDocCount int `json:"waveai:textdoccount,omitempty"`
WaveAITextLen int `json:"waveai:textlen,omitempty"`
WaveAIFirstByteMs int `json:"waveai:firstbytems,omitempty"` // ms
WaveAIRequestDurMs int `json:"waveai:requestdurms,omitempty"` // ms
WaveAIWidgetAccess bool `json:"waveai:widgetaccess,omitempty"`
WaveAIAPIType string `json:"waveai:apitype,omitempty"`
WaveAIModel string `json:"waveai:model,omitempty"`
WaveAIInputTokens int `json:"waveai:inputtokens,omitempty"`
WaveAIOutputTokens int `json:"waveai:outputtokens,omitempty"`
WaveAINativeWebSearchCount int `json:"waveai:nativewebsearchcount,omitempty"`
WaveAIRequestCount int `json:"waveai:requestcount,omitempty"`
WaveAIToolUseCount int `json:"waveai:toolusecount,omitempty"`
WaveAIToolUseErrorCount int `json:"waveai:tooluseerrorcount,omitempty"`
WaveAIToolDetail map[string]int `json:"waveai:tooldetail,omitempty"`
WaveAIPremiumReq int `json:"waveai:premiumreq,omitempty"`
WaveAIProxyReq int `json:"waveai:proxyreq,omitempty"`
WaveAIHadError bool `json:"waveai:haderror,omitempty"`
WaveAIImageCount int `json:"waveai:imagecount,omitempty"`
WaveAIPDFCount int `json:"waveai:pdfcount,omitempty"`
WaveAITextDocCount int `json:"waveai:textdoccount,omitempty"`
WaveAITextLen int `json:"waveai:textlen,omitempty"`
WaveAIFirstByteMs int `json:"waveai:firstbytems,omitempty"` // ms
WaveAIRequestDurMs int `json:"waveai:requestdurms,omitempty"` // ms
WaveAIWidgetAccess bool `json:"waveai:widgetaccess,omitempty"`

UserSet *TEventUserProps `json:"$set,omitempty"`
UserSetOnce *TEventUserProps `json:"$set_once,omitempty"`
Expand Down
Loading