From c2113e3d56eada0882c0599c2f6bd277f6ff828a Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Mon, 13 Apr 2026 16:37:36 -0700 Subject: [PATCH 01/20] feat: add tool calling support to m serve Signed-off-by: Mark Sturdevant --- cli/serve/app.py | 43 ++- cli/serve/models.py | 26 ++ docs/examples/m_serve/client_tool_calling.py | 208 +++++++++++++ .../m_serve/m_serve_example_tool_calling.py | 177 +++++++++++ test/cli/test_serve.py | 16 +- test/cli/test_serve_tool_calling.py | 290 ++++++++++++++++++ 6 files changed, 750 insertions(+), 10 deletions(-) create mode 100644 docs/examples/m_serve/client_tool_calling.py create mode 100644 docs/examples/m_serve/m_serve_example_tool_calling.py create mode 100644 test/cli/test_serve_tool_calling.py diff --git a/cli/serve/app.py b/cli/serve/app.py index 583b28c01..3307978c5 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -3,10 +3,12 @@ import asyncio import importlib.util import inspect +import json import os import sys import time import uuid +from typing import Literal try: import typer @@ -26,10 +28,12 @@ from .models import ( ChatCompletion, ChatCompletionMessage, + ChatCompletionMessageToolCall, ChatCompletionRequest, Choice, OpenAIError, OpenAIErrorResponse, + ToolCallFunction, ) from .streaming import stream_chat_completion_chunks @@ -111,14 +115,14 @@ def _build_model_options(request: ChatCompletionRequest) -> dict: "response_format", # Response format (json_object) - not yet implemented "functions", # Legacy function calling - not yet implemented "function_call", # Legacy function calling - not yet implemented - "tools", # Tool calling - not yet implemented - "tool_choice", # Tool choice - not yet implemented + # Tool choice is passed through as-is (not a ModelOption sentinel) } openai_to_model_option = { "temperature": ModelOption.TEMPERATURE, "max_tokens": ModelOption.MAX_NEW_TOKENS, "seed": ModelOption.SEED, "stream": ModelOption.STREAM, + "tools": ModelOption.TOOLS, } # Get all non-None fields @@ -171,6 +175,35 @@ async def endpoint(request: ChatCompletionRequest): model_options=model_options, ) + # Extract tool calls from the ModelOutputThunk if available + tool_calls = None + finish_reason: Literal[ + "stop", "length", "content_filter", "tool_calls", "function_call" + ] = "stop" + if ( + hasattr(output, "tool_calls") + and output.tool_calls is not None + and isinstance(output.tool_calls, dict) + ): + tool_calls = [] + for tool_name, model_tool_call in output.tool_calls.items(): + # Generate a unique ID for this tool call + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + + # Serialize the arguments to JSON string + args_json = json.dumps(model_tool_call.args) + + tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call_id, + type="function", + function=ToolCallFunction( + name=model_tool_call.name, arguments=args_json + ), + ) + ) + finish_reason = "tool_calls" + # system_fingerprint represents backend config hash, not model name # The model name is already in response.model (line 73) # Leave as None since we don't track backend config fingerprints yet @@ -198,9 +231,11 @@ async def endpoint(request: ChatCompletionRequest): Choice( index=0, message=ChatCompletionMessage( - content=output.value, role="assistant" + content=output.value, + role="assistant", + tool_calls=tool_calls, ), - finish_reason="stop", + finish_reason=finish_reason, ) ], object="chat.completion", # type: ignore diff --git a/cli/serve/models.py b/cli/serve/models.py index 7e738730e..ba0bd8cca 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -80,6 +80,29 @@ class ChatCompletionRequest(BaseModel): extra: dict[str, Any] = Field(default_factory=dict) +class ToolCallFunction(BaseModel): + """Function details for a tool call.""" + + name: str + """The name of the function to call.""" + + arguments: str + """The arguments to call the function with, as a JSON string.""" + + +class ChatCompletionMessageToolCall(BaseModel): + """A tool call generated by the model.""" + + id: str + """The ID of the tool call.""" + + type: Literal["function"] + """The type of the tool. Currently, only 'function' is supported.""" + + function: ToolCallFunction + """The function that the model called.""" + + # Taking this from OpenAI types https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion.py, class ChatCompletionMessage(BaseModel): content: str | None = None @@ -91,6 +114,9 @@ class ChatCompletionMessage(BaseModel): role: Literal["assistant"] """The role of the author of this message.""" + tool_calls: list[ChatCompletionMessageToolCall] | None = None + """The tool calls generated by the model, such as function calls.""" + class Choice(BaseModel): index: int diff --git a/docs/examples/m_serve/client_tool_calling.py b/docs/examples/m_serve/client_tool_calling.py new file mode 100644 index 000000000..fca522c76 --- /dev/null +++ b/docs/examples/m_serve/client_tool_calling.py @@ -0,0 +1,208 @@ +"""Client example for testing tool calling with m serve. + +This script demonstrates how to interact with an m serve server +that supports tool calling using the OpenAI-compatible API. + +Usage: + 1. Start the server: + uv run m serve docs/examples/m_serve/m_serve_example_tool_calling.py + + 2. Run this client: + uv run python docs/examples/m_serve/client_tool_calling.py +""" + +import json + +import requests + +# Server configuration +BASE_URL = "http://localhost:8080" +ENDPOINT = f"{BASE_URL}/v1/chat/completions" + +# Define tools in OpenAI format +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. San Francisco", + }, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "calculator", + "description": "Evaluate a mathematical expression", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression to evaluate", + } + }, + "required": ["expression"], + }, + }, + }, +] + + +def make_request(messages: list[dict], tools: list[dict] | None = None) -> dict: + """Make a request to the m serve API. + + Args: + messages: List of message dictionaries + tools: Optional list of tool definitions + + Returns: + Response dictionary from the API + """ + payload = { + "model": "gpt-3.5-turbo", # Model name (not used by m serve) + "messages": messages, + "temperature": 0.7, + } + + if tools: + payload["tools"] = tools + payload["tool_choice"] = "auto" + + response = requests.post(ENDPOINT, json=payload, timeout=30) + response.raise_for_status() + return response.json() + + +def main(): + """Run example tool calling interactions.""" + print("=" * 60) + print("Tool Calling Example with m serve") + print("=" * 60) + + # Example 1: Request that should trigger weather tool + print("\n1. Weather Query") + print("-" * 60) + messages = [{"role": "user", "content": "What's the weather like in Tokyo?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=tools) + + choice = response["choices"][0] + print(f"\nFinish Reason: {choice['finish_reason']}") + + if choice.get("message", {}).get("tool_calls"): + print("\nTool Calls:") + for tool_call in choice["message"]["tool_calls"]: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + else: + print(f"Assistant: {choice['message']['content']}") + + # Example 2: Request that should trigger calculator tool + print("\n\n2. Math Query") + print("-" * 60) + messages = [{"role": "user", "content": "What is 15 * 23 + 7?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=tools) + + choice = response["choices"][0] + print(f"\nFinish Reason: {choice['finish_reason']}") + + if choice.get("message", {}).get("tool_calls"): + print("\nTool Calls:") + for tool_call in choice["message"]["tool_calls"]: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + else: + print(f"Assistant: {choice['message']['content']}") + + # Example 3: Request without tools (normal chat) + print("\n\n3. Normal Chat (No Tools)") + print("-" * 60) + messages = [{"role": "user", "content": "Hello! How are you?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=None) + + choice = response["choices"][0] + print(f"\nFinish Reason: {choice['finish_reason']}") + print(f"Assistant: {choice['message']['content']}") + + # Example 4: Multi-turn conversation with tool use + print("\n\n4. Multi-turn Conversation") + print("-" * 60) + messages = [{"role": "user", "content": "What's the weather in Paris?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=tools) + + choice = response["choices"][0] + assistant_message = choice["message"] + + if assistant_message.get("tool_calls"): + print("\nAssistant requested tool calls:") + for tool_call in assistant_message["tool_calls"]: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + + # Simulate tool execution + if func["name"] == "get_weather": + tool_result = f"The weather in {args['location']} is sunny and 22°C" + else: + tool_result = "Tool result" + + # Add tool response to conversation + messages.append( + { + "role": "assistant", + "content": assistant_message.get("content"), + "tool_calls": assistant_message["tool_calls"], + } + ) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call["id"], + "content": tool_result, + } + ) + + # Get final response after tool execution + print("\nGetting final response after tool execution...") + response = make_request(messages, tools=tools) + choice = response["choices"][0] + print(f"Assistant: {choice['message']['content']}") + + print("\n" + "=" * 60) + print("Examples completed!") + print("=" * 60) + + +if __name__ == "__main__": + try: + main() + except requests.exceptions.ConnectionError: + print("Error: Could not connect to server.") + print("Make sure the server is running:") + print(" uv run m serve docs/examples/m_serve/m_serve_example_tool_calling.py") + except Exception as e: + print(f"Error: {e}") diff --git a/docs/examples/m_serve/m_serve_example_tool_calling.py b/docs/examples/m_serve/m_serve_example_tool_calling.py new file mode 100644 index 000000000..e7dbedd29 --- /dev/null +++ b/docs/examples/m_serve/m_serve_example_tool_calling.py @@ -0,0 +1,177 @@ +# pytest: ollama, e2e + +"""Example demonstrating tool calling with m serve. + +This example shows how to use the OpenAI-compatible tool calling API +with m serve. The server will accept tool definitions and return tool +calls in the response when the model decides to use them. +""" + +from typing import Any + +import mellea +from cli.serve.models import ChatMessage +from mellea.core import ModelOutputThunk, Requirement +from mellea.core.base import AbstractMelleaTool +from mellea.stdlib.context import ChatContext + +session = mellea.start_session(ctx=ChatContext()) + + +class GetWeatherTool(AbstractMelleaTool): + """Tool for getting weather information.""" + + name = "get_weather" + + def run(self, location: str, units: str = "celsius") -> str: + """Get the current weather for a location. + + Args: + location: The city name + units: Temperature units (celsius or fahrenheit) + + Returns: + Weather information as a string + """ + # In a real implementation, this would call a weather API + return f"The weather in {location} is sunny and 22°{units[0].upper()}" + + @property + def as_json_tool(self) -> dict[str, Any]: + """Return JSON schema for this tool.""" + return { + "type": "function", + "function": { + "name": self.name, + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. San Francisco", + }, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + }, + }, + } + + +class CalculatorTool(AbstractMelleaTool): + """Tool for performing calculations.""" + + name = "calculator" + + def run(self, expression: str) -> str: + """Evaluate a mathematical expression. + + Args: + expression: A mathematical expression to evaluate + + Returns: + The result of the calculation + """ + try: + # In a real implementation, use a safe expression evaluator + result = eval(expression) # noqa: S307 + return f"The result is {result}" + except Exception as e: + return f"Error evaluating expression: {e}" + + @property + def as_json_tool(self) -> dict[str, Any]: + """Return JSON schema for this tool.""" + return { + "type": "function", + "function": { + "name": self.name, + "description": "Evaluate a mathematical expression", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression to evaluate, e.g. '2 + 2'", + } + }, + "required": ["expression"], + }, + }, + } + + +# Create tool instances +weather_tool = GetWeatherTool() +calculator_tool = CalculatorTool() + +# Map tool names to instances for easy lookup +TOOLS = {weather_tool.name: weather_tool, calculator_tool.name: calculator_tool} + + +def serve( + input: list[ChatMessage], + requirements: list[str] | None = None, + model_options: None | dict = None, +) -> ModelOutputThunk: + """Serve function that handles tool calling. + + This function demonstrates how to use tools with m serve. The tools + are passed via model_options and the model can request to call them. + + Args: + input: List of chat messages + requirements: Optional list of requirement strings + model_options: Model options including tools and tool_choice + + Returns: + ModelOutputThunk with potential tool calls + """ + requirements = requirements if requirements else [] + message = input[-1].content + + # Extract tools from model_options if provided + tools = None + if model_options and "@@@tools@@@" in model_options: + # Convert OpenAI tool format to Mellea tool format + openai_tools = model_options["@@@tools@@@"] + tools = {} + for tool_def in openai_tools: + tool_name = tool_def["function"]["name"] + if tool_name in TOOLS: + tools[tool_name] = TOOLS[tool_name] + + # Build model options with tools + final_model_options = model_options or {} + if tools: + final_model_options["@@@tools@@@"] = tools + + # Use instruct to generate response with potential tool calls + result = session.instruct( + description=message, # type: ignore + requirements=[Requirement(req) for req in requirements], # type: ignore + model_options=final_model_options, + ) + + return result + + +if __name__ == "__main__": + # Example usage (for testing purposes) + test_messages = [ChatMessage(role="user", content="What's the weather in Paris?")] + + # Simulate tool definitions being passed + test_model_options = { + "@@@tools@@@": [weather_tool.as_json_tool, calculator_tool.as_json_tool] + } + + response = serve(input=test_messages, model_options=test_model_options) + + print(f"Response: {response.value}") + if response.tool_calls: + print(f"Tool calls requested: {list(response.tool_calls.keys())}") diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 2e626e6d5..e20deb6e3 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -451,18 +451,19 @@ async def test_unsupported_params_excluded_from_model_options(self, mock_module) assert "logit_bias" not in model_options @pytest.mark.asyncio - async def test_tool_params_excluded_from_model_options(self, mock_module): - """Test that tool-related parameters are excluded from model_options.""" + async def test_tool_params_passed_to_model_options(self, mock_module): + """Test that tool-related parameters are passed to model_options.""" from cli.serve.models import ( FunctionDefinition, FunctionParameters, ToolFunction, ) + from mellea.backends.model_options import ModelOption request = ChatCompletionRequest( model="test-model", messages=[ChatMessage(role="user", content="Hello")], - # Tool-related parameters that should be excluded + # Tool-related parameters tools=[ ToolFunction( type="function", @@ -498,9 +499,12 @@ async def test_tool_params_excluded_from_model_options(self, mock_module): assert call_args is not None model_options = call_args.kwargs["model_options"] - # Tool-related parameters should NOT be in model_options - assert "tools" not in model_options - assert "tool_choice" not in model_options + # Tools should be passed with ModelOption.TOOLS key + assert ModelOption.TOOLS in model_options + # tool_choice should be passed through as-is + assert "tool_choice" in model_options + assert model_options["tool_choice"] == "auto" + # Legacy function calling parameters should still be excluded assert "functions" not in model_options assert "function_call" not in model_options diff --git a/test/cli/test_serve_tool_calling.py b/test/cli/test_serve_tool_calling.py new file mode 100644 index 000000000..ffdb8f3e3 --- /dev/null +++ b/test/cli/test_serve_tool_calling.py @@ -0,0 +1,290 @@ +"""Tests for tool calling support in m serve OpenAI-compatible API server.""" + +import json +from typing import Any +from unittest.mock import Mock + +import pytest + +from cli.serve.app import make_chat_endpoint +from cli.serve.models import ( + ChatCompletion, + ChatCompletionRequest, + ChatMessage, + FunctionDefinition, + FunctionParameters, + ToolFunction, +) +from mellea.core.base import AbstractMelleaTool, ModelOutputThunk, ModelToolCall + + +class MockTool(AbstractMelleaTool): + """Mock tool for testing.""" + + name = "get_weather" + + def run(self, location: str) -> str: + """Mock run method.""" + return f"Weather in {location} is sunny" + + @property + def as_json_tool(self) -> dict[str, Any]: + """Return JSON schema for this tool.""" + return { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, + } + + +@pytest.fixture +def mock_module(): + """Create a mock module with a serve function.""" + module = Mock() + module.__name__ = "test_module" + return module + + +@pytest.fixture +def sample_tool_request(): + """Create a sample ChatCompletionRequest with tools.""" + return ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="What's the weather in Paris?")], + tools=[ + ToolFunction( + type="function", + function=FunctionDefinition( + name="get_weather", + description="Get the current weather in a location", + parameters=FunctionParameters( + RootModel={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name", + } + }, + "required": ["location"], + } + ), + ), + ) + ], + tool_choice="auto", + ) + + +class TestToolCalling: + """Tests for tool calling functionality.""" + + @pytest.mark.asyncio + async def test_tool_calls_in_response(self, mock_module, sample_tool_request): + """Test that tool calls are properly formatted in the response.""" + # Setup mock output with tool calls + mock_output = ModelOutputThunk("I'll check the weather for you.") + mock_tool = MockTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_module.serve.return_value = mock_output + + # Create endpoint and call it + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify response structure + assert isinstance(response, ChatCompletion) + assert response.choices[0].finish_reason == "tool_calls" + assert response.choices[0].message.tool_calls is not None + assert len(response.choices[0].message.tool_calls) == 1 + + # Verify tool call details + tool_call = response.choices[0].message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_weather" + + # Parse and verify arguments + args = json.loads(tool_call.function.arguments) + assert args == {"location": "Paris"} + + # Verify tool call ID format + assert tool_call.id.startswith("call_") + assert len(tool_call.id) > len("call_") + + @pytest.mark.asyncio + async def test_multiple_tool_calls(self, mock_module, sample_tool_request): + """Test handling multiple tool calls in a single response.""" + mock_output = ModelOutputThunk("I'll check multiple locations.") + mock_tool = MockTool() + mock_output.tool_calls = { + "get_weather_paris": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ), + "get_weather_london": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "London"} + ), + } + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify multiple tool calls + assert response.choices[0].finish_reason == "tool_calls" + assert len(response.choices[0].message.tool_calls) == 2 + + # Verify each tool call has unique ID + ids = [tc.id for tc in response.choices[0].message.tool_calls] + assert len(ids) == len(set(ids)), "Tool call IDs should be unique" + + @pytest.mark.asyncio + async def test_no_tool_calls_finish_reason_stop( + self, mock_module, sample_tool_request + ): + """Test that finish_reason is 'stop' when no tool calls are made.""" + mock_output = ModelOutputThunk("The weather is sunny.") + # No tool_calls set + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + assert response.choices[0].finish_reason == "stop" + assert response.choices[0].message.tool_calls is None + + @pytest.mark.asyncio + async def test_tools_passed_to_model_options( + self, mock_module, sample_tool_request + ): + """Test that tools are passed to serve function in model_options.""" + from mellea.backends.model_options import ModelOption + + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + await endpoint(sample_tool_request) + + # Verify serve was called with tools in model_options + call_args = mock_module.serve.call_args + assert call_args is not None + model_options = call_args.kwargs["model_options"] + + # Tools should be in model_options with the ModelOption.TOOLS key + assert ModelOption.TOOLS in model_options + assert model_options[ModelOption.TOOLS] is not None + + @pytest.mark.asyncio + async def test_tool_choice_passed_to_model_options( + self, mock_module, sample_tool_request + ): + """Test that tool_choice is passed to serve function in model_options.""" + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + await endpoint(sample_tool_request) + + # Verify serve was called with tool_choice in model_options + call_args = mock_module.serve.call_args + assert call_args is not None + model_options = call_args.kwargs["model_options"] + + # tool_choice should be passed through as-is + assert "tool_choice" in model_options + assert model_options["tool_choice"] == "auto" + + @pytest.mark.asyncio + async def test_tool_calls_with_complex_arguments( + self, mock_module, sample_tool_request + ): + """Test tool calls with complex nested arguments.""" + mock_output = ModelOutputThunk("Processing complex request.") + mock_tool = MockTool() + mock_output.tool_calls = { + "complex_tool": ModelToolCall( + name="complex_function", + func=mock_tool, + args={ + "location": "Paris", + "options": { + "units": "celsius", + "include_forecast": True, + "days": 5, + }, + "tags": ["weather", "forecast"], + }, + ) + } + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify complex arguments are properly serialized + tool_call = response.choices[0].message.tool_calls[0] + args = json.loads(tool_call.function.arguments) + + assert args["location"] == "Paris" + assert args["options"]["units"] == "celsius" + assert args["options"]["include_forecast"] is True + assert args["options"]["days"] == 5 + assert args["tags"] == ["weather", "forecast"] + + @pytest.mark.asyncio + async def test_tool_calls_with_usage_info(self, mock_module, sample_tool_request): + """Test that usage info is included alongside tool calls.""" + mock_output = ModelOutputThunk("Calling tool.") + mock_tool = MockTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_output.usage = { + "prompt_tokens": 50, + "completion_tokens": 20, + "total_tokens": 70, + } + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify both tool calls and usage are present + assert response.choices[0].finish_reason == "tool_calls" + assert response.choices[0].message.tool_calls is not None + assert response.usage is not None + assert response.usage.total_tokens == 70 + + @pytest.mark.asyncio + async def test_request_without_tools(self, mock_module): + """Test that requests without tools still work normally.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + # No tools specified + ) + + mock_output = ModelOutputThunk("Hello! How can I help?") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should work normally without tool-related fields + assert isinstance(response, ChatCompletion) + assert response.choices[0].finish_reason == "stop" + assert response.choices[0].message.tool_calls is None + assert response.choices[0].message.content == "Hello! How can I help?" From 82a8fb7c9de8f1c663add386f8fd74ec93b75f8e Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 14 Apr 2026 14:08:54 -0700 Subject: [PATCH 02/20] fix: fixed the bug in m serve where finish_reason=tool_calls for empty dict Fixed the bug where an empty tool_calls dict ({}) incorrectly produced finish_reason="tool_calls" with an empty array instead of finish_reason="stop" with tool_calls=None. Signed-off-by: Mark Sturdevant --- cli/serve/app.py | 1 + test/cli/test_serve_tool_calling.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/cli/serve/app.py b/cli/serve/app.py index 3307978c5..20ddad051 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -184,6 +184,7 @@ async def endpoint(request: ChatCompletionRequest): hasattr(output, "tool_calls") and output.tool_calls is not None and isinstance(output.tool_calls, dict) + and output.tool_calls # Check dict is not empty ): tool_calls = [] for tool_name, model_tool_call in output.tool_calls.items(): diff --git a/test/cli/test_serve_tool_calling.py b/test/cli/test_serve_tool_calling.py index ffdb8f3e3..0b31b7837 100644 --- a/test/cli/test_serve_tool_calling.py +++ b/test/cli/test_serve_tool_calling.py @@ -163,6 +163,28 @@ async def test_no_tool_calls_finish_reason_stop( assert response.choices[0].finish_reason == "stop" assert response.choices[0].message.tool_calls is None + @pytest.mark.asyncio + async def test_empty_tool_calls_dict_finish_reason_stop( + self, mock_module, sample_tool_request + ): + """Test that finish_reason is 'stop' when tool_calls is an empty dict. + + Regression test for bug where empty tool_calls dict {} produces + finish_reason='tool_calls' with an empty array instead of + finish_reason='stop' with tool_calls=None. + """ + mock_output = ModelOutputThunk("Hello! How can I help?") + # Set tool_calls to empty dict (the bug case) + mock_output.tool_calls = {} + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Should behave like no tool calls at all + assert response.choices[0].finish_reason == "stop" + assert response.choices[0].message.tool_calls is None + @pytest.mark.asyncio async def test_tools_passed_to_model_options( self, mock_module, sample_tool_request From d897e4151e0b8ba9a47aa24b7e8a5802caae27fd Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 14 Apr 2026 14:12:15 -0700 Subject: [PATCH 03/20] fix: move message add to outside the loop in client_tool_calling.py example Issue: The assistant message was being added inside the loop for each tool call, causing duplication when multiple tool calls were present. Fix: Moved the assistant message append outside the loop (before processing tool calls), so it's only added once. Now the loop only adds tool responses. Signed-off-by: Mark Sturdevant --- docs/examples/m_serve/client_tool_calling.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/docs/examples/m_serve/client_tool_calling.py b/docs/examples/m_serve/client_tool_calling.py index fca522c76..6ffc3cc24 100644 --- a/docs/examples/m_serve/client_tool_calling.py +++ b/docs/examples/m_serve/client_tool_calling.py @@ -159,6 +159,17 @@ def main(): if assistant_message.get("tool_calls"): print("\nAssistant requested tool calls:") + + # Add assistant message once before processing tool calls + messages.append( + { + "role": "assistant", + "content": assistant_message.get("content"), + "tool_calls": assistant_message["tool_calls"], + } + ) + + # Process each tool call and add tool responses for tool_call in assistant_message["tool_calls"]: func = tool_call["function"] args = json.loads(func["arguments"]) @@ -171,13 +182,6 @@ def main(): tool_result = "Tool result" # Add tool response to conversation - messages.append( - { - "role": "assistant", - "content": assistant_message.get("content"), - "tool_calls": assistant_message["tool_calls"], - } - ) messages.append( { "role": "tool", From 128a9c6fab0e1dc67979b2e4a51bef045a5405e8 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 14 Apr 2026 14:18:51 -0700 Subject: [PATCH 04/20] fix: cli app.py loop variable tool_name is never used MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The dict key tool_name is never used — the function name comes from model_tool_call.name. Using .values() instead. Signed-off-by: Mark Sturdevant --- cli/serve/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/serve/app.py b/cli/serve/app.py index 20ddad051..cb37d93ce 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -187,7 +187,7 @@ async def endpoint(request: ChatCompletionRequest): and output.tool_calls # Check dict is not empty ): tool_calls = [] - for tool_name, model_tool_call in output.tool_calls.items(): + for model_tool_call in output.tool_calls.values(): # Generate a unique ID for this tool call tool_call_id = f"call_{uuid.uuid4().hex[:24]}" From 0e23d929425ca97bd74a5d45f56b8f4c9135e1f6 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 14 Apr 2026 14:34:29 -0700 Subject: [PATCH 05/20] fix: fix test_mot_init_typing() hasattr was always true Replaced hasattr() with direct __dict__ membership tests to correctly distinguish: 1. Typed instances (ModelOutputThunk[float](...)) - have __orig_class__ in their instance dict 2. Untyped instances (ModelOutputThunk(...)) - do NOT have __orig_class__ in their instance dict Signed-off-by: Mark Sturdevant --- test/core/test_component_typing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/core/test_component_typing.py b/test/core/test_component_typing.py index f8d3d411d..829fb9eaa 100644 --- a/test/core/test_component_typing.py +++ b/test/core/test_component_typing.py @@ -78,16 +78,16 @@ def session(backend) -> MelleaSession: def test_mot_init_typing(): mot = ModelOutputThunk[float](value="1") - assert hasattr(mot, "__orig_class__"), ( - "mots are generics and should have this field" + assert "__orig_class__" in mot.__dict__, ( + "mots are generics and should have this field in instance dict" ) assert get_args(mot.__orig_class__)[0] is float, ( # type: ignore f"expected float, got {get_args(mot.__orig_class__)[0]} as mot type" # type: ignore ) # type: ignore unknown_mot = ModelOutputThunk(value="2") - assert not hasattr(unknown_mot, "__orig_class__"), ( - "unknown mots / mots with no type defined at instantiate don't have this attribute" + assert "__orig_class__" not in unknown_mot.__dict__, ( + "unknown mots / mots with no type defined at instantiate don't have this attribute in instance dict" ) From 0f894d82f3eff00ba013ebf44ef03e3e42134293 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 14 Apr 2026 14:42:14 -0700 Subject: [PATCH 06/20] fix: update m_serve_example_tool_calling.py to use safer example tool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Security issue resolved in `m_serve_example_tool_calling.py`: **Changes made:** - Replaced `CalculatorTool` (which used unsafe `eval()` with `# noqa: S307`) with `GetStockPriceTool` - New tool demonstrates API-calling pattern with mock stock prices (AAPL, GOOGL, MSFT, TSLA) - Updated all references: `calculator_tool` → `stock_price_tool` - Maintains the same tool calling demonstration with two tools (weather + stock price) **Why this is better:** - Eliminates security risk entirely (no `eval()` or suppressed lints) - Still demonstrates multiple tools effectively - Uses safe, realistic API-calling pattern that users can copy - No dangerous code that could be copy-pasted into production Signed-off-by: Mark Sturdevant --- .../m_serve/m_serve_example_tool_calling.py | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/docs/examples/m_serve/m_serve_example_tool_calling.py b/docs/examples/m_serve/m_serve_example_tool_calling.py index e7dbedd29..1aa041eee 100644 --- a/docs/examples/m_serve/m_serve_example_tool_calling.py +++ b/docs/examples/m_serve/m_serve_example_tool_calling.py @@ -63,26 +63,29 @@ def as_json_tool(self) -> dict[str, Any]: } -class CalculatorTool(AbstractMelleaTool): - """Tool for performing calculations.""" +class GetStockPriceTool(AbstractMelleaTool): + """Tool for getting stock price information.""" - name = "calculator" + name = "get_stock_price" - def run(self, expression: str) -> str: - """Evaluate a mathematical expression. + def run(self, symbol: str) -> str: + """Get the current stock price for a symbol. Args: - expression: A mathematical expression to evaluate + symbol: The stock ticker symbol (e.g., AAPL, GOOGL) Returns: - The result of the calculation + Stock price information as a string """ - try: - # In a real implementation, use a safe expression evaluator - result = eval(expression) # noqa: S307 - return f"The result is {result}" - except Exception as e: - return f"Error evaluating expression: {e}" + # In a real implementation, this would call a stock market API + mock_prices = { + "AAPL": "$175.43", + "GOOGL": "$142.87", + "MSFT": "$378.91", + "TSLA": "$242.15", + } + price = mock_prices.get(symbol.upper(), "$100.00") + return f"The current price of {symbol.upper()} is {price}" @property def as_json_tool(self) -> dict[str, Any]: @@ -91,16 +94,16 @@ def as_json_tool(self) -> dict[str, Any]: "type": "function", "function": { "name": self.name, - "description": "Evaluate a mathematical expression", + "description": "Get the current stock price for a given ticker symbol", "parameters": { "type": "object", "properties": { - "expression": { + "symbol": { "type": "string", - "description": "The mathematical expression to evaluate, e.g. '2 + 2'", + "description": "The stock ticker symbol, e.g. AAPL, GOOGL", } }, - "required": ["expression"], + "required": ["symbol"], }, }, } @@ -108,10 +111,10 @@ def as_json_tool(self) -> dict[str, Any]: # Create tool instances weather_tool = GetWeatherTool() -calculator_tool = CalculatorTool() +stock_price_tool = GetStockPriceTool() # Map tool names to instances for easy lookup -TOOLS = {weather_tool.name: weather_tool, calculator_tool.name: calculator_tool} +TOOLS = {weather_tool.name: weather_tool, stock_price_tool.name: stock_price_tool} def serve( @@ -167,7 +170,7 @@ def serve( # Simulate tool definitions being passed test_model_options = { - "@@@tools@@@": [weather_tool.as_json_tool, calculator_tool.as_json_tool] + "@@@tools@@@": [weather_tool.as_json_tool, stock_price_tool.as_json_tool] } response = serve(input=test_messages, model_options=test_model_options) From 7de99e46273375a3fee735f350dba7ea38157125 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Thu, 16 Apr 2026 19:25:55 -0700 Subject: [PATCH 07/20] fix: replace repeated hard-coded string with constant Signed-off-by: Mark Sturdevant --- docs/examples/m_serve/m_serve_example_tool_calling.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/examples/m_serve/m_serve_example_tool_calling.py b/docs/examples/m_serve/m_serve_example_tool_calling.py index 1aa041eee..30f6ee12f 100644 --- a/docs/examples/m_serve/m_serve_example_tool_calling.py +++ b/docs/examples/m_serve/m_serve_example_tool_calling.py @@ -11,6 +11,7 @@ import mellea from cli.serve.models import ChatMessage +from mellea.backends import ModelOption from mellea.core import ModelOutputThunk, Requirement from mellea.core.base import AbstractMelleaTool from mellea.stdlib.context import ChatContext @@ -140,9 +141,9 @@ def serve( # Extract tools from model_options if provided tools = None - if model_options and "@@@tools@@@" in model_options: + if model_options and ModelOption.TOOLS in model_options: # Convert OpenAI tool format to Mellea tool format - openai_tools = model_options["@@@tools@@@"] + openai_tools = model_options[ModelOption.TOOLS] tools = {} for tool_def in openai_tools: tool_name = tool_def["function"]["name"] @@ -152,7 +153,7 @@ def serve( # Build model options with tools final_model_options = model_options or {} if tools: - final_model_options["@@@tools@@@"] = tools + final_model_options[ModelOption.TOOLS] = tools # Use instruct to generate response with potential tool calls result = session.instruct( @@ -170,7 +171,7 @@ def serve( # Simulate tool definitions being passed test_model_options = { - "@@@tools@@@": [weather_tool.as_json_tool, stock_price_tool.as_json_tool] + ModelOption.TOOLS: [weather_tool.as_json_tool, stock_price_tool.as_json_tool] } response = serve(input=test_messages, model_options=test_model_options) From 98ede722c4493d8c495177f9968e6db505d77e46 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Thu, 16 Apr 2026 19:36:12 -0700 Subject: [PATCH 08/20] fix: add TOOL_CHOICE to ModelOptions like TEMPERATURE not a sentinel The pass-thru behavior was not clear enough, so adding it to ModelOptions where important options are known. Most of these are sentinels which are removed (because @@@) but this will be like TEMPERATURE which is passed through to the backends. No behavior change, but give a handly constant and a place to look for these. This does not address all the other possible pass through args. Signed-off-by: Mark Sturdevant Assisted-by: IBM Bob --- cli/serve/app.py | 2 +- docs/examples/m_serve/m_serve_example_tool_calling.py | 10 ++++++---- mellea/backends/model_options.py | 4 ++++ test/cli/test_serve.py | 6 +++--- test/cli/test_serve_tool_calling.py | 7 ++++--- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/cli/serve/app.py b/cli/serve/app.py index cb37d93ce..83676fedd 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -115,7 +115,6 @@ def _build_model_options(request: ChatCompletionRequest) -> dict: "response_format", # Response format (json_object) - not yet implemented "functions", # Legacy function calling - not yet implemented "function_call", # Legacy function calling - not yet implemented - # Tool choice is passed through as-is (not a ModelOption sentinel) } openai_to_model_option = { "temperature": ModelOption.TEMPERATURE, @@ -123,6 +122,7 @@ def _build_model_options(request: ChatCompletionRequest) -> dict: "seed": ModelOption.SEED, "stream": ModelOption.STREAM, "tools": ModelOption.TOOLS, + "tool_choice": ModelOption.TOOL_CHOICE, } # Get all non-None fields diff --git a/docs/examples/m_serve/m_serve_example_tool_calling.py b/docs/examples/m_serve/m_serve_example_tool_calling.py index 30f6ee12f..944ed56bf 100644 --- a/docs/examples/m_serve/m_serve_example_tool_calling.py +++ b/docs/examples/m_serve/m_serve_example_tool_calling.py @@ -126,12 +126,13 @@ def serve( """Serve function that handles tool calling. This function demonstrates how to use tools with m serve. The tools - are passed via model_options and the model can request to call them. + are passed via model_options using ModelOption.TOOLS, and tool_choice + can be specified using ModelOption.TOOL_CHOICE. Args: input: List of chat messages requirements: Optional list of requirement strings - model_options: Model options including tools and tool_choice + model_options: Model options including ModelOption.TOOLS and ModelOption.TOOL_CHOICE Returns: ModelOutputThunk with potential tool calls @@ -169,9 +170,10 @@ def serve( # Example usage (for testing purposes) test_messages = [ChatMessage(role="user", content="What's the weather in Paris?")] - # Simulate tool definitions being passed + # Simulate tool definitions being passed with tool_choice test_model_options = { - ModelOption.TOOLS: [weather_tool.as_json_tool, stock_price_tool.as_json_tool] + ModelOption.TOOLS: [weather_tool.as_json_tool, stock_price_tool.as_json_tool], + ModelOption.TOOL_CHOICE: "auto", # Can be "none", "auto", or specific tool } response = serve(input=test_messages, model_options=test_model_options) diff --git a/mellea/backends/model_options.py b/mellea/backends/model_options.py index decc8c34b..a03e8625c 100644 --- a/mellea/backends/model_options.py +++ b/mellea/backends/model_options.py @@ -22,6 +22,7 @@ class ModelOption: Attributes: TOOLS (str): Sentinel key for a list or dict of ``MelleaTool`` instances to expose for tool calling. + TOOL_CHOICE (str): Key for tool choice strategy (passed through to the backend). MAX_NEW_TOKENS (str): Sentinel key for the maximum number of new tokens to generate. SYSTEM_PROMPT (str): Sentinel key for the system prompt string. TEMPERATURE (str): Key for the sampling temperature (passed through to the backend). @@ -34,6 +35,9 @@ class ModelOption: TOOLS = "@@@tools@@@" """Must be a list[MelleaTool] or a dict[str, MelleaTool]. Use ``MelleaTool.from_callable()`` or the ``@tool`` decorator to wrap plain callables.""" + TOOL_CHOICE = "tool_choice" + """Controls which tool the model should use. Can be "none", "auto", or a specific tool name.""" + MAX_NEW_TOKENS = "@@@max_new_tokens@@@" SYSTEM_PROMPT = "@@@system_prompt@@@" TEMPERATURE = "temperature" diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index e20deb6e3..6aae713dd 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -501,9 +501,9 @@ async def test_tool_params_passed_to_model_options(self, mock_module): # Tools should be passed with ModelOption.TOOLS key assert ModelOption.TOOLS in model_options - # tool_choice should be passed through as-is - assert "tool_choice" in model_options - assert model_options["tool_choice"] == "auto" + # tool_choice should be passed through using ModelOption.TOOL_CHOICE + assert ModelOption.TOOL_CHOICE in model_options + assert model_options[ModelOption.TOOL_CHOICE] == "auto" # Legacy function calling parameters should still be excluded assert "functions" not in model_options assert "function_call" not in model_options diff --git a/test/cli/test_serve_tool_calling.py b/test/cli/test_serve_tool_calling.py index 0b31b7837..d1e76cd09 100644 --- a/test/cli/test_serve_tool_calling.py +++ b/test/cli/test_serve_tool_calling.py @@ -15,6 +15,7 @@ FunctionParameters, ToolFunction, ) +from mellea.backends import ModelOption from mellea.core.base import AbstractMelleaTool, ModelOutputThunk, ModelToolCall @@ -223,9 +224,9 @@ async def test_tool_choice_passed_to_model_options( assert call_args is not None model_options = call_args.kwargs["model_options"] - # tool_choice should be passed through as-is - assert "tool_choice" in model_options - assert model_options["tool_choice"] == "auto" + # tool_choice should be passed through using ModelOption.TOOL_CHOICE + assert ModelOption.TOOL_CHOICE in model_options + assert model_options[ModelOption.TOOL_CHOICE] == "auto" @pytest.mark.asyncio async def test_tool_calls_with_complex_arguments( From 6a812f23ca90460316cd753126d4301d3967c7ba Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Fri, 17 Apr 2026 11:41:12 -0700 Subject: [PATCH 09/20] fix: fix m serve tool-calling examples - switch server example to OpenAIBackend - align tool-calling example with tested Granite model setup - narrow advertised tools when `tool_choice` selects a specific function - enable `tool_calls=True` in the serve path - replace calculator example with stock-price tool - examples 1/2 as tool-call-only demos - example 4 as the full tool execution round-trip - improve client diagnostics for empty/no-tool responses Signed-off-by: Mark Sturdevant Assisted-by: IBM Bob --- docs/examples/m_serve/client_tool_calling.py | 161 ++++++++++++----- .../m_serve/m_serve_example_tool_calling.py | 168 ++++++++++++++---- 2 files changed, 249 insertions(+), 80 deletions(-) diff --git a/docs/examples/m_serve/client_tool_calling.py b/docs/examples/m_serve/client_tool_calling.py index 6ffc3cc24..d68e5d238 100644 --- a/docs/examples/m_serve/client_tool_calling.py +++ b/docs/examples/m_serve/client_tool_calling.py @@ -27,48 +27,55 @@ "name": "get_weather", "description": "Get the current weather in a given location", "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city name, e.g. San Francisco", + "RootModel": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. San Francisco", + }, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, }, - "units": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "Temperature units", - }, - }, - "required": ["location"], + "required": ["location"], + } }, }, }, { "type": "function", "function": { - "name": "calculator", - "description": "Evaluate a mathematical expression", + "name": "get_stock_price", + "description": "Get the current stock price for a given ticker symbol", "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "The mathematical expression to evaluate", - } - }, - "required": ["expression"], + "RootModel": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL, GOOGL", + } + }, + "required": ["symbol"], + } }, }, }, ] -def make_request(messages: list[dict], tools: list[dict] | None = None) -> dict: +def make_request( + messages: list[dict], tools: list[dict] | None = None, tool_name: str | None = None +) -> dict: """Make a request to the m serve API. Args: messages: List of message dictionaries tools: Optional list of tool definitions + tool_name: Optional tool name to request explicitly Returns: Response dictionary from the API @@ -81,13 +88,53 @@ def make_request(messages: list[dict], tools: list[dict] | None = None) -> dict: if tools: payload["tools"] = tools - payload["tool_choice"] = "auto" + if tool_name is not None: + # m serve forwards tool_choice to compatible backends, but the + # downstream provider/model may ignore it or treat it as a weak + # preference rather than a guarantee. Use an explicit function + # selection in this client so the example demonstrates the API + # contract even when the model would otherwise decline to call tools. + payload["tool_choice"] = { + "type": "function", + "function": {"name": tool_name}, + } + else: + payload["tool_choice"] = "auto" response = requests.post(ENDPOINT, json=payload, timeout=30) - response.raise_for_status() + + if response.status_code >= 400: + try: + error_payload = response.json() + except ValueError: + error_payload = {"error": {"message": response.text}} + + error_message = error_payload.get("error", {}).get("message", response.text) + raise requests.HTTPError( + f"{response.status_code} Server Error: {error_message}", response=response + ) + return response.json() +def _run_local_tool(tool_name: str, args: dict) -> str: + """Simulate local execution of the example tools.""" + if tool_name == "get_weather": + units = args.get("units") or "celsius" + unit_suffix = "C" if units == "celsius" else "F" + return f"The weather in {args['location']} is sunny and 22°{unit_suffix}" + if tool_name == "get_stock_price": + mock_prices = { + "AAPL": "$175.43", + "GOOGL": "$142.87", + "MSFT": "$378.91", + "TSLA": "$242.15", + } + symbol = args["symbol"].upper() + return f"The current price of {symbol} is {mock_prices.get(symbol, '$100.00')}" + return "Tool result" + + def main(): """Run example tool calling interactions.""" print("=" * 60) @@ -100,7 +147,7 @@ def main(): messages = [{"role": "user", "content": "What's the weather like in Tokyo?"}] print(f"User: {messages[0]['content']}") - response = make_request(messages, tools=tools) + response = make_request(messages, tools=tools, tool_name="get_weather") choice = response["choices"][0] print(f"\nFinish Reason: {choice['finish_reason']}") @@ -111,16 +158,18 @@ def main(): func = tool_call["function"] args = json.loads(func["arguments"]) print(f" - {func['name']}({json.dumps(args)})") - else: + elif choice.get("message", {}).get("content"): print(f"Assistant: {choice['message']['content']}") + else: + print("Assistant returned no content and no tool calls.") - # Example 2: Request that should trigger calculator tool - print("\n\n2. Math Query") + # Example 2: Request that should trigger stock price tool + print("\n\n2. Stock Price Query") print("-" * 60) - messages = [{"role": "user", "content": "What is 15 * 23 + 7?"}] + messages = [{"role": "user", "content": "What's the current stock price of AAPL?"}] print(f"User: {messages[0]['content']}") - response = make_request(messages, tools=tools) + response = make_request(messages, tools=tools, tool_name="get_stock_price") choice = response["choices"][0] print(f"\nFinish Reason: {choice['finish_reason']}") @@ -131,8 +180,10 @@ def main(): func = tool_call["function"] args = json.loads(func["arguments"]) print(f" - {func['name']}({json.dumps(args)})") - else: + elif choice.get("message", {}).get("content"): print(f"Assistant: {choice['message']['content']}") + else: + print("Assistant returned no content and no tool calls.") # Example 3: Request without tools (normal chat) print("\n\n3. Normal Chat (No Tools)") @@ -152,7 +203,7 @@ def main(): messages = [{"role": "user", "content": "What's the weather in Paris?"}] print(f"User: {messages[0]['content']}") - response = make_request(messages, tools=tools) + response = make_request(messages, tools=tools, tool_name="get_weather") choice = response["choices"][0] assistant_message = choice["message"] @@ -169,17 +220,17 @@ def main(): } ) + tool_results: list[str] = [] + # Process each tool call and add tool responses for tool_call in assistant_message["tool_calls"]: func = tool_call["function"] args = json.loads(func["arguments"]) print(f" - {func['name']}({json.dumps(args)})") - # Simulate tool execution - if func["name"] == "get_weather": - tool_result = f"The weather in {args['location']} is sunny and 22°C" - else: - tool_result = "Tool result" + tool_result = _run_local_tool(func["name"], args) + tool_results.append(tool_result) + print(f" Result: {tool_result}") # Add tool response to conversation messages.append( @@ -190,11 +241,32 @@ def main(): } ) - # Get final response after tool execution + # Get final response after tool execution. + # Ask for a concise answer that explicitly uses the tool result so the + # example output includes the actual weather/price instead of only a + # conversational acknowledgement. + messages.append( + { + "role": "user", + "content": ( + f"Original question: {messages[0]['content']}\n" + f"Tool result: {'; '.join(tool_results)}\n" + "Answer the original question directly using only that tool " + "result. Do not mention unrelated topics or other tools." + ), + } + ) print("\nGetting final response after tool execution...") - response = make_request(messages, tools=tools) + response = make_request(messages, tools=None) choice = response["choices"][0] - print(f"Assistant: {choice['message']['content']}") + if choice.get("message", {}).get("content"): + print(f"Assistant: {choice['message']['content']}") + else: + print("Assistant returned no content after tool execution.") + elif assistant_message.get("content"): + print(f"Assistant: {assistant_message['content']}") + else: + print("Assistant returned no content and no tool calls.") print("\n" + "=" * 60) print("Examples completed!") @@ -208,5 +280,12 @@ def main(): print("Error: Could not connect to server.") print("Make sure the server is running:") print(" uv run m serve docs/examples/m_serve/m_serve_example_tool_calling.py") + except requests.exceptions.HTTPError as e: + print(f"Error: {e}") + if e.response is not None: + try: + print("Server response:", json.dumps(e.response.json(), indent=2)) + except ValueError: + print("Server response:", e.response.text) except Exception as e: print(f"Error: {e}") diff --git a/docs/examples/m_serve/m_serve_example_tool_calling.py b/docs/examples/m_serve/m_serve_example_tool_calling.py index 944ed56bf..aac3e68c1 100644 --- a/docs/examples/m_serve/m_serve_example_tool_calling.py +++ b/docs/examples/m_serve/m_serve_example_tool_calling.py @@ -2,21 +2,45 @@ """Example demonstrating tool calling with m serve. -This example shows how to use the OpenAI-compatible tool calling API -with m serve. The server will accept tool definitions and return tool -calls in the response when the model decides to use them. +This file supports two distinct usage patterns: + +1. Running it directly with ``uv run python ...`` performs a local smoke test + using native Mellea tool calling. +2. Serving it with ``m serve`` exposes an OpenAI-compatible endpoint that + accepts OpenAI-style tool definitions in the request. + +The direct ``__main__`` smoke test is intentionally separate from the +OpenAI-compatible server flow because local ``session.instruct(...)`` calls +should use ``MelleaTool`` objects directly. """ +import os from typing import Any import mellea from cli.serve.models import ChatMessage from mellea.backends import ModelOption +from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO +from mellea.backends.openai import OpenAIBackend +from mellea.backends.tools import MelleaTool from mellea.core import ModelOutputThunk, Requirement from mellea.core.base import AbstractMelleaTool +from mellea.formatters import TemplateFormatter from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements.tool_reqs import uses_tool +from mellea.stdlib.session import MelleaSession + +_ollama_host = os.environ.get("OLLAMA_HOST", "localhost:11434") +if not _ollama_host.startswith(("http://", "https://")): + _ollama_host = f"http://{_ollama_host}" -session = mellea.start_session(ctx=ChatContext()) +backend = OpenAIBackend( + model_id=IBM_GRANITE_4_HYBRID_MICRO.ollama_name, # type: ignore[arg-type] + formatter=TemplateFormatter(model_id=IBM_GRANITE_4_HYBRID_MICRO.hf_model_name), # type: ignore[arg-type] + base_url=f"{_ollama_host}/v1", + api_key="ollama", +) +session = MelleaSession(backend, ctx=ChatContext()) class GetWeatherTool(AbstractMelleaTool): @@ -24,7 +48,7 @@ class GetWeatherTool(AbstractMelleaTool): name = "get_weather" - def run(self, location: str, units: str = "celsius") -> str: + def run(self, location: str, units: str | None = "celsius") -> str: """Get the current weather for a location. Args: @@ -34,8 +58,10 @@ def run(self, location: str, units: str = "celsius") -> str: Returns: Weather information as a string """ + # Models sometimes emit optional arguments explicitly as null/None. + resolved_units = units or "celsius" # In a real implementation, this would call a weather API - return f"The weather in {location} is sunny and 22°{units[0].upper()}" + return f"The weather in {location} is sunny and 22°{resolved_units[0].upper()}" @property def as_json_tool(self) -> dict[str, Any]: @@ -110,12 +136,56 @@ def as_json_tool(self) -> dict[str, Any]: } -# Create tool instances -weather_tool = GetWeatherTool() -stock_price_tool = GetStockPriceTool() +# Create tool instances for server-side lookup +weather_tool_impl = GetWeatherTool() +stock_price_tool_impl = GetStockPriceTool() + +# Native MelleaTool wrappers are only needed for the direct ``__main__`` path. +# The backend helper used by local ``session.instruct(..., ModelOption.TOOLS=[...])`` +# expects ``MelleaTool`` instances in a list, while the server path below uses the +# class-based implementations via the ``TOOLS`` lookup. +weather_tool = MelleaTool( + name=weather_tool_impl.name, + tool_call=weather_tool_impl.run, + as_json_tool=weather_tool_impl.as_json_tool, +) +stock_price_tool = MelleaTool( + name=stock_price_tool_impl.name, + tool_call=stock_price_tool_impl.run, + as_json_tool=stock_price_tool_impl.as_json_tool, +) + +# Map tool names to server-side tool implementations for easy lookup +TOOLS = { + weather_tool_impl.name: weather_tool_impl, + stock_price_tool_impl.name: stock_price_tool_impl, +} + + +def _extract_mellea_tools_from_model_options( + model_options: dict | None, +) -> dict[str, AbstractMelleaTool]: + """Normalize example tool inputs to native tool instances. + + This example supports only two shapes: + - OpenAI-style JSON tool definitions from the server path + - native tool objects from the direct ``__main__`` path + """ + if model_options is None or ModelOption.TOOLS not in model_options: + return {} + + provided_tools = model_options[ModelOption.TOOLS] + tools: dict[str, AbstractMelleaTool] = {} + + for tool_def in provided_tools: + if isinstance(tool_def, AbstractMelleaTool): + tools[tool_def.name] = tool_def + else: + tool_name = tool_def["function"]["name"] + if tool_name in TOOLS: + tools[tool_name] = TOOLS[tool_name] -# Map tool names to instances for easy lookup -TOOLS = {weather_tool.name: weather_tool, stock_price_tool.name: stock_price_tool} + return tools def serve( @@ -127,7 +197,9 @@ def serve( This function demonstrates how to use tools with m serve. The tools are passed via model_options using ModelOption.TOOLS, and tool_choice - can be specified using ModelOption.TOOL_CHOICE. + can be specified using ModelOption.TOOL_CHOICE. Mellea forwards that + setting to compatible backends, but the downstream provider/model may + still ignore it or treat it as a weak preference. Args: input: List of chat messages @@ -141,43 +213,61 @@ def serve( message = input[-1].content # Extract tools from model_options if provided - tools = None - if model_options and ModelOption.TOOLS in model_options: - # Convert OpenAI tool format to Mellea tool format - openai_tools = model_options[ModelOption.TOOLS] - tools = {} - for tool_def in openai_tools: - tool_name = tool_def["function"]["name"] - if tool_name in TOOLS: - tools[tool_name] = TOOLS[tool_name] - - # Build model options with tools - final_model_options = model_options or {} + tools = _extract_mellea_tools_from_model_options(model_options) + + # Build model options with tools. + # If the caller explicitly selected a single function via tool_choice, + # narrow the advertised tool set to that one tool so the backend/model + # is not asked to choose among unrelated tools. + final_model_options = dict(model_options or {}) + selected_tool_name: str | None = None if tools: - final_model_options[ModelOption.TOOLS] = tools - - # Use instruct to generate response with potential tool calls + selected_tools = tools + if model_options is not None and ModelOption.TOOL_CHOICE in model_options: + tool_choice = model_options[ModelOption.TOOL_CHOICE] + if isinstance(tool_choice, dict): + selected_tool_name = tool_choice.get("function", {}).get("name") + if selected_tool_name in tools: + selected_tools = {selected_tool_name: tools[selected_tool_name]} + final_model_options[ModelOption.TOOLS] = selected_tools + + # Keep the serve path deterministic for the client example by retrying only + # at the request level. Enforcing uses_tool(...) inside session.instruct() + # caused noisy server-side failures when the model ignored the tool request + # on a particular sample. result = session.instruct( description=message, # type: ignore requirements=[Requirement(req) for req in requirements], # type: ignore model_options=final_model_options, + tool_calls=True, + strategy=None, ) return result if __name__ == "__main__": - # Example usage (for testing purposes) - test_messages = [ChatMessage(role="user", content="What's the weather in Paris?")] - - # Simulate tool definitions being passed with tool_choice - test_model_options = { - ModelOption.TOOLS: [weather_tool.as_json_tool, stock_price_tool.as_json_tool], - ModelOption.TOOL_CHOICE: "auto", # Can be "none", "auto", or specific tool - } - - response = serve(input=test_messages, model_options=test_model_options) + response = session.instruct( + "What's the weather in Boston?", + model_options={ + ModelOption.TOOLS: [weather_tool], + # This direct path now uses the OpenAI backend against Ollama's + # OpenAI-compatible endpoint, so TOOL_CHOICE is forwarded by + # Mellea. Ollama and/or the selected model may still ignore it + # or not enforce it strictly in practice. + ModelOption.TOOL_CHOICE: "auto", + ModelOption.MAX_NEW_TOKENS: 1000, + }, + strategy=None, + tool_calls=True, + ) print(f"Response: {response.value}") - if response.tool_calls: - print(f"Tool calls requested: {list(response.tool_calls.keys())}") + print( + "Tool calls requested:", + None if response.tool_calls is None else list(response.tool_calls.keys()), + ) + + if response.tool_calls and weather_tool.name in response.tool_calls: + tool_result = response.tool_calls[weather_tool.name].call_func() + print(f"Tool result: {tool_result}") From 8bd74fdaa91ba052c9fd1b57f59b4046ce34b0e5 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Fri, 17 Apr 2026 13:32:39 -0700 Subject: [PATCH 10/20] fix: remove unused imports in example Signed-off-by: Mark Sturdevant --- docs/examples/m_serve/m_serve_example_tool_calling.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/examples/m_serve/m_serve_example_tool_calling.py b/docs/examples/m_serve/m_serve_example_tool_calling.py index aac3e68c1..839c91b1b 100644 --- a/docs/examples/m_serve/m_serve_example_tool_calling.py +++ b/docs/examples/m_serve/m_serve_example_tool_calling.py @@ -17,7 +17,6 @@ import os from typing import Any -import mellea from cli.serve.models import ChatMessage from mellea.backends import ModelOption from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO @@ -27,7 +26,6 @@ from mellea.core.base import AbstractMelleaTool from mellea.formatters import TemplateFormatter from mellea.stdlib.context import ChatContext -from mellea.stdlib.requirements.tool_reqs import uses_tool from mellea.stdlib.session import MelleaSession _ollama_host = os.environ.get("OLLAMA_HOST", "localhost:11434") From 9a82f5f6aea8dd356546872d732182f00bce51e0 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Thu, 23 Apr 2026 17:35:30 -0700 Subject: [PATCH 11/20] feat: cli support for OpenAI API tool calling with streaming Signed-off-by: Mark Sturdevant Assisted-by: IBM Bob --- cli/serve/app.py | 56 ++- cli/serve/models.py | 3 + cli/serve/streaming.py | 50 ++- .../m_serve/client_streaming_tool_calling.py | 323 ++++++++++++++++++ mellea/helpers/openai_compatible_helpers.py | 56 ++- test/cli/test_serve.py | 87 +++++ 6 files changed, 543 insertions(+), 32 deletions(-) create mode 100644 docs/examples/m_serve/client_streaming_tool_calling.py diff --git a/cli/serve/app.py b/cli/serve/app.py index 83676fedd..f4c2efd50 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -3,7 +3,6 @@ import asyncio import importlib.util import inspect -import json import os import sys import time @@ -23,7 +22,10 @@ ) from e from mellea.backends.model_options import ModelOption -from mellea.helpers.openai_compatible_helpers import build_completion_usage +from mellea.helpers.openai_compatible_helpers import ( + build_completion_usage, + build_tool_calls, +) from .models import ( ChatCompletion, @@ -176,34 +178,30 @@ async def endpoint(request: ChatCompletionRequest): ) # Extract tool calls from the ModelOutputThunk if available - tool_calls = None - finish_reason: Literal[ - "stop", "length", "content_filter", "tool_calls", "function_call" - ] = "stop" - if ( - hasattr(output, "tool_calls") - and output.tool_calls is not None - and isinstance(output.tool_calls, dict) - and output.tool_calls # Check dict is not empty - ): - tool_calls = [] - for model_tool_call in output.tool_calls.values(): - # Generate a unique ID for this tool call - tool_call_id = f"call_{uuid.uuid4().hex[:24]}" - - # Serialize the arguments to JSON string - args_json = json.dumps(model_tool_call.args) - - tool_calls.append( - ChatCompletionMessageToolCall( - id=tool_call_id, - type="function", - function=ToolCallFunction( - name=model_tool_call.name, arguments=args_json - ), - ) + tool_calls_list = build_tool_calls(output) + tool_calls = ( + [ + ChatCompletionMessageToolCall( + id=tc["id"], + type=tc["type"], + function=ToolCallFunction( + name=tc["function"]["name"], + arguments=tc["function"]["arguments"], + ), ) - finish_reason = "tool_calls" + for tc in tool_calls_list + ] + if tool_calls_list + else None + ) + + # Determine finish_reason based on tool calls + finish_reason: ( + Literal[ + "stop", "length", "content_filter", "tool_calls", "function_call" + ] + | None + ) = "tool_calls" if tool_calls else "stop" # system_fingerprint represents backend config hash, not model name # The model name is already in response.model (line 73) diff --git a/cli/serve/models.py b/cli/serve/models.py index ba0bd8cca..5e994011f 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -170,6 +170,9 @@ class ChatCompletionChunkDelta(BaseModel): refusal: str | None = None """The refusal message fragment, if any.""" + tool_calls: list[ChatCompletionMessageToolCall] | None = None + """The tool calls generated by the model (only in tool call chunks).""" + class ChatCompletionChunkChoice(BaseModel): """A choice in a streaming chunk.""" diff --git a/cli/serve/streaming.py b/cli/serve/streaming.py index 51ff33c3c..e837006cc 100644 --- a/cli/serve/streaming.py +++ b/cli/serve/streaming.py @@ -1,18 +1,24 @@ """Streaming utilities for OpenAI-compatible server responses.""" from collections.abc import AsyncGenerator +from typing import Literal from mellea.core.base import ModelOutputThunk from mellea.core.utils import MelleaLogger -from mellea.helpers.openai_compatible_helpers import build_completion_usage +from mellea.helpers.openai_compatible_helpers import ( + build_completion_usage, + build_tool_calls, +) from .models import ( ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkDelta, + ChatCompletionMessageToolCall, OpenAIError, OpenAIErrorResponse, StreamOptions, + ToolCallFunction, ) @@ -98,6 +104,46 @@ async def stream_chat_completion_chunks( ) yield f"data: {chunk.model_dump_json()}\n\n" + # Extract tool calls from the ModelOutputThunk if available + tool_calls_list = build_tool_calls(output) + + if tool_calls_list: + # Convert to ChatCompletionMessageToolCall objects + tool_calls = [ + ChatCompletionMessageToolCall( + id=tc["id"], + type=tc["type"], + function=ToolCallFunction( + name=tc["function"]["name"], + arguments=tc["function"]["arguments"], + ), + ) + for tc in tool_calls_list + ] + + # Emit tool calls in a separate chunk before the final chunk + tool_call_chunk = ChatCompletionChunk( + id=completion_id, + model=model, + created=created, + choices=[ + ChatCompletionChunkChoice( + index=0, + delta=ChatCompletionChunkDelta(tool_calls=tool_calls), + finish_reason=None, + ) + ], + object="chat.completion.chunk", + system_fingerprint=system_fingerprint, + ) + yield f"data: {tool_call_chunk.model_dump_json()}\n\n" + + # Determine finish_reason based on tool calls + finish_reason: ( + Literal["stop", "length", "content_filter", "tool_calls", "function_call"] + | None + ) = "tool_calls" if tool_calls_list else "stop" + # Include usage in final chunk only if explicitly requested via stream_options # Per OpenAI spec: usage is only included when stream_options.include_usage=True include_usage = stream_options is not None and stream_options.include_usage @@ -112,7 +158,7 @@ async def stream_chat_completion_chunks( ChatCompletionChunkChoice( index=0, delta=ChatCompletionChunkDelta(content=None), - finish_reason="stop", + finish_reason=finish_reason, ) ], object="chat.completion.chunk", diff --git a/docs/examples/m_serve/client_streaming_tool_calling.py b/docs/examples/m_serve/client_streaming_tool_calling.py new file mode 100644 index 000000000..2a406564e --- /dev/null +++ b/docs/examples/m_serve/client_streaming_tool_calling.py @@ -0,0 +1,323 @@ +"""Client example for testing streaming with tool calling. + +This script demonstrates how to use streaming responses with tool calls +from an m serve server using the OpenAI-compatible API. + +Usage: + 1. Start the server: + uv run m serve docs/examples/m_serve/m_serve_example_tool_calling.py + + 2. Run this client: + uv run python docs/examples/m_serve/client_streaming_tool_calling.py +""" + +import json +from typing import Any + +import requests + +# Server configuration +BASE_URL = "http://localhost:8080" +ENDPOINT = f"{BASE_URL}/v1/chat/completions" + +# Define tools in OpenAI format +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "RootModel": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. San Francisco", + }, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + } + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_stock_price", + "description": "Get the current stock price for a given ticker symbol", + "parameters": { + "RootModel": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL, GOOGL", + } + }, + "required": ["symbol"], + } + }, + }, + }, +] + + +def make_streaming_request( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + tool_name: str | None = None, +) -> tuple[str, list[dict[str, Any]] | None, str]: + """Make a streaming request to the m serve API. + + Args: + messages: List of message dictionaries + tools: Optional list of tool definitions + tool_name: Optional tool name to request explicitly + + Returns: + Tuple of (content, tool_calls, finish_reason) + """ + payload: dict[str, Any] = { + "model": "gpt-3.5-turbo", # Model name (not used by m serve) + "messages": messages, + "temperature": 0.7, + "stream": True, + } + + if tools: + payload["tools"] = tools + if tool_name is not None: + payload["tool_choice"] = { + "type": "function", + "function": {"name": tool_name}, + } + else: + payload["tool_choice"] = "auto" + + response = requests.post(ENDPOINT, json=payload, stream=True, timeout=30) + + if response.status_code >= 400: + try: + error_payload = response.json() + except ValueError: + error_payload = {"error": {"message": response.text}} + + error_message = error_payload.get("error", {}).get("message", response.text) + raise requests.HTTPError( + f"{response.status_code} Server Error: {error_message}", response=response + ) + + content_chunks = [] + tool_calls: list[dict[str, Any]] | None = None + finish_reason = "stop" + + for line in response.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + data_str = line_str[6:] + if data_str == "[DONE]": + break + + chunk = json.loads(data_str) + choice = chunk["choices"][0] + delta = choice.get("delta", {}) + + # Collect content + if delta.get("content"): + content_chunks.append(delta["content"]) + print(delta["content"], end="", flush=True) + + # Collect tool calls + if delta.get("tool_calls"): + tool_calls = delta["tool_calls"] + + # Get finish reason + if choice.get("finish_reason"): + finish_reason = choice["finish_reason"] + + content = "".join(content_chunks) + return content, tool_calls, finish_reason + + +def _run_local_tool(tool_name: str, args: dict) -> str: + """Simulate local execution of the example tools.""" + if tool_name == "get_weather": + units = args.get("units") or "celsius" + unit_suffix = "C" if units == "celsius" else "F" + return f"The weather in {args['location']} is sunny and 22°{unit_suffix}" + if tool_name == "get_stock_price": + mock_prices = { + "AAPL": "$175.43", + "GOOGL": "$142.87", + "MSFT": "$378.91", + "TSLA": "$242.15", + } + symbol = args["symbol"].upper() + return f"The current price of {symbol} is {mock_prices.get(symbol, '$100.00')}" + return "Tool result" + + +def main(): + """Run example streaming tool calling interactions.""" + print("=" * 60) + print("Streaming Tool Calling Example with m serve") + print("=" * 60) + + # Example 1: Request that should trigger weather tool + print("\n1. Weather Query (Streaming)") + print("-" * 60) + messages = [{"role": "user", "content": "What's the weather like in Tokyo?"}] + + print(f"User: {messages[0]['content']}") + print("\nAssistant: ", end="", flush=True) + content, tool_calls, finish_reason = make_streaming_request( + messages, tools=tools, tool_name="get_weather" + ) + + print(f"\n\nFinish Reason: {finish_reason}") + + if tool_calls: + print("\nTool Calls:") + for tool_call in tool_calls: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + elif content: + print("(Content already displayed above)") + else: + print("Assistant returned no content and no tool calls.") + + # Example 2: Request that should trigger stock price tool + print("\n\n2. Stock Price Query (Streaming)") + print("-" * 60) + messages = [{"role": "user", "content": "What's the current stock price of AAPL?"}] + + print(f"User: {messages[0]['content']}") + print("\nAssistant: ", end="", flush=True) + content, tool_calls, finish_reason = make_streaming_request( + messages, tools=tools, tool_name="get_stock_price" + ) + + print(f"\n\nFinish Reason: {finish_reason}") + + if tool_calls: + print("\nTool Calls:") + for tool_call in tool_calls: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + elif content: + print("(Content already displayed above)") + else: + print("Assistant returned no content and no tool calls.") + + # Example 3: Request without tools (normal chat) + print("\n\n3. Normal Chat (No Tools, Streaming)") + print("-" * 60) + messages = [{"role": "user", "content": "Hello! How are you?"}] + + print(f"User: {messages[0]['content']}") + print("\nAssistant: ", end="", flush=True) + content, tool_calls, finish_reason = make_streaming_request(messages, tools=None) + + print(f"\n\nFinish Reason: {finish_reason}") + + # Example 4: Multi-turn conversation with tool use + print("\n\n4. Multi-turn Conversation (Streaming)") + print("-" * 60) + messages = [{"role": "user", "content": "What's the weather in Paris?"}] + + print(f"User: {messages[0]['content']}") + print("\nAssistant: ", end="", flush=True) + content, tool_calls, finish_reason = make_streaming_request( + messages, tools=tools, tool_name="get_weather" + ) + print() # New line after streaming + + if tool_calls: + print("\nAssistant requested tool calls:") + + # Add assistant message once before processing tool calls + messages.append( + { + "role": "assistant", + "content": content if content else None, + "tool_calls": tool_calls, + } + ) + + tool_results: list[str] = [] + + # Process each tool call and add tool responses + for tool_call in tool_calls: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + + tool_result = _run_local_tool(func["name"], args) + tool_results.append(tool_result) + print(f" Result: {tool_result}") + + # Add tool response to conversation + messages.append( + { + "role": "tool", + "tool_call_id": tool_call["id"], + "content": tool_result, + } + ) + + # Get final response after tool execution + messages.append( + { + "role": "user", + "content": ( + f"Original question: {messages[0]['content']}\n" + f"Tool result: {'; '.join(tool_results)}\n" + "Answer the original question directly using only that tool " + "result. Do not mention unrelated topics or other tools." + ), + } + ) + print("\nGetting final response after tool execution...") + print("Assistant: ", end="", flush=True) + content, tool_calls, finish_reason = make_streaming_request( + messages, tools=None + ) + print() # New line after streaming + if not content: + print("Assistant returned no content after tool execution.") + elif content: + print("(Content already displayed above)") + else: + print("Assistant returned no content and no tool calls.") + + print("\n" + "=" * 60) + print("Examples completed!") + print("=" * 60) + + +if __name__ == "__main__": + try: + main() + except requests.exceptions.ConnectionError: + print("Error: Could not connect to server.") + print("Make sure the server is running:") + print(" uv run m serve docs/examples/m_serve/m_serve_example_tool_calling.py") + except requests.exceptions.HTTPError as e: + print(f"Error: {e}") + if e.response is not None: + try: + print("Server response:", json.dumps(e.response.json(), indent=2)) + except ValueError: + print("Server response:", e.response.text) + except Exception as e: + print(f"Unexpected error: {e}") + raise diff --git a/mellea/helpers/openai_compatible_helpers.py b/mellea/helpers/openai_compatible_helpers.py index dfa9dd122..002dd195d 100644 --- a/mellea/helpers/openai_compatible_helpers.py +++ b/mellea/helpers/openai_compatible_helpers.py @@ -1,7 +1,7 @@ """A file for helper functions that deal with OpenAI API compatible helpers.""" import json -from typing import Any +from typing import Any, Literal, TypedDict from pydantic import BaseModel @@ -11,6 +11,21 @@ from ..stdlib.components import Document, Message +class ToolCallFunction(TypedDict): + """Function details in a tool call.""" + + name: str + arguments: str + + +class ToolCallDict(TypedDict): + """OpenAI-compatible tool call dictionary with ID and function.""" + + id: str + type: Literal["function"] + function: ToolCallFunction + + class CompletionUsage(BaseModel): """Token usage statistics for a completion request.""" @@ -251,3 +266,42 @@ def build_completion_usage(output: ModelOutputThunk) -> CompletionUsage | None: completion_tokens=completion_tokens, total_tokens=total_tokens, ) + + +def build_tool_calls(output: ModelOutputThunk) -> list[ToolCallDict] | None: + """Build OpenAI-compatible tool calls from a model output, if available. + + Args: + output: Model output thunk that may expose a ``tool_calls`` mapping. + + Returns: + List of ``ToolCallDict`` objects when tool calls are present, + otherwise ``None``. + """ + import json + import uuid + + # Check for tool calls - ModelOutputThunk always has tool_calls attribute + if ( + output.tool_calls is None + or not isinstance(output.tool_calls, dict) + or not output.tool_calls + ): + return None + + tool_calls: list[ToolCallDict] = [] + for model_tool_call in output.tool_calls.values(): + # Generate a unique ID for this tool call + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + + # Serialize the arguments to JSON string + args_json = json.dumps(model_tool_call.args) + + tool_call: ToolCallDict = { + "id": tool_call_id, + "type": "function", + "function": {"name": model_tool_call.name, "arguments": args_json}, + } + tool_calls.append(tool_call) + + return tool_calls diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 6aae713dd..48aebcf2e 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -535,3 +535,90 @@ async def test_response_format_excluded_from_model_options(self, mock_module): # response_format should NOT be in model_options assert "response_format" not in model_options + + @pytest.mark.asyncio + async def test_streaming_with_tool_calls(self, mock_module): + """Test that tool calls are properly emitted in streaming responses.""" + import json + from unittest.mock import Mock + + from fastapi.responses import StreamingResponse + + from mellea.core.base import ModelToolCall + + # Create a mock tool + mock_tool = Mock() + mock_tool.name = "get_weather" + + # Create a mock output with tool calls + # Real backends may return content alongside tool calls (e.g., "I'll check that for you") + mock_output = ModelOutputThunk("I'll check the weather for you.") + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", + func=mock_tool, + args={"location": "San Francisco", "units": "celsius"}, + ) + } + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="What's the weather?")], + stream=True, + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify it's a streaming response + assert isinstance(response, StreamingResponse) + + # Collect all chunks + chunks = [] + async for chunk_data in response.body_iterator: + # Convert to string for parsing + if isinstance(chunk_data, (bytes, memoryview)): + chunk_str = ( + bytes(chunk_data).decode("utf-8") + if isinstance(chunk_data, memoryview) + else chunk_data.decode("utf-8") + ) + else: + chunk_str = chunk_data + + # Parse SSE format: "data: {json}\n\n" + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + if json_str and json_str != "[DONE]": + chunks.append(json.loads(json_str)) + + # Verify we have the expected chunk sequence + # Expected: initial (role), content, tool_calls, final = 4 chunks + assert len(chunks) == 4, f"Should have exactly 4 chunks, got {len(chunks)}" + + # Chunk 0: Initial chunk with role + initial_chunk = chunks[0] + assert initial_chunk["choices"][0]["delta"].get("role") == "assistant" + assert initial_chunk["choices"][0]["finish_reason"] is None + + # Chunk 1: Content chunk + content_chunk = chunks[1] + assert ( + content_chunk["choices"][0]["delta"].get("content") + == "I'll check the weather for you." + ) + assert content_chunk["choices"][0]["finish_reason"] is None + + # Chunk 2: Tool call chunk + tool_call_chunk = chunks[2] + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + assert len(tool_calls) == 1 + assert tool_calls[0]["function"]["name"] == "get_weather" + assert "location" in tool_calls[0]["function"]["arguments"] + assert tool_call_chunk["choices"][0]["finish_reason"] is None + + # Chunk 3: Final chunk has finish_reason="tool_calls" + final_chunk = chunks[3] + assert final_chunk["choices"][0]["delta"].get("content") is None + assert final_chunk["choices"][0]["finish_reason"] == "tool_calls" From e12f2a7ec69b97b345da6417589dd1d59dabd9cb Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 29 Apr 2026 13:18:20 -0700 Subject: [PATCH 12/20] fix: add required index field to streaming tool call deltas The OpenAI streaming spec requires each item in delta.tool_calls to carry an index field. Clients including the openai Python SDK, LangChain, and LiteLLM key their delta-reassembly state machine on this field. Without it, they silently drop tool calls, coalesce them incorrectly, or raise a TypeError depending on version. Changes: - Add ChatCompletionMessageToolCallDelta model with required index field - Add ToolCallFunctionDelta model for streaming function deltas - Update ChatCompletionChunkDelta to use delta models - Update streaming.py to populate index field using enumerate() - Add comprehensive tests verifying index field presence - Update existing test to check for index field The bundled client_streaming_tool_calling.py example masked this issue because it reads delta.tool_calls verbatim rather than going through SDK delta reassembly. Fixes compatibility with OpenAI SDK, LangChain, and LiteLLM streaming tool call consumers. Signed-off-by: Mark Sturdevant Assisted-by: IBM Bob --- cli/serve/models.py | 49 +++++- cli/serve/streaming.py | 13 +- test/cli/test_serve.py | 3 + test/cli/test_tool_call_index_verification.py | 151 ++++++++++++++++++ 4 files changed, 207 insertions(+), 9 deletions(-) create mode 100644 test/cli/test_tool_call_index_verification.py diff --git a/cli/serve/models.py b/cli/serve/models.py index 5e994011f..1130e62a9 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -91,7 +91,7 @@ class ToolCallFunction(BaseModel): class ChatCompletionMessageToolCall(BaseModel): - """A tool call generated by the model.""" + """A tool call generated by the model (non-streaming).""" id: str """The ID of the tool call.""" @@ -103,6 +103,44 @@ class ChatCompletionMessageToolCall(BaseModel): """The function that the model called.""" +class ToolCallFunctionDelta(BaseModel): + """Function details for a streaming tool call delta. + + In streaming responses, function name and arguments may arrive across + multiple chunks, so both fields are optional. + """ + + name: str | None = None + """The name of the function to call (may be None in delta chunks).""" + + arguments: str | None = None + """The arguments fragment for this delta (may be None in delta chunks).""" + + +class ChatCompletionMessageToolCallDelta(BaseModel): + """A tool call delta in a streaming response. + + Per OpenAI streaming spec, each delta must include an index field that + clients use to reassemble tool calls across chunks. The id, type, and + function fields are optional since they may arrive incrementally. + """ + + index: int + """The index of this tool call in the tool_calls array. + + Required for delta reassembly in OpenAI SDK and compatible clients. + """ + + id: str | None = None + """The ID of the tool call (may be None in subsequent delta chunks).""" + + type: Literal["function"] | None = None + """The type of the tool (may be None in subsequent delta chunks).""" + + function: ToolCallFunctionDelta | None = None + """The function delta for this chunk (may be None in some chunks).""" + + # Taking this from OpenAI types https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion.py, class ChatCompletionMessage(BaseModel): content: str | None = None @@ -170,8 +208,13 @@ class ChatCompletionChunkDelta(BaseModel): refusal: str | None = None """The refusal message fragment, if any.""" - tool_calls: list[ChatCompletionMessageToolCall] | None = None - """The tool calls generated by the model (only in tool call chunks).""" + tool_calls: list[ChatCompletionMessageToolCallDelta] | None = None + """The tool call deltas in this chunk. + + Each delta includes a required index field for reassembly by OpenAI SDK + and compatible clients. The id, type, and function fields are optional + since they may arrive incrementally across multiple chunks. + """ class ChatCompletionChunkChoice(BaseModel): diff --git a/cli/serve/streaming.py b/cli/serve/streaming.py index e837006cc..6dea035a4 100644 --- a/cli/serve/streaming.py +++ b/cli/serve/streaming.py @@ -14,11 +14,11 @@ ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkDelta, - ChatCompletionMessageToolCall, + ChatCompletionMessageToolCallDelta, OpenAIError, OpenAIErrorResponse, StreamOptions, - ToolCallFunction, + ToolCallFunctionDelta, ) @@ -108,17 +108,18 @@ async def stream_chat_completion_chunks( tool_calls_list = build_tool_calls(output) if tool_calls_list: - # Convert to ChatCompletionMessageToolCall objects + # Convert to ChatCompletionMessageToolCallDelta objects with required index tool_calls = [ - ChatCompletionMessageToolCall( + ChatCompletionMessageToolCallDelta( + index=idx, id=tc["id"], type=tc["type"], - function=ToolCallFunction( + function=ToolCallFunctionDelta( name=tc["function"]["name"], arguments=tc["function"]["arguments"], ), ) - for tc in tool_calls_list + for idx, tc in enumerate(tool_calls_list) ] # Emit tool calls in a separate chunk before the final chunk diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 48aebcf2e..add2682e5 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -614,6 +614,9 @@ async def test_streaming_with_tool_calls(self, mock_module): tool_call_chunk = chunks[2] tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] assert len(tool_calls) == 1 + # Verify required index field is present (OpenAI streaming spec requirement) + assert "index" in tool_calls[0], "tool_calls delta must include index field" + assert tool_calls[0]["index"] == 0 assert tool_calls[0]["function"]["name"] == "get_weather" assert "location" in tool_calls[0]["function"]["arguments"] assert tool_call_chunk["choices"][0]["finish_reason"] is None diff --git a/test/cli/test_tool_call_index_verification.py b/test/cli/test_tool_call_index_verification.py new file mode 100644 index 000000000..ad1232609 --- /dev/null +++ b/test/cli/test_tool_call_index_verification.py @@ -0,0 +1,151 @@ +"""Verification that streaming tool call deltas include required index field. + +This test demonstrates that our streaming implementation is compatible with +OpenAI SDK delta reassembly logic, which requires the index field. +""" + +import json +from unittest.mock import Mock + +import pytest + +from cli.serve.app import make_chat_endpoint +from cli.serve.models import ChatCompletionRequest, ChatMessage +from mellea.core.base import ModelOutputThunk, ModelToolCall + + +@pytest.mark.asyncio +async def test_tool_call_delta_has_required_index_field(): + """Verify that streaming tool call deltas include the required index field. + + The OpenAI streaming spec requires each item in delta.tool_calls to carry + an index field. Clients including the openai Python SDK, LangChain, and + LiteLLM key their delta-reassembly state machine on this field. + + Without it, they silently drop tool calls, coalesce them incorrectly, or + raise a TypeError depending on version. + """ + # Create a mock module with a serve function + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Create a mock tool + mock_tool = Mock() + mock_tool.name = "get_weather" + + # Create a mock output with multiple tool calls to test indexing + mock_output = ModelOutputThunk("I'll check the weather for you.") + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", + func=mock_tool, + args={"location": "San Francisco", "units": "celsius"}, + ), + "get_forecast": ModelToolCall( + name="get_forecast", + func=mock_tool, + args={"location": "San Francisco", "days": 3}, + ), + } + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="What's the weather?")], + stream=True, + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Collect all chunks + chunks = [] + async for chunk_data in response.body_iterator: + chunk_str = ( + bytes(chunk_data).decode("utf-8") + if isinstance(chunk_data, (bytes, memoryview)) + else chunk_data + ) + + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + if json_str and json_str != "[DONE]": + chunks.append(json.loads(json_str)) + + # Find the tool call chunk + tool_call_chunk = None + for chunk in chunks: + if chunk["choices"][0]["delta"].get("tool_calls"): + tool_call_chunk = chunk + break + + assert tool_call_chunk is not None, "Should have a tool call chunk" + + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + assert len(tool_calls) == 2, "Should have 2 tool calls" + + # Verify REQUIRED index field is present on each tool call delta + for i, tc in enumerate(tool_calls): + assert "index" in tc, f"tool_calls[{i}] must include index field" + assert isinstance(tc["index"], int), "index must be an integer" + assert tc["index"] == i, f"tool_calls[{i}] should have index={i}" + + # Verify other fields are present (id, type, function) + assert "id" in tc, f"tool_calls[{i}] should have id" + assert "type" in tc, f"tool_calls[{i}] should have type" + assert tc["type"] == "function", f"tool_calls[{i}] type should be 'function'" + assert "function" in tc, f"tool_calls[{i}] should have function" + assert "name" in tc["function"], f"tool_calls[{i}].function should have name" + assert "arguments" in tc["function"], ( + f"tool_calls[{i}].function should have arguments" + ) + + +@pytest.mark.asyncio +async def test_single_tool_call_has_index_zero(): + """Verify that a single tool call has index=0.""" + mock_module = Mock() + mock_module.__name__ = "test_module" + + mock_tool = Mock() + mock_tool.name = "search" + + mock_output = ModelOutputThunk("Searching...") + mock_output.tool_calls = { + "search": ModelToolCall(name="search", func=mock_tool, args={"query": "test"}) + } + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Search for test")], + stream=True, + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + chunks = [] + async for chunk_data in response.body_iterator: + chunk_str = ( + bytes(chunk_data).decode("utf-8") + if isinstance(chunk_data, (bytes, memoryview)) + else chunk_data + ) + + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + if json_str and json_str != "[DONE]": + chunks.append(json.loads(json_str)) + + # Find the tool call chunk + tool_call_chunk = None + for chunk in chunks: + if chunk["choices"][0]["delta"].get("tool_calls"): + tool_call_chunk = chunk + break + + assert tool_call_chunk is not None + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + assert len(tool_calls) == 1 + assert tool_calls[0]["index"] == 0, "Single tool call should have index=0" From 017701094fec20e683e9dca3eb09b77649243fa4 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 29 Apr 2026 13:42:50 -0700 Subject: [PATCH 13/20] fix: move build_tool_calls invocation build_tool_calls was called before streaming block and then not used in case of streaming. Rearrange condition and call to avoid wasted call. Signed-off-by: Mark Sturdevant --- cli/serve/app.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/cli/serve/app.py b/cli/serve/app.py index f4c2efd50..e92d648e2 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -177,6 +177,23 @@ async def endpoint(request: ChatCompletionRequest): model_options=model_options, ) + # Leave as None since we don't track backend config fingerprints yet + system_fingerprint = None + + # Handle streaming response + if request.stream: + return StreamingResponse( + stream_chat_completion_chunks( + output=output, + completion_id=completion_id, + model=request.model, + created=created_timestamp, + stream_options=request.stream_options, + system_fingerprint=system_fingerprint, + ), + media_type="text/event-stream", + ) + # Extract tool calls from the ModelOutputThunk if available tool_calls_list = build_tool_calls(output) tool_calls = ( @@ -203,25 +220,6 @@ async def endpoint(request: ChatCompletionRequest): | None ) = "tool_calls" if tool_calls else "stop" - # system_fingerprint represents backend config hash, not model name - # The model name is already in response.model (line 73) - # Leave as None since we don't track backend config fingerprints yet - system_fingerprint = None - - # Handle streaming response - if request.stream: - return StreamingResponse( - stream_chat_completion_chunks( - output=output, - completion_id=completion_id, - model=request.model, - created=created_timestamp, - stream_options=request.stream_options, - system_fingerprint=system_fingerprint, - ), - media_type="text/event-stream", - ) - return ChatCompletion( id=completion_id, model=request.model, From 43cb8b8190d1307852567a28a95bae8cfc93cf8f Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 29 Apr 2026 15:20:18 -0700 Subject: [PATCH 14/20] test: add integration test for cli/serve using TestClient with streaming and tool calling Signed-off-by: Mark Sturdevant --- test/cli/test_serve_integration.py | 572 +++++++++++++++++++++++++++++ 1 file changed, 572 insertions(+) create mode 100644 test/cli/test_serve_integration.py diff --git a/test/cli/test_serve_integration.py b/test/cli/test_serve_integration.py new file mode 100644 index 000000000..79f1a2416 --- /dev/null +++ b/test/cli/test_serve_integration.py @@ -0,0 +1,572 @@ +"""Integration tests for m serve using FastAPI TestClient. + +Tests the full HTTP request/response cycle including: +- Streaming responses (SSE format, headers, chunking) +- Tool calling responses via HTTP +- Error handling at the HTTP layer +""" + +import json +from typing import Any +from unittest.mock import Mock + +import pytest +from fastapi import FastAPI +from fastapi.exceptions import RequestValidationError +from fastapi.testclient import TestClient + +from cli.serve.app import make_chat_endpoint, validation_exception_handler +from cli.serve.models import FunctionDefinition, FunctionParameters, ToolFunction +from mellea.core.base import AbstractMelleaTool, ModelOutputThunk, ModelToolCall + +# Mark all tests in this module as integration tests +pytestmark = pytest.mark.integration + + +class MockWeatherTool(AbstractMelleaTool): + """Mock weather tool for testing.""" + + name = "get_weather" + + def run(self, location: str, units: str = "celsius") -> str: + """Mock run method.""" + return f"Weather in {location} is 22°{units[0].upper()}" + + @property + def as_json_tool(self) -> dict[str, Any]: + """Return JSON schema for this tool.""" + return { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + }, + }, + } + + @property + def as_tool_function(self) -> ToolFunction: + """Return ToolFunction model for HTTP requests.""" + return ToolFunction( + type="function", + function=FunctionDefinition( + name="get_weather", + description="Get the current weather in a location", + parameters=FunctionParameters( + RootModel={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + } + ), + ), + ) + + +@pytest.fixture +def mock_module(): + """Create a mock module with a serve function.""" + module = Mock() + module.__name__ = "test_integration_module" + return module + + +@pytest.fixture +def test_app(mock_module): + """Create a FastAPI test app with the chat endpoint.""" + app = FastAPI() + app.add_exception_handler(RequestValidationError, validation_exception_handler) + app.add_api_route( + "/v1/chat/completions", make_chat_endpoint(mock_module), methods=["POST"] + ) + return app + + +@pytest.fixture +def client(test_app): + """Create a TestClient for the app.""" + return TestClient(test_app) + + +class TestStreamingIntegration: + """Integration tests for streaming responses via HTTP.""" + + def test_streaming_response_headers(self, client, mock_module): + """Test that streaming responses have correct HTTP headers.""" + mock_output = ModelOutputThunk("Hello, streaming world!") + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + }, + ) + + # Verify streaming headers + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + def test_streaming_sse_format(self, client, mock_module): + """Test that streaming responses follow SSE format.""" + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + }, + ) + + # Parse SSE chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # Verify chunk structure + assert len(chunks) > 0 + for chunk in chunks: + assert chunk["object"] == "chat.completion.chunk" + assert "id" in chunk + assert "model" in chunk + assert "created" in chunk + assert "choices" in chunk + assert len(chunk["choices"]) == 1 + + # Verify final chunk has finish_reason + assert chunks[-1]["choices"][0]["finish_reason"] == "stop" + + def test_streaming_content_chunks(self, client, mock_module): + """Test that content is properly chunked in streaming response.""" + mock_output = ModelOutputThunk("Hello world!") + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Say hello"}], + "stream": True, + }, + ) + + # Parse chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # First chunk should have role + assert chunks[0]["choices"][0]["delta"].get("role") == "assistant" + + # Second chunk should have content + assert chunks[1]["choices"][0]["delta"].get("content") == "Hello world!" + + # Final chunk should have finish_reason + assert chunks[-1]["choices"][0]["finish_reason"] == "stop" + + def test_streaming_with_usage_field(self, client, mock_module): + """Test streaming response includes usage when stream_options.include_usage=True.""" + mock_output = ModelOutputThunk("Response") + mock_output.usage = { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + "stream_options": {"include_usage": True}, + }, + ) + + # Parse chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # Final chunk should include usage + final_chunk = chunks[-1] + assert "usage" in final_chunk + assert final_chunk["usage"]["prompt_tokens"] == 10 + assert final_chunk["usage"]["completion_tokens"] == 5 + assert final_chunk["usage"]["total_tokens"] == 15 + + def test_streaming_done_marker(self, client, mock_module): + """Test that streaming response ends with [DONE] marker.""" + mock_output = ModelOutputThunk("Test") + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + }, + ) + + # Verify [DONE] marker is present + assert "data: [DONE]" in response.text + assert response.text.strip().endswith("data: [DONE]") + + +class TestToolCallingIntegration: + """Integration tests for tool calling via HTTP.""" + + def test_tool_call_response_structure(self, client, mock_module): + """Test that tool calls are properly formatted in HTTP response.""" + mock_output = ModelOutputThunk("I'll check the weather.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", + func=mock_tool, + args={"location": "Paris", "units": "celsius"}, + ) + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather in Paris?"} + ], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + }, + ) + + assert response.status_code == 200 + data = response.json() + + # Verify response structure + assert data["object"] == "chat.completion" + assert data["choices"][0]["finish_reason"] == "tool_calls" + assert data["choices"][0]["message"]["tool_calls"] is not None + assert len(data["choices"][0]["message"]["tool_calls"]) == 1 + + # Verify tool call details + tool_call = data["choices"][0]["message"]["tool_calls"][0] + assert tool_call["type"] == "function" + assert tool_call["function"]["name"] == "get_weather" + assert tool_call["id"].startswith("call_") + + # Verify arguments + args = json.loads(tool_call["function"]["arguments"]) + assert args["location"] == "Paris" + assert args["units"] == "celsius" + + def test_multiple_tool_calls_via_http(self, client, mock_module): + """Test multiple tool calls in a single HTTP response.""" + mock_output = ModelOutputThunk("Checking multiple locations.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "weather_paris": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ), + "weather_london": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "London"} + ), + "weather_tokyo": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Tokyo"} + ), + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Weather in multiple cities"}], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + }, + ) + + assert response.status_code == 200 + data = response.json() + + # Verify multiple tool calls + tool_calls = data["choices"][0]["message"]["tool_calls"] + assert len(tool_calls) == 3 + + # Verify each has unique ID + ids = [tc["id"] for tc in tool_calls] + assert len(ids) == len(set(ids)), "Tool call IDs should be unique" + + # Verify locations + locations = [ + json.loads(tc["function"]["arguments"])["location"] for tc in tool_calls + ] + assert set(locations) == {"Paris", "London", "Tokyo"} + + def test_tool_calls_with_usage_info(self, client, mock_module): + """Test that usage info is included with tool calls.""" + mock_output = ModelOutputThunk("Calling tool.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_output.usage = { + "prompt_tokens": 50, + "completion_tokens": 20, + "total_tokens": 70, + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Weather?"}], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + }, + ) + + assert response.status_code == 200 + data = response.json() + + # Verify both tool calls and usage + assert data["choices"][0]["finish_reason"] == "tool_calls" + assert data["usage"] is not None + assert data["usage"]["total_tokens"] == 70 + + +class TestStreamingWithToolCalls: + """Integration tests for streaming responses with tool calls.""" + + def test_streaming_tool_call_response(self, client, mock_module): + """Test streaming response with tool calls via HTTP.""" + mock_output = ModelOutputThunk("I'll check that for you.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Weather in Paris?"}], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + "stream": True, + }, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + # Parse chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # Should have: initial (role), content, tool_calls, final + assert len(chunks) == 4 + + # Verify chunk sequence + assert chunks[0]["choices"][0]["delta"].get("role") == "assistant" + assert ( + chunks[1]["choices"][0]["delta"].get("content") + == "I'll check that for you." + ) + assert "tool_calls" in chunks[2]["choices"][0]["delta"] + assert chunks[3]["choices"][0]["finish_reason"] == "tool_calls" + + # Verify tool call structure in streaming chunk + tool_calls = chunks[2]["choices"][0]["delta"]["tool_calls"] + assert len(tool_calls) == 1 + assert "index" in tool_calls[0], "Streaming tool calls must include index" + assert tool_calls[0]["index"] == 0 + assert tool_calls[0]["type"] == "function" + assert tool_calls[0]["function"]["name"] == "get_weather" + + def test_streaming_multiple_tool_calls(self, client, mock_module): + """Test streaming with multiple tool calls via HTTP.""" + mock_output = ModelOutputThunk("Checking multiple locations.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "weather_1": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ), + "weather_2": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "London"} + ), + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Weather?"}], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + "stream": True, + }, + ) + + # Parse chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # Find tool call chunk (with non-None tool_calls) + tool_call_chunk = next( + c + for c in chunks + if "tool_calls" in c["choices"][0]["delta"] + and c["choices"][0]["delta"]["tool_calls"] is not None + ) + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + + # Verify multiple tool calls with indices + assert len(tool_calls) == 2 + indices = [tc["index"] for tc in tool_calls] + assert indices == [0, 1] + + def test_streaming_tool_calls_with_usage(self, client, mock_module): + """Test streaming tool calls with usage info via HTTP.""" + mock_output = ModelOutputThunk("Calling tool.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_output.usage = { + "prompt_tokens": 30, + "completion_tokens": 15, + "total_tokens": 45, + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Weather?"}], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + "stream": True, + "stream_options": {"include_usage": True}, + }, + ) + + # Parse chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # Final chunk should have both finish_reason and usage + final_chunk = chunks[-1] + assert final_chunk["choices"][0]["finish_reason"] == "tool_calls" + assert "usage" in final_chunk + assert final_chunk["usage"]["total_tokens"] == 45 + + +class TestHTTPErrorHandling: + """Integration tests for error handling at HTTP layer.""" + + def test_invalid_request_returns_400(self, client, mock_module): + """Test that invalid requests return 400 with OpenAI error format.""" + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "n": 0, # Invalid: must be >= 1 + }, + ) + + assert response.status_code == 400 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "invalid_request_error" + assert data["error"]["param"] == "n" + + def test_unsupported_n_parameter(self, client, mock_module): + """Test that n > 1 is rejected with proper error.""" + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "n": 2, + }, + ) + + assert response.status_code == 400 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "invalid_request_error" + assert data["error"]["param"] == "n" + assert "not supported" in data["error"]["message"].lower() + + def test_server_error_returns_500(self, client, mock_module): + """Test that server errors return 500 with OpenAI error format.""" + # Make serve raise an exception + mock_module.serve.side_effect = RuntimeError("Internal error") + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + }, + ) + + assert response.status_code == 500 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "server_error" + assert "Internal error" in data["error"]["message"] From e68d50cd2a48dd4d501708ff95441cc7d3eeeb21 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 29 Apr 2026 15:47:54 -0700 Subject: [PATCH 15/20] fix: use fallback for json.dumps in build_tool_calls Use str to non-serializable types. This should effectively avoid TypeError (in normal situations). Signed-off-by: Mark Sturdevant --- mellea/helpers/openai_compatible_helpers.py | 4 +-- .../helpers/test_openai_compatible_helpers.py | 30 ++++++++++++++++++- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/mellea/helpers/openai_compatible_helpers.py b/mellea/helpers/openai_compatible_helpers.py index 002dd195d..71ecec888 100644 --- a/mellea/helpers/openai_compatible_helpers.py +++ b/mellea/helpers/openai_compatible_helpers.py @@ -294,8 +294,8 @@ def build_tool_calls(output: ModelOutputThunk) -> list[ToolCallDict] | None: # Generate a unique ID for this tool call tool_call_id = f"call_{uuid.uuid4().hex[:24]}" - # Serialize the arguments to JSON string - args_json = json.dumps(model_tool_call.args) + # Serialize arguments to JSON with str fallback for non-serializable types + args_json = json.dumps(model_tool_call.args, default=str) tool_call: ToolCallDict = { "id": tool_call_id, diff --git a/test/helpers/test_openai_compatible_helpers.py b/test/helpers/test_openai_compatible_helpers.py index 3963c3d8f..2f29a554f 100644 --- a/test/helpers/test_openai_compatible_helpers.py +++ b/test/helpers/test_openai_compatible_helpers.py @@ -2,12 +2,15 @@ import base64 import json +from datetime import datetime +from decimal import Decimal import pytest from mellea.backends.tools import MelleaTool -from mellea.core.base import ImageBlock +from mellea.core.base import ImageBlock, ModelOutputThunk, ModelToolCall from mellea.helpers.openai_compatible_helpers import ( + build_tool_calls, chat_completion_delta_merge, extract_model_tool_requests, message_to_openai_message, @@ -359,5 +362,30 @@ def test_docs_across_messages(self): assert result[1]["text"] == "b" +# --- build_tool_calls --- + + +class TestBuildToolCalls: + def test_with_non_json_serializable_args(self): + """Non-JSON-serializable values (datetime, Decimal) are converted to strings.""" + tool = _make_tool("test_tool") + tool_call = ModelToolCall( + name="test_tool", + func=tool, + args={"timestamp": datetime(2024, 1, 15), "amount": Decimal("123.45")}, + ) + output = ModelOutputThunk(value="test", tool_calls={"test_tool": tool_call}) + + result = build_tool_calls(output) + + assert result is not None + assert len(result) == 1 + # Verify arguments are valid JSON and values were converted to strings + args = json.loads(result[0]["function"]["arguments"]) + assert isinstance(args["timestamp"], str) + assert "2024-01-15" in args["timestamp"] + assert args["amount"] == "123.45" + + if __name__ == "__main__": pytest.main([__file__]) From f3c9d854effaa00b5e655b4c5c0fe4431df8a3dd Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 29 Apr 2026 18:36:47 -0700 Subject: [PATCH 16/20] test: restore cli streaming tests to fix conflicts New tests for tooling improved coverage, but the significant rewrite caused too much diverging from main. Keeping the old tests in places while adding new tests in new file will help sort this out. Signed-off-by: Mark Sturdevant --- test/cli/test_serve_streaming_tool_calls.py | 503 ++++++++++++++++++++ 1 file changed, 503 insertions(+) create mode 100644 test/cli/test_serve_streaming_tool_calls.py diff --git a/test/cli/test_serve_streaming_tool_calls.py b/test/cli/test_serve_streaming_tool_calls.py new file mode 100644 index 000000000..6aa691ab5 --- /dev/null +++ b/test/cli/test_serve_streaming_tool_calls.py @@ -0,0 +1,503 @@ +"""Unit tests for streaming with tool calls, usage fields, and error handling. + +This file contains new tests added in the tool-calling PR. The main streaming +tests (from main branch) are in test_serve_streaming.py. +""" + +import json +from unittest.mock import AsyncMock, Mock + +import pytest + +from cli.serve.models import StreamOptions +from cli.serve.streaming import stream_chat_completion_chunks +from mellea.core.base import ModelOutputThunk, ModelToolCall + + +class TestStreamingToolCalls: + """Tests for streaming responses with tool calls.""" + + @pytest.mark.asyncio + async def test_streaming_tool_call_chunk_structure(self): + """Test that tool call chunks have correct structure with index field.""" + # Create a mock tool + mock_tool = Mock() + mock_tool.name = "get_weather" + + # Create output with tool calls + output = ModelOutputThunk("Checking weather...") + output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", + func=mock_tool, + args={"location": "San Francisco", "units": "celsius"}, + ) + } + + # Stream chunks + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test123", + model="test-model", + created=1234567890, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Should have: initial (role), content, tool_calls, final = 4 chunks + assert len(chunks) == 4 + + # Verify tool call chunk structure + tool_call_chunk = chunks[2] + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + assert len(tool_calls) == 1 + + # Critical: index field must be present (OpenAI streaming spec) + assert "index" in tool_calls[0], "tool_calls delta must include index field" + assert tool_calls[0]["index"] == 0 + assert tool_calls[0]["id"] is not None + assert tool_calls[0]["type"] == "function" + assert tool_calls[0]["function"]["name"] == "get_weather" + assert "location" in tool_calls[0]["function"]["arguments"] + + @pytest.mark.asyncio + async def test_finish_reason_tool_calls(self): + """Test that finish_reason is 'tool_calls' when tool calls are present.""" + mock_tool = Mock() + mock_tool.name = "test_func" + + output = ModelOutputThunk("Response") + output.tool_calls = { + "test_func": ModelToolCall( + name="test_func", func=mock_tool, args={"arg": "value"} + ) + } + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Final chunk should have finish_reason="tool_calls" + final_chunk = chunks[-1] + assert final_chunk["choices"][0]["finish_reason"] == "tool_calls" + + @pytest.mark.asyncio + async def test_finish_reason_stop_without_tool_calls(self): + """Test that finish_reason is 'stop' when no tool calls are present.""" + output = ModelOutputThunk("Simple response") + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Final chunk should have finish_reason="stop" + final_chunk = chunks[-1] + assert final_chunk["choices"][0]["finish_reason"] == "stop" + + @pytest.mark.asyncio + async def test_multiple_tool_calls_with_indices(self): + """Test that multiple tool calls each get correct index values.""" + mock_tool1 = Mock() + mock_tool1.name = "func1" + mock_tool2 = Mock() + mock_tool2.name = "func2" + mock_tool3 = Mock() + mock_tool3.name = "func3" + + output = ModelOutputThunk("Calling multiple functions") + output.tool_calls = { + "func1": ModelToolCall(name="func1", func=mock_tool1, args={"a": 1}), + "func2": ModelToolCall(name="func2", func=mock_tool2, args={"b": 2}), + "func3": ModelToolCall(name="func3", func=mock_tool3, args={"c": 3}), + } + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Find tool call chunk + tool_call_chunk = chunks[2] + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + + # Should have 3 tool calls with indices 0, 1, 2 + assert len(tool_calls) == 3 + indices = [tc["index"] for tc in tool_calls] + assert indices == [0, 1, 2] + + # Verify each has required fields + for tc in tool_calls: + assert "index" in tc + assert "id" in tc + assert "type" in tc + assert tc["type"] == "function" + assert "function" in tc + assert "name" in tc["function"] + assert "arguments" in tc["function"] + + @pytest.mark.asyncio + async def test_tool_call_chunk_before_final_chunk(self): + """Test that tool call chunk is emitted before final chunk.""" + mock_tool = Mock() + mock_tool.name = "test_func" + + output = ModelOutputThunk("Response") + output.tool_calls = { + "test_func": ModelToolCall(name="test_func", func=mock_tool, args={}) + } + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Verify chunk sequence + assert len(chunks) == 4 + + # Chunk 0: initial with role + assert chunks[0]["choices"][0]["delta"].get("role") == "assistant" + assert chunks[0]["choices"][0]["finish_reason"] is None + + # Chunk 1: content + assert chunks[1]["choices"][0]["delta"].get("content") == "Response" + assert chunks[1]["choices"][0]["finish_reason"] is None + + # Chunk 2: tool calls (before final) + assert "tool_calls" in chunks[2]["choices"][0]["delta"] + assert chunks[2]["choices"][0]["finish_reason"] is None + + # Chunk 3: final with finish_reason + assert chunks[3]["choices"][0]["finish_reason"] == "tool_calls" + + +class TestStreamingIncrementalContent: + """Tests for streaming with incremental content (not pre-computed).""" + + @pytest.mark.asyncio + async def test_streaming_incremental_chunks(self): + """Test streaming with incremental content via astream().""" + from unittest.mock import patch + + # Create output that streams incrementally + output = ModelOutputThunk("") + + # Mock astream to return incremental chunks + chunks_to_stream = ["Hello", " ", "world", "!"] + stream_index = 0 + + async def mock_astream(): + nonlocal stream_index + if stream_index < len(chunks_to_stream): + chunk = chunks_to_stream[stream_index] + stream_index += 1 + return chunk + else: + # Mark as computed by setting the value property + output.value = "Hello world!" + # Use object.__setattr__ to bypass property setter for _computed + object.__setattr__(output, "_computed", True) + return "" + + def mock_is_computed(): + return stream_index >= len(chunks_to_stream) + + # Patch the astream and is_computed methods + with ( + patch.object(output, "astream", side_effect=mock_astream), + patch.object(output, "is_computed", side_effect=mock_is_computed), + ): + # Collect streamed chunks + collected_chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if ( + chunk_data.startswith("data: ") + and chunk_data.strip() != "data: [DONE]" + ): + json_str = chunk_data[6:].strip() + parsed = json.loads(json_str) + delta_content = parsed["choices"][0]["delta"].get("content") + if delta_content: + collected_chunks.append(delta_content) + + # Should have initial role chunk + 4 content chunks + # (role chunk has content=None, so not collected) + assert collected_chunks == ["Hello", " ", "world", "!"] + + @pytest.mark.asyncio + async def test_streaming_with_tool_calls_after_incremental_content(self): + """Test that tool calls are emitted after incremental content streaming.""" + from unittest.mock import patch + + mock_tool = Mock() + mock_tool.name = "test_func" + + # Create output that streams incrementally + output = ModelOutputThunk("") + output.tool_calls = { + "test_func": ModelToolCall( + name="test_func", func=mock_tool, args={"key": "value"} + ) + } + + # Mock astream + chunks_to_stream = ["Part1", "Part2"] + stream_index = 0 + + async def mock_astream(): + nonlocal stream_index + if stream_index < len(chunks_to_stream): + chunk = chunks_to_stream[stream_index] + stream_index += 1 + return chunk + else: + output.value = "Part1Part2" + object.__setattr__(output, "_computed", True) + return "" + + def mock_is_computed(): + return stream_index >= len(chunks_to_stream) + + # Patch the astream and is_computed methods + with ( + patch.object(output, "astream", side_effect=mock_astream), + patch.object(output, "is_computed", side_effect=mock_is_computed), + ): + # Collect all chunks + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if ( + chunk_data.startswith("data: ") + and chunk_data.strip() != "data: [DONE]" + ): + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Should have: initial, Part1, Part2, tool_calls, final = 5 chunks + assert len(chunks) == 5 + + # Verify sequence + assert chunks[0]["choices"][0]["delta"].get("role") == "assistant" + assert chunks[1]["choices"][0]["delta"].get("content") == "Part1" + assert chunks[2]["choices"][0]["delta"].get("content") == "Part2" + assert "tool_calls" in chunks[3]["choices"][0]["delta"] + assert chunks[4]["choices"][0]["finish_reason"] == "tool_calls" + + +class TestStreamingUsageField: + """Tests for usage field in streaming responses.""" + + @pytest.mark.asyncio + async def test_usage_included_when_stream_options_set(self): + """Test that usage is included in final chunk when stream_options.include_usage=True.""" + output = ModelOutputThunk("Response") + output.usage = {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + stream_options=StreamOptions(include_usage=True), + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Final chunk should include usage + final_chunk = chunks[-1] + assert "usage" in final_chunk + assert final_chunk["usage"]["prompt_tokens"] == 10 + assert final_chunk["usage"]["completion_tokens"] == 5 + assert final_chunk["usage"]["total_tokens"] == 15 + + @pytest.mark.asyncio + async def test_usage_excluded_when_stream_options_not_set(self): + """Test that usage is excluded when stream_options is None.""" + output = ModelOutputThunk("Response") + output.usage = {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + stream_options=None, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Final chunk should NOT include usage + final_chunk = chunks[-1] + assert "usage" not in final_chunk or final_chunk["usage"] is None + + @pytest.mark.asyncio + async def test_usage_excluded_when_include_usage_false(self): + """Test that usage is excluded when stream_options.include_usage=False.""" + output = ModelOutputThunk("Response") + output.usage = {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + stream_options=StreamOptions(include_usage=False), + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Final chunk should NOT include usage + final_chunk = chunks[-1] + assert "usage" not in final_chunk or final_chunk["usage"] is None + + +class TestStreamingErrorHandling: + """Tests for error handling in streaming.""" + + @pytest.mark.asyncio + async def test_streaming_error_emits_error_response(self): + """Test that streaming errors emit OpenAI-compatible error responses.""" + # Create output that will raise an error during streaming + output = ModelOutputThunk("") + output._computed = False + + async def mock_astream_error(): + raise RuntimeError("Simulated streaming error") + + output.astream = mock_astream_error + + # Collect chunks + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + chunks.append(chunk_data) + + # Should have: initial chunk, error response, [DONE] + assert len(chunks) >= 3 + + # Find error response (second-to-last before [DONE]) + error_chunk_data = chunks[-2] + assert error_chunk_data.startswith("data: ") + json_str = error_chunk_data[6:].strip() + error_response = json.loads(json_str) + + # Verify error structure + assert "error" in error_response + assert error_response["error"]["type"] == "server_error" + assert "Streaming error" in error_response["error"]["message"] + assert "Simulated streaming error" in error_response["error"]["message"] + + # Should still end with [DONE] + assert chunks[-1] == "data: [DONE]\n\n" + + +class TestStreamingChunkMetadata: + """Tests for chunk metadata fields.""" + + @pytest.mark.asyncio + async def test_all_chunks_have_required_fields(self): + """Test that all chunks have required OpenAI fields.""" + output = ModelOutputThunk("Test response") + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test123", + model="test-model-name", + created=1234567890, + system_fingerprint="test-fingerprint", + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Verify all chunks have required fields + for chunk in chunks: + assert chunk["id"] == "chatcmpl-test123" + assert chunk["model"] == "test-model-name" + assert chunk["created"] == 1234567890 + assert chunk["object"] == "chat.completion.chunk" + assert chunk["system_fingerprint"] == "test-fingerprint" + assert "choices" in chunk + assert len(chunk["choices"]) == 1 + assert chunk["choices"][0]["index"] == 0 + + @pytest.mark.asyncio + async def test_done_marker_emitted(self): + """Test that [DONE] marker is always emitted at the end.""" + output = ModelOutputThunk("Response") + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + chunks.append(chunk_data) + + # Last chunk should be [DONE] + assert chunks[-1] == "data: [DONE]\n\n" + + @pytest.mark.asyncio + async def test_sse_format_correct(self): + """Test that chunks follow SSE format: 'data: {json}\\n\\n'.""" + output = ModelOutputThunk("Response") + + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + # All chunks should start with "data: " + assert chunk_data.startswith("data: ") + # All chunks should end with double newline + assert chunk_data.endswith("\n\n") From e46afd32d3af50b0a8757f3cd71744cfca10d378 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 29 Apr 2026 18:48:53 -0700 Subject: [PATCH 17/20] test: update output.usage -> output.generation.usage Rebased and now the new tests need updating. Signed-off-by: Mark Sturdevant --- test/cli/test_serve_streaming_tool_calls.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/test/cli/test_serve_streaming_tool_calls.py b/test/cli/test_serve_streaming_tool_calls.py index 6aa691ab5..40db343d5 100644 --- a/test/cli/test_serve_streaming_tool_calls.py +++ b/test/cli/test_serve_streaming_tool_calls.py @@ -5,7 +5,7 @@ """ import json -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock import pytest @@ -328,7 +328,11 @@ class TestStreamingUsageField: async def test_usage_included_when_stream_options_set(self): """Test that usage is included in final chunk when stream_options.include_usage=True.""" output = ModelOutputThunk("Response") - output.usage = {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + output.generation.usage = { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } chunks = [] async for chunk_data in stream_chat_completion_chunks( @@ -353,7 +357,11 @@ async def test_usage_included_when_stream_options_set(self): async def test_usage_excluded_when_stream_options_not_set(self): """Test that usage is excluded when stream_options is None.""" output = ModelOutputThunk("Response") - output.usage = {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + output.generation.usage = { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } chunks = [] async for chunk_data in stream_chat_completion_chunks( @@ -375,7 +383,11 @@ async def test_usage_excluded_when_stream_options_not_set(self): async def test_usage_excluded_when_include_usage_false(self): """Test that usage is excluded when stream_options.include_usage=False.""" output = ModelOutputThunk("Response") - output.usage = {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + output.generation.usage = { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } chunks = [] async for chunk_data in stream_chat_completion_chunks( From 5b737abefc4ae67e70fb0b5ffdc233424876866b Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 29 Apr 2026 19:34:09 -0700 Subject: [PATCH 18/20] test: update tests usage -> gneration.usage More new tests need fixing after rebase. Signed-off-by: Mark Sturdevant --- test/cli/test_serve_integration.py | 6 +++--- test/cli/test_serve_tool_calling.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/cli/test_serve_integration.py b/test/cli/test_serve_integration.py index 79f1a2416..ff6792470 100644 --- a/test/cli/test_serve_integration.py +++ b/test/cli/test_serve_integration.py @@ -196,7 +196,7 @@ def test_streaming_content_chunks(self, client, mock_module): def test_streaming_with_usage_field(self, client, mock_module): """Test streaming response includes usage when stream_options.include_usage=True.""" mock_output = ModelOutputThunk("Response") - mock_output.usage = { + mock_output.generation.usage = { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15, @@ -346,7 +346,7 @@ def test_tool_calls_with_usage_info(self, client, mock_module): name="get_weather", func=mock_tool, args={"location": "Paris"} ) } - mock_output.usage = { + mock_output.generation.usage = { "prompt_tokens": 50, "completion_tokens": 20, "total_tokens": 70, @@ -481,7 +481,7 @@ def test_streaming_tool_calls_with_usage(self, client, mock_module): name="get_weather", func=mock_tool, args={"location": "Paris"} ) } - mock_output.usage = { + mock_output.generation.usage = { "prompt_tokens": 30, "completion_tokens": 15, "total_tokens": 45, diff --git a/test/cli/test_serve_tool_calling.py b/test/cli/test_serve_tool_calling.py index d1e76cd09..29c5bbf1b 100644 --- a/test/cli/test_serve_tool_calling.py +++ b/test/cli/test_serve_tool_calling.py @@ -275,7 +275,7 @@ async def test_tool_calls_with_usage_info(self, mock_module, sample_tool_request name="get_weather", func=mock_tool, args={"location": "Paris"} ) } - mock_output.usage = { + mock_output.generation.usage = { "prompt_tokens": 50, "completion_tokens": 20, "total_tokens": 70, From 3aea0c3ba18e26a847d234e2b71911f5c41d063b Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 5 May 2026 11:30:08 -0700 Subject: [PATCH 19/20] refactor(serve): simplify tool call construction with Pydantic validation - Use model_validate() instead of manual field mapping for tool calls - Move uuid import to module level in openai_compatible_helpers - Replace manual async function with AsyncMock in streaming error test - Remove redundant comments about tool call extraction These changes reduce code duplication and leverage Pydantic's built-in validation for cleaner, more maintainable code. Signed-off-by: Mark Sturdevant Assisted-by: IBM Bob --- cli/serve/app.py | 10 +--------- cli/serve/streaming.py | 11 +---------- mellea/helpers/openai_compatible_helpers.py | 4 +--- test/cli/test_serve_streaming_tool_calls.py | 10 +++++----- 4 files changed, 8 insertions(+), 27 deletions(-) diff --git a/cli/serve/app.py b/cli/serve/app.py index e92d648e2..e941037c8 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -194,18 +194,10 @@ async def endpoint(request: ChatCompletionRequest): media_type="text/event-stream", ) - # Extract tool calls from the ModelOutputThunk if available tool_calls_list = build_tool_calls(output) tool_calls = ( [ - ChatCompletionMessageToolCall( - id=tc["id"], - type=tc["type"], - function=ToolCallFunction( - name=tc["function"]["name"], - arguments=tc["function"]["arguments"], - ), - ) + ChatCompletionMessageToolCall.model_validate(tc) for tc in tool_calls_list ] if tool_calls_list diff --git a/cli/serve/streaming.py b/cli/serve/streaming.py index 6dea035a4..0ae7193b9 100644 --- a/cli/serve/streaming.py +++ b/cli/serve/streaming.py @@ -104,21 +104,12 @@ async def stream_chat_completion_chunks( ) yield f"data: {chunk.model_dump_json()}\n\n" - # Extract tool calls from the ModelOutputThunk if available tool_calls_list = build_tool_calls(output) if tool_calls_list: # Convert to ChatCompletionMessageToolCallDelta objects with required index tool_calls = [ - ChatCompletionMessageToolCallDelta( - index=idx, - id=tc["id"], - type=tc["type"], - function=ToolCallFunctionDelta( - name=tc["function"]["name"], - arguments=tc["function"]["arguments"], - ), - ) + ChatCompletionMessageToolCallDelta.model_validate({**tc, "index": idx}) for idx, tc in enumerate(tool_calls_list) ] diff --git a/mellea/helpers/openai_compatible_helpers.py b/mellea/helpers/openai_compatible_helpers.py index 71ecec888..bd1507910 100644 --- a/mellea/helpers/openai_compatible_helpers.py +++ b/mellea/helpers/openai_compatible_helpers.py @@ -1,6 +1,7 @@ """A file for helper functions that deal with OpenAI API compatible helpers.""" import json +import uuid from typing import Any, Literal, TypedDict from pydantic import BaseModel @@ -278,9 +279,6 @@ def build_tool_calls(output: ModelOutputThunk) -> list[ToolCallDict] | None: List of ``ToolCallDict`` objects when tool calls are present, otherwise ``None``. """ - import json - import uuid - # Check for tool calls - ModelOutputThunk always has tool_calls attribute if ( output.tool_calls is None diff --git a/test/cli/test_serve_streaming_tool_calls.py b/test/cli/test_serve_streaming_tool_calls.py index 40db343d5..0b5a3a63e 100644 --- a/test/cli/test_serve_streaming_tool_calls.py +++ b/test/cli/test_serve_streaming_tool_calls.py @@ -5,7 +5,7 @@ """ import json -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest @@ -416,10 +416,10 @@ async def test_streaming_error_emits_error_response(self): output = ModelOutputThunk("") output._computed = False - async def mock_astream_error(): - raise RuntimeError("Simulated streaming error") - - output.astream = mock_astream_error + # Use AsyncMock with side_effect to raise error + output.astream = AsyncMock( + side_effect=RuntimeError("Simulated streaming error") + ) # Collect chunks chunks = [] From 2b6792ea735f2a093195263e793a446583da9e72 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 5 May 2026 14:49:19 -0700 Subject: [PATCH 20/20] fix: remove unused imports Signed-off-by: Mark Sturdevant --- cli/serve/app.py | 1 - cli/serve/streaming.py | 1 - 2 files changed, 2 deletions(-) diff --git a/cli/serve/app.py b/cli/serve/app.py index e941037c8..7ad354a95 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -35,7 +35,6 @@ Choice, OpenAIError, OpenAIErrorResponse, - ToolCallFunction, ) from .streaming import stream_chat_completion_chunks diff --git a/cli/serve/streaming.py b/cli/serve/streaming.py index 0ae7193b9..4d2f8f8ec 100644 --- a/cli/serve/streaming.py +++ b/cli/serve/streaming.py @@ -18,7 +18,6 @@ OpenAIError, OpenAIErrorResponse, StreamOptions, - ToolCallFunctionDelta, )