diff --git a/README.md b/README.md index e3e05c6..9dd46c4 100644 --- a/README.md +++ b/README.md @@ -31,17 +31,13 @@ if not api_key: llm = LLM(model="openrouter:openrouter/free", api_key=api_key) result = llm.chat("Describe Republic in one sentence.", max_tokens=48) - -if result.error: - print(result.error.kind, result.error.message) -else: - print(result.value) +print(result) ``` ## Why It Feels Natural - **Plain Python**: The main flow is regular functions and branches, no extra DSL. -- **Structured Result**: Core interfaces return `StructuredOutput`, with stable `ErrorKind` values. +- **Structured error handling**: Errors are explicit and typed, so retry and fallback logic stays deterministic. - **Tools without magic**: Supports both automatic and manual tool execution with clear debugging and auditing. - **Tape-first memory**: Use anchor/handoff to bound context windows and replay full evidence. - **Event streaming**: Subscribe to text deltas, tool calls, tool results, usage, and final state. diff --git a/docs/guides/chat.md b/docs/guides/chat.md index 62fafe3..8717c19 100644 --- a/docs/guides/chat.md +++ b/docs/guides/chat.md @@ -9,7 +9,7 @@ from republic import LLM llm = LLM(model="openrouter:openrouter/free", api_key="") out = llm.chat("Output exactly one word: ready", max_tokens=8) -print(out.value, out.error) +print(out) ``` ## Messages Mode @@ -25,12 +25,18 @@ out = llm.chat(messages=messages, max_tokens=48) ## Structured Error Handling ```python -result = llm.chat("Write one sentence.") -if result.error: - if result.error.kind == "temporary": +from republic import ErrorPayload, LLM + +llm = LLM(model="openrouter:openrouter/free", api_key="") + +try: + out = llm.chat("Write one sentence.", max_tokens=32) + print(out) +except ErrorPayload as error: + if error.kind == "temporary": print("retry later") else: - print("fail fast:", result.error.message) + print("fail fast:", error.message) ``` ## Retries and Fallback diff --git a/docs/guides/tools.md b/docs/guides/tools.md index 54ee506..0a8dab9 100644 --- a/docs/guides/tools.md +++ b/docs/guides/tools.md @@ -23,7 +23,7 @@ from republic import LLM llm = LLM(model="openrouter:openai/gpt-4o-mini", api_key="") out = llm.run_tools("What is weather in Tokyo?", tools=[get_weather]) -print(out.kind) # text | tools | error +print(out.kind) # text | tools | error print(out.tool_results) print(out.error) ``` @@ -32,10 +32,7 @@ print(out.error) ```python calls = llm.tool_calls("Use get_weather for Berlin.", tools=[get_weather]) -if calls.error: - raise RuntimeError(calls.error.message) - -execution = llm.tools.execute(calls.value, tools=[get_weather]) +execution = llm.tools.execute(calls, tools=[get_weather]) print(execution.tool_results) print(execution.error) ``` diff --git a/docs/quickstart.md b/docs/quickstart.md index 78edb27..8abfc6b 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -20,11 +20,7 @@ llm = LLM(model="openrouter:openrouter/free", api_key="") ```python out = llm.chat("Write one short release note.", max_tokens=48) - -if out.error: - print("error:", out.error.kind, out.error.message) -else: - print("text:", out.value) +print("text:", out) ``` ## Step 3: Add an auditable trace to the session @@ -36,12 +32,14 @@ tape = llm.tape("release-notes") tape.handoff("draft_v1", state={"owner": "assistant"}) reply = tape.chat("Summarize the version changes in three bullets.", system_prompt="Keep it concise.") -print(reply.value) +print(reply) ``` ## Step 4: Handle failures and fallback ```python +from republic import ErrorPayload, LLM + llm = LLM( model="openai:gpt-4o-mini", fallback_models=["openrouter:openrouter/free"], @@ -49,8 +47,9 @@ llm = LLM( api_key={"openai": "", "openrouter": ""}, ) -result = llm.chat("say hello", max_tokens=8) -if result.error: - # error.kind is one of invalid_input/config/provider/tool/temporary/not_found/unknown - print(result.error.kind, result.error.message) +try: + result = llm.chat("say hello", max_tokens=8) + print(result) +except ErrorPayload as error: + print(error.kind, error.message) ``` diff --git a/src/republic/clients/chat.py b/src/republic/clients/chat.py index 58b0f19..58553c8 100644 --- a/src/republic/clients/chat.py +++ b/src/republic/clients/chat.py @@ -8,8 +8,11 @@ from functools import partial from typing import Any +from republic.clients.parsing import BaseTransportParser, TransportKind, parser_for_transport +from republic.clients.parsing.common import expand_tool_calls +from republic.clients.parsing.common import field as _field from republic.core.errors import ErrorKind, RepublicError -from republic.core.execution import LLMCore +from republic.core.execution import LLMCore, TransportResponse from republic.core.results import ( AsyncStreamEvents, AsyncTextStream, @@ -102,8 +105,8 @@ def _resolve_key_by_index(self, tool_call: Any, index: Any, position: int) -> ob return index_key position_key = self._key_at_position(position) - func = getattr(tool_call, "function", None) - tool_name = getattr(func, "name", None) if func is not None else None + func = _field(tool_call, "function") + tool_name = _field(func, "name") if func is not None else None if (tool_name is None or tool_name == "") and position_key is not None and position_key in self._calls: self._index_to_key[index] = position_key @@ -122,8 +125,8 @@ def _resolve_key_by_index(self, tool_call: Any, index: Any, position: int) -> ob return index_key def _resolve_key(self, tool_call: Any, position: int) -> object: - call_id = getattr(tool_call, "id", None) - index = getattr(tool_call, "index", None) + call_id = _field(tool_call, "id") + index = _field(tool_call, "index") if call_id is not None: return self._resolve_key_by_id(call_id, index, position) @@ -138,6 +141,27 @@ def _resolve_key(self, tool_call: Any, position: int) -> object: return position_key return ("position", position) + @staticmethod + def _merge_arguments( + entry: dict[str, Any], + *, + arguments: Any, + arguments_complete: bool, + ) -> None: + if arguments is None: + return + if not isinstance(arguments, str): + entry["function"]["arguments"] = arguments + return + if arguments_complete: + entry["function"]["arguments"] = arguments + return + + existing = entry["function"].get("arguments", "") + if not isinstance(existing, str): + existing = "" + entry["function"]["arguments"] = existing + arguments + def add_deltas(self, tool_calls: list[Any]) -> None: for position, tool_call in enumerate(tool_calls): key = self._resolve_key(tool_call, position) @@ -145,24 +169,27 @@ def add_deltas(self, tool_calls: list[Any]) -> None: self._order.append(key) self._calls[key] = {"function": {"name": "", "arguments": ""}} entry = self._calls[key] - call_id = getattr(tool_call, "id", None) + call_id = _field(tool_call, "id") if call_id: entry["id"] = call_id - call_type = getattr(tool_call, "type", None) + call_type = _field(tool_call, "type") if call_type: entry["type"] = call_type - func = getattr(tool_call, "function", None) + func = _field(tool_call, "function") if func is None: continue - name = getattr(func, "name", None) + name = _field(func, "name") if name: entry["function"]["name"] = name - arguments = getattr(func, "arguments", None) - if arguments: - entry["function"]["arguments"] = entry["function"].get("arguments", "") + arguments + arguments = _field(func, "arguments") + self._merge_arguments( + entry, + arguments=arguments, + arguments_complete=bool(_field(tool_call, "arguments_complete", False)), + ) def finalize(self) -> list[dict[str, Any]]: - return [self._calls[key] for key in self._order] + return expand_tool_calls([self._calls[key] for key in self._order]) class ChatClient: @@ -184,6 +211,58 @@ def __init__( def default_context(self) -> TapeContext: return self._tape.default_context + @staticmethod + def _unwrap_response(response: Any) -> tuple[Any, TransportKind | None]: + if isinstance(response, TransportResponse): + return response.payload, response.transport + return response, None + + @staticmethod + def _resolve_transport( + payload: Any, + transport: TransportKind | None = None, + ) -> TransportKind: + if transport is not None: + return transport + if isinstance(payload, list): + return "responses" + if _field(payload, "output") is not None: + return "responses" + if _field(payload, "output_text") is not None: + return "responses" + event_type = _field(payload, "type") + if isinstance(event_type, str) and event_type.startswith("response."): + return "responses" + return "completion" + + @staticmethod + def _parser_for_payload( + payload: Any, + *, + transport: TransportKind | None = None, + ) -> BaseTransportParser: + effective_transport = ChatClient._resolve_transport(payload, transport) + return parser_for_transport(effective_transport) + + @staticmethod + def _unwrap_response_with_parser( + response: Any, + *, + transport: TransportKind | None = None, + ) -> tuple[Any, BaseTransportParser]: + payload, detected_transport = ChatClient._unwrap_response(response) + parser = ChatClient._parser_for_payload(payload, transport=transport or detected_transport) + return payload, parser + + @staticmethod + def _is_non_stream_response( + response: Any, + *, + transport: TransportKind | None = None, + ) -> bool: + parser = ChatClient._parser_for_payload(response, transport=transport) + return parser.is_non_stream_response(response) + def _validate_chat_input( self, *, @@ -380,6 +459,14 @@ async def _prepare_request_async( context=context, ) + @staticmethod + def _split_reasoning_effort(kwargs: dict[str, Any]) -> tuple[Any | None, dict[str, Any]]: + if "reasoning_effort" not in kwargs: + return None, kwargs + request_kwargs = dict(kwargs) + reasoning_effort = request_kwargs.pop("reasoning_effort", None) + return reasoning_effort, request_kwargs + def _execute_sync( self, prepared: PreparedChat, @@ -394,6 +481,7 @@ def _execute_sync( ) -> Any: if prepared.context_error is not None: raise prepared.context_error + reasoning_effort, request_kwargs = self._split_reasoning_effort(kwargs) try: return self._core.run_chat_sync( messages_payload=prepared.payload, @@ -402,8 +490,8 @@ def _execute_sync( provider=provider, max_tokens=max_tokens, stream=stream, - reasoning_effort=None, - kwargs=kwargs, + reasoning_effort=reasoning_effort, + kwargs=request_kwargs, on_response=on_response, ) except RepublicError as exc: @@ -423,6 +511,7 @@ async def _execute_async( ) -> Any: if prepared.context_error is not None: raise prepared.context_error + reasoning_effort, request_kwargs = self._split_reasoning_effort(kwargs) try: return await self._core.run_chat_async( messages_payload=prepared.payload, @@ -431,8 +520,8 @@ async def _execute_async( provider=provider, max_tokens=max_tokens, stream=stream, - reasoning_effort=None, - kwargs=kwargs, + reasoning_effort=reasoning_effort, + kwargs=request_kwargs, on_response=on_response, ) except RepublicError as exc: @@ -840,12 +929,13 @@ def _handle_create_response( model_id: str, attempt: int, ) -> str | object: - text = self._extract_text(response) + payload, transport = self._unwrap_response(response) + text = self._extract_text(payload, transport=transport) if text: self._update_tape( prepared, text, - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -862,12 +952,13 @@ async def _handle_create_response_async( model_id: str, attempt: int, ) -> str | object: - text = self._extract_text(response) + payload, transport = self._unwrap_response(response) + text = self._extract_text(payload, transport=transport) if text: await self._update_tape_async( prepared, text, - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -884,13 +975,14 @@ def _handle_tool_calls_response( model_id: str, attempt: int, ) -> list[dict[str, Any]]: - calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + calls = self._extract_tool_calls(payload, transport=transport) self._update_tape( prepared, None, tool_calls=calls, tool_results=[], - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -904,13 +996,14 @@ async def _handle_tool_calls_response_async( model_id: str, attempt: int, ) -> list[dict[str, Any]]: - calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + calls = self._extract_tool_calls(payload, transport=transport) await self._update_tape_async( prepared, None, tool_calls=calls, tool_results=[], - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -924,7 +1017,8 @@ def _handle_tools_auto_response( model_id: str, attempt: int, ) -> ToolAutoResult | object: - tool_calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + tool_calls = self._extract_tool_calls(payload, transport=transport) if tool_calls: execution = self._tool_executor.execute( tool_calls, @@ -936,18 +1030,18 @@ def _handle_tools_auto_response( None, tool_calls=execution.tool_calls, tool_results=execution.tool_results, - response=response, + response=payload, provider=provider_name, model=model_id, ) return ToolAutoResult.tools_result(execution.tool_calls, execution.tool_results) - text = self._extract_text(response) + text = self._extract_text(payload, transport=transport) if text: self._update_tape( prepared, text, - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -965,7 +1059,8 @@ async def _handle_tools_auto_response_async( model_id: str, attempt: int, ) -> ToolAutoResult | object: - tool_calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + tool_calls = self._extract_tool_calls(payload, transport=transport) if tool_calls: execution = await self._tool_executor.execute_async( tool_calls, @@ -977,18 +1072,18 @@ async def _handle_tools_auto_response_async( None, tool_calls=execution.tool_calls, tool_results=execution.tool_results, - response=response, + response=payload, provider=provider_name, model=model_id, ) return ToolAutoResult.tools_result(execution.tool_calls, execution.tool_results) - text = self._extract_text(response) + text = self._extract_text(payload, transport=transport) if text: await self._update_tape_async( prepared, text, - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -1374,9 +1469,10 @@ def _build_text_stream( model_id: str, attempt: int, ) -> TextStream: - if hasattr(response, "choices"): - text = self._extract_text(response) - tool_calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + if self._is_non_stream_response(payload, transport=transport): + text = self._extract_text(payload, transport=transport) + tool_calls = self._extract_tool_calls(payload, transport=transport) state = StreamState() self._finalize_text_stream( prepared, @@ -1386,8 +1482,8 @@ def _build_text_stream( provider_name=provider_name, model_id=model_id, attempt=attempt, - usage=self._extract_usage(response), - response=response, + usage=self._extract_usage(payload, transport=transport), + response=payload, log_empty=False, ) return TextStream(iter([text]) if text else self._empty_iterator(), state=state) @@ -1401,15 +1497,15 @@ def _build_text_stream( def _iterator() -> Iterator[str]: nonlocal usage try: - for chunk in response: - deltas = self._extract_chunk_tool_call_deltas(chunk) + for chunk in payload: + deltas = self._extract_chunk_tool_call_deltas(chunk, transport=transport) if deltas: assembler.add_deltas(deltas) - text = self._extract_chunk_text(chunk) + text = self._extract_chunk_text(chunk, transport=transport) if text: parts.append(text) yield text - usage = self._extract_usage(chunk) or usage + usage = self._extract_usage(chunk, transport=transport) or usage except Exception as exc: kind = self._core.classify_exception(exc) wrapped = self._core.wrap_error(exc, kind, provider_name, model_id) @@ -1438,9 +1534,10 @@ async def _build_async_text_stream( model_id: str, attempt: int, ) -> AsyncTextStream: - if hasattr(response, "choices"): - text = self._extract_text(response) - tool_calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + if self._is_non_stream_response(payload, transport=transport): + text = self._extract_text(payload, transport=transport) + tool_calls = self._extract_tool_calls(payload, transport=transport) state = StreamState() await self._finalize_text_stream_async( prepared, @@ -1450,8 +1547,8 @@ async def _build_async_text_stream( provider_name=provider_name, model_id=model_id, attempt=attempt, - usage=self._extract_usage(response), - response=response, + usage=self._extract_usage(payload, transport=transport), + response=payload, log_empty=False, ) @@ -1469,15 +1566,15 @@ async def _single() -> AsyncIterator[str]: async def _iterator() -> AsyncIterator[str]: nonlocal usage try: - async for chunk in response: - deltas = self._extract_chunk_tool_call_deltas(chunk) + async for chunk in payload: + deltas = self._extract_chunk_tool_call_deltas(chunk, transport=transport) if deltas: assembler.add_deltas(deltas) - text = self._extract_chunk_text(chunk) + text = self._extract_chunk_text(chunk, transport=transport) if text: parts.append(text) yield text - usage = self._extract_usage(chunk) or usage + usage = self._extract_usage(chunk, transport=transport) or usage except Exception as exc: kind = self._core.classify_exception(exc) wrapped = self._core.wrap_error(exc, kind, provider_name, model_id) @@ -1499,28 +1596,22 @@ async def _iterator() -> AsyncIterator[str]: return AsyncTextStream(_iterator(), state=state) @staticmethod - def _chunk_has_tool_calls(chunk: Any) -> bool: - return bool(ChatClient._extract_chunk_tool_call_deltas(chunk)) - - @staticmethod - def _extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: - choices = getattr(chunk, "choices", None) - if not choices: - return [] - delta = getattr(choices[0], "delta", None) - if delta is None: - return [] - return getattr(delta, "tool_calls", None) or [] + def _extract_chunk_tool_call_deltas( + chunk: Any, + *, + transport: TransportKind | None = None, + ) -> list[Any]: + parser = ChatClient._parser_for_payload(chunk, transport=transport) + return parser.extract_chunk_tool_call_deltas(chunk) @staticmethod - def _extract_chunk_text(chunk: Any) -> str: - choices = getattr(chunk, "choices", None) - if not choices: - return "" - delta = getattr(choices[0], "delta", None) - if delta is None: - return "" - return getattr(delta, "content", "") or "" + def _extract_chunk_text( + chunk: Any, + *, + transport: TransportKind | None = None, + ) -> str: + parser = ChatClient._parser_for_payload(chunk, transport=transport) + return parser.extract_chunk_text(chunk) def _build_event_stream( self, @@ -1530,12 +1621,14 @@ def _build_event_stream( model_id: str, attempt: int, ) -> StreamEvents: - if hasattr(response, "choices"): + payload, transport = self._unwrap_response(response) + if self._is_non_stream_response(payload, transport=transport): return self._build_event_stream_from_response( prepared, - response, + payload, provider_name, model_id, + transport=transport, ) state = StreamState() @@ -1548,10 +1641,10 @@ def _build_event_stream( def _iterator() -> Iterator[StreamEvent]: nonlocal usage, tool_calls, tool_results try: - for chunk in response: - usage = self._extract_usage(chunk) or usage - assembler.add_deltas(self._extract_chunk_tool_call_deltas(chunk)) - text = self._extract_chunk_text(chunk) + for chunk in payload: + usage = self._extract_usage(chunk, transport=transport) or usage + assembler.add_deltas(self._extract_chunk_tool_call_deltas(chunk, transport=transport)) + text = self._extract_chunk_text(chunk, transport=transport) if text: parts.append(text) yield StreamEvent("text", {"delta": text}) @@ -1603,12 +1696,14 @@ def _build_async_event_stream( model_id: str, attempt: int, ) -> AsyncStreamEvents: - if hasattr(response, "choices"): + payload, transport = self._unwrap_response(response) + if self._is_non_stream_response(payload, transport=transport): return self._build_async_event_stream_from_response( prepared, - response, + payload, provider_name, model_id, + transport=transport, ) state = StreamState() @@ -1621,10 +1716,10 @@ def _build_async_event_stream( async def _iterator() -> AsyncIterator[StreamEvent]: nonlocal usage, tool_calls, tool_results try: - async for chunk in response: - usage = self._extract_usage(chunk) or usage - assembler.add_deltas(self._extract_chunk_tool_call_deltas(chunk)) - text = self._extract_chunk_text(chunk) + async for chunk in payload: + usage = self._extract_usage(chunk, transport=transport) or usage + assembler.add_deltas(self._extract_chunk_tool_call_deltas(chunk, transport=transport)) + text = self._extract_chunk_text(chunk, transport=transport) if text: parts.append(text) yield StreamEvent("text", {"delta": text}) @@ -1676,10 +1771,12 @@ def _build_event_stream_from_response( response: Any, provider_name: str, model_id: str, + *, + transport: TransportKind | None = None, ) -> StreamEvents: - text = self._extract_text(response) - tool_calls = self._extract_tool_calls(response) - usage = self._extract_usage(response) + text = self._extract_text(response, transport=transport) + tool_calls = self._extract_tool_calls(response, transport=transport) + usage = self._extract_usage(response, transport=transport) state = StreamState(usage=usage) tool_results: list[Any] = [] try: @@ -1732,10 +1829,12 @@ def _build_async_event_stream_from_response( response: Any, provider_name: str, model_id: str, + *, + transport: TransportKind | None = None, ) -> AsyncStreamEvents: - text = self._extract_text(response) - tool_calls = self._extract_tool_calls(response) - usage = self._extract_usage(response) + text = self._extract_text(response, transport=transport) + tool_calls = self._extract_tool_calls(response, transport=transport) + usage = self._extract_usage(response, transport=transport) state = StreamState(usage=usage) tool_results: list[Any] = [] @@ -1831,93 +1930,28 @@ def _make_tool_context(prepared: PreparedChat, provider_name: str, model_id: str ) @staticmethod - def _extract_text(response: Any) -> str: - if isinstance(response, str): - return response - output = getattr(response, "output", None) - if output: - parts: list[str] = [] - for item in output: - if getattr(item, "type", None) != "message": - continue - content = getattr(item, "content", None) or [] - for entry in content: - if getattr(entry, "type", None) == "output_text": - text = getattr(entry, "text", None) - if text: - parts.append(text) - return "".join(parts) - choices = getattr(response, "choices", None) - if not choices: - return "" - message = getattr(choices[0], "message", None) - if message is None: - return "" - return getattr(message, "content", "") or "" - - @staticmethod - def _extract_tool_calls(response: Any) -> list[dict[str, Any]]: - output = getattr(response, "output", None) - if output: - return ChatClient._extract_responses_tool_calls(output) - return ChatClient._extract_completion_tool_calls(response) - - @staticmethod - def _extract_responses_tool_calls(output: list[Any]) -> list[dict[str, Any]]: - calls: list[dict[str, Any]] = [] - for item in output: - if getattr(item, "type", None) != "function_call": - continue - name = getattr(item, "name", None) - arguments = getattr(item, "arguments", None) - if not name: - continue - entry: dict[str, Any] = {"function": {"name": name, "arguments": arguments or ""}} - call_id = getattr(item, "call_id", None) or getattr(item, "id", None) - if call_id: - entry["id"] = call_id - entry["type"] = "function" - calls.append(entry) - return calls + def _extract_text( + response: Any, + *, + transport: TransportKind | None = None, + ) -> str: + payload, parser = ChatClient._unwrap_response_with_parser(response, transport=transport) + return parser.extract_text(payload) @staticmethod - def _extract_completion_tool_calls(response: Any) -> list[dict[str, Any]]: - choices = getattr(response, "choices", None) - if not choices: - return [] - message = getattr(choices[0], "message", None) - if message is None: - return [] - tool_calls = getattr(message, "tool_calls", None) or [] - calls: list[dict[str, Any]] = [] - for tool_call in tool_calls: - entry: dict[str, Any] = { - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - } - } - call_id = getattr(tool_call, "id", None) - if call_id: - entry["id"] = call_id - call_type = getattr(tool_call, "type", None) - if call_type: - entry["type"] = call_type - calls.append(entry) - return calls + def _extract_tool_calls( + response: Any, + *, + transport: TransportKind | None = None, + ) -> list[dict[str, Any]]: + payload, parser = ChatClient._unwrap_response_with_parser(response, transport=transport) + return parser.extract_tool_calls(payload) @staticmethod - def _extract_usage(response: Any) -> dict[str, Any] | None: - usage = getattr(response, "usage", None) - if usage is None: - return None - if hasattr(usage, "model_dump"): - return usage.model_dump() - if isinstance(usage, dict): - return dict(usage) - data: dict[str, Any] = {} - for field in ("input_tokens", "output_tokens", "total_tokens", "requests"): - value = getattr(usage, field, None) - if value is not None: - data[field] = value - return data or None + def _extract_usage( + response: Any, + *, + transport: TransportKind | None = None, + ) -> dict[str, Any] | None: + payload, parser = ChatClient._unwrap_response_with_parser(response, transport=transport) + return parser.extract_usage(payload) diff --git a/src/republic/clients/parsing/__init__.py b/src/republic/clients/parsing/__init__.py new file mode 100644 index 0000000..66db4c0 --- /dev/null +++ b/src/republic/clients/parsing/__init__.py @@ -0,0 +1,21 @@ +"""Parsing helpers for provider response payloads.""" + +from __future__ import annotations + +from republic.clients.parsing.completion import PARSER as completion_parser +from republic.clients.parsing.messages import PARSER as messages_parser +from republic.clients.parsing.responses import PARSER as responses_parser +from republic.clients.parsing.types import BaseTransportParser, TransportKind + +_PARSERS: dict[TransportKind, BaseTransportParser] = { + "completion": completion_parser, + "responses": responses_parser, + "messages": messages_parser, +} + + +def parser_for_transport(transport: TransportKind) -> BaseTransportParser: + return _PARSERS[transport] + + +__all__ = ["BaseTransportParser", "TransportKind", "parser_for_transport"] diff --git a/src/republic/clients/parsing/common.py b/src/republic/clients/parsing/common.py new file mode 100644 index 0000000..0a47ecd --- /dev/null +++ b/src/republic/clients/parsing/common.py @@ -0,0 +1,67 @@ +"""Common parsing utilities shared by completion and responses adapters.""" + +from __future__ import annotations + +import json +from itertools import chain +from typing import Any + + +def field(data: Any, key: str, default: Any = None) -> Any: + if isinstance(data, dict): + return data.get(key, default) + return getattr(data, key, default) + + +def expand_tool_calls(calls: list[dict[str, Any]]) -> list[dict[str, Any]]: + return list(chain.from_iterable(_expand_tool_call(call) for call in calls)) + + +def _expand_tool_call(call: dict[str, Any]) -> list[dict[str, Any]]: + function = field(call, "function") + if not isinstance(function, dict): + return [dict(call)] + + arguments = field(function, "arguments") + if not isinstance(arguments, str): + return [dict(call)] + + chunks = _split_concatenated_json_objects(arguments) + if not chunks: + return [dict(call)] + + call_id = field(call, "id") + expanded: list[dict[str, Any]] = [] + for index, chunk in enumerate(chunks): + cloned = dict(call) + cloned_function = dict(function) + cloned_function["arguments"] = chunk + cloned["function"] = cloned_function + if isinstance(call_id, str) and call_id and index > 0: + cloned["id"] = f"{call_id}__{index + 1}" + expanded.append(cloned) + return expanded + + +def _split_concatenated_json_objects(raw: str) -> list[str]: + decoder = json.JSONDecoder() + chunks: list[str] = [] + position = 0 + total = len(raw) + while position < total: + while position < total and raw[position].isspace(): + position += 1 + if position >= total: + break + try: + parsed, end = decoder.raw_decode(raw, position) + except json.JSONDecodeError: + return [] + if not isinstance(parsed, dict): + return [] + chunks.append(raw[position:end]) + position = end + + if len(chunks) <= 1: + return [] + return chunks diff --git a/src/republic/clients/parsing/completion.py b/src/republic/clients/parsing/completion.py new file mode 100644 index 0000000..aa77313 --- /dev/null +++ b/src/republic/clients/parsing/completion.py @@ -0,0 +1,95 @@ +"""OpenAI chat-completions shape parsing.""" + +from __future__ import annotations + +from typing import Any + +from republic.clients.parsing.common import expand_tool_calls, field +from republic.clients.parsing.types import FunctionTransportParser + + +def is_non_stream_response(response: Any) -> bool: + return isinstance(response, str) or field(response, "choices") is not None + + +def extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: + choices = field(chunk, "choices") + if not choices: + return [] + delta = field(choices[0], "delta") + if delta is None: + return [] + return field(delta, "tool_calls") or [] + + +def extract_chunk_text(chunk: Any) -> str: + choices = field(chunk, "choices") + if not choices: + return "" + delta = field(choices[0], "delta") + if delta is None: + return "" + return field(delta, "content", "") or "" + + +def extract_text(response: Any) -> str: + if isinstance(response, str): + return response + + choices = field(response, "choices") + if not choices: + return "" + message = field(choices[0], "message") + if message is None: + return "" + return field(message, "content", "") or "" + + +def extract_tool_calls(response: Any) -> list[dict[str, Any]]: + choices = field(response, "choices") + if not choices: + return [] + message = field(choices[0], "message") + if message is None: + return [] + tool_calls = field(message, "tool_calls") or [] + calls: list[dict[str, Any]] = [] + for tool_call in tool_calls: + function = field(tool_call, "function") + if function is None: + continue + entry: dict[str, Any] = { + "function": { + "name": field(function, "name"), + "arguments": field(function, "arguments"), + } + } + call_id = field(tool_call, "id") + if call_id: + entry["id"] = call_id + call_type = field(tool_call, "type") + if call_type: + entry["type"] = call_type + calls.append(entry) + return expand_tool_calls(calls) + + +def extract_usage(response: Any) -> dict[str, Any] | None: + usage = field(response, "usage") + if usage is None: + return None + if isinstance(usage, dict): + return dict(usage) + if hasattr(usage, "model_dump"): + return usage.model_dump() + return None + + +PARSER = FunctionTransportParser( + is_non_stream_response_fn=is_non_stream_response, + extract_chunk_tool_call_deltas_fn=extract_chunk_tool_call_deltas, + extract_chunk_text_fn=extract_chunk_text, + extract_text_fn=extract_text, + extract_tool_calls_fn=extract_tool_calls, + extract_usage_fn=extract_usage, +) diff --git a/src/republic/clients/parsing/messages.py b/src/republic/clients/parsing/messages.py new file mode 100644 index 0000000..190f354 --- /dev/null +++ b/src/republic/clients/parsing/messages.py @@ -0,0 +1,11 @@ +"""Anthropic messages parsing. + +Currently any-llm exposes Anthropic messages in completion-compatible payload +shapes, so this parser intentionally reuses completion parsing behavior. +""" + +from __future__ import annotations + +from republic.clients.parsing.completion import PARSER + +__all__ = ["PARSER"] diff --git a/src/republic/clients/parsing/responses.py b/src/republic/clients/parsing/responses.py new file mode 100644 index 0000000..f1bec0b --- /dev/null +++ b/src/republic/clients/parsing/responses.py @@ -0,0 +1,157 @@ +"""OpenAI responses shape parsing.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +from republic.clients.parsing.common import expand_tool_calls, field +from republic.clients.parsing.types import FunctionTransportParser + + +def is_non_stream_response(response: Any) -> bool: + return ( + isinstance(response, str) + or field(response, "choices") is not None + or field(response, "output") is not None + or field(response, "output_text") is not None + ) + + +def _tool_delta_from_args_event(chunk: Any, event_type: str) -> list[Any]: + item_id = field(chunk, "item_id") + if not item_id: + return [] + arguments = field(chunk, "delta") + if event_type == "response.function_call_arguments.done": + arguments = field(chunk, "arguments") + if not isinstance(arguments, str): + return [] + + call_id = field(chunk, "call_id") + payload: dict[str, Any] = { + "index": item_id, + "type": "function", + "function": SimpleNamespace(name=field(chunk, "name") or "", arguments=arguments), + "arguments_complete": event_type == "response.function_call_arguments.done", + } + if call_id: + payload["id"] = call_id + return [SimpleNamespace(**payload)] + + +def _tool_delta_from_output_item_event(chunk: Any, event_type: str) -> list[Any]: + item = field(chunk, "item") + if field(item, "type") != "function_call": + return [] + + item_id = field(item, "id") + call_id = field(item, "call_id") or item_id + if not call_id: + return [] + arguments = field(item, "arguments") + if not isinstance(arguments, str): + arguments = "" + return [ + SimpleNamespace( + id=call_id, + index=item_id or call_id, + type="function", + function=SimpleNamespace(name=field(item, "name") or "", arguments=arguments), + arguments_complete=event_type == "response.output_item.done", + ) + ] + + +def extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: + event_type = field(chunk, "type") + if event_type in {"response.function_call_arguments.delta", "response.function_call_arguments.done"}: + return _tool_delta_from_args_event(chunk, event_type) + if event_type in {"response.output_item.added", "response.output_item.done"}: + return _tool_delta_from_output_item_event(chunk, event_type) + return [] + + +def extract_chunk_text(chunk: Any) -> str: + if field(chunk, "type") != "response.output_text.delta": + return "" + delta = field(chunk, "delta") + if isinstance(delta, str): + return delta + return "" + + +def extract_text_from_output(output: Any) -> str: + if not isinstance(output, list): + return "" + parts: list[str] = [] + for item in output: + if field(item, "type") != "message": + continue + content = field(item, "content") or [] + for entry in content: + if field(entry, "type") == "output_text": + text = field(entry, "text") + if text: + parts.append(text) + return "".join(parts) + + +def extract_text(response: Any) -> str: + output_text = field(response, "output_text") + if isinstance(output_text, str): + return output_text + return extract_text_from_output(field(response, "output")) + + +def extract_tool_calls(response: Any) -> list[dict[str, Any]]: + output = response if isinstance(response, list) else field(response, "output") + if not isinstance(output, list): + return [] + calls: list[dict[str, Any]] = [] + for item in output: + if field(item, "type") != "function_call": + continue + name = field(item, "name") + arguments = field(item, "arguments") + if not name: + continue + entry: dict[str, Any] = {"function": {"name": name, "arguments": arguments or ""}} + call_id = field(item, "call_id") or field(item, "id") + if call_id: + entry["id"] = call_id + entry["type"] = "function" + calls.append(entry) + return expand_tool_calls(calls) + + +def extract_usage(response: Any) -> dict[str, Any] | None: + event_type = field(response, "type") + if event_type in {"response.completed", "response.in_progress", "response.failed", "response.incomplete"}: + usage = field(field(response, "response"), "usage") + else: + usage = field(response, "usage") + + if usage is None: + return None + if hasattr(usage, "model_dump"): + return usage.model_dump() + if isinstance(usage, dict): + return dict(usage) + + data: dict[str, Any] = {} + for usage_field in ("input_tokens", "output_tokens", "total_tokens", "requests"): + value = field(usage, usage_field) + if value is not None: + data[usage_field] = value + return data or None + + +PARSER = FunctionTransportParser( + is_non_stream_response_fn=is_non_stream_response, + extract_chunk_tool_call_deltas_fn=extract_chunk_tool_call_deltas, + extract_chunk_text_fn=extract_chunk_text, + extract_text_fn=extract_text, + extract_tool_calls_fn=extract_tool_calls, + extract_usage_fn=extract_usage, +) diff --git a/src/republic/clients/parsing/types.py b/src/republic/clients/parsing/types.py new file mode 100644 index 0000000..0290e19 --- /dev/null +++ b/src/republic/clients/parsing/types.py @@ -0,0 +1,52 @@ +"""Shared parser typing and validation primitives.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Literal, Protocol, runtime_checkable + +TransportKind = Literal["completion", "responses", "messages"] + + +@runtime_checkable +class BaseTransportParser(Protocol): + def is_non_stream_response(self, response: Any) -> bool: ... + + def extract_chunk_tool_call_deltas(self, chunk: Any) -> list[Any]: ... + + def extract_chunk_text(self, chunk: Any) -> str: ... + + def extract_text(self, response: Any) -> str: ... + + def extract_tool_calls(self, response: Any) -> list[dict[str, Any]]: ... + + def extract_usage(self, response: Any) -> dict[str, Any] | None: ... + + +@dataclass(frozen=True, slots=True) +class FunctionTransportParser: + is_non_stream_response_fn: Callable[[Any], bool] + extract_chunk_tool_call_deltas_fn: Callable[[Any], list[Any]] + extract_chunk_text_fn: Callable[[Any], str] + extract_text_fn: Callable[[Any], str] + extract_tool_calls_fn: Callable[[Any], list[dict[str, Any]]] + extract_usage_fn: Callable[[Any], dict[str, Any] | None] + + def is_non_stream_response(self, response: Any) -> bool: + return self.is_non_stream_response_fn(response) + + def extract_chunk_tool_call_deltas(self, chunk: Any) -> list[Any]: + return self.extract_chunk_tool_call_deltas_fn(chunk) + + def extract_chunk_text(self, chunk: Any) -> str: + return self.extract_chunk_text_fn(chunk) + + def extract_text(self, response: Any) -> str: + return self.extract_text_fn(response) + + def extract_tool_calls(self, response: Any) -> list[dict[str, Any]]: + return self.extract_tool_calls_fn(response) + + def extract_usage(self, response: Any) -> dict[str, Any] | None: + return self.extract_usage_fn(response) diff --git a/src/republic/core/execution.py b/src/republic/core/execution.py index ccc3a73..2011314 100644 --- a/src/republic/core/execution.py +++ b/src/republic/core/execution.py @@ -9,7 +9,7 @@ from collections.abc import Callable from dataclasses import dataclass from enum import Enum, auto -from typing import Any, NoReturn +from typing import Any, Literal, NoReturn from any_llm import AnyLLM from any_llm.exceptions import ( @@ -26,7 +26,9 @@ UnsupportedProviderError, ) +from republic.core import provider_policies from republic.core.errors import ErrorKind, RepublicError +from republic.core.request_adapters import normalize_responses_kwargs logger = logging.getLogger(__name__) @@ -46,6 +48,25 @@ class AttemptOutcome: decision: AttemptDecision +@dataclass(frozen=True) +class TransportResponse: + transport: Literal["completion", "responses", "messages"] + payload: Any + + +@dataclass(frozen=True) +class TransportCallRequest: + client: AnyLLM + provider_name: str + model_id: str + messages_payload: list[dict[str, Any]] + tools_payload: list[dict[str, Any]] | None + max_tokens: int | None + stream: bool + reasoning_effort: Any | None + kwargs: dict[str, Any] + + class LLMCore: """Shared LLM execution utilities (provider resolution, retries, client cache).""" @@ -61,7 +82,7 @@ def __init__( api_key: str | dict[str, str] | None, api_base: str | dict[str, str] | None, client_args: dict[str, Any], - use_responses: bool, + api_format: Literal["completion", "responses", "messages"], verbose: int, error_classifier: Callable[[Exception], ErrorKind | None] | None = None, ) -> None: @@ -72,7 +93,7 @@ def __init__( self._api_key = api_key self._api_base = api_base self._client_args = client_args - self._use_responses = use_responses + self._api_format = api_format self._verbose = verbose self._error_classifier = error_classifier self._client_cache: dict[str, AnyLLM] = {} @@ -328,8 +349,12 @@ def raise_wrapped(self, exc: Exception, provider: str, model: str) -> NoReturn: raise self.wrap_error(exc, kind, provider, model) from exc def _handle_attempt_error(self, exc: Exception, provider_name: str, model_id: str, attempt: int) -> AttemptOutcome: - kind = self.classify_exception(exc) - wrapped = self.wrap_error(exc, kind, provider_name, model_id) + if isinstance(exc, RepublicError): + wrapped = exc + kind = exc.kind + else: + kind = self.classify_exception(exc) + wrapped = self.wrap_error(exc, kind, provider_name, model_id) self.log_error(wrapped, provider_name, model_id, attempt) can_retry_same_model = self.should_retry(kind) and attempt + 1 < self.max_attempts() if can_retry_same_model: @@ -339,21 +364,185 @@ def _handle_attempt_error(self, exc: Exception, provider_name: str, model_id: st def _decide_kwargs_for_provider( self, provider: str, max_tokens: int | None, kwargs: dict[str, Any] ) -> dict[str, Any]: - use_completion_tokens = "openai" in provider.lower() - if not use_completion_tokens: - return {**kwargs, "max_tokens": max_tokens} - if "max_completion_tokens" in kwargs: - return kwargs - return {**kwargs, "max_completion_tokens": max_tokens} + clean_kwargs = dict(kwargs) + max_tokens_arg = provider_policies.completion_max_tokens_arg(provider) + if max_tokens_arg in clean_kwargs: + return clean_kwargs + return {**clean_kwargs, max_tokens_arg: max_tokens} def _decide_responses_kwargs(self, max_tokens: int | None, kwargs: dict[str, Any]) -> dict[str, Any]: + # any-llm responses params currently reject extra_headers, so drop it here only. clean_kwargs = {k: v for k, v in kwargs.items() if k != "extra_headers"} + clean_kwargs = normalize_responses_kwargs(clean_kwargs) if "max_output_tokens" in clean_kwargs: return clean_kwargs return {**clean_kwargs, "max_output_tokens": max_tokens} - def _should_use_responses(self, client: AnyLLM, *, stream: bool) -> bool: - return not stream and self._use_responses and bool(getattr(client, "SUPPORTS_RESPONSES", False)) + @staticmethod + def _should_default_completion_stream_usage(provider_name: str) -> bool: + return provider_policies.should_include_completion_stream_usage(provider_name) + + def _with_default_completion_stream_options( + self, + provider_name: str, + stream: bool, + kwargs: dict[str, Any], + ) -> dict[str, Any]: + if not stream: + return kwargs + if not self._should_default_completion_stream_usage(provider_name): + return kwargs + if "stream_options" in kwargs: + return kwargs + return {**kwargs, "stream_options": {"include_usage": True}} + + @staticmethod + def _with_responses_reasoning( + kwargs: dict[str, Any], + reasoning_effort: Any | None, + ) -> dict[str, Any]: + if reasoning_effort is None: + return kwargs + if "reasoning" in kwargs: + return kwargs + return {**kwargs, "reasoning": {"effort": reasoning_effort}} + + @staticmethod + def _convert_tools_for_responses(tools_payload: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None: + if not tools_payload: + return tools_payload + + converted_tools: list[dict[str, Any]] = [] + for tool in tools_payload: + function = tool.get("function") + if isinstance(function, dict): + converted: dict[str, Any] = { + "type": tool.get("type", "function"), + "name": function.get("name"), + "description": function.get("description", ""), + "parameters": function.get("parameters", {}), + } + if "strict" in function: + converted["strict"] = function["strict"] + converted_tools.append(converted) + continue + converted_tools.append(dict(tool)) + return converted_tools + + def _selected_transport( + self, + client: AnyLLM, + *, + provider_name: str, + model_id: str, + tools_payload: list[dict[str, Any]] | None, + ) -> Literal["completion", "responses", "messages"]: + if self._api_format == "completion": + return "completion" + if self._api_format == "messages": + if not provider_policies.supports_messages_format( + provider_name=provider_name, + model_id=model_id, + ): + raise RepublicError( + ErrorKind.INVALID_INPUT, + f"{provider_name}:{model_id}: messages format is only valid for Anthropic models", + ) + return "messages" + + reason = provider_policies.responses_rejection_reason( + provider_name=provider_name, + model_id=model_id, + has_tools=bool(tools_payload), + supports_responses=bool(getattr(client, "SUPPORTS_RESPONSES", False)), + ) + if reason is not None: + raise RepublicError(ErrorKind.INVALID_INPUT, f"{provider_name}:{model_id}: {reason}") + return "responses" + + def _call_responses_sync( + self, + request: TransportCallRequest, + ) -> Any: + instructions, input_items = self._split_messages_for_responses(request.messages_payload) + responses_kwargs = self._with_responses_reasoning(request.kwargs, request.reasoning_effort) + return TransportResponse( + transport="responses", + payload=request.client.responses( + model=request.model_id, + input_data=input_items, + tools=self._convert_tools_for_responses(request.tools_payload), + stream=request.stream, + instructions=instructions, + **self._decide_responses_kwargs(request.max_tokens, responses_kwargs), + ), + ) + + def _call_completion_like_sync( + self, + *, + transport: Literal["completion", "messages"], + request: TransportCallRequest, + ) -> Any: + completion_kwargs = self._decide_kwargs_for_provider(request.provider_name, request.max_tokens, request.kwargs) + completion_kwargs = self._with_default_completion_stream_options( + request.provider_name, + request.stream, + completion_kwargs, + ) + return TransportResponse( + transport=transport, + payload=request.client.completion( + model=request.model_id, + messages=request.messages_payload, + tools=request.tools_payload, + stream=request.stream, + reasoning_effort=request.reasoning_effort, + **completion_kwargs, + ), + ) + + async def _call_responses_async( + self, + request: TransportCallRequest, + ) -> Any: + instructions, input_items = self._split_messages_for_responses(request.messages_payload) + responses_kwargs = self._with_responses_reasoning(request.kwargs, request.reasoning_effort) + return TransportResponse( + transport="responses", + payload=await request.client.aresponses( + model=request.model_id, + input_data=input_items, + tools=self._convert_tools_for_responses(request.tools_payload), + stream=request.stream, + instructions=instructions, + **self._decide_responses_kwargs(request.max_tokens, responses_kwargs), + ), + ) + + async def _call_completion_like_async( + self, + *, + transport: Literal["completion", "messages"], + request: TransportCallRequest, + ) -> Any: + completion_kwargs = self._decide_kwargs_for_provider(request.provider_name, request.max_tokens, request.kwargs) + completion_kwargs = self._with_default_completion_stream_options( + request.provider_name, + request.stream, + completion_kwargs, + ) + return TransportResponse( + transport=transport, + payload=await request.client.acompletion( + model=request.model_id, + messages=request.messages_payload, + tools=request.tools_payload, + stream=request.stream, + reasoning_effort=request.reasoning_effort, + **completion_kwargs, + ), + ) def _call_client_sync( self, @@ -368,24 +557,26 @@ def _call_client_sync( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - if self._should_use_responses(client, stream=stream): - instructions, input_items = self._split_messages_for_responses(messages_payload) - return client.responses( - model=model_id, - input_data=input_items, - tools=tools_payload, - stream=stream, - instructions=instructions, - **self._decide_responses_kwargs(max_tokens, kwargs), - ) - return client.completion( - model=model_id, - messages=messages_payload, - tools=tools_payload, + request = TransportCallRequest( + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, stream=stream, reasoning_effort=reasoning_effort, - **self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs), + kwargs=kwargs, + ) + transport = self._selected_transport( + client, + provider_name=provider_name, + model_id=model_id, + tools_payload=tools_payload, ) + if transport == "responses": + return self._call_responses_sync(request) + return self._call_completion_like_sync(transport=transport, request=request) async def _call_client_async( self, @@ -400,24 +591,26 @@ async def _call_client_async( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - if self._should_use_responses(client, stream=stream): - instructions, input_items = self._split_messages_for_responses(messages_payload) - return await client.aresponses( - model=model_id, - input_data=input_items, - tools=tools_payload, - stream=stream, - instructions=instructions, - **self._decide_responses_kwargs(max_tokens, kwargs), - ) - return await client.acompletion( - model=model_id, - messages=messages_payload, - tools=tools_payload, + request = TransportCallRequest( + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, stream=stream, reasoning_effort=reasoning_effort, - **self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs), + kwargs=kwargs, + ) + transport = self._selected_transport( + client, + provider_name=provider_name, + model_id=model_id, + tools_payload=tools_payload, ) + if transport == "responses": + return await self._call_responses_async(request) + return await self._call_completion_like_async(transport=transport, request=request) @staticmethod def _split_messages_for_responses( diff --git a/src/republic/core/provider_policies.py b/src/republic/core/provider_policies.py new file mode 100644 index 0000000..db99276 --- /dev/null +++ b/src/republic/core/provider_policies.py @@ -0,0 +1,72 @@ +"""Provider policy decisions shared across request paths.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ProviderPolicy: + enable_responses_without_capability: bool = False + include_usage_in_completion_stream: bool = False + completion_max_tokens_arg: str = "max_tokens" + responses_tools_blocked_model_prefixes: tuple[str, ...] = () + + +_DEFAULT_POLICY = ProviderPolicy() +_POLICIES: dict[str, ProviderPolicy] = { + "openai": ProviderPolicy( + include_usage_in_completion_stream=True, + completion_max_tokens_arg="max_completion_tokens", + ), + # any-llm supports OpenRouter responses in practice but still reports SUPPORTS_RESPONSES=False. + "openrouter": ProviderPolicy( + enable_responses_without_capability=True, + include_usage_in_completion_stream=True, + responses_tools_blocked_model_prefixes=("anthropic/",), + ), +} + + +def _normalize_provider_name(provider_name: str) -> str: + return provider_name.strip().lower() + + +def provider_policy(provider_name: str) -> ProviderPolicy: + return _POLICIES.get(_normalize_provider_name(provider_name), _DEFAULT_POLICY) + + +def _responses_tools_blocked_for_model(provider_name: str, model_id: str) -> bool: + policy = provider_policy(provider_name) + lowered_model = model_id.strip().lower() + return any(lowered_model.startswith(prefix) for prefix in policy.responses_tools_blocked_model_prefixes) + + +def responses_rejection_reason( + *, + provider_name: str, + model_id: str, + has_tools: bool, + supports_responses: bool, +) -> str | None: + if has_tools and _responses_tools_blocked_for_model(provider_name, model_id): + return "responses format is not supported for this model when tools are enabled" + if supports_responses: + return None + if provider_policy(provider_name).enable_responses_without_capability: + return None + return "responses format is not supported by this provider" + + +def supports_messages_format(*, provider_name: str, model_id: str) -> bool: + normalized_provider = _normalize_provider_name(provider_name) + normalized_model = model_id.strip().lower() + return normalized_provider == "anthropic" or normalized_model.startswith("anthropic/") + + +def should_include_completion_stream_usage(provider_name: str) -> bool: + return provider_policy(provider_name).include_usage_in_completion_stream + + +def completion_max_tokens_arg(provider_name: str) -> str: + return provider_policy(provider_name).completion_max_tokens_arg diff --git a/src/republic/core/request_adapters.py b/src/republic/core/request_adapters.py new file mode 100644 index 0000000..7f68bbd --- /dev/null +++ b/src/republic/core/request_adapters.py @@ -0,0 +1,26 @@ +"""Request-shape adapters for different upstream APIs.""" + +from __future__ import annotations + +from typing import Any + + +def normalize_responses_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: + """Normalize completion-style kwargs into responses-compatible shapes.""" + tool_choice = kwargs.get("tool_choice") + if not isinstance(tool_choice, dict): + return kwargs + + function = tool_choice.get("function") + if not isinstance(function, dict): + return kwargs + + function_name = function.get("name") + if not isinstance(function_name, str) or not function_name: + return kwargs + + normalized_tool_choice = dict(tool_choice) + normalized_tool_choice.pop("function", None) + normalized_tool_choice["type"] = normalized_tool_choice.get("type", "function") + normalized_tool_choice["name"] = function_name + return {**kwargs, "tool_choice": normalized_tool_choice} diff --git a/src/republic/llm.py b/src/republic/llm.py index 445f27b..cae1851 100644 --- a/src/republic/llm.py +++ b/src/republic/llm.py @@ -4,7 +4,7 @@ import warnings from collections.abc import Callable -from typing import Any, cast +from typing import Any, Literal, cast from republic.__about__ import DEFAULT_MODEL from republic.clients._internal import InternalOps @@ -48,7 +48,7 @@ def __init__( api_key: str | dict[str, str] | None = None, api_base: str | dict[str, str] | None = None, client_args: dict[str, Any] | None = None, - use_responses: bool = False, + api_format: Literal["completion", "responses", "messages"] = "completion", verbose: int = 0, tape_store: TapeStore | AsyncTapeStore | None = None, context: TapeContext | None = None, @@ -58,6 +58,11 @@ def __init__( raise RepublicError(ErrorKind.INVALID_INPUT, "verbose must be 0, 1, or 2") if max_retries < 0: raise RepublicError(ErrorKind.INVALID_INPUT, "max_retries must be >= 0") + if api_format not in {"completion", "responses", "messages"}: + raise RepublicError( + ErrorKind.INVALID_INPUT, + "api_format must be 'completion', 'responses', or 'messages'", + ) if not model: model = DEFAULT_MODEL @@ -73,7 +78,7 @@ def __init__( api_key=api_key, api_base=api_base, client_args=client_args or {}, - use_responses=use_responses, + api_format=api_format, verbose=verbose, error_classifier=error_classifier, ) diff --git a/tests/fakes.py b/tests/fakes.py index 0dbaef9..be13c8d 100644 --- a/tests/fakes.py +++ b/tests/fakes.py @@ -142,10 +142,56 @@ def make_responses_response( *, text: str = "", tool_calls: list[Any] | None = None, + usage: dict[str, Any] | None = None, ) -> Any: output: list[Any] = [] if text: output.append(make_responses_message(text)) if tool_calls: output.extend(tool_calls) - return SimpleNamespace(output=output) + return SimpleNamespace(output=output, usage=usage) + + +def make_responses_text_delta(delta: str) -> Any: + return SimpleNamespace(type="response.output_text.delta", delta=delta) + + +def make_responses_function_delta(delta: str, *, item_id: str = "call_1") -> Any: + return SimpleNamespace(type="response.function_call_arguments.delta", delta=delta, item_id=item_id, output_index=0) + + +def make_responses_function_done(name: str, arguments: str, *, item_id: str = "call_1") -> Any: + return SimpleNamespace( + type="response.function_call_arguments.done", + name=name, + arguments=arguments, + item_id=item_id, + output_index=0, + ) + + +def make_responses_completed(usage: dict[str, Any] | None = None) -> Any: + response = SimpleNamespace(usage=usage) + return SimpleNamespace(type="response.completed", response=response) + + +def make_responses_output_item_added( + *, + item_id: str = "fc_1", + call_id: str = "call_1", + name: str = "echo", + arguments: str = "", +) -> Any: + item = SimpleNamespace(type="function_call", id=item_id, call_id=call_id, name=name, arguments=arguments) + return SimpleNamespace(type="response.output_item.added", item=item) + + +def make_responses_output_item_done( + *, + item_id: str = "fc_1", + call_id: str = "call_1", + name: str = "echo", + arguments: str = '{"text":"tokyo"}', +) -> Any: + item = SimpleNamespace(type="function_call", id=item_id, call_id=call_id, name=name, arguments=arguments) + return SimpleNamespace(type="response.output_item.done", item=item) diff --git a/tests/test_parsing_registry.py b/tests/test_parsing_registry.py new file mode 100644 index 0000000..3bf7db9 --- /dev/null +++ b/tests/test_parsing_registry.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from republic.clients.chat import ChatClient +from republic.clients.parsing import parser_for_transport +from republic.clients.parsing import responses as responses_parser +from republic.clients.parsing.types import BaseTransportParser + +from .fakes import make_responses_function_call, make_responses_response + + +def test_parser_for_transport_returns_parser_objects() -> None: + parser = parser_for_transport("completion") + assert isinstance(parser, BaseTransportParser) + assert callable(parser.extract_text) + assert callable(parser.extract_tool_calls) + assert callable(parser.extract_usage) + + parser = parser_for_transport("responses") + assert isinstance(parser, BaseTransportParser) + assert callable(parser.extract_text) + assert callable(parser.extract_tool_calls) + assert callable(parser.extract_usage) + + parser = parser_for_transport("messages") + assert isinstance(parser, BaseTransportParser) + assert callable(parser.extract_text) + assert callable(parser.extract_tool_calls) + assert callable(parser.extract_usage) + + +def test_responses_extract_tool_calls_accepts_full_response() -> None: + response = make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) + calls = responses_parser.extract_tool_calls(response) + assert calls[0]["function"]["name"] == "echo" + + +def test_chat_client_resolve_transport_treats_output_text_as_responses() -> None: + payload = SimpleNamespace(output_text="hello") + assert ChatClient._is_non_stream_response(payload) is True + assert ChatClient._extract_text(payload) == "hello" diff --git a/tests/test_provider_policies.py b/tests/test_provider_policies.py new file mode 100644 index 0000000..12346da --- /dev/null +++ b/tests/test_provider_policies.py @@ -0,0 +1,67 @@ +from republic.core import provider_policies + + +def test_responses_rejection_reason_none_when_openrouter_responses_available() -> None: + assert ( + provider_policies.responses_rejection_reason( + provider_name="openrouter", + model_id="openai/gpt-4o-mini", + has_tools=False, + supports_responses=False, + ) + is None + ) + + +def test_responses_rejection_reason_for_provider_without_responses() -> None: + reason = provider_policies.responses_rejection_reason( + provider_name="anthropic", + model_id="claude-3-5-haiku-latest", + has_tools=False, + supports_responses=False, + ) + assert reason is not None + assert "not supported" in reason + + +def test_responses_rejection_reason_for_openrouter_anthropic_tools() -> None: + reason = provider_policies.responses_rejection_reason( + provider_name="openrouter", + model_id="anthropic/claude-3.5-haiku", + has_tools=True, + supports_responses=False, + ) + assert reason is not None + assert "tools" in reason + + +def test_supports_messages_format() -> None: + assert provider_policies.supports_messages_format( + provider_name="anthropic", + model_id="claude-3-5-haiku-latest", + ) + assert provider_policies.supports_messages_format( + provider_name="openrouter", + model_id="anthropic/claude-3.5-haiku", + ) + assert not provider_policies.supports_messages_format( + provider_name="openai", + model_id="gpt-4o-mini", + ) + + +def test_completion_stream_usage_policy() -> None: + assert provider_policies.should_include_completion_stream_usage("openai") + assert provider_policies.should_include_completion_stream_usage("openrouter") + assert not provider_policies.should_include_completion_stream_usage("anthropic") + + +def test_completion_max_tokens_arg_policy() -> None: + assert provider_policies.completion_max_tokens_arg("openai") == "max_completion_tokens" + assert provider_policies.completion_max_tokens_arg("openrouter") == "max_tokens" + assert provider_policies.completion_max_tokens_arg("anthropic") == "max_tokens" + + +def test_provider_policy_uses_exact_match_not_substring() -> None: + assert not provider_policies.should_include_completion_stream_usage("my-openrouter-proxy") + assert provider_policies.completion_max_tokens_arg("my-openrouter-proxy") == "max_tokens" diff --git a/tests/test_responses_handling.py b/tests/test_responses_handling.py index 0b4993c..92e76cf 100644 --- a/tests/test_responses_handling.py +++ b/tests/test_responses_handling.py @@ -1,17 +1,90 @@ from __future__ import annotations -from republic import LLM -from republic.clients.chat import ChatClient +from typing import Any + +import pytest + +from republic import LLM, tool from republic.core.execution import LLMCore +from republic.core.results import ErrorPayload + +from .fakes import ( + make_chunk, + make_response, + make_responses_completed, + make_responses_function_call, + make_responses_function_delta, + make_responses_function_done, + make_responses_output_item_added, + make_responses_output_item_done, + make_responses_response, + make_responses_text_delta, + make_tool_call, +) + + +@tool +def echo(text: str) -> str: + return text.upper() + + +def _compact_stream_events(events: list[Any]) -> list[tuple[str, Any]]: + compact: list[tuple[str, Any]] = [] + for event in events: + if event.kind == "text": + compact.append(("text", event.data["delta"])) + elif event.kind == "tool_call": + call = event.data["call"] + compact.append(("tool_call", (call.get("id"), call["function"]["name"], call["function"]["arguments"]))) + elif event.kind == "tool_result": + compact.append(("tool_result", event.data["result"])) + elif event.kind == "usage": + compact.append(("usage", event.data)) + elif event.kind == "final": + final = event.data + compact.append(( + "final", + { + "text": final.get("text"), + "tool_calls": [ + (call.get("id"), call["function"]["name"], call["function"]["arguments"]) + for call in (final.get("tool_calls") or []) + ], + "tool_results": final.get("tool_results"), + "usage": final.get("usage"), + "error": final.get("error"), + }, + )) + return compact + + +def _completion_stream_event_items() -> list[Any]: + return [ + make_chunk(text="Checking "), + make_chunk(tool_calls=[make_tool_call("echo", '{"text":"to', call_id="call_1")]), + make_chunk( + tool_calls=[make_tool_call("echo", 'kyo"}', call_id="call_1")], + usage={"total_tokens": 12}, + ), + ] + + +def test_default_api_format_uses_completion(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion(make_response(text="hello")) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + result = llm.chat("hi") -from .fakes import make_responses_function_call, make_responses_response + assert result == "hello" + assert client.calls[-1].get("responses") is None -def test_llm_use_responses_calls_responses(fake_anyllm) -> None: +def test_responses_api_format_uses_responses(fake_anyllm) -> None: client = fake_anyllm.ensure("openai") client.queue_responses(make_responses_response(text="hello")) - llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", use_responses=True) + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", api_format="responses") result = llm.chat("hi") assert result == "hello" @@ -19,18 +92,114 @@ def test_llm_use_responses_calls_responses(fake_anyllm) -> None: assert client.calls[-1]["input_data"][0]["role"] == "user" -def test_extract_tool_calls_from_responses() -> None: - response = make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"hi"}')]) +def test_openrouter_responses_works_even_if_provider_flag_is_false(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.SUPPORTS_RESPONSES = False + client.queue_responses(make_responses_response(text="hello")) - calls = ChatClient._extract_tool_calls(response) + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") + result = llm.chat("hi") - assert calls == [ - { - "function": {"name": "echo", "arguments": '{"text":"hi"}'}, - "id": "call_1", - "type": "function", - } - ] + assert result == "hello" + assert client.calls[-1].get("responses") is True + + +def test_openrouter_anthropic_tools_rejects_responses_format(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.SUPPORTS_RESPONSES = False + llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="responses") + + with pytest.raises(ErrorPayload) as exc_info: + llm.tool_calls( + "Call echo for tokyo", + tools=[echo], + tool_choice={"type": "function", "function": {"name": "echo"}}, + ) + assert exc_info.value.kind == "invalid_input" + + +def test_messages_format_maps_to_completion(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_completion(make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')])) + + llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") + calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) + + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "echo" + assert client.calls[-1].get("responses") is None + + +def test_messages_format_rejects_non_anthropic_model(fake_anyllm) -> None: + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", api_format="messages") + with pytest.raises(ErrorPayload) as exc_info: + llm.chat("hi") + assert exc_info.value.kind == "invalid_input" + + +def test_responses_tool_choice_accepts_completion_function_shape(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_responses( + make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) + ) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", api_format="responses") + calls = llm.tool_calls( + "Call echo for tokyo", + tools=[echo], + tool_choice={"type": "function", "function": {"name": "echo"}}, + ) + + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "echo" + assert client.calls[-1]["tool_choice"] == {"type": "function", "name": "echo"} + + +def test_non_stream_completion_splits_concatenated_tool_arguments(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion( + make_response( + tool_calls=[ + make_tool_call( + "echo", + '{"text":"tokyo"}{"text":"osaka"}', + call_id="call_1", + ) + ] + ) + ) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + result = llm.run_tools("Call echo twice", tools=[echo]) + + assert result.kind == "tools" + assert [call["id"] for call in result.tool_calls] == ["call_1", "call_1__2"] + assert result.tool_results == ["TOKYO", "OSAKA"] + + +def test_stream_events_splits_concatenated_tool_arguments(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion( + iter([ + make_chunk( + tool_calls=[ + make_tool_call( + "echo", + '{"text":"tokyo"}{"text":"osaka"}', + call_id="call_1", + ) + ], + usage={"total_tokens": 8}, + ) + ]) + ) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + events = list(llm.stream_events("Call echo twice", tools=[echo])) + + tool_calls = [event.data["call"] for event in events if event.kind == "tool_call"] + assert [call["id"] for call in tool_calls] == ["call_1", "call_1__2"] + assert [call["function"]["arguments"] for call in tool_calls] == ['{"text":"tokyo"}', '{"text":"osaka"}'] def test_split_messages_for_responses() -> None: @@ -59,3 +228,194 @@ def test_split_messages_for_responses() -> None: {"type": "function_call", "name": "echo", "arguments": '{"text":"hi"}', "call_id": "call_1"}, {"type": "function_call_output", "call_id": "call_1", "output": '{"ok":true}'}, ] + + +def test_stream_uses_responses_and_collects_usage(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses( + iter([ + make_responses_text_delta("Hello"), + make_responses_text_delta(" world"), + make_responses_completed({"total_tokens": 7}), + ]) + ) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") + stream = llm.stream("Say hello") + text = "".join(list(stream)) + + assert text == "Hello world" + assert stream.error is None + assert stream.usage == {"total_tokens": 7} + + +def test_stream_events_supports_responses_tool_events(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses( + iter([ + make_responses_text_delta("Checking "), + make_responses_function_delta('{"text":"to', item_id="call_rsp_1"), + make_responses_function_delta('kyo"}', item_id="call_rsp_1"), + make_responses_function_done("echo", '{"text":"tokyo"}', item_id="call_rsp_1"), + make_responses_completed({"total_tokens": 12}), + ]) + ) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") + stream = llm.stream_events("Call echo for tokyo", tools=[echo]) + events = list(stream) + + kinds = [event.kind for event in events] + assert "tool_call" in kinds + assert "tool_result" in kinds + assert "usage" in kinds + assert kinds[-1] == "final" + + +def test_stream_events_parity_between_completion_and_responses(fake_anyllm) -> None: + completion_client = fake_anyllm.ensure("openai") + completion_client.queue_completion(iter(_completion_stream_event_items())) + + responses_client = fake_anyllm.ensure("openrouter") + responses_client.queue_responses( + iter([ + make_responses_text_delta("Checking "), + make_responses_output_item_added(item_id="fc_1", call_id="call_1", name="echo", arguments=""), + make_responses_function_delta('{"text":"to', item_id="fc_1"), + make_responses_function_delta('kyo"}', item_id="fc_1"), + make_responses_output_item_done( + item_id="fc_1", + call_id="call_1", + name="echo", + arguments='{"text":"tokyo"}', + ), + make_responses_completed({"total_tokens": 12}), + ]) + ) + + completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") + + completion_events = list(completion_llm.stream_events("Call echo for tokyo", tools=[echo])) + responses_events = list(responses_llm.stream_events("Call echo for tokyo", tools=[echo])) + + assert _compact_stream_events(completion_events) == _compact_stream_events(responses_events) + + +def test_stream_events_parity_between_completion_and_messages(fake_anyllm) -> None: + completion_client = fake_anyllm.ensure("openai") + completion_client.queue_completion(iter(_completion_stream_event_items())) + + messages_client = fake_anyllm.ensure("openrouter") + messages_client.queue_completion(iter(_completion_stream_event_items())) + + completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + messages_llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") + + completion_events = list(completion_llm.stream_events("Call echo for tokyo", tools=[echo])) + messages_events = list(messages_llm.stream_events("Call echo for tokyo", tools=[echo])) + + assert _compact_stream_events(completion_events) == _compact_stream_events(messages_events) + + +def test_non_stream_responses_tool_calls_converts_tools_payload(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses( + make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) + ) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") + calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) + + assert len(calls) == 1 + sent_tools = client.calls[-1]["tools"] + assert sent_tools[0]["type"] == "function" + assert sent_tools[0]["name"] == "echo" + assert "function" not in sent_tools[0] + + +def test_chat_reasoning_effort_for_responses_is_mapped(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses(make_responses_response(text="ready")) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") + assert llm.chat("Reply with ready", reasoning_effort="low") == "ready" + + call = client.calls[-1] + assert call["responses"] is True + assert call.get("reasoning") == {"effort": "low"} + assert "reasoning_effort" not in call + + +def test_completion_preserves_extra_headers(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion(make_response(text="hello")) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + assert llm.chat("Say hello", extra_headers={"X-Title": "Republic"}) == "hello" + + call = client.calls[-1] + assert call.get("extra_headers") == {"X-Title": "Republic"} + + +def test_messages_preserves_extra_headers(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_completion(make_response(text="hello")) + + llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") + assert llm.chat("Say hello", extra_headers={"X-Title": "Republic"}) == "hello" + + call = client.calls[-1] + assert call.get("extra_headers") == {"X-Title": "Republic"} + + +def test_responses_drops_extra_headers(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses(make_responses_response(text="hello")) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") + assert llm.chat("Say hello", extra_headers={"X-Title": "Republic"}) == "hello" + + call = client.calls[-1] + assert "extra_headers" not in call + + +def test_stream_completion_defaults_include_usage(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion( + iter([ + make_chunk(text="hello"), + make_chunk(text=" world", usage={"total_tokens": 7}), + ]) + ) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + stream = llm.stream("Say hello") + assert "".join(list(stream)) == "hello world" + assert stream.usage == {"total_tokens": 7} + + assert client.calls[-1].get("stream_options") == {"include_usage": True} + + +def test_openai_completion_uses_max_completion_tokens(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion(make_response(text="hello")) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + assert llm.chat("Say hello", max_tokens=11) == "hello" + + call = client.calls[-1] + assert call.get("max_completion_tokens") == 11 + assert "max_tokens" not in call + + +def test_non_openai_completion_uses_max_tokens(fake_anyllm) -> None: + client = fake_anyllm.ensure("anthropic") + client.queue_completion(make_response(text="hello")) + + llm = LLM(model="anthropic:claude-3-5-haiku-latest", api_key="dummy") + assert llm.chat("Say hello", max_tokens=11) == "hello" + + call = client.calls[-1] + assert call.get("max_tokens") == 11 + assert "max_completion_tokens" not in call