diff --git a/pkg/model/provider/openai/ws_stream.go b/pkg/model/provider/openai/ws_stream.go index 4b2fc63bc..40303b570 100644 --- a/pkg/model/provider/openai/ws_stream.go +++ b/pkg/model/provider/openai/ws_stream.go @@ -1,6 +1,7 @@ package openai import ( + "bytes" "context" "encoding/json" "fmt" @@ -124,30 +125,39 @@ func (s *wsStream) Next() bool { return false } - _, data, err := s.conn.ReadMessage() - if err != nil { - if websocket.IsCloseError(err, - websocket.CloseNormalClosure, - websocket.CloseGoingAway, - websocket.CloseNoStatusReceived, - ) { + for { + _, data, err := s.conn.ReadMessage() + if err != nil { + if websocket.IsCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived, + ) { + s.done = true + return false + } + s.err = fmt.Errorf("websocket read: %w", err) s.done = true return false } - s.err = fmt.Errorf("websocket read: %w", err) - s.done = true - return false - } - var event responses.ResponseStreamEventUnion - if err := json.Unmarshal(data, &event); err != nil { - s.err = fmt.Errorf("websocket unmarshal event: %w", err) - s.done = true - return false - } + if len(bytes.TrimSpace(data)) == 0 { + slog.Debug("Ignoring empty WebSocket frame") + continue + } - s.current = event + var event responses.ResponseStreamEventUnion + if err := json.Unmarshal(data, &event); err != nil { + s.err = fmt.Errorf("websocket unmarshal event: %w", err) + s.done = true + return false + } + + s.current = event + break + } + event := s.current slog.Debug("WebSocket event received", "type", event.Type) // Check for server-side error events. diff --git a/pkg/model/provider/openai/ws_stream_test.go b/pkg/model/provider/openai/ws_stream_test.go index 90f6e5427..a95a9c8b8 100644 --- a/pkg/model/provider/openai/ws_stream_test.go +++ b/pkg/model/provider/openai/ws_stream_test.go @@ -134,6 +134,70 @@ func TestWSStream_TextDelta(t *testing.T) { assert.ErrorIs(t, err, io.EOF) } +func TestWSStream_IgnoresEmptyFrames(t *testing.T) { + t.Parallel() + + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("WebSocket upgrade failed: %v", err) + return + } + defer conn.Close() + + _, _, err = conn.ReadMessage() + if err != nil { + t.Errorf("Failed to read response.create: %v", err) + return + } + + require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(" "))) + require.NoError(t, conn.WriteJSON(map[string]any{ + "type": "response.output_text.delta", + "delta": "Hi from OVH", + "item_id": "item_1", + })) + require.NoError(t, conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_ovh", + "output": []any{}, + "usage": map[string]any{ + "input_tokens": 4, + "output_tokens": 3, + "total_tokens": 7, + "input_tokens_details": map[string]any{"cached_tokens": 0}, + "output_tokens_details": map[string]any{"reasoning_tokens": 0}, + }, + }, + })) + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + stream, err := dialWebSocket(t.Context(), wsURL, http.Header{}, defaultTestParams()) + require.NoError(t, err) + defer stream.Close() + + adapter := newResponseStreamAdapter(stream, true) + + resp, err := adapter.Recv() + require.NoError(t, err) + require.Len(t, resp.Choices, 1) + assert.Equal(t, "Hi from OVH", resp.Choices[0].Delta.Content) + + resp, err = adapter.Recv() + require.NoError(t, err) + require.Len(t, resp.Choices, 1) + assert.Equal(t, chat.FinishReasonStop, resp.Choices[0].FinishReason) + assert.NotNil(t, resp.Usage) + assert.Equal(t, int64(3), resp.Usage.OutputTokens) + + _, err = adapter.Recv() + assert.ErrorIs(t, err, io.EOF) +} + func TestWSStream_ToolCall(t *testing.T) { t.Parallel()