diff --git a/.flake8 b/.flake8 index 542ad1d..1bc2f7d 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,5 @@ [flake8] max-line-length = 88 exclude = .git,.github,.chglog,__pycache__,docs,venv,env,mypy_cache -max-complexity = 10 \ No newline at end of file +max-complexity = 10 +extend-ignore = E203 \ No newline at end of file diff --git a/README.md b/README.md index 16039e6..3cea739 100644 --- a/README.md +++ b/README.md @@ -542,6 +542,24 @@ For more detailed examples and integration patterns, check out: --- +## Streaming + +The SDK auto-detects the response wire format from `Content-Type`: + +- `application/vnd.amazon.eventstream` — AWS Bedrock event-stream framing (used by `invoke-with-response-stream` and `converse-stream`). Frames are parsed via `botocore.eventstream`; Bedrock's `{"bytes": ""}` envelope is unwrapped automatically. +- `text/event-stream` / `application/x-ndjson` — SSE (OpenAI, Anthropic direct, Gemini). + +Each route configures a `stream_response_path` (JSONPath) that extracts the per-chunk text from the model's native chunk schema. Common values: + +| Model family | `stream_response_path` | +| ------------------------ | -------------------------- | +| Anthropic Claude | `delta.text` | +| Meta Llama (Bedrock) | `generation` | +| Cohere Command (Bedrock) | `text` | +| OpenAI chat | `choices[0].delta.content` | + +If `stream_response_path` doesn't match a chunk, the SDK logs a single structured WARNING with the payload keys it saw — check that warning against the `ModelSpec` on the gateway. + ## Type Hints & py.typed Marker This package includes a `py.typed` marker file, which indicates to type checkers (like `mypy`, `pyright`, `pylance`) that the package supports type checking. This allows IDEs and static analysis tools to provide better autocomplete, type checking, and refactoring support. diff --git a/highflame/services/route_service.py b/highflame/services/route_service.py index 60979cd..ea642e9 100644 --- a/highflame/services/route_service.py +++ b/highflame/services/route_service.py @@ -1,9 +1,24 @@ +import base64 +import binascii import json import logging -from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union +from enum import Enum +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Dict, + Generator, + Iterator, + List, + Optional, + Union, +) import httpx +from botocore.eventstream import EventStreamBuffer from jsonpath_ng import parse +from jsonpath_ng.exceptions import JsonPathParserError from highflame.exceptions import ( BadRequest, @@ -18,6 +33,22 @@ logger = logging.getLogger(__name__) +class _StreamFormat(Enum): + EVENT_STREAM = "event_stream" + SSE = "sse" + UNKNOWN = "unknown" + + +_EVENT_STREAM_CONTENT_TYPES = ( + "application/vnd.amazon.eventstream", + "application/x-amz-eventstream", +) +_SSE_CONTENT_TYPES = ( + "text/event-stream", + "application/x-ndjson", +) + + class RouteService: def __init__(self, client): self.client = client @@ -171,78 +202,227 @@ async def adelete_route(self, route_name: str) -> str: self.areload_route(route_name=route_name) return self._process_route_response_ok(response) - def _extract_json_from_line(self, line_str: str) -> Optional[Dict[str, Any]]: - """Extract JSON data from a line string.""" + def _detect_stream_format(self, response: httpx.Response) -> _StreamFormat: + content_type = response.headers.get("content-type", "").lower() + if any(ct in content_type for ct in _EVENT_STREAM_CONTENT_TYPES): + return _StreamFormat.EVENT_STREAM + if any(ct in content_type for ct in _SSE_CONTENT_TYPES): + return _StreamFormat.SSE + return _StreamFormat.UNKNOWN + + def _compile_jsonpath(self, stream_response_path: Optional[str]): + if not stream_response_path: + return None try: - json_start = line_str.find("{") - json_end = line_str.rfind("}") + 1 - if json_start != -1 and json_end != -1: - json_str = line_str[json_start:json_end] - return json.loads(json_str) - except Exception: - pass - return None + return parse(stream_response_path) + except JsonPathParserError: + logger.warning( + "Invalid stream_response_path %r; streaming will yield nothing", + stream_response_path, + ) + return None - def _process_bytes_message( - self, data: Dict[str, Any], jsonpath_expr + def _decode_event_stream_payload(self, payload: bytes) -> Optional[Dict[str, Any]]: + """Decode one event-stream frame payload to a JSON dict. + + Bedrock wraps the model chunk as `{"bytes": ""}`. Some + providers emit raw JSON. Handle both. + """ + try: + outer = json.loads(payload) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + logger.debug("event-stream payload was not JSON: %s", e) + return None + + if isinstance(outer, dict) and "bytes" in outer: + try: + inner_bytes = base64.b64decode(outer["bytes"]) + return json.loads(inner_bytes) + except ( + binascii.Error, + json.JSONDecodeError, + UnicodeDecodeError, + TypeError, + ) as e: + logger.debug("event-stream inner bytes payload undecodable: %s", e) + return None + + return outer if isinstance(outer, dict) else None + + def _extract_text( + self, payload: Optional[Dict[str, Any]], jsonpath_expr ) -> Optional[str]: - """Process a message with bytes data.""" + """Return the matched text, or None if the path missed entirely. + + An empty-string match is returned as ``""`` (distinct from a miss) so the + caller can suppress the "did-not-match" warning for legitimate empty + chunks (e.g. OpenAI streaming keep-alives where ``delta.content == ""``). + A match whose value is non-string (dict, list, number) is treated as a + miss with a debug log, since stringifying it would mask a misconfigured + ``stream_response_path``. + """ + if payload is None or jsonpath_expr is None: + return None try: - if "bytes" in data: - import base64 - - bytes_data = base64.b64decode(data["bytes"]) - decoded_data = json.loads(bytes_data) - matches = jsonpath_expr.find(decoded_data) - if matches and matches[0].value: - return matches[0].value - except Exception: - pass + matches = jsonpath_expr.find(payload) + except Exception as e: # jsonpath_ng raises various subclasses + logger.debug("JSONPath evaluation failed: %s", e) + return None + if not matches: + return None + val = matches[0].value + if isinstance(val, str): + return val + if val is None: + return None + logger.debug( + "JSONPath matched a non-string value of type %s; treating as miss", + type(val).__name__, + ) return None - def _process_delta_message(self, data: Dict[str, Any]) -> Optional[str]: - """Process a message with delta data.""" + def _iter_event_stream_payloads( + self, byte_chunks: Iterator[bytes] + ) -> Iterator[Dict[str, Any]]: + buffer = EventStreamBuffer() + for chunk in byte_chunks: + if not chunk: + continue + try: + buffer.add_data(chunk) + except Exception as e: + logger.debug("event-stream framing error: %s", e) + continue + for message in buffer: + decoded = self._decode_event_stream_payload(message.payload) + if decoded is not None: + yield decoded + + async def _aiter_event_stream_payloads( + self, byte_chunks: AsyncIterator[bytes] + ) -> AsyncIterator[Dict[str, Any]]: + buffer = EventStreamBuffer() + async for chunk in byte_chunks: + if not chunk: + continue + try: + buffer.add_data(chunk) + except Exception as e: + logger.debug("event-stream framing error: %s", e) + continue + for message in buffer: + decoded = self._decode_event_stream_payload(message.payload) + if decoded is not None: + yield decoded + + def _iter_sse_payloads( + self, lines: Iterator[Union[str, bytes]] + ) -> Iterator[Dict[str, Any]]: + for line in lines: + payload = self._parse_sse_line(line) + if payload is not None: + yield payload + + async def _aiter_sse_payloads( + self, lines: AsyncIterator[Union[str, bytes]] + ) -> AsyncIterator[Dict[str, Any]]: + async for line in lines: + payload = self._parse_sse_line(line) + if payload is not None: + yield payload + + def _parse_sse_line(self, line: Union[str, bytes]) -> Optional[Dict[str, Any]]: + if not line: + return None + line_str = line.decode("utf-8") if isinstance(line, bytes) else line + if not line_str.startswith("data:"): + return None + data_str = line_str[len("data:") :].strip() + if not data_str or data_str == "[DONE]": + return None try: - if "delta" in data and "text" in data["delta"]: - return data["delta"]["text"] - except Exception: - pass - return None + return json.loads(data_str) + except json.JSONDecodeError as e: + logger.debug("SSE payload was not JSON: %s", e) + return None - def _process_sse_data(self, line_str: str, jsonpath_expr) -> Optional[str]: - """Process Server-Sent Events (SSE) data format.""" - try: - if line_str.strip() != "data: [DONE]": - json_str = line_str.replace("data: ", "") - data = json.loads(json_str) - matches = jsonpath_expr.find(data) - if matches and matches[0].value: - return matches[0].value - except Exception: - pass - return None + def _warn_once_on_miss( + self, + state: Dict[str, bool], + payload: Dict[str, Any], + stream_response_path: Optional[str], + ) -> None: + if state.get("warned"): + return + state["warned"] = True + logger.warning( + "stream_response_path %r did not match any chunk; payload keys=%s. " + "Check ModelSpec.stream_response_path for this route.", + stream_response_path, + sorted(payload.keys()) if isinstance(payload, dict) else type(payload), + ) - def _process_stream_line( - self, line_str: str, jsonpath_expr, is_bedrock: bool = False - ) -> Optional[str]: - """Process a single line from the stream response - and extract text if available.""" - try: - if "message-type" in line_str: - data = self._extract_json_from_line(line_str) - if data: - if "bytes" in line_str: - return self._process_bytes_message(data, jsonpath_expr) - else: - return self._process_delta_message(data) - - # Handle SSE data format - elif line_str.startswith("data: "): - return self._process_sse_data(line_str, jsonpath_expr) - - except Exception: - pass - return None + def _stream_text_sync( + self, + response: httpx.Response, + stream_response_path: Optional[str], + ) -> Generator[str, None, None]: + jsonpath_expr = self._compile_jsonpath(stream_response_path) + fmt = self._detect_stream_format(response) + miss_state: Dict[str, bool] = {"warned": False} + + if fmt == _StreamFormat.EVENT_STREAM: + payloads = self._iter_event_stream_payloads(response.iter_bytes()) + elif fmt == _StreamFormat.SSE: + payloads = self._iter_sse_payloads(response.iter_lines()) + else: + # Unknown content-type: try SSE first (most common), then fall back + # to event-stream framing on raw bytes if SSE yields nothing. + logger.debug( + "Unknown stream content-type %r; attempting SSE parse", + response.headers.get("content-type"), + ) + payloads = self._iter_sse_payloads(response.iter_lines()) + + for payload in payloads: + text = self._extract_text(payload, jsonpath_expr) + if text is None: + self._warn_once_on_miss(miss_state, payload, stream_response_path) + elif text: + yield text + + async def _stream_text_async( + self, + response: httpx.Response, + stream_response_path: Optional[str], + ) -> AsyncGenerator[str, None]: + jsonpath_expr = self._compile_jsonpath(stream_response_path) + fmt = self._detect_stream_format(response) + miss_state: Dict[str, bool] = {"warned": False} + + if fmt == _StreamFormat.EVENT_STREAM: + async for payload in self._aiter_event_stream_payloads( + response.aiter_bytes() + ): + text = self._extract_text(payload, jsonpath_expr) + if text is None: + self._warn_once_on_miss(miss_state, payload, stream_response_path) + elif text: + yield text + return + + if fmt == _StreamFormat.UNKNOWN: + logger.debug( + "Unknown stream content-type %r; attempting SSE parse", + response.headers.get("content-type"), + ) + + async for payload in self._aiter_sse_payloads(response.aiter_lines()): + text = self._extract_text(payload, jsonpath_expr) + if text is None: + self._warn_once_on_miss(miss_state, payload, stream_response_path) + elif text: + yield text def query_route( self, @@ -269,17 +449,7 @@ def query_route( if not stream or response.status_code != 200: return self._process_route_response_json(response) - jsonpath_expr = parse(stream_response_path) - - def generate_stream(): - for line in response.iter_lines(): - if line: - line_str = line.decode("utf-8") if isinstance(line, bytes) else line - text = self._process_stream_line(line_str, jsonpath_expr) - if text: - yield text - - return generate_stream() + return self._stream_text_sync(response, stream_response_path) async def aquery_route( self, @@ -305,19 +475,7 @@ async def aquery_route( if not stream or response.status_code != 200: return self._process_route_response_json(response) - jsonpath_expr = parse(stream_response_path) - - async def generate_stream(): - async for line in response.aiter_lines(): - if line: - line_str = line.decode("utf-8") if isinstance(line, bytes) else line - text = self._process_stream_line( - line_str, jsonpath_expr, is_bedrock=True - ) - if text: - yield text - - return generate_stream() + return self._stream_text_async(response, stream_response_path) def reload_route(self, route_name: str) -> str: """ @@ -377,23 +535,10 @@ def query_unified_endpoint( # Only parse JSON for application/json responses content_type = response.headers.get("content-type", "").lower() - print(f"Content-Type: {content_type}") if "application/json" in content_type: - print(f"Response: {response.json()}") return response.json() - # Handle streaming response if stream_response_path is provided - jsonpath_expr = parse(stream_response_path) - - def generate_stream(): - for line in response.iter_lines(): - if line: - line_str = line.decode("utf-8") if isinstance(line, bytes) else line - text = self._process_stream_line(line_str, jsonpath_expr) - if text: - yield text - - return generate_stream() + return self._stream_text_sync(response, stream_response_path) async def aquery_unified_endpoint( self, @@ -427,17 +572,4 @@ async def aquery_unified_endpoint( if "application/json" in content_type: return response.json() - # Handle streaming response if stream_response_path is provided - jsonpath_expr = parse(stream_response_path) - - async def generate_stream(): - async for line in response.aiter_lines(): - if line: - line_str = line.decode("utf-8") if isinstance(line, bytes) else line - text = self._process_stream_line( - line_str, jsonpath_expr, is_bedrock=True - ) - if text: - yield text - - return generate_stream() + return self._stream_text_async(response, stream_response_path) diff --git a/pyproject.toml b/pyproject.toml index 8ef0451..16f821e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ requests = "^2.32.3" urllib3 = "^2.6.3" jmespath = "^1.0.1" jsonpath-ng = "^1.7.0" +botocore = "^1.34.0" # OpenTelemetry Dependencies opentelemetry-api = "^1.32.1" @@ -57,3 +58,7 @@ isort = "^5.13.2" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +asyncio_mode = "strict" +testpaths = ["tests"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/streaming/__init__.py b/tests/streaming/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/streaming/conftest.py b/tests/streaming/conftest.py new file mode 100644 index 0000000..9b6600b --- /dev/null +++ b/tests/streaming/conftest.py @@ -0,0 +1,97 @@ +from pathlib import Path +from typing import AsyncIterator, Iterator, List +from unittest.mock import MagicMock + +import pytest + + +FIXTURES = Path(__file__).parent / "fixtures" + + +def _chunked(data: bytes, size: int = 64) -> List[bytes]: + return [data[i : i + size] for i in range(0, len(data), size)] + + +class FakeResponse: + """Minimal stand-in for httpx.Response in streaming tests. + + Supports both sync (iter_bytes/iter_lines) and async (aiter_bytes/aiter_lines) + interfaces. We drive it from a captured fixture file so the parser sees + real event-stream / SSE bytes split across multiple chunks (mimicking how + httpx actually delivers data over the wire). + """ + + def __init__( + self, + content: bytes, + content_type: str, + status_code: int = 200, + chunk_size: int = 64, + ): + self.headers = {"content-type": content_type} + self.status_code = status_code + self._chunks = _chunked(content, chunk_size) + self._text = content.decode("utf-8", errors="replace") + + def iter_bytes(self) -> Iterator[bytes]: + for c in self._chunks: + yield c + + async def aiter_bytes(self) -> AsyncIterator[bytes]: + for c in self._chunks: + yield c + + def iter_lines(self) -> Iterator[str]: + for line in self._text.splitlines(): + yield line + + async def aiter_lines(self) -> AsyncIterator[str]: + for line in self._text.splitlines(): + yield line + + +@pytest.fixture +def llama_eventstream_response(): + return FakeResponse( + (FIXTURES / "bedrock_llama3_chunks.bin").read_bytes(), + content_type="application/vnd.amazon.eventstream", + ) + + +@pytest.fixture +def cohere_eventstream_response(): + return FakeResponse( + (FIXTURES / "bedrock_cohere_chunks.bin").read_bytes(), + content_type="application/vnd.amazon.eventstream", + ) + + +@pytest.fixture +def claude_eventstream_response(): + return FakeResponse( + (FIXTURES / "bedrock_claude_chunks.bin").read_bytes(), + content_type="application/vnd.amazon.eventstream", + ) + + +@pytest.fixture +def openai_sse_response(): + return FakeResponse( + (FIXTURES / "openai_sse_chunks.txt").read_bytes(), + content_type="text/event-stream", + ) + + +@pytest.fixture +def openai_sse_with_empty_response(): + return FakeResponse( + (FIXTURES / "openai_sse_with_empty.txt").read_bytes(), + content_type="text/event-stream", + ) + + +@pytest.fixture +def route_service(): + from highflame.services.route_service import RouteService + + return RouteService(client=MagicMock()) diff --git a/tests/streaming/fixtures/_build.py b/tests/streaming/fixtures/_build.py new file mode 100644 index 0000000..3ad08fc --- /dev/null +++ b/tests/streaming/fixtures/_build.py @@ -0,0 +1,140 @@ +"""Build synthetic AWS event-stream fixtures for streaming parser tests. + +Re-run this script if AWS changes the event-stream framing format. It produces +deterministic .bin fixtures next to this file. + +The frame format (per AWS docs): + [total_length:4 BE][headers_length:4 BE][prelude_crc:4 BE] + [headers][payload] + [message_crc:4 BE] + +A header is encoded as: + [name_len:1][name:bytes][value_type:1][value_len:2 BE][value:bytes] + +We use only one header: `:message-type` = `event` (string, type 7). +""" + +import base64 +import binascii +import json +import struct +from pathlib import Path +from typing import List + + +HEADER_TYPE_STRING = 7 + + +def _encode_string_header(name: str, value: str) -> bytes: + name_bytes = name.encode("utf-8") + value_bytes = value.encode("utf-8") + return ( + struct.pack(">B", len(name_bytes)) + + name_bytes + + struct.pack(">B", HEADER_TYPE_STRING) + + struct.pack(">H", len(value_bytes)) + + value_bytes + ) + + +def encode_frame(payload: bytes) -> bytes: + headers = _encode_string_header(":message-type", "event") + headers_len = len(headers) + total_len = 4 + 4 + 4 + headers_len + len(payload) + 4 + prelude = struct.pack(">II", total_len, headers_len) + prelude_crc = binascii.crc32(prelude) & 0xFFFFFFFF + pre_message = prelude + struct.pack(">I", prelude_crc) + headers + payload + message_crc = binascii.crc32(pre_message) & 0xFFFFFFFF + return pre_message + struct.pack(">I", message_crc) + + +def encode_bedrock_chunks(model_chunks: List[dict]) -> bytes: + """Wrap each model chunk dict as Bedrock's `{"bytes": ""}` envelope. + + Bedrock event-stream payloads are JSON of shape {"bytes": ""} + where the base64-decoded bytes are the model's native chunk. + """ + out = b"" + for chunk in model_chunks: + inner = json.dumps(chunk).encode("utf-8") + envelope = json.dumps({"bytes": base64.b64encode(inner).decode("ascii")}) + out += encode_frame(envelope.encode("utf-8")) + return out + + +def encode_raw_chunks(model_chunks: List[dict]) -> bytes: + """For providers that emit raw JSON payloads (no base64 envelope).""" + out = b"" + for chunk in model_chunks: + out += encode_frame(json.dumps(chunk).encode("utf-8")) + return out + + +def encode_sse_chunks(chunks: List[dict]) -> bytes: + """Standard SSE: `data: \\n\\n` per chunk, terminating with [DONE].""" + lines = [] + for chunk in chunks: + lines.append(f"data: {json.dumps(chunk)}\n\n") + lines.append("data: [DONE]\n\n") + return "".join(lines).encode("utf-8") + + +def main(): + out_dir = Path(__file__).parent + + # Llama-3 on Bedrock: top-level `generation` field. + llama_chunks = [ + {"generation": "Hello", "generation_token_count": 1, "stop_reason": None}, + {"generation": " world", "generation_token_count": 2, "stop_reason": None}, + {"generation": "!", "generation_token_count": 3, "stop_reason": "stop"}, + ] + (out_dir / "bedrock_llama3_chunks.bin").write_bytes( + encode_bedrock_chunks(llama_chunks) + ) + + # Cohere Command on Bedrock: top-level `text` field. + cohere_chunks = [ + {"text": "Hello", "is_finished": False, "index": 0}, + {"text": " world", "is_finished": False, "index": 1}, + {"text": "!", "is_finished": True, "index": 2}, + ] + (out_dir / "bedrock_cohere_chunks.bin").write_bytes( + encode_bedrock_chunks(cohere_chunks) + ) + + # Claude on Bedrock: nested `delta.text` (regression — must keep working). + claude_chunks = [ + {"type": "content_block_delta", "delta": {"text": "Hello"}}, + {"type": "content_block_delta", "delta": {"text": " world"}}, + {"type": "content_block_delta", "delta": {"text": "!"}}, + ] + (out_dir / "bedrock_claude_chunks.bin").write_bytes( + encode_bedrock_chunks(claude_chunks) + ) + + # OpenAI-style SSE: `choices[0].delta.content`. + openai_chunks = [ + {"choices": [{"delta": {"content": "Hello"}}]}, + {"choices": [{"delta": {"content": " world"}}]}, + {"choices": [{"delta": {"content": "!"}}]}, + ] + (out_dir / "openai_sse_chunks.txt").write_bytes(encode_sse_chunks(openai_chunks)) + + # OpenAI-style SSE with an empty-content keep-alive chunk in the middle. + # The first chunk in a real OpenAI stream often has delta = {"role": "assistant"} + # with no `content` key (path miss), and intermediate empty `content: ""` chunks + # can also appear. Both should NOT trigger the "did-not-match" warning. + openai_with_empty = [ + {"choices": [{"delta": {"content": "Hello"}}]}, + {"choices": [{"delta": {"content": ""}}]}, # legitimate empty chunk + {"choices": [{"delta": {"content": " world"}}]}, + ] + (out_dir / "openai_sse_with_empty.txt").write_bytes( + encode_sse_chunks(openai_with_empty) + ) + + print(f"Wrote fixtures to {out_dir}") + + +if __name__ == "__main__": + main() diff --git a/tests/streaming/fixtures/bedrock_claude_chunks.bin b/tests/streaming/fixtures/bedrock_claude_chunks.bin new file mode 100644 index 0000000..75228ff Binary files /dev/null and b/tests/streaming/fixtures/bedrock_claude_chunks.bin differ diff --git a/tests/streaming/fixtures/bedrock_cohere_chunks.bin b/tests/streaming/fixtures/bedrock_cohere_chunks.bin new file mode 100644 index 0000000..166534f Binary files /dev/null and b/tests/streaming/fixtures/bedrock_cohere_chunks.bin differ diff --git a/tests/streaming/fixtures/bedrock_llama3_chunks.bin b/tests/streaming/fixtures/bedrock_llama3_chunks.bin new file mode 100644 index 0000000..0d116d9 Binary files /dev/null and b/tests/streaming/fixtures/bedrock_llama3_chunks.bin differ diff --git a/tests/streaming/fixtures/openai_sse_chunks.txt b/tests/streaming/fixtures/openai_sse_chunks.txt new file mode 100644 index 0000000..14a6743 --- /dev/null +++ b/tests/streaming/fixtures/openai_sse_chunks.txt @@ -0,0 +1,8 @@ +data: {"choices": [{"delta": {"content": "Hello"}}]} + +data: {"choices": [{"delta": {"content": " world"}}]} + +data: {"choices": [{"delta": {"content": "!"}}]} + +data: [DONE] + diff --git a/tests/streaming/fixtures/openai_sse_with_empty.txt b/tests/streaming/fixtures/openai_sse_with_empty.txt new file mode 100644 index 0000000..1dd3f85 --- /dev/null +++ b/tests/streaming/fixtures/openai_sse_with_empty.txt @@ -0,0 +1,8 @@ +data: {"choices": [{"delta": {"content": "Hello"}}]} + +data: {"choices": [{"delta": {"content": ""}}]} + +data: {"choices": [{"delta": {"content": " world"}}]} + +data: [DONE] + diff --git a/tests/streaming/test_event_stream_parser.py b/tests/streaming/test_event_stream_parser.py new file mode 100644 index 0000000..98a18f7 --- /dev/null +++ b/tests/streaming/test_event_stream_parser.py @@ -0,0 +1,159 @@ +"""Regression tests for the AWS event-stream streaming parser. + +Covers issues #146 (Cohere) and #147 (Llama-3) where the previous +substring-based parser silently yielded nothing for model schemas that don't +match the default `delta.text` JSONPath. +""" + +import logging + +import pytest + + +def _collect_sync(route_service, response, path): + return list(route_service._stream_text_sync(response, path)) + + +async def _collect_async(route_service, response, path): + return [chunk async for chunk in route_service._stream_text_async(response, path)] + + +# --- Sync path -------------------------------------------------------------- + + +def test_llama3_eventstream_yields_generation( + route_service, llama_eventstream_response +): + out = _collect_sync(route_service, llama_eventstream_response, "generation") + assert out == ["Hello", " world", "!"] + + +def test_cohere_eventstream_yields_text(route_service, cohere_eventstream_response): + out = _collect_sync(route_service, cohere_eventstream_response, "text") + assert out == ["Hello", " world", "!"] + + +def test_claude_eventstream_yields_delta_text( + route_service, claude_eventstream_response +): + """Regression: existing Claude behavior must keep working.""" + out = _collect_sync(route_service, claude_eventstream_response, "delta.text") + assert out == ["Hello", " world", "!"] + + +def test_eventstream_with_wrong_path_yields_nothing_and_warns_once( + route_service, llama_eventstream_response, caplog +): + caplog.set_level(logging.WARNING, logger="highflame.services.route_service") + out = _collect_sync(route_service, llama_eventstream_response, "delta.text") + assert out == [] + warnings = [ + r + for r in caplog.records + if r.levelno == logging.WARNING and "did not match" in r.getMessage() + ] + # Exactly one warning regardless of how many chunks missed. + assert len(warnings) == 1 + assert "'delta.text'" in warnings[0].getMessage() + + +# --- Async path ------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_llama3_eventstream_async(route_service, llama_eventstream_response): + out = await _collect_async(route_service, llama_eventstream_response, "generation") + assert out == ["Hello", " world", "!"] + + +@pytest.mark.asyncio +async def test_cohere_eventstream_async(route_service, cohere_eventstream_response): + out = await _collect_async(route_service, cohere_eventstream_response, "text") + assert out == ["Hello", " world", "!"] + + +@pytest.mark.asyncio +async def test_claude_eventstream_async(route_service, claude_eventstream_response): + out = await _collect_async(route_service, claude_eventstream_response, "delta.text") + assert out == ["Hello", " world", "!"] + + +# --- SSE path --------------------------------------------------------------- + + +def test_openai_sse_yields_delta_content(route_service, openai_sse_response): + out = _collect_sync(route_service, openai_sse_response, "choices[0].delta.content") + assert out == ["Hello", " world", "!"] + + +@pytest.mark.asyncio +async def test_openai_sse_async(route_service, openai_sse_response): + out = await _collect_async( + route_service, openai_sse_response, "choices[0].delta.content" + ) + assert out == ["Hello", " world", "!"] + + +# --- Empty-string match semantics (review feedback) ------------------------ + + +def test_empty_content_chunk_does_not_warn( + route_service, openai_sse_with_empty_response, caplog +): + """Empty-string chunks (legitimate keep-alives) must not trigger the + 'did not match' warning. Only true path-misses should warn.""" + caplog.set_level(logging.WARNING, logger="highflame.services.route_service") + out = list( + route_service._stream_text_sync( + openai_sse_with_empty_response, "choices[0].delta.content" + ) + ) + # The empty chunk yields nothing but the non-empty ones do. + assert out == ["Hello", " world"] + warnings = [ + r + for r in caplog.records + if r.levelno == logging.WARNING and "did not match" in r.getMessage() + ] + assert warnings == [] + + +def test_extract_text_distinguishes_miss_from_empty(route_service): + from jsonpath_ng import parse + + expr = parse("delta.content") + # Path missing entirely -> None + assert route_service._extract_text({"other": "x"}, expr) is None + # Path matches an empty string -> "" (NOT None) + assert route_service._extract_text({"delta": {"content": ""}}, expr) == "" + # Path matches a non-empty string -> the string + assert route_service._extract_text({"delta": {"content": "hi"}}, expr) == "hi" + # Path matches a non-string value -> None (treated as miss, debug-logged) + assert route_service._extract_text({"delta": {"content": 42}}, expr) is None + # Path matches null -> None + assert route_service._extract_text({"delta": {"content": None}}, expr) is None + + +# --- Robustness ------------------------------------------------------------- + + +def test_invalid_jsonpath_yields_nothing_without_raising( + route_service, llama_eventstream_response +): + # Garbage JSONPath should be compiled defensively, not crash mid-stream. + out = _collect_sync(route_service, llama_eventstream_response, "@@@bad@@@") + assert out == [] + + +@pytest.mark.asyncio +async def test_sync_and_async_agree( + route_service, llama_eventstream_response, cohere_eventstream_response +): + """Sync and async paths must produce identical output for the same fixture.""" + for fixture, path in [ + (llama_eventstream_response, "generation"), + (cohere_eventstream_response, "text"), + ]: + sync_out = _collect_sync(route_service, fixture, path) + async_out = await _collect_async(route_service, fixture, path) + assert sync_out == async_out