diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index 117af4866..921781cc0 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -15,6 +15,7 @@ from mellea.stdlib import functional as mfuncs # from mellea.stdlib.components.docs.document import Document +from mellea.stdlib.frameworks.react_compaction import CompactionStrategy from mellea.stdlib.components.chat import ToolMessage from mellea.stdlib.components.react import ( MELLEA_FINALIZER_TOOL, @@ -36,6 +37,7 @@ async def react( model_options: dict | None = None, tools: list[AbstractMelleaTool] | None, loop_budget: int = 10, + compaction: CompactionStrategy | None = None, ) -> tuple[ComputedModelOutputThunk[str], ChatContext]: """Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools. @@ -47,6 +49,9 @@ async def react( model_options: additional model options, which will upsert into the model/backend's defaults. tools: the list of tools to use loop_budget: the number of steps allowed; use -1 for unlimited + compaction: an optional ``CompactionStrategy`` to apply when the context + exceeds the strategy's configured threshold + (e.g. ``KeepLastN(keep_n=5, threshold=20)``). Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -79,6 +84,7 @@ async def react( turn_num = 0 while (turn_num < loop_budget) or (loop_budget == -1): turn_num += 1 + MelleaLogger.get_logger().info(f"## ReACT TURN NUMBER {turn_num}") step, next_context = await mfuncs.aact( @@ -127,4 +133,10 @@ async def react( step._underlying_value = str(tool_responses[0].content) return step, context + # Compact after the final-answer check so terminal turns skip it. + if compaction is not None: + context = await compaction.maybe_compact( + context, backend=backend, goal=goal + ) + raise RuntimeError(f"could not complete react loop in {loop_budget} iterations") diff --git a/mellea/stdlib/frameworks/react_compaction.py b/mellea/stdlib/frameworks/react_compaction.py new file mode 100644 index 000000000..ccc312b5f --- /dev/null +++ b/mellea/stdlib/frameworks/react_compaction.py @@ -0,0 +1,397 @@ +"""Context compaction strategies for the ReACT framework. + +Provides modular, callable strategy objects to compact a ``ChatContext`` that +has grown too large during a react loop. Three strategies are available: + +- ``ClearAll`` — discard the entire conversation body, keeping only the prefix + (everything up to and including the ``ReactInitiator``). +- ``KeepLastN`` — keep the prefix plus the *n* most recent body components. +- ``LLMSummarize`` — ask the backend to summarize old body components into a + single ``Message``, then keep the last *n* body components verbatim. + +All strategies preserve the **prefix** (every component up to and including the +first ``ReactInitiator``) so the model retains its goal and tool definitions. + +Example:: + + from mellea.stdlib.frameworks.react_compaction import KeepLastN + from mellea.stdlib.frameworks.react import react + + # Compact once the most recent model call reports > 8000 prompt+completion tokens. + await react( + goal="...", + context=ChatContext(), + backend=m.backend, + tools=[search_tool], + compaction=KeepLastN(keep_n=5, threshold=8000), + ) +""" + +from __future__ import annotations + +import abc + +from mellea.core.backend import Backend +from mellea.core.base import CBlock, Component, ModelOutputThunk +from mellea.core.utils import MelleaLogger +from mellea.stdlib.components.chat import Message, ToolMessage +from mellea.stdlib.components.react import ReactInitiator +from mellea.stdlib.context import ChatContext + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def rebuild_chat_context( + components: list[Component | CBlock], *, window_size: int | None = None +) -> ChatContext: + """Build a fresh ``ChatContext`` from an ordered list of components. + + Args: + components: Components to add, in chronological order. + window_size: Optional sliding-window size for the new context. + + Returns: + A new ``ChatContext`` containing all *components*. + """ + ctx = ChatContext(window_size=window_size) + for c in components: + ctx = ctx.add(c) + return ctx + + +def _find_prefix_end(components: list[Component | CBlock]) -> int: + """Return the index *after* the first ``ReactInitiator``. + + Everything in ``components[:idx]`` is the prefix that must be preserved by + every compaction strategy. Returns 0 when no ``ReactInitiator`` is found. + """ + for i, c in enumerate(components): + if isinstance(c, ReactInitiator): + return i + 1 + return 0 + + +def _last_usage_tokens(context: ChatContext) -> int | None: + """Return ``total_tokens`` from the most recent ``ModelOutputThunk`` with usage. + + Walks *context* back-to-front looking for a ``ModelOutputThunk`` whose + ``usage`` dict has been populated by a backend's ``post_processing``. + Falls back to ``prompt_tokens + completion_tokens`` when ``total_tokens`` + is missing. Returns ``None`` if no usable token count can be recovered — + typically the case before the first model call completes. + """ + for c in reversed(context.as_list()): + if isinstance(c, ModelOutputThunk) and c.generation.usage is not None: + total = c.generation.usage.get("total_tokens") + if total is None: + pt = c.generation.usage.get("prompt_tokens") or 0 + ct = c.generation.usage.get("completion_tokens") or 0 + total = pt + ct + return total if total and total > 0 else None + return None + + +# --------------------------------------------------------------------------- +# Abstract base +# --------------------------------------------------------------------------- + + +class CompactionStrategy(abc.ABC): + """Abstract base class for context compaction strategies. + + Each strategy carries a ``threshold`` — the token count above which + compaction should fire. The :meth:`should_compact` helper reads the + most recent ``ModelOutputThunk.usage`` populated by the backend and + compares its total token count to ``threshold``. + + Because ``usage`` is recorded when a model call completes, the measured + token count reflects the context as of the *previous* turn — any + components appended since (e.g. a tool response) are not yet included. + In practice this one-turn lag is negligible unless a single tool call + adds a very large payload. + + Subclasses implement :meth:`compact` which receives the current + ``ChatContext`` and returns a compacted copy. The method is ``async`` + so that strategies requiring LLM calls (e.g. ``LLMSummarize``) work + transparently; synchronous strategies simply never ``await``. + + Args: + threshold (int): Trigger compaction when the most recent thunk's + total token usage exceeds this value. ``0`` disables compaction. + """ + + def __init__(self, *, threshold: int = 0) -> None: + """Initialize with the token-count threshold.""" + self.threshold = threshold + + def should_compact(self, context: ChatContext) -> bool: + """Return ``True`` when the last thunk's token usage exceeds ``threshold``. + + Reads ``total_tokens`` from the most recent ``ModelOutputThunk.usage`` + in *context*. Returns ``False`` when no thunk with usage is present + (e.g. before the first model call) or when ``threshold`` is not + positive. + + Args: + context: The context to check. + + Returns: + ``True`` if the recovered token count exceeds ``self.threshold`` + and ``self.threshold`` is greater than 0. + """ + if self.threshold <= 0: + return False + tokens = _last_usage_tokens(context) + if tokens is None: + return False + return tokens > self.threshold + + async def maybe_compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Compact *context* only if it exceeds the threshold, otherwise return it unchanged. + + Args: + context: The context to check and potentially compact. + backend: The backend (forwarded to :meth:`compact`). + goal: The react goal string (forwarded to :meth:`compact`). + + Returns: + A compacted ``ChatContext`` if the threshold was exceeded, + or the original *context* unchanged. + """ + if self.should_compact(context): + return await self.compact(context, backend=backend, goal=goal) + return context + + @abc.abstractmethod + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a compacted copy of *context*. + + Args: + context: The context to compact. + backend: The backend (required by ``LLMSummarize``). + goal: The react goal string (required by ``LLMSummarize``). + + Returns: + A new, compacted ``ChatContext``. + """ + + +# --------------------------------------------------------------------------- +# Concrete strategies +# --------------------------------------------------------------------------- + + +class ClearAll(CompactionStrategy): + """Discard the entire conversation body, keeping only the prefix. + + The prefix is everything up to and including the first ``ReactInitiator``. + + Args: + threshold (int): Trigger compaction when the most recent thunk's total + token usage exceeds this value. + """ + + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a context containing only the prefix. + + Args: + context: The context to compact. + backend: Unused by this strategy; accepted for interface compatibility. + goal: Unused by this strategy; accepted for interface compatibility. + + Returns: + A new ``ChatContext`` containing only the prefix components. + """ + components = context.as_list() + prefix_end = _find_prefix_end(components) + compacted = components[:prefix_end] + + MelleaLogger.get_logger().info( + f"ClearAll: compacted context from {len(components)} to " + f"{len(compacted)} components" + ) + return rebuild_chat_context(compacted, window_size=context._window_size) + + +class KeepLastN(CompactionStrategy): + """Keep the prefix plus the last *keep_n* body components. + + Args: + keep_n (int): Number of recent body components to retain. + threshold (int): Trigger compaction when the most recent thunk's total + token usage exceeds this value. + """ + + def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: + """Initialize with the number of recent body components to keep.""" + super().__init__(threshold=threshold) + self.keep_n = keep_n + + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a context with the prefix and the last *keep_n* body components. + + Args: + context: The context to compact. + backend: Unused by this strategy; accepted for interface compatibility. + goal: Unused by this strategy; accepted for interface compatibility. + + Returns: + A new ``ChatContext`` with the prefix plus the most recent *keep_n* + body components, or the original *context* if the body is already + at or below *keep_n* in length. + """ + components = context.as_list() + prefix_end = _find_prefix_end(components) + prefix = components[:prefix_end] + body = components[prefix_end:] + + if len(body) <= self.keep_n: + return context # nothing to compact + + compacted = prefix + body[-self.keep_n :] + + MelleaLogger.get_logger().info( + f"KeepLastN(keep_n={self.keep_n}): compacted context from " + f"{len(components)} to {len(compacted)} components" + ) + return rebuild_chat_context(compacted, window_size=context._window_size) + + +class LLMSummarize(CompactionStrategy): + """Summarize old body components with the LLM, keep last *keep_n* verbatim. + + Requires ``backend`` and ``goal`` to be passed to :meth:`compact`. + + Args: + keep_n (int): Number of recent body components to retain verbatim. + threshold (int): Trigger compaction when the most recent thunk's total + token usage exceeds this value. + """ + + def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: + """Initialize with the number of recent body components to keep.""" + super().__init__(threshold=threshold) + self.keep_n = keep_n + + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a context with the prefix, an LLM summary, and recent body components. + + Args: + context: The context to compact. + backend: Backend used to generate the summary; required. + goal: The react goal string, included in the summary prompt; required. + + Returns: + A new ``ChatContext`` containing the prefix, a single summary + ``Message`` produced by the backend, and the most recent *keep_n* + body components verbatim. Returns the original *context* if the + body is already at or below *keep_n* in length. + + Raises: + ValueError: If *backend* or *goal* are not provided. + """ + if backend is None or goal is None: + raise ValueError( + "LLMSummarize requires both 'backend' and 'goal' arguments" + ) + + from mellea.stdlib import functional as mfuncs + from mellea.stdlib.context import SimpleContext + + components = context.as_list() + prefix_end = _find_prefix_end(components) + prefix = components[:prefix_end] + body = components[prefix_end:] + + if len(body) <= self.keep_n: + return context # nothing to compact + + old = body[: -self.keep_n] if self.keep_n > 0 else body + recent = body[-self.keep_n :] if self.keep_n > 0 else [] + + # Build a textual representation of old components for summarization. + context_lines: list[str] = [] + for c in old: + if isinstance(c, ToolMessage): + context_lines.append(f"tool ({c.name}): {c.content}") + elif isinstance(c, Message): + context_lines.append(f"{c.role}: {c.content}") + elif isinstance(c, ModelOutputThunk): + context_lines.append(f"assistant: {c.value}") + elif isinstance(c, CBlock): + context_lines.append(str(c)) + else: + context_lines.append(str(getattr(c, "content", c))) + + summary_prompt = ( + "You are summarizing research progress to maintain context " + "within token limits.\n\n" + f"GOAL: {goal}\n\n" + "Provide a comprehensive summary of the research context below. " + "Your summary should:\n" + "- Preserve ALL specific facts, numbers, names, URLs, and search " + "queries found\n" + "- Note which tools were called and what results were obtained\n" + "- Highlight key findings and any dead ends encountered\n" + "- Be structured clearly so the research can continue seamlessly" + "\n\nContext to summarize:\n" + f"{chr(10).join(context_lines)}" + ) + + summary_action = Message(role="user", content=summary_prompt) + result, _ = await mfuncs.aact( + action=summary_action, + context=SimpleContext(), + backend=backend, + requirements=[], + strategy=None, + await_result=True, + ) + + summary_text = result.value or "" + summary_message = Message( + role="user", + content=( + f"[CONTEXT SUMMARY]\n{summary_text}\n\nContinue working on: {goal}" + ), + ) + + compacted = [*prefix, summary_message, *recent] + + MelleaLogger.get_logger().info( + f"LLMSummarize(keep_n={self.keep_n}): compacted context from " + f"{len(components)} to {len(compacted)} components" + ) + return rebuild_chat_context(compacted, window_size=context._window_size) diff --git a/test/stdlib/frameworks/test_react_compaction.py b/test/stdlib/frameworks/test_react_compaction.py new file mode 100644 index 000000000..07e5e44ce --- /dev/null +++ b/test/stdlib/frameworks/test_react_compaction.py @@ -0,0 +1,395 @@ +"""Unit and integration tests for mellea.stdlib.frameworks.react_compaction.""" + +from collections.abc import Sequence +from dataclasses import dataclass + +import pytest + +from mellea.backends.tools import MelleaTool +from mellea.core.backend import Backend, BaseModelSubclass +from mellea.core.base import ( + C, + CBlock, + Component, + Context, + GenerateLog, + ModelOutputThunk, + ModelToolCall, +) +from mellea.stdlib.frameworks.react_compaction import ( + ClearAll, + KeepLastN, + LLMSummarize, + _find_prefix_end, + _last_usage_tokens, + rebuild_chat_context, +) +from mellea.stdlib.components.chat import Message +from mellea.stdlib.components.react import ( + MELLEA_FINALIZER_TOOL, + ReactInitiator, + _mellea_finalize_tool, +) +from mellea.stdlib.context import ChatContext +from mellea.stdlib.frameworks.react import react + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_context(components: list[Component | CBlock]) -> ChatContext: + """Build a ChatContext from a list of components.""" + ctx = ChatContext() + for c in components: + ctx = ctx.add(c) + return ctx + + +def _msg(role: Message.Role, content: str) -> Message: + return Message(role=role, content=content) + + +def _thunk(total_tokens: int, value: str = "") -> ModelOutputThunk: + """Build a ModelOutputThunk with a populated usage dict.""" + mot = ModelOutputThunk(value=value) + mot.generation.usage = { + "prompt_tokens": total_tokens, + "completion_tokens": 0, + "total_tokens": total_tokens, + } + return mot + + +# --------------------------------------------------------------------------- +# rebuild_chat_context +# --------------------------------------------------------------------------- + + +class TestRebuildChatContext: + def test_empty(self): + ctx = rebuild_chat_context([]) + assert ctx.as_list() == [] + + def test_round_trip(self): + components = [_msg("user", "hello"), _msg("assistant", "hi")] + ctx = rebuild_chat_context(components) + result = ctx.as_list() + assert len(result) == 2 + assert all(isinstance(c, Message) for c in result) + + def test_preserves_window_size(self): + ctx = rebuild_chat_context([_msg("user", "a")], window_size=3) + assert ctx._window_size == 3 + + +# --------------------------------------------------------------------------- +# _find_prefix_end +# --------------------------------------------------------------------------- + + +class TestFindPrefixEnd: + def test_no_initiator(self): + components = [_msg("user", "a"), _msg("assistant", "b")] + assert _find_prefix_end(components) == 0 + + def test_initiator_at_start(self): + components = [ReactInitiator("goal", []), _msg("user", "a")] + assert _find_prefix_end(components) == 1 + + def test_initiator_after_system_msg(self): + components = [ + _msg("system", "sys"), + ReactInitiator("goal", []), + _msg("user", "a"), + ] + assert _find_prefix_end(components) == 2 + + +# --------------------------------------------------------------------------- +# should_compact +# --------------------------------------------------------------------------- + + +class TestLastUsageTokens: + def test_no_thunk_returns_none(self): + ctx = _build_context([_msg("user", "a"), _msg("assistant", "b")]) + assert _last_usage_tokens(ctx) is None + + def test_thunk_without_usage_returns_none(self): + ctx = _build_context([_msg("user", "a"), ModelOutputThunk(value="b")]) + assert _last_usage_tokens(ctx) is None + + def test_reads_total_tokens(self): + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=150)]) + assert _last_usage_tokens(ctx) == 150 + + def test_falls_back_to_prompt_plus_completion(self): + mot = ModelOutputThunk(value="x") + mot.generation.usage = {"prompt_tokens": 40, "completion_tokens": 20} + ctx = _build_context([_msg("user", "a"), mot]) + assert _last_usage_tokens(ctx) == 60 + + def test_uses_most_recent_thunk(self): + ctx = _build_context([_thunk(100), _msg("user", "x"), _thunk(500)]) + assert _last_usage_tokens(ctx) == 500 + + +class TestShouldCompact: + def test_no_thunk_does_not_trigger(self): + ctx = _build_context([_msg("user", "a"), _msg("assistant", "b")]) + strategy = KeepLastN(keep_n=1, threshold=100) + assert strategy.should_compact(ctx) is False + + def test_below_threshold(self): + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=50)]) + strategy = KeepLastN(keep_n=1, threshold=100) + assert strategy.should_compact(ctx) is False + + def test_above_threshold(self): + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=500)]) + strategy = KeepLastN(keep_n=1, threshold=100) + assert strategy.should_compact(ctx) is True + + def test_zero_threshold_never_triggers(self): + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=10_000)]) + strategy = KeepLastN(keep_n=1, threshold=0) + assert strategy.should_compact(ctx) is False + + +# --------------------------------------------------------------------------- +# ClearAll +# --------------------------------------------------------------------------- + + +class TestClearAll: + @pytest.mark.asyncio + async def test_keeps_only_prefix(self): + initiator = ReactInitiator("find the answer", []) + components = [initiator, _msg("user", "a"), _msg("assistant", "b")] + ctx = _build_context(components) + + result = await ClearAll().compact(ctx) + result_list = result.as_list() + assert len(result_list) == 1 + assert isinstance(result_list[0], ReactInitiator) + + @pytest.mark.asyncio + async def test_empty_body_is_noop(self): + initiator = ReactInitiator("goal", []) + ctx = _build_context([initiator]) + + result = await ClearAll().compact(ctx) + assert len(result.as_list()) == 1 + + +# --------------------------------------------------------------------------- +# KeepLastN +# --------------------------------------------------------------------------- + + +class TestKeepLastN: + @pytest.mark.asyncio + async def test_keeps_prefix_and_last_n(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", str(i)) for i in range(10)] + ctx = _build_context([initiator, *body]) + + result = await KeepLastN(keep_n=3).compact(ctx) + result_list = result.as_list() + assert len(result_list) == 4 # 1 prefix + 3 body + assert isinstance(result_list[0], ReactInitiator) + # Last 3 body messages + for i, c in enumerate(result_list[1:]): + assert isinstance(c, Message) + assert c.content == str(7 + i) + + @pytest.mark.asyncio + async def test_fewer_than_n_is_noop(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", "a"), _msg("assistant", "b")] + ctx = _build_context([initiator, *body]) + + result = await KeepLastN(keep_n=5).compact(ctx) + # Should return original context unchanged + assert result is ctx + + @pytest.mark.asyncio + async def test_preserves_window_size(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", str(i)) for i in range(10)] + ctx = rebuild_chat_context([initiator, *body], window_size=7) + + result = await KeepLastN(keep_n=2).compact(ctx) + assert result._window_size == 7 + + +# --------------------------------------------------------------------------- +# LLMSummarize +# --------------------------------------------------------------------------- + + +@dataclass +class _ScriptedTurn: + """A single scripted backend response.""" + + value: str + tool_calls: dict[str, ModelToolCall] | None = None + total_tokens: int | None = None + + +class ScriptedBackend(Backend): + """Fake backend returning pre-scripted responses.""" + + def __init__(self, script: list[_ScriptedTurn]) -> None: + self._script = iter(script) + + async def _generate_from_context( + self, + action: Component[C] | CBlock, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk[C], Context]: + turn = next(self._script) + mot: ModelOutputThunk = ModelOutputThunk( + value=turn.value, tool_calls=turn.tool_calls + ) + mot._generate_log = GenerateLog(is_final_result=True) + if turn.total_tokens is not None: + mot.generation.usage = { + "prompt_tokens": turn.total_tokens, + "completion_tokens": 0, + "total_tokens": turn.total_tokens, + } + return mot, ctx.add(action).add(mot) + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + +class TestLLMSummarize: + @pytest.mark.asyncio + async def test_raises_without_backend(self): + ctx = _build_context([ReactInitiator("g", []), _msg("user", "a")]) + with pytest.raises(ValueError, match="backend"): + await LLMSummarize(keep_n=0).compact(ctx) + + @pytest.mark.asyncio + async def test_raises_without_goal(self): + ctx = _build_context([ReactInitiator("g", []), _msg("user", "a")]) + backend = ScriptedBackend([]) + with pytest.raises(ValueError, match="goal"): + await LLMSummarize(keep_n=0).compact(ctx, backend=backend) + + @pytest.mark.asyncio + async def test_summarizes_old_keeps_recent(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", f"msg-{i}") for i in range(6)] + ctx = _build_context([initiator, *body]) + + # The backend will return one summary when the summarization prompt is sent + backend = ScriptedBackend([_ScriptedTurn(value="Summary of old messages")]) + + result = await LLMSummarize(keep_n=2).compact(ctx, backend=backend, goal="goal") + result_list = result.as_list() + + # prefix (1) + summary message (1) + last 2 body = 4 + assert len(result_list) == 4 + assert isinstance(result_list[0], ReactInitiator) + # Summary message + assert isinstance(result_list[1], Message) + assert "[CONTEXT SUMMARY]" in result_list[1].content + # Recent messages preserved + assert result_list[2].content == "msg-4" + assert result_list[3].content == "msg-5" + + @pytest.mark.asyncio + async def test_fewer_than_n_is_noop(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", "a")] + ctx = _build_context([initiator, *body]) + backend = ScriptedBackend([]) + + result = await LLMSummarize(keep_n=5).compact(ctx, backend=backend, goal="goal") + assert result is ctx + + +# --------------------------------------------------------------------------- +# Integration: react() with compaction +# --------------------------------------------------------------------------- + + +def _make_tool(name: str, return_value: str = "tool_result") -> MelleaTool: + def _fn() -> str: + return return_value + + return MelleaTool.from_callable(_fn, name=name) + + +def _final_answer_call(answer: str = "42") -> _ScriptedTurn: + tool = MelleaTool.from_callable(_mellea_finalize_tool, MELLEA_FINALIZER_TOOL) + tc = ModelToolCall(name=MELLEA_FINALIZER_TOOL, func=tool, args={"answer": answer}) + return _ScriptedTurn(value="", tool_calls={MELLEA_FINALIZER_TOOL: tc}) + + +def _tool_call_turn( + tool_name: str, + tool: MelleaTool, + thought: str = "thinking...", + total_tokens: int | None = None, +) -> _ScriptedTurn: + tc = ModelToolCall(name=tool_name, func=tool, args={}) + return _ScriptedTurn( + value=thought, tool_calls={tool_name: tc}, total_tokens=total_tokens + ) + + +class TestReactWithCompaction: + @pytest.mark.asyncio + @pytest.mark.integration + async def test_compaction_triggers_during_react(self): + """Compaction fires when last thunk's token usage exceeds threshold.""" + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [ + _tool_call_turn("search", search, "step 1", total_tokens=200), + _tool_call_turn("search", search, "step 2", total_tokens=200), + _tool_call_turn("search", search, "step 3", total_tokens=200), + _final_answer_call("done"), + ] + ) + + result, _ctx = await react( + goal="find info", + context=ChatContext(), + backend=backend, + tools=[search], + loop_budget=10, + compaction=KeepLastN(keep_n=3, threshold=100), + ) + assert result.value == "done" + + @pytest.mark.asyncio + @pytest.mark.integration + async def test_no_compaction_when_disabled(self): + """Without compaction params, react behaves identically to before.""" + backend = ScriptedBackend([_final_answer_call("42")]) + result, _ = await react( + goal="answer", + context=ChatContext(), + backend=backend, + tools=None, + loop_budget=5, + ) + assert result.value == "42"