diff --git a/pkg/aiusechat/openaichat/openaichat-backend.go b/pkg/aiusechat/openaichat/openaichat-backend.go index 635f334873..70b825e858 100644 --- a/pkg/aiusechat/openaichat/openaichat-backend.go +++ b/pkg/aiusechat/openaichat/openaichat-backend.go @@ -101,6 +101,8 @@ func processChatStream( ) (*uctypes.WaveStopReason, *StoredChatMessage, error) { decoder := eventsource.NewDecoder(body) var textBuilder strings.Builder + var reasoningBuilder strings.Builder + reasoningStarted := false msgID := uuid.New().String() textID := uuid.New().String() var finishReason string @@ -128,7 +130,7 @@ func processChatStream( break } if sseHandler.Err() != nil { - partialMsg := extractPartialTextMessage(msgID, textBuilder.String()) + partialMsg := extractPartialTextMessage(msgID, textBuilder.String(), reasoningBuilder.String()) return &uctypes.WaveStopReason{ Kind: uctypes.StopKindCanceled, ErrorType: "client_disconnect", @@ -160,6 +162,10 @@ func processChatStream( choice := chunk.Choices[0] if choice.Delta.Content != "" { + if reasoningStarted { + reasoningStarted = false + _ = sseHandler.AiMsgReasoningEnd(msgID) + } if !textStarted { _ = sseHandler.AiMsgTextStart(textID) textStarted = true @@ -168,6 +174,15 @@ func processChatStream( _ = sseHandler.AiMsgTextDelta(textID, choice.Delta.Content) } + if choice.Delta.ReasoningContent != "" { + if !reasoningStarted { + reasoningStarted = true + _ = sseHandler.AiMsgReasoningStart(msgID) + } + reasoningBuilder.WriteString(choice.Delta.ReasoningContent) + _ = sseHandler.AiMsgReasoningDelta(msgID, choice.Delta.ReasoningContent) + } + if len(choice.Delta.ToolCalls) > 0 { for _, tcDelta := range choice.Delta.ToolCalls { idx := tcDelta.Index @@ -239,7 +254,8 @@ func processChatStream( assistantMsg := &StoredChatMessage{ MessageId: msgID, Message: ChatRequestMessage{ - Role: "assistant", + Role: "assistant", + ReasoningContent: reasoningBuilder.String(), }, } @@ -249,6 +265,11 @@ func processChatStream( assistantMsg.Message.Content = textBuilder.String() } + // reasoning-end is emitted inline on first content delta (if reasoning was active); + // if no content ever arrived (e.g. max_tokens during reasoning), close it here. + if reasoningStarted { + _ = sseHandler.AiMsgReasoningEnd(msgID) + } if textStarted { _ = sseHandler.AiMsgTextEnd(textID) } @@ -260,16 +281,17 @@ func processChatStream( return stopReason, assistantMsg, nil } -func extractPartialTextMessage(msgID string, text string) *StoredChatMessage { - if text == "" { +func extractPartialTextMessage(msgID string, text string, reasoning string) *StoredChatMessage { + if text == "" && reasoning == "" { return nil } return &StoredChatMessage{ MessageId: msgID, Message: ChatRequestMessage{ - Role: "assistant", - Content: text, + Role: "assistant", + Content: text, + ReasoningContent: reasoning, }, } } diff --git a/pkg/aiusechat/openaichat/openaichat-backend_test.go b/pkg/aiusechat/openaichat/openaichat-backend_test.go new file mode 100644 index 0000000000..1f18f9e438 --- /dev/null +++ b/pkg/aiusechat/openaichat/openaichat-backend_test.go @@ -0,0 +1,161 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package openaichat + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" +) + +func TestReasoningContentRoundTrip(t *testing.T) { + original := ChatRequestMessage{ + Role: "assistant", + Content: "The answer is 42.", + ReasoningContent: "Let me think about this carefully...", + ToolCalls: []ToolCall{ + {ID: "call_1", Type: "function", Function: ToolFunctionCall{Name: "search", Arguments: `{}`}}, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + var restored ChatRequestMessage + if err := json.Unmarshal(data, &restored); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if restored.Role != original.Role { + t.Errorf("Role: got %q, want %q", restored.Role, original.Role) + } + if restored.Content != original.Content { + t.Errorf("Content: got %q, want %q", restored.Content, original.Content) + } + if restored.ReasoningContent != original.ReasoningContent { + t.Errorf("ReasoningContent: got %q, want %q", restored.ReasoningContent, original.ReasoningContent) + } + if len(restored.ToolCalls) != len(original.ToolCalls) { + t.Fatalf("ToolCalls length: got %d, want %d", len(restored.ToolCalls), len(original.ToolCalls)) + } + if restored.ToolCalls[0].ID != original.ToolCalls[0].ID { + t.Errorf("ToolCalls[0].ID: got %q, want %q", restored.ToolCalls[0].ID, original.ToolCalls[0].ID) + } + if restored.ToolCalls[0].Function.Name != original.ToolCalls[0].Function.Name { + t.Errorf("ToolCalls[0].Function.Name: got %q, want %q", restored.ToolCalls[0].Function.Name, original.ToolCalls[0].Function.Name) + } +} + +func TestReasoningContentOmittedWhenEmpty(t *testing.T) { + msg := ChatRequestMessage{ + Role: "user", + Content: "Hello", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + jsonStr := string(data) + if strings.Contains(jsonStr, "reasoning_content") { + t.Errorf("JSON should NOT contain 'reasoning_content' when empty, got: %s", jsonStr) + } +} + +func TestStreamChunkWithReasoningContent(t *testing.T) { + chunkJSON := `{"choices":[{"delta":{"reasoning_content":"I need to search for this...","content":"Let me search."}}]}` + + var chunk StreamChunk + if err := json.Unmarshal([]byte(chunkJSON), &chunk); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if len(chunk.Choices) == 0 { + t.Fatal("expected at least one choice") + } + + delta := chunk.Choices[0].Delta + if delta.ReasoningContent != "I need to search for this..." { + t.Errorf("ReasoningContent: got %q, want %q", delta.ReasoningContent, "I need to search for this...") + } + if delta.Content != "Let me search." { + t.Errorf("Content: got %q, want %q", delta.Content, "Let me search.") + } +} + +func TestCleanPreservesReasoningContent(t *testing.T) { + msg := &ChatRequestMessage{ + Role: "assistant", + Content: "text", + ReasoningContent: "thinking", + ToolCalls: []ToolCall{ + {ID: "call_1", Type: "function", Function: ToolFunctionCall{Name: "f", Arguments: "{}"}, ToolUseData: &uctypes.UIMessageDataToolUse{}}, + }, + } + + cleaned := msg.clean() + + if cleaned == msg { + t.Error("clean() should return a different pointer") + } + + if cleaned.ReasoningContent != "thinking" { + t.Errorf("ReasoningContent: got %q, want %q", cleaned.ReasoningContent, "thinking") + } + + if cleaned.Content != "text" { + t.Errorf("Content: got %q, want %q", cleaned.Content, "text") + } + + if len(cleaned.ToolCalls) != 1 { + t.Fatalf("ToolCalls length: got %d, want 1", len(cleaned.ToolCalls)) + } + if cleaned.ToolCalls[0].ToolUseData != nil { + t.Error("ToolCalls[0].ToolUseData should be nil after clean()") + } +} + +func TestExtractPartialTextMessageWithReasoning(t *testing.T) { + msg := extractPartialTextMessage("msg-1", "partial text", "partial reasoning") + if msg == nil { + t.Fatal("expected non-nil message when text is present") + } + if msg.MessageId != "msg-1" { + t.Errorf("MessageId: got %q, want %q", msg.MessageId, "msg-1") + } + if msg.Message.Content != "partial text" { + t.Errorf("Content: got %q, want %q", msg.Message.Content, "partial text") + } + if msg.Message.ReasoningContent != "partial reasoning" { + t.Errorf("ReasoningContent: got %q, want %q", msg.Message.ReasoningContent, "partial reasoning") + } + if msg.Message.Role != "assistant" { + t.Errorf("Role: got %q, want %q", msg.Message.Role, "assistant") + } +} + +func TestExtractPartialTextMessageWithOnlyReasoning(t *testing.T) { + msg := extractPartialTextMessage("msg-2", "", "some reasoning") + if msg == nil { + t.Fatal("expected non-nil message when reasoning is present") + } + if msg.Message.Content != "" { + t.Errorf("Content: got %q, want empty", msg.Message.Content) + } + if msg.Message.ReasoningContent != "some reasoning" { + t.Errorf("ReasoningContent: got %q, want %q", msg.Message.ReasoningContent, "some reasoning") + } +} + +func TestExtractPartialTextMessageEmpty(t *testing.T) { + msg := extractPartialTextMessage("msg-3", "", "") + if msg != nil { + t.Fatal("expected nil when both text and reasoning are empty") + } +} diff --git a/pkg/aiusechat/openaichat/openaichat-types.go b/pkg/aiusechat/openaichat/openaichat-types.go index 18d28e3b20..9eacd54241 100644 --- a/pkg/aiusechat/openaichat/openaichat-types.go +++ b/pkg/aiusechat/openaichat/openaichat-types.go @@ -50,29 +50,32 @@ type ChatImageUrl struct { } type ChatRequestMessage struct { - Role string `json:"role"` // "system","user","assistant","tool" - Content string `json:"-"` // plain text (used when ContentParts is nil) - ContentParts []ChatContentPart `json:"-"` // multimodal parts (used when images present) - ToolCalls []ToolCall `json:"tool_calls,omitempty"` // assistant tool-call message - ToolCallID string `json:"tool_call_id,omitempty"` // for role:"tool" - Name string `json:"name,omitempty"` // tool name on role:"tool" + Role string `json:"role"` // "system","user","assistant","tool" + ReasoningContent string `json:"-"` // DeepSeek/OpenAI reasoning_content (top-level string) + Content string `json:"-"` // plain text (used when ContentParts is nil) + ContentParts []ChatContentPart `json:"-"` // multimodal parts (used when images present) + ToolCalls []ToolCall `json:"tool_calls,omitempty"` // assistant tool-call message + ToolCallID string `json:"tool_call_id,omitempty"` // for role:"tool" + Name string `json:"name,omitempty"` // tool name on role:"tool" } // chatRequestMessageJSON is the wire format for ChatRequestMessage type chatRequestMessageJSON struct { - Role string `json:"role"` - Content json.RawMessage `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` - Name string `json:"name,omitempty"` + Role string `json:"role"` + ReasoningContent string `json:"reasoning_content,omitempty"` + Content json.RawMessage `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` } func (cm ChatRequestMessage) MarshalJSON() ([]byte, error) { raw := chatRequestMessageJSON{ - Role: cm.Role, - ToolCalls: cm.ToolCalls, - ToolCallID: cm.ToolCallID, - Name: cm.Name, + Role: cm.Role, + ReasoningContent: cm.ReasoningContent, + ToolCalls: cm.ToolCalls, + ToolCallID: cm.ToolCallID, + Name: cm.Name, } if len(cm.ContentParts) > 0 { b, err := json.Marshal(cm.ContentParts) @@ -96,6 +99,7 @@ func (cm *ChatRequestMessage) UnmarshalJSON(data []byte) error { return err } cm.Role = raw.Role + cm.ReasoningContent = raw.ReasoningContent cm.ToolCalls = raw.ToolCalls cm.ToolCallID = raw.ToolCallID cm.Name = raw.Name @@ -193,9 +197,10 @@ type StreamChoice struct { // This is the important part: type ContentDelta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"` + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"` } type ToolCallDelta struct {