From c85008e939c74cde579c322f0d4afc8f17b5f20b Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Fri, 17 Apr 2026 12:57:43 -0700 Subject: [PATCH 01/13] feat: cli OpenAI-compatible API `response_format` support - Added `JsonSchemaFormat` model to represent JSON schema definitions - Extended `ResponseFormat` to support `json_schema` type (in addition to existing `text` and `json_object`) - Used field alias to avoid conflict with Pydantic's `schema` method - Added `_json_schema_to_pydantic()` utility function to dynamically convert JSON schemas to Pydantic models - Updated `_build_model_options()` to exclude `response_format` from model options (handled separately) - Modified `make_chat_endpoint()` to: - Parse `response_format` from requests - Convert `json_schema` type to Pydantic models using the utility function - Detect if the serve function accepts a `format` parameter using `inspect.signature()` - Pass the generated Pydantic model as `format=` parameter to serve functions that support it - Handle backward compatibility with serve functions that don't accept `format` - Added proper error handling for invalid schemas - Test json_schema format is converted to Pydantic model and passed to serve - Test json_object format doesn't pass a schema - Test text format doesn't pass a schema - Test error handling for missing json_schema field - Test error handling for invalid JSON schemas - Test backward compatibility with serve functions without format parameter - Test optional fields in JSON schemas When a client sends a request with `response_format.type = "json_schema"`, the server: 1. Extracts the JSON schema from `response_format.json_schema.schema` 2. Dynamically creates a Pydantic model from the schema 3. Passes it as the `format=` parameter to the serve function 4. The serve function can then use this for constrained decoding via Mellea's `instruct()` method This maps OpenAI's `response_format` API to Mellea's native `format=` parameter for structured output. Signed-off-by: Mark Sturdevant --- cli/serve/app.py | 127 +++++++++++++++++-- cli/serve/models.py | 20 ++- test/cli/test_serve.py | 276 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 410 insertions(+), 13 deletions(-) diff --git a/cli/serve/app.py b/cli/serve/app.py index 583b28c01..57a93524e 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -7,6 +7,7 @@ import sys import time import uuid +from typing import Any try: import typer @@ -14,6 +15,7 @@ from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse + from pydantic import BaseModel, create_model except ImportError as e: raise ImportError( "The 'm serve' command requires extra dependencies. " @@ -90,6 +92,58 @@ def create_openai_error_response( ) +def _json_schema_to_pydantic( + schema: dict[str, Any], model_name: str = "DynamicModel" +) -> type[BaseModel]: + """Convert a JSON Schema to a Pydantic model dynamically. + + Args: + schema: JSON Schema definition (must have 'properties' and 'type': 'object'). + model_name: Name for the generated Pydantic model. + + Returns: + A dynamically created Pydantic model class. + + Raises: + ValueError: If the schema is invalid or unsupported. + """ + if not isinstance(schema, dict): + raise ValueError("Schema must be a dictionary") + + if schema.get("type") != "object": + raise ValueError("Only object-type schemas are supported") + + properties = schema.get("properties", {}) + required = schema.get("required", []) + + if not properties: + raise ValueError("Schema must have 'properties' field") + + # Map JSON Schema types to Python types + type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "array": list, + "object": dict, + } + + # Build field definitions for create_model + field_definitions: dict[str, Any] = {} + for field_name, field_schema in properties.items(): + field_type = field_schema.get("type", "string") + python_type = type_mapping.get(field_type, str) + + # Handle optional fields + if field_name in required: + field_definitions[field_name] = (python_type, ...) + else: + field_definitions[field_name] = (python_type | None, None) + + return create_model(model_name, **field_definitions) + + def _build_model_options(request: ChatCompletionRequest) -> dict: """Build model_options dict from OpenAI-compatible request parameters.""" excluded_fields = { @@ -108,7 +162,7 @@ def _build_model_options(request: ChatCompletionRequest) -> dict: "presence_penalty", # Presence penalty - not yet implemented "frequency_penalty", # Frequency penalty - not yet implemented "logit_bias", # Logit bias - not yet implemented - "response_format", # Response format (json_object) - not yet implemented + "response_format", # Response format - handled separately "functions", # Legacy function calling - not yet implemented "function_call", # Legacy function calling - not yet implemented "tools", # Tool calling - not yet implemented @@ -154,22 +208,71 @@ async def endpoint(request: ChatCompletionRequest): model_options = _build_model_options(request) + # Handle response_format + format_model: type[BaseModel] | None = None + if request.response_format is not None: + if request.response_format.type == "json_schema": + if request.response_format.json_schema is None: + return create_openai_error_response( + status_code=400, + message="json_schema field is required when response_format.type is 'json_schema'", + error_type="invalid_request_error", + param="response_format.json_schema", + ) + try: + format_model = _json_schema_to_pydantic( + request.response_format.json_schema.schema_, + request.response_format.json_schema.name, + ) + except ValueError as e: + return create_openai_error_response( + status_code=400, + message=f"Invalid JSON schema: {e!s}", + error_type="invalid_request_error", + param="response_format.json_schema.schema", + ) + elif request.response_format.type == "json_object": + # For json_object, we don't enforce a specific schema + # The backend will handle JSON mode if supported + pass + + # Check if serve function accepts format parameter + serve_sig = inspect.signature(module.serve) + accepts_format = "format" in serve_sig.parameters + # Detect if serve is async or sync and handle accordingly if inspect.iscoroutinefunction(module.serve): # It's async, await it directly - output = await module.serve( - input=request.messages, - requirements=request.requirements, - model_options=model_options, - ) + if accepts_format: + output = await module.serve( + input=request.messages, + requirements=request.requirements, + model_options=model_options, + format=format_model, + ) + else: + output = await module.serve( + input=request.messages, + requirements=request.requirements, + model_options=model_options, + ) else: # It's sync, run in thread pool to avoid blocking event loop - output = await asyncio.to_thread( - module.serve, - input=request.messages, - requirements=request.requirements, - model_options=model_options, - ) + if accepts_format: + output = await asyncio.to_thread( + module.serve, + input=request.messages, + requirements=request.requirements, + model_options=model_options, + format=format_model, + ) + else: + output = await asyncio.to_thread( + module.serve, + input=request.messages, + requirements=request.requirements, + model_options=model_options, + ) # system_fingerprint represents backend config hash, not model name # The model name is already in response.model (line 73) diff --git a/cli/serve/models.py b/cli/serve/models.py index 7e738730e..b68588417 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -29,8 +29,26 @@ class ToolFunction(BaseModel): function: FunctionDefinition +class JsonSchemaFormat(BaseModel): + """JSON Schema definition for structured output.""" + + name: str + """Name of the schema.""" + + schema_: dict[str, Any] = Field(alias="schema") + """JSON Schema definition.""" + + strict: bool | None = None + """Whether to enforce strict schema validation.""" + + model_config = {"populate_by_name": True} + + class ResponseFormat(BaseModel): - type: Literal["text", "json_object"] + type: Literal["text", "json_object", "json_schema"] + + json_schema: JsonSchemaFormat | None = None + """JSON Schema definition when type is 'json_schema'.""" class StreamOptions(BaseModel): diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 515cc82f2..2b67db9e1 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -535,3 +535,279 @@ async def test_response_format_excluded_from_model_options(self, mock_module): # response_format should NOT be in model_options assert "response_format" not in model_options + + +class TestResponseFormat: + """Tests for response_format parameter handling.""" + + @pytest.mark.asyncio + async def test_json_schema_format_passed_to_serve(self): + """Test that json_schema response_format is converted to Pydantic model and passed to serve.""" + from pydantic import BaseModel + + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a mock module with serve that accepts format parameter + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Track calls manually + captured_format = None + + def mock_serve(input, requirements=None, model_options=None, format=None): + nonlocal captured_format + captured_format = format + return ModelOutputThunk('{"name": "Alice", "age": 30}') + + # Assign the real function so signature inspection works + mock_module.serve = mock_serve + + # Create a request with json_schema response_format + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate a person")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Person", + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify format was passed + assert captured_format is not None + assert issubclass(captured_format, BaseModel) + assert "name" in captured_format.model_fields + assert "age" in captured_format.model_fields + + # Verify response is successful + assert isinstance(response, ChatCompletion) + assert response.choices[0].message.content == '{"name": "Alice", "age": 30}' + + @pytest.mark.asyncio + async def test_json_object_format_no_schema(self, mock_module): + """Test that json_object response_format doesn't pass a format model.""" + from cli.serve.models import ResponseFormat + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate JSON")], + response_format=ResponseFormat(type="json_object"), + ) + + mock_output = ModelOutputThunk('{"result": "success"}') + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify serve was called + call_args = mock_module.serve.call_args + assert call_args is not None + + # For json_object, format should be None (no specific schema) + if "format" in call_args.kwargs: + assert call_args.kwargs["format"] is None + + # Verify response is successful + assert isinstance(response, ChatCompletion) + + @pytest.mark.asyncio + async def test_text_format_no_schema(self, mock_module): + """Test that text response_format doesn't pass a format model.""" + from cli.serve.models import ResponseFormat + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + response_format=ResponseFormat(type="text"), + ) + + mock_output = ModelOutputThunk("Hello there!") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify serve was called + call_args = mock_module.serve.call_args + assert call_args is not None + + # For text, format should be None + if "format" in call_args.kwargs: + assert call_args.kwargs["format"] is None + + # Verify response is successful + assert isinstance(response, ChatCompletion) + + @pytest.mark.asyncio + async def test_json_schema_missing_schema_field(self, mock_module): + """Test that json_schema without schema field returns error.""" + import json + + from fastapi.responses import JSONResponse + + from cli.serve.models import ResponseFormat + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + response_format=ResponseFormat( + type="json_schema", + json_schema=None, # Missing schema + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should return error + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + + body_bytes = response.body + if isinstance(body_bytes, memoryview): + body_bytes = bytes(body_bytes) + error_data = json.loads(body_bytes.decode("utf-8")) + assert "error" in error_data + assert error_data["error"]["type"] == "invalid_request_error" + assert "json_schema" in error_data["error"]["message"].lower() + + @pytest.mark.asyncio + async def test_json_schema_invalid_schema(self, mock_module): + """Test that invalid JSON schema returns error.""" + import json + + from fastapi.responses import JSONResponse + + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Invalid", + schema={ + "type": "array", # Not supported (only object) + "items": {"type": "string"}, + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should return error + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + + body_bytes = response.body + if isinstance(body_bytes, memoryview): + body_bytes = bytes(body_bytes) + error_data = json.loads(body_bytes.decode("utf-8")) + assert "error" in error_data + assert error_data["error"]["type"] == "invalid_request_error" + assert "schema" in error_data["error"]["message"].lower() + + @pytest.mark.asyncio + async def test_serve_without_format_parameter(self, mock_module): + """Test that serve functions without format parameter still work.""" + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a serve function that doesn't accept format + def serve_no_format(input, requirements=None, model_options=None): + return ModelOutputThunk("Response without format") + + mock_module.serve = serve_no_format + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Test", + schema={ + "type": "object", + "properties": {"result": {"type": "string"}}, + "required": ["result"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should succeed even though serve doesn't accept format + assert isinstance(response, ChatCompletion) + assert response.choices[0].message.content == "Response without format" + + @pytest.mark.asyncio + async def test_json_schema_with_optional_fields(self): + """Test that JSON schema with optional fields is handled correctly.""" + from pydantic import BaseModel + + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a mock module with serve that accepts format parameter + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Track calls manually + captured_format = None + + def mock_serve(input, requirements=None, model_options=None, format=None): + nonlocal captured_format + captured_format = format + return ModelOutputThunk('{"name": "Widget", "price": 9.99}') + + # Assign the real function so signature inspection works + mock_module.serve = mock_serve + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Product", + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "price": {"type": "number"}, + "description": {"type": "string"}, + }, + "required": ["name", "price"], # description is optional + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify format model was created correctly + assert captured_format is not None + assert issubclass(captured_format, BaseModel) + assert "name" in captured_format.model_fields + assert "price" in captured_format.model_fields + assert "description" in captured_format.model_fields + + # Verify response is successful + assert isinstance(response, ChatCompletion) From 79f327fdf78b4ba28ca618efea0032e540cf0b54 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Fri, 17 Apr 2026 15:59:01 -0700 Subject: [PATCH 02/13] feat: add response_format support in cli when streaming Signed-off-by: Mark Sturdevant --- cli/serve/app.py | 1 + cli/serve/streaming.py | 46 ++++++++ test/cli/test_serve.py | 256 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 303 insertions(+) diff --git a/cli/serve/app.py b/cli/serve/app.py index 57a93524e..859d37045 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -289,6 +289,7 @@ async def endpoint(request: ChatCompletionRequest): created=created_timestamp, stream_options=request.stream_options, system_fingerprint=system_fingerprint, + format_model=format_model, ), media_type="text/event-stream", ) diff --git a/cli/serve/streaming.py b/cli/serve/streaming.py index 51ff33c3c..15298a98a 100644 --- a/cli/serve/streaming.py +++ b/cli/serve/streaming.py @@ -1,7 +1,10 @@ """Streaming utilities for OpenAI-compatible server responses.""" +import json from collections.abc import AsyncGenerator +from pydantic import BaseModel, ValidationError + from mellea.core.base import ModelOutputThunk from mellea.core.utils import MelleaLogger from mellea.helpers.openai_compatible_helpers import build_completion_usage @@ -23,6 +26,7 @@ async def stream_chat_completion_chunks( created: int, stream_options: StreamOptions | None = None, system_fingerprint: str | None = None, + format_model: type[BaseModel] | None = None, ) -> AsyncGenerator[str, None]: """Generate OpenAI-compatible SSE chat completion chunks from a model output. @@ -36,6 +40,9 @@ async def stream_chat_completion_chunks( ``include_usage`` field. system_fingerprint: Backend configuration fingerprint to include in chunks. Defaults to ``None``. + format_model: Optional Pydantic model for validating structured output. + When provided, the complete streamed output will be validated against + this schema before the final chunk is sent. Yields: Server-sent event payload strings representing OpenAI-compatible chat @@ -98,6 +105,45 @@ async def stream_chat_completion_chunks( ) yield f"data: {chunk.model_dump_json()}\n\n" + # Validate format if format_model is provided + if format_model is not None: + if output.value is None: + error_response = OpenAIErrorResponse( + error=OpenAIError( + message="Output value is None, cannot validate format", + type="invalid_response_error", + ) + ) + yield f"data: {error_response.model_dump_json()}\n\n" + yield "data: [DONE]\n\n" + return + + try: + # Parse the complete output as JSON + output_json = json.loads(output.value) + # Validate against the Pydantic model + format_model.model_validate(output_json) + except json.JSONDecodeError as e: + error_response = OpenAIErrorResponse( + error=OpenAIError( + message=f"Output is not valid JSON: {e!s}", + type="invalid_response_error", + ) + ) + yield f"data: {error_response.model_dump_json()}\n\n" + yield "data: [DONE]\n\n" + return + except ValidationError as e: + error_response = OpenAIErrorResponse( + error=OpenAIError( + message=f"Output does not match required schema: {e!s}", + type="invalid_response_error", + ) + ) + yield f"data: {error_response.model_dump_json()}\n\n" + yield "data: [DONE]\n\n" + return + # Include usage in final chunk only if explicitly requested via stream_options # Per OpenAI spec: usage is only included when stream_options.include_usage=True include_usage = stream_options is not None and stream_options.include_usage diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 2b67db9e1..afeda4ebe 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -811,3 +811,259 @@ def mock_serve(input, requirements=None, model_options=None, format=None): # Verify response is successful assert isinstance(response, ChatCompletion) + + +class TestResponseFormatStreaming: + """Tests for response_format parameter with streaming enabled.""" + + @pytest.mark.asyncio + async def test_json_schema_format_with_streaming(self): + """Test that json_schema response_format works with stream=True.""" + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a mock module with serve that accepts format parameter + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Create a mock output that supports streaming + mock_output = ModelOutputThunk('{"name": "Alice", "age": 30}') + mock_output._computed = True # Mark as pre-computed + + def mock_serve(input, requirements=None, model_options=None, format=None): + return mock_output + + mock_module.serve = mock_serve + + # Create a request with json_schema response_format and streaming + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate a person")], + stream=True, + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Person", + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify it's a streaming response + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + # Consume the stream and verify chunks + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should have multiple chunks including initial, content, final, and [DONE] + assert len(chunks) > 0 + + # Verify no error chunks (all should start with "data: ") + for chunk in chunks: + chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + assert chunk_str.startswith("data: ") + + @pytest.mark.asyncio + async def test_json_schema_format_streaming_validation_error(self): + """Test that invalid JSON in streaming response returns error chunk.""" + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a mock module + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Create output with invalid JSON (missing required field) + mock_output = ModelOutputThunk('{"name": "Alice"}') # Missing 'age' + mock_output._computed = True + + def mock_serve(input, requirements=None, model_options=None, format=None): + return mock_output + + mock_module.serve = mock_serve + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + stream=True, + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Person", + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + # Consume the stream + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should contain an error chunk + error_found = False + for chunk in chunks: + chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + if "error" in chunk_str.lower() and "schema" in chunk_str.lower(): + error_found = True + break + + assert error_found, "Expected validation error in stream" + + @pytest.mark.asyncio + async def test_json_schema_format_streaming_invalid_json(self): + """Test that non-JSON output in streaming response returns error chunk.""" + from cli.serve.models import JsonSchemaFormat, ResponseFormat + + # Create a mock module + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Create output with invalid JSON + mock_output = ModelOutputThunk("This is not JSON") + mock_output._computed = True + + def mock_serve(input, requirements=None, model_options=None, format=None): + return mock_output + + mock_module.serve = mock_serve + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + stream=True, + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Person", + schema={ + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + # Consume the stream + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should contain an error chunk about invalid JSON + error_found = False + for chunk in chunks: + chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + if "error" in chunk_str.lower() and "json" in chunk_str.lower(): + error_found = True + break + + assert error_found, "Expected JSON parsing error in stream" + + @pytest.mark.asyncio + async def test_json_object_format_with_streaming(self): + """Test that json_object response_format works with stream=True.""" + from cli.serve.models import ResponseFormat + + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Valid JSON output + mock_output = ModelOutputThunk('{"result": "success"}') + mock_output._computed = True + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate JSON")], + stream=True, + response_format=ResponseFormat(type="json_object"), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + # Consume the stream + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should complete successfully without errors + assert len(chunks) > 0 + # Verify no error chunks + for chunk in chunks: + chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + assert "error" not in chunk_str.lower() or chunk_str.startswith( + "data: [DONE]" + ) + + @pytest.mark.asyncio + async def test_text_format_with_streaming(self): + """Test that text response_format works with stream=True.""" + from cli.serve.models import ResponseFormat + + mock_module = Mock() + mock_module.__name__ = "test_module" + + mock_output = ModelOutputThunk("Plain text response") + mock_output._computed = True + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + stream=True, + response_format=ResponseFormat(type="text"), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + # Consume the stream + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should complete successfully + assert len(chunks) > 0 From af7e90578fc310b7504060497eeab9eab6455a79 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Fri, 17 Apr 2026 16:35:30 -0700 Subject: [PATCH 03/13] feat: cli response_format features adding doc examples Signed-off-by: Mark Sturdevant --- docs/examples/m_serve/README.md | 91 +++++++ .../m_serve/client_response_format.py | 254 ++++++++++++++++++ .../m_serve_example_response_format.py | 56 ++++ 3 files changed, 401 insertions(+) create mode 100644 docs/examples/m_serve/client_response_format.py create mode 100644 docs/examples/m_serve/m_serve_example_response_format.py diff --git a/docs/examples/m_serve/README.md b/docs/examples/m_serve/README.md index 70fcb5f5e..c65ba8819 100644 --- a/docs/examples/m_serve/README.md +++ b/docs/examples/m_serve/README.md @@ -19,6 +19,14 @@ A dedicated streaming example for `m serve` that supports both modes: - `stream=True` returns an uncomputed thunk so the server can emit incremental Server-Sent Events (SSE) chunks +### m_serve_example_response_format.py +Example demonstrating structured output with the `response_format` parameter. + +**Key Features:** +- Supporting the `format` parameter in serve functions +- Structured output validation with JSON schemas +- Three format types: `text`, `json_object`, `json_schema` + ### pii_serve.py Example of serving a PII (Personally Identifiable Information) detection service. @@ -29,6 +37,9 @@ Client code for testing the served API endpoints with non-streaming requests. Client code demonstrating streaming responses using Server-Sent Events (SSE) against `m_serve_example_streaming.py`. +### client_response_format.py +Client code demonstrating all three `response_format` types with examples. + ## Concepts Demonstrated - **API Deployment**: Exposing Mellea programs as REST APIs @@ -37,6 +48,7 @@ against `m_serve_example_streaming.py`. - **Validation in Production**: Using requirements in deployed services - **Model Options**: Passing model configuration through API - **Streaming Responses**: Real-time token streaming via Server-Sent Events (SSE) +- **Structured Output**: Using `response_format` for JSON schema validation ## Basic Pattern @@ -84,6 +96,85 @@ m serve docs/examples/m_serve/m_serve_example_streaming.py python docs/examples/m_serve/client_streaming.py ``` +### Response Format + +```bash +# Start the response_format example server +m serve docs/examples/m_serve/m_serve_example_response_format.py + +# In another terminal, test with the response_format client +python docs/examples/m_serve/client_response_format.py +``` + +## Response Format Support + +The server supports structured output via the `response_format` parameter, which allows you to control the format of the model's response. This is compatible with OpenAI's response format API. + +**Three Format Types:** + +1. **`text`** (default): Plain text output +2. **`json_object`**: Unstructured JSON output (model decides the schema) +3. **`json_schema`**: Structured output validated against a JSON schema + +**Key Features:** +- Automatic JSON schema to Pydantic model conversion +- Schema validation for structured outputs +- OpenAI-compatible API +- Works with the `format` parameter in serve functions + +**Example - JSON Schema:** +```python +import openai + +client = openai.OpenAI(api_key="na", base_url="http://0.0.0.0:8080/v1") + +# Define a schema for structured output +person_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "email": {"type": "string"}, + }, + "required": ["name", "age", "email"], +} + +response = client.chat.completions.create( + messages=[{"role": "user", "content": "Generate a person named Alice"}], + model="granite4:micro-h", + response_format={ + "type": "json_schema", + "json_schema": { + "name": "Person", + "schema": person_schema, + "strict": True, + }, + }, +) + +# Response will be valid JSON matching the schema +print(response.choices[0].message.content) +``` + +**Server Implementation:** +Your serve function must accept a `format` parameter to support `json_schema`: + +```python +def serve( + input: list[ChatMessage], + requirements: list[str] | None = None, + model_options: dict | None = None, + format: type | None = None, # Add this parameter +) -> ModelOutputThunk: + result = session.instruct( + description=input[-1].content, + requirements=requirements, + model_options=model_options, + format=format, # Pass to instruct() + ) + return result +``` + ## Streaming Support The server supports streaming responses via Server-Sent Events (SSE) when the diff --git a/docs/examples/m_serve/client_response_format.py b/docs/examples/m_serve/client_response_format.py new file mode 100644 index 000000000..a51f371b1 --- /dev/null +++ b/docs/examples/m_serve/client_response_format.py @@ -0,0 +1,254 @@ +# pytest: skip_always +"""Client demonstrating response_format parameter with m serve. + +This example shows how to use the three response_format types: +1. text - Plain text output (default) +2. json_object - Unstructured JSON output +3. json_schema - Structured output validated against a JSON schema + +Prerequisites: + Start the server first: + m serve docs/examples/m_serve/m_serve_example_response_format.py + + Then run this client: + python docs/examples/m_serve/client_response_format.py +""" + +import json + +import openai + +PORT = 8080 +BASE_URL = f"http://0.0.0.0:{PORT}/v1" + +# Create OpenAI client pointing to our m serve endpoint +client = openai.OpenAI(api_key="not-needed", base_url=BASE_URL) + + +def example_text_format(): + """Example 1: Plain text output (default behavior).""" + print("\n" + "=" * 60) + print("Example 1: Text Format (default)") + print("=" * 60) + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[{"role": "user", "content": "Write a haiku about programming."}], + response_format={"type": "text"}, + ) + + print(f"Response: {response.choices[0].message.content}") + + +def example_json_object(): + """Example 2: Unstructured JSON output. + + Note: json_object format requests JSON but doesn't enforce it strictly. + The model may wrap JSON in markdown or add explanatory text. + For strict JSON validation, use json_schema instead. + """ + print("\n" + "=" * 60) + print("Example 2: JSON Object Format") + print("=" * 60) + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[ + { + "role": "user", + "content": "Generate a JSON object with information about a fictional person. Include name, age, and city. Return ONLY the JSON, no markdown formatting.", + } + ], + response_format={"type": "json_object"}, + ) + + content = response.choices[0].message.content or "" + print(f"Response: {content}") + + # First, try to parse as-is (valid JSON) + try: + data = json.loads(content) + print("\n✓ Valid JSON received") + print(f"\nParsed JSON:\n{json.dumps(data, indent=2)}") + return + except json.JSONDecodeError: + # Not valid JSON, try to extract from markdown + print("\n⚠ Response is not valid JSON, attempting to extract from markdown...") + + # Fallback: Try to extract JSON from markdown code blocks + json_content = content + if "```json" in content: + # Extract JSON from markdown code block + start = content.find("```json") + 7 + end = content.find("```", start) + if end > start: + json_content = content[start:end].strip() + print("Extracted from ```json block") + elif "```" in content: + # Generic code block + start = content.find("```") + 3 + end = content.find("```", start) + if end > start: + json_content = content[start:end].strip() + print("Extracted from ``` block") + + # Try parsing the extracted content + try: + data = json.loads(json_content) + print( + f"\n✓ Successfully extracted and parsed JSON:\n{json.dumps(data, indent=2)}" + ) + except json.JSONDecodeError as e: + print("\n✗ Failed to parse JSON even after extraction") + print("Note: json_object format doesn't enforce strict JSON.") + print("For guaranteed JSON output, use json_schema format instead.") + print(f"Parse error: {e}") + + +def example_json_schema_person(): + """Example 3: Structured output with JSON schema validation.""" + print("\n" + "=" * 60) + print("Example 3: JSON Schema Format - Person") + print("=" * 60) + + # Define a JSON schema for a person + person_schema = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The person's full name"}, + "age": {"type": "integer", "description": "The person's age in years"}, + "email": {"type": "string", "description": "The person's email address"}, + "city": { + "type": "string", + "description": "The city where the person lives", + }, + }, + "required": ["name", "age", "email"], + "additionalProperties": False, + } + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[ + { + "role": "user", + "content": "Generate information about a software engineer named Alice.", + } + ], + response_format={ + "type": "json_schema", + "json_schema": {"name": "Person", "schema": person_schema, "strict": True}, + }, + ) + + content = response.choices[0].message.content + print(f"Response: {content}") + + # Parse and validate the structured output + try: + data = json.loads(content or "{}") + print(f"\nParsed structured output:\n{json.dumps(data, indent=2)}") + + # Verify required fields + assert "name" in data, "Missing required field: name" + assert "age" in data, "Missing required field: age" + assert "email" in data, "Missing required field: email" + print("\n✓ All required fields present") + + except json.JSONDecodeError as e: + print(f"Failed to parse JSON: {e}") + except AssertionError as e: + print(f"Validation error: {e}") + + +def example_json_schema_product(): + """Example 4: Structured output for a product catalog.""" + print("\n" + "=" * 60) + print("Example 4: JSON Schema Format - Product") + print("=" * 60) + + # Define a JSON schema for a product + product_schema = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Product name"}, + "price": {"type": "number", "description": "Price in USD"}, + "category": { + "type": "string", + "enum": ["electronics", "clothing", "food", "books"], + "description": "Product category", + }, + "in_stock": { + "type": "boolean", + "description": "Whether the product is in stock", + }, + "description": {"type": "string", "description": "Product description"}, + }, + "required": ["name", "price", "category", "in_stock"], + "additionalProperties": False, + } + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[ + { + "role": "user", + "content": "Generate a product listing for a laptop computer.", + } + ], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "Product", + "schema": product_schema, + "strict": True, + }, + }, + ) + + content = response.choices[0].message.content + print(f"Response: {content}") + + # Parse and display the structured output + try: + data = json.loads(content or "{}") + print(f"\nParsed product data:\n{json.dumps(data, indent=2)}") + + # Verify the category is valid + valid_categories = ["electronics", "clothing", "food", "books"] + if data.get("category") in valid_categories: + print(f"\n✓ Valid category: {data['category']}") + + except json.JSONDecodeError as e: + print(f"Failed to parse JSON: {e}") + + +def main(): + """Run all examples.""" + print("\n" + "=" * 60) + print("RESPONSE_FORMAT EXAMPLES") + print("=" * 60) + print(f"Connecting to: {BASE_URL}") + print("=" * 60) + + try: + # Run all examples + example_text_format() + example_json_object() + example_json_schema_person() + example_json_schema_product() + + print("\n" + "=" * 60) + print("ALL EXAMPLES COMPLETED") + print("=" * 60) + + except Exception as e: + print(f"\nError: {e}") + print("\nMake sure the server is running:") + print( + f" m serve docs/examples/m_serve/m_serve_example_response_format.py --port {PORT}" + ) + + +if __name__ == "__main__": + main() diff --git a/docs/examples/m_serve/m_serve_example_response_format.py b/docs/examples/m_serve/m_serve_example_response_format.py new file mode 100644 index 000000000..4d2bc6b5c --- /dev/null +++ b/docs/examples/m_serve/m_serve_example_response_format.py @@ -0,0 +1,56 @@ +# pytest: ollama, e2e + +"""Example demonstrating response_format with m serve. + +This example shows how to use the response_format parameter to get structured +output from the model. The server supports three format types: +- text: Plain text output (default) +- json_object: Unstructured JSON output +- json_schema: Structured output validated against a JSON schema + +Run the server: + m serve docs/examples/m_serve/m_serve_example_response_format.py + +Test with the client: + python docs/examples/m_serve/client_response_format.py +""" + +from typing import Any + +import mellea +from cli.serve.models import ChatMessage +from mellea.core import ModelOutputThunk +from mellea.stdlib.context import ChatContext + +session = mellea.start_session(ctx=ChatContext()) + + +def serve( + input: list[ChatMessage], + requirements: list[str] | None = None, + model_options: dict[str, Any] | None = None, + format: type | None = None, +) -> ModelOutputThunk: + """Serve function that supports response_format parameter. + + Args: + input: List of chat messages from the client + requirements: Optional list of requirement strings + model_options: Optional model configuration parameters + format: Optional Pydantic model for structured output (from response_format) + + Returns: + ModelOutputThunk with the generated response + """ + message = input[-1].content or "No message provided" + + # When format is provided (from json_schema response_format), + # pass it to instruct() to get structured output + result = session.instruct( + description=message, + requirements=requirements, # type: ignore + model_options=model_options, + format=format, # This enables structured output validation + ) + + return result From bfcd50083240586bdd0db568dd7951723fab08f3 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 22 Apr 2026 14:08:12 -0700 Subject: [PATCH 04/13] fix: cli use output stop reason and cleanup extra conditions * elif/pass not needed. Comment * set finish reason from result Signed-off-by: Mark Sturdevant --- cli/serve/app.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cli/serve/app.py b/cli/serve/app.py index 859d37045..c85024c2e 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -231,10 +231,9 @@ async def endpoint(request: ChatCompletionRequest): error_type="invalid_request_error", param="response_format.json_schema.schema", ) - elif request.response_format.type == "json_object": - # For json_object, we don't enforce a specific schema - # The backend will handle JSON mode if supported - pass + # For "json_object" and "text", format_model remains None + # Note: "json_object" mode is not yet implemented - the backend + # receives no signal to produce JSON output (same as "text" mode) # Check if serve function accepts format parameter serve_sig = inspect.signature(module.serve) @@ -304,7 +303,7 @@ async def endpoint(request: ChatCompletionRequest): message=ChatCompletionMessage( content=output.value, role="assistant" ), - finish_reason="stop", + finish_reason=output.finish_reason, ) ], object="chat.completion", # type: ignore From e82bdc9dfbc631e1715c31e93a6ee97f535e3420 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 29 Apr 2026 00:23:00 -0700 Subject: [PATCH 05/13] feat: cli serve supports json_schema * inspect server() to set accepts_format and is_async in make_chat_endpoint instead of on every request * improved support for json_schema that can be converted to pydantic * no output validation. The module is given a format when applicable, but instead of throwing an error if the formatting is incorrect, the results are sent to the client. * consolidated redundant blocks of kwargs by building a dict of kwargs needed. * improved setting of finish_reason * added tests * improved comments/docstrings where clarification needed Signed-off-by: Mark Sturdevant --- cli/serve/app.py | 108 ++------ cli/serve/models.py | 2 +- cli/serve/schema_converter.py | 407 ++++++++++++++++++++++++++++++ cli/serve/streaming.py | 54 +--- cli/serve/utils.py | 49 ++++ test/cli/test_schema_converter.py | 207 +++++++++++++++ test/cli/test_serve.py | 199 +++------------ test/cli/test_serve_utils.py | 185 ++++++++++++++ 8 files changed, 917 insertions(+), 294 deletions(-) create mode 100644 cli/serve/schema_converter.py create mode 100644 cli/serve/utils.py create mode 100644 test/cli/test_schema_converter.py create mode 100644 test/cli/test_serve_utils.py diff --git a/cli/serve/app.py b/cli/serve/app.py index c85024c2e..31b0f2da3 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -15,7 +15,7 @@ from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse - from pydantic import BaseModel, create_model + from pydantic import BaseModel except ImportError as e: raise ImportError( "The 'm serve' command requires extra dependencies. " @@ -33,7 +33,9 @@ OpenAIError, OpenAIErrorResponse, ) +from .schema_converter import json_schema_to_pydantic from .streaming import stream_chat_completion_chunks +from .utils import extract_finish_reason app = FastAPI( title="M serve OpenAI API Compatible Server", @@ -92,58 +94,6 @@ def create_openai_error_response( ) -def _json_schema_to_pydantic( - schema: dict[str, Any], model_name: str = "DynamicModel" -) -> type[BaseModel]: - """Convert a JSON Schema to a Pydantic model dynamically. - - Args: - schema: JSON Schema definition (must have 'properties' and 'type': 'object'). - model_name: Name for the generated Pydantic model. - - Returns: - A dynamically created Pydantic model class. - - Raises: - ValueError: If the schema is invalid or unsupported. - """ - if not isinstance(schema, dict): - raise ValueError("Schema must be a dictionary") - - if schema.get("type") != "object": - raise ValueError("Only object-type schemas are supported") - - properties = schema.get("properties", {}) - required = schema.get("required", []) - - if not properties: - raise ValueError("Schema must have 'properties' field") - - # Map JSON Schema types to Python types - type_mapping = { - "string": str, - "integer": int, - "number": float, - "boolean": bool, - "array": list, - "object": dict, - } - - # Build field definitions for create_model - field_definitions: dict[str, Any] = {} - for field_name, field_schema in properties.items(): - field_type = field_schema.get("type", "string") - python_type = type_mapping.get(field_type, str) - - # Handle optional fields - if field_name in required: - field_definitions[field_name] = (python_type, ...) - else: - field_definitions[field_name] = (python_type | None, None) - - return create_model(model_name, **field_definitions) - - def _build_model_options(request: ChatCompletionRequest) -> dict: """Build model_options dict from OpenAI-compatible request parameters.""" excluded_fields = { @@ -191,6 +141,10 @@ def _build_model_options(request: ChatCompletionRequest) -> dict: def make_chat_endpoint(module): """Makes a chat endpoint using a custom module.""" + # Inspect serve function once at endpoint creation time + serve_sig = inspect.signature(module.serve) + accepts_format = "format" in serve_sig.parameters + is_async = inspect.iscoroutinefunction(module.serve) async def endpoint(request: ChatCompletionRequest): try: @@ -220,7 +174,7 @@ async def endpoint(request: ChatCompletionRequest): param="response_format.json_schema", ) try: - format_model = _json_schema_to_pydantic( + format_model = json_schema_to_pydantic( request.response_format.json_schema.schema_, request.response_format.json_schema.name, ) @@ -235,43 +189,22 @@ async def endpoint(request: ChatCompletionRequest): # Note: "json_object" mode is not yet implemented - the backend # receives no signal to produce JSON output (same as "text" mode) - # Check if serve function accepts format parameter - serve_sig = inspect.signature(module.serve) - accepts_format = "format" in serve_sig.parameters + # Build kwargs for serve call + serve_kwargs: dict[str, Any] = { + "input": request.messages, + "requirements": request.requirements, + "model_options": model_options, + } + if accepts_format: + serve_kwargs["format"] = format_model # Detect if serve is async or sync and handle accordingly - if inspect.iscoroutinefunction(module.serve): + if is_async: # It's async, await it directly - if accepts_format: - output = await module.serve( - input=request.messages, - requirements=request.requirements, - model_options=model_options, - format=format_model, - ) - else: - output = await module.serve( - input=request.messages, - requirements=request.requirements, - model_options=model_options, - ) + output = await module.serve(**serve_kwargs) else: # It's sync, run in thread pool to avoid blocking event loop - if accepts_format: - output = await asyncio.to_thread( - module.serve, - input=request.messages, - requirements=request.requirements, - model_options=model_options, - format=format_model, - ) - else: - output = await asyncio.to_thread( - module.serve, - input=request.messages, - requirements=request.requirements, - model_options=model_options, - ) + output = await asyncio.to_thread(module.serve, **serve_kwargs) # system_fingerprint represents backend config hash, not model name # The model name is already in response.model (line 73) @@ -288,7 +221,6 @@ async def endpoint(request: ChatCompletionRequest): created=created_timestamp, stream_options=request.stream_options, system_fingerprint=system_fingerprint, - format_model=format_model, ), media_type="text/event-stream", ) @@ -303,7 +235,7 @@ async def endpoint(request: ChatCompletionRequest): message=ChatCompletionMessage( content=output.value, role="assistant" ), - finish_reason=output.finish_reason, + finish_reason=extract_finish_reason(output), ) ], object="chat.completion", # type: ignore diff --git a/cli/serve/models.py b/cli/serve/models.py index b68588417..64ee1e11f 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -39,7 +39,7 @@ class JsonSchemaFormat(BaseModel): """JSON Schema definition.""" strict: bool | None = None - """Whether to enforce strict schema validation.""" + """Accepted for OpenAI compatibility; currently ignored by ``m serve``.""" model_config = {"populate_by_name": True} diff --git a/cli/serve/schema_converter.py b/cli/serve/schema_converter.py new file mode 100644 index 000000000..866e9341d --- /dev/null +++ b/cli/serve/schema_converter.py @@ -0,0 +1,407 @@ +"""Helpers for converting OpenAI-style JSON Schema response formats.""" + +from enum import Enum +from typing import Annotated, Any, Literal, cast + +from pydantic import BaseModel, ConfigDict, Strict, create_model + + +def json_schema_to_pydantic( + schema: dict[str, Any], model_name: str = "DynamicModel" +) -> type[BaseModel]: + """Convert a practical subset of JSON Schema to a Pydantic model dynamically. + + This converter targets the OpenAI-style structured output schemas used by + ``m serve``. It intentionally maps JSON Schema features into Python typing + and Pydantic model semantics rather than attempting to preserve every JSON + Schema validation rule exactly. + + Supported features: + - top-level and nested ``object`` schemas with ``properties`` and ``required`` + - primitive types: ``string``, ``integer``, ``number``, ``boolean`` + - arrays via ``type: "array"`` with supported ``items`` + - string or primitive enums via ``enum`` + - nullable fields via ``type: ["", "null"]`` + - local ``$ref`` into ``$defs`` / ``definitions`` + - simple ``allOf`` merging for object-like schemas + - simple ``anyOf`` / ``oneOf`` unions when each branch is representable + - boolean and schema-valued ``additionalProperties`` + + Behavior notes: + - ``additionalProperties: false`` maps to ``extra="forbid"`` + - ``additionalProperties: true`` maps to ``extra="ignore"`` + - schema-valued ``additionalProperties`` maps to ``dict[str, ValueType]`` + only for open-ended object maps. It cannot be combined with named + ``properties`` because that is not representable as a single standard + Pydantic field shape without custom validators. + - sibling keywords next to ``$ref`` are merged over the resolved target, + matching common JSON Schema practice for OpenAI-compatible schemas + + Still unsupported and will raise ``ValueError``: + - non-local refs + - tuple-style array schemas + - object schemas without ``properties`` unless they are pure + ``additionalProperties`` maps + - schema constraints beyond representable typing/extra handling + + Args: + schema: JSON Schema definition (must have top-level ``type: "object"``). + model_name: Name for the generated Pydantic model. + + Returns: + A dynamically created Pydantic model class. + + Raises: + ValueError: If the schema is invalid or unsupported. + """ + defs = schema.get("$defs") + if defs is None: + defs = schema.get("definitions", {}) + if defs is None: + defs = {} + if not isinstance(defs, dict): + raise ValueError("Schema '$defs' must be an object") + + ref_cache: dict[str, Any] = {} + model_cache: dict[str, type[BaseModel]] = {} + + def _sanitize_model_name(name: str) -> str: + sanitized = "".join(ch if ch.isalnum() else "_" for ch in name).strip("_") + return sanitized or "DynamicModel" + + def _format_path(path: str) -> str: + return path or "" + + def _resolve_ref(ref: str) -> dict[str, Any]: + if ref in ref_cache: + resolved = ref_cache[ref] + if not isinstance(resolved, dict): + raise ValueError(f"Resolved ref is invalid: {ref}") + return resolved + + prefixes = ("#/$defs/", "#/definitions/") + for prefix in prefixes: + if ref.startswith(prefix): + key = ref[len(prefix) :] + if key not in defs: + raise ValueError(f"Unresolved local ref: {ref}") + target = defs[key] + if not isinstance(target, dict): + raise ValueError(f"Ref target must be an object: {ref}") + ref_cache[ref] = target + return target + + raise ValueError( + f"Only local $ref values into $defs/definitions are supported: {ref}" + ) + + def _merge_nullable(annotation: Any, is_nullable: bool) -> Any: + """Wrap an annotation in ``None`` when the source schema is nullable.""" + if is_nullable: + return annotation | None + return annotation + + def _enum_annotation(enum_values: list[Any], path: str) -> Any: + """Convert JSON Schema enum values into a Python typing annotation.""" + if not enum_values: + raise ValueError(f"{_format_path(path)} enum must not be empty") + + value_types = {type(value) for value in enum_values} + if len(value_types) != 1: + raise ValueError( + f"{_format_path(path)} enum values must all have the same primitive type" + ) + + value_type = value_types.pop() + allowed_types = {str, int, float, bool} + if value_type not in allowed_types: + raise ValueError( + f"{_format_path(path)} enum values must be string, integer, number, or boolean" + ) + + if value_type is str: + enum_name = _sanitize_model_name( + path.replace(".", "_").replace("[", "_").replace("]", "") + ) + members = { + ( + value.upper() if value and value[0].isalpha() else f"VALUE_{index}" + ): value + for index, value in enumerate(enum_values) + } + return Enum(enum_name or "GeneratedEnum", members) + + return Literal[tuple(enum_values)] + + def _merge_object_schemas( + schemas: list[dict[str, Any]], path: str + ) -> dict[str, Any]: + """Merge simple object schemas for ``allOf``. + + This supports the common OpenAI-compatible case where ``allOf`` is used + to compose object fragments. Conflicting keywords are rejected rather + than silently guessed. + """ + merged: dict[str, Any] = {"type": "object", "properties": {}, "required": []} + merged_required: set[str] = set() + merged_additional_properties: bool | dict[str, Any] = True + + for index, branch in enumerate(schemas): + resolved_branch = _normalize_schema(branch, f"{path}.allOf[{index}]") + branch_type = resolved_branch.get("type", "object") + if branch_type != "object": + raise ValueError( + f"{_format_path(path)} allOf only supports object branches" + ) + + branch_properties = resolved_branch.get("properties", {}) + if not isinstance(branch_properties, dict): + raise ValueError( + f"{_format_path(path)} allOf branch properties must be an object" + ) + + for property_name, property_schema in branch_properties.items(): + if property_name in merged["properties"]: + raise ValueError( + f"{_format_path(path)} allOf has conflicting property " + f"definitions for '{property_name}'" + ) + cast(dict[str, Any], merged["properties"])[property_name] = ( + property_schema + ) + + branch_required = resolved_branch.get("required", []) + if not isinstance(branch_required, list): + raise ValueError( + f"{_format_path(path)} allOf branch 'required' must be an array" + ) + merged_required.update( + field_name + for field_name in branch_required + if isinstance(field_name, str) + ) + + branch_additional_properties = resolved_branch.get( + "additionalProperties", True + ) + if branch_additional_properties is False: + merged_additional_properties = False + elif isinstance(branch_additional_properties, dict): + if merged_additional_properties is True: + merged_additional_properties = branch_additional_properties + elif merged_additional_properties is False: + continue + elif merged_additional_properties != branch_additional_properties: + raise ValueError( + f"{_format_path(path)} allOf has conflicting " + "additionalProperties schemas" + ) + + merged["required"] = sorted(merged_required) + merged["additionalProperties"] = merged_additional_properties + return merged + + def _union_annotation( + keyword: str, union_schemas: list[dict[str, Any]], path: str + ) -> Any: + """Convert ``anyOf``/``oneOf`` branches into a Python union annotation.""" + if not union_schemas: + raise ValueError(f"{_format_path(path)} {keyword} must not be empty") + + annotations: list[Any] = [] + for index, branch in enumerate(union_schemas): + annotations.append( + _schema_to_type(branch, f"{path}.{keyword}[{index}]", in_union=True) + ) + + annotation = annotations[0] + for branch_annotation in annotations[1:]: + annotation = annotation | branch_annotation + return annotation + + def _normalize_schema(field_schema: dict[str, Any], path: str) -> dict[str, Any]: + """Resolve refs and simple combinators into a normalized schema object.""" + if not isinstance(field_schema, dict): + raise ValueError(f"{_format_path(path)} schema must be an object") + + normalized = dict(field_schema) + + if "$ref" in normalized: + ref = normalized["$ref"] + if not isinstance(ref, str): + raise ValueError(f"{_format_path(path)} $ref must be a string") + resolved = _resolve_ref(ref) + sibling_keywords = {k: v for k, v in normalized.items() if k != "$ref"} + if sibling_keywords: + merged = dict(resolved) + merged.update(sibling_keywords) + normalized = merged + else: + normalized = dict(resolved) + + if "allOf" in normalized: + all_of = normalized.pop("allOf") + if not isinstance(all_of, list): + raise ValueError(f"{_format_path(path)} allOf must be an array") + merged = _merge_object_schemas(all_of, path) + merged.update(normalized) + normalized = merged + + return normalized + + def _schema_to_type( + field_schema: dict[str, Any], path: str, in_union: bool = False + ) -> Any: + """Convert a JSON Schema node into a Python typing annotation.""" + normalized_schema = _normalize_schema(field_schema, path) + + for keyword in ("anyOf", "oneOf"): + if keyword in normalized_schema: + union_schemas = normalized_schema[keyword] + if not isinstance(union_schemas, list): + raise ValueError(f"{_format_path(path)} {keyword} must be an array") + sibling_keywords = { + key: value + for key, value in normalized_schema.items() + if key != keyword + } + branch_schemas: list[dict[str, Any]] = [] + for branch in union_schemas: + if not isinstance(branch, dict): + raise ValueError( + f"{_format_path(path)} {keyword} branches must be objects" + ) + merged_branch = dict(branch) + for sibling_key, sibling_value in sibling_keywords.items(): + merged_branch.setdefault(sibling_key, sibling_value) + branch_schemas.append(merged_branch) + return _union_annotation(keyword, branch_schemas, path) + + if "enum" in normalized_schema: + enum_values = normalized_schema["enum"] + if not isinstance(enum_values, list): + raise ValueError(f"{_format_path(path)} enum must be an array") + return _enum_annotation(enum_values, path) + + field_type = normalized_schema.get("type", "string") + is_nullable = False + if isinstance(field_type, list): + non_null_types = [item for item in field_type if item != "null"] + null_count = len(field_type) - len(non_null_types) + if null_count > 1 or len(non_null_types) != 1: + raise ValueError( + f"{_format_path(path)} uses unsupported multi-type schema: {field_type}" + ) + if null_count == 1: + is_nullable = True + field_type = non_null_types[0] + + primitive_type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + } + + if field_type in primitive_type_mapping: + base_type = primitive_type_mapping[field_type] + if in_union: + annotated_type = Annotated[base_type, Strict()] # type: ignore[valid-type] + return _merge_nullable(annotated_type, is_nullable) + return _merge_nullable(base_type, is_nullable) + + if field_type == "object": + properties = normalized_schema.get("properties") + additional_properties = normalized_schema.get("additionalProperties", True) + + if properties is None and isinstance(additional_properties, dict): + value_annotation = _schema_to_type(additional_properties, f"{path}.*") + dict_type = dict[str, value_annotation] # type: ignore[valid-type] + return _merge_nullable(dict_type, is_nullable) + + nested_name = _sanitize_model_name(f"{model_name}_{path.replace('.', '_')}") + nested_model = _object_schema_to_model(normalized_schema, nested_name, path) + return _merge_nullable(nested_model, is_nullable) + + if field_type == "array": + items_schema = normalized_schema.get("items") + item_annotation: Any + if items_schema is None: + item_annotation = Any + elif isinstance(items_schema, list): + raise ValueError( + f"{_format_path(path)} uses unsupported tuple-style array schema" + ) + elif isinstance(items_schema, dict): + item_annotation = _schema_to_type(items_schema, f"{path}[]") + else: + raise ValueError(f"{_format_path(path)} items must be an object") + # Construct list type at runtime to avoid mypy subscript error. + list_type = list[item_annotation] # type: ignore[valid-type] + return _merge_nullable(list_type, is_nullable) + + raise ValueError( + f"{_format_path(path)} uses unsupported JSON schema type: {field_type}" + ) + + def _object_schema_to_model( + object_schema: dict[str, Any], current_model_name: str, path: str + ) -> type[BaseModel]: + normalized_schema = _normalize_schema(object_schema, path) + if normalized_schema.get("type") != "object": + raise ValueError(f"{_format_path(path)} must be an object schema") + + cache_key = f"{current_model_name}:{id(object_schema)}" + cached = model_cache.get(cache_key) + if cached is not None: + return cached + + properties = normalized_schema.get("properties", {}) + required = normalized_schema.get("required", []) + additional_properties = normalized_schema.get("additionalProperties", True) + + if not isinstance(required, list): + raise ValueError(f"{_format_path(path)} 'required' must be an array") + + if not isinstance(properties, dict): + raise ValueError(f"{_format_path(path)} 'properties' must be an object") + + if not properties: + if isinstance(additional_properties, dict): + raise ValueError( + f"{_format_path(path)} is a pure additionalProperties map and should " + "be used as a field type, not as a model root" + ) + raise ValueError( + f"{_format_path(path)} must have a non-empty 'properties' object" + ) + + field_definitions: dict[str, Any] = {} + for field_name, field_schema in properties.items(): + child_path = f"{path}.{field_name}" if path else field_name + annotation = _schema_to_type(field_schema, child_path) + if field_name in required: + field_definitions[field_name] = (annotation, ...) + else: + field_definitions[field_name] = (annotation | None, None) + + if additional_properties not in (True, False): + raise ValueError( + f"{_format_path(path)} only supports boolean additionalProperties " + "when combined with named properties" + ) + + model_config = ConfigDict( + extra="forbid" if additional_properties is False else "ignore", + use_enum_values=True, + ) + dynamic_model = create_model( + current_model_name, __config__=model_config, **field_definitions + ) + model_cache[cache_key] = dynamic_model + return dynamic_model + + if not isinstance(schema, dict): + raise ValueError("Schema must be a dictionary") + + return _object_schema_to_model(schema, _sanitize_model_name(model_name), "") diff --git a/cli/serve/streaming.py b/cli/serve/streaming.py index 15298a98a..6b75f9c4b 100644 --- a/cli/serve/streaming.py +++ b/cli/serve/streaming.py @@ -1,10 +1,7 @@ """Streaming utilities for OpenAI-compatible server responses.""" -import json from collections.abc import AsyncGenerator -from pydantic import BaseModel, ValidationError - from mellea.core.base import ModelOutputThunk from mellea.core.utils import MelleaLogger from mellea.helpers.openai_compatible_helpers import build_completion_usage @@ -17,6 +14,7 @@ OpenAIErrorResponse, StreamOptions, ) +from .utils import extract_finish_reason async def stream_chat_completion_chunks( @@ -26,10 +24,14 @@ async def stream_chat_completion_chunks( created: int, stream_options: StreamOptions | None = None, system_fingerprint: str | None = None, - format_model: type[BaseModel] | None = None, ) -> AsyncGenerator[str, None]: """Generate OpenAI-compatible SSE chat completion chunks from a model output. + This function acts as a pass-through streaming layer, forwarding chunks directly + from the backend to the client without buffering or validation. Format validation + for structured outputs happens at the module level (in the serve function) and + client side, not in this streaming layer. + Args: output: The model output object to stream. completion_id: Unique identifier for this completion. @@ -40,9 +42,6 @@ async def stream_chat_completion_chunks( ``include_usage`` field. system_fingerprint: Backend configuration fingerprint to include in chunks. Defaults to ``None``. - format_model: Optional Pydantic model for validating structured output. - When provided, the complete streamed output will be validated against - this schema before the final chunk is sent. Yields: Server-sent event payload strings representing OpenAI-compatible chat @@ -105,45 +104,6 @@ async def stream_chat_completion_chunks( ) yield f"data: {chunk.model_dump_json()}\n\n" - # Validate format if format_model is provided - if format_model is not None: - if output.value is None: - error_response = OpenAIErrorResponse( - error=OpenAIError( - message="Output value is None, cannot validate format", - type="invalid_response_error", - ) - ) - yield f"data: {error_response.model_dump_json()}\n\n" - yield "data: [DONE]\n\n" - return - - try: - # Parse the complete output as JSON - output_json = json.loads(output.value) - # Validate against the Pydantic model - format_model.model_validate(output_json) - except json.JSONDecodeError as e: - error_response = OpenAIErrorResponse( - error=OpenAIError( - message=f"Output is not valid JSON: {e!s}", - type="invalid_response_error", - ) - ) - yield f"data: {error_response.model_dump_json()}\n\n" - yield "data: [DONE]\n\n" - return - except ValidationError as e: - error_response = OpenAIErrorResponse( - error=OpenAIError( - message=f"Output does not match required schema: {e!s}", - type="invalid_response_error", - ) - ) - yield f"data: {error_response.model_dump_json()}\n\n" - yield "data: [DONE]\n\n" - return - # Include usage in final chunk only if explicitly requested via stream_options # Per OpenAI spec: usage is only included when stream_options.include_usage=True include_usage = stream_options is not None and stream_options.include_usage @@ -158,7 +118,7 @@ async def stream_chat_completion_chunks( ChatCompletionChunkChoice( index=0, delta=ChatCompletionChunkDelta(content=None), - finish_reason="stop", + finish_reason=extract_finish_reason(output), ) ], object="chat.completion.chunk", diff --git a/cli/serve/utils.py b/cli/serve/utils.py new file mode 100644 index 000000000..81fd39731 --- /dev/null +++ b/cli/serve/utils.py @@ -0,0 +1,49 @@ +from typing import Any, Literal + +FinishReason = Literal[ + "stop", "length", "content_filter", "tool_calls", "function_call" +] + + +def extract_finish_reason(output: Any) -> FinishReason: + """Extract finish_reason from ModelOutputThunk metadata. + + Args: + output: The model output thunk containing response metadata. + + Returns: + The finish_reason from the backend response, defaulting to "stop" if unavailable. + Possible values: "stop", "length", "content_filter", "tool_calls", "function_call". + """ + # Valid finish_reason values per OpenAI spec + valid_reasons: set[FinishReason] = { + "stop", + "length", + "content_filter", + "tool_calls", + "function_call", + } + + # Try to get finish_reason from the response metadata + # Different backends store this in different places + if hasattr(output, "_meta") and output._meta: + # Ollama backend stores response in chat_response with done_reason field + # (ollama.ChatResponse object with done_reason attribute) + chat_response = output._meta.get("chat_response") + if chat_response and hasattr(chat_response, "done_reason"): + done_reason = chat_response.done_reason + if done_reason in valid_reasons: + return done_reason + + # OpenAI backend stores full response dict in oai_chat_response + # (from chunk.model_dump() which includes choices array) + oai_response = output._meta.get("oai_chat_response") + if oai_response and isinstance(oai_response, dict): + choices = oai_response.get("choices", []) + if choices and len(choices) > 0: + finish_reason = choices[0].get("finish_reason") + if finish_reason in valid_reasons: + return finish_reason + + # Default to "stop" per OpenAI spec + return "stop" diff --git a/test/cli/test_schema_converter.py b/test/cli/test_schema_converter.py new file mode 100644 index 000000000..78004292a --- /dev/null +++ b/test/cli/test_schema_converter.py @@ -0,0 +1,207 @@ +"""Unit tests for JSON Schema to Pydantic conversion.""" + +import pytest + +from cli.serve.schema_converter import json_schema_to_pydantic + + +def test_json_schema_supports_enum_field(): + """Test that enum constraints are converted to a narrower Pydantic type.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": {"status": {"type": "string", "enum": ["open", "closed"]}}, + "required": ["status"], + }, + "EnumExample", + ) + + parsed = model.model_validate({"status": "open"}) + assert parsed.model_dump()["status"] == "open" + + with pytest.raises(Exception): + model.model_validate({"status": "pending"}) + + +def test_json_schema_supports_nested_object_field(): + """Test that nested object schemas are converted recursively.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name"], + "additionalProperties": False, + } + }, + "required": ["user"], + }, + "NestedObjectExample", + ) + + parsed = model.model_validate({"user": {"name": "Alice", "age": 30}}) + parsed_user = parsed.model_dump()["user"] + assert parsed_user["name"] == "Alice" + assert parsed_user["age"] == 30 + + with pytest.raises(Exception): + model.model_validate({"user": {"name": "Alice", "extra": True}}) + + +def test_json_schema_supports_array_items_schema(): + """Test that arrays validate their item schemas.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": {"tags": {"type": "array", "items": {"type": "string"}}}, + "required": ["tags"], + }, + "ArrayExample", + ) + + parsed = model.model_validate({"tags": ["a", "b"]}) + assert parsed.model_dump()["tags"] == ["a", "b"] + + with pytest.raises(Exception): + model.model_validate({"tags": ["a", 1]}) + + +def test_json_schema_supports_top_level_ref(): + """Test that local refs are resolved from $defs.""" + model = json_schema_to_pydantic( + { + "type": "object", + "$defs": { + "User": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + }, + "properties": {"user": {"$ref": "#/$defs/User"}}, + "required": ["user"], + }, + "RefExample", + ) + + parsed = model.model_validate({"user": {"name": "Alice"}}) + assert parsed.model_dump()["user"]["name"] == "Alice" + + with pytest.raises(Exception): + model.model_validate({"user": {}}) + + +def test_json_schema_supports_anyof_field(): + """Test that representable anyOf branches are converted to unions.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": { + "value": {"anyOf": [{"type": "string"}, {"type": "integer"}]} + }, + "required": ["value"], + }, + "AnyOfExample", + ) + + parsed_string = model.model_validate({"value": "hello"}) + assert parsed_string.model_dump()["value"] == "hello" + + parsed_integer = model.model_validate({"value": 7}) + assert parsed_integer.model_dump()["value"] == 7 + + with pytest.raises(Exception): + model.model_validate({"value": True}) + + +def test_json_schema_supports_allof_object_merge(): + """Test that allOf merges object fragments into one model.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": { + "user": { + "allOf": [ + { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + { + "type": "object", + "properties": {"age": {"type": "integer"}}, + "required": ["age"], + "additionalProperties": False, + }, + ] + } + }, + "required": ["user"], + }, + "AllOfExample", + ) + + parsed = model.model_validate({"user": {"name": "Alice", "age": 30}}) + parsed_user = parsed.model_dump()["user"] + assert parsed_user["name"] == "Alice" + assert parsed_user["age"] == 30 + + with pytest.raises(Exception): + model.model_validate({"user": {"name": "Alice"}}) + + with pytest.raises(Exception): + model.model_validate({"user": {"name": "Alice", "age": 30, "extra": True}}) + + +def test_json_schema_supports_additional_properties_schema_map(): + """Test schema-valued additionalProperties as a typed dict field.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": { + "metadata": { + "type": "object", + "additionalProperties": {"type": "integer"}, + } + }, + "required": ["metadata"], + }, + "AdditionalPropertiesMapExample", + ) + + parsed = model.model_validate({"metadata": {"a": 1, "b": 2}}) + assert parsed.model_dump()["metadata"] == {"a": 1, "b": 2} + + with pytest.raises(Exception): + model.model_validate({"metadata": {"a": "bad"}}) + + +def test_json_schema_supports_nested_ref_in_array_items(): + """Test local refs nested under array items.""" + model = json_schema_to_pydantic( + { + "type": "object", + "$defs": { + "Tag": { + "type": "object", + "properties": {"label": {"type": "string"}}, + "required": ["label"], + "additionalProperties": False, + } + }, + "properties": {"tags": {"type": "array", "items": {"$ref": "#/$defs/Tag"}}}, + "required": ["tags"], + }, + "NestedRefArrayExample", + ) + + parsed = model.model_validate({"tags": [{"label": "alpha"}]}) + assert parsed.model_dump()["tags"][0]["label"] == "alpha" + + with pytest.raises(Exception): + model.model_validate({"tags": [{"label": "alpha", "extra": True}]}) diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index afeda4ebe..dbe0fe751 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -1,19 +1,28 @@ """Tests for the m serve OpenAI-compatible API server.""" +import json from unittest.mock import Mock import pytest from fastapi import FastAPI from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.testclient import TestClient +from pydantic import BaseModel, ValidationError -from cli.serve.app import make_chat_endpoint +from cli.serve.app import make_chat_endpoint, validation_exception_handler from cli.serve.models import ( ChatCompletion, ChatCompletionRequest, ChatMessage, CompletionUsage, + FunctionDefinition, + FunctionParameters, + JsonSchemaFormat, + ResponseFormat, + ToolFunction, ) +from mellea.backends.model_options import ModelOption from mellea.core.base import ModelOutputThunk @@ -125,8 +134,6 @@ async def test_system_fingerprint_always_none(self, mock_module, sample_request) @pytest.mark.asyncio async def test_model_options_passed_correctly(self, mock_module, sample_request): """Test that model options are passed to serve function correctly.""" - from mellea.backends.model_options import ModelOption - mock_output = ModelOutputThunk("Test response") mock_module.serve.return_value = mock_output @@ -233,9 +240,6 @@ async def test_all_fields_together(self, mock_module, sample_request): @pytest.mark.asyncio async def test_n_greater_than_1_rejected(self, mock_module): """Test that requests with n > 1 are rejected with appropriate error.""" - import json - - from fastapi.responses import JSONResponse request = ChatCompletionRequest( model="test-model", @@ -293,7 +297,6 @@ async def test_n_less_than_1_rejected_by_pydantic(self, mock_module): so n=0 or negative values will be caught by the framework, not our code. This test documents that behavior. """ - from pydantic import ValidationError # Pydantic validation happens before the endpoint is called with pytest.raises(ValidationError) as exc_info: @@ -320,8 +323,6 @@ def test_n_zero_rejected_at_http_level(self, mock_module): and converted to OpenAI-compatible 400 errors (not FastAPI's default 422). """ # Setup a test app with the exception handler - from cli.serve.app import validation_exception_handler - app = FastAPI() app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_api_route( @@ -439,8 +440,6 @@ async def test_unsupported_params_excluded_from_model_options(self, mock_module) model_options = call_args.kwargs["model_options"] # Supported parameters should be present - from mellea.backends.model_options import ModelOption - assert ModelOption.TEMPERATURE in model_options assert model_options[ModelOption.TEMPERATURE] == 0.7 assert ModelOption.MAX_NEW_TOKENS in model_options @@ -457,12 +456,6 @@ async def test_unsupported_params_excluded_from_model_options(self, mock_module) @pytest.mark.asyncio async def test_tool_params_excluded_from_model_options(self, mock_module): """Test that tool-related parameters are excluded from model_options.""" - from cli.serve.models import ( - FunctionDefinition, - FunctionParameters, - ToolFunction, - ) - request = ChatCompletionRequest( model="test-model", messages=[ChatMessage(role="user", content="Hello")], @@ -511,7 +504,6 @@ async def test_tool_params_excluded_from_model_options(self, mock_module): @pytest.mark.asyncio async def test_response_format_excluded_from_model_options(self, mock_module): """Test that response_format parameter is excluded from model_options.""" - from cli.serve.models import ResponseFormat request = ChatCompletionRequest( model="test-model", @@ -543,9 +535,6 @@ class TestResponseFormat: @pytest.mark.asyncio async def test_json_schema_format_passed_to_serve(self): """Test that json_schema response_format is converted to Pydantic model and passed to serve.""" - from pydantic import BaseModel - - from cli.serve.models import JsonSchemaFormat, ResponseFormat # Create a mock module with serve that accepts format parameter mock_module = Mock() @@ -598,7 +587,6 @@ def mock_serve(input, requirements=None, model_options=None, format=None): @pytest.mark.asyncio async def test_json_object_format_no_schema(self, mock_module): """Test that json_object response_format doesn't pass a format model.""" - from cli.serve.models import ResponseFormat request = ChatCompletionRequest( model="test-model", @@ -626,7 +614,6 @@ async def test_json_object_format_no_schema(self, mock_module): @pytest.mark.asyncio async def test_text_format_no_schema(self, mock_module): """Test that text response_format doesn't pass a format model.""" - from cli.serve.models import ResponseFormat request = ChatCompletionRequest( model="test-model", @@ -654,11 +641,6 @@ async def test_text_format_no_schema(self, mock_module): @pytest.mark.asyncio async def test_json_schema_missing_schema_field(self, mock_module): """Test that json_schema without schema field returns error.""" - import json - - from fastapi.responses import JSONResponse - - from cli.serve.models import ResponseFormat request = ChatCompletionRequest( model="test-model", @@ -687,11 +669,6 @@ async def test_json_schema_missing_schema_field(self, mock_module): @pytest.mark.asyncio async def test_json_schema_invalid_schema(self, mock_module): """Test that invalid JSON schema returns error.""" - import json - - from fastapi.responses import JSONResponse - - from cli.serve.models import JsonSchemaFormat, ResponseFormat request = ChatCompletionRequest( model="test-model", @@ -726,7 +703,6 @@ async def test_json_schema_invalid_schema(self, mock_module): @pytest.mark.asyncio async def test_serve_without_format_parameter(self, mock_module): """Test that serve functions without format parameter still work.""" - from cli.serve.models import JsonSchemaFormat, ResponseFormat # Create a serve function that doesn't accept format def serve_no_format(input, requirements=None, model_options=None): @@ -760,9 +736,6 @@ def serve_no_format(input, requirements=None, model_options=None): @pytest.mark.asyncio async def test_json_schema_with_optional_fields(self): """Test that JSON schema with optional fields is handled correctly.""" - from pydantic import BaseModel - - from cli.serve.models import JsonSchemaFormat, ResponseFormat # Create a mock module with serve that accepts format parameter mock_module = Mock() @@ -812,44 +785,23 @@ def mock_serve(input, requirements=None, model_options=None, format=None): # Verify response is successful assert isinstance(response, ChatCompletion) - -class TestResponseFormatStreaming: - """Tests for response_format parameter with streaming enabled.""" - @pytest.mark.asyncio - async def test_json_schema_format_with_streaming(self): - """Test that json_schema response_format works with stream=True.""" - from cli.serve.models import JsonSchemaFormat, ResponseFormat - - # Create a mock module with serve that accepts format parameter - mock_module = Mock() - mock_module.__name__ = "test_module" - - # Create a mock output that supports streaming - mock_output = ModelOutputThunk('{"name": "Alice", "age": 30}') - mock_output._computed = True # Mark as pre-computed - - def mock_serve(input, requirements=None, model_options=None, format=None): - return mock_output - - mock_module.serve = mock_serve + async def test_json_schema_rejects_non_local_ref(self, mock_module): + """Test that non-local refs still return a request error.""" - # Create a request with json_schema response_format and streaming request = ChatCompletionRequest( model="test-model", - messages=[ChatMessage(role="user", content="Generate a person")], - stream=True, + messages=[ChatMessage(role="user", content="Generate")], response_format=ResponseFormat( type="json_schema", json_schema=JsonSchemaFormat( - name="Person", + name="RemoteRefExample", schema={ "type": "object", "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"}, + "value": {"$ref": "https://example.com/schemas/value.json"} }, - "required": ["name", "age"], + "required": ["value"], }, ), ), @@ -858,45 +810,42 @@ def mock_serve(input, requirements=None, model_options=None, format=None): endpoint = make_chat_endpoint(mock_module) response = await endpoint(request) - # Verify it's a streaming response - from fastapi.responses import StreamingResponse - - assert isinstance(response, StreamingResponse) + assert isinstance(response, JSONResponse) + assert response.status_code == 400 - # Consume the stream and verify chunks - chunks = [] - async for chunk in response.body_iterator: - chunks.append(chunk) + body_bytes = response.body + if isinstance(body_bytes, memoryview): + body_bytes = bytes(body_bytes) + error_data = json.loads(body_bytes.decode("utf-8")) + assert error_data["error"]["type"] == "invalid_request_error" + assert "local" in error_data["error"]["message"].lower() + assert "$ref" in error_data["error"]["message"].lower() - # Should have multiple chunks including initial, content, final, and [DONE] - assert len(chunks) > 0 - # Verify no error chunks (all should start with "data: ") - for chunk in chunks: - chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk - assert chunk_str.startswith("data: ") +class TestResponseFormatStreaming: + """Tests for response_format parameter with streaming enabled.""" @pytest.mark.asyncio - async def test_json_schema_format_streaming_validation_error(self): - """Test that invalid JSON in streaming response returns error chunk.""" - from cli.serve.models import JsonSchemaFormat, ResponseFormat + async def test_json_schema_format_with_streaming(self): + """Test that json_schema response_format works with stream=True.""" - # Create a mock module + # Create a mock module with serve that accepts format parameter mock_module = Mock() mock_module.__name__ = "test_module" - # Create output with invalid JSON (missing required field) - mock_output = ModelOutputThunk('{"name": "Alice"}') # Missing 'age' - mock_output._computed = True + # Create a mock output that supports streaming + mock_output = ModelOutputThunk('{"name": "Alice", "age": 30}') + mock_output._computed = True # Mark as pre-computed def mock_serve(input, requirements=None, model_options=None, format=None): return mock_output mock_module.serve = mock_serve + # Create a request with json_schema response_format and streaming request = ChatCompletionRequest( model="test-model", - messages=[ChatMessage(role="user", content="Generate")], + messages=[ChatMessage(role="user", content="Generate a person")], stream=True, response_format=ResponseFormat( type="json_schema", @@ -917,86 +866,25 @@ def mock_serve(input, requirements=None, model_options=None, format=None): endpoint = make_chat_endpoint(mock_module) response = await endpoint(request) - from fastapi.responses import StreamingResponse - + # Verify it's a streaming response assert isinstance(response, StreamingResponse) - # Consume the stream + # Consume the stream and verify chunks chunks = [] async for chunk in response.body_iterator: chunks.append(chunk) - # Should contain an error chunk - error_found = False - for chunk in chunks: - chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk - if "error" in chunk_str.lower() and "schema" in chunk_str.lower(): - error_found = True - break - - assert error_found, "Expected validation error in stream" - - @pytest.mark.asyncio - async def test_json_schema_format_streaming_invalid_json(self): - """Test that non-JSON output in streaming response returns error chunk.""" - from cli.serve.models import JsonSchemaFormat, ResponseFormat - - # Create a mock module - mock_module = Mock() - mock_module.__name__ = "test_module" - - # Create output with invalid JSON - mock_output = ModelOutputThunk("This is not JSON") - mock_output._computed = True - - def mock_serve(input, requirements=None, model_options=None, format=None): - return mock_output - - mock_module.serve = mock_serve - - request = ChatCompletionRequest( - model="test-model", - messages=[ChatMessage(role="user", content="Generate")], - stream=True, - response_format=ResponseFormat( - type="json_schema", - json_schema=JsonSchemaFormat( - name="Person", - schema={ - "type": "object", - "properties": {"name": {"type": "string"}}, - "required": ["name"], - }, - ), - ), - ) - - endpoint = make_chat_endpoint(mock_module) - response = await endpoint(request) - - from fastapi.responses import StreamingResponse - - assert isinstance(response, StreamingResponse) - - # Consume the stream - chunks = [] - async for chunk in response.body_iterator: - chunks.append(chunk) + # Should have multiple chunks including initial, content, final, and [DONE] + assert len(chunks) > 0 - # Should contain an error chunk about invalid JSON - error_found = False + # Verify no error chunks (all should start with "data: ") for chunk in chunks: chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk - if "error" in chunk_str.lower() and "json" in chunk_str.lower(): - error_found = True - break - - assert error_found, "Expected JSON parsing error in stream" + assert chunk_str.startswith("data: ") @pytest.mark.asyncio async def test_json_object_format_with_streaming(self): """Test that json_object response_format works with stream=True.""" - from cli.serve.models import ResponseFormat mock_module = Mock() mock_module.__name__ = "test_module" @@ -1016,8 +904,6 @@ async def test_json_object_format_with_streaming(self): endpoint = make_chat_endpoint(mock_module) response = await endpoint(request) - from fastapi.responses import StreamingResponse - assert isinstance(response, StreamingResponse) # Consume the stream @@ -1037,7 +923,6 @@ async def test_json_object_format_with_streaming(self): @pytest.mark.asyncio async def test_text_format_with_streaming(self): """Test that text response_format works with stream=True.""" - from cli.serve.models import ResponseFormat mock_module = Mock() mock_module.__name__ = "test_module" @@ -1056,8 +941,6 @@ async def test_text_format_with_streaming(self): endpoint = make_chat_endpoint(mock_module) response = await endpoint(request) - from fastapi.responses import StreamingResponse - assert isinstance(response, StreamingResponse) # Consume the stream diff --git a/test/cli/test_serve_utils.py b/test/cli/test_serve_utils.py new file mode 100644 index 000000000..e0c44f89f --- /dev/null +++ b/test/cli/test_serve_utils.py @@ -0,0 +1,185 @@ +"""Unit tests for cli/serve/utils.py — finish_reason extraction.""" + +from unittest.mock import Mock + +import pytest + +from cli.serve.utils import extract_finish_reason +from mellea.core.base import ModelOutputThunk + + +class TestExtractFinishReason: + """Tests for extract_finish_reason function.""" + + def test_default_finish_reason_when_no_meta(self): + """Test that 'stop' is returned when output has no _meta attribute.""" + output = ModelOutputThunk("test response") + # Don't set _meta attribute + assert extract_finish_reason(output) == "stop" + + def test_default_finish_reason_when_meta_is_none(self): + """Test that 'stop' is returned when _meta is None.""" + output = ModelOutputThunk("test response") + output._meta = None + assert extract_finish_reason(output) == "stop" + + def test_default_finish_reason_when_meta_is_empty(self): + """Test that 'stop' is returned when _meta is empty dict.""" + output = ModelOutputThunk("test response") + output._meta = {} + assert extract_finish_reason(output) == "stop" + + def test_ollama_done_reason_stop(self): + """Test extraction of 'stop' from Ollama chat_response.done_reason.""" + output = ModelOutputThunk("test response") + chat_response = Mock() + chat_response.done_reason = "stop" + output._meta = {"chat_response": chat_response} + assert extract_finish_reason(output) == "stop" + + def test_ollama_done_reason_length(self): + """Test extraction of 'length' from Ollama chat_response.done_reason.""" + output = ModelOutputThunk("test response") + chat_response = Mock() + chat_response.done_reason = "length" + output._meta = {"chat_response": chat_response} + assert extract_finish_reason(output) == "length" + + def test_ollama_done_reason_none(self): + """Test that default 'stop' is returned when done_reason is None.""" + output = ModelOutputThunk("test response") + chat_response = Mock() + chat_response.done_reason = None + output._meta = {"chat_response": chat_response} + assert extract_finish_reason(output) == "stop" + + def test_ollama_chat_response_without_done_reason(self): + """Test that default 'stop' is returned when chat_response lacks done_reason.""" + output = ModelOutputThunk("test response") + chat_response = Mock(spec=[]) # Mock without done_reason attribute + output._meta = {"chat_response": chat_response} + assert extract_finish_reason(output) == "stop" + + def test_openai_finish_reason_stop(self): + """Test extraction of 'stop' from OpenAI oai_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": {"choices": [{"finish_reason": "stop", "index": 0}]} + } + assert extract_finish_reason(output) == "stop" + + def test_openai_finish_reason_length(self): + """Test extraction of 'length' from OpenAI oai_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": {"choices": [{"finish_reason": "length", "index": 0}]} + } + assert extract_finish_reason(output) == "length" + + def test_openai_finish_reason_content_filter(self): + """Test extraction of 'content_filter' from OpenAI oai_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": { + "choices": [{"finish_reason": "content_filter", "index": 0}] + } + } + assert extract_finish_reason(output) == "content_filter" + + def test_openai_finish_reason_tool_calls(self): + """Test extraction of 'tool_calls' from OpenAI oai_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": { + "choices": [{"finish_reason": "tool_calls", "index": 0}] + } + } + assert extract_finish_reason(output) == "tool_calls" + + def test_openai_finish_reason_function_call(self): + """Test extraction of 'function_call' from OpenAI oai_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": { + "choices": [{"finish_reason": "function_call", "index": 0}] + } + } + assert extract_finish_reason(output) == "function_call" + + def test_openai_empty_choices_array(self): + """Test that default 'stop' is returned when choices array is empty.""" + output = ModelOutputThunk("test response") + output._meta = {"oai_chat_response": {"choices": []}} + assert extract_finish_reason(output) == "stop" + + def test_openai_missing_choices_key(self): + """Test that default 'stop' is returned when choices key is missing.""" + output = ModelOutputThunk("test response") + output._meta = {"oai_chat_response": {}} + assert extract_finish_reason(output) == "stop" + + def test_openai_finish_reason_none(self): + """Test that default 'stop' is returned when finish_reason is None.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": {"choices": [{"finish_reason": None, "index": 0}]} + } + assert extract_finish_reason(output) == "stop" + + def test_openai_non_dict_response(self): + """Test that default 'stop' is returned when oai_chat_response is not a dict.""" + output = ModelOutputThunk("test response") + output._meta = {"oai_chat_response": "not a dict"} + assert extract_finish_reason(output) == "stop" + + def test_ollama_takes_precedence_over_openai(self): + """Test that Ollama done_reason is checked before OpenAI finish_reason.""" + output = ModelOutputThunk("test response") + chat_response = Mock() + chat_response.done_reason = "length" + output._meta = { + "chat_response": chat_response, + "oai_chat_response": {"choices": [{"finish_reason": "stop", "index": 0}]}, + } + # Should return Ollama's done_reason, not OpenAI's finish_reason + assert extract_finish_reason(output) == "length" + + def test_openai_used_when_ollama_missing(self): + """Test that OpenAI finish_reason is used when Ollama data is missing.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": { + "choices": [{"finish_reason": "content_filter", "index": 0}] + } + } + assert extract_finish_reason(output) == "content_filter" + + def test_multiple_choices_uses_first(self): + """Test that first choice is used when multiple choices exist.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": { + "choices": [ + {"finish_reason": "stop", "index": 0}, + {"finish_reason": "length", "index": 1}, + ] + } + } + assert extract_finish_reason(output) == "stop" + + def test_other_meta_keys_ignored(self): + """Test that unrelated _meta keys don't interfere.""" + output = ModelOutputThunk("test response") + output._meta = { + "model": "gpt-4", + "provider": "openai", + "usage": {"total_tokens": 100}, + "random_key": "random_value", + } + assert extract_finish_reason(output) == "stop" + + def test_output_without_meta_attribute(self): + """Test handling of output objects that don't have _meta attribute at all.""" + # Create a simple object without _meta + output = Mock(spec=[]) + assert extract_finish_reason(output) == "stop" From 4d5455bdb7048139f8c19c296d2f2d46bc903548 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 5 May 2026 15:28:22 -0700 Subject: [PATCH 06/13] feat: add model_validator to ResponseFormat for json_schema validation Add Pydantic model_validator to ResponseFormat class to validate that json_schema field is provided when type is 'json_schema'. This moves validation from the endpoint handler to the model level, providing earlier error detection and cleaner separation of concerns. Changes: - Import model_validator from pydantic in cli/serve/models.py - Add validate_json_schema_required() validator to ResponseFormat - Update test to expect ValidationError at model instantiation The validator raises ValueError with message "json_schema field is required when type is 'json_schema'" when validation fails. Signed-off-by: Mark Sturdevant Assisted-by: IBM Bob --- cli/serve/models.py | 9 ++++++++- test/cli/test_serve.py | 30 +++++++++--------------------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/cli/serve/models.py b/cli/serve/models.py index 64ee1e11f..2797b4b7e 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from mellea.helpers.openai_compatible_helpers import CompletionUsage @@ -50,6 +50,13 @@ class ResponseFormat(BaseModel): json_schema: JsonSchemaFormat | None = None """JSON Schema definition when type is 'json_schema'.""" + @model_validator(mode="after") + def validate_json_schema_required(self) -> "ResponseFormat": + """Validate that json_schema is provided when type is 'json_schema'.""" + if self.type == "json_schema" and self.json_schema is None: + raise ValueError("json_schema field is required when type is 'json_schema'") + return self + class StreamOptions(BaseModel): """OpenAI-compatible streaming options. diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index dbe0fe751..89ee3d940 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -640,31 +640,19 @@ async def test_text_format_no_schema(self, mock_module): @pytest.mark.asyncio async def test_json_schema_missing_schema_field(self, mock_module): - """Test that json_schema without schema field returns error.""" + """Test that json_schema without schema field raises ValidationError.""" + from pydantic import ValidationError - request = ChatCompletionRequest( - model="test-model", - messages=[ChatMessage(role="user", content="Generate")], - response_format=ResponseFormat( + # Should raise ValidationError when creating ResponseFormat + with pytest.raises(ValidationError) as exc_info: + ResponseFormat( type="json_schema", json_schema=None, # Missing schema - ), - ) - - endpoint = make_chat_endpoint(mock_module) - response = await endpoint(request) - - # Should return error - assert isinstance(response, JSONResponse) - assert response.status_code == 400 + ) - body_bytes = response.body - if isinstance(body_bytes, memoryview): - body_bytes = bytes(body_bytes) - error_data = json.loads(body_bytes.decode("utf-8")) - assert "error" in error_data - assert error_data["error"]["type"] == "invalid_request_error" - assert "json_schema" in error_data["error"]["message"].lower() + # Verify error message mentions json_schema requirement + error_str = str(exc_info.value) + assert "json_schema" in error_str.lower() @pytest.mark.asyncio async def test_json_schema_invalid_schema(self, mock_module): From 7a796e14f4d8bd0ff7142d3f0523d8f61fc4c5b0 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 5 May 2026 15:53:34 -0700 Subject: [PATCH 07/13] fix: remove redundant validation check - Remove redundant validation check from cli/serve/app.py - Use cast() for type narrowing Signed-off-by: Mark Sturdevant Assisted-by: IBM Bob --- cli/serve/app.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/cli/serve/app.py b/cli/serve/app.py index 31b0f2da3..d5cb6e8ad 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -7,7 +7,7 @@ import sys import time import uuid -from typing import Any +from typing import Any, cast try: import typer @@ -30,6 +30,7 @@ ChatCompletionMessage, ChatCompletionRequest, Choice, + JsonSchemaFormat, OpenAIError, OpenAIErrorResponse, ) @@ -166,17 +167,13 @@ async def endpoint(request: ChatCompletionRequest): format_model: type[BaseModel] | None = None if request.response_format is not None: if request.response_format.type == "json_schema": - if request.response_format.json_schema is None: - return create_openai_error_response( - status_code=400, - message="json_schema field is required when response_format.type is 'json_schema'", - error_type="invalid_request_error", - param="response_format.json_schema", - ) + # json_schema presence is validated by ResponseFormat.model_validator + json_schema = cast( + JsonSchemaFormat, request.response_format.json_schema + ) try: format_model = json_schema_to_pydantic( - request.response_format.json_schema.schema_, - request.response_format.json_schema.name, + json_schema.schema_, json_schema.name ) except ValueError as e: return create_openai_error_response( From 4437d55fecd854a6bd9eccde3de3fc79fadb92f5 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 5 May 2026 16:16:24 -0700 Subject: [PATCH 08/13] fix: catch TypeError in JSON schema validation for m serve The exception handler at line 178 in cli/serve/app.py now catches both ValueError and TypeError when validating JSON schemas. This handles edge cases where: - Enum values contain unhashable types (lists, dicts) - Pydantic's create_model() receives malformed field definitions Other exceptions (KeyError, AttributeError, RecursionError) are intentionally not caught as they indicate internal bugs in the schema converter, not invalid user input. Signed-off-by: Mark Sturdevant Assisted-by: IBM Bob --- cli/serve/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/serve/app.py b/cli/serve/app.py index d5cb6e8ad..2104c8de0 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -175,7 +175,7 @@ async def endpoint(request: ChatCompletionRequest): format_model = json_schema_to_pydantic( json_schema.schema_, json_schema.name ) - except ValueError as e: + except (ValueError, TypeError) as e: return create_openai_error_response( status_code=400, message=f"Invalid JSON schema: {e!s}", From d87310aa224c3cb396a634df6942d9d7bf425112 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 5 May 2026 16:27:06 -0700 Subject: [PATCH 09/13] test: fix conditional asserts that never fired in test_serve.py Replace conditional asserts with unconditional checks in test_json_object_format_no_schema and test_text_format_no_schema. The previous pattern 'if "format" in kwargs: assert kwargs["format"] is None' would pass silently if format was absent, failing to verify the intended behavior. Changed to 'assert "format" not in kwargs' to properly test that no format parameter is passed for json_object and text response types. Signed-off-by: Mark Sturdevant Assisted-by: IBM Bob --- test/cli/test_serve.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 89ee3d940..eaf363a4f 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -604,9 +604,8 @@ async def test_json_object_format_no_schema(self, mock_module): call_args = mock_module.serve.call_args assert call_args is not None - # For json_object, format should be None (no specific schema) - if "format" in call_args.kwargs: - assert call_args.kwargs["format"] is None + # For json_object, format should not be passed (no specific schema) + assert "format" not in call_args.kwargs # Verify response is successful assert isinstance(response, ChatCompletion) @@ -631,9 +630,8 @@ async def test_text_format_no_schema(self, mock_module): call_args = mock_module.serve.call_args assert call_args is not None - # For text, format should be None - if "format" in call_args.kwargs: - assert call_args.kwargs["format"] is None + # For text, format should not be passed + assert "format" not in call_args.kwargs # Verify response is successful assert isinstance(response, ChatCompletion) From c2e550ad26dc0fc1a95210297801ee7dbf2e77cf Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 5 May 2026 16:39:36 -0700 Subject: [PATCH 10/13] feat: extract finish_reason with LiteLLM Signed-off-by: Mark Sturdevant --- cli/serve/utils.py | 11 +++++ test/cli/test_serve_utils.py | 79 ++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/cli/serve/utils.py b/cli/serve/utils.py index 81fd39731..e93baac77 100644 --- a/cli/serve/utils.py +++ b/cli/serve/utils.py @@ -8,6 +8,10 @@ def extract_finish_reason(output: Any) -> FinishReason: """Extract finish_reason from ModelOutputThunk metadata. + Checks backend-specific metadata fields in order: Ollama, OpenAI, LiteLLM. + Backends without finish_reason metadata (e.g., HuggingFace) fall through to + the default "stop" value. + Args: output: The model output thunk containing response metadata. @@ -45,5 +49,12 @@ def extract_finish_reason(output: Any) -> FinishReason: if finish_reason in valid_reasons: return finish_reason + # LiteLLM backend stores response dict in litellm_chat_response + litellm_response = output._meta.get("litellm_chat_response") + if litellm_response and isinstance(litellm_response, dict): + finish_reason = litellm_response.get("finish_reason") + if finish_reason in valid_reasons: + return finish_reason + # Default to "stop" per OpenAI spec return "stop" diff --git a/test/cli/test_serve_utils.py b/test/cli/test_serve_utils.py index e0c44f89f..b3ea52a72 100644 --- a/test/cli/test_serve_utils.py +++ b/test/cli/test_serve_utils.py @@ -183,3 +183,82 @@ def test_output_without_meta_attribute(self): # Create a simple object without _meta output = Mock(spec=[]) assert extract_finish_reason(output) == "stop" + + def test_litellm_finish_reason_stop(self): + """Test extraction of 'stop' from LiteLLM litellm_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = {"litellm_chat_response": {"finish_reason": "stop"}} + assert extract_finish_reason(output) == "stop" + + def test_litellm_finish_reason_length(self): + """Test extraction of 'length' from LiteLLM litellm_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = {"litellm_chat_response": {"finish_reason": "length"}} + assert extract_finish_reason(output) == "length" + + def test_litellm_finish_reason_tool_calls(self): + """Test extraction of 'tool_calls' from LiteLLM litellm_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = {"litellm_chat_response": {"finish_reason": "tool_calls"}} + assert extract_finish_reason(output) == "tool_calls" + + def test_litellm_finish_reason_content_filter(self): + """Test extraction of 'content_filter' from LiteLLM litellm_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = {"litellm_chat_response": {"finish_reason": "content_filter"}} + assert extract_finish_reason(output) == "content_filter" + + def test_litellm_finish_reason_function_call(self): + """Test extraction of 'function_call' from LiteLLM litellm_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = {"litellm_chat_response": {"finish_reason": "function_call"}} + assert extract_finish_reason(output) == "function_call" + + def test_litellm_finish_reason_none(self): + """Test that default 'stop' is returned when LiteLLM finish_reason is None.""" + output = ModelOutputThunk("test response") + output._meta = {"litellm_chat_response": {"finish_reason": None}} + assert extract_finish_reason(output) == "stop" + + def test_litellm_missing_finish_reason_key(self): + """Test that default 'stop' is returned when finish_reason key is missing.""" + output = ModelOutputThunk("test response") + output._meta = {"litellm_chat_response": {}} + assert extract_finish_reason(output) == "stop" + + def test_litellm_non_dict_response(self): + """Test that default 'stop' is returned when litellm_chat_response is not a dict.""" + output = ModelOutputThunk("test response") + output._meta = {"litellm_chat_response": "not a dict"} + assert extract_finish_reason(output) == "stop" + + def test_backend_precedence_ollama_openai_litellm(self): + """Test that backends are checked in order: Ollama, OpenAI, LiteLLM.""" + output = ModelOutputThunk("test response") + chat_response = Mock() + chat_response.done_reason = "length" + output._meta = { + "chat_response": chat_response, + "oai_chat_response": {"choices": [{"finish_reason": "stop", "index": 0}]}, + "litellm_chat_response": {"finish_reason": "content_filter"}, + } + # Should return Ollama's done_reason (checked first) + assert extract_finish_reason(output) == "length" + + def test_litellm_used_when_ollama_and_openai_missing(self): + """Test that LiteLLM finish_reason is used when Ollama and OpenAI data missing.""" + output = ModelOutputThunk("test response") + output._meta = {"litellm_chat_response": {"finish_reason": "tool_calls"}} + assert extract_finish_reason(output) == "tool_calls" + + def test_openai_takes_precedence_over_litellm(self): + """Test that OpenAI finish_reason is checked before LiteLLM.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": { + "choices": [{"finish_reason": "content_filter", "index": 0}] + }, + "litellm_chat_response": {"finish_reason": "stop"}, + } + # Should return OpenAI's finish_reason (checked before LiteLLM) + assert extract_finish_reason(output) == "content_filter" From da8a302af1ea597cbb1d29734e97c7bcb87b198f Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 5 May 2026 17:06:13 -0700 Subject: [PATCH 11/13] fix: return 400 for recursive json schema refs Signed-off-by: Mark Sturdevant Assisted-by: IBM Bob --- cli/serve/app.py | 9 ++- cli/serve/schema_converter.py | 116 ++++++++++++++++++++-------------- test/cli/test_serve.py | 44 +++++++++++++ test/cli/test_serve_utils.py | 2 - 4 files changed, 119 insertions(+), 52 deletions(-) diff --git a/cli/serve/app.py b/cli/serve/app.py index 2104c8de0..365b07898 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -175,10 +175,15 @@ async def endpoint(request: ChatCompletionRequest): format_model = json_schema_to_pydantic( json_schema.schema_, json_schema.name ) - except (ValueError, TypeError) as e: + except (ValueError, TypeError, RecursionError) as e: + message = ( + "Invalid JSON schema: recursive $ref is not supported" + if isinstance(e, RecursionError) + else f"Invalid JSON schema: {e!s}" + ) return create_openai_error_response( status_code=400, - message=f"Invalid JSON schema: {e!s}", + message=message, error_type="invalid_request_error", param="response_format.json_schema.schema", ) diff --git a/cli/serve/schema_converter.py b/cli/serve/schema_converter.py index 866e9341d..d682a0752 100644 --- a/cli/serve/schema_converter.py +++ b/cli/serve/schema_converter.py @@ -39,6 +39,7 @@ def json_schema_to_pydantic( Still unsupported and will raise ``ValueError``: - non-local refs + - recursive ``$ref`` cycles - tuple-style array schemas - object schemas without ``properties`` unless they are pure ``additionalProperties`` maps @@ -64,6 +65,8 @@ def json_schema_to_pydantic( ref_cache: dict[str, Any] = {} model_cache: dict[str, type[BaseModel]] = {} + in_flight_refs: set[str] = set() + ref_name_by_schema_id: dict[int, str] = {} def _sanitize_model_name(name: str) -> str: sanitized = "".join(ch if ch.isalnum() else "_" for ch in name).strip("_") @@ -72,12 +75,18 @@ def _sanitize_model_name(name: str) -> str: def _format_path(path: str) -> str: return path or "" - def _resolve_ref(ref: str) -> dict[str, Any]: + def _resolve_ref(ref: str) -> tuple[str, dict[str, Any]]: if ref in ref_cache: resolved = ref_cache[ref] if not isinstance(resolved, dict): raise ValueError(f"Resolved ref is invalid: {ref}") - return resolved + + for prefix in ("#/$defs/", "#/definitions/"): + if ref.startswith(prefix): + return ref[len(prefix) :], resolved + raise ValueError( + f"Only local $ref values into $defs/definitions are supported: {ref}" + ) prefixes = ("#/$defs/", "#/definitions/") for prefix in prefixes: @@ -89,7 +98,8 @@ def _resolve_ref(ref: str) -> dict[str, Any]: if not isinstance(target, dict): raise ValueError(f"Ref target must be an object: {ref}") ref_cache[ref] = target - return target + ref_name_by_schema_id[id(target)] = key + return key, target raise ValueError( f"Only local $ref values into $defs/definitions are supported: {ref}" @@ -230,7 +240,7 @@ def _normalize_schema(field_schema: dict[str, Any], path: str) -> dict[str, Any] ref = normalized["$ref"] if not isinstance(ref, str): raise ValueError(f"{_format_path(path)} $ref must be a string") - resolved = _resolve_ref(ref) + _, resolved = _resolve_ref(ref) sibling_keywords = {k: v for k, v in normalized.items() if k != "$ref"} if sibling_keywords: merged = dict(resolved) @@ -347,59 +357,69 @@ def _schema_to_type( def _object_schema_to_model( object_schema: dict[str, Any], current_model_name: str, path: str ) -> type[BaseModel]: - normalized_schema = _normalize_schema(object_schema, path) - if normalized_schema.get("type") != "object": - raise ValueError(f"{_format_path(path)} must be an object schema") + current_ref_name = ref_name_by_schema_id.get(id(object_schema)) + if current_ref_name is not None: + if current_ref_name in in_flight_refs: + raise ValueError("recursive $ref is not supported") + in_flight_refs.add(current_ref_name) + + try: + normalized_schema = _normalize_schema(object_schema, path) + if normalized_schema.get("type") != "object": + raise ValueError(f"{_format_path(path)} must be an object schema") + + cache_key = f"{current_model_name}:{id(object_schema)}" + cached = model_cache.get(cache_key) + if cached is not None: + return cached + + properties = normalized_schema.get("properties", {}) + required = normalized_schema.get("required", []) + additional_properties = normalized_schema.get("additionalProperties", True) - cache_key = f"{current_model_name}:{id(object_schema)}" - cached = model_cache.get(cache_key) - if cached is not None: - return cached + if not isinstance(required, list): + raise ValueError(f"{_format_path(path)} 'required' must be an array") - properties = normalized_schema.get("properties", {}) - required = normalized_schema.get("required", []) - additional_properties = normalized_schema.get("additionalProperties", True) + if not isinstance(properties, dict): + raise ValueError(f"{_format_path(path)} 'properties' must be an object") - if not isinstance(required, list): - raise ValueError(f"{_format_path(path)} 'required' must be an array") + if not properties: + if isinstance(additional_properties, dict): + raise ValueError( + f"{_format_path(path)} is a pure additionalProperties map and should " + "be used as a field type, not as a model root" + ) + raise ValueError( + f"{_format_path(path)} must have a non-empty 'properties' object" + ) - if not isinstance(properties, dict): - raise ValueError(f"{_format_path(path)} 'properties' must be an object") + field_definitions: dict[str, Any] = {} + for field_name, field_schema in properties.items(): + child_path = f"{path}.{field_name}" if path else field_name + annotation = _schema_to_type(field_schema, child_path) + if field_name in required: + field_definitions[field_name] = (annotation, ...) + else: + field_definitions[field_name] = (annotation | None, None) - if not properties: - if isinstance(additional_properties, dict): + if additional_properties not in (True, False): raise ValueError( - f"{_format_path(path)} is a pure additionalProperties map and should " - "be used as a field type, not as a model root" + f"{_format_path(path)} only supports boolean additionalProperties " + "when combined with named properties" ) - raise ValueError( - f"{_format_path(path)} must have a non-empty 'properties' object" - ) - - field_definitions: dict[str, Any] = {} - for field_name, field_schema in properties.items(): - child_path = f"{path}.{field_name}" if path else field_name - annotation = _schema_to_type(field_schema, child_path) - if field_name in required: - field_definitions[field_name] = (annotation, ...) - else: - field_definitions[field_name] = (annotation | None, None) - if additional_properties not in (True, False): - raise ValueError( - f"{_format_path(path)} only supports boolean additionalProperties " - "when combined with named properties" + model_config = ConfigDict( + extra="forbid" if additional_properties is False else "ignore", + use_enum_values=True, ) - - model_config = ConfigDict( - extra="forbid" if additional_properties is False else "ignore", - use_enum_values=True, - ) - dynamic_model = create_model( - current_model_name, __config__=model_config, **field_definitions - ) - model_cache[cache_key] = dynamic_model - return dynamic_model + dynamic_model = create_model( + current_model_name, __config__=model_config, **field_definitions + ) + model_cache[cache_key] = dynamic_model + return dynamic_model + finally: + if current_ref_name is not None: + in_flight_refs.remove(current_ref_name) if not isinstance(schema, dict): raise ValueError("Schema must be a dictionary") diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index eaf363a4f..e9d5a6d6b 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -807,6 +807,50 @@ async def test_json_schema_rejects_non_local_ref(self, mock_module): assert "local" in error_data["error"]["message"].lower() assert "$ref" in error_data["error"]["message"].lower() + @pytest.mark.asyncio + async def test_json_schema_rejects_recursive_ref(self, mock_module): + """Test that recursive local refs return a request error instead of crashing.""" + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="RecursiveNode", + schema={ + "type": "object", + "$defs": { + "Node": { + "type": "object", + "properties": { + "val": {"type": "integer"}, + "child": {"$ref": "#/$defs/Node"}, + }, + "required": ["val"], + } + }, + "properties": {"root": {"$ref": "#/$defs/Node"}}, + "required": ["root"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + + body_bytes = response.body + if isinstance(body_bytes, memoryview): + body_bytes = bytes(body_bytes) + error_data = json.loads(body_bytes.decode("utf-8")) + assert error_data["error"]["type"] == "invalid_request_error" + assert "recursive" in error_data["error"]["message"].lower() + assert "$ref" in error_data["error"]["message"].lower() + class TestResponseFormatStreaming: """Tests for response_format parameter with streaming enabled.""" diff --git a/test/cli/test_serve_utils.py b/test/cli/test_serve_utils.py index b3ea52a72..f2e83d0ac 100644 --- a/test/cli/test_serve_utils.py +++ b/test/cli/test_serve_utils.py @@ -2,8 +2,6 @@ from unittest.mock import Mock -import pytest - from cli.serve.utils import extract_finish_reason from mellea.core.base import ModelOutputThunk From 2cfc9eeb3ff264d14e6c7bfda176c30e1adad039 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 5 May 2026 17:18:24 -0700 Subject: [PATCH 12/13] fix: raise error on missing type in JSON Schema instead of defaulting to string Signed-off-by: Mark Sturdevant --- cli/serve/schema_converter.py | 8 ++- test/cli/test_schema_converter.py | 102 ++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/cli/serve/schema_converter.py b/cli/serve/schema_converter.py index d682a0752..e253b66ee 100644 --- a/cli/serve/schema_converter.py +++ b/cli/serve/schema_converter.py @@ -293,7 +293,13 @@ def _schema_to_type( raise ValueError(f"{_format_path(path)} enum must be an array") return _enum_annotation(enum_values, path) - field_type = normalized_schema.get("type", "string") + field_type = normalized_schema.get("type") + if field_type is None: + raise ValueError( + f"{_format_path(path)} schema must have a 'type' keyword. " + "JSON Schema without 'type' is valid but not supported by this converter. " + "Please add an explicit type (e.g., 'string', 'integer', 'object', 'array')." + ) is_nullable = False if isinstance(field_type, list): non_null_types = [item for item in field_type if item != "null"] diff --git a/test/cli/test_schema_converter.py b/test/cli/test_schema_converter.py index 78004292a..a303f6069 100644 --- a/test/cli/test_schema_converter.py +++ b/test/cli/test_schema_converter.py @@ -205,3 +205,105 @@ def test_json_schema_supports_nested_ref_in_array_items(): with pytest.raises(Exception): model.model_validate({"tags": [{"label": "alpha", "extra": True}]}) + + +def test_json_schema_rejects_missing_type_on_property(): + """Test that properties without explicit type raise ValueError.""" + with pytest.raises( + ValueError, + match=r"schema must have a 'type' keyword.*not supported by this converter", + ): + json_schema_to_pydantic( + { + "type": "object", + "properties": {"data": {"description": "anything"}}, + "required": ["data"], + } + ) + + +def test_json_schema_rejects_missing_type_on_nested_object(): + """Test that nested objects without type raise ValueError.""" + with pytest.raises( + ValueError, + match=r"schema must have a 'type' keyword.*not supported by this converter", + ): + json_schema_to_pydantic( + { + "type": "object", + "properties": { + "user": { + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + }, + "required": ["user"], + } + ) + + +def test_json_schema_rejects_missing_type_on_array_items(): + """Test that array items without type raise ValueError.""" + with pytest.raises( + ValueError, + match=r"schema must have a 'type' keyword.*not supported by this converter", + ): + json_schema_to_pydantic( + { + "type": "object", + "properties": { + "items": {"type": "array", "items": {"description": "any item"}} + }, + "required": ["items"], + } + ) + + +def test_json_schema_rejects_missing_type_in_anyof_branch(): + """Test that anyOf branches without type raise ValueError.""" + with pytest.raises( + ValueError, + match=r"schema must have a 'type' keyword.*not supported by this converter", + ): + json_schema_to_pydantic( + { + "type": "object", + "properties": { + "value": { + "anyOf": [{"type": "string"}, {"description": "anything"}] + } + }, + "required": ["value"], + } + ) + + +def test_json_schema_allows_missing_type_in_allof_branches(): + """Test that allOf branches without type default to object (intentional).""" + # allOf is specifically for merging object fragments, so missing type + # defaults to "object" rather than raising an error + model = json_schema_to_pydantic( + { + "type": "object", + "properties": { + "user": { + "allOf": [ + { + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + { + "properties": {"age": {"type": "integer"}}, + "required": ["age"], + }, + ] + } + }, + "required": ["user"], + } + ) + + parsed = model.model_validate({"user": {"name": "Alice", "age": 30}}) + parsed_user = parsed.model_dump()["user"] + assert parsed_user["name"] == "Alice" + assert parsed_user["age"] == 30 From 19baa226713a16e4526a4db94eec68b2923bb0a4 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 5 May 2026 17:34:48 -0700 Subject: [PATCH 13/13] fix(serve): use Literal for enums to prevent case collisions Replace Enum generation with Literal[tuple(values)] to preserve case-variant enum values like ["open", "OPEN"] that were previously collapsing to a single member. Signed-off-by: Mark Sturdevant Assisted-by: IBM Bob --- cli/serve/schema_converter.py | 13 --------- test/cli/test_schema_converter.py | 48 +++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/cli/serve/schema_converter.py b/cli/serve/schema_converter.py index e253b66ee..4aa5fd519 100644 --- a/cli/serve/schema_converter.py +++ b/cli/serve/schema_converter.py @@ -1,6 +1,5 @@ """Helpers for converting OpenAI-style JSON Schema response formats.""" -from enum import Enum from typing import Annotated, Any, Literal, cast from pydantic import BaseModel, ConfigDict, Strict, create_model @@ -129,18 +128,6 @@ def _enum_annotation(enum_values: list[Any], path: str) -> Any: f"{_format_path(path)} enum values must be string, integer, number, or boolean" ) - if value_type is str: - enum_name = _sanitize_model_name( - path.replace(".", "_").replace("[", "_").replace("]", "") - ) - members = { - ( - value.upper() if value and value[0].isalpha() else f"VALUE_{index}" - ): value - for index, value in enumerate(enum_values) - } - return Enum(enum_name or "GeneratedEnum", members) - return Literal[tuple(enum_values)] def _merge_object_schemas( diff --git a/test/cli/test_schema_converter.py b/test/cli/test_schema_converter.py index a303f6069..49edb663b 100644 --- a/test/cli/test_schema_converter.py +++ b/test/cli/test_schema_converter.py @@ -307,3 +307,51 @@ def test_json_schema_allows_missing_type_in_allof_branches(): parsed_user = parsed.model_dump()["user"] assert parsed_user["name"] == "Alice" assert parsed_user["age"] == 30 + + +def test_json_schema_supports_case_variant_enum_values(): + """Test that enum values differing only in case are preserved correctly. + + Regression test for issue where ["open", "OPEN"] would collapse to a + single enum member, causing validation to fail for one of the values. + """ + model = json_schema_to_pydantic( + { + "type": "object", + "properties": {"status": {"type": "string", "enum": ["open", "OPEN"]}}, + "required": ["status"], + }, + "CaseVariantEnumExample", + ) + + # Both case variants should validate successfully + parsed_lower = model.model_validate({"status": "open"}) + assert parsed_lower.model_dump()["status"] == "open" + + parsed_upper = model.model_validate({"status": "OPEN"}) + assert parsed_upper.model_dump()["status"] == "OPEN" + + # Invalid value should still fail + with pytest.raises(Exception): + model.model_validate({"status": "closed"}) + + +def test_json_schema_supports_time_period_enum(): + """Test AM/PM style enums that are common in migrated schemas.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": {"period": {"type": "string", "enum": ["AM", "PM"]}}, + "required": ["period"], + }, + "TimePeriodExample", + ) + + parsed_am = model.model_validate({"period": "AM"}) + assert parsed_am.model_dump()["period"] == "AM" + + parsed_pm = model.model_validate({"period": "PM"}) + assert parsed_pm.model_dump()["period"] == "PM" + + with pytest.raises(Exception): + model.model_validate({"period": "am"})