diff --git a/docs/examples/agents/react/react_using_mellea.py b/docs/examples/agents/react/react_using_mellea.py index 4bd23328a..54d3fc108 100644 --- a/docs/examples/agents/react/react_using_mellea.py +++ b/docs/examples/agents/react/react_using_mellea.py @@ -8,6 +8,7 @@ from langchain_community.tools import DuckDuckGoSearchResults from mellea.backends.tools import MelleaTool +from mellea.stdlib import functional as mfuncs from mellea.stdlib.context import ChatContext from mellea.stdlib.frameworks.react import react from mellea.stdlib.session import start_session @@ -28,15 +29,75 @@ class Email(pydantic.BaseModel): body: str +class TrueOrFalse(pydantic.BaseModel): + """Response indicating whether the ReACT agent has completed its task.""" + + answer: bool = pydantic.Field( + description="True if you have enough information to answer the user's question, False if you need more tool calls" + ) + + +async def last_loop_completion_check( + goal, step, context, backend, model_options, turn_num, loop_budget +): + """Completion check that asks the model if it has the answer on the last iteration. + + Only checks on the last iteration (when turn_num == loop_budget) to avoid + unnecessary LLM calls. Returns False for all other iterations. + + Note: step.value is guaranteed to exist when this is called. + """ + # Only check on last iteration (and not for unlimited budget) + if loop_budget == -1 or turn_num < loop_budget: + return False + + content = mfuncs.chat( + content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.", + context=context, + backend=backend, + format=TrueOrFalse, + )[0].content + have_answer = TrueOrFalse.model_validate_json(content).answer + + return have_answer + + +async def custom_completion_check( + goal, step, context, backend, model_options, turn_num, loop_budget +): + """Custom completion check combining keyword detection and fallback to last-loop check. + + This runs every iteration: + 1. First checks if response contains "final answer" for early termination + 2. On the last iteration, falls back to asking the model if it has the answer + + Note: step.value is guaranteed to exist when this is called. + """ + # Check every iteration for "final answer" keyword (early termination) + if "final answer" in step.value.lower(): + return True + + # On last iteration, fall back to asking the model if it has the answer + if loop_budget != -1 and turn_num >= loop_budget: + return await last_loop_completion_check( + goal, step, context, backend, model_options, turn_num, loop_budget + ) + + return False + + async def main(): """Example.""" - # Simple version that just searches for an answer. + # Version with custom answer check that terminates early + # when the model says "final answer" and queries the LLM + # if it reaches the loop_budget. out, _ = await react( goal="What is the Mellea python library?", context=ChatContext(), backend=m.backend, tools=[search_tool], loop_budget=12, + answer_check=custom_completion_check, ) print(out) @@ -46,6 +107,7 @@ async def main(): # context=ChatContext(), # backend=m.backend, # tools=[search_tool], + # answer_check = custom_completion_check, # format=Email # ) # print(out) diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index 117af4866..3a8cda73f 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -7,6 +7,10 @@ history tracking. Raises ``RuntimeError`` if the loop ends without a final answer. """ +from collections.abc import Awaitable, Callable + +import pydantic + # from PIL import Image as PILImage from mellea.backends.model_options import ModelOption from mellea.core.backend import Backend, BaseModelSubclass @@ -24,6 +28,14 @@ from mellea.stdlib.context import ChatContext +class TrueOrFalse(pydantic.BaseModel): + """Response indicating whether the ReACT agent has completed its task.""" + + answer: bool = pydantic.Field( + description="True if you have enough information to answer the user's question, False if you need more tool calls" + ) + + async def react( goal: str, context: ChatContext, @@ -36,6 +48,19 @@ async def react( model_options: dict | None = None, tools: list[AbstractMelleaTool] | None, loop_budget: int = 10, + answer_check: Callable[ + [ + str, + ComputedModelOutputThunk[str], + ChatContext, + Backend, + dict | None, + int, + int, + ], + Awaitable[bool], + ] + | None = None, ) -> tuple[ComputedModelOutputThunk[str], ChatContext]: """Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools. @@ -47,6 +72,11 @@ async def react( model_options: additional model options, which will upsert into the model/backend's defaults. tools: the list of tools to use loop_budget: the number of steps allowed; use -1 for unlimited + answer_check: optional callable to determine if the agent has completed its task. + Called every iteration when no tool calls are made and step.value exists (if provided). + Receives (goal, step, context, backend, model_options, turn_num, loop_budget). + Returns bool indicating if the task is complete. + If None, no answer check is performed (loop continues until finalizer or budget exhausted). Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -106,9 +136,31 @@ async def react( if tool_res.name == MELLEA_FINALIZER_TOOL: is_final = True + # Check if the agent has completed its task (runs every iteration if answer_check is provided and there's a value) + # The answer_check function can decide when to actually check based on turn_num and loop_budget + elif not is_final and answer_check and step.value: + have_answer = await answer_check( + goal, step, context, backend, model_options, turn_num, loop_budget + ) + + if have_answer: + # Create a synthetic finalizer tool response to be consistent with normal loop + finalizer_response = ToolMessage( + role="tool", + content=step.value or "", + tool_output=step.value or "", + name=MELLEA_FINALIZER_TOOL, + args={}, + tool=None, # type: ignore + ) + tool_responses = [finalizer_response] + context = context.add(finalizer_response) + is_final = True + if is_final: assert len(tool_responses) == 1, "multiple tools were called with 'final'" + # Apply format if requested if format is not None: step, next_context = await mfuncs.aact( action=ReactThought(), diff --git a/test/stdlib/test_react_direct_answer.py b/test/stdlib/test_react_direct_answer.py new file mode 100644 index 000000000..3405cf2d1 --- /dev/null +++ b/test/stdlib/test_react_direct_answer.py @@ -0,0 +1,108 @@ +"""Test ReACT framework handling of direct answers without tool calls.""" + +import pydantic +import pytest + +from mellea.backends.tools import tool +from mellea.stdlib import functional as mfuncs +from mellea.stdlib.context import ChatContext +from mellea.stdlib.frameworks.react import react +from mellea.stdlib.session import start_session + + +class TrueOrFalse(pydantic.BaseModel): + """Response indicating whether the ReACT agent has completed its task.""" + + answer: bool = pydantic.Field( + description="True if you have enough information to answer the user's question, False if you need more tool calls" + ) + + +async def last_loop_completion_check( + goal, step, context, backend, model_options, turn_num, loop_budget +): + """Completion check that asks the model if it has the answer on the last iteration. + + Note: step.value is guaranteed to exist when this is called. + """ + # Only check on last iteration (and not for unlimited budget) + if loop_budget == -1 or turn_num < loop_budget: + return False + + content = mfuncs.chat( + content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.", + context=context, + backend=backend, + format=TrueOrFalse, + )[0].content + have_answer = TrueOrFalse.model_validate_json(content).answer + return have_answer + + +@pytest.mark.ollama +@pytest.mark.llm +async def test_react_direct_answer_without_tools(): + """Test that ReACT handles direct answers when model doesn't call tools. + + This tests the case where the model provides a direct answer in step.value + without making any tool calls. The fix ensures the loop terminates properly + instead of continuing until loop_budget is exhausted. + """ + m = start_session() + + # Ask a simple question that doesn't require tools + # The model should provide a direct answer without calling any tools + out, _ = await react( + goal="What is 2 + 2?", + context=ChatContext(), + backend=m.backend, + tools=[], # No tools provided + loop_budget=3, # Should complete in 1 iteration, not exhaust budget + answer_check=last_loop_completion_check, + ) + + # Verify we got an answer + assert out.value is not None + assert len(out.value) > 0 + + # The answer should contain "4" or "four" + answer_lower = out.value.lower() + assert "4" in answer_lower or "four" in answer_lower + + +@pytest.mark.ollama +@pytest.mark.llm +async def test_react_direct_answer_with_unused_tools(): + """Test that ReACT handles direct answers even when tools are available. + + This tests the case where tools are provided but the model chooses to + answer directly without using them. + """ + m = start_session() + + # Create a dummy tool that won't be needed + @tool + def search_web(query: str) -> str: + """Search the web for information.""" + return "Search results" + + # Ask a question that doesn't need the tool + out, _ = await react( + goal="What is the capital of France?", + context=ChatContext(), + backend=m.backend, + tools=[search_web], + loop_budget=3, + answer_check=last_loop_completion_check, + ) + + # Verify we got an answer + assert out.value is not None + assert len(out.value) > 0 + + # The answer should mention Paris + answer_lower = out.value.lower() + assert "paris" in answer_lower + + +# Made with Bob