Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
8128dfa
feat(core): add cancel_generation() to ModelOutputThunk
planetf1 Apr 27, 2026
f26cce7
feat(stdlib): add stream_with_chunking() with per-chunk validation (#…
planetf1 Apr 27, 2026
93e7587
test(stdlib): add StreamingMockBackend and streaming orchestration tests
planetf1 Apr 27, 2026
a5d358c
docs: add streaming_chunking example (#901)
planetf1 Apr 27, 2026
39f18a4
docs(stdlib): add Args section to StreamChunkingResult class docstring
planetf1 Apr 28, 2026
36173cb
docs(stdlib): add Raises section to stream_with_chunking docstring
planetf1 Apr 28, 2026
ea6bdb0
fix(stdlib): stream_with_chunking passes one chunk per stream_validat…
planetf1 Apr 28, 2026
35df77f
docs(stdlib): fix example for delta semantics and note validator latency
planetf1 Apr 28, 2026
61448a9
feat(stdlib): flush trailing chunk fragment at end of stream
planetf1 Apr 28, 2026
def10b6
fix(stdlib): address review feedback on streaming validation
planetf1 Apr 28, 2026
da41a06
fix(stdlib): address second-round review feedback
planetf1 Apr 28, 2026
74c009d
docs(stdlib): add Args and Returns sections to chunker flush overrides
planetf1 Apr 28, 2026
3fb501e
fix(stdlib): address third-round review feedback
planetf1 Apr 29, 2026
5850f92
fix(stdlib): stash orchestrator exception and narrow finally except
planetf1 May 1, 2026
4f508fd
feat(core): add cancelled flag on ModelOutputThunk
planetf1 May 5, 2026
5075a47
docs(stdlib): note ChunkingStrategy is text-only
planetf1 May 5, 2026
f0f93b3
test(stdlib): assert cancelled flag reflects cancellation state
planetf1 May 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions docs/examples/streaming/streaming_chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# pytest: ollama, e2e
Comment thread
planetf1 marked this conversation as resolved.

"""Streaming generation with per-chunk validation using stream_with_chunking().

Demonstrates:
- Subclassing Requirement to override stream_validate() for early-exit checks
- Calling stream_with_chunking() with sentence-level chunking
- Consuming validated chunks via astream() as they arrive
- Awaiting full completion with acomplete() to access final_validations and full_text
"""

import asyncio

from mellea.core.backend import Backend
from mellea.core.base import Context
from mellea.core.requirement import (
PartialValidationResult,
Requirement,
ValidationResult,
)
from mellea.stdlib.components import Instruction
from mellea.stdlib.streaming import stream_with_chunking


class MaxSentencesReq(Requirement):
"""Fails if the model generates more than *limit* sentences mid-stream.

Each ``stream_validate`` call receives one complete sentence from the
:class:`~mellea.stdlib.chunking.SentenceChunker`. The running count is
maintained on ``self`` — this is the standard pattern for requirements
that need context beyond a single chunk.
"""

def __init__(self, limit: int) -> None:
super().__init__()
self._limit = limit
self._count = 0

def format_for_llm(self) -> str:
return f"The response must be at most {self._limit} sentences long."

async def stream_validate(
self, chunk: str, *, backend: Backend, ctx: Context
) -> PartialValidationResult:
self._count += 1
if self._count > self._limit:
return PartialValidationResult(
"fail",
reason=f"Response exceeded {self._limit} sentence limit mid-stream",
)
return PartialValidationResult("unknown")

async def validate(
self,
backend: Backend,
ctx: Context,
*,
format: type | None = None,
model_options: dict | None = None,
) -> ValidationResult:
return ValidationResult(result=True)


async def main() -> None:
from mellea.stdlib.session import start_session

m = start_session()
backend = m.backend
ctx = m.ctx

action = Instruction(
"Write a short paragraph about the water cycle in exactly two sentences."
)
req = MaxSentencesReq(limit=3)

result = await stream_with_chunking(
action, backend, ctx, quick_check_requirements=[req], chunking="sentence"
)

print("Streaming chunks as they arrive:")
async for chunk in result.astream():
print(f" CHUNK: {chunk!r}")

await result.acomplete()

print(f"\nCompleted normally: {result.completed}")
print(f"Full text: {result.full_text!r}")

if result.streaming_failures:
for _req, pvr in result.streaming_failures:
print(f"Streaming failure: {pvr.reason}")

if result.final_validations:
for vr in result.final_validations:
print(f"Final validation: {'PASS' if vr.as_bool() else 'FAIL'}")


asyncio.run(main())
81 changes: 81 additions & 0 deletions mellea/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def __init__(

# Set computed to True if a value is passed in.
self._computed: bool = True if value is not None else False
self._cancelled: bool = False

# Additional fields that should be standardized across apis.
self.tool_calls = tool_calls
Expand Down Expand Up @@ -364,6 +365,86 @@ def _record_ttfb(self) -> None:
).total_seconds() * 1000
self._first_chunk_received = True

async def cancel_generation(self, error: Exception | None = None) -> None:
"""Cancel an in-progress streaming generation, drain the queue, and close any open telemetry span.

Safe to call at any point during streaming. After this method returns,
``is_computed()`` is ``True`` and ``value`` contains whatever text was
accumulated before cancellation. Calling on an already-computed MOT
is a no-op.
Comment thread
planetf1 marked this conversation as resolved.

Draining the internal queue after cancellation is necessary to release
any ``asyncio.Queue.put()`` call that the generation task was blocked on
(queue maxsize=20).

Args:
error: Optional cause attributed to the open telemetry span. When
provided, this exception is recorded via ``set_span_error`` so
the span reflects the actual reason for cancellation (e.g. the
requirement failure or an unhandled exception from a streaming
validator). When ``None``, a generic
``RuntimeError("Generation cancelled")`` is recorded.
"""
if self._computed:
return

def _drain() -> None:
while not self._async_queue.empty():
try:
self._async_queue.get_nowait()
except asyncio.QueueEmpty:
break

if self._generate is not None and not self._generate.done():
self._generate.cancel()

if self._generate_extra is not None and not self._generate_extra.done():
self._generate_extra.cancel()

# Drain before awaiting — unblocks any put() the task is stuck on.
_drain()

if self._generate is not None:
try:
await self._generate
except (asyncio.CancelledError, Exception):
pass

if self._generate_extra is not None:
try:
await self._generate_extra
except (asyncio.CancelledError, Exception):
pass

# Drain again for any final item the task put before terminating.
_drain()

span = self._meta.pop("_telemetry_span", None)
if span is not None:
from ..telemetry import end_backend_span, set_span_error

recorded: Exception = (
error if error is not None else RuntimeError("Generation cancelled")
)
set_span_error(span, recorded)
end_backend_span(span)

if self._underlying_value is None:
self._underlying_value = ""
self._cancelled = True
self._computed = True

@property
def cancelled(self) -> bool:
"""``True`` if :meth:`cancel_generation` ran to completion on this MOT.

A normally-completed MOT leaves this ``False``; only an actual
cancellation via :meth:`cancel_generation` flips it. Consumers holding
a computed MOT can use this to distinguish a genuine result from one
cut short (for example by a streaming requirement failure).
"""
return self._cancelled

def _copy_from(self, other: ModelOutputThunk) -> None:
"""Copy computed-output fields from *other* into *self*.

Expand Down
15 changes: 13 additions & 2 deletions mellea/stdlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,20 @@
``mellea.stdlib.session`` — for day-to-day use.

Streaming chunking strategies (for use with streaming validation) are available at
``mellea.stdlib.chunking`` and re-exported here for convenience.
``mellea.stdlib.chunking`` and re-exported here for convenience. The core streaming
orchestration primitive :func:`~mellea.stdlib.streaming.stream_with_chunking` and
its result type :class:`~mellea.stdlib.streaming.StreamChunkingResult` are also
re-exported here.
"""

from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker
from .streaming import StreamChunkingResult, stream_with_chunking

__all__ = ["ChunkingStrategy", "ParagraphChunker", "SentenceChunker", "WordChunker"]
__all__ = [
"ChunkingStrategy",
"ParagraphChunker",
"SentenceChunker",
"StreamChunkingResult",
"WordChunker",
"stream_with_chunking",
Comment thread
planetf1 marked this conversation as resolved.
]
107 changes: 107 additions & 0 deletions mellea/stdlib/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class ChunkingStrategy(ABC):
take the full accumulated text, identify everything after the last returned
chunk boundary, and handle it appropriately (e.g. pass to a final validator
or discard).

Note: this ABC operates on text streams only. Multi-modal output (audio
segments, image regions) is not supported — the ``accumulated_text: str``
signatures on ``split`` and ``flush`` preclude it.
"""

@abstractmethod
Expand All @@ -35,6 +39,27 @@ def split(self, accumulated_text: str) -> list[str]:
"""
...

def flush(self, accumulated_text: str) -> list[str]:
Comment thread
planetf1 marked this conversation as resolved.
"""Return any trailing fragment that ``split`` withheld.

Called once by the orchestrator after the stream has ended naturally
(not on early-exit cancellation). Gives the chunker a chance to
release the final fragment that did not reach a terminator.

The default implementation returns an empty list — the trailing
fragment is discarded. Built-in chunkers override this to return
the withheld fragment as a single-element list when non-empty.

Args:
accumulated_text: The full accumulated text at stream end.

Returns:
The trailing fragment as ``[fragment]`` if it should be treated
as a final chunk, or an empty list to discard it.
"""
_ = accumulated_text
return []


# Sentence boundary: sentence-ending punctuation, optionally followed by a closing
# quote or paren, then whitespace.
Expand Down Expand Up @@ -94,6 +119,36 @@ def split(self, accumulated_text: str) -> list[str]:

return chunks

def flush(self, accumulated_text: str) -> list[str]:
"""Return the trailing sentence fragment (if any) as a final chunk.

Trailing whitespace on the fragment is non-semantic for sentence
boundaries and is dropped via ``rstrip``. Leading whitespace is
already removed by the loop's ``lstrip`` on each advance, so no
``lstrip`` is needed here. The result is the fragment's content
only, consistent with how :meth:`split` returns sentences without
trailing whitespace.

Args:
accumulated_text: The full accumulated text at stream end.

Returns:
A single-element list containing the trailing sentence fragment
with leading and trailing whitespace stripped, or an empty list
when there is no fragment (all content ended in a sentence
boundary or the input is empty/whitespace-only).
"""
if not accumulated_text:
return []
remaining = accumulated_text
while True:
match = _SENTENCE_BOUNDARY.search(remaining)
if match is None:
break
remaining = remaining[match.end() :].lstrip()
trailing = remaining.rstrip()
return [trailing] if trailing else []


class WordChunker(ChunkingStrategy):
"""Splits accumulated text on whitespace boundaries.
Expand Down Expand Up @@ -134,6 +189,32 @@ def split(self, accumulated_text: str) -> list[str]:

return parts

def flush(self, accumulated_text: str) -> list[str]:
"""Return the trailing word fragment (if any) as a final chunk.

The trailing fragment is the text after the last whitespace run when
the accumulated text does not end with whitespace. When it does end
with whitespace, every word is already complete and no fragment is
released.

Args:
accumulated_text: The full accumulated text at stream end.

Returns:
A single-element list containing the trailing word fragment, or
an empty list when the input ends with whitespace (every word
already complete) or is empty.
"""
if not accumulated_text:
return []
if accumulated_text[-1].isspace():
return []
parts = _WHITESPACE.split(accumulated_text)
for part in reversed(parts):
if part:
return [part]
return []


class ParagraphChunker(ChunkingStrategy):
r"""Splits accumulated text on double-newline paragraph boundaries.
Expand Down Expand Up @@ -168,3 +249,29 @@ def split(self, accumulated_text: str) -> list[str]:

# _PARA_BOUNDARY.split on leading \n\n produces an empty first element.
return [p for p in parts if p]

def flush(self, accumulated_text: str) -> list[str]:
r"""Return the trailing paragraph fragment (if any) as a final chunk.

Unlike :class:`SentenceChunker.flush`, the fragment is returned
byte-for-byte without stripping. Internal whitespace — including
a trailing single ``\n`` — can be semantically meaningful inside
a paragraph (e.g. a list item or a deliberate line break), and a
consumer validating paragraph content should see the fragment as
it was withheld.

Args:
accumulated_text: The full accumulated text at stream end.

Returns:
A single-element list containing the trailing paragraph fragment
byte-for-byte, or an empty list when the input ends with a
paragraph boundary (``\n\n`` or more) or is empty.
"""
if not accumulated_text:
return []
if _PARA_BOUNDARY_END.search(accumulated_text):
return []
parts = _PARA_BOUNDARY.split(accumulated_text)
trailing = parts[-1] if parts else ""
return [trailing] if trailing else []
Loading
Loading