From 74a4760ed609c2a999e7ffb746a2d59083a9c770 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Mon, 27 Apr 2026 18:36:49 +0000 Subject: [PATCH 1/5] feat: add context compaction strategies for react framework Adds CompactionStrategy abstraction and KeepLastN implementation to mellea/stdlib/compaction.py, wires an optional compaction parameter into the react() loop, and adds full test coverage in test/stdlib/test_compaction.py. Assisted-by: Claude Code Signed-off-by: Yousef El-Kurdi --- mellea/stdlib/compaction.py | 325 ++++++++++++++++++++++++++++ mellea/stdlib/frameworks/react.py | 13 ++ test/stdlib/test_compaction.py | 344 ++++++++++++++++++++++++++++++ 3 files changed, 682 insertions(+) create mode 100644 mellea/stdlib/compaction.py create mode 100644 test/stdlib/test_compaction.py diff --git a/mellea/stdlib/compaction.py b/mellea/stdlib/compaction.py new file mode 100644 index 000000000..2f9636fb7 --- /dev/null +++ b/mellea/stdlib/compaction.py @@ -0,0 +1,325 @@ +"""Context compaction strategies for the ReACT framework. + +Provides modular, callable strategy objects to compact a ``ChatContext`` that +has grown too large during a react loop. Three strategies are available: + +- ``ClearAll`` — discard the entire conversation body, keeping only the prefix + (everything up to and including the ``ReactInitiator``). +- ``KeepLastN`` — keep the prefix plus the *n* most recent body components. +- ``LLMSummarize`` — ask the backend to summarize old body components into a + single ``Message``, then keep the last *n* body components verbatim. + +All strategies preserve the **prefix** (every component up to and including the +first ``ReactInitiator``) so the model retains its goal and tool definitions. + +Example:: + + from mellea.stdlib.compaction import KeepLastN + from mellea.stdlib.frameworks.react import react + + await react( + goal="...", + context=ChatContext(), + backend=m.backend, + tools=[search_tool], + compaction=KeepLastN(keep_n=5, threshold=20), + ) +""" + +from __future__ import annotations + +import abc + +from mellea.core.backend import Backend +from mellea.core.base import CBlock, Component, ModelOutputThunk +from mellea.core.utils import MelleaLogger +from mellea.stdlib.components.chat import Message, ToolMessage +from mellea.stdlib.components.react import ReactInitiator +from mellea.stdlib.context import ChatContext + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def rebuild_chat_context( + components: list[Component | CBlock], *, window_size: int | None = None +) -> ChatContext: + """Build a fresh ``ChatContext`` from an ordered list of components. + + Args: + components: Components to add, in chronological order. + window_size: Optional sliding-window size for the new context. + + Returns: + A new ``ChatContext`` containing all *components*. + """ + ctx = ChatContext(window_size=window_size) + for c in components: + ctx = ctx.add(c) + return ctx + + +def _find_prefix_end(components: list[Component | CBlock]) -> int: + """Return the index *after* the first ``ReactInitiator``. + + Everything in ``components[:idx]`` is the prefix that must be preserved by + every compaction strategy. Returns 0 when no ``ReactInitiator`` is found. + """ + for i, c in enumerate(components): + if isinstance(c, ReactInitiator): + return i + 1 + return 0 + + +# --------------------------------------------------------------------------- +# Abstract base +# --------------------------------------------------------------------------- + + +class CompactionStrategy(abc.ABC): + """Abstract base class for context compaction strategies. + + Each strategy carries a ``threshold`` — the component count above which + compaction should fire. The :meth:`should_compact` helper checks this so + callers don't need to track the threshold separately. + + Subclasses implement :meth:`compact` which receives the current + ``ChatContext`` and returns a compacted copy. The method is ``async`` + so that strategies requiring LLM calls (e.g. ``LLMSummarize``) work + transparently; synchronous strategies simply never ``await``. + + Args: + threshold (int): Trigger compaction when the number of context + components exceeds this value. + """ + + def __init__(self, *, threshold: int = 0) -> None: + """Initialize with the component-count threshold.""" + self.threshold = threshold + + def should_compact(self, context: ChatContext) -> bool: + """Return ``True`` when *context* exceeds the configured threshold. + + Args: + context: The context to check. + + Returns: + ``True`` if the number of components exceeds ``self.threshold`` + and ``self.threshold`` is greater than 0. + """ + return self.threshold > 0 and len(context.as_list()) > self.threshold + + async def maybe_compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Compact *context* only if it exceeds the threshold, otherwise return it unchanged. + + Args: + context: The context to check and potentially compact. + backend: The backend (forwarded to :meth:`compact`). + goal: The react goal string (forwarded to :meth:`compact`). + + Returns: + A compacted ``ChatContext`` if the threshold was exceeded, + or the original *context* unchanged. + """ + if self.should_compact(context): + return await self.compact(context, backend=backend, goal=goal) + return context + + @abc.abstractmethod + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a compacted copy of *context*. + + Args: + context: The context to compact. + backend: The backend (required by ``LLMSummarize``). + goal: The react goal string (required by ``LLMSummarize``). + + Returns: + A new, compacted ``ChatContext``. + """ + + +# --------------------------------------------------------------------------- +# Concrete strategies +# --------------------------------------------------------------------------- + + +class ClearAll(CompactionStrategy): + """Discard the entire conversation body, keeping only the prefix. + + The prefix is everything up to and including the first ``ReactInitiator``. + + Args: + threshold (int): Trigger compaction when context exceeds this many components. + """ + + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a context containing only the prefix.""" + components = context.as_list() + prefix_end = _find_prefix_end(components) + compacted = components[:prefix_end] + + MelleaLogger.get_logger().info( + f"ClearAll: compacted context from {len(components)} to " + f"{len(compacted)} components" + ) + return rebuild_chat_context(compacted, window_size=context._window_size) + + +class KeepLastN(CompactionStrategy): + """Keep the prefix plus the last *keep_n* body components. + + Args: + keep_n (int): Number of recent body components to retain. + threshold (int): Trigger compaction when context exceeds this many components. + """ + + def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: + """Initialize with the number of recent body components to keep.""" + super().__init__(threshold=threshold) + self.keep_n = keep_n + + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a context with the prefix and the last *keep_n* body components.""" + components = context.as_list() + prefix_end = _find_prefix_end(components) + prefix = components[:prefix_end] + body = components[prefix_end:] + + if len(body) <= self.keep_n: + return context # nothing to compact + + compacted = prefix + body[-self.keep_n :] + + MelleaLogger.get_logger().info( + f"KeepLastN(keep_n={self.keep_n}): compacted context from " + f"{len(components)} to {len(compacted)} components" + ) + return rebuild_chat_context(compacted, window_size=context._window_size) + + +class LLMSummarize(CompactionStrategy): + """Summarize old body components with the LLM, keep last *keep_n* verbatim. + + Requires ``backend`` and ``goal`` to be passed to :meth:`compact`. + + Args: + keep_n (int): Number of recent body components to retain verbatim. + threshold (int): Trigger compaction when context exceeds this many components. + """ + + def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: + """Initialize with the number of recent body components to keep.""" + super().__init__(threshold=threshold) + self.keep_n = keep_n + + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a context with the prefix, an LLM summary, and recent body components. + + Raises: + ValueError: If *backend* or *goal* are not provided. + """ + if backend is None or goal is None: + raise ValueError( + "LLMSummarize requires both 'backend' and 'goal' arguments" + ) + + from mellea.stdlib import functional as mfuncs + from mellea.stdlib.context import SimpleContext + + components = context.as_list() + prefix_end = _find_prefix_end(components) + prefix = components[:prefix_end] + body = components[prefix_end:] + + if len(body) <= self.keep_n: + return context # nothing to compact + + old = body[: -self.keep_n] if self.keep_n > 0 else body + recent = body[-self.keep_n :] if self.keep_n > 0 else [] + + # Build a textual representation of old components for summarization. + context_lines: list[str] = [] + for c in old: + if isinstance(c, ToolMessage): + context_lines.append(f"tool ({c.name}): {c.content}") + elif isinstance(c, Message): + context_lines.append(f"{c.role}: {c.content}") + elif isinstance(c, ModelOutputThunk): + context_lines.append(f"assistant: {c.value}") + elif isinstance(c, CBlock): + context_lines.append(str(c)) + else: + context_lines.append(str(getattr(c, "content", c))) + + summary_prompt = ( + "You are summarizing research progress to maintain context " + "within token limits.\n\n" + f"GOAL: {goal}\n\n" + "Provide a comprehensive summary of the research context below. " + "Your summary should:\n" + "- Preserve ALL specific facts, numbers, names, URLs, and search " + "queries found\n" + "- Note which tools were called and what results were obtained\n" + "- Highlight key findings and any dead ends encountered\n" + "- Be structured clearly so the research can continue seamlessly" + "\n\nContext to summarize:\n" + f"{chr(10).join(context_lines)}" + ) + + summary_action = Message(role="user", content=summary_prompt) + result, _ = await mfuncs.aact( + action=summary_action, + context=SimpleContext(), + backend=backend, + requirements=[], + strategy=None, + await_result=True, + ) + + summary_text = result.value or "" + summary_message = Message( + role="user", + content=( + f"[CONTEXT SUMMARY]\n{summary_text}\n\nContinue working on: {goal}" + ), + ) + + compacted = [*prefix, summary_message, *recent] + + MelleaLogger.get_logger().info( + f"LLMSummarize(keep_n={self.keep_n}): compacted context from " + f"{len(components)} to {len(compacted)} components" + ) + return rebuild_chat_context(compacted, window_size=context._window_size) diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index 117af4866..7f39bba27 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -15,6 +15,7 @@ from mellea.stdlib import functional as mfuncs # from mellea.stdlib.components.docs.document import Document +from mellea.stdlib.compaction import CompactionStrategy from mellea.stdlib.components.chat import ToolMessage from mellea.stdlib.components.react import ( MELLEA_FINALIZER_TOOL, @@ -36,6 +37,7 @@ async def react( model_options: dict | None = None, tools: list[AbstractMelleaTool] | None, loop_budget: int = 10, + compaction: CompactionStrategy | None = None, ) -> tuple[ComputedModelOutputThunk[str], ChatContext]: """Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools. @@ -47,6 +49,10 @@ async def react( model_options: additional model options, which will upsert into the model/backend's defaults. tools: the list of tools to use loop_budget: the number of steps allowed; use -1 for unlimited + compaction: an optional ``CompactionStrategy`` to apply when the context + exceeds the strategy's configured threshold + (e.g. ``KeepLastN(keep_n=5, threshold=20)``). + Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -79,6 +85,13 @@ async def react( turn_num = 0 while (turn_num < loop_budget) or (loop_budget == -1): turn_num += 1 + + # -- Context compaction -- + if compaction is not None: + context = await compaction.maybe_compact( + context, backend=backend, goal=goal + ) + MelleaLogger.get_logger().info(f"## ReACT TURN NUMBER {turn_num}") step, next_context = await mfuncs.aact( diff --git a/test/stdlib/test_compaction.py b/test/stdlib/test_compaction.py new file mode 100644 index 000000000..9b2ff455d --- /dev/null +++ b/test/stdlib/test_compaction.py @@ -0,0 +1,344 @@ +"""Unit and integration tests for mellea.stdlib.compaction.""" + +from collections.abc import Sequence +from dataclasses import dataclass + +import pytest + +from mellea.core.backend import Backend, BaseModelSubclass +from mellea.core.base import ( + C, + CBlock, + Component, + Context, + GenerateLog, + ModelOutputThunk, + ModelToolCall, +) +from mellea.stdlib.compaction import ( + ClearAll, + KeepLastN, + LLMSummarize, + _find_prefix_end, + rebuild_chat_context, +) +from mellea.stdlib.components.chat import Message +from mellea.stdlib.components.react import ( + MELLEA_FINALIZER_TOOL, + ReactInitiator, + _mellea_finalize_tool, +) +from mellea.stdlib.context import ChatContext +from mellea.stdlib.frameworks.react import react + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_context(components: list[Component | CBlock]) -> ChatContext: + """Build a ChatContext from a list of components.""" + ctx = ChatContext() + for c in components: + ctx = ctx.add(c) + return ctx + + +def _msg(role: Message.Role, content: str) -> Message: + return Message(role=role, content=content) + + +# --------------------------------------------------------------------------- +# rebuild_chat_context +# --------------------------------------------------------------------------- + + +class TestRebuildChatContext: + def test_empty(self): + ctx = rebuild_chat_context([]) + assert ctx.as_list() == [] + + def test_round_trip(self): + components = [_msg("user", "hello"), _msg("assistant", "hi")] + ctx = rebuild_chat_context(components) + result = ctx.as_list() + assert len(result) == 2 + assert all(isinstance(c, Message) for c in result) + + def test_preserves_window_size(self): + ctx = rebuild_chat_context([_msg("user", "a")], window_size=3) + assert ctx._window_size == 3 + + +# --------------------------------------------------------------------------- +# _find_prefix_end +# --------------------------------------------------------------------------- + + +class TestFindPrefixEnd: + def test_no_initiator(self): + components = [_msg("user", "a"), _msg("assistant", "b")] + assert _find_prefix_end(components) == 0 + + def test_initiator_at_start(self): + components = [ReactInitiator("goal", []), _msg("user", "a")] + assert _find_prefix_end(components) == 1 + + def test_initiator_after_system_msg(self): + components = [ + _msg("system", "sys"), + ReactInitiator("goal", []), + _msg("user", "a"), + ] + assert _find_prefix_end(components) == 2 + + +# --------------------------------------------------------------------------- +# should_compact +# --------------------------------------------------------------------------- + + +class TestShouldCompact: + def test_below_threshold(self): + ctx = _build_context([_msg("user", "a"), _msg("assistant", "b")]) + strategy = KeepLastN(keep_n=1, threshold=5) + assert strategy.should_compact(ctx) is False + + def test_above_threshold(self): + ctx = _build_context([_msg("user", str(i)) for i in range(10)]) + strategy = KeepLastN(keep_n=1, threshold=5) + assert strategy.should_compact(ctx) is True + + def test_zero_threshold_never_triggers(self): + ctx = _build_context([_msg("user", str(i)) for i in range(10)]) + strategy = KeepLastN(keep_n=1, threshold=0) + assert strategy.should_compact(ctx) is False + + +# --------------------------------------------------------------------------- +# ClearAll +# --------------------------------------------------------------------------- + + +class TestClearAll: + @pytest.mark.asyncio + async def test_keeps_only_prefix(self): + initiator = ReactInitiator("find the answer", []) + components = [initiator, _msg("user", "a"), _msg("assistant", "b")] + ctx = _build_context(components) + + result = await ClearAll().compact(ctx) + result_list = result.as_list() + assert len(result_list) == 1 + assert isinstance(result_list[0], ReactInitiator) + + @pytest.mark.asyncio + async def test_empty_body_is_noop(self): + initiator = ReactInitiator("goal", []) + ctx = _build_context([initiator]) + + result = await ClearAll().compact(ctx) + assert len(result.as_list()) == 1 + + +# --------------------------------------------------------------------------- +# KeepLastN +# --------------------------------------------------------------------------- + + +class TestKeepLastN: + @pytest.mark.asyncio + async def test_keeps_prefix_and_last_n(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", str(i)) for i in range(10)] + ctx = _build_context([initiator, *body]) + + result = await KeepLastN(keep_n=3).compact(ctx) + result_list = result.as_list() + assert len(result_list) == 4 # 1 prefix + 3 body + assert isinstance(result_list[0], ReactInitiator) + # Last 3 body messages + for i, c in enumerate(result_list[1:]): + assert isinstance(c, Message) + assert c.content == str(7 + i) + + @pytest.mark.asyncio + async def test_fewer_than_n_is_noop(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", "a"), _msg("assistant", "b")] + ctx = _build_context([initiator, *body]) + + result = await KeepLastN(keep_n=5).compact(ctx) + # Should return original context unchanged + assert result is ctx + + @pytest.mark.asyncio + async def test_preserves_window_size(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", str(i)) for i in range(10)] + ctx = rebuild_chat_context([initiator, *body], window_size=7) + + result = await KeepLastN(keep_n=2).compact(ctx) + assert result._window_size == 7 + + +# --------------------------------------------------------------------------- +# LLMSummarize +# --------------------------------------------------------------------------- + + +@dataclass +class _ScriptedTurn: + """A single scripted backend response.""" + + value: str + tool_calls: dict[str, ModelToolCall] | None = None + + +class ScriptedBackend(Backend): + """Fake backend returning pre-scripted responses.""" + + def __init__(self, script: list[_ScriptedTurn]) -> None: + self._script = iter(script) + + async def _generate_from_context( + self, + action: Component[C] | CBlock, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk[C], Context]: + turn = next(self._script) + mot: ModelOutputThunk = ModelOutputThunk( + value=turn.value, tool_calls=turn.tool_calls + ) + mot._generate_log = GenerateLog(is_final_result=True) + return mot, ctx.add(action).add(mot) + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + +class TestLLMSummarize: + @pytest.mark.asyncio + async def test_raises_without_backend(self): + ctx = _build_context([ReactInitiator("g", []), _msg("user", "a")]) + with pytest.raises(ValueError, match="backend"): + await LLMSummarize(keep_n=0).compact(ctx) + + @pytest.mark.asyncio + async def test_raises_without_goal(self): + ctx = _build_context([ReactInitiator("g", []), _msg("user", "a")]) + backend = ScriptedBackend([]) + with pytest.raises(ValueError, match="goal"): + await LLMSummarize(keep_n=0).compact(ctx, backend=backend) + + @pytest.mark.asyncio + async def test_summarizes_old_keeps_recent(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", f"msg-{i}") for i in range(6)] + ctx = _build_context([initiator, *body]) + + # The backend will return one summary when the summarization prompt is sent + backend = ScriptedBackend([_ScriptedTurn(value="Summary of old messages")]) + + result = await LLMSummarize(keep_n=2).compact(ctx, backend=backend, goal="goal") + result_list = result.as_list() + + # prefix (1) + summary message (1) + last 2 body = 4 + assert len(result_list) == 4 + assert isinstance(result_list[0], ReactInitiator) + # Summary message + assert isinstance(result_list[1], Message) + assert "[CONTEXT SUMMARY]" in result_list[1].content + # Recent messages preserved + assert result_list[2].content == "msg-4" + assert result_list[3].content == "msg-5" + + @pytest.mark.asyncio + async def test_fewer_than_n_is_noop(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", "a")] + ctx = _build_context([initiator, *body]) + backend = ScriptedBackend([]) + + result = await LLMSummarize(keep_n=5).compact(ctx, backend=backend, goal="goal") + assert result is ctx + + +# --------------------------------------------------------------------------- +# Integration: react() with compaction +# --------------------------------------------------------------------------- + + +from mellea.backends.tools import MelleaTool + + +def _make_tool(name: str, return_value: str = "tool_result") -> MelleaTool: + def _fn() -> str: + return return_value + + return MelleaTool.from_callable(_fn, name=name) + + +def _final_answer_call(answer: str = "42") -> _ScriptedTurn: + tool = MelleaTool.from_callable(_mellea_finalize_tool, MELLEA_FINALIZER_TOOL) + tc = ModelToolCall(name=MELLEA_FINALIZER_TOOL, func=tool, args={"answer": answer}) + return _ScriptedTurn(value="", tool_calls={MELLEA_FINALIZER_TOOL: tc}) + + +def _tool_call_turn( + tool_name: str, tool: MelleaTool, thought: str = "thinking..." +) -> _ScriptedTurn: + tc = ModelToolCall(name=tool_name, func=tool, args={}) + return _ScriptedTurn(value=thought, tool_calls={tool_name: tc}) + + +class TestReactWithCompaction: + @pytest.mark.asyncio + @pytest.mark.integration + async def test_compaction_triggers_during_react(self): + """Compaction fires when context exceeds threshold, loop still completes.""" + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [ + _tool_call_turn("search", search, "step 1"), + _tool_call_turn("search", search, "step 2"), + _tool_call_turn("search", search, "step 3"), + _final_answer_call("done"), + ] + ) + + result, _ctx = await react( + goal="find info", + context=ChatContext(), + backend=backend, + tools=[search], + loop_budget=10, + compaction=KeepLastN(keep_n=3, threshold=6), + ) + assert result.value == "done" + + @pytest.mark.asyncio + @pytest.mark.integration + async def test_no_compaction_when_disabled(self): + """Without compaction params, react behaves identically to before.""" + backend = ScriptedBackend([_final_answer_call("42")]) + result, _ = await react( + goal="answer", + context=ChatContext(), + backend=backend, + tools=None, + loop_budget=5, + ) + assert result.value == "42" From 455c93d970eed100cf234ba957094fd247849c38 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Tue, 28 Apr 2026 18:58:21 +0000 Subject: [PATCH 2/5] refactor: express compaction threshold as token count Switches `CompactionStrategy.threshold` from a component-count trigger to a token-count trigger, read from the most recent `ModelOutputThunk.usage` populated by the backend. This aligns compaction with the real constraint (context size) and sidesteps per-backend tokenizer dependencies by using provider-reported usage; the trade-off is a one-turn lag since usage is recorded at the end of each model call. Also reorders the react loop so compaction runs after the final-answer check, skipping wasted work (and a wasted LLM call for LLMSummarize) on terminal turns. Assisted-by: Claude Code Signed-off-by: Yousef El-Kurdi --- mellea/stdlib/compaction.py | 67 +++++++++++++++++++++------ mellea/stdlib/frameworks/react.py | 12 ++--- test/stdlib/test_compaction.py | 77 ++++++++++++++++++++++++++----- 3 files changed, 125 insertions(+), 31 deletions(-) diff --git a/mellea/stdlib/compaction.py b/mellea/stdlib/compaction.py index 2f9636fb7..1f98d2eef 100644 --- a/mellea/stdlib/compaction.py +++ b/mellea/stdlib/compaction.py @@ -17,12 +17,13 @@ from mellea.stdlib.compaction import KeepLastN from mellea.stdlib.frameworks.react import react + # Compact once the most recent model call reports > 8000 prompt+completion tokens. await react( goal="...", context=ChatContext(), backend=m.backend, tools=[search_tool], - compaction=KeepLastN(keep_n=5, threshold=20), + compaction=KeepLastN(keep_n=5, threshold=8000), ) """ @@ -72,6 +73,26 @@ def _find_prefix_end(components: list[Component | CBlock]) -> int: return 0 +def _last_usage_tokens(context: ChatContext) -> int | None: + """Return ``total_tokens`` from the most recent ``ModelOutputThunk`` with usage. + + Walks *context* back-to-front looking for a ``ModelOutputThunk`` whose + ``usage`` dict has been populated by a backend's ``post_processing``. + Falls back to ``prompt_tokens + completion_tokens`` when ``total_tokens`` + is missing. Returns ``None`` if no usable token count can be recovered — + typically the case before the first model call completes. + """ + for c in reversed(context.as_list()): + if isinstance(c, ModelOutputThunk) and c.usage is not None: + total = c.usage.get("total_tokens") + if total is None: + pt = c.usage.get("prompt_tokens") or 0 + ct = c.usage.get("completion_tokens") or 0 + total = pt + ct + return total if total and total > 0 else None + return None + + # --------------------------------------------------------------------------- # Abstract base # --------------------------------------------------------------------------- @@ -80,9 +101,16 @@ def _find_prefix_end(components: list[Component | CBlock]) -> int: class CompactionStrategy(abc.ABC): """Abstract base class for context compaction strategies. - Each strategy carries a ``threshold`` — the component count above which - compaction should fire. The :meth:`should_compact` helper checks this so - callers don't need to track the threshold separately. + Each strategy carries a ``threshold`` — the token count above which + compaction should fire. The :meth:`should_compact` helper reads the + most recent ``ModelOutputThunk.usage`` populated by the backend and + compares its total token count to ``threshold``. + + Because ``usage`` is recorded when a model call completes, the measured + token count reflects the context as of the *previous* turn — any + components appended since (e.g. a tool response) are not yet included. + In practice this one-turn lag is negligible unless a single tool call + adds a very large payload. Subclasses implement :meth:`compact` which receives the current ``ChatContext`` and returns a compacted copy. The method is ``async`` @@ -90,25 +118,35 @@ class CompactionStrategy(abc.ABC): transparently; synchronous strategies simply never ``await``. Args: - threshold (int): Trigger compaction when the number of context - components exceeds this value. + threshold (int): Trigger compaction when the most recent thunk's + total token usage exceeds this value. ``0`` disables compaction. """ def __init__(self, *, threshold: int = 0) -> None: - """Initialize with the component-count threshold.""" + """Initialize with the token-count threshold.""" self.threshold = threshold def should_compact(self, context: ChatContext) -> bool: - """Return ``True`` when *context* exceeds the configured threshold. + """Return ``True`` when the last thunk's token usage exceeds ``threshold``. + + Reads ``total_tokens`` from the most recent ``ModelOutputThunk.usage`` + in *context*. Returns ``False`` when no thunk with usage is present + (e.g. before the first model call) or when ``threshold`` is not + positive. Args: context: The context to check. Returns: - ``True`` if the number of components exceeds ``self.threshold`` + ``True`` if the recovered token count exceeds ``self.threshold`` and ``self.threshold`` is greater than 0. """ - return self.threshold > 0 and len(context.as_list()) > self.threshold + if self.threshold <= 0: + return False + tokens = _last_usage_tokens(context) + if tokens is None: + return False + return tokens > self.threshold async def maybe_compact( self, @@ -163,7 +201,8 @@ class ClearAll(CompactionStrategy): The prefix is everything up to and including the first ``ReactInitiator``. Args: - threshold (int): Trigger compaction when context exceeds this many components. + threshold (int): Trigger compaction when the most recent thunk's total + token usage exceeds this value. """ async def compact( @@ -190,7 +229,8 @@ class KeepLastN(CompactionStrategy): Args: keep_n (int): Number of recent body components to retain. - threshold (int): Trigger compaction when context exceeds this many components. + threshold (int): Trigger compaction when the most recent thunk's total + token usage exceeds this value. """ def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: @@ -230,7 +270,8 @@ class LLMSummarize(CompactionStrategy): Args: keep_n (int): Number of recent body components to retain verbatim. - threshold (int): Trigger compaction when context exceeds this many components. + threshold (int): Trigger compaction when the most recent thunk's total + token usage exceeds this value. """ def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index 7f39bba27..baf876ea0 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -86,12 +86,6 @@ async def react( while (turn_num < loop_budget) or (loop_budget == -1): turn_num += 1 - # -- Context compaction -- - if compaction is not None: - context = await compaction.maybe_compact( - context, backend=backend, goal=goal - ) - MelleaLogger.get_logger().info(f"## ReACT TURN NUMBER {turn_num}") step, next_context = await mfuncs.aact( @@ -140,4 +134,10 @@ async def react( step._underlying_value = str(tool_responses[0].content) return step, context + # Compact after the final-answer check so terminal turns skip it. + if compaction is not None: + context = await compaction.maybe_compact( + context, backend=backend, goal=goal + ) + raise RuntimeError(f"could not complete react loop in {loop_budget} iterations") diff --git a/test/stdlib/test_compaction.py b/test/stdlib/test_compaction.py index 9b2ff455d..3f4650e0d 100644 --- a/test/stdlib/test_compaction.py +++ b/test/stdlib/test_compaction.py @@ -20,6 +20,7 @@ KeepLastN, LLMSummarize, _find_prefix_end, + _last_usage_tokens, rebuild_chat_context, ) from mellea.stdlib.components.chat import Message @@ -48,6 +49,17 @@ def _msg(role: Message.Role, content: str) -> Message: return Message(role=role, content=content) +def _thunk(total_tokens: int, value: str = "") -> ModelOutputThunk: + """Build a ModelOutputThunk with a populated usage dict.""" + mot = ModelOutputThunk(value=value) + mot.usage = { + "prompt_tokens": total_tokens, + "completion_tokens": 0, + "total_tokens": total_tokens, + } + return mot + + # --------------------------------------------------------------------------- # rebuild_chat_context # --------------------------------------------------------------------------- @@ -98,19 +110,48 @@ def test_initiator_after_system_msg(self): # --------------------------------------------------------------------------- +class TestLastUsageTokens: + def test_no_thunk_returns_none(self): + ctx = _build_context([_msg("user", "a"), _msg("assistant", "b")]) + assert _last_usage_tokens(ctx) is None + + def test_thunk_without_usage_returns_none(self): + ctx = _build_context([_msg("user", "a"), ModelOutputThunk(value="b")]) + assert _last_usage_tokens(ctx) is None + + def test_reads_total_tokens(self): + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=150)]) + assert _last_usage_tokens(ctx) == 150 + + def test_falls_back_to_prompt_plus_completion(self): + mot = ModelOutputThunk(value="x") + mot.usage = {"prompt_tokens": 40, "completion_tokens": 20} + ctx = _build_context([_msg("user", "a"), mot]) + assert _last_usage_tokens(ctx) == 60 + + def test_uses_most_recent_thunk(self): + ctx = _build_context([_thunk(100), _msg("user", "x"), _thunk(500)]) + assert _last_usage_tokens(ctx) == 500 + + class TestShouldCompact: - def test_below_threshold(self): + def test_no_thunk_does_not_trigger(self): ctx = _build_context([_msg("user", "a"), _msg("assistant", "b")]) - strategy = KeepLastN(keep_n=1, threshold=5) + strategy = KeepLastN(keep_n=1, threshold=100) + assert strategy.should_compact(ctx) is False + + def test_below_threshold(self): + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=50)]) + strategy = KeepLastN(keep_n=1, threshold=100) assert strategy.should_compact(ctx) is False def test_above_threshold(self): - ctx = _build_context([_msg("user", str(i)) for i in range(10)]) - strategy = KeepLastN(keep_n=1, threshold=5) + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=500)]) + strategy = KeepLastN(keep_n=1, threshold=100) assert strategy.should_compact(ctx) is True def test_zero_threshold_never_triggers(self): - ctx = _build_context([_msg("user", str(i)) for i in range(10)]) + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=10_000)]) strategy = KeepLastN(keep_n=1, threshold=0) assert strategy.should_compact(ctx) is False @@ -193,6 +234,7 @@ class _ScriptedTurn: value: str tool_calls: dict[str, ModelToolCall] | None = None + total_tokens: int | None = None class ScriptedBackend(Backend): @@ -215,6 +257,12 @@ async def _generate_from_context( value=turn.value, tool_calls=turn.tool_calls ) mot._generate_log = GenerateLog(is_final_result=True) + if turn.total_tokens is not None: + mot.usage = { + "prompt_tokens": turn.total_tokens, + "completion_tokens": 0, + "total_tokens": turn.total_tokens, + } return mot, ctx.add(action).add(mot) async def generate_from_raw( @@ -298,23 +346,28 @@ def _final_answer_call(answer: str = "42") -> _ScriptedTurn: def _tool_call_turn( - tool_name: str, tool: MelleaTool, thought: str = "thinking..." + tool_name: str, + tool: MelleaTool, + thought: str = "thinking...", + total_tokens: int | None = None, ) -> _ScriptedTurn: tc = ModelToolCall(name=tool_name, func=tool, args={}) - return _ScriptedTurn(value=thought, tool_calls={tool_name: tc}) + return _ScriptedTurn( + value=thought, tool_calls={tool_name: tc}, total_tokens=total_tokens + ) class TestReactWithCompaction: @pytest.mark.asyncio @pytest.mark.integration async def test_compaction_triggers_during_react(self): - """Compaction fires when context exceeds threshold, loop still completes.""" + """Compaction fires when last thunk's token usage exceeds threshold.""" search = _make_tool("search", "found it") backend = ScriptedBackend( [ - _tool_call_turn("search", search, "step 1"), - _tool_call_turn("search", search, "step 2"), - _tool_call_turn("search", search, "step 3"), + _tool_call_turn("search", search, "step 1", total_tokens=200), + _tool_call_turn("search", search, "step 2", total_tokens=200), + _tool_call_turn("search", search, "step 3", total_tokens=200), _final_answer_call("done"), ] ) @@ -325,7 +378,7 @@ async def test_compaction_triggers_during_react(self): backend=backend, tools=[search], loop_budget=10, - compaction=KeepLastN(keep_n=3, threshold=6), + compaction=KeepLastN(keep_n=3, threshold=100), ) assert result.value == "done" From ca7bea1e9ee428c101aa739d2d4fe7f5f84a5c35 Mon Sep 17 00:00:00 2001 From: ramon-astudillo Date: Thu, 30 Apr 2026 13:18:18 -0400 Subject: [PATCH 3/5] Fix mot.generation.usage --- mellea/stdlib/compaction.py | 8 ++++---- mellea/stdlib/frameworks/react.py | 1 - test/stdlib/test_compaction.py | 10 ++++------ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/mellea/stdlib/compaction.py b/mellea/stdlib/compaction.py index 1f98d2eef..20b60f336 100644 --- a/mellea/stdlib/compaction.py +++ b/mellea/stdlib/compaction.py @@ -83,11 +83,11 @@ def _last_usage_tokens(context: ChatContext) -> int | None: typically the case before the first model call completes. """ for c in reversed(context.as_list()): - if isinstance(c, ModelOutputThunk) and c.usage is not None: - total = c.usage.get("total_tokens") + if isinstance(c, ModelOutputThunk) and c.generation.usage is not None: + total = c.generation.usage.get("total_tokens") if total is None: - pt = c.usage.get("prompt_tokens") or 0 - ct = c.usage.get("completion_tokens") or 0 + pt = c.generation.usage.get("prompt_tokens") or 0 + ct = c.generation.usage.get("completion_tokens") or 0 total = pt + ct return total if total and total > 0 else None return None diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index baf876ea0..81dc04146 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -53,7 +53,6 @@ async def react( exceeds the strategy's configured threshold (e.g. ``KeepLastN(keep_n=5, threshold=20)``). - Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. diff --git a/test/stdlib/test_compaction.py b/test/stdlib/test_compaction.py index 3f4650e0d..076faa7f6 100644 --- a/test/stdlib/test_compaction.py +++ b/test/stdlib/test_compaction.py @@ -5,6 +5,7 @@ import pytest +from mellea.backends.tools import MelleaTool from mellea.core.backend import Backend, BaseModelSubclass from mellea.core.base import ( C, @@ -52,7 +53,7 @@ def _msg(role: Message.Role, content: str) -> Message: def _thunk(total_tokens: int, value: str = "") -> ModelOutputThunk: """Build a ModelOutputThunk with a populated usage dict.""" mot = ModelOutputThunk(value=value) - mot.usage = { + mot.generation.usage = { "prompt_tokens": total_tokens, "completion_tokens": 0, "total_tokens": total_tokens, @@ -125,7 +126,7 @@ def test_reads_total_tokens(self): def test_falls_back_to_prompt_plus_completion(self): mot = ModelOutputThunk(value="x") - mot.usage = {"prompt_tokens": 40, "completion_tokens": 20} + mot.generation.usage = {"prompt_tokens": 40, "completion_tokens": 20} ctx = _build_context([_msg("user", "a"), mot]) assert _last_usage_tokens(ctx) == 60 @@ -258,7 +259,7 @@ async def _generate_from_context( ) mot._generate_log = GenerateLog(is_final_result=True) if turn.total_tokens is not None: - mot.usage = { + mot.generation.usage = { "prompt_tokens": turn.total_tokens, "completion_tokens": 0, "total_tokens": turn.total_tokens, @@ -329,9 +330,6 @@ async def test_fewer_than_n_is_noop(self): # --------------------------------------------------------------------------- -from mellea.backends.tools import MelleaTool - - def _make_tool(name: str, return_value: str = "tool_result") -> MelleaTool: def _fn() -> str: return return_value From 643f5a56cbc0a692683b9d5dde3d58d3ca543285 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Thu, 30 Apr 2026 22:19:03 +0000 Subject: [PATCH 4/5] refactor: relocate compaction module into frameworks package Move the compaction strategies alongside the react framework they serve: - mellea/stdlib/compaction.py -> mellea/stdlib/frameworks/react_compaction.py - test/stdlib/test_compaction.py -> test/stdlib/frameworks/test_react_compaction.py Imports and module docstrings updated accordingly. Assisted-by: Claude Code Signed-off-by: Yousef El-Kurdi --- mellea/stdlib/frameworks/react.py | 2 +- .../stdlib/{compaction.py => frameworks/react_compaction.py} | 2 +- .../test_react_compaction.py} | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename mellea/stdlib/{compaction.py => frameworks/react_compaction.py} (99%) rename test/stdlib/{test_compaction.py => frameworks/test_react_compaction.py} (99%) diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index 81dc04146..921781cc0 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -15,7 +15,7 @@ from mellea.stdlib import functional as mfuncs # from mellea.stdlib.components.docs.document import Document -from mellea.stdlib.compaction import CompactionStrategy +from mellea.stdlib.frameworks.react_compaction import CompactionStrategy from mellea.stdlib.components.chat import ToolMessage from mellea.stdlib.components.react import ( MELLEA_FINALIZER_TOOL, diff --git a/mellea/stdlib/compaction.py b/mellea/stdlib/frameworks/react_compaction.py similarity index 99% rename from mellea/stdlib/compaction.py rename to mellea/stdlib/frameworks/react_compaction.py index 20b60f336..111b95524 100644 --- a/mellea/stdlib/compaction.py +++ b/mellea/stdlib/frameworks/react_compaction.py @@ -14,7 +14,7 @@ Example:: - from mellea.stdlib.compaction import KeepLastN + from mellea.stdlib.frameworks.react_compaction import KeepLastN from mellea.stdlib.frameworks.react import react # Compact once the most recent model call reports > 8000 prompt+completion tokens. diff --git a/test/stdlib/test_compaction.py b/test/stdlib/frameworks/test_react_compaction.py similarity index 99% rename from test/stdlib/test_compaction.py rename to test/stdlib/frameworks/test_react_compaction.py index 076faa7f6..07e5e44ce 100644 --- a/test/stdlib/test_compaction.py +++ b/test/stdlib/frameworks/test_react_compaction.py @@ -1,4 +1,4 @@ -"""Unit and integration tests for mellea.stdlib.compaction.""" +"""Unit and integration tests for mellea.stdlib.frameworks.react_compaction.""" from collections.abc import Sequence from dataclasses import dataclass @@ -16,7 +16,7 @@ ModelOutputThunk, ModelToolCall, ) -from mellea.stdlib.compaction import ( +from mellea.stdlib.frameworks.react_compaction import ( ClearAll, KeepLastN, LLMSummarize, From 16e7571a529169483afb82f44a3d61da0c8e8bda Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Wed, 6 May 2026 09:16:42 -0400 Subject: [PATCH 5/5] docs: add Args/Returns sections to react_compaction compact overrides MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The docstring quality gate (tooling/docs-autogen/audit_coverage.py --quality --threshold 100) requires each documented symbol to have its own Args/Returns sections — inheritance from the abstract parent is not consulted. Six issues were reported against the compact() overrides on ClearAll, KeepLastN, and LLMSummarize. Assisted-by: Claude Code Signed-off-by: Yousef El-Kurdi --- mellea/stdlib/frameworks/react_compaction.py | 35 ++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/mellea/stdlib/frameworks/react_compaction.py b/mellea/stdlib/frameworks/react_compaction.py index 111b95524..ccc312b5f 100644 --- a/mellea/stdlib/frameworks/react_compaction.py +++ b/mellea/stdlib/frameworks/react_compaction.py @@ -212,7 +212,16 @@ async def compact( backend: Backend | None = None, goal: str | None = None, ) -> ChatContext: - """Return a context containing only the prefix.""" + """Return a context containing only the prefix. + + Args: + context: The context to compact. + backend: Unused by this strategy; accepted for interface compatibility. + goal: Unused by this strategy; accepted for interface compatibility. + + Returns: + A new ``ChatContext`` containing only the prefix components. + """ components = context.as_list() prefix_end = _find_prefix_end(components) compacted = components[:prefix_end] @@ -245,7 +254,18 @@ async def compact( backend: Backend | None = None, goal: str | None = None, ) -> ChatContext: - """Return a context with the prefix and the last *keep_n* body components.""" + """Return a context with the prefix and the last *keep_n* body components. + + Args: + context: The context to compact. + backend: Unused by this strategy; accepted for interface compatibility. + goal: Unused by this strategy; accepted for interface compatibility. + + Returns: + A new ``ChatContext`` with the prefix plus the most recent *keep_n* + body components, or the original *context* if the body is already + at or below *keep_n* in length. + """ components = context.as_list() prefix_end = _find_prefix_end(components) prefix = components[:prefix_end] @@ -288,6 +308,17 @@ async def compact( ) -> ChatContext: """Return a context with the prefix, an LLM summary, and recent body components. + Args: + context: The context to compact. + backend: Backend used to generate the summary; required. + goal: The react goal string, included in the summary prompt; required. + + Returns: + A new ``ChatContext`` containing the prefix, a single summary + ``Message`` produced by the backend, and the most recent *keep_n* + body components verbatim. Returns the original *context* if the + body is already at or below *keep_n* in length. + Raises: ValueError: If *backend* or *goal* are not provided. """