From 8128dfad567ac072fa6731c50ce4a361d428da7f Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Mon, 27 Apr 2026 15:48:17 +0100 Subject: [PATCH 01/17] feat(core): add cancel_generation() to ModelOutputThunk Adds an async cancel_generation() method that cancels in-progress _generate and _generate_extra tasks, drains the internal async queue to release any blocked put() calls, closes the open telemetry span, and sets _computed=True so the MOT is immediately usable. Required by the stream_with_chunking() orchestrator (#901) for clean early-exit when a streaming requirement returns "fail". Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/core/base.py | 58 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/mellea/core/base.py b/mellea/core/base.py index 2028008d9..95b6e8cdc 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -364,6 +364,64 @@ def _record_ttfb(self) -> None: ).total_seconds() * 1000 self._first_chunk_received = True + async def cancel_generation(self) -> None: + """Cancel an in-progress streaming generation, drain the queue, and close any open telemetry span. + + Safe to call at any point during streaming. After this method returns, + ``is_computed()`` is ``True`` and ``value`` contains whatever text was + accumulated before cancellation. Calling on an already-computed MOT + is a no-op. + + Draining the internal queue after cancellation is necessary to release + any ``asyncio.Queue.put()`` call that the generation task was blocked on + (queue maxsize=20). + """ + if self._computed: + return + + def _drain() -> None: + while not self._async_queue.empty(): + try: + self._async_queue.get_nowait() + except asyncio.QueueEmpty: + break + + if self._generate is not None and not self._generate.done(): + self._generate.cancel() + + if self._generate_extra is not None and not self._generate_extra.done(): + self._generate_extra.cancel() + + # Drain before awaiting — unblocks any put() the task is stuck on. + _drain() + + if self._generate is not None: + try: + await self._generate + except (asyncio.CancelledError, Exception): + pass + + if self._generate_extra is not None: + try: + await self._generate_extra + except (asyncio.CancelledError, Exception): + pass + + # Drain again for any final item the task put before terminating. + _drain() + + span = self._meta.get("_telemetry_span") + if span is not None: + from ..telemetry import end_backend_span, set_span_error + + set_span_error(span, RuntimeError("Generation cancelled")) + end_backend_span(span) + del self._meta["_telemetry_span"] + + if self._underlying_value is None: + self._underlying_value = "" + self._computed = True + def _copy_from(self, other: ModelOutputThunk) -> None: """Copy computed-output fields from *other* into *self*. From f26cce729ff3d9b3b8e91952e6acb5dc2b50d9db Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Mon, 27 Apr 2026 15:48:34 +0100 Subject: [PATCH 02/17] feat(stdlib): add stream_with_chunking() with per-chunk validation (#901) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds stream_with_chunking() — the core streaming orchestration primitive that consumes a ModelOutputThunk's async stream via a background asyncio.Task, applies a ChunkingStrategy to produce semantic chunks, and runs stream_validate() in parallel across all requirements at each chunk boundary. Key behaviours: - Early exit: if any requirement returns "fail" during streaming, generation is cancelled immediately via cancel_generation() and StreamChunkingResult.completed is set to False. - Final validation: after natural completion, validate() is called on all non-failed requirements. - Clone-per-call: requirements are cloned (copy(req)) before each invocation; originals are never mutated. - String aliases: "sentence", "word", "paragraph" map to the corresponding ChunkingStrategy subclasses. StreamChunkingResult exposes: - astream() — async iterator yielding individual validated chunks - acomplete() — await full completion including final validation - as_thunk — wrap full_text as a computed ModelOutputThunk - completed, full_text, final_validations, streaming_failures Re-exports StreamChunkingResult and stream_with_chunking from mellea.stdlib for day-to-day use. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/stdlib/__init__.py | 15 +- mellea/stdlib/streaming.py | 282 +++++++++++++++++++++++++++++++++++++ 2 files changed, 295 insertions(+), 2 deletions(-) create mode 100644 mellea/stdlib/streaming.py diff --git a/mellea/stdlib/__init__.py b/mellea/stdlib/__init__.py index e4f32941b..7a30fdd53 100644 --- a/mellea/stdlib/__init__.py +++ b/mellea/stdlib/__init__.py @@ -10,9 +10,20 @@ ``mellea.stdlib.session`` — for day-to-day use. Streaming chunking strategies (for use with streaming validation) are available at -``mellea.stdlib.chunking`` and re-exported here for convenience. +``mellea.stdlib.chunking`` and re-exported here for convenience. The core streaming +orchestration primitive :func:`~mellea.stdlib.streaming.stream_with_chunking` and +its result type :class:`~mellea.stdlib.streaming.StreamChunkingResult` are also +re-exported here. """ from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker +from .streaming import StreamChunkingResult, stream_with_chunking -__all__ = ["ChunkingStrategy", "ParagraphChunker", "SentenceChunker", "WordChunker"] +__all__ = [ + "ChunkingStrategy", + "ParagraphChunker", + "SentenceChunker", + "StreamChunkingResult", + "WordChunker", + "stream_with_chunking", +] diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py new file mode 100644 index 000000000..0da2202ad --- /dev/null +++ b/mellea/stdlib/streaming.py @@ -0,0 +1,282 @@ +"""Streaming generation with per-chunk validation. + +Provides :func:`stream_with_chunking`, the core orchestration primitive that +consumes a streaming :class:`~mellea.core.base.ModelOutputThunk`, applies a +:class:`~mellea.stdlib.chunking.ChunkingStrategy` to produce semantic chunks, +and runs :meth:`~mellea.core.requirement.Requirement.stream_validate` on each +chunk in parallel. Higher-level streaming APIs build on this function. +""" + +import asyncio +from collections.abc import AsyncIterator, Sequence +from copy import copy +from typing import Any + +from ..backends.model_options import ModelOption +from ..core.backend import Backend +from ..core.base import CBlock, Component, Context, ModelOutputThunk +from ..core.requirement import PartialValidationResult, Requirement, ValidationResult +from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker + +_CHUNKING_ALIASES: dict[str, type[ChunkingStrategy]] = { + "sentence": SentenceChunker, + "word": WordChunker, + "paragraph": ParagraphChunker, +} + + +class StreamChunkingResult: + """Result of a :func:`stream_with_chunking` operation. + + Provides async iteration over validated text chunks as they complete + (:meth:`astream`), a blocking :meth:`acomplete` for awaiting the full + result including final validation, and :attr:`as_thunk` for wrapping the + output as a :class:`~mellea.core.base.ModelOutputThunk`. + + Instances are created by :func:`stream_with_chunking`; do not instantiate + directly. + + Attributes: + completed: ``False`` if the stream exited early because a requirement + returned ``"fail"`` during streaming; ``True`` otherwise. + full_text: The complete generated text accumulated during streaming. + Available after :meth:`acomplete` returns. + final_validations: :class:`~mellea.core.requirement.ValidationResult` + objects from the final :meth:`~mellea.core.requirement.Requirement.validate` + calls on all non-failed requirements. Available after + :meth:`acomplete` returns. + streaming_failures: ``(Requirement, PartialValidationResult)`` pairs + for every requirement that returned ``"fail"`` during streaming. + """ + + def __init__(self, mot: ModelOutputThunk, ctx: Context) -> None: + """Initialise with the MOT and context from the backend call.""" + self._mot = mot + self._ctx = ctx + self._chunk_queue: asyncio.Queue[str | None | Exception] = asyncio.Queue() + self._orchestration_task: asyncio.Task[None] | None = None + self._done = asyncio.Event() + + self.completed: bool = True + self.full_text: str = "" + self.final_validations: list[ValidationResult] = [] + self.streaming_failures: list[tuple[Requirement, PartialValidationResult]] = [] + + async def astream(self) -> AsyncIterator[str]: + """Yield validated text chunks as they complete. + + Each yielded string is a chunk that has passed per-chunk streaming + validation (or the stream had no requirements). Iteration ends when + all chunks have been yielded, whether the stream completed normally or + was cancelled early on a ``"fail"`` result. + + Yields: + str: A validated text chunk from the chunking strategy. + + Raises: + Exception: Propagates any error from the background orchestration + task. + """ + while True: + item = await self._chunk_queue.get() + if item is None: + return + if isinstance(item, Exception): + raise item + yield item + + async def acomplete(self) -> None: + """Await full completion, including final validation. + + After this method returns, :attr:`full_text`, :attr:`completed`, + :attr:`final_validations`, and :attr:`streaming_failures` are all + populated. If :meth:`astream` has already been consumed to + exhaustion, this call is effectively a no-op. + + Raises: + Exception: Propagates any error from the orchestration task. + """ + await self._done.wait() + if self._orchestration_task is not None and self._orchestration_task.done(): + exc = self._orchestration_task.exception() + if exc is not None: + raise exc + + @property + def as_thunk(self) -> ModelOutputThunk: + """Wrap the output as a computed :class:`~mellea.core.base.ModelOutputThunk`. + + Returns a new thunk with ``value`` set to :attr:`full_text` and + generation metadata copied from the original MOT. Safe to call on + early-exit results; ``value`` will reflect whatever was accumulated + before cancellation. + + Returns: + ModelOutputThunk: A computed thunk containing the streamed output. + + Raises: + RuntimeError: If called before :meth:`acomplete` has returned. + """ + if not self._done.is_set(): + raise RuntimeError( + "as_thunk accessed before acomplete() — await acomplete() first" + ) + thunk = ModelOutputThunk(value=self.full_text) + thunk.generation = copy(self._mot.generation) + return thunk + + +async def _orchestrate_streaming( + result: StreamChunkingResult, + mot: ModelOutputThunk, + ctx: Context, + cloned_reqs: list[Requirement], + chunking: ChunkingStrategy, + val_backend: Backend, +) -> None: + accumulated = "" + prev_chunk_count = 0 + failed_indices: set[int] = set() + early_exit = False + + try: + while not mot.is_computed(): + try: + delta = await mot.astream() + except RuntimeError: + break + + accumulated += delta + chunks = chunking.split(accumulated) + new_chunks = chunks[prev_chunk_count:] + + if new_chunks: + active = [ + (i, req) + for i, req in enumerate(cloned_reqs) + if i not in failed_indices + ] + if active: + pvrs: list[PartialValidationResult] = list( + await asyncio.gather( + *[ + req.stream_validate( + accumulated, backend=val_backend, ctx=ctx + ) + for _, req in active + ] + ) + ) + for (idx, req), pvr in zip(active, pvrs): + if pvr.success == "fail": + failed_indices.add(idx) + result.streaming_failures.append((req, pvr)) + + if failed_indices: + early_exit = True + result.completed = False + await mot.cancel_generation() + for c in new_chunks: + await result._chunk_queue.put(c) + break + + for c in new_chunks: + await result._chunk_queue.put(c) + prev_chunk_count = len(chunks) + + result.full_text = accumulated + + non_failed = [ + req for i, req in enumerate(cloned_reqs) if i not in failed_indices + ] + if non_failed and not early_exit: + result.final_validations = list( + await asyncio.gather( + *[req.validate(val_backend, ctx) for req in non_failed] + ) + ) + + except Exception as exc: + await result._chunk_queue.put(exc) + finally: + await result._chunk_queue.put(None) + result._done.set() + + +async def stream_with_chunking( + action: Component[Any] | CBlock, + backend: Backend, + ctx: Context, + *, + quick_check_requirements: Sequence[Requirement] | None = None, + chunking: str | ChunkingStrategy = "sentence", + quick_check_backend: Backend | None = None, +) -> StreamChunkingResult: + """Generate a streaming response with per-chunk validation. + + Starts a backend generation with streaming enabled, consumes the + :class:`~mellea.core.base.ModelOutputThunk`'s async stream in a single + background task, splits the accumulated text using *chunking*, and runs + :meth:`~mellea.core.requirement.Requirement.stream_validate` on each new + chunk in parallel across all requirements. + + If any requirement returns ``"fail"`` during streaming validation, the + generation is cancelled immediately (via + :meth:`~mellea.core.base.ModelOutputThunk.cancel_generation`) and + :attr:`StreamChunkingResult.completed` is set to ``False``. + + After the stream ends (naturally or via early exit), ``validate()`` is + called on all requirements that did not return ``"fail"``. Requirements + are cloned (``copy(req)``) before use so originals are never mutated. + + ``stream_validate`` receives the *accumulated* model output so far, not + just the current chunk. The chunking strategy determines *when* it is + called (at chunk boundaries). Requirements that want delta-only + processing track ``self._seen_len`` and slice + ``accumulated[self._seen_len:]``. + + Note: + v1 retry is simple re-invocation of this function. Plugin hooks + (``SAMPLING_LOOP_START``, ``SAMPLING_REPAIR``, etc.) do not fire + on retries — use the ``#902`` event types for observability instead. + + Args: + action: The component or content block to generate from. + backend: The backend used for generation and final validation. + ctx: The generation context. + quick_check_requirements: Sequence of requirements to validate against + each chunk during streaming. ``None`` disables streaming validation + (chunks are still produced; ``validate()`` is not called at stream end). + chunking: Chunking strategy — either a :class:`~mellea.stdlib.chunking.ChunkingStrategy` + instance or one of the string aliases ``"sentence"`` (default), + ``"word"``, or ``"paragraph"``. + quick_check_backend: Optional alternate backend for both + ``stream_validate`` and final ``validate`` calls. When ``None``, + *backend* is used for validation. + + Returns: + StreamChunkingResult: A result object providing :meth:`~StreamChunkingResult.astream` + for incremental chunk consumption and + :meth:`~StreamChunkingResult.acomplete` for blocking until done. + """ + if isinstance(chunking, str): + cls = _CHUNKING_ALIASES.get(chunking) + if cls is None: + raise ValueError( + f"Unknown chunking alias {chunking!r}. Choose from: {list(_CHUNKING_ALIASES)}" + ) + chunking = cls() + + opts: dict[str, Any] = {ModelOption.STREAM: True} + mot, gen_ctx = await backend.generate_from_context(action, ctx, model_options=opts) + + result = StreamChunkingResult(mot, gen_ctx) + + cloned_reqs = [copy(req) for req in (quick_check_requirements or [])] + val_backend = quick_check_backend if quick_check_backend is not None else backend + + result._orchestration_task = asyncio.create_task( + _orchestrate_streaming(result, mot, gen_ctx, cloned_reqs, chunking, val_backend) + ) + + return result From 93e75878c7b2774f446cfd66cef47e56ac6a73d8 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Mon, 27 Apr 2026 15:48:48 +0100 Subject: [PATCH 03/17] test(stdlib): add StreamingMockBackend and streaming orchestration tests Adds test/stdlib/test_streaming.py with 9 unit tests covering: - Normal completion: validate() called at stream end, completed=True - Early exit on "fail": completed=False, streaming_failures populated - Clone isolation: originals never mutated across retries - quick_check_backend routing: validation uses alternate backend - Deadlock prevention: early exit with asyncio.wait_for timeout - as_thunk correctness: value=full_text, raises before acomplete() - astream() yields individual chunks (not accumulated text) - No requirements: streams without validation StreamingMockBackend subclasses Backend and feeds a fixed response string into a MOT queue char-by-char via asyncio.create_task, following the create_manual_mock_thunk() pattern from test_astream_mock.py. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- test/stdlib/test_streaming.py | 384 ++++++++++++++++++++++++++++++++++ 1 file changed, 384 insertions(+) create mode 100644 test/stdlib/test_streaming.py diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py new file mode 100644 index 000000000..9284ec1c9 --- /dev/null +++ b/test/stdlib/test_streaming.py @@ -0,0 +1,384 @@ +"""Tests for stream_with_chunking() and StreamChunkingResult. + +Uses StreamingMockBackend — a deterministic test double that feeds tokens from a +fixed response string into a MOT queue without network or LLM calls. + +All tests are unit tests (no @pytest.mark.ollama needed). +""" + +import asyncio +from typing import Any + +import pytest + +from mellea.core.backend import Backend +from mellea.core.base import CBlock, Context, GenerateType, ModelOutputThunk +from mellea.core.requirement import ( + PartialValidationResult, + Requirement, + ValidationResult, +) +from mellea.stdlib.context import SimpleContext +from mellea.stdlib.streaming import stream_with_chunking + +# --------------------------------------------------------------------------- +# StreamingMockBackend +# --------------------------------------------------------------------------- + + +async def _mock_process(mot: ModelOutputThunk, chunk: Any) -> None: + if mot._underlying_value is None: + mot._underlying_value = "" + if chunk is not None: + mot._underlying_value += chunk + + +async def _mock_post_process(_mot: ModelOutputThunk) -> None: + pass + + +def _make_mot() -> ModelOutputThunk: + mot = ModelOutputThunk(value=None) + mot._action = CBlock("mock_action") + mot._generate_type = GenerateType.ASYNC + mot._process = _mock_process + mot._post_process = _mock_post_process + mot._chunk_size = 0 + return mot + + +async def _feed_tokens(mot: ModelOutputThunk, response: str, token_size: int) -> None: + i = 0 + while i < len(response): + token = response[i : i + token_size] + await mot._async_queue.put(token) + await asyncio.sleep(0) + i += token_size + await mot._async_queue.put(None) + + +class StreamingMockBackend(Backend): + """Test double that streams a fixed response one token at a time. + + ``token_size`` controls how many characters constitute one token. + Validation calls (via ``stream_validate`` / ``validate``) are delegated + to the requirements themselves — this backend does not perform any real + inference. + """ + + def __init__(self, response: str, token_size: int = 1) -> None: + self._response = response + self._token_size = token_size + + async def _generate_from_context( + self, + action: Any, + ctx: Context, + *, + format: Any = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk, Context]: + _ = format, model_options, tool_calls + mot = _make_mot() + task = asyncio.create_task(_feed_tokens(mot, self._response, self._token_size)) + _ = task + new_ctx = ctx.add(action).add(mot) + return mot, new_ctx + + async def generate_from_raw( + self, actions: Any, ctx: Any, **kwargs: Any + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Requirement test doubles +# --------------------------------------------------------------------------- + + +class AlwaysUnknownReq(Requirement): + """stream_validate always returns 'unknown'; validate returns True.""" + + def format_for_llm(self) -> str: + return "always unknown" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class FailAfterWordsReq(Requirement): + """Returns 'fail' once the accumulated text reaches *threshold* words.""" + + def __init__(self, threshold: int) -> None: + self._threshold = threshold + + def format_for_llm(self) -> str: + return f"fail after {self._threshold} words" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + if len(chunk.split()) >= self._threshold: + return PartialValidationResult("fail", reason="too many words") + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class BackendRecordingReq(Requirement): + """Records which backend was passed to stream_validate and validate.""" + + def __init__(self) -> None: + self.seen_backends: list[Any] = [] + + def __copy__(self) -> "BackendRecordingReq": + clone = BackendRecordingReq() + clone.seen_backends = [] # fresh list — do not share with original + return clone + + def format_for_llm(self) -> str: + return "backend recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk + self.seen_backends.append(backend) + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + self.seen_backends.append(backend) + return ValidationResult(result=True) + + +class MutationDetectorReq(Requirement): + """Tracks how many times stream_validate was called on this instance.""" + + def __init__(self) -> None: + self._call_count = 0 + + def format_for_llm(self) -> str: + return "mutation detector" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + self._call_count += 1 + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _ctx() -> SimpleContext: + return SimpleContext() + + +def _action() -> CBlock: + return CBlock("prompt") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normal_completion_calls_validate_at_stream_end() -> None: + """All 'unknown' requirements → validate() called at stream end; completed=True.""" + response = "Hello world. How are you. " + backend = StreamingMockBackend(response, token_size=3) + req = AlwaysUnknownReq() + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + ) + await result.acomplete() + + assert result.completed is True + assert result.full_text == response + assert len(result.final_validations) == 1 + assert result.final_validations[0].as_bool() is True + assert result.streaming_failures == [] + + +@pytest.mark.asyncio +async def test_early_exit_on_fail() -> None: + """Requirement fails mid-stream → completed=False, streaming_failures populated.""" + # 5 words to trigger failure + response = "one two three four five six seven eight. " + backend = StreamingMockBackend(response, token_size=2) + req = FailAfterWordsReq(threshold=4) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="word" + ) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + _req, pvr = result.streaming_failures[0] + assert pvr.success == "fail" + assert pvr.reason == "too many words" + # final_validations should be empty — final validate() skipped on early exit + assert result.final_validations == [] + + +@pytest.mark.asyncio +async def test_clone_isolation_across_retries() -> None: + """Originals must not be mutated; two invocations are independent.""" + response = "Sentence one. Sentence two. " + req = MutationDetectorReq() + original_reqs = [req] + + backend = StreamingMockBackend(response, token_size=4) + + r1 = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=original_reqs, + chunking="sentence", + ) + await r1.acomplete() + + r2 = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=original_reqs, + chunking="sentence", + ) + await r2.acomplete() + + # Original requirement must never have been called — only clones are used + assert req._call_count == 0 + + +@pytest.mark.asyncio +async def test_quick_check_backend_routing() -> None: + """stream_validate and validate receive quick_check_backend, not the main backend.""" + response = "One sentence. Two sentences. " + main_backend = StreamingMockBackend(response, token_size=3) + val_backend = StreamingMockBackend("unused", token_size=1) + + req = BackendRecordingReq() + + result = await stream_with_chunking( + _action(), + main_backend, + _ctx(), + quick_check_requirements=[req], + chunking="sentence", + quick_check_backend=val_backend, + ) + await result.acomplete() + + # The clone's seen_backends should only contain val_backend + # (The original req was never called; clones were.) + # Verify via final_validations side-effect: at least one backend recorded + assert result.completed is True + # The original req._seen_backends is untouched (clone isolation) + assert req.seen_backends == [] + + +@pytest.mark.asyncio +async def test_early_exit_does_not_deadlock() -> None: + """Early failure with a high-throughput stream must not hang.""" + long_response = "word " * 200 + backend = StreamingMockBackend(long_response, token_size=5) + req = FailAfterWordsReq(threshold=3) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="word" + ) + # 5-second timeout — should complete in milliseconds on success + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + assert result.completed is False + + +@pytest.mark.asyncio +async def test_as_thunk_correctness() -> None: + """as_thunk is computed, value matches full_text, generation metadata preserved.""" + response = "This is a test sentence. " + backend = StreamingMockBackend(response, token_size=4) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + await result.acomplete() + + thunk = result.as_thunk + assert thunk.is_computed() + assert thunk.value == result.full_text == response + + +@pytest.mark.asyncio +async def test_as_thunk_raises_before_acomplete() -> None: + """as_thunk raises RuntimeError if accessed before acomplete().""" + response = "Some text. " + backend = StreamingMockBackend(response, token_size=2) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + + with pytest.raises(RuntimeError, match="acomplete"): + _ = result.as_thunk + + +@pytest.mark.asyncio +async def test_astream_yields_individual_chunks() -> None: + """Consumer via astream() receives individual chunks, not accumulated text.""" + response = "First sentence. Second sentence. Third sentence. " + backend = StreamingMockBackend(response, token_size=5) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + + chunks: list[str] = [] + async for chunk in result.astream(): + chunks.append(chunk) + + await result.acomplete() + + # Each chunk must be a complete sentence (not the accumulated text) + assert len(chunks) == 3 + for chunk in chunks: + assert chunk.endswith(".") + # Chunks don't include inter-sentence spaces; joined with a space they appear in full_text + assert " ".join(chunks) in result.full_text + + +@pytest.mark.asyncio +async def test_no_requirements_streams_without_validation() -> None: + """quick_check_requirements=None → chunks produced, no validate() called.""" + response = "Chunk one. Chunk two. Chunk three. " + backend = StreamingMockBackend(response, token_size=3) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=None, chunking="sentence" + ) + await result.acomplete() + + assert result.completed is True + assert result.full_text == response + assert result.final_validations == [] + assert result.streaming_failures == [] From a5d358c970bc405aa8ca1ea0f277f42bc8d5a3d2 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Mon, 27 Apr 2026 15:49:19 +0100 Subject: [PATCH 04/17] docs: add streaming_chunking example (#901) Adds docs/examples/streaming/streaming_chunking.py demonstrating stream_with_chunking() end-to-end: defining a custom stream_validate() override, consuming chunks via astream(), and awaiting acomplete() to inspect final_validations and streaming_failures. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- docs/examples/streaming/streaming_chunking.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 docs/examples/streaming/streaming_chunking.py diff --git a/docs/examples/streaming/streaming_chunking.py b/docs/examples/streaming/streaming_chunking.py new file mode 100644 index 000000000..c11bceb24 --- /dev/null +++ b/docs/examples/streaming/streaming_chunking.py @@ -0,0 +1,91 @@ +# pytest: ollama, integration + +"""Streaming generation with per-chunk validation using stream_with_chunking(). + +Demonstrates: +- Subclassing Requirement to override stream_validate() for early-exit checks +- Calling stream_with_chunking() with sentence-level chunking +- Consuming validated chunks via astream() as they arrive +- Awaiting full completion with acomplete() to access final_validations and full_text +""" + +import asyncio + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import ( + PartialValidationResult, + Requirement, + ValidationResult, +) +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import stream_with_chunking + + +class MaxSentencesReq(Requirement): + """Fails if the model generates more than *limit* sentences mid-stream.""" + + def __init__(self, limit: int) -> None: + self._limit = limit + self._count = 0 + + def format_for_llm(self) -> str: + return f"The response must be at most {self._limit} sentences long." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + sentence_count = chunk.count(".") + chunk.count("!") + chunk.count("?") + if sentence_count > self._limit: + return PartialValidationResult( + "fail", + reason=f"Response exceeded {self._limit} sentence limit mid-stream", + ) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type | None = None, + model_options: dict | None = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + backend = m.backend + ctx = m.ctx + + action = Instruction( + "Write a short paragraph about the water cycle in exactly two sentences." + ) + req = MaxSentencesReq(limit=3) + + result = await stream_with_chunking( + action, backend, ctx, quick_check_requirements=[req], chunking="sentence" + ) + + print("Streaming chunks as they arrive:") + async for chunk in result.astream(): + print(f" CHUNK: {chunk!r}") + + await result.acomplete() + + print(f"\nCompleted normally: {result.completed}") + print(f"Full text: {result.full_text!r}") + + if result.streaming_failures: + for _req, pvr in result.streaming_failures: + print(f"Streaming failure: {pvr.reason}") + + if result.final_validations: + for vr in result.final_validations: + print(f"Final validation: {'PASS' if vr.as_bool() else 'FAIL'}") + + +asyncio.run(main()) From 39f18a4eb6ee61fb43f44d1a07bf5e423ad2a40e Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 09:40:48 +0100 Subject: [PATCH 05/17] docs(stdlib): add Args section to StreamChunkingResult class docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes [no_class_args] CI failure — the docs build-and-validate checker requires __init__ parameters to be documented in the class docstring (not __init__) per Option C convention. Assisted-by: Claude Code --- mellea/stdlib/streaming.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 0da2202ad..685378511 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -36,6 +36,11 @@ class StreamChunkingResult: Instances are created by :func:`stream_with_chunking`; do not instantiate directly. + Args: + mot: The :class:`~mellea.core.base.ModelOutputThunk` from the backend + generation call. + ctx: The generation context returned alongside the MOT. + Attributes: completed: ``False`` if the stream exited early because a requirement returned ``"fail"`` during streaming; ``True`` otherwise. From 36173cb839f03e9d14f20f51db8996efd9c6fa89 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 09:56:37 +0100 Subject: [PATCH 06/17] docs(stdlib): add Raises section to stream_with_chunking docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes second [no_raises] CI failure — stream_with_chunking raises ValueError for unknown chunking aliases but had no Raises: section. Assisted-by: Claude Code --- mellea/stdlib/streaming.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 685378511..425ab4b3b 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -263,6 +263,10 @@ async def stream_with_chunking( StreamChunkingResult: A result object providing :meth:`~StreamChunkingResult.astream` for incremental chunk consumption and :meth:`~StreamChunkingResult.acomplete` for blocking until done. + + Raises: + ValueError: If *chunking* is a string that does not match any known + alias (``"sentence"``, ``"word"``, ``"paragraph"``). """ if isinstance(chunking, str): cls = _CHUNKING_ALIASES.get(chunking) From ea6bdb077d8a9084a6b386344a9f95b3addd8906 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 17:07:29 +0100 Subject: [PATCH 07/17] fix(stdlib): stream_with_chunking passes one chunk per stream_validate call Aligns the orchestrator with the chunk-at-a-time design set out in the #891 epic and #900 spec. Previously _orchestrate_streaming passed the full accumulated text to stream_validate, and called it once per batch of new chunks rather than once per chunk. This masked the design intent of the chunking strategy and forced stateful requirements into the self._seen_len workaround. Behaviour changes: - stream_validate is called once per complete chunk produced by the chunking strategy (not once per astream() iteration) - The call receives that single chunk (not the accumulated text) - Multiple chunks from one astream() iteration are validated in order; early exit on a "fail" prevents later chunks in the same batch from being validated or emitted - On early exit, the failing chunk is no longer emitted to the consumer; consumers inspect StreamChunkingResult.streaming_failures instead (previous behaviour emitted whatever the current batch contained) Test changes: - FailAfterWordsReq now maintains a running word count on self, since each stream_validate call sees a one-word chunk rather than the growing accumulation - New test_stream_validate_receives_individual_chunks asserts the per-chunk contract directly by capturing the cloned requirement and checking the chunks it saw Docstring updated to describe the per-chunk contract, the in-order validation of a batch, the non-emission of failing chunks, and the MOT single-consumer constraint. Assisted-by: Claude Code --- mellea/stdlib/streaming.py | 41 ++++++++++-------- test/stdlib/test_streaming.py | 80 ++++++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 19 deletions(-) diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 425ab4b3b..1e5aca985 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -154,8 +154,9 @@ async def _orchestrate_streaming( accumulated += delta chunks = chunking.split(accumulated) new_chunks = chunks[prev_chunk_count:] + prev_chunk_count = len(chunks) - if new_chunks: + for c in new_chunks: active = [ (i, req) for i, req in enumerate(cloned_reqs) @@ -165,9 +166,7 @@ async def _orchestrate_streaming( pvrs: list[PartialValidationResult] = list( await asyncio.gather( *[ - req.stream_validate( - accumulated, backend=val_backend, ctx=ctx - ) + req.stream_validate(c, backend=val_backend, ctx=ctx) for _, req in active ] ) @@ -181,13 +180,12 @@ async def _orchestrate_streaming( early_exit = True result.completed = False await mot.cancel_generation() - for c in new_chunks: - await result._chunk_queue.put(c) break - for c in new_chunks: - await result._chunk_queue.put(c) - prev_chunk_count = len(chunks) + await result._chunk_queue.put(c) + + if early_exit: + break result.full_text = accumulated @@ -225,20 +223,29 @@ async def stream_with_chunking( :meth:`~mellea.core.requirement.Requirement.stream_validate` on each new chunk in parallel across all requirements. - If any requirement returns ``"fail"`` during streaming validation, the - generation is cancelled immediately (via + For each new complete chunk produced by the chunking strategy, + ``stream_validate`` is called once per active requirement (in parallel + via :func:`asyncio.gather`), receiving that single chunk. Multiple + chunks produced from one ``astream()`` iteration are validated + sequentially in order, so early exit on a ``"fail"`` result prevents + later chunks in the same batch from being validated or emitted to the + consumer. + + If any requirement returns ``"fail"``, the generation is cancelled + immediately (via :meth:`~mellea.core.base.ModelOutputThunk.cancel_generation`) and - :attr:`StreamChunkingResult.completed` is set to ``False``. + :attr:`StreamChunkingResult.completed` is set to ``False``. The + failing chunk is not emitted to the consumer; use + :attr:`StreamChunkingResult.streaming_failures` to inspect what failed. After the stream ends (naturally or via early exit), ``validate()`` is called on all requirements that did not return ``"fail"``. Requirements are cloned (``copy(req)``) before use so originals are never mutated. - ``stream_validate`` receives the *accumulated* model output so far, not - just the current chunk. The chunking strategy determines *when* it is - called (at chunk boundaries). Requirements that want delta-only - processing track ``self._seen_len`` and slice - ``accumulated[self._seen_len:]``. + Requirements that need context beyond the current chunk should + accumulate it themselves across ``stream_validate`` calls (e.g. + ``self._seen = self._seen + chunk``). They must not read ``mot.astream()`` + directly — this orchestrator is the single consumer of the MOT stream. Note: v1 retry is simple re-invocation of this function. Plugin hooks diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 9284ec1c9..bd03e5ffb 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -115,10 +115,15 @@ async def validate( class FailAfterWordsReq(Requirement): - """Returns 'fail' once the accumulated text reaches *threshold* words.""" + """Returns 'fail' once the cumulative word count reaches *threshold*. + + Each call to ``stream_validate`` receives a single chunk (delta) from the + chunking strategy; the running total is maintained on the instance. + """ def __init__(self, threshold: int) -> None: self._threshold = threshold + self._word_count = 0 def format_for_llm(self) -> str: return f"fail after {self._threshold} words" @@ -126,7 +131,8 @@ def format_for_llm(self) -> str: async def stream_validate( self, chunk: str, *, backend: Any, ctx: Any ) -> PartialValidationResult: - if len(chunk.split()) >= self._threshold: + self._word_count += len(chunk.split()) + if self._word_count >= self._threshold: return PartialValidationResult("fail", reason="too many words") return PartialValidationResult("unknown") @@ -367,6 +373,76 @@ async def test_astream_yields_individual_chunks() -> None: assert " ".join(chunks) in result.full_text +@pytest.mark.asyncio +async def test_stream_validate_receives_individual_chunks() -> None: + """stream_validate is called once per chunk with the chunk itself, not accumulated text.""" + + class ChunkRecordingReq(Requirement): + def __init__(self) -> None: + self.seen_chunks: list[str] = [] + + def __copy__(self) -> "ChunkRecordingReq": + clone = ChunkRecordingReq() + clone.seen_chunks = [] + return clone + + def format_for_llm(self) -> str: + return "chunk recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self.seen_chunks.append(chunk) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + response = "First sentence. Second sentence. Third sentence. " + backend = StreamingMockBackend(response, token_size=4) + req = ChunkRecordingReq() + + # Capture the cloned requirement used by the orchestrator via a side channel. + captured: list[ChunkRecordingReq] = [] + original_copy = ChunkRecordingReq.__copy__ + + def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + ChunkRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[req], + chunking="sentence", + ) + await result.acomplete() + finally: + ChunkRecordingReq.__copy__ = original_copy # type: ignore[method-assign] + + assert len(captured) == 1 + seen = captured[0].seen_chunks + # Three complete sentences → three separate stream_validate calls. + assert len(seen) == 3 + # Each chunk is one sentence, not a prefix of accumulated text. + for chunk in seen: + assert chunk.endswith(".") + # Lengths must not be monotonically growing (which would indicate accumulated text). + # With per-chunk semantics, each chunk is roughly the same length as one sentence. + assert not all(len(seen[i]) < len(seen[i + 1]) for i in range(len(seen) - 1)) + + @pytest.mark.asyncio async def test_no_requirements_streams_without_validation() -> None: """quick_check_requirements=None → chunks produced, no validate() called.""" From 35df77f0ed3dcba7bdd6d793af2288d60e470d26 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 17:27:30 +0100 Subject: [PATCH 08/17] docs(stdlib): fix example for delta semantics and note validator latency Two documentation fixes following the per-chunk semantics correction: - streaming_chunking.py: MaxSentencesReq previously counted sentence-end punctuation in the chunk, which worked under the old accumulated-text behaviour but returns at most 1 per sentence under delta semantics. Rewritten to increment self._count once per chunk -- the canonical pattern for a requirement that needs context beyond a single chunk. - stream_with_chunking docstring: add a Note that chunks are emitted to the consumer only after every active validator returns for that chunk. A slow stream_validate (e.g. an LLM-based one) therefore adds latency to every chunk. The invariant preserved is that the consumer never sees unvalidated content; a concurrent-emission fast path may be added in future if a concrete use case calls for it. Assisted-by: Claude Code --- docs/examples/streaming/streaming_chunking.py | 12 +++++++++--- mellea/stdlib/streaming.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/docs/examples/streaming/streaming_chunking.py b/docs/examples/streaming/streaming_chunking.py index c11bceb24..70037d0a1 100644 --- a/docs/examples/streaming/streaming_chunking.py +++ b/docs/examples/streaming/streaming_chunking.py @@ -23,7 +23,13 @@ class MaxSentencesReq(Requirement): - """Fails if the model generates more than *limit* sentences mid-stream.""" + """Fails if the model generates more than *limit* sentences mid-stream. + + Each ``stream_validate`` call receives one complete sentence from the + :class:`~mellea.stdlib.chunking.SentenceChunker`. The running count is + maintained on ``self`` — this is the standard pattern for requirements + that need context beyond a single chunk. + """ def __init__(self, limit: int) -> None: self._limit = limit @@ -35,8 +41,8 @@ def format_for_llm(self) -> str: async def stream_validate( self, chunk: str, *, backend: Backend, ctx: Context ) -> PartialValidationResult: - sentence_count = chunk.count(".") + chunk.count("!") + chunk.count("?") - if sentence_count > self._limit: + self._count += 1 + if self._count > self._limit: return PartialValidationResult( "fail", reason=f"Response exceeded {self._limit} sentence limit mid-stream", diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 1e5aca985..44cec5201 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -247,6 +247,19 @@ async def stream_with_chunking( ``self._seen = self._seen + chunk``). They must not read ``mot.astream()`` directly — this orchestrator is the single consumer of the MOT stream. + Note: + Chunks are emitted to the consumer (via + :meth:`StreamChunkingResult.astream`) only after every requirement's + ``stream_validate`` has returned for that chunk. A slow validator + (for example, one that invokes an LLM) therefore adds latency to + every chunk — the consumer sees a chunk at most as quickly as the + slowest active validator. This trade is deliberate in v1: it + preserves the invariant that the consumer never sees content that + has not been validated, which matters for UIs displaying generated + text live. A future fast-path mode that emits chunks to the + consumer concurrently with validation (at the cost of that + invariant) may be added if a concrete use case calls for it. + Note: v1 retry is simple re-invocation of this function. Plugin hooks (``SAMPLING_LOOP_START``, ``SAMPLING_REPAIR``, etc.) do not fire From 61448a90f49a5a7b8a9ab8840e19fac77b14c9fa Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 17:48:37 +0100 Subject: [PATCH 09/17] feat(stdlib): flush trailing chunk fragment at end of stream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ChunkingStrategy.split() withholds the trailing fragment by design (#899). Previously the orchestrator discarded it — it appeared in full_text and the final validate() saw it, but it was never yielded to astream() consumers and never seen by stream_validate. For a response that did not end in a chunk terminator (e.g. "Sentence one. Sentence two." with no trailing whitespace under SentenceChunker), the last sentence silently bypassed streaming validation. Adds ChunkingStrategy.flush(accumulated_text) -> list[str]: - Default in the ABC returns [] (backward-compatible — external chunkers retain the old discard behaviour until they opt in). - SentenceChunker, WordChunker, ParagraphChunker each override to return the withheld trailing fragment as a single-element list. _orchestrate_streaming calls chunking.flush(accumulated) after the main loop (only when the stream ended naturally, not on early exit — a cancelled stream's trailing fragment is by definition incomplete). Each flushed chunk goes through the same stream_validate / emit path as regular chunks, so the "no unvalidated content reaches the consumer" invariant extends to the trailing fragment, and a fail on the fragment still records a streaming failure and skips final validate(). Tests: - 13 new chunker tests covering the default-discard behaviour and each built-in's flush logic (empty input, fragment-present, already- terminated cases). - test_trailing_fragment_is_flushed_to_consumer: stream_validate sees the fragment and astream yields it. - test_early_exit_on_trailing_fragment: fail on the flushed fragment propagates to streaming_failures and skips final validation. Assisted-by: Claude Code --- mellea/stdlib/chunking.py | 56 ++++++++++++++++ mellea/stdlib/streaming.py | 70 +++++++++++++------- test/stdlib/test_chunking.py | 76 ++++++++++++++++++++++ test/stdlib/test_streaming.py | 118 ++++++++++++++++++++++++++++++++++ 4 files changed, 298 insertions(+), 22 deletions(-) diff --git a/mellea/stdlib/chunking.py b/mellea/stdlib/chunking.py index 6b9091780..d4a5f79e1 100644 --- a/mellea/stdlib/chunking.py +++ b/mellea/stdlib/chunking.py @@ -35,6 +35,27 @@ def split(self, accumulated_text: str) -> list[str]: """ ... + def flush(self, accumulated_text: str) -> list[str]: + """Return any trailing fragment that ``split`` withheld. + + Called once by the orchestrator after the stream has ended naturally + (not on early-exit cancellation). Gives the chunker a chance to + release the final fragment that did not reach a terminator. + + The default implementation returns an empty list — the trailing + fragment is discarded. Built-in chunkers override this to return + the withheld fragment as a single-element list when non-empty. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + The trailing fragment as ``[fragment]`` if it should be treated + as a final chunk, or an empty list to discard it. + """ + _ = accumulated_text + return [] + # Sentence boundary: sentence-ending punctuation, optionally followed by a closing # quote or paren, then whitespace. @@ -94,6 +115,19 @@ def split(self, accumulated_text: str) -> list[str]: return chunks + def flush(self, accumulated_text: str) -> list[str]: + """Return the trailing sentence fragment (if any) as a final chunk.""" + if not accumulated_text: + return [] + remaining = accumulated_text + while True: + match = _SENTENCE_BOUNDARY.search(remaining) + if match is None: + break + remaining = remaining[match.end() :].lstrip() + trailing = remaining.strip() + return [trailing] if trailing else [] + class WordChunker(ChunkingStrategy): """Splits accumulated text on whitespace boundaries. @@ -134,6 +168,18 @@ def split(self, accumulated_text: str) -> list[str]: return parts + def flush(self, accumulated_text: str) -> list[str]: + """Return the trailing word fragment (if any) as a final chunk.""" + if not accumulated_text: + return [] + if accumulated_text[-1].isspace(): + return [] + parts = _WHITESPACE.split(accumulated_text) + for part in reversed(parts): + if part: + return [part] + return [] + class ParagraphChunker(ChunkingStrategy): r"""Splits accumulated text on double-newline paragraph boundaries. @@ -168,3 +214,13 @@ def split(self, accumulated_text: str) -> list[str]: # _PARA_BOUNDARY.split on leading \n\n produces an empty first element. return [p for p in parts if p] + + def flush(self, accumulated_text: str) -> list[str]: + """Return the trailing paragraph fragment (if any) as a final chunk.""" + if not accumulated_text: + return [] + if _PARA_BOUNDARY_END.search(accumulated_text): + return [] + parts = _PARA_BOUNDARY.split(accumulated_text) + trailing = parts[-1] if parts else "" + return [trailing] if trailing else [] diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 44cec5201..0346fb519 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -144,6 +144,35 @@ async def _orchestrate_streaming( failed_indices: set[int] = set() early_exit = False + async def _validate_and_emit(c: str) -> bool: + """Run stream_validate on chunk c across active requirements. + + Returns True if a failure was recorded (caller should early-exit), + False otherwise (chunk was emitted to the consumer queue). + """ + active = [ + (i, req) for i, req in enumerate(cloned_reqs) if i not in failed_indices + ] + if active: + pvrs: list[PartialValidationResult] = list( + await asyncio.gather( + *[ + req.stream_validate(c, backend=val_backend, ctx=ctx) + for _, req in active + ] + ) + ) + for (idx, req), pvr in zip(active, pvrs): + if pvr.success == "fail": + failed_indices.add(idx) + result.streaming_failures.append((req, pvr)) + + if failed_indices: + return True + + await result._chunk_queue.put(c) + return False + try: while not mot.is_computed(): try: @@ -157,36 +186,27 @@ async def _orchestrate_streaming( prev_chunk_count = len(chunks) for c in new_chunks: - active = [ - (i, req) - for i, req in enumerate(cloned_reqs) - if i not in failed_indices - ] - if active: - pvrs: list[PartialValidationResult] = list( - await asyncio.gather( - *[ - req.stream_validate(c, backend=val_backend, ctx=ctx) - for _, req in active - ] - ) - ) - for (idx, req), pvr in zip(active, pvrs): - if pvr.success == "fail": - failed_indices.add(idx) - result.streaming_failures.append((req, pvr)) - - if failed_indices: + failed = await _validate_and_emit(c) + if failed: early_exit = True result.completed = False await mot.cancel_generation() break - await result._chunk_queue.put(c) - if early_exit: break + # Stream ended naturally: flush any withheld trailing fragment and + # run stream_validate on it. Skipped on early exit — the generation + # was cancelled, the trailing fragment is incomplete. + if not early_exit: + for c in chunking.flush(accumulated): + failed = await _validate_and_emit(c) + if failed: + early_exit = True + result.completed = False + break + result.full_text = accumulated non_failed = [ @@ -238,6 +258,12 @@ async def stream_with_chunking( failing chunk is not emitted to the consumer; use :attr:`StreamChunkingResult.streaming_failures` to inspect what failed. + When the stream ends naturally, any trailing fragment withheld by the + chunking strategy (see :meth:`~mellea.stdlib.chunking.ChunkingStrategy.flush`) + is released as a final chunk and run through ``stream_validate`` on the + same terms as the regular chunks. On early exit, the trailing fragment + is discarded because the generation was cancelled mid-token. + After the stream ends (naturally or via early exit), ``validate()`` is called on all requirements that did not return ``"fail"``. Requirements are cloned (``copy(req)``) before use so originals are never mutated. diff --git a/test/stdlib/test_chunking.py b/test/stdlib/test_chunking.py index fbaf727a2..7b965350f 100644 --- a/test/stdlib/test_chunking.py +++ b/test/stdlib/test_chunking.py @@ -242,3 +242,79 @@ def test_paragraph_chunker_incremental_simulation(): "First paragraph.", "Second paragraph.", ] + + +# --------------------------------------------------------------------------- +# flush() — trailing-fragment release at end of stream +# --------------------------------------------------------------------------- + + +def test_default_flush_returns_empty_list(): + """The ABC default discards the trailing fragment.""" + + class Minimal(ChunkingStrategy): + def split(self, accumulated_text: str) -> list[str]: + _ = accumulated_text + return [] + + assert Minimal().flush("anything at all") == [] + assert Minimal().flush("") == [] + + +def test_sentence_chunker_flush_empty(): + assert SentenceChunker().flush("") == [] + + +def test_sentence_chunker_flush_only_complete(): + """All text ends in a complete sentence with trailing whitespace → no fragment.""" + assert SentenceChunker().flush("One. Two. ") == [] + + +def test_sentence_chunker_flush_trailing_fragment(): + """Final sentence without trailing whitespace is released by flush.""" + assert SentenceChunker().flush("One. Two without period") == ["Two without period"] + + +def test_sentence_chunker_flush_terminated_no_trailing_space(): + """Final sentence with terminator but no trailing whitespace is a fragment + under split() semantics and gets released by flush().""" + assert SentenceChunker().flush("One. Two.") == ["Two."] + + +def test_sentence_chunker_flush_single_sentence_no_terminator(): + assert SentenceChunker().flush("Incomplete sentence") == ["Incomplete sentence"] + + +def test_word_chunker_flush_empty(): + assert WordChunker().flush("") == [] + + +def test_word_chunker_flush_trailing_whitespace(): + """Trailing whitespace means all words are complete → no fragment.""" + assert WordChunker().flush("one two three ") == [] + + +def test_word_chunker_flush_trailing_fragment(): + assert WordChunker().flush("one two three") == ["three"] + + +def test_word_chunker_flush_single_word(): + assert WordChunker().flush("solo") == ["solo"] + + +def test_paragraph_chunker_flush_empty(): + assert ParagraphChunker().flush("") == [] + + +def test_paragraph_chunker_flush_only_complete(): + assert ParagraphChunker().flush("Para one.\n\nPara two.\n\n") == [] + + +def test_paragraph_chunker_flush_trailing_fragment(): + assert ParagraphChunker().flush("Para one.\n\nPara two (no sep)") == [ + "Para two (no sep)" + ] + + +def test_paragraph_chunker_flush_single_paragraph_no_separator(): + assert ParagraphChunker().flush("Only paragraph") == ["Only paragraph"] diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index bd03e5ffb..7c3d97793 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -443,6 +443,124 @@ def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: assert not all(len(seen[i]) < len(seen[i + 1]) for i in range(len(seen) - 1)) +@pytest.mark.asyncio +async def test_trailing_fragment_is_flushed_to_consumer() -> None: + """Response without trailing whitespace: final sentence reaches astream() and stream_validate.""" + + class ChunkRecordingReq(Requirement): + def __init__(self) -> None: + self.seen_chunks: list[str] = [] + + def __copy__(self) -> "ChunkRecordingReq": + clone = ChunkRecordingReq() + clone.seen_chunks = [] + return clone + + def format_for_llm(self) -> str: + return "chunk recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self.seen_chunks.append(chunk) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + # No trailing whitespace after the final sentence — SentenceChunker withholds it. + response = "First sentence. Second sentence." + backend = StreamingMockBackend(response, token_size=4) + req = ChunkRecordingReq() + + captured: list[ChunkRecordingReq] = [] + original_copy = ChunkRecordingReq.__copy__ + + def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + ChunkRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[req], + chunking="sentence", + ) + yielded: list[str] = [] + async for chunk in result.astream(): + yielded.append(chunk) + await result.acomplete() + finally: + ChunkRecordingReq.__copy__ = original_copy # type: ignore[method-assign] + + # Both sentences reach the consumer, including the terminating one without trailing whitespace. + assert yielded == ["First sentence.", "Second sentence."] + # stream_validate was called on both — the flush path is not a shortcut. + assert captured[0].seen_chunks == ["First sentence.", "Second sentence."] + assert result.completed is True + + +@pytest.mark.asyncio +async def test_early_exit_on_trailing_fragment() -> None: + """A fail on the flushed fragment records a streaming failure and skips final validate().""" + + class FailOnSecondSentence(Requirement): + def __init__(self) -> None: + self._count = 0 + + def format_for_llm(self) -> str: + return "fail on second sentence" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + self._count += 1 + if self._count >= 2: + return PartialValidationResult("fail", reason="second sentence hit") + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + response = "First sentence. Second sentence." + backend = StreamingMockBackend(response, token_size=4) + req = FailOnSecondSentence() + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + ) + yielded: list[str] = [] + async for chunk in result.astream(): + yielded.append(chunk) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + # First sentence was emitted; second (the flushed fragment) failed and wasn't emitted. + assert yielded == ["First sentence."] + # Early exit on fail skips final validate(). + assert result.final_validations == [] + + @pytest.mark.asyncio async def test_no_requirements_streams_without_validation() -> None: """quick_check_requirements=None → chunks produced, no validate() called.""" From def10b6f87ce30c2b14901b9ba99acb083df2d1a Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 18:23:10 +0100 Subject: [PATCH 10/17] fix(stdlib): address review feedback on streaming validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses issues raised by independent review on top of PR #942. Orchestrator (mellea/stdlib/streaming.py): - except Exception now calls mot.cancel_generation() before surfacing the exception to the consumer — previously the backend producer was left running, eventually blocking on mot._async_queue (maxsize=20). Cleanup failures are logged via MelleaLogger.warning with a TODO(#902) marker; #902 replaces the log with a proper ErrorEvent. - RuntimeError catch in the astream() loop now re-raises unless mot.is_computed() is true, so only the documented "already computed" race is swallowed. - astream() docstring now states the single-consumer contract explicitly; a second iteration blocks on an empty queue with no sentinel to deliver. - as_thunk docstring now flags the early-exit case: cancel_generation forces is_computed=True without running post_processing(), so generation.usage and related telemetry fields may be None. Chunker (mellea/stdlib/chunking.py): - SentenceChunker.flush switches from .strip() to .rstrip() with a comment explaining why: the loop's lstrip has already removed leading whitespace, and trailing whitespace on a sentence fragment is non-semantic (consistent with split() returning sentences without trailing whitespace). - ParagraphChunker.flush adds a docstring noting the deliberate asymmetry: paragraph fragments are returned byte-for-byte because internal whitespace (e.g. trailing \n of a list item) can be semantically meaningful. Tests (test/stdlib/test_streaming.py): - test_stream_validate_receives_individual_chunks now uses exact- match on the captured chunk list, which directly regresses if someone reverts to accumulated-text semantics. - test_multiple_chunks_in_one_batch_with_mid_batch_fail: response fed as one large token so split() yields 4 sentences at once; verifies chunk 1 emits, chunk 2 fails (not emitted), chunks 3 and 4 are neither validated nor emitted. - test_cancel_generation_invoked_on_fail: spies on ModelOutputThunk.cancel_generation and asserts it was called on the "fail" early-exit path. - test_exception_in_stream_validate_cancels_generation: a requirement that raises must cause cancel_generation to run and the exception to surface via astream()/acomplete() without hanging. Telemetry observability (orchestrator-level spans, metrics, span events) remains deferred to #902 per the epic, which now has the acceptance criteria updated to cover event emission, the OTEL bridge, and the ErrorEvent type that will replace the MelleaLogger stopgap. Assisted-by: Claude Code --- mellea/stdlib/chunking.py | 22 +++- mellea/stdlib/streaming.py | 42 ++++++- test/stdlib/test_streaming.py | 201 ++++++++++++++++++++++++++++++++-- 3 files changed, 253 insertions(+), 12 deletions(-) diff --git a/mellea/stdlib/chunking.py b/mellea/stdlib/chunking.py index d4a5f79e1..fb3521ec2 100644 --- a/mellea/stdlib/chunking.py +++ b/mellea/stdlib/chunking.py @@ -116,7 +116,15 @@ def split(self, accumulated_text: str) -> list[str]: return chunks def flush(self, accumulated_text: str) -> list[str]: - """Return the trailing sentence fragment (if any) as a final chunk.""" + """Return the trailing sentence fragment (if any) as a final chunk. + + Trailing whitespace on the fragment is non-semantic for sentence + boundaries and is dropped via ``rstrip``. Leading whitespace is + already removed by the loop's ``lstrip`` on each advance, so no + ``lstrip`` is needed here. The result is the fragment's content + only, consistent with how :meth:`split` returns sentences without + trailing whitespace. + """ if not accumulated_text: return [] remaining = accumulated_text @@ -125,7 +133,7 @@ def flush(self, accumulated_text: str) -> list[str]: if match is None: break remaining = remaining[match.end() :].lstrip() - trailing = remaining.strip() + trailing = remaining.rstrip() return [trailing] if trailing else [] @@ -216,7 +224,15 @@ def split(self, accumulated_text: str) -> list[str]: return [p for p in parts if p] def flush(self, accumulated_text: str) -> list[str]: - """Return the trailing paragraph fragment (if any) as a final chunk.""" + r"""Return the trailing paragraph fragment (if any) as a final chunk. + + Unlike :class:`SentenceChunker.flush`, the fragment is returned + byte-for-byte without stripping. Internal whitespace — including + a trailing single ``\n`` — can be semantically meaningful inside + a paragraph (e.g. a list item or a deliberate line break), and a + consumer validating paragraph content should see the fragment as + it was withheld. + """ if not accumulated_text: return [] if _PARA_BOUNDARY_END.search(accumulated_text): diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 0346fb519..267e14848 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -16,6 +16,7 @@ from ..core.backend import Backend from ..core.base import CBlock, Component, Context, ModelOutputThunk from ..core.requirement import PartialValidationResult, Requirement, ValidationResult +from ..core.utils import MelleaLogger from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker _CHUNKING_ALIASES: dict[str, type[ChunkingStrategy]] = { @@ -75,6 +76,14 @@ async def astream(self) -> AsyncIterator[str]: all chunks have been yielded, whether the stream completed normally or was cancelled early on a ``"fail"`` result. + **Single-consumer.** Chunks are delivered via an + :class:`asyncio.Queue` that this method drains; calling + ``astream()`` a second time on the same result blocks indefinitely + because the queue is empty and the terminating ``None`` sentinel + has already been consumed. If you need the chunks after + iteration, capture them into a list during the first pass or use + :attr:`full_text` after :meth:`acomplete`. + Yields: str: A validated text chunk from the chunking strategy. @@ -116,6 +125,15 @@ def as_thunk(self) -> ModelOutputThunk: early-exit results; ``value`` will reflect whatever was accumulated before cancellation. + Note: + On early exit, ``cancel_generation()`` forces the MOT into a + computed state without running the backend's + ``post_processing()``. Telemetry fields on the returned thunk + (``generation.usage``, ``generation.ttfb_ms``, etc.) may + therefore be ``None`` or reflect the partial state at + cancellation time. ``value`` and ``streaming`` are reliable; + usage totals are not. + Returns: ModelOutputThunk: A computed thunk containing the streamed output. @@ -178,7 +196,12 @@ async def _validate_and_emit(c: str) -> bool: try: delta = await mot.astream() except RuntimeError: - break + # Expected race: mot.is_computed() was False at the top of the + # loop but the stream finished before we re-entered astream(). + # Any other RuntimeError is a real bug and must propagate. + if mot.is_computed(): + break + raise accumulated += delta chunks = chunking.split(accumulated) @@ -220,6 +243,23 @@ async def _validate_and_emit(c: str) -> bool: ) except Exception as exc: + # Orchestrator is leaving — we must stop the backend producer too, + # otherwise mot._async_queue (maxsize=20) fills and the feeder task + # blocks indefinitely. The spec (#891, #901) calls this out for the + # "fail" path; the same reasoning applies to any unplanned exit. + try: + await mot.cancel_generation() + except Exception as cleanup_exc: + # Never let cleanup mask the original exception: log loudly and + # continue to surface `exc` to the consumer. + # TODO(#902): replace this log with an ErrorEvent emission. + MelleaLogger.get_logger().warning( + "stream_with_chunking: cancel_generation() raised during " + "exception cleanup (original: %r, cleanup: %r)", + exc, + cleanup_exc, + ) + result.completed = False await result._chunk_queue.put(exc) finally: await result._chunk_queue.put(None) diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 7c3d97793..d52ce18a1 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -433,14 +433,12 @@ def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: assert len(captured) == 1 seen = captured[0].seen_chunks - # Three complete sentences → three separate stream_validate calls. - assert len(seen) == 3 - # Each chunk is one sentence, not a prefix of accumulated text. - for chunk in seen: - assert chunk.endswith(".") - # Lengths must not be monotonically growing (which would indicate accumulated text). - # With per-chunk semantics, each chunk is roughly the same length as one sentence. - assert not all(len(seen[i]) < len(seen[i + 1]) for i in range(len(seen) - 1)) + # Exact match: three separate calls, one per complete sentence, + # each call receiving that sentence and nothing more. Under the old + # accumulated-text semantics, seen would have been + # ["First sentence.", "First sentence. Second sentence.", ...] — + # exact match against the per-chunk list is the direct regression guard. + assert seen == ["First sentence.", "Second sentence.", "Third sentence."] @pytest.mark.asyncio @@ -576,3 +574,190 @@ async def test_no_requirements_streams_without_validation() -> None: assert result.full_text == response assert result.final_validations == [] assert result.streaming_failures == [] + + +@pytest.mark.asyncio +async def test_multiple_chunks_in_one_batch_with_mid_batch_fail() -> None: + """When one astream() delta produces several complete chunks and one in + the middle fails, earlier chunks emit, failing chunk is recorded, later + chunks are neither validated nor emitted.""" + + captured: list[Any] = [] + + class FailOnNthChunk(Requirement): + def __init__(self, n: int) -> None: + self._n = n + self._calls = 0 + self.seen: list[str] = [] + + def __copy__(self) -> "FailOnNthChunk": + clone = FailOnNthChunk(self._n) + captured.append(clone) + return clone + + def format_for_llm(self) -> str: + return f"fail on chunk {self._n}" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = backend, ctx + self._calls += 1 + self.seen.append(chunk) + if self._calls == self._n: + return PartialValidationResult("fail", reason=f"n={self._n}") + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + # token_size larger than the whole response → one astream() delta delivers + # the full text, so chunking.split produces 4 sentences in a single batch. + response = "One. Two. Three. Four. " + backend = StreamingMockBackend(response, token_size=100) + req = FailOnNthChunk(n=2) + + result = await stream_with_chunking( + _action(), backend, _ctx(), quick_check_requirements=[req], chunking="sentence" + ) + yielded: list[str] = [] + async for c in result.astream(): + yielded.append(c) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + # Chunk 1 was validated and emitted; chunk 2 was validated and failed + # (NOT emitted); chunks 3 and 4 were NEITHER validated NOR emitted. + assert yielded == ["One."] + assert len(captured) == 1 + assert captured[0].seen == ["One.", "Two."] + assert captured[0]._calls == 2 + + +@pytest.mark.asyncio +async def test_cancel_generation_invoked_on_fail() -> None: + """Early exit on 'fail' must call mot.cancel_generation() — the spec reason + is that asyncio.Queue(maxsize=20) will block the producer if the consumer + stops without cancelling.""" + + from mellea.core.base import ModelOutputThunk + + response = "word " * 50 + backend = StreamingMockBackend(response, token_size=3) + + class FailOnFirstChunk(Requirement): + def format_for_llm(self) -> str: + return "fail immediately" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + return PartialValidationResult("fail", reason="nope") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + call_count = 0 + real_cancel = ModelOutputThunk.cancel_generation + + async def spy_cancel(self: ModelOutputThunk) -> None: + nonlocal call_count + call_count += 1 + await real_cancel(self) + + ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[FailOnFirstChunk()], + chunking="word", + ) + await asyncio.wait_for(result.acomplete(), timeout=5.0) + finally: + ModelOutputThunk.cancel_generation = real_cancel # type: ignore[method-assign] + + assert result.completed is False + assert call_count >= 1 + + +@pytest.mark.asyncio +async def test_exception_in_stream_validate_cancels_generation() -> None: + """If stream_validate raises, the orchestrator must still call + cancel_generation() — otherwise the backend producer blocks on the + (maxsize=20) queue — and surface the exception to the consumer via + astream()/acomplete().""" + + from mellea.core.base import ModelOutputThunk + + class RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raises" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + raise ValueError("boom") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + response = "word " * 50 # enough to fill maxsize=20 queue without cleanup + backend = StreamingMockBackend(response, token_size=3) + + call_count = 0 + real_cancel = ModelOutputThunk.cancel_generation + + async def spy_cancel(self: ModelOutputThunk) -> None: + nonlocal call_count + call_count += 1 + await real_cancel(self) + + ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[RaisingReq()], + chunking="word", + ) + with pytest.raises(ValueError, match="boom"): + async for _chunk in result.astream(): + pass + # acomplete must complete (not hang) even though the orchestration + # task raised, because cancel_generation was called in the except path. + await asyncio.wait_for(result.acomplete(), timeout=5.0) + finally: + ModelOutputThunk.cancel_generation = real_cancel # type: ignore[method-assign] + + assert result.completed is False + assert call_count >= 1 From da41a06a0ce0926d482f8d74371da5f6fc7f4a41 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 19:09:15 +0100 Subject: [PATCH 11/17] fix(stdlib): address second-round review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three items from the second independent review: cancel_generation(error=) — accept an optional Exception parameter. When the orchestrator enters the except Exception path, it now passes the caught exception to cancel_generation() so the backend telemetry span records the real cause via set_span_error instead of a generic RuntimeError("Generation cancelled"). The original exception still surfaces to the consumer via astream()/acomplete(); this is purely an OTEL accuracy fix. Backward-compatible: the default None preserves the previous "Generation cancelled" message for the normal fail path. stream_with_chunking docstring — the "After the stream ends (naturally or via early exit), validate() is called" wording overstated behaviour. The orchestrator actually skips final validate() on early exit (test_early_exit_on_fail verifies final_validations == []). Docstring now correctly says final validate() runs only on natural completion. test_exception_in_stream_validate_cancels_generation docstring — the test fails on chunk 1 so the queue never actually fills; it verifies the cancel-on-exception path and the no-hang guarantee but does not directly prove the worst-case "producer blocked on full queue" scenario. Docstring now states what it actually covers and points at test/core/ for the cancel_generation drain logic. Assisted-by: Claude Code --- mellea/core/base.py | 15 +++++++++++++-- mellea/stdlib/streaming.py | 13 +++++++++---- test/stdlib/test_streaming.py | 26 ++++++++++++++++++-------- 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/mellea/core/base.py b/mellea/core/base.py index 95b6e8cdc..28ab78783 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -364,7 +364,7 @@ def _record_ttfb(self) -> None: ).total_seconds() * 1000 self._first_chunk_received = True - async def cancel_generation(self) -> None: + async def cancel_generation(self, error: Exception | None = None) -> None: """Cancel an in-progress streaming generation, drain the queue, and close any open telemetry span. Safe to call at any point during streaming. After this method returns, @@ -375,6 +375,14 @@ async def cancel_generation(self) -> None: Draining the internal queue after cancellation is necessary to release any ``asyncio.Queue.put()`` call that the generation task was blocked on (queue maxsize=20). + + Args: + error: Optional cause attributed to the open telemetry span. When + provided, this exception is recorded via ``set_span_error`` so + the span reflects the actual reason for cancellation (e.g. the + requirement failure or an unhandled exception from a streaming + validator). When ``None``, a generic + ``RuntimeError("Generation cancelled")`` is recorded. """ if self._computed: return @@ -414,7 +422,10 @@ def _drain() -> None: if span is not None: from ..telemetry import end_backend_span, set_span_error - set_span_error(span, RuntimeError("Generation cancelled")) + recorded: Exception = ( + error if error is not None else RuntimeError("Generation cancelled") + ) + set_span_error(span, recorded) end_backend_span(span) del self._meta["_telemetry_span"] diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 267e14848..16cafbe9d 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -247,8 +247,10 @@ async def _validate_and_emit(c: str) -> bool: # otherwise mot._async_queue (maxsize=20) fills and the feeder task # blocks indefinitely. The spec (#891, #901) calls this out for the # "fail" path; the same reasoning applies to any unplanned exit. + # Pass `exc` so the backend telemetry span records the real cause + # rather than a generic "Generation cancelled". try: - await mot.cancel_generation() + await mot.cancel_generation(error=exc) except Exception as cleanup_exc: # Never let cleanup mask the original exception: log loudly and # continue to surface `exc` to the consumer. @@ -304,9 +306,12 @@ async def stream_with_chunking( same terms as the regular chunks. On early exit, the trailing fragment is discarded because the generation was cancelled mid-token. - After the stream ends (naturally or via early exit), ``validate()`` is - called on all requirements that did not return ``"fail"``. Requirements - are cloned (``copy(req)``) before use so originals are never mutated. + After the stream ends naturally, ``validate()`` is called on every + requirement that did not return ``"fail"`` — both ``"pass"`` and + ``"unknown"`` trigger final validation. On early exit, no ``validate()`` + call is made; :attr:`StreamChunkingResult.final_validations` remains + empty. Requirements are cloned (``copy(req)``) before use so originals + are never mutated. Requirements that need context beyond the current chunk should accumulate it themselves across ``stream_validate`` calls (e.g. diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index d52ce18a1..560b94ec9 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -678,10 +678,12 @@ async def validate( call_count = 0 real_cancel = ModelOutputThunk.cancel_generation - async def spy_cancel(self: ModelOutputThunk) -> None: + async def spy_cancel( + self: ModelOutputThunk, error: Exception | None = None + ) -> None: nonlocal call_count call_count += 1 - await real_cancel(self) + await real_cancel(self, error) ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] try: @@ -702,10 +704,16 @@ async def spy_cancel(self: ModelOutputThunk) -> None: @pytest.mark.asyncio async def test_exception_in_stream_validate_cancels_generation() -> None: - """If stream_validate raises, the orchestrator must still call - cancel_generation() — otherwise the backend producer blocks on the - (maxsize=20) queue — and surface the exception to the consumer via - astream()/acomplete().""" + """Verifies the orchestrator's exception-path cleanup: if stream_validate + raises, cancel_generation() is called and the exception surfaces to the + consumer via astream()/acomplete() without hanging. + + This covers the cancel-on-exception path and the no-hang guarantee. + It does not directly exercise the worst-case "producer already blocked on + full queue" scenario (here the fail happens on chunk 1 so the queue never + fills); the cancel_generation drain logic is covered by its own tests in + test/core/. + """ from mellea.core.base import ModelOutputThunk @@ -736,10 +744,12 @@ async def validate( call_count = 0 real_cancel = ModelOutputThunk.cancel_generation - async def spy_cancel(self: ModelOutputThunk) -> None: + async def spy_cancel( + self: ModelOutputThunk, error: Exception | None = None + ) -> None: nonlocal call_count call_count += 1 - await real_cancel(self) + await real_cancel(self, error) ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] try: From 74c009d245e94d2d7dc4dbb721c60a135e096120 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 28 Apr 2026 19:38:00 +0100 Subject: [PATCH 12/17] docs(stdlib): add Args and Returns sections to chunker flush overrides The Docs CI docstring quality gate [no_class_args]-equivalent check requires every documented method with typed params to have an Args section and a Returns section matching the return annotation. SentenceChunker.flush, WordChunker.flush, and ParagraphChunker.flush all took accumulated_text and returned list[str] without the sections. Add both to each override, documenting each flush's specific semantics (rstrip for sentences, whitespace-split trailing fragment for words, byte-for-byte for paragraphs). Assisted-by: Claude Code --- mellea/stdlib/chunking.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/mellea/stdlib/chunking.py b/mellea/stdlib/chunking.py index fb3521ec2..6c81105c5 100644 --- a/mellea/stdlib/chunking.py +++ b/mellea/stdlib/chunking.py @@ -124,6 +124,15 @@ def flush(self, accumulated_text: str) -> list[str]: ``lstrip`` is needed here. The result is the fragment's content only, consistent with how :meth:`split` returns sentences without trailing whitespace. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing sentence fragment + with leading and trailing whitespace stripped, or an empty list + when there is no fragment (all content ended in a sentence + boundary or the input is empty/whitespace-only). """ if not accumulated_text: return [] @@ -177,7 +186,21 @@ def split(self, accumulated_text: str) -> list[str]: return parts def flush(self, accumulated_text: str) -> list[str]: - """Return the trailing word fragment (if any) as a final chunk.""" + """Return the trailing word fragment (if any) as a final chunk. + + The trailing fragment is the text after the last whitespace run when + the accumulated text does not end with whitespace. When it does end + with whitespace, every word is already complete and no fragment is + released. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing word fragment, or + an empty list when the input ends with whitespace (every word + already complete) or is empty. + """ if not accumulated_text: return [] if accumulated_text[-1].isspace(): @@ -232,6 +255,14 @@ def flush(self, accumulated_text: str) -> list[str]: a paragraph (e.g. a list item or a deliberate line break), and a consumer validating paragraph content should see the fragment as it was withheld. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing paragraph fragment + byte-for-byte, or an empty list when the input ends with a + paragraph boundary (``\n\n`` or more) or is empty. """ if not accumulated_text: return [] From 3fb501ef56c78c74fe18e617e1e3c17170611fab Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Wed, 29 Apr 2026 11:14:10 +0100 Subject: [PATCH 13/17] fix(stdlib): address third-round review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - _orchestrate_streaming: add cancel_generation() in finally block so the backend producer is stopped even on external CancelledError (BaseException bypasses except Exception, leaving _generate hung on a full queue) - cancel_generation: replace .get + del on _telemetry_span with .pop to prevent KeyError if two coroutines race before _computed is set - Example and test doubles: add super().__init__() to Requirement subclasses so description/validation_fn/_output are always initialised - docs/examples: fix pytest tier marker integration → e2e (Ollama example must be e2e per MARKERS_GUIDE; all peer examples use e2e) - test_quick_check_backend_routing: capture clone via __copy__ intercept and assert all seen_backends are val_backend, not just clone-isolation check Assisted-by: Claude Code --- docs/examples/streaming/streaming_chunking.py | 3 +- mellea/core/base.py | 3 +- mellea/stdlib/streaming.py | 9 ++++ test/stdlib/test_streaming.py | 44 +++++++++++++------ 4 files changed, 43 insertions(+), 16 deletions(-) diff --git a/docs/examples/streaming/streaming_chunking.py b/docs/examples/streaming/streaming_chunking.py index 70037d0a1..737ea7486 100644 --- a/docs/examples/streaming/streaming_chunking.py +++ b/docs/examples/streaming/streaming_chunking.py @@ -1,4 +1,4 @@ -# pytest: ollama, integration +# pytest: ollama, e2e """Streaming generation with per-chunk validation using stream_with_chunking(). @@ -32,6 +32,7 @@ class MaxSentencesReq(Requirement): """ def __init__(self, limit: int) -> None: + super().__init__() self._limit = limit self._count = 0 diff --git a/mellea/core/base.py b/mellea/core/base.py index 28ab78783..a8f35e79d 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -418,7 +418,7 @@ def _drain() -> None: # Drain again for any final item the task put before terminating. _drain() - span = self._meta.get("_telemetry_span") + span = self._meta.pop("_telemetry_span", None) if span is not None: from ..telemetry import end_backend_span, set_span_error @@ -427,7 +427,6 @@ def _drain() -> None: ) set_span_error(span, recorded) end_backend_span(span) - del self._meta["_telemetry_span"] if self._underlying_value is None: self._underlying_value = "" diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index 16cafbe9d..c417f9057 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -264,6 +264,15 @@ async def _validate_and_emit(c: str) -> bool: result.completed = False await result._chunk_queue.put(exc) finally: + # CancelledError (BaseException, not Exception) bypasses the except + # block above, so cancel_generation() may not have been called. + # Guard here ensures the backend producer is always stopped, even on + # external task cancellation (e.g. asyncio.wait_for timeout). + if not mot.is_computed(): + try: + await mot.cancel_generation() + except BaseException: + pass await result._chunk_queue.put(None) result._done.set() diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 560b94ec9..46550d9a0 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -122,6 +122,7 @@ class FailAfterWordsReq(Requirement): """ def __init__(self, threshold: int) -> None: + super().__init__() self._threshold = threshold self._word_count = 0 @@ -146,6 +147,7 @@ class BackendRecordingReq(Requirement): """Records which backend was passed to stream_validate and validate.""" def __init__(self) -> None: + super().__init__() self.seen_backends: list[Any] = [] def __copy__(self) -> "BackendRecordingReq": @@ -174,6 +176,7 @@ class MutationDetectorReq(Requirement): """Tracks how many times stream_validate was called on this instance.""" def __init__(self) -> None: + super().__init__() self._call_count = 0 def format_for_llm(self) -> str: @@ -291,22 +294,37 @@ async def test_quick_check_backend_routing() -> None: req = BackendRecordingReq() - result = await stream_with_chunking( - _action(), - main_backend, - _ctx(), - quick_check_requirements=[req], - chunking="sentence", - quick_check_backend=val_backend, - ) - await result.acomplete() + # Capture the cloned requirement so we can inspect which backends it saw. + captured: list[BackendRecordingReq] = [] + original_copy = BackendRecordingReq.__copy__ + + def _capturing_copy(self: BackendRecordingReq) -> BackendRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + BackendRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + main_backend, + _ctx(), + quick_check_requirements=[req], + chunking="sentence", + quick_check_backend=val_backend, + ) + await result.acomplete() + finally: + BackendRecordingReq.__copy__ = original_copy # type: ignore[method-assign] - # The clone's seen_backends should only contain val_backend - # (The original req was never called; clones were.) - # Verify via final_validations side-effect: at least one backend recorded assert result.completed is True - # The original req._seen_backends is untouched (clone isolation) + # The original was never called — only clones are used. assert req.seen_backends == [] + # The clone must have seen val_backend for every call (stream_validate + validate), + # never main_backend. This is the actual routing assertion. + assert len(captured) == 1 + assert len(captured[0].seen_backends) > 0 + assert all(b is val_backend for b in captured[0].seen_backends) @pytest.mark.asyncio From 5850f924787960b96c6915864911255a89db62b9 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Fri, 1 May 2026 14:26:22 +0100 Subject: [PATCH 14/17] fix(stdlib): stash orchestrator exception and narrow finally except Addresses review feedback on `_orchestrate_streaming` cleanup: - Exceptions caught by the orchestrator were only pushed to the chunk queue, so callers who skipped `astream()` and went straight to `acomplete()` saw the call return silently. Stash the exception on the result and raise it from `acomplete()` with raise-once semantics (cleared by whichever of astream/acomplete reads it first). - The finally cleanup caught `BaseException`, silently absorbing CancelledError/KeyboardInterrupt/SystemExit. Narrow to `except Exception` and switch the terminator to `put_nowait(None)` + `set()` so the sync ops always run even when the task is being cancelled (otherwise acomplete consumers hang). Two regression tests: - test_acomplete_surfaces_exception_without_astream - test_external_task_cancellation_releases_consumers Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/stdlib/streaming.py | 36 +++++++++++--- test/stdlib/test_streaming.py | 88 +++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 7 deletions(-) diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index c417f9057..dcdd6e894 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -62,6 +62,10 @@ def __init__(self, mot: ModelOutputThunk, ctx: Context) -> None: self._chunk_queue: asyncio.Queue[str | None | Exception] = asyncio.Queue() self._orchestration_task: asyncio.Task[None] | None = None self._done = asyncio.Event() + # Stashed so acomplete() surfaces orchestrator failures even when the + # consumer never iterates astream(). Cleared once consumed by + # whichever of the two reads it first. + self._orchestration_exception: BaseException | None = None self.completed: bool = True self.full_text: str = "" @@ -96,6 +100,10 @@ async def astream(self) -> AsyncIterator[str]: if item is None: return if isinstance(item, Exception): + if self._orchestration_exception is None: + # Already surfaced by acomplete(); don't raise twice. + continue + self._orchestration_exception = None raise item yield item @@ -111,10 +119,16 @@ async def acomplete(self) -> None: Exception: Propagates any error from the orchestration task. """ await self._done.wait() + # Raise-once: if astream() already consumed the exception, the stash + # is already None and this is a no-op. + exc = self._orchestration_exception + if exc is not None: + self._orchestration_exception = None + raise exc if self._orchestration_task is not None and self._orchestration_task.done(): - exc = self._orchestration_task.exception() - if exc is not None: - raise exc + task_exc = self._orchestration_task.exception() + if task_exc is not None: + raise task_exc @property def as_thunk(self) -> ModelOutputThunk: @@ -262,18 +276,26 @@ async def _validate_and_emit(c: str) -> bool: cleanup_exc, ) result.completed = False + result._orchestration_exception = exc await result._chunk_queue.put(exc) finally: # CancelledError (BaseException, not Exception) bypasses the except # block above, so cancel_generation() may not have been called. - # Guard here ensures the backend producer is always stopped, even on - # external task cancellation (e.g. asyncio.wait_for timeout). + # Catch only Exception here so CancelledError / KeyboardInterrupt / + # SystemExit still propagate to the caller. if not mot.is_computed(): try: await mot.cancel_generation() - except BaseException: + except Exception: pass - await result._chunk_queue.put(None) + # put_nowait + set() are synchronous — no await point, so they cannot + # be interrupted by task cancellation. Consumers waiting on + # _done.wait() are always released, even if the task was cancelled + # mid-cleanup. The queue is unbounded, so QueueFull cannot occur. + try: + result._chunk_queue.put_nowait(None) + except asyncio.QueueFull: + pass result._done.set() diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 46550d9a0..759d8d272 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -789,3 +789,91 @@ async def spy_cancel( assert result.completed is False assert call_count >= 1 + + +@pytest.mark.asyncio +async def test_acomplete_surfaces_exception_without_astream() -> None: + """acomplete() must surface orchestrator exceptions even when the + consumer never iterates astream(). + + The alternative — only delivering the exception through the chunk queue + — silently swallows validator failures for callers who skip astream(). + """ + + class RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raises" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + raise ValueError("surfaced-without-astream") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + response = "word " * 50 + backend = StreamingMockBackend(response, token_size=3) + + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[RaisingReq()], + chunking="word", + ) + # Deliberately skip astream(). wait_for bounds any hang. + with pytest.raises(ValueError, match="surfaced-without-astream"): + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + assert result.completed is False + # Raise-once: a second acomplete() must not re-raise. + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + +@pytest.mark.asyncio +async def test_external_task_cancellation_releases_consumers() -> None: + """External cancellation of the orchestration task must still set _done. + + If the finally cleanup itself contains an ``await`` (e.g. awaiting a + terminator put into the chunk queue), CancelledError re-raises at that + await and ``_done.set()`` never runs — any consumer blocked on + ``acomplete()`` hangs forever. The cleanup must therefore end with + synchronous operations only. + """ + response = "word " * 200 # long enough that streaming is still in progress + backend = StreamingMockBackend(response, token_size=2) + + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + quick_check_requirements=[AlwaysUnknownReq()], + chunking="word", + ) + + assert result._orchestration_task is not None + # Yield once so the orchestration task enters its main loop before we + # cancel it. + await asyncio.sleep(0.01) + + # Same mechanism asyncio.wait_for uses on timeout. + result._orchestration_task.cancel() + + # _done must be set by the finally cleanup. A hang would time out here. + await asyncio.wait_for(result._done.wait(), timeout=2.0) + assert result._done.is_set() + + # acomplete() surfaces the CancelledError via task.exception() and must + # not hang. + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(result.acomplete(), timeout=2.0) From 4f508fd4b12d035733e6f069b2805fbc0982a4f2 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 5 May 2026 10:11:17 +0100 Subject: [PATCH 15/17] feat(core): add cancelled flag on ModelOutputThunk Adds a `_cancelled` attribute (False by default) on `ModelOutputThunk`, set to True inside `cancel_generation()` just before `_computed = True`, exposed via a read-only `cancelled` property. Propagated through `StreamChunkingResult.as_thunk` so consumers that only hold the wrapped thunk can still distinguish cancellation from a natural completion. Addresses nrfulton's review feedback on #942 and pre-stages the cancel-vs-complete signal that #902's `CompletedEvent` needs to surface. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/core/base.py | 13 +++++++++++++ mellea/stdlib/streaming.py | 1 + 2 files changed, 14 insertions(+) diff --git a/mellea/core/base.py b/mellea/core/base.py index a8f35e79d..60a079c8a 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -320,6 +320,7 @@ def __init__( # Set computed to True if a value is passed in. self._computed: bool = True if value is not None else False + self._cancelled: bool = False # Additional fields that should be standardized across apis. self.tool_calls = tool_calls @@ -430,8 +431,20 @@ def _drain() -> None: if self._underlying_value is None: self._underlying_value = "" + self._cancelled = True self._computed = True + @property + def cancelled(self) -> bool: + """``True`` if :meth:`cancel_generation` ran to completion on this MOT. + + A normally-completed MOT leaves this ``False``; only an actual + cancellation via :meth:`cancel_generation` flips it. Consumers holding + a computed MOT can use this to distinguish a genuine result from one + cut short (for example by a streaming requirement failure). + """ + return self._cancelled + def _copy_from(self, other: ModelOutputThunk) -> None: """Copy computed-output fields from *other* into *self*. diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index dcdd6e894..2ba815096 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -159,6 +159,7 @@ def as_thunk(self) -> ModelOutputThunk: "as_thunk accessed before acomplete() — await acomplete() first" ) thunk = ModelOutputThunk(value=self.full_text) + thunk._cancelled = self._mot._cancelled thunk.generation = copy(self._mot.generation) return thunk From 5075a4794cf29578c4bbcc4f3affa224bf53f289 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 5 May 2026 11:59:32 +0100 Subject: [PATCH 16/17] docs(stdlib): note ChunkingStrategy is text-only Adds a short note on the ChunkingStrategy class docstring stating that the ABC operates on text streams only and does not support multi-modal output (audio segments, image regions). Addresses review feedback on #942 without expanding scope. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/stdlib/chunking.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mellea/stdlib/chunking.py b/mellea/stdlib/chunking.py index 6c81105c5..462af81a8 100644 --- a/mellea/stdlib/chunking.py +++ b/mellea/stdlib/chunking.py @@ -19,6 +19,10 @@ class ChunkingStrategy(ABC): take the full accumulated text, identify everything after the last returned chunk boundary, and handle it appropriately (e.g. pass to a final validator or discard). + + Note: this ABC operates on text streams only. Multi-modal output (audio + segments, image regions) is not supported — the ``accumulated_text: str`` + signatures on ``split`` and ``flush`` preclude it. """ @abstractmethod From f0f93b3acf8b73d4a7be09e082c5bfc1fe3598a6 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Tue, 5 May 2026 12:06:33 +0100 Subject: [PATCH 17/17] test(stdlib): assert cancelled flag reflects cancellation state Adds test_cancelled_flag_reflects_cancellation_state covering both the early-exit path (cancelled is True, is_computed True, propagates through as_thunk) and the normal-completion path (cancelled is False). Pairs with the cancellation flag added in the prior commit. Addresses nrfulton's review feedback on #942. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- test/stdlib/test_streaming.py | 61 +++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index 759d8d272..2010c9b3e 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -720,6 +720,67 @@ async def spy_cancel( assert call_count >= 1 +@pytest.mark.asyncio +async def test_cancelled_flag_reflects_cancellation_state() -> None: + """The ``cancelled`` property on ModelOutputThunk distinguishes an early-exit + cancellation from a normal completion and propagates through ``as_thunk``.""" + + # Early exit → cancelled is True, is_computed True, propagates through as_thunk. + fail_response = "word " * 50 + fail_backend = StreamingMockBackend(fail_response, token_size=3) + + class FailImmediately(Requirement): + def format_for_llm(self) -> str: + return "fail immediately" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + return PartialValidationResult("fail", reason="nope") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + fail_result = await stream_with_chunking( + _action(), + fail_backend, + _ctx(), + quick_check_requirements=[FailImmediately()], + chunking="word", + ) + await asyncio.wait_for(fail_result.acomplete(), timeout=5.0) + + assert fail_result.completed is False + assert fail_result.as_thunk.cancelled is True + assert fail_result.as_thunk.is_computed() is True + + # Normal completion → cancelled is False. + ok_response = "Hello world. How are you. " + ok_backend = StreamingMockBackend(ok_response, token_size=3) + + ok_result = await stream_with_chunking( + _action(), + ok_backend, + _ctx(), + quick_check_requirements=[AlwaysUnknownReq()], + chunking="sentence", + ) + await ok_result.acomplete() + + assert ok_result.completed is True + assert ok_result.as_thunk.cancelled is False + assert ok_result.as_thunk.is_computed() is True + + @pytest.mark.asyncio async def test_exception_in_stream_validate_cancels_generation() -> None: """Verifies the orchestrator's exception-path cleanup: if stream_validate