Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions pkg/aiusechat/openaichat/openaichat-backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ func processChatStream(
) (*uctypes.WaveStopReason, *StoredChatMessage, error) {
decoder := eventsource.NewDecoder(body)
var textBuilder strings.Builder
var reasoningBuilder strings.Builder
reasoningStarted := false
msgID := uuid.New().String()
textID := uuid.New().String()
var finishReason string
Expand Down Expand Up @@ -128,7 +130,7 @@ func processChatStream(
break
}
if sseHandler.Err() != nil {
partialMsg := extractPartialTextMessage(msgID, textBuilder.String())
partialMsg := extractPartialTextMessage(msgID, textBuilder.String(), reasoningBuilder.String())
return &uctypes.WaveStopReason{
Kind: uctypes.StopKindCanceled,
ErrorType: "client_disconnect",
Expand Down Expand Up @@ -160,6 +162,10 @@ func processChatStream(

choice := chunk.Choices[0]
if choice.Delta.Content != "" {
if reasoningStarted {
reasoningStarted = false
_ = sseHandler.AiMsgReasoningEnd(msgID)
}
if !textStarted {
_ = sseHandler.AiMsgTextStart(textID)
textStarted = true
Expand All @@ -168,6 +174,15 @@ func processChatStream(
_ = sseHandler.AiMsgTextDelta(textID, choice.Delta.Content)
}

if choice.Delta.ReasoningContent != "" {
if !reasoningStarted {
reasoningStarted = true
_ = sseHandler.AiMsgReasoningStart(msgID)
}
reasoningBuilder.WriteString(choice.Delta.ReasoningContent)
_ = sseHandler.AiMsgReasoningDelta(msgID, choice.Delta.ReasoningContent)
}

if len(choice.Delta.ToolCalls) > 0 {
for _, tcDelta := range choice.Delta.ToolCalls {
idx := tcDelta.Index
Expand Down Expand Up @@ -239,7 +254,8 @@ func processChatStream(
assistantMsg := &StoredChatMessage{
MessageId: msgID,
Message: ChatRequestMessage{
Role: "assistant",
Role: "assistant",
ReasoningContent: reasoningBuilder.String(),
},
}

Expand All @@ -249,6 +265,11 @@ func processChatStream(
assistantMsg.Message.Content = textBuilder.String()
}

// reasoning-end is emitted inline on first content delta (if reasoning was active);
// if no content ever arrived (e.g. max_tokens during reasoning), close it here.
if reasoningStarted {
_ = sseHandler.AiMsgReasoningEnd(msgID)
}
if textStarted {
_ = sseHandler.AiMsgTextEnd(textID)
}
Expand All @@ -260,16 +281,17 @@ func processChatStream(
return stopReason, assistantMsg, nil
}

func extractPartialTextMessage(msgID string, text string) *StoredChatMessage {
if text == "" {
func extractPartialTextMessage(msgID string, text string, reasoning string) *StoredChatMessage {
if text == "" && reasoning == "" {
return nil
}

return &StoredChatMessage{
MessageId: msgID,
Message: ChatRequestMessage{
Role: "assistant",
Content: text,
Role: "assistant",
Content: text,
ReasoningContent: reasoning,
},
}
}
161 changes: 161 additions & 0 deletions pkg/aiusechat/openaichat/openaichat-backend_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Copyright 2026, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0

package openaichat

import (
"encoding/json"
"strings"
"testing"

"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
)

func TestReasoningContentRoundTrip(t *testing.T) {
original := ChatRequestMessage{
Role: "assistant",
Content: "The answer is 42.",
ReasoningContent: "Let me think about this carefully...",
ToolCalls: []ToolCall{
{ID: "call_1", Type: "function", Function: ToolFunctionCall{Name: "search", Arguments: `{}`}},
},
}

data, err := json.Marshal(original)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}

var restored ChatRequestMessage
if err := json.Unmarshal(data, &restored); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}

if restored.Role != original.Role {
t.Errorf("Role: got %q, want %q", restored.Role, original.Role)
}
if restored.Content != original.Content {
t.Errorf("Content: got %q, want %q", restored.Content, original.Content)
}
if restored.ReasoningContent != original.ReasoningContent {
t.Errorf("ReasoningContent: got %q, want %q", restored.ReasoningContent, original.ReasoningContent)
}
if len(restored.ToolCalls) != len(original.ToolCalls) {
t.Fatalf("ToolCalls length: got %d, want %d", len(restored.ToolCalls), len(original.ToolCalls))
}
if restored.ToolCalls[0].ID != original.ToolCalls[0].ID {
t.Errorf("ToolCalls[0].ID: got %q, want %q", restored.ToolCalls[0].ID, original.ToolCalls[0].ID)
}
if restored.ToolCalls[0].Function.Name != original.ToolCalls[0].Function.Name {
t.Errorf("ToolCalls[0].Function.Name: got %q, want %q", restored.ToolCalls[0].Function.Name, original.ToolCalls[0].Function.Name)
}
}

func TestReasoningContentOmittedWhenEmpty(t *testing.T) {
msg := ChatRequestMessage{
Role: "user",
Content: "Hello",
}

data, err := json.Marshal(msg)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}

jsonStr := string(data)
if strings.Contains(jsonStr, "reasoning_content") {
t.Errorf("JSON should NOT contain 'reasoning_content' when empty, got: %s", jsonStr)
}
}

func TestStreamChunkWithReasoningContent(t *testing.T) {
chunkJSON := `{"choices":[{"delta":{"reasoning_content":"I need to search for this...","content":"Let me search."}}]}`

var chunk StreamChunk
if err := json.Unmarshal([]byte(chunkJSON), &chunk); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}

if len(chunk.Choices) == 0 {
t.Fatal("expected at least one choice")
}

delta := chunk.Choices[0].Delta
if delta.ReasoningContent != "I need to search for this..." {
t.Errorf("ReasoningContent: got %q, want %q", delta.ReasoningContent, "I need to search for this...")
}
if delta.Content != "Let me search." {
t.Errorf("Content: got %q, want %q", delta.Content, "Let me search.")
}
}

func TestCleanPreservesReasoningContent(t *testing.T) {
msg := &ChatRequestMessage{
Role: "assistant",
Content: "text",
ReasoningContent: "thinking",
ToolCalls: []ToolCall{
{ID: "call_1", Type: "function", Function: ToolFunctionCall{Name: "f", Arguments: "{}"}, ToolUseData: &uctypes.UIMessageDataToolUse{}},
},
}

cleaned := msg.clean()

if cleaned == msg {
t.Error("clean() should return a different pointer")
}

if cleaned.ReasoningContent != "thinking" {
t.Errorf("ReasoningContent: got %q, want %q", cleaned.ReasoningContent, "thinking")
}

if cleaned.Content != "text" {
t.Errorf("Content: got %q, want %q", cleaned.Content, "text")
}

if len(cleaned.ToolCalls) != 1 {
t.Fatalf("ToolCalls length: got %d, want 1", len(cleaned.ToolCalls))
}
if cleaned.ToolCalls[0].ToolUseData != nil {
t.Error("ToolCalls[0].ToolUseData should be nil after clean()")
}
}

func TestExtractPartialTextMessageWithReasoning(t *testing.T) {
msg := extractPartialTextMessage("msg-1", "partial text", "partial reasoning")
if msg == nil {
t.Fatal("expected non-nil message when text is present")
}
if msg.MessageId != "msg-1" {
t.Errorf("MessageId: got %q, want %q", msg.MessageId, "msg-1")
}
if msg.Message.Content != "partial text" {
t.Errorf("Content: got %q, want %q", msg.Message.Content, "partial text")
}
if msg.Message.ReasoningContent != "partial reasoning" {
t.Errorf("ReasoningContent: got %q, want %q", msg.Message.ReasoningContent, "partial reasoning")
}
if msg.Message.Role != "assistant" {
t.Errorf("Role: got %q, want %q", msg.Message.Role, "assistant")
}
}

func TestExtractPartialTextMessageWithOnlyReasoning(t *testing.T) {
msg := extractPartialTextMessage("msg-2", "", "some reasoning")
if msg == nil {
t.Fatal("expected non-nil message when reasoning is present")
}
if msg.Message.Content != "" {
t.Errorf("Content: got %q, want empty", msg.Message.Content)
}
if msg.Message.ReasoningContent != "some reasoning" {
t.Errorf("ReasoningContent: got %q, want %q", msg.Message.ReasoningContent, "some reasoning")
}
}

func TestExtractPartialTextMessageEmpty(t *testing.T) {
msg := extractPartialTextMessage("msg-3", "", "")
if msg != nil {
t.Fatal("expected nil when both text and reasoning are empty")
}
}
41 changes: 23 additions & 18 deletions pkg/aiusechat/openaichat/openaichat-types.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,32 @@ type ChatImageUrl struct {
}

type ChatRequestMessage struct {
Role string `json:"role"` // "system","user","assistant","tool"
Content string `json:"-"` // plain text (used when ContentParts is nil)
ContentParts []ChatContentPart `json:"-"` // multimodal parts (used when images present)
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // assistant tool-call message
ToolCallID string `json:"tool_call_id,omitempty"` // for role:"tool"
Name string `json:"name,omitempty"` // tool name on role:"tool"
Role string `json:"role"` // "system","user","assistant","tool"
ReasoningContent string `json:"-"` // DeepSeek/OpenAI reasoning_content (top-level string)
Content string `json:"-"` // plain text (used when ContentParts is nil)
ContentParts []ChatContentPart `json:"-"` // multimodal parts (used when images present)
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // assistant tool-call message
ToolCallID string `json:"tool_call_id,omitempty"` // for role:"tool"
Name string `json:"name,omitempty"` // tool name on role:"tool"
}

// chatRequestMessageJSON is the wire format for ChatRequestMessage
type chatRequestMessageJSON struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
Role string `json:"role"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Content json.RawMessage `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}

func (cm ChatRequestMessage) MarshalJSON() ([]byte, error) {
raw := chatRequestMessageJSON{
Role: cm.Role,
ToolCalls: cm.ToolCalls,
ToolCallID: cm.ToolCallID,
Name: cm.Name,
Role: cm.Role,
ReasoningContent: cm.ReasoningContent,
ToolCalls: cm.ToolCalls,
ToolCallID: cm.ToolCallID,
Name: cm.Name,
}
if len(cm.ContentParts) > 0 {
b, err := json.Marshal(cm.ContentParts)
Expand All @@ -96,6 +99,7 @@ func (cm *ChatRequestMessage) UnmarshalJSON(data []byte) error {
return err
}
cm.Role = raw.Role
cm.ReasoningContent = raw.ReasoningContent
cm.ToolCalls = raw.ToolCalls
cm.ToolCallID = raw.ToolCallID
cm.Name = raw.Name
Expand Down Expand Up @@ -193,9 +197,10 @@ type StreamChoice struct {

// This is the important part:
type ContentDelta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"`
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"`
}

type ToolCallDelta struct {
Expand Down