Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions cli/serve/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import sys
import time
import uuid
from typing import Any, cast

try:
import typer
import uvicorn
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
except ImportError as e:
raise ImportError(
"The 'm serve' command requires extra dependencies. "
Expand All @@ -28,10 +30,13 @@
ChatCompletionMessage,
ChatCompletionRequest,
Choice,
JsonSchemaFormat,
OpenAIError,
OpenAIErrorResponse,
)
from .schema_converter import json_schema_to_pydantic
from .streaming import stream_chat_completion_chunks
from .utils import extract_finish_reason

app = FastAPI(
title="M serve OpenAI API Compatible Server",
Expand Down Expand Up @@ -108,7 +113,7 @@ def _build_model_options(request: ChatCompletionRequest) -> dict:
"presence_penalty", # Presence penalty - not yet implemented
"frequency_penalty", # Frequency penalty - not yet implemented
"logit_bias", # Logit bias - not yet implemented
"response_format", # Response format (json_object) - not yet implemented
"response_format", # Response format - handled separately
"functions", # Legacy function calling - not yet implemented
"function_call", # Legacy function calling - not yet implemented
"tools", # Tool calling - not yet implemented
Expand Down Expand Up @@ -137,6 +142,10 @@ def _build_model_options(request: ChatCompletionRequest) -> dict:

def make_chat_endpoint(module):
"""Makes a chat endpoint using a custom module."""
# Inspect serve function once at endpoint creation time
serve_sig = inspect.signature(module.serve)
accepts_format = "format" in serve_sig.parameters
is_async = inspect.iscoroutinefunction(module.serve)

async def endpoint(request: ChatCompletionRequest):
try:
Expand All @@ -154,22 +163,50 @@ async def endpoint(request: ChatCompletionRequest):

model_options = _build_model_options(request)

# Handle response_format
format_model: type[BaseModel] | None = None
if request.response_format is not None:
if request.response_format.type == "json_schema":
# json_schema presence is validated by ResponseFormat.model_validator
json_schema = cast(
JsonSchemaFormat, request.response_format.json_schema
)
try:
format_model = json_schema_to_pydantic(
json_schema.schema_, json_schema.name
)
except (ValueError, TypeError, RecursionError) as e:
message = (
"Invalid JSON schema: recursive $ref is not supported"
if isinstance(e, RecursionError)
else f"Invalid JSON schema: {e!s}"
)
return create_openai_error_response(
status_code=400,
message=message,
error_type="invalid_request_error",
Comment thread
markstur marked this conversation as resolved.
param="response_format.json_schema.schema",
)
# For "json_object" and "text", format_model remains None
# Note: "json_object" mode is not yet implemented - the backend
# receives no signal to produce JSON output (same as "text" mode)

# Build kwargs for serve call
serve_kwargs: dict[str, Any] = {
"input": request.messages,
"requirements": request.requirements,
"model_options": model_options,
}
if accepts_format:
serve_kwargs["format"] = format_model

# Detect if serve is async or sync and handle accordingly
if inspect.iscoroutinefunction(module.serve):
if is_async:
# It's async, await it directly
output = await module.serve(
input=request.messages,
requirements=request.requirements,
model_options=model_options,
)
output = await module.serve(**serve_kwargs)
else:
# It's sync, run in thread pool to avoid blocking event loop
output = await asyncio.to_thread(
module.serve,
input=request.messages,
requirements=request.requirements,
model_options=model_options,
)
output = await asyncio.to_thread(module.serve, **serve_kwargs)

# system_fingerprint represents backend config hash, not model name
# The model name is already in response.model (line 73)
Expand Down Expand Up @@ -200,7 +237,7 @@ async def endpoint(request: ChatCompletionRequest):
message=ChatCompletionMessage(
content=output.value, role="assistant"
),
finish_reason="stop",
finish_reason=extract_finish_reason(output),
)
],
object="chat.completion", # type: ignore
Expand Down
29 changes: 27 additions & 2 deletions cli/serve/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from mellea.helpers.openai_compatible_helpers import CompletionUsage

Expand Down Expand Up @@ -29,8 +29,33 @@ class ToolFunction(BaseModel):
function: FunctionDefinition


class JsonSchemaFormat(BaseModel):
"""JSON Schema definition for structured output."""

name: str
"""Name of the schema."""

schema_: dict[str, Any] = Field(alias="schema")
"""JSON Schema definition."""

strict: bool | None = None
Comment thread
markstur marked this conversation as resolved.
"""Accepted for OpenAI compatibility; currently ignored by ``m serve``."""

model_config = {"populate_by_name": True}


class ResponseFormat(BaseModel):
type: Literal["text", "json_object"]
type: Literal["text", "json_object", "json_schema"]

json_schema: JsonSchemaFormat | None = None
"""JSON Schema definition when type is 'json_schema'."""

@model_validator(mode="after")
def validate_json_schema_required(self) -> "ResponseFormat":
"""Validate that json_schema is provided when type is 'json_schema'."""
if self.type == "json_schema" and self.json_schema is None:
raise ValueError("json_schema field is required when type is 'json_schema'")
return self


class StreamOptions(BaseModel):
Comment thread
markstur marked this conversation as resolved.
Expand Down
Loading
Loading