From d8c3349b8ea87511496e783fff4d17f5bbb4bbd7 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 1 Mar 2026 17:48:22 +0000 Subject: [PATCH 01/14] fix: restore responses parity across main paths - re-enable responses path for streaming when use_responses is enabled\n- normalize responses streaming tool-call events (arguments + output_item)\n- preserve call_id mapping and support usage from in_progress/failed/incomplete events\n- convert tool schemas for non-stream responses API calls\n- add sync/async parity matrix tests for use_responses=false/true and refactor test fixtures for maintainability --- src/republic/clients/chat.py | 221 ++++++++++++---- src/republic/core/execution.py | 28 +- tests/fakes.py | 48 +++- tests/test_responses_handling.py | 441 ++++++++++++++++++++++++++++++- 4 files changed, 679 insertions(+), 59 deletions(-) diff --git a/src/republic/clients/chat.py b/src/republic/clients/chat.py index 58b0f19..8c045ea 100644 --- a/src/republic/clients/chat.py +++ b/src/republic/clients/chat.py @@ -6,6 +6,7 @@ from collections.abc import AsyncIterator, Callable, Iterator from dataclasses import dataclass from functools import partial +from types import SimpleNamespace from typing import Any from republic.core.errors import ErrorKind, RepublicError @@ -29,6 +30,12 @@ MessageInput = dict[str, Any] +def _field(data: Any, key: str, default: Any = None) -> Any: + if isinstance(data, dict): + return data.get(key, default) + return getattr(data, key, default) + + @dataclass(frozen=True) class PreparedChat: payload: list[dict[str, Any]] @@ -102,8 +109,8 @@ def _resolve_key_by_index(self, tool_call: Any, index: Any, position: int) -> ob return index_key position_key = self._key_at_position(position) - func = getattr(tool_call, "function", None) - tool_name = getattr(func, "name", None) if func is not None else None + func = _field(tool_call, "function") + tool_name = _field(func, "name") if func is not None else None if (tool_name is None or tool_name == "") and position_key is not None and position_key in self._calls: self._index_to_key[index] = position_key @@ -122,8 +129,8 @@ def _resolve_key_by_index(self, tool_call: Any, index: Any, position: int) -> ob return index_key def _resolve_key(self, tool_call: Any, position: int) -> object: - call_id = getattr(tool_call, "id", None) - index = getattr(tool_call, "index", None) + call_id = _field(tool_call, "id") + index = _field(tool_call, "index") if call_id is not None: return self._resolve_key_by_id(call_id, index, position) @@ -138,6 +145,27 @@ def _resolve_key(self, tool_call: Any, position: int) -> object: return position_key return ("position", position) + @staticmethod + def _merge_arguments( + entry: dict[str, Any], + *, + arguments: Any, + arguments_complete: bool, + ) -> None: + if arguments is None: + return + if not isinstance(arguments, str): + entry["function"]["arguments"] = arguments + return + if arguments_complete: + entry["function"]["arguments"] = arguments + return + + existing = entry["function"].get("arguments", "") + if not isinstance(existing, str): + existing = "" + entry["function"]["arguments"] = existing + arguments + def add_deltas(self, tool_calls: list[Any]) -> None: for position, tool_call in enumerate(tool_calls): key = self._resolve_key(tool_call, position) @@ -145,21 +173,24 @@ def add_deltas(self, tool_calls: list[Any]) -> None: self._order.append(key) self._calls[key] = {"function": {"name": "", "arguments": ""}} entry = self._calls[key] - call_id = getattr(tool_call, "id", None) + call_id = _field(tool_call, "id") if call_id: entry["id"] = call_id - call_type = getattr(tool_call, "type", None) + call_type = _field(tool_call, "type") if call_type: entry["type"] = call_type - func = getattr(tool_call, "function", None) + func = _field(tool_call, "function") if func is None: continue - name = getattr(func, "name", None) + name = _field(func, "name") if name: entry["function"]["name"] = name - arguments = getattr(func, "arguments", None) - if arguments: - entry["function"]["arguments"] = entry["function"].get("arguments", "") + arguments + arguments = _field(func, "arguments") + self._merge_arguments( + entry, + arguments=arguments, + arguments_complete=bool(_field(tool_call, "arguments_complete", False)), + ) def finalize(self) -> list[dict[str, Any]]: return [self._calls[key] for key in self._order] @@ -184,6 +215,15 @@ def __init__( def default_context(self) -> TapeContext: return self._tape.default_context + @staticmethod + def _is_non_stream_response(response: Any) -> bool: + return ( + isinstance(response, str) + or _field(response, "choices") is not None + or _field(response, "output") is not None + or _field(response, "output_text") is not None + ) + def _validate_chat_input( self, *, @@ -1374,7 +1414,7 @@ def _build_text_stream( model_id: str, attempt: int, ) -> TextStream: - if hasattr(response, "choices"): + if self._is_non_stream_response(response): text = self._extract_text(response) tool_calls = self._extract_tool_calls(response) state = StreamState() @@ -1438,7 +1478,7 @@ async def _build_async_text_stream( model_id: str, attempt: int, ) -> AsyncTextStream: - if hasattr(response, "choices"): + if self._is_non_stream_response(response): text = self._extract_text(response) tool_calls = self._extract_tool_calls(response) state = StreamState() @@ -1502,25 +1542,81 @@ async def _iterator() -> AsyncIterator[str]: def _chunk_has_tool_calls(chunk: Any) -> bool: return bool(ChatClient._extract_chunk_tool_call_deltas(chunk)) + @staticmethod + def _responses_tool_delta_from_args_event(chunk: Any, event_type: str) -> list[Any]: + item_id = _field(chunk, "item_id") + if not item_id: + return [] + arguments = _field(chunk, "delta") + if event_type == "response.function_call_arguments.done": + arguments = _field(chunk, "arguments") + if not isinstance(arguments, str): + return [] + call_id = _field(chunk, "call_id") + payload: dict[str, Any] = { + "index": item_id, + "type": "function", + "function": SimpleNamespace(name=_field(chunk, "name") or "", arguments=arguments), + "arguments_complete": event_type == "response.function_call_arguments.done", + } + if call_id: + payload["id"] = call_id + return [SimpleNamespace(**payload)] + + @staticmethod + def _responses_tool_delta_from_output_item_event(chunk: Any, event_type: str) -> list[Any]: + item = _field(chunk, "item") + if _field(item, "type") != "function_call": + return [] + item_id = _field(item, "id") + call_id = _field(item, "call_id") or item_id + if not call_id: + return [] + arguments = _field(item, "arguments") + if not isinstance(arguments, str): + arguments = "" + return [ + SimpleNamespace( + id=call_id, + index=item_id or call_id, + type="function", + function=SimpleNamespace(name=_field(item, "name") or "", arguments=arguments), + arguments_complete=event_type == "response.output_item.done", + ) + ] + @staticmethod def _extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: - choices = getattr(chunk, "choices", None) + event_type = _field(chunk, "type") + if event_type in {"response.function_call_arguments.delta", "response.function_call_arguments.done"}: + return ChatClient._responses_tool_delta_from_args_event(chunk, event_type) + if event_type in {"response.output_item.added", "response.output_item.done"}: + return ChatClient._responses_tool_delta_from_output_item_event(chunk, event_type) + + choices = _field(chunk, "choices") if not choices: return [] - delta = getattr(choices[0], "delta", None) + delta = _field(choices[0], "delta") if delta is None: return [] - return getattr(delta, "tool_calls", None) or [] + return _field(delta, "tool_calls") or [] @staticmethod def _extract_chunk_text(chunk: Any) -> str: - choices = getattr(chunk, "choices", None) + event_type = _field(chunk, "type") + if event_type == "response.output_text.delta": + delta = _field(chunk, "delta") + if isinstance(delta, str): + return delta + return "" + + choices = _field(chunk, "choices") if not choices: return "" - delta = getattr(choices[0], "delta", None) + delta = _field(choices[0], "delta") if delta is None: return "" - return getattr(delta, "content", "") or "" + return _field(delta, "content", "") or "" def _build_event_stream( self, @@ -1530,7 +1626,7 @@ def _build_event_stream( model_id: str, attempt: int, ) -> StreamEvents: - if hasattr(response, "choices"): + if self._is_non_stream_response(response): return self._build_event_stream_from_response( prepared, response, @@ -1603,7 +1699,7 @@ def _build_async_event_stream( model_id: str, attempt: int, ) -> AsyncStreamEvents: - if hasattr(response, "choices"): + if self._is_non_stream_response(response): return self._build_async_event_stream_from_response( prepared, response, @@ -1830,50 +1926,62 @@ def _make_tool_context(prepared: PreparedChat, provider_name: str, model_id: str state={} if prepared.context is None else prepared.context.state, ) + @staticmethod + def _extract_text_from_responses_output(output: Any) -> str: + if not isinstance(output, list): + return "" + parts: list[str] = [] + for item in output: + if _field(item, "type") != "message": + continue + content = _field(item, "content") or [] + for entry in content: + if _field(entry, "type") == "output_text": + text = _field(entry, "text") + if text: + parts.append(text) + return "".join(parts) + @staticmethod def _extract_text(response: Any) -> str: if isinstance(response, str): return response - output = getattr(response, "output", None) - if output: - parts: list[str] = [] - for item in output: - if getattr(item, "type", None) != "message": - continue - content = getattr(item, "content", None) or [] - for entry in content: - if getattr(entry, "type", None) == "output_text": - text = getattr(entry, "text", None) - if text: - parts.append(text) - return "".join(parts) - choices = getattr(response, "choices", None) + output_text = _field(response, "output_text") + if isinstance(output_text, str): + return output_text + output = _field(response, "output") + text_from_output = ChatClient._extract_text_from_responses_output(output) + if text_from_output: + return text_from_output + choices = _field(response, "choices") if not choices: return "" - message = getattr(choices[0], "message", None) + message = _field(choices[0], "message") if message is None: return "" - return getattr(message, "content", "") or "" + return _field(message, "content", "") or "" @staticmethod def _extract_tool_calls(response: Any) -> list[dict[str, Any]]: - output = getattr(response, "output", None) - if output: + output = _field(response, "output") + if output is not None: return ChatClient._extract_responses_tool_calls(output) return ChatClient._extract_completion_tool_calls(response) @staticmethod - def _extract_responses_tool_calls(output: list[Any]) -> list[dict[str, Any]]: + def _extract_responses_tool_calls(output: Any) -> list[dict[str, Any]]: + if not isinstance(output, list): + return [] calls: list[dict[str, Any]] = [] for item in output: - if getattr(item, "type", None) != "function_call": + if _field(item, "type") != "function_call": continue - name = getattr(item, "name", None) - arguments = getattr(item, "arguments", None) + name = _field(item, "name") + arguments = _field(item, "arguments") if not name: continue entry: dict[str, Any] = {"function": {"name": name, "arguments": arguments or ""}} - call_id = getattr(item, "call_id", None) or getattr(item, "id", None) + call_id = _field(item, "call_id") or _field(item, "id") if call_id: entry["id"] = call_id entry["type"] = "function" @@ -1882,25 +1990,28 @@ def _extract_responses_tool_calls(output: list[Any]) -> list[dict[str, Any]]: @staticmethod def _extract_completion_tool_calls(response: Any) -> list[dict[str, Any]]: - choices = getattr(response, "choices", None) + choices = _field(response, "choices") if not choices: return [] - message = getattr(choices[0], "message", None) + message = _field(choices[0], "message") if message is None: return [] - tool_calls = getattr(message, "tool_calls", None) or [] + tool_calls = _field(message, "tool_calls") or [] calls: list[dict[str, Any]] = [] for tool_call in tool_calls: + function = _field(tool_call, "function") + if function is None: + continue entry: dict[str, Any] = { "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, + "name": _field(function, "name"), + "arguments": _field(function, "arguments"), } } - call_id = getattr(tool_call, "id", None) + call_id = _field(tool_call, "id") if call_id: entry["id"] = call_id - call_type = getattr(tool_call, "type", None) + call_type = _field(tool_call, "type") if call_type: entry["type"] = call_type calls.append(entry) @@ -1908,7 +2019,11 @@ def _extract_completion_tool_calls(response: Any) -> list[dict[str, Any]]: @staticmethod def _extract_usage(response: Any) -> dict[str, Any] | None: - usage = getattr(response, "usage", None) + event_type = _field(response, "type") + if event_type in {"response.completed", "response.in_progress", "response.failed", "response.incomplete"}: + usage = _field(_field(response, "response"), "usage") + else: + usage = _field(response, "usage") if usage is None: return None if hasattr(usage, "model_dump"): @@ -1917,7 +2032,7 @@ def _extract_usage(response: Any) -> dict[str, Any] | None: return dict(usage) data: dict[str, Any] = {} for field in ("input_tokens", "output_tokens", "total_tokens", "requests"): - value = getattr(usage, field, None) + value = _field(usage, field) if value is not None: data[field] = value return data or None diff --git a/src/republic/core/execution.py b/src/republic/core/execution.py index ccc3a73..094bdf6 100644 --- a/src/republic/core/execution.py +++ b/src/republic/core/execution.py @@ -352,8 +352,30 @@ def _decide_responses_kwargs(self, max_tokens: int | None, kwargs: dict[str, Any return clean_kwargs return {**clean_kwargs, "max_output_tokens": max_tokens} + @staticmethod + def _convert_tools_for_responses(tools_payload: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None: + if not tools_payload: + return tools_payload + + converted_tools: list[dict[str, Any]] = [] + for tool in tools_payload: + function = tool.get("function") + if isinstance(function, dict): + converted: dict[str, Any] = { + "type": tool.get("type", "function"), + "name": function.get("name"), + "description": function.get("description", ""), + "parameters": function.get("parameters", {}), + } + if "strict" in function: + converted["strict"] = function["strict"] + converted_tools.append(converted) + continue + converted_tools.append(dict(tool)) + return converted_tools + def _should_use_responses(self, client: AnyLLM, *, stream: bool) -> bool: - return not stream and self._use_responses and bool(getattr(client, "SUPPORTS_RESPONSES", False)) + return self._use_responses and bool(getattr(client, "SUPPORTS_RESPONSES", False)) def _call_client_sync( self, @@ -373,7 +395,7 @@ def _call_client_sync( return client.responses( model=model_id, input_data=input_items, - tools=tools_payload, + tools=self._convert_tools_for_responses(tools_payload), stream=stream, instructions=instructions, **self._decide_responses_kwargs(max_tokens, kwargs), @@ -405,7 +427,7 @@ async def _call_client_async( return await client.aresponses( model=model_id, input_data=input_items, - tools=tools_payload, + tools=self._convert_tools_for_responses(tools_payload), stream=stream, instructions=instructions, **self._decide_responses_kwargs(max_tokens, kwargs), diff --git a/tests/fakes.py b/tests/fakes.py index 0dbaef9..be13c8d 100644 --- a/tests/fakes.py +++ b/tests/fakes.py @@ -142,10 +142,56 @@ def make_responses_response( *, text: str = "", tool_calls: list[Any] | None = None, + usage: dict[str, Any] | None = None, ) -> Any: output: list[Any] = [] if text: output.append(make_responses_message(text)) if tool_calls: output.extend(tool_calls) - return SimpleNamespace(output=output) + return SimpleNamespace(output=output, usage=usage) + + +def make_responses_text_delta(delta: str) -> Any: + return SimpleNamespace(type="response.output_text.delta", delta=delta) + + +def make_responses_function_delta(delta: str, *, item_id: str = "call_1") -> Any: + return SimpleNamespace(type="response.function_call_arguments.delta", delta=delta, item_id=item_id, output_index=0) + + +def make_responses_function_done(name: str, arguments: str, *, item_id: str = "call_1") -> Any: + return SimpleNamespace( + type="response.function_call_arguments.done", + name=name, + arguments=arguments, + item_id=item_id, + output_index=0, + ) + + +def make_responses_completed(usage: dict[str, Any] | None = None) -> Any: + response = SimpleNamespace(usage=usage) + return SimpleNamespace(type="response.completed", response=response) + + +def make_responses_output_item_added( + *, + item_id: str = "fc_1", + call_id: str = "call_1", + name: str = "echo", + arguments: str = "", +) -> Any: + item = SimpleNamespace(type="function_call", id=item_id, call_id=call_id, name=name, arguments=arguments) + return SimpleNamespace(type="response.output_item.added", item=item) + + +def make_responses_output_item_done( + *, + item_id: str = "fc_1", + call_id: str = "call_1", + name: str = "echo", + arguments: str = '{"text":"tokyo"}', +) -> Any: + item = SimpleNamespace(type="function_call", id=item_id, call_id=call_id, name=name, arguments=arguments) + return SimpleNamespace(type="response.output_item.done", item=item) diff --git a/tests/test_responses_handling.py b/tests/test_responses_handling.py index 0b4993c..70cf47e 100644 --- a/tests/test_responses_handling.py +++ b/tests/test_responses_handling.py @@ -1,10 +1,142 @@ from __future__ import annotations -from republic import LLM +from typing import Any + +import pytest + +from republic import LLM, tool from republic.clients.chat import ChatClient from republic.core.execution import LLMCore -from .fakes import make_responses_function_call, make_responses_response +from .fakes import ( + make_chunk, + make_response, + make_responses_completed, + make_responses_function_call, + make_responses_function_delta, + make_responses_function_done, + make_responses_output_item_added, + make_responses_output_item_done, + make_responses_response, + make_responses_text_delta, + make_tool_call, +) + + +@tool +def echo(text: str) -> str: + return text.upper() + + +def _compact_stream_events(events: list[Any]) -> list[tuple[str, Any]]: + compact: list[tuple[str, Any]] = [] + for event in events: + if event.kind == "text": + compact.append(("text", event.data["delta"])) + elif event.kind == "tool_call": + call = event.data["call"] + compact.append(("tool_call", (call.get("id"), call["function"]["name"], call["function"]["arguments"]))) + elif event.kind == "tool_result": + compact.append(("tool_result", event.data["result"])) + elif event.kind == "usage": + compact.append(("usage", event.data)) + elif event.kind == "error": + compact.append(("error", event.data)) + elif event.kind == "final": + final = event.data + compact.append(( + "final", + { + "text": final.get("text"), + "tool_calls": [ + (call.get("id"), call["function"]["name"], call["function"]["arguments"]) + for call in (final.get("tool_calls") or []) + ], + "tool_results": final.get("tool_results"), + "usage": final.get("usage"), + "error": final.get("error"), + }, + )) + return compact + + +def _as_async_iter(items: list[Any]) -> Any: + async def _generator() -> Any: + for item in items: + yield item + + return _generator() + + +def _responses_stream_text_items() -> list[Any]: + return [ + make_responses_text_delta("hello"), + make_responses_text_delta(" world"), + make_responses_completed({"total_tokens": 7}), + ] + + +def _responses_stream_event_items() -> list[Any]: + return [ + make_responses_text_delta("Checking "), + make_responses_output_item_added(item_id="fc_1", call_id="call_1", name="echo", arguments=""), + make_responses_function_delta('{"text":"to', item_id="fc_1"), + make_responses_function_delta('kyo"}', item_id="fc_1"), + make_responses_output_item_done( + item_id="fc_1", + call_id="call_1", + name="echo", + arguments='{"text":"tokyo"}', + ), + make_responses_completed({"total_tokens": 12}), + ] + + +def _completion_stream_text_items() -> list[Any]: + return [ + make_chunk(text="hello"), + make_chunk(text=" world", usage={"total_tokens": 7}), + ] + + +def _completion_stream_event_items() -> list[Any]: + return [ + make_chunk(text="Checking "), + make_chunk(tool_calls=[make_tool_call("echo", '{"text":"to', call_id="call_1")]), + make_chunk( + tool_calls=[make_tool_call("echo", 'kyo"}', call_id="call_1")], + usage={"total_tokens": 12}, + ), + ] + + +def _main_path_payloads(*, use_responses: bool, async_mode: bool) -> list[Any]: + wrap_stream = _as_async_iter if async_mode else iter + if use_responses: + return [ + make_responses_response(text="ready"), + make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]), + make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]), + wrap_stream(_responses_stream_text_items()), + wrap_stream(_responses_stream_event_items()), + ] + + return [ + make_response(text="ready"), + make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')]), + make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')]), + wrap_stream(_completion_stream_text_items()), + wrap_stream(_completion_stream_event_items()), + ] + + +def _queue_main_path_fixtures(client: Any, *, use_responses: bool, async_mode: bool) -> None: + payloads = _main_path_payloads(use_responses=use_responses, async_mode=async_mode) + if use_responses: + queue = client.queue_aresponses if async_mode else client.queue_responses + else: + queue = client.queue_acompletion if async_mode else client.queue_completion + queue(*payloads) def test_llm_use_responses_calls_responses(fake_anyllm) -> None: @@ -59,3 +191,308 @@ def test_split_messages_for_responses() -> None: {"type": "function_call", "name": "echo", "arguments": '{"text":"hi"}', "call_id": "call_1"}, {"type": "function_call_output", "call_id": "call_1", "output": '{"ok":true}'}, ] + + +def test_stream_uses_responses_and_collects_usage(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses( + iter([ + make_responses_text_delta("Hello"), + make_responses_text_delta(" world"), + make_responses_completed({"total_tokens": 7}), + ]) + ) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + stream = llm.stream("Say hello") + text = "".join(list(stream)) + + assert text == "Hello world" + assert stream.error is None + assert stream.usage == {"total_tokens": 7} + assert client.calls[-1]["responses"] is True + assert client.calls[-1]["stream"] is True + + +def test_stream_events_supports_responses_tool_events(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses( + iter([ + make_responses_text_delta("Checking "), + make_responses_function_delta('{"text":"to', item_id="call_rsp_1"), + make_responses_function_delta('kyo"}', item_id="call_rsp_1"), + make_responses_function_done("echo", '{"text":"tokyo"}', item_id="call_rsp_1"), + make_responses_completed({"total_tokens": 12}), + ]) + ) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + stream = llm.stream_events("Call echo for tokyo", tools=[echo]) + events = list(stream) + + kinds = [event.kind for event in events] + assert "text" in kinds + assert "tool_call" in kinds + assert "tool_result" in kinds + assert "usage" in kinds + assert kinds[-1] == "final" + + tool_calls = [event for event in events if event.kind == "tool_call"] + assert len(tool_calls) == 1 + assert tool_calls[0].data["call"]["function"]["name"] == "echo" + assert tool_calls[0].data["call"]["function"]["arguments"] == '{"text":"tokyo"}' + + tool_results = [event for event in events if event.kind == "tool_result"] + assert len(tool_results) == 1 + assert tool_results[0].data["result"] == "TOKYO" + assert stream.error is None + assert stream.usage == {"total_tokens": 12} + + +def test_stream_and_events_support_responses_dict_events(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses( + iter([ + {"type": "response.output_text.delta", "delta": "Checking "}, + {"type": "response.function_call_arguments.delta", "item_id": "call_d1", "delta": '{"text":"to'}, + { + "type": "response.function_call_arguments.done", + "item_id": "call_d1", + "name": "echo", + "arguments": '{"text":"tokyo"}', + }, + {"type": "response.completed", "response": {"usage": {"total_tokens": 5}}}, + ]) + ) + client.queue_responses(make_responses_response(text="ready", usage={"total_tokens": 3})) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + events_stream = llm.stream_events("Call echo for tokyo", tools=[echo]) + events = list(events_stream) + tool_call = next(event for event in events if event.kind == "tool_call") + assert tool_call.data["call"]["function"]["arguments"] == '{"text":"tokyo"}' + assert events_stream.usage == {"total_tokens": 5} + + text_stream = llm.stream("Reply with ready") + text = "".join(list(text_stream)) + assert text == "ready" + assert text_stream.usage == {"total_tokens": 3} + + +def test_stream_events_parity_between_completion_and_responses(fake_anyllm) -> None: + completion_client = fake_anyllm.ensure("openai") + completion_client.queue_completion( + iter([ + make_chunk(text="Checking "), + make_chunk(tool_calls=[make_tool_call("echo", '{"text":"to', call_id="call_1")]), + make_chunk( + tool_calls=[make_tool_call("echo", 'kyo"}', call_id="call_1")], + usage={"total_tokens": 12}, + ), + ]) + ) + + responses_client = fake_anyllm.ensure("openrouter") + responses_client.queue_responses( + iter([ + make_responses_text_delta("Checking "), + make_responses_output_item_added(item_id="fc_1", call_id="call_1", name="echo", arguments=""), + make_responses_function_delta('{"text":"to', item_id="fc_1"), + make_responses_function_delta('kyo"}', item_id="fc_1"), + make_responses_function_done("echo", '{"text":"tokyo"}', item_id="fc_1"), + make_responses_completed({"total_tokens": 12}), + ]) + ) + + completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + + completion_events = list(completion_llm.stream_events("Call echo for tokyo", tools=[echo])) + responses_events = list(responses_llm.stream_events("Call echo for tokyo", tools=[echo])) + assert completion_client.calls[-1].get("responses") is None + assert responses_client.calls[-1].get("responses") is True + + assert _compact_stream_events(completion_events) == _compact_stream_events(responses_events) + + +def test_stream_events_responses_output_item_events_keep_call_id(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses( + iter([ + make_responses_output_item_added(item_id="fc_123", call_id="call_abc", name="echo"), + make_responses_function_delta('{"text":"to', item_id="fc_123"), + make_responses_function_delta('kyo"}', item_id="fc_123"), + make_responses_output_item_done( + item_id="fc_123", + call_id="call_abc", + name="echo", + arguments='{"text":"tokyo"}', + ), + make_responses_completed({"total_tokens": 6}), + ]) + ) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + events = list(llm.stream_events("Call echo for tokyo", tools=[echo])) + + tool_call = next(event for event in events if event.kind == "tool_call").data["call"] + assert tool_call["id"] == "call_abc" + assert tool_call["function"]["name"] == "echo" + assert tool_call["function"]["arguments"] == '{"text":"tokyo"}' + + +def test_stream_usage_accepts_responses_in_progress_usage(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses( + iter([ + make_responses_text_delta("ok"), + {"type": "response.in_progress", "response": {"usage": {"total_tokens": 2}}}, + ]) + ) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + stream = llm.stream("Reply with ok") + assert "".join(list(stream)) == "ok" + assert stream.usage == {"total_tokens": 2} + + +def test_non_stream_responses_tool_calls_converts_tools_payload(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses( + make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) + ) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) + + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "echo" + sent_tools = client.calls[-1]["tools"] + assert sent_tools[0]["type"] == "function" + assert sent_tools[0]["name"] == "echo" + assert sent_tools[0]["description"] == "" + assert sent_tools[0]["parameters"]["type"] == "object" + assert "function" not in sent_tools[0] + + +def test_non_stream_responses_run_tools_uses_converted_tools(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses( + make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) + ) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + result = llm.run_tools("Call echo for tokyo", tools=[echo]) + + assert result.kind == "tools" + assert result.tool_results == ["TOKYO"] + sent_tools = client.calls[-1]["tools"] + assert sent_tools[0]["type"] == "function" + assert sent_tools[0]["name"] == "echo" + assert "function" not in sent_tools[0] + + +def test_non_stream_chat_parity_between_completion_and_responses(fake_anyllm) -> None: + completion_client = fake_anyllm.ensure("openai") + completion_client.queue_completion(make_response(text="ready")) + + responses_client = fake_anyllm.ensure("openrouter") + responses_client.queue_responses(make_responses_response(text="ready")) + + completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + + assert completion_llm.chat("Reply with ready") == "ready" + assert responses_llm.chat("Reply with ready") == "ready" + + +def test_non_stream_run_tools_parity_between_completion_and_responses(fake_anyllm) -> None: + completion_client = fake_anyllm.ensure("openai") + completion_client.queue_completion(make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')])) + + responses_client = fake_anyllm.ensure("openrouter") + responses_client.queue_responses( + make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) + ) + + completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + + completion_result = completion_llm.run_tools("Call echo for tokyo", tools=[echo]) + responses_result = responses_llm.run_tools("Call echo for tokyo", tools=[echo]) + + assert completion_result.kind == responses_result.kind == "tools" + assert completion_result.tool_results == responses_result.tool_results == ["TOKYO"] + + +@pytest.mark.parametrize("use_responses", [False, True]) +def test_sync_main_paths_with_mode_switch(fake_anyllm, use_responses: bool) -> None: + client = fake_anyllm.ensure("openai") + _queue_main_path_fixtures(client, use_responses=use_responses, async_mode=False) + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", use_responses=use_responses) + + assert llm.chat("Reply with ready") == "ready" + + calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "echo" + + run_result = llm.run_tools("Call echo for tokyo", tools=[echo]) + assert run_result.kind == "tools" + assert run_result.tool_results == ["TOKYO"] + + text_stream = llm.stream("Say hello") + assert "".join(list(text_stream)) == "hello world" + assert text_stream.usage == {"total_tokens": 7} + + event_stream = llm.stream_events("Call echo for tokyo", tools=[echo]) + events = list(event_stream) + compact = _compact_stream_events(events) + assert compact[0] == ("text", "Checking ") + assert compact[1][0] == "tool_call" + assert compact[2] == ("tool_result", "TOKYO") + assert compact[3] == ("usage", {"total_tokens": 12}) + assert compact[-1][0] == "final" + assert event_stream.usage == {"total_tokens": 12} + + if use_responses: + assert all(call.get("responses") is True for call in client.calls) + else: + assert all(call.get("responses") is None for call in client.calls) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_responses", [False, True]) +async def test_async_main_paths_with_mode_switch(fake_anyllm, use_responses: bool) -> None: + client = fake_anyllm.ensure("openai") + _queue_main_path_fixtures(client, use_responses=use_responses, async_mode=True) + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", use_responses=use_responses) + + assert await llm.chat_async("Reply with ready") == "ready" + + calls = await llm.tool_calls_async("Call echo for tokyo", tools=[echo]) + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "echo" + + run_result = await llm.run_tools_async("Call echo for tokyo", tools=[echo]) + assert run_result.kind == "tools" + assert run_result.tool_results == ["TOKYO"] + + text_stream = await llm.stream_async("Say hello") + assert "".join([part async for part in text_stream]) == "hello world" + assert text_stream.usage == {"total_tokens": 7} + + event_stream = await llm.stream_events_async("Call echo for tokyo", tools=[echo]) + events = [event async for event in event_stream] + compact = _compact_stream_events(events) + assert compact[0] == ("text", "Checking ") + assert compact[1][0] == "tool_call" + assert compact[2] == ("tool_result", "TOKYO") + assert compact[3] == ("usage", {"total_tokens": 12}) + assert compact[-1][0] == "final" + assert event_stream.usage == {"total_tokens": 12} + + if use_responses: + assert all(call.get("responses") is True for call in client.calls) + else: + assert all(call.get("responses") is None for call in client.calls) From de41fa9c9023b3d513383416179b02d12934e50c Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 1 Mar 2026 18:09:58 +0000 Subject: [PATCH 02/14] refactor: split chat parsing into completion and responses adapters --- src/republic/clients/chat.py | 192 +++------------------ src/republic/clients/parsing/__init__.py | 1 + src/republic/clients/parsing/common.py | 11 ++ src/republic/clients/parsing/completion.py | 69 ++++++++ src/republic/clients/parsing/responses.py | 145 ++++++++++++++++ 5 files changed, 246 insertions(+), 172 deletions(-) create mode 100644 src/republic/clients/parsing/__init__.py create mode 100644 src/republic/clients/parsing/common.py create mode 100644 src/republic/clients/parsing/completion.py create mode 100644 src/republic/clients/parsing/responses.py diff --git a/src/republic/clients/chat.py b/src/republic/clients/chat.py index 8c045ea..1a314b1 100644 --- a/src/republic/clients/chat.py +++ b/src/republic/clients/chat.py @@ -6,9 +6,11 @@ from collections.abc import AsyncIterator, Callable, Iterator from dataclasses import dataclass from functools import partial -from types import SimpleNamespace from typing import Any +from republic.clients.parsing import completion as completion_parser +from republic.clients.parsing import responses as responses_parser +from republic.clients.parsing.common import field as _field from republic.core.errors import ErrorKind, RepublicError from republic.core.execution import LLMCore from republic.core.results import ( @@ -30,12 +32,6 @@ MessageInput = dict[str, Any] -def _field(data: Any, key: str, default: Any = None) -> Any: - if isinstance(data, dict): - return data.get(key, default) - return getattr(data, key, default) - - @dataclass(frozen=True) class PreparedChat: payload: list[dict[str, Any]] @@ -217,12 +213,7 @@ def default_context(self) -> TapeContext: @staticmethod def _is_non_stream_response(response: Any) -> bool: - return ( - isinstance(response, str) - or _field(response, "choices") is not None - or _field(response, "output") is not None - or _field(response, "output_text") is not None - ) + return responses_parser.is_non_stream_response(response) def _validate_chat_input( self, @@ -1542,81 +1533,19 @@ async def _iterator() -> AsyncIterator[str]: def _chunk_has_tool_calls(chunk: Any) -> bool: return bool(ChatClient._extract_chunk_tool_call_deltas(chunk)) - @staticmethod - def _responses_tool_delta_from_args_event(chunk: Any, event_type: str) -> list[Any]: - item_id = _field(chunk, "item_id") - if not item_id: - return [] - arguments = _field(chunk, "delta") - if event_type == "response.function_call_arguments.done": - arguments = _field(chunk, "arguments") - if not isinstance(arguments, str): - return [] - call_id = _field(chunk, "call_id") - payload: dict[str, Any] = { - "index": item_id, - "type": "function", - "function": SimpleNamespace(name=_field(chunk, "name") or "", arguments=arguments), - "arguments_complete": event_type == "response.function_call_arguments.done", - } - if call_id: - payload["id"] = call_id - return [SimpleNamespace(**payload)] - - @staticmethod - def _responses_tool_delta_from_output_item_event(chunk: Any, event_type: str) -> list[Any]: - item = _field(chunk, "item") - if _field(item, "type") != "function_call": - return [] - item_id = _field(item, "id") - call_id = _field(item, "call_id") or item_id - if not call_id: - return [] - arguments = _field(item, "arguments") - if not isinstance(arguments, str): - arguments = "" - return [ - SimpleNamespace( - id=call_id, - index=item_id or call_id, - type="function", - function=SimpleNamespace(name=_field(item, "name") or "", arguments=arguments), - arguments_complete=event_type == "response.output_item.done", - ) - ] - @staticmethod def _extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: - event_type = _field(chunk, "type") - if event_type in {"response.function_call_arguments.delta", "response.function_call_arguments.done"}: - return ChatClient._responses_tool_delta_from_args_event(chunk, event_type) - if event_type in {"response.output_item.added", "response.output_item.done"}: - return ChatClient._responses_tool_delta_from_output_item_event(chunk, event_type) - - choices = _field(chunk, "choices") - if not choices: - return [] - delta = _field(choices[0], "delta") - if delta is None: - return [] - return _field(delta, "tool_calls") or [] + responses_deltas = responses_parser.extract_chunk_tool_call_deltas(chunk) + if responses_deltas: + return responses_deltas + return completion_parser.extract_chunk_tool_call_deltas(chunk) @staticmethod def _extract_chunk_text(chunk: Any) -> str: - event_type = _field(chunk, "type") - if event_type == "response.output_text.delta": - delta = _field(chunk, "delta") - if isinstance(delta, str): - return delta - return "" - - choices = _field(chunk, "choices") - if not choices: - return "" - delta = _field(choices[0], "delta") - if delta is None: - return "" - return _field(delta, "content", "") or "" + responses_text = responses_parser.extract_chunk_text(chunk) + if responses_text: + return responses_text + return completion_parser.extract_chunk_text(chunk) def _build_event_stream( self, @@ -1928,38 +1857,14 @@ def _make_tool_context(prepared: PreparedChat, provider_name: str, model_id: str @staticmethod def _extract_text_from_responses_output(output: Any) -> str: - if not isinstance(output, list): - return "" - parts: list[str] = [] - for item in output: - if _field(item, "type") != "message": - continue - content = _field(item, "content") or [] - for entry in content: - if _field(entry, "type") == "output_text": - text = _field(entry, "text") - if text: - parts.append(text) - return "".join(parts) + return responses_parser.extract_text_from_output(output) @staticmethod def _extract_text(response: Any) -> str: - if isinstance(response, str): - return response - output_text = _field(response, "output_text") - if isinstance(output_text, str): - return output_text - output = _field(response, "output") - text_from_output = ChatClient._extract_text_from_responses_output(output) - if text_from_output: - return text_from_output - choices = _field(response, "choices") - if not choices: - return "" - message = _field(choices[0], "message") - if message is None: - return "" - return _field(message, "content", "") or "" + responses_text = responses_parser.extract_text(response) + if responses_text: + return responses_text + return completion_parser.extract_text(response) @staticmethod def _extract_tool_calls(response: Any) -> list[dict[str, Any]]: @@ -1970,69 +1875,12 @@ def _extract_tool_calls(response: Any) -> list[dict[str, Any]]: @staticmethod def _extract_responses_tool_calls(output: Any) -> list[dict[str, Any]]: - if not isinstance(output, list): - return [] - calls: list[dict[str, Any]] = [] - for item in output: - if _field(item, "type") != "function_call": - continue - name = _field(item, "name") - arguments = _field(item, "arguments") - if not name: - continue - entry: dict[str, Any] = {"function": {"name": name, "arguments": arguments or ""}} - call_id = _field(item, "call_id") or _field(item, "id") - if call_id: - entry["id"] = call_id - entry["type"] = "function" - calls.append(entry) - return calls + return responses_parser.extract_tool_calls(output) @staticmethod def _extract_completion_tool_calls(response: Any) -> list[dict[str, Any]]: - choices = _field(response, "choices") - if not choices: - return [] - message = _field(choices[0], "message") - if message is None: - return [] - tool_calls = _field(message, "tool_calls") or [] - calls: list[dict[str, Any]] = [] - for tool_call in tool_calls: - function = _field(tool_call, "function") - if function is None: - continue - entry: dict[str, Any] = { - "function": { - "name": _field(function, "name"), - "arguments": _field(function, "arguments"), - } - } - call_id = _field(tool_call, "id") - if call_id: - entry["id"] = call_id - call_type = _field(tool_call, "type") - if call_type: - entry["type"] = call_type - calls.append(entry) - return calls + return completion_parser.extract_tool_calls(response) @staticmethod def _extract_usage(response: Any) -> dict[str, Any] | None: - event_type = _field(response, "type") - if event_type in {"response.completed", "response.in_progress", "response.failed", "response.incomplete"}: - usage = _field(_field(response, "response"), "usage") - else: - usage = _field(response, "usage") - if usage is None: - return None - if hasattr(usage, "model_dump"): - return usage.model_dump() - if isinstance(usage, dict): - return dict(usage) - data: dict[str, Any] = {} - for field in ("input_tokens", "output_tokens", "total_tokens", "requests"): - value = _field(usage, field) - if value is not None: - data[field] = value - return data or None + return responses_parser.extract_usage(response) diff --git a/src/republic/clients/parsing/__init__.py b/src/republic/clients/parsing/__init__.py new file mode 100644 index 0000000..9b7915b --- /dev/null +++ b/src/republic/clients/parsing/__init__.py @@ -0,0 +1 @@ +"""Parsing helpers for provider response payloads.""" diff --git a/src/republic/clients/parsing/common.py b/src/republic/clients/parsing/common.py new file mode 100644 index 0000000..480d0bc --- /dev/null +++ b/src/republic/clients/parsing/common.py @@ -0,0 +1,11 @@ +"""Common parsing utilities shared by completion and responses adapters.""" + +from __future__ import annotations + +from typing import Any + + +def field(data: Any, key: str, default: Any = None) -> Any: + if isinstance(data, dict): + return data.get(key, default) + return getattr(data, key, default) diff --git a/src/republic/clients/parsing/completion.py b/src/republic/clients/parsing/completion.py new file mode 100644 index 0000000..58730b2 --- /dev/null +++ b/src/republic/clients/parsing/completion.py @@ -0,0 +1,69 @@ +"""OpenAI chat-completions shape parsing.""" + +from __future__ import annotations + +from typing import Any + +from republic.clients.parsing.common import field + + +def extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: + choices = field(chunk, "choices") + if not choices: + return [] + delta = field(choices[0], "delta") + if delta is None: + return [] + return field(delta, "tool_calls") or [] + + +def extract_chunk_text(chunk: Any) -> str: + choices = field(chunk, "choices") + if not choices: + return "" + delta = field(choices[0], "delta") + if delta is None: + return "" + return field(delta, "content", "") or "" + + +def extract_text(response: Any) -> str: + if isinstance(response, str): + return response + + choices = field(response, "choices") + if not choices: + return "" + message = field(choices[0], "message") + if message is None: + return "" + return field(message, "content", "") or "" + + +def extract_tool_calls(response: Any) -> list[dict[str, Any]]: + choices = field(response, "choices") + if not choices: + return [] + message = field(choices[0], "message") + if message is None: + return [] + tool_calls = field(message, "tool_calls") or [] + calls: list[dict[str, Any]] = [] + for tool_call in tool_calls: + function = field(tool_call, "function") + if function is None: + continue + entry: dict[str, Any] = { + "function": { + "name": field(function, "name"), + "arguments": field(function, "arguments"), + } + } + call_id = field(tool_call, "id") + if call_id: + entry["id"] = call_id + call_type = field(tool_call, "type") + if call_type: + entry["type"] = call_type + calls.append(entry) + return calls diff --git a/src/republic/clients/parsing/responses.py b/src/republic/clients/parsing/responses.py new file mode 100644 index 0000000..3ee442b --- /dev/null +++ b/src/republic/clients/parsing/responses.py @@ -0,0 +1,145 @@ +"""OpenAI responses shape parsing.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +from republic.clients.parsing.common import field + + +def is_non_stream_response(response: Any) -> bool: + return ( + isinstance(response, str) + or field(response, "choices") is not None + or field(response, "output") is not None + or field(response, "output_text") is not None + ) + + +def _tool_delta_from_args_event(chunk: Any, event_type: str) -> list[Any]: + item_id = field(chunk, "item_id") + if not item_id: + return [] + arguments = field(chunk, "delta") + if event_type == "response.function_call_arguments.done": + arguments = field(chunk, "arguments") + if not isinstance(arguments, str): + return [] + + call_id = field(chunk, "call_id") + payload: dict[str, Any] = { + "index": item_id, + "type": "function", + "function": SimpleNamespace(name=field(chunk, "name") or "", arguments=arguments), + "arguments_complete": event_type == "response.function_call_arguments.done", + } + if call_id: + payload["id"] = call_id + return [SimpleNamespace(**payload)] + + +def _tool_delta_from_output_item_event(chunk: Any, event_type: str) -> list[Any]: + item = field(chunk, "item") + if field(item, "type") != "function_call": + return [] + + item_id = field(item, "id") + call_id = field(item, "call_id") or item_id + if not call_id: + return [] + arguments = field(item, "arguments") + if not isinstance(arguments, str): + arguments = "" + return [ + SimpleNamespace( + id=call_id, + index=item_id or call_id, + type="function", + function=SimpleNamespace(name=field(item, "name") or "", arguments=arguments), + arguments_complete=event_type == "response.output_item.done", + ) + ] + + +def extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: + event_type = field(chunk, "type") + if event_type in {"response.function_call_arguments.delta", "response.function_call_arguments.done"}: + return _tool_delta_from_args_event(chunk, event_type) + if event_type in {"response.output_item.added", "response.output_item.done"}: + return _tool_delta_from_output_item_event(chunk, event_type) + return [] + + +def extract_chunk_text(chunk: Any) -> str: + if field(chunk, "type") != "response.output_text.delta": + return "" + delta = field(chunk, "delta") + if isinstance(delta, str): + return delta + return "" + + +def extract_text_from_output(output: Any) -> str: + if not isinstance(output, list): + return "" + parts: list[str] = [] + for item in output: + if field(item, "type") != "message": + continue + content = field(item, "content") or [] + for entry in content: + if field(entry, "type") == "output_text": + text = field(entry, "text") + if text: + parts.append(text) + return "".join(parts) + + +def extract_text(response: Any) -> str: + output_text = field(response, "output_text") + if isinstance(output_text, str): + return output_text + return extract_text_from_output(field(response, "output")) + + +def extract_tool_calls(output: Any) -> list[dict[str, Any]]: + if not isinstance(output, list): + return [] + calls: list[dict[str, Any]] = [] + for item in output: + if field(item, "type") != "function_call": + continue + name = field(item, "name") + arguments = field(item, "arguments") + if not name: + continue + entry: dict[str, Any] = {"function": {"name": name, "arguments": arguments or ""}} + call_id = field(item, "call_id") or field(item, "id") + if call_id: + entry["id"] = call_id + entry["type"] = "function" + calls.append(entry) + return calls + + +def extract_usage(response: Any) -> dict[str, Any] | None: + event_type = field(response, "type") + if event_type in {"response.completed", "response.in_progress", "response.failed", "response.incomplete"}: + usage = field(field(response, "response"), "usage") + else: + usage = field(response, "usage") + + if usage is None: + return None + if hasattr(usage, "model_dump"): + return usage.model_dump() + if isinstance(usage, dict): + return dict(usage) + + data: dict[str, Any] = {} + for usage_field in ("input_tokens", "output_tokens", "total_tokens", "requests"): + value = field(usage, usage_field) + if value is not None: + data[usage_field] = value + return data or None From de271158c43d1baeb89401c4fc88ccca3731bd53 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 1 Mar 2026 18:10:52 +0000 Subject: [PATCH 03/14] fix: align reasoning and stream usage defaults across chat and responses --- src/republic/clients/chat.py | 18 ++++++-- src/republic/core/execution.py | 59 +++++++++++++++++++++++---- tests/test_responses_handling.py | 70 ++++++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 13 deletions(-) diff --git a/src/republic/clients/chat.py b/src/republic/clients/chat.py index 1a314b1..30cdd79 100644 --- a/src/republic/clients/chat.py +++ b/src/republic/clients/chat.py @@ -411,6 +411,14 @@ async def _prepare_request_async( context=context, ) + @staticmethod + def _split_reasoning_effort(kwargs: dict[str, Any]) -> tuple[Any | None, dict[str, Any]]: + if "reasoning_effort" not in kwargs: + return None, kwargs + request_kwargs = dict(kwargs) + reasoning_effort = request_kwargs.pop("reasoning_effort", None) + return reasoning_effort, request_kwargs + def _execute_sync( self, prepared: PreparedChat, @@ -425,6 +433,7 @@ def _execute_sync( ) -> Any: if prepared.context_error is not None: raise prepared.context_error + reasoning_effort, request_kwargs = self._split_reasoning_effort(kwargs) try: return self._core.run_chat_sync( messages_payload=prepared.payload, @@ -433,8 +442,8 @@ def _execute_sync( provider=provider, max_tokens=max_tokens, stream=stream, - reasoning_effort=None, - kwargs=kwargs, + reasoning_effort=reasoning_effort, + kwargs=request_kwargs, on_response=on_response, ) except RepublicError as exc: @@ -454,6 +463,7 @@ async def _execute_async( ) -> Any: if prepared.context_error is not None: raise prepared.context_error + reasoning_effort, request_kwargs = self._split_reasoning_effort(kwargs) try: return await self._core.run_chat_async( messages_payload=prepared.payload, @@ -462,8 +472,8 @@ async def _execute_async( provider=provider, max_tokens=max_tokens, stream=stream, - reasoning_effort=None, - kwargs=kwargs, + reasoning_effort=reasoning_effort, + kwargs=request_kwargs, on_response=on_response, ) except RepublicError as exc: diff --git a/src/republic/core/execution.py b/src/republic/core/execution.py index 094bdf6..79378b9 100644 --- a/src/republic/core/execution.py +++ b/src/republic/core/execution.py @@ -339,19 +339,54 @@ def _handle_attempt_error(self, exc: Exception, provider_name: str, model_id: st def _decide_kwargs_for_provider( self, provider: str, max_tokens: int | None, kwargs: dict[str, Any] ) -> dict[str, Any]: + clean_kwargs = self._sanitize_request_kwargs(kwargs) use_completion_tokens = "openai" in provider.lower() if not use_completion_tokens: - return {**kwargs, "max_tokens": max_tokens} - if "max_completion_tokens" in kwargs: - return kwargs - return {**kwargs, "max_completion_tokens": max_tokens} + return {**clean_kwargs, "max_tokens": max_tokens} + if "max_completion_tokens" in clean_kwargs: + return clean_kwargs + return {**clean_kwargs, "max_completion_tokens": max_tokens} def _decide_responses_kwargs(self, max_tokens: int | None, kwargs: dict[str, Any]) -> dict[str, Any]: - clean_kwargs = {k: v for k, v in kwargs.items() if k != "extra_headers"} + clean_kwargs = self._sanitize_request_kwargs(kwargs) if "max_output_tokens" in clean_kwargs: return clean_kwargs return {**clean_kwargs, "max_output_tokens": max_tokens} + @staticmethod + def _sanitize_request_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in kwargs.items() if k != "extra_headers"} + + @staticmethod + def _should_default_completion_stream_usage(provider_name: str) -> bool: + lowered_provider = provider_name.lower() + return "openai" in lowered_provider or "openrouter" in lowered_provider + + def _with_default_completion_stream_options( + self, + provider_name: str, + stream: bool, + kwargs: dict[str, Any], + ) -> dict[str, Any]: + if not stream: + return kwargs + if not self._should_default_completion_stream_usage(provider_name): + return kwargs + if "stream_options" in kwargs: + return kwargs + return {**kwargs, "stream_options": {"include_usage": True}} + + @staticmethod + def _with_responses_reasoning( + kwargs: dict[str, Any], + reasoning_effort: Any | None, + ) -> dict[str, Any]: + if reasoning_effort is None: + return kwargs + if "reasoning" in kwargs: + return kwargs + return {**kwargs, "reasoning": {"effort": reasoning_effort}} + @staticmethod def _convert_tools_for_responses(tools_payload: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None: if not tools_payload: @@ -392,21 +427,24 @@ def _call_client_sync( ) -> Any: if self._should_use_responses(client, stream=stream): instructions, input_items = self._split_messages_for_responses(messages_payload) + responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) return client.responses( model=model_id, input_data=input_items, tools=self._convert_tools_for_responses(tools_payload), stream=stream, instructions=instructions, - **self._decide_responses_kwargs(max_tokens, kwargs), + **self._decide_responses_kwargs(max_tokens, responses_kwargs), ) + completion_kwargs = self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs) + completion_kwargs = self._with_default_completion_stream_options(provider_name, stream, completion_kwargs) return client.completion( model=model_id, messages=messages_payload, tools=tools_payload, stream=stream, reasoning_effort=reasoning_effort, - **self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs), + **completion_kwargs, ) async def _call_client_async( @@ -424,21 +462,24 @@ async def _call_client_async( ) -> Any: if self._should_use_responses(client, stream=stream): instructions, input_items = self._split_messages_for_responses(messages_payload) + responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) return await client.aresponses( model=model_id, input_data=input_items, tools=self._convert_tools_for_responses(tools_payload), stream=stream, instructions=instructions, - **self._decide_responses_kwargs(max_tokens, kwargs), + **self._decide_responses_kwargs(max_tokens, responses_kwargs), ) + completion_kwargs = self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs) + completion_kwargs = self._with_default_completion_stream_options(provider_name, stream, completion_kwargs) return await client.acompletion( model=model_id, messages=messages_payload, tools=tools_payload, stream=stream, reasoning_effort=reasoning_effort, - **self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs), + **completion_kwargs, ) @staticmethod diff --git a/tests/test_responses_handling.py b/tests/test_responses_handling.py index 70cf47e..7962f04 100644 --- a/tests/test_responses_handling.py +++ b/tests/test_responses_handling.py @@ -392,6 +392,76 @@ def test_non_stream_responses_run_tools_uses_converted_tools(fake_anyllm) -> Non assert "function" not in sent_tools[0] +def test_chat_reasoning_effort_for_completion_is_forwarded(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion(make_response(text="ready")) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + assert llm.chat("Reply with ready", reasoning_effort="high") == "ready" + + call = client.calls[-1] + assert call.get("reasoning_effort") == "high" + + +def test_chat_reasoning_effort_for_responses_is_mapped(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses(make_responses_response(text="ready")) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + assert llm.chat("Reply with ready", reasoning_effort="low") == "ready" + + call = client.calls[-1] + assert call["responses"] is True + assert call.get("reasoning") == {"effort": "low"} + assert "reasoning_effort" not in call + + +def test_chat_reasoning_kwarg_has_priority_over_reasoning_effort(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses(make_responses_response(text="ready")) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + assert llm.chat("Reply with ready", reasoning_effort="low", reasoning={"effort": "high"}) == "ready" + + call = client.calls[-1] + assert call["responses"] is True + assert call.get("reasoning") == {"effort": "high"} + + +def test_stream_completion_defaults_include_usage(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion( + iter([ + make_chunk(text="hello"), + make_chunk(text=" world", usage={"total_tokens": 7}), + ]) + ) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + stream = llm.stream("Say hello") + assert "".join(list(stream)) == "hello world" + assert stream.usage == {"total_tokens": 7} + + assert client.calls[-1].get("stream_options") == {"include_usage": True} + + +def test_stream_completion_preserves_user_stream_options(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion( + iter([ + make_chunk(text="hello"), + make_chunk(text=" world", usage={"total_tokens": 7}), + ]) + ) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + stream = llm.stream("Say hello", stream_options={"include_usage": False, "custom": True}) + assert "".join(list(stream)) == "hello world" + assert stream.usage == {"total_tokens": 7} + + assert client.calls[-1].get("stream_options") == {"include_usage": False, "custom": True} + + def test_non_stream_chat_parity_between_completion_and_responses(fake_anyllm) -> None: completion_client = fake_anyllm.ensure("openai") completion_client.queue_completion(make_response(text="ready")) From 16da95d46af16d4c8e9cb079af813d6c220c3253 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 1 Mar 2026 19:06:38 +0000 Subject: [PATCH 04/14] refactor: centralize provider compatibility and tool-call normalization --- src/republic/clients/chat.py | 3 +- src/republic/clients/parsing/common.py | 58 ++++++++++++ src/republic/clients/parsing/completion.py | 4 +- src/republic/clients/parsing/responses.py | 25 ++++- src/republic/core/execution.py | 26 +++--- src/republic/core/provider_policies.py | 50 ++++++++++ tests/test_provider_policies.py | 57 ++++++++++++ tests/test_responses_handling.py | 103 +++++++++++++++++++++ 8 files changed, 310 insertions(+), 16 deletions(-) create mode 100644 src/republic/core/provider_policies.py create mode 100644 tests/test_provider_policies.py diff --git a/src/republic/clients/chat.py b/src/republic/clients/chat.py index 30cdd79..907c88c 100644 --- a/src/republic/clients/chat.py +++ b/src/republic/clients/chat.py @@ -10,6 +10,7 @@ from republic.clients.parsing import completion as completion_parser from republic.clients.parsing import responses as responses_parser +from republic.clients.parsing.common import expand_tool_calls from republic.clients.parsing.common import field as _field from republic.core.errors import ErrorKind, RepublicError from republic.core.execution import LLMCore @@ -189,7 +190,7 @@ def add_deltas(self, tool_calls: list[Any]) -> None: ) def finalize(self) -> list[dict[str, Any]]: - return [self._calls[key] for key in self._order] + return expand_tool_calls([self._calls[key] for key in self._order]) class ChatClient: diff --git a/src/republic/clients/parsing/common.py b/src/republic/clients/parsing/common.py index 480d0bc..d46f51f 100644 --- a/src/republic/clients/parsing/common.py +++ b/src/republic/clients/parsing/common.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json from typing import Any @@ -9,3 +10,60 @@ def field(data: Any, key: str, default: Any = None) -> Any: if isinstance(data, dict): return data.get(key, default) return getattr(data, key, default) + + +def expand_tool_calls(calls: list[dict[str, Any]]) -> list[dict[str, Any]]: + expanded: list[dict[str, Any]] = [] + for call in calls: + expanded.extend(_expand_tool_call(call)) + return expanded + + +def _expand_tool_call(call: dict[str, Any]) -> list[dict[str, Any]]: + function = field(call, "function") + if not isinstance(function, dict): + return [dict(call)] + + arguments = field(function, "arguments") + if not isinstance(arguments, str): + return [dict(call)] + + chunks = _split_concatenated_json_objects(arguments) + if not chunks: + return [dict(call)] + + call_id = field(call, "id") + expanded: list[dict[str, Any]] = [] + for index, chunk in enumerate(chunks): + cloned = dict(call) + cloned_function = dict(function) + cloned_function["arguments"] = chunk + cloned["function"] = cloned_function + if isinstance(call_id, str) and call_id and index > 0: + cloned["id"] = f"{call_id}__{index + 1}" + expanded.append(cloned) + return expanded + + +def _split_concatenated_json_objects(raw: str) -> list[str]: + decoder = json.JSONDecoder() + chunks: list[str] = [] + position = 0 + total = len(raw) + while position < total: + while position < total and raw[position].isspace(): + position += 1 + if position >= total: + break + try: + parsed, end = decoder.raw_decode(raw, position) + except json.JSONDecodeError: + return [] + if not isinstance(parsed, dict): + return [] + chunks.append(raw[position:end]) + position = end + + if len(chunks) <= 1: + return [] + return chunks diff --git a/src/republic/clients/parsing/completion.py b/src/republic/clients/parsing/completion.py index 58730b2..26c2f4c 100644 --- a/src/republic/clients/parsing/completion.py +++ b/src/republic/clients/parsing/completion.py @@ -4,7 +4,7 @@ from typing import Any -from republic.clients.parsing.common import field +from republic.clients.parsing.common import expand_tool_calls, field def extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: @@ -66,4 +66,4 @@ def extract_tool_calls(response: Any) -> list[dict[str, Any]]: if call_type: entry["type"] = call_type calls.append(entry) - return calls + return expand_tool_calls(calls) diff --git a/src/republic/clients/parsing/responses.py b/src/republic/clients/parsing/responses.py index 3ee442b..e483d07 100644 --- a/src/republic/clients/parsing/responses.py +++ b/src/republic/clients/parsing/responses.py @@ -5,7 +5,7 @@ from types import SimpleNamespace from typing import Any -from republic.clients.parsing.common import field +from republic.clients.parsing.common import expand_tool_calls, field def is_non_stream_response(response: Any) -> bool: @@ -17,6 +17,27 @@ def is_non_stream_response(response: Any) -> bool: ) +def normalize_request_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: + """Normalize completion-style kwargs into responses-compatible shapes.""" + tool_choice = kwargs.get("tool_choice") + if not isinstance(tool_choice, dict): + return kwargs + + function = tool_choice.get("function") + if not isinstance(function, dict): + return kwargs + + function_name = function.get("name") + if not isinstance(function_name, str) or not function_name: + return kwargs + + normalized_tool_choice = dict(tool_choice) + normalized_tool_choice.pop("function", None) + normalized_tool_choice["type"] = normalized_tool_choice.get("type", "function") + normalized_tool_choice["name"] = function_name + return {**kwargs, "tool_choice": normalized_tool_choice} + + def _tool_delta_from_args_event(chunk: Any, event_type: str) -> list[Any]: item_id = field(chunk, "item_id") if not item_id: @@ -120,7 +141,7 @@ def extract_tool_calls(output: Any) -> list[dict[str, Any]]: entry["id"] = call_id entry["type"] = "function" calls.append(entry) - return calls + return expand_tool_calls(calls) def extract_usage(response: Any) -> dict[str, Any] | None: diff --git a/src/republic/core/execution.py b/src/republic/core/execution.py index 79378b9..c17abbf 100644 --- a/src/republic/core/execution.py +++ b/src/republic/core/execution.py @@ -26,6 +26,8 @@ UnsupportedProviderError, ) +from republic.clients.parsing import responses as responses_parser +from republic.core import provider_policies from republic.core.errors import ErrorKind, RepublicError logger = logging.getLogger(__name__) @@ -340,15 +342,14 @@ def _decide_kwargs_for_provider( self, provider: str, max_tokens: int | None, kwargs: dict[str, Any] ) -> dict[str, Any]: clean_kwargs = self._sanitize_request_kwargs(kwargs) - use_completion_tokens = "openai" in provider.lower() - if not use_completion_tokens: - return {**clean_kwargs, "max_tokens": max_tokens} - if "max_completion_tokens" in clean_kwargs: + max_tokens_arg = provider_policies.completion_max_tokens_arg(provider) + if max_tokens_arg in clean_kwargs: return clean_kwargs - return {**clean_kwargs, "max_completion_tokens": max_tokens} + return {**clean_kwargs, max_tokens_arg: max_tokens} def _decide_responses_kwargs(self, max_tokens: int | None, kwargs: dict[str, Any]) -> dict[str, Any]: clean_kwargs = self._sanitize_request_kwargs(kwargs) + clean_kwargs = responses_parser.normalize_request_kwargs(clean_kwargs) if "max_output_tokens" in clean_kwargs: return clean_kwargs return {**clean_kwargs, "max_output_tokens": max_tokens} @@ -359,8 +360,7 @@ def _sanitize_request_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: @staticmethod def _should_default_completion_stream_usage(provider_name: str) -> bool: - lowered_provider = provider_name.lower() - return "openai" in lowered_provider or "openrouter" in lowered_provider + return provider_policies.should_include_completion_stream_usage(provider_name) def _with_default_completion_stream_options( self, @@ -409,8 +409,12 @@ def _convert_tools_for_responses(tools_payload: list[dict[str, Any]] | None) -> converted_tools.append(dict(tool)) return converted_tools - def _should_use_responses(self, client: AnyLLM, *, stream: bool) -> bool: - return self._use_responses and bool(getattr(client, "SUPPORTS_RESPONSES", False)) + def _should_use_responses(self, client: AnyLLM, *, provider_name: str) -> bool: + return provider_policies.should_use_responses( + provider_name=provider_name, + use_responses=self._use_responses, + supports_responses=bool(getattr(client, "SUPPORTS_RESPONSES", False)), + ) def _call_client_sync( self, @@ -425,7 +429,7 @@ def _call_client_sync( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - if self._should_use_responses(client, stream=stream): + if self._should_use_responses(client, provider_name=provider_name): instructions, input_items = self._split_messages_for_responses(messages_payload) responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) return client.responses( @@ -460,7 +464,7 @@ async def _call_client_async( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - if self._should_use_responses(client, stream=stream): + if self._should_use_responses(client, provider_name=provider_name): instructions, input_items = self._split_messages_for_responses(messages_payload) responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) return await client.aresponses( diff --git a/src/republic/core/provider_policies.py b/src/republic/core/provider_policies.py new file mode 100644 index 0000000..a59851f --- /dev/null +++ b/src/republic/core/provider_policies.py @@ -0,0 +1,50 @@ +"""Provider policy decisions shared across request paths.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ProviderPolicy: + enable_responses_without_capability: bool = False + include_usage_in_completion_stream: bool = False + completion_max_tokens_arg: str = "max_tokens" + + +_DEFAULT_POLICY = ProviderPolicy() +_POLICIES: dict[str, ProviderPolicy] = { + "openai": ProviderPolicy( + include_usage_in_completion_stream=True, + completion_max_tokens_arg="max_completion_tokens", + ), + # any-llm supports OpenRouter responses in practice but still reports SUPPORTS_RESPONSES=False. + "openrouter": ProviderPolicy( + enable_responses_without_capability=True, + include_usage_in_completion_stream=True, + ), +} + + +def provider_policy(provider_name: str) -> ProviderPolicy: + lowered = provider_name.lower() + for key, policy in _POLICIES.items(): + if key in lowered: + return policy + return _DEFAULT_POLICY + + +def should_use_responses(*, provider_name: str, use_responses: bool, supports_responses: bool) -> bool: + if not use_responses: + return False + if supports_responses: + return True + return provider_policy(provider_name).enable_responses_without_capability + + +def should_include_completion_stream_usage(provider_name: str) -> bool: + return provider_policy(provider_name).include_usage_in_completion_stream + + +def completion_max_tokens_arg(provider_name: str) -> str: + return provider_policy(provider_name).completion_max_tokens_arg diff --git a/tests/test_provider_policies.py b/tests/test_provider_policies.py new file mode 100644 index 0000000..74648bc --- /dev/null +++ b/tests/test_provider_policies.py @@ -0,0 +1,57 @@ +from republic.core import provider_policies + + +def test_should_use_responses_respects_global_flag() -> None: + assert ( + provider_policies.should_use_responses( + provider_name="openai", + use_responses=False, + supports_responses=True, + ) + is False + ) + + +def test_should_use_responses_accepts_provider_capability() -> None: + assert ( + provider_policies.should_use_responses( + provider_name="anthropic", + use_responses=True, + supports_responses=True, + ) + is True + ) + + +def test_should_use_responses_openrouter_policy_fallback() -> None: + assert ( + provider_policies.should_use_responses( + provider_name="openrouter", + use_responses=True, + supports_responses=False, + ) + is True + ) + + +def test_should_use_responses_requires_explicit_policy_or_capability() -> None: + assert ( + provider_policies.should_use_responses( + provider_name="anthropic", + use_responses=True, + supports_responses=False, + ) + is False + ) + + +def test_completion_stream_usage_policy() -> None: + assert provider_policies.should_include_completion_stream_usage("openai") + assert provider_policies.should_include_completion_stream_usage("openrouter") + assert not provider_policies.should_include_completion_stream_usage("anthropic") + + +def test_completion_max_tokens_arg_policy() -> None: + assert provider_policies.completion_max_tokens_arg("openai") == "max_completion_tokens" + assert provider_policies.completion_max_tokens_arg("openrouter") == "max_tokens" + assert provider_policies.completion_max_tokens_arg("anthropic") == "max_tokens" diff --git a/tests/test_responses_handling.py b/tests/test_responses_handling.py index 7962f04..dc6242d 100644 --- a/tests/test_responses_handling.py +++ b/tests/test_responses_handling.py @@ -151,6 +151,36 @@ def test_llm_use_responses_calls_responses(fake_anyllm) -> None: assert client.calls[-1]["input_data"][0]["role"] == "user" +def test_openrouter_uses_responses_when_enabled_even_if_provider_flag_is_false(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.SUPPORTS_RESPONSES = False + client.queue_responses(make_responses_response(text="hello")) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + result = llm.chat("hi") + + assert result == "hello" + assert client.calls[-1].get("responses") is True + + +def test_responses_tool_choice_accepts_completion_function_shape(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_responses( + make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) + ) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", use_responses=True) + calls = llm.tool_calls( + "Call echo for tokyo", + tools=[echo], + tool_choice={"type": "function", "function": {"name": "echo"}}, + ) + + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "echo" + assert client.calls[-1]["tool_choice"] == {"type": "function", "name": "echo"} + + def test_extract_tool_calls_from_responses() -> None: response = make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"hi"}')]) @@ -165,6 +195,55 @@ def test_extract_tool_calls_from_responses() -> None: ] +def test_non_stream_completion_splits_concatenated_tool_arguments(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion( + make_response( + tool_calls=[ + make_tool_call( + "echo", + '{"text":"tokyo"}{"text":"osaka"}', + call_id="call_1", + ) + ] + ) + ) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + result = llm.run_tools("Call echo twice", tools=[echo]) + + assert result.kind == "tools" + assert [call["id"] for call in result.tool_calls] == ["call_1", "call_1__2"] + assert result.tool_results == ["TOKYO", "OSAKA"] + + +def test_stream_events_splits_concatenated_tool_arguments(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion( + iter([ + make_chunk( + tool_calls=[ + make_tool_call( + "echo", + '{"text":"tokyo"}{"text":"osaka"}', + call_id="call_1", + ) + ], + usage={"total_tokens": 8}, + ) + ]) + ) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + events = list(llm.stream_events("Call echo twice", tools=[echo])) + + tool_calls = [event.data["call"] for event in events if event.kind == "tool_call"] + assert [call["id"] for call in tool_calls] == ["call_1", "call_1__2"] + assert [call["function"]["arguments"] for call in tool_calls] == ['{"text":"tokyo"}', '{"text":"osaka"}'] + tool_results = [event.data["result"] for event in events if event.kind == "tool_result"] + assert tool_results == ["TOKYO", "OSAKA"] + + def test_split_messages_for_responses() -> None: messages = [ {"role": "system", "content": "sys"}, @@ -462,6 +541,30 @@ def test_stream_completion_preserves_user_stream_options(fake_anyllm) -> None: assert client.calls[-1].get("stream_options") == {"include_usage": False, "custom": True} +def test_openai_completion_uses_max_completion_tokens(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion(make_response(text="hello")) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + assert llm.chat("Say hello", max_tokens=11) == "hello" + + call = client.calls[-1] + assert call.get("max_completion_tokens") == 11 + assert "max_tokens" not in call + + +def test_non_openai_completion_uses_max_tokens(fake_anyllm) -> None: + client = fake_anyllm.ensure("anthropic") + client.queue_completion(make_response(text="hello")) + + llm = LLM(model="anthropic:claude-3-5-haiku-latest", api_key="dummy") + assert llm.chat("Say hello", max_tokens=11) == "hello" + + call = client.calls[-1] + assert call.get("max_tokens") == 11 + assert "max_completion_tokens" not in call + + def test_non_stream_chat_parity_between_completion_and_responses(fake_anyllm) -> None: completion_client = fake_anyllm.ensure("openai") completion_client.queue_completion(make_response(text="ready")) From 2df674f237343f0ee981ae18baa332681bfb3d03 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 1 Mar 2026 19:15:13 +0000 Subject: [PATCH 05/14] fix: gate openrouter responses by model and tool context --- src/republic/core/execution.py | 25 +++++++++++++++-- src/republic/core/provider_policies.py | 29 +++++++++++++++---- tests/test_provider_policies.py | 39 ++++++++++++++++++++++++++ tests/test_responses_handling.py | 17 +++++++++++ 4 files changed, 101 insertions(+), 9 deletions(-) diff --git a/src/republic/core/execution.py b/src/republic/core/execution.py index c17abbf..1cc0695 100644 --- a/src/republic/core/execution.py +++ b/src/republic/core/execution.py @@ -409,9 +409,18 @@ def _convert_tools_for_responses(tools_payload: list[dict[str, Any]] | None) -> converted_tools.append(dict(tool)) return converted_tools - def _should_use_responses(self, client: AnyLLM, *, provider_name: str) -> bool: + def _should_use_responses( + self, + client: AnyLLM, + *, + provider_name: str, + model_id: str, + tools_payload: list[dict[str, Any]] | None, + ) -> bool: return provider_policies.should_use_responses( provider_name=provider_name, + model_id=model_id, + has_tools=bool(tools_payload), use_responses=self._use_responses, supports_responses=bool(getattr(client, "SUPPORTS_RESPONSES", False)), ) @@ -429,7 +438,12 @@ def _call_client_sync( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - if self._should_use_responses(client, provider_name=provider_name): + if self._should_use_responses( + client, + provider_name=provider_name, + model_id=model_id, + tools_payload=tools_payload, + ): instructions, input_items = self._split_messages_for_responses(messages_payload) responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) return client.responses( @@ -464,7 +478,12 @@ async def _call_client_async( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - if self._should_use_responses(client, provider_name=provider_name): + if self._should_use_responses( + client, + provider_name=provider_name, + model_id=model_id, + tools_payload=tools_payload, + ): instructions, input_items = self._split_messages_for_responses(messages_payload) responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) return await client.aresponses( diff --git a/src/republic/core/provider_policies.py b/src/republic/core/provider_policies.py index a59851f..1fae74b 100644 --- a/src/republic/core/provider_policies.py +++ b/src/republic/core/provider_policies.py @@ -10,6 +10,7 @@ class ProviderPolicy: enable_responses_without_capability: bool = False include_usage_in_completion_stream: bool = False completion_max_tokens_arg: str = "max_tokens" + responses_tools_blocked_model_prefixes: tuple[str, ...] = () _DEFAULT_POLICY = ProviderPolicy() @@ -22,21 +23,37 @@ class ProviderPolicy: "openrouter": ProviderPolicy( enable_responses_without_capability=True, include_usage_in_completion_stream=True, + responses_tools_blocked_model_prefixes=("anthropic/",), ), } +def _normalize_provider_name(provider_name: str) -> str: + return provider_name.strip().lower() + + def provider_policy(provider_name: str) -> ProviderPolicy: - lowered = provider_name.lower() - for key, policy in _POLICIES.items(): - if key in lowered: - return policy - return _DEFAULT_POLICY + return _POLICIES.get(_normalize_provider_name(provider_name), _DEFAULT_POLICY) + +def _responses_tools_blocked_for_model(provider_name: str, model_id: str) -> bool: + policy = provider_policy(provider_name) + lowered_model = model_id.strip().lower() + return any(lowered_model.startswith(prefix) for prefix in policy.responses_tools_blocked_model_prefixes) -def should_use_responses(*, provider_name: str, use_responses: bool, supports_responses: bool) -> bool: + +def should_use_responses( + *, + provider_name: str, + model_id: str, + has_tools: bool, + use_responses: bool, + supports_responses: bool, +) -> bool: if not use_responses: return False + if has_tools and _responses_tools_blocked_for_model(provider_name, model_id): + return False if supports_responses: return True return provider_policy(provider_name).enable_responses_without_capability diff --git a/tests/test_provider_policies.py b/tests/test_provider_policies.py index 74648bc..e7f8108 100644 --- a/tests/test_provider_policies.py +++ b/tests/test_provider_policies.py @@ -5,6 +5,8 @@ def test_should_use_responses_respects_global_flag() -> None: assert ( provider_policies.should_use_responses( provider_name="openai", + model_id="gpt-4o-mini", + has_tools=False, use_responses=False, supports_responses=True, ) @@ -16,6 +18,8 @@ def test_should_use_responses_accepts_provider_capability() -> None: assert ( provider_policies.should_use_responses( provider_name="anthropic", + model_id="claude-3-5-haiku-latest", + has_tools=False, use_responses=True, supports_responses=True, ) @@ -27,6 +31,8 @@ def test_should_use_responses_openrouter_policy_fallback() -> None: assert ( provider_policies.should_use_responses( provider_name="openrouter", + model_id="openai/gpt-4o-mini", + has_tools=False, use_responses=True, supports_responses=False, ) @@ -38,6 +44,8 @@ def test_should_use_responses_requires_explicit_policy_or_capability() -> None: assert ( provider_policies.should_use_responses( provider_name="anthropic", + model_id="claude-3-5-haiku-latest", + has_tools=False, use_responses=True, supports_responses=False, ) @@ -45,6 +53,32 @@ def test_should_use_responses_requires_explicit_policy_or_capability() -> None: ) +def test_should_use_responses_openrouter_anthropic_tools_fallbacks_to_completion() -> None: + assert ( + provider_policies.should_use_responses( + provider_name="openrouter", + model_id="anthropic/claude-3.5-haiku", + has_tools=True, + use_responses=True, + supports_responses=False, + ) + is False + ) + + +def test_should_use_responses_openrouter_anthropic_without_tools_still_uses_responses() -> None: + assert ( + provider_policies.should_use_responses( + provider_name="openrouter", + model_id="anthropic/claude-3.5-haiku", + has_tools=False, + use_responses=True, + supports_responses=False, + ) + is True + ) + + def test_completion_stream_usage_policy() -> None: assert provider_policies.should_include_completion_stream_usage("openai") assert provider_policies.should_include_completion_stream_usage("openrouter") @@ -55,3 +89,8 @@ def test_completion_max_tokens_arg_policy() -> None: assert provider_policies.completion_max_tokens_arg("openai") == "max_completion_tokens" assert provider_policies.completion_max_tokens_arg("openrouter") == "max_tokens" assert provider_policies.completion_max_tokens_arg("anthropic") == "max_tokens" + + +def test_provider_policy_uses_exact_match_not_substring() -> None: + assert not provider_policies.should_include_completion_stream_usage("my-openrouter-proxy") + assert provider_policies.completion_max_tokens_arg("my-openrouter-proxy") == "max_tokens" diff --git a/tests/test_responses_handling.py b/tests/test_responses_handling.py index dc6242d..4ca236b 100644 --- a/tests/test_responses_handling.py +++ b/tests/test_responses_handling.py @@ -163,6 +163,23 @@ def test_openrouter_uses_responses_when_enabled_even_if_provider_flag_is_false(f assert client.calls[-1].get("responses") is True +def test_openrouter_anthropic_with_tools_falls_back_to_completion(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.SUPPORTS_RESPONSES = False + client.queue_completion(make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')])) + + llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", use_responses=True) + calls = llm.tool_calls( + "Call echo for tokyo", + tools=[echo], + tool_choice={"type": "function", "function": {"name": "echo"}}, + ) + + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "echo" + assert client.calls[-1].get("responses") is None + + def test_responses_tool_choice_accepts_completion_function_shape(fake_anyllm) -> None: client = fake_anyllm.ensure("openai") client.queue_responses( From 00833838b39a2c071a755409b3f9cd00884a0fb7 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 1 Mar 2026 19:27:45 +0000 Subject: [PATCH 06/14] refactor: unify transport routing across responses and completion --- src/republic/clients/parsing/responses.py | 21 --- src/republic/core/execution.py | 194 +++++++++++++++++----- src/republic/core/provider_policies.py | 26 ++- src/republic/core/request_adapters.py | 26 +++ tests/test_provider_policies.py | 72 ++++---- tests/test_responses_handling.py | 30 ++++ 6 files changed, 277 insertions(+), 92 deletions(-) create mode 100644 src/republic/core/request_adapters.py diff --git a/src/republic/clients/parsing/responses.py b/src/republic/clients/parsing/responses.py index e483d07..1da7ed2 100644 --- a/src/republic/clients/parsing/responses.py +++ b/src/republic/clients/parsing/responses.py @@ -17,27 +17,6 @@ def is_non_stream_response(response: Any) -> bool: ) -def normalize_request_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: - """Normalize completion-style kwargs into responses-compatible shapes.""" - tool_choice = kwargs.get("tool_choice") - if not isinstance(tool_choice, dict): - return kwargs - - function = tool_choice.get("function") - if not isinstance(function, dict): - return kwargs - - function_name = function.get("name") - if not isinstance(function_name, str) or not function_name: - return kwargs - - normalized_tool_choice = dict(tool_choice) - normalized_tool_choice.pop("function", None) - normalized_tool_choice["type"] = normalized_tool_choice.get("type", "function") - normalized_tool_choice["name"] = function_name - return {**kwargs, "tool_choice": normalized_tool_choice} - - def _tool_delta_from_args_event(chunk: Any, event_type: str) -> list[Any]: item_id = field(chunk, "item_id") if not item_id: diff --git a/src/republic/core/execution.py b/src/republic/core/execution.py index 1cc0695..0fd5394 100644 --- a/src/republic/core/execution.py +++ b/src/republic/core/execution.py @@ -26,9 +26,9 @@ UnsupportedProviderError, ) -from republic.clients.parsing import responses as responses_parser from republic.core import provider_policies from republic.core.errors import ErrorKind, RepublicError +from republic.core.request_adapters import normalize_responses_kwargs logger = logging.getLogger(__name__) @@ -349,7 +349,7 @@ def _decide_kwargs_for_provider( def _decide_responses_kwargs(self, max_tokens: int | None, kwargs: dict[str, Any]) -> dict[str, Any]: clean_kwargs = self._sanitize_request_kwargs(kwargs) - clean_kwargs = responses_parser.normalize_request_kwargs(clean_kwargs) + clean_kwargs = normalize_responses_kwargs(clean_kwargs) if "max_output_tokens" in clean_kwargs: return clean_kwargs return {**clean_kwargs, "max_output_tokens": max_tokens} @@ -409,15 +409,15 @@ def _convert_tools_for_responses(tools_payload: list[dict[str, Any]] | None) -> converted_tools.append(dict(tool)) return converted_tools - def _should_use_responses( + def _transport_order( self, client: AnyLLM, *, provider_name: str, model_id: str, tools_payload: list[dict[str, Any]] | None, - ) -> bool: - return provider_policies.should_use_responses( + ) -> tuple[str, ...]: + return provider_policies.transport_order( provider_name=provider_name, model_id=model_id, has_tools=bool(tools_payload), @@ -425,7 +425,34 @@ def _should_use_responses( supports_responses=bool(getattr(client, "SUPPORTS_RESPONSES", False)), ) - def _call_client_sync( + def _should_fallback_transport(self, exc: Exception) -> bool: + kind = self.classify_exception(exc) + return kind in {ErrorKind.INVALID_INPUT, ErrorKind.PROVIDER} + + def _call_responses_sync( + self, + *, + client: AnyLLM, + model_id: str, + messages_payload: list[dict[str, Any]], + tools_payload: list[dict[str, Any]] | None, + max_tokens: int | None, + stream: bool, + reasoning_effort: Any | None, + kwargs: dict[str, Any], + ) -> Any: + instructions, input_items = self._split_messages_for_responses(messages_payload) + responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) + return client.responses( + model=model_id, + input_data=input_items, + tools=self._convert_tools_for_responses(tools_payload), + stream=stream, + instructions=instructions, + **self._decide_responses_kwargs(max_tokens, responses_kwargs), + ) + + def _call_completion_sync( self, *, client: AnyLLM, @@ -438,22 +465,6 @@ def _call_client_sync( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - if self._should_use_responses( - client, - provider_name=provider_name, - model_id=model_id, - tools_payload=tools_payload, - ): - instructions, input_items = self._split_messages_for_responses(messages_payload) - responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) - return client.responses( - model=model_id, - input_data=input_items, - tools=self._convert_tools_for_responses(tools_payload), - stream=stream, - instructions=instructions, - **self._decide_responses_kwargs(max_tokens, responses_kwargs), - ) completion_kwargs = self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs) completion_kwargs = self._with_default_completion_stream_options(provider_name, stream, completion_kwargs) return client.completion( @@ -465,7 +476,30 @@ def _call_client_sync( **completion_kwargs, ) - async def _call_client_async( + async def _call_responses_async( + self, + *, + client: AnyLLM, + model_id: str, + messages_payload: list[dict[str, Any]], + tools_payload: list[dict[str, Any]] | None, + max_tokens: int | None, + stream: bool, + reasoning_effort: Any | None, + kwargs: dict[str, Any], + ) -> Any: + instructions, input_items = self._split_messages_for_responses(messages_payload) + responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) + return await client.aresponses( + model=model_id, + input_data=input_items, + tools=self._convert_tools_for_responses(tools_payload), + stream=stream, + instructions=instructions, + **self._decide_responses_kwargs(max_tokens, responses_kwargs), + ) + + async def _call_completion_async( self, *, client: AnyLLM, @@ -478,22 +512,6 @@ async def _call_client_async( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - if self._should_use_responses( - client, - provider_name=provider_name, - model_id=model_id, - tools_payload=tools_payload, - ): - instructions, input_items = self._split_messages_for_responses(messages_payload) - responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) - return await client.aresponses( - model=model_id, - input_data=input_items, - tools=self._convert_tools_for_responses(tools_payload), - stream=stream, - instructions=instructions, - **self._decide_responses_kwargs(max_tokens, responses_kwargs), - ) completion_kwargs = self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs) completion_kwargs = self._with_default_completion_stream_options(provider_name, stream, completion_kwargs) return await client.acompletion( @@ -505,6 +523,104 @@ async def _call_client_async( **completion_kwargs, ) + def _call_client_sync( + self, + *, + client: AnyLLM, + provider_name: str, + model_id: str, + messages_payload: list[dict[str, Any]], + tools_payload: list[dict[str, Any]] | None, + max_tokens: int | None, + stream: bool, + reasoning_effort: Any | None, + kwargs: dict[str, Any], + ) -> Any: + transports = self._transport_order( + client, + provider_name=provider_name, + model_id=model_id, + tools_payload=tools_payload, + ) + for index, transport in enumerate(transports): + try: + if transport == "responses": + return self._call_responses_sync( + client=client, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) + return self._call_completion_sync( + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) + except Exception as exc: + has_next_transport = index + 1 < len(transports) + if has_next_transport and self._should_fallback_transport(exc): + continue + raise + + async def _call_client_async( + self, + *, + client: AnyLLM, + provider_name: str, + model_id: str, + messages_payload: list[dict[str, Any]], + tools_payload: list[dict[str, Any]] | None, + max_tokens: int | None, + stream: bool, + reasoning_effort: Any | None, + kwargs: dict[str, Any], + ) -> Any: + transports = self._transport_order( + client, + provider_name=provider_name, + model_id=model_id, + tools_payload=tools_payload, + ) + for index, transport in enumerate(transports): + try: + if transport == "responses": + return await self._call_responses_async( + client=client, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) + return await self._call_completion_async( + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) + except Exception as exc: + has_next_transport = index + 1 < len(transports) + if has_next_transport and self._should_fallback_transport(exc): + continue + raise + @staticmethod def _split_messages_for_responses( messages: list[dict[str, Any]], diff --git a/src/republic/core/provider_policies.py b/src/republic/core/provider_policies.py index 1fae74b..bb3c3a2 100644 --- a/src/republic/core/provider_policies.py +++ b/src/republic/core/provider_policies.py @@ -42,16 +42,13 @@ def _responses_tools_blocked_for_model(provider_name: str, model_id: str) -> boo return any(lowered_model.startswith(prefix) for prefix in policy.responses_tools_blocked_model_prefixes) -def should_use_responses( +def should_attempt_responses( *, provider_name: str, model_id: str, has_tools: bool, - use_responses: bool, supports_responses: bool, ) -> bool: - if not use_responses: - return False if has_tools and _responses_tools_blocked_for_model(provider_name, model_id): return False if supports_responses: @@ -59,6 +56,27 @@ def should_use_responses( return provider_policy(provider_name).enable_responses_without_capability +def transport_order( + *, + provider_name: str, + model_id: str, + has_tools: bool, + use_responses: bool, + supports_responses: bool, +) -> tuple[str, ...]: + attempt_responses = should_attempt_responses( + provider_name=provider_name, + model_id=model_id, + has_tools=has_tools, + supports_responses=supports_responses, + ) + if not attempt_responses: + return ("completion",) + if use_responses: + return ("responses", "completion") + return ("completion", "responses") + + def should_include_completion_stream_usage(provider_name: str) -> bool: return provider_policy(provider_name).include_usage_in_completion_stream diff --git a/src/republic/core/request_adapters.py b/src/republic/core/request_adapters.py new file mode 100644 index 0000000..7f68bbd --- /dev/null +++ b/src/republic/core/request_adapters.py @@ -0,0 +1,26 @@ +"""Request-shape adapters for different upstream APIs.""" + +from __future__ import annotations + +from typing import Any + + +def normalize_responses_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: + """Normalize completion-style kwargs into responses-compatible shapes.""" + tool_choice = kwargs.get("tool_choice") + if not isinstance(tool_choice, dict): + return kwargs + + function = tool_choice.get("function") + if not isinstance(function, dict): + return kwargs + + function_name = function.get("name") + if not isinstance(function_name, str) or not function_name: + return kwargs + + normalized_tool_choice = dict(tool_choice) + normalized_tool_choice.pop("function", None) + normalized_tool_choice["type"] = normalized_tool_choice.get("type", "function") + normalized_tool_choice["name"] = function_name + return {**kwargs, "tool_choice": normalized_tool_choice} diff --git a/tests/test_provider_policies.py b/tests/test_provider_policies.py index e7f8108..cd5d993 100644 --- a/tests/test_provider_policies.py +++ b/tests/test_provider_policies.py @@ -1,84 +1,100 @@ from republic.core import provider_policies -def test_should_use_responses_respects_global_flag() -> None: +def test_should_attempt_responses_accepts_provider_capability() -> None: assert ( - provider_policies.should_use_responses( - provider_name="openai", - model_id="gpt-4o-mini", - has_tools=False, - use_responses=False, - supports_responses=True, - ) - is False - ) - - -def test_should_use_responses_accepts_provider_capability() -> None: - assert ( - provider_policies.should_use_responses( + provider_policies.should_attempt_responses( provider_name="anthropic", model_id="claude-3-5-haiku-latest", has_tools=False, - use_responses=True, supports_responses=True, ) is True ) -def test_should_use_responses_openrouter_policy_fallback() -> None: +def test_should_attempt_responses_openrouter_policy_fallback() -> None: assert ( - provider_policies.should_use_responses( + provider_policies.should_attempt_responses( provider_name="openrouter", model_id="openai/gpt-4o-mini", has_tools=False, - use_responses=True, supports_responses=False, ) is True ) -def test_should_use_responses_requires_explicit_policy_or_capability() -> None: +def test_should_attempt_responses_requires_explicit_policy_or_capability() -> None: assert ( - provider_policies.should_use_responses( + provider_policies.should_attempt_responses( provider_name="anthropic", model_id="claude-3-5-haiku-latest", has_tools=False, - use_responses=True, supports_responses=False, ) is False ) -def test_should_use_responses_openrouter_anthropic_tools_fallbacks_to_completion() -> None: +def test_should_attempt_responses_openrouter_anthropic_tools_disabled() -> None: assert ( - provider_policies.should_use_responses( + provider_policies.should_attempt_responses( provider_name="openrouter", model_id="anthropic/claude-3.5-haiku", has_tools=True, - use_responses=True, supports_responses=False, ) is False ) -def test_should_use_responses_openrouter_anthropic_without_tools_still_uses_responses() -> None: +def test_should_attempt_responses_openrouter_anthropic_without_tools_enabled() -> None: assert ( - provider_policies.should_use_responses( + provider_policies.should_attempt_responses( provider_name="openrouter", model_id="anthropic/claude-3.5-haiku", has_tools=False, - use_responses=True, supports_responses=False, ) is True ) +def test_transport_order_respects_user_preference() -> None: + assert provider_policies.transport_order( + provider_name="openrouter", + model_id="openai/gpt-4o-mini", + has_tools=False, + use_responses=True, + supports_responses=False, + ) == ("responses", "completion") + assert provider_policies.transport_order( + provider_name="openrouter", + model_id="openai/gpt-4o-mini", + has_tools=False, + use_responses=False, + supports_responses=False, + ) == ("completion", "responses") + + +def test_transport_order_uses_completion_only_when_responses_unavailable() -> None: + assert provider_policies.transport_order( + provider_name="anthropic", + model_id="claude-3-5-haiku-latest", + has_tools=False, + use_responses=True, + supports_responses=False, + ) == ("completion",) + assert provider_policies.transport_order( + provider_name="openrouter", + model_id="anthropic/claude-3.5-haiku", + has_tools=True, + use_responses=True, + supports_responses=False, + ) == ("completion",) + + def test_completion_stream_usage_policy() -> None: assert provider_policies.should_include_completion_stream_usage("openai") assert provider_policies.should_include_completion_stream_usage("openrouter") diff --git a/tests/test_responses_handling.py b/tests/test_responses_handling.py index 4ca236b..8eb6cc7 100644 --- a/tests/test_responses_handling.py +++ b/tests/test_responses_handling.py @@ -3,6 +3,7 @@ from typing import Any import pytest +from any_llm.exceptions import InvalidRequestError from republic import LLM, tool from republic.clients.chat import ChatClient @@ -180,6 +181,35 @@ def test_openrouter_anthropic_with_tools_falls_back_to_completion(fake_anyllm) - assert client.calls[-1].get("responses") is None +def test_tool_calls_fallback_to_completion_when_responses_rejects_request(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.SUPPORTS_RESPONSES = True + client.queue_responses(InvalidRequestError("responses rejected this tool request")) + client.queue_completion(make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')])) + + llm = LLM(model="openrouter:openai/gpt-4o-mini", api_key="dummy", use_responses=True) + calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) + + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "echo" + assert client.calls[-2].get("responses") is True + assert client.calls[-1].get("responses") is None + + +def test_chat_fallback_to_responses_when_completion_rejects_request(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.SUPPORTS_RESPONSES = True + client.queue_completion(InvalidRequestError("completion rejected request")) + client.queue_responses(make_responses_response(text="hello")) + + llm = LLM(model="openrouter:openai/gpt-4o-mini", api_key="dummy", use_responses=False) + result = llm.chat("hi") + + assert result == "hello" + assert client.calls[-2].get("responses") is None + assert client.calls[-1].get("responses") is True + + def test_responses_tool_choice_accepts_completion_function_shape(fake_anyllm) -> None: client = fake_anyllm.ensure("openai") client.queue_responses( From 28064d768271479fcdd20c45dd514ce6b95a2676 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 1 Mar 2026 19:50:36 +0000 Subject: [PATCH 07/14] refactor: make API format explicit and transport-specific --- src/republic/clients/chat.py | 219 +++++++++++++------- src/republic/clients/parsing/completion.py | 4 + src/republic/core/execution.py | 221 +++++++++++---------- src/republic/core/provider_policies.py | 29 +-- src/republic/llm.py | 11 +- tests/test_provider_policies.py | 59 ++++-- tests/test_responses_handling.py | 138 ++++++------- 7 files changed, 400 insertions(+), 281 deletions(-) diff --git a/src/republic/clients/chat.py b/src/republic/clients/chat.py index 907c88c..62a82b8 100644 --- a/src/republic/clients/chat.py +++ b/src/republic/clients/chat.py @@ -6,14 +6,14 @@ from collections.abc import AsyncIterator, Callable, Iterator from dataclasses import dataclass from functools import partial -from typing import Any +from typing import Any, Literal from republic.clients.parsing import completion as completion_parser from republic.clients.parsing import responses as responses_parser from republic.clients.parsing.common import expand_tool_calls from republic.clients.parsing.common import field as _field from republic.core.errors import ErrorKind, RepublicError -from republic.core.execution import LLMCore +from republic.core.execution import LLMCore, TransportResponse from republic.core.results import ( AsyncStreamEvents, AsyncTextStream, @@ -213,7 +213,21 @@ def default_context(self) -> TapeContext: return self._tape.default_context @staticmethod - def _is_non_stream_response(response: Any) -> bool: + def _unwrap_response(response: Any) -> tuple[Any, Literal["completion", "responses"] | None]: + if isinstance(response, TransportResponse): + return response.payload, response.transport + return response, None + + @staticmethod + def _is_non_stream_response( + response: Any, + *, + transport: Literal["completion", "responses"] | None = None, + ) -> bool: + if transport == "responses": + return responses_parser.is_non_stream_response(response) + if transport == "completion": + return completion_parser.is_non_stream_response(response) return responses_parser.is_non_stream_response(response) def _validate_chat_input( @@ -882,12 +896,13 @@ def _handle_create_response( model_id: str, attempt: int, ) -> str | object: - text = self._extract_text(response) + payload, transport = self._unwrap_response(response) + text = self._extract_text(payload, transport=transport) if text: self._update_tape( prepared, text, - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -904,12 +919,13 @@ async def _handle_create_response_async( model_id: str, attempt: int, ) -> str | object: - text = self._extract_text(response) + payload, transport = self._unwrap_response(response) + text = self._extract_text(payload, transport=transport) if text: await self._update_tape_async( prepared, text, - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -926,13 +942,14 @@ def _handle_tool_calls_response( model_id: str, attempt: int, ) -> list[dict[str, Any]]: - calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + calls = self._extract_tool_calls(payload, transport=transport) self._update_tape( prepared, None, tool_calls=calls, tool_results=[], - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -946,13 +963,14 @@ async def _handle_tool_calls_response_async( model_id: str, attempt: int, ) -> list[dict[str, Any]]: - calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + calls = self._extract_tool_calls(payload, transport=transport) await self._update_tape_async( prepared, None, tool_calls=calls, tool_results=[], - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -966,7 +984,8 @@ def _handle_tools_auto_response( model_id: str, attempt: int, ) -> ToolAutoResult | object: - tool_calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + tool_calls = self._extract_tool_calls(payload, transport=transport) if tool_calls: execution = self._tool_executor.execute( tool_calls, @@ -978,18 +997,18 @@ def _handle_tools_auto_response( None, tool_calls=execution.tool_calls, tool_results=execution.tool_results, - response=response, + response=payload, provider=provider_name, model=model_id, ) return ToolAutoResult.tools_result(execution.tool_calls, execution.tool_results) - text = self._extract_text(response) + text = self._extract_text(payload, transport=transport) if text: self._update_tape( prepared, text, - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -1007,7 +1026,8 @@ async def _handle_tools_auto_response_async( model_id: str, attempt: int, ) -> ToolAutoResult | object: - tool_calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + tool_calls = self._extract_tool_calls(payload, transport=transport) if tool_calls: execution = await self._tool_executor.execute_async( tool_calls, @@ -1019,18 +1039,18 @@ async def _handle_tools_auto_response_async( None, tool_calls=execution.tool_calls, tool_results=execution.tool_results, - response=response, + response=payload, provider=provider_name, model=model_id, ) return ToolAutoResult.tools_result(execution.tool_calls, execution.tool_results) - text = self._extract_text(response) + text = self._extract_text(payload, transport=transport) if text: await self._update_tape_async( prepared, text, - response=response, + response=payload, provider=provider_name, model=model_id, ) @@ -1416,9 +1436,10 @@ def _build_text_stream( model_id: str, attempt: int, ) -> TextStream: - if self._is_non_stream_response(response): - text = self._extract_text(response) - tool_calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + if self._is_non_stream_response(payload, transport=transport): + text = self._extract_text(payload, transport=transport) + tool_calls = self._extract_tool_calls(payload, transport=transport) state = StreamState() self._finalize_text_stream( prepared, @@ -1428,8 +1449,8 @@ def _build_text_stream( provider_name=provider_name, model_id=model_id, attempt=attempt, - usage=self._extract_usage(response), - response=response, + usage=self._extract_usage(payload, transport=transport), + response=payload, log_empty=False, ) return TextStream(iter([text]) if text else self._empty_iterator(), state=state) @@ -1443,15 +1464,15 @@ def _build_text_stream( def _iterator() -> Iterator[str]: nonlocal usage try: - for chunk in response: - deltas = self._extract_chunk_tool_call_deltas(chunk) + for chunk in payload: + deltas = self._extract_chunk_tool_call_deltas(chunk, transport=transport) if deltas: assembler.add_deltas(deltas) - text = self._extract_chunk_text(chunk) + text = self._extract_chunk_text(chunk, transport=transport) if text: parts.append(text) yield text - usage = self._extract_usage(chunk) or usage + usage = self._extract_usage(chunk, transport=transport) or usage except Exception as exc: kind = self._core.classify_exception(exc) wrapped = self._core.wrap_error(exc, kind, provider_name, model_id) @@ -1480,9 +1501,10 @@ async def _build_async_text_stream( model_id: str, attempt: int, ) -> AsyncTextStream: - if self._is_non_stream_response(response): - text = self._extract_text(response) - tool_calls = self._extract_tool_calls(response) + payload, transport = self._unwrap_response(response) + if self._is_non_stream_response(payload, transport=transport): + text = self._extract_text(payload, transport=transport) + tool_calls = self._extract_tool_calls(payload, transport=transport) state = StreamState() await self._finalize_text_stream_async( prepared, @@ -1492,8 +1514,8 @@ async def _build_async_text_stream( provider_name=provider_name, model_id=model_id, attempt=attempt, - usage=self._extract_usage(response), - response=response, + usage=self._extract_usage(payload, transport=transport), + response=payload, log_empty=False, ) @@ -1511,15 +1533,15 @@ async def _single() -> AsyncIterator[str]: async def _iterator() -> AsyncIterator[str]: nonlocal usage try: - async for chunk in response: - deltas = self._extract_chunk_tool_call_deltas(chunk) + async for chunk in payload: + deltas = self._extract_chunk_tool_call_deltas(chunk, transport=transport) if deltas: assembler.add_deltas(deltas) - text = self._extract_chunk_text(chunk) + text = self._extract_chunk_text(chunk, transport=transport) if text: parts.append(text) yield text - usage = self._extract_usage(chunk) or usage + usage = self._extract_usage(chunk, transport=transport) or usage except Exception as exc: kind = self._core.classify_exception(exc) wrapped = self._core.wrap_error(exc, kind, provider_name, model_id) @@ -1541,18 +1563,38 @@ async def _iterator() -> AsyncIterator[str]: return AsyncTextStream(_iterator(), state=state) @staticmethod - def _chunk_has_tool_calls(chunk: Any) -> bool: - return bool(ChatClient._extract_chunk_tool_call_deltas(chunk)) + def _chunk_has_tool_calls( + chunk: Any, + *, + transport: Literal["completion", "responses"] | None = None, + ) -> bool: + return bool(ChatClient._extract_chunk_tool_call_deltas(chunk, transport=transport)) @staticmethod - def _extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: + def _extract_chunk_tool_call_deltas( + chunk: Any, + *, + transport: Literal["completion", "responses"] | None = None, + ) -> list[Any]: + if transport == "responses": + return responses_parser.extract_chunk_tool_call_deltas(chunk) + if transport == "completion": + return completion_parser.extract_chunk_tool_call_deltas(chunk) responses_deltas = responses_parser.extract_chunk_tool_call_deltas(chunk) if responses_deltas: return responses_deltas return completion_parser.extract_chunk_tool_call_deltas(chunk) @staticmethod - def _extract_chunk_text(chunk: Any) -> str: + def _extract_chunk_text( + chunk: Any, + *, + transport: Literal["completion", "responses"] | None = None, + ) -> str: + if transport == "responses": + return responses_parser.extract_chunk_text(chunk) + if transport == "completion": + return completion_parser.extract_chunk_text(chunk) responses_text = responses_parser.extract_chunk_text(chunk) if responses_text: return responses_text @@ -1566,12 +1608,14 @@ def _build_event_stream( model_id: str, attempt: int, ) -> StreamEvents: - if self._is_non_stream_response(response): + payload, transport = self._unwrap_response(response) + if self._is_non_stream_response(payload, transport=transport): return self._build_event_stream_from_response( prepared, - response, + payload, provider_name, model_id, + transport=transport, ) state = StreamState() @@ -1584,10 +1628,10 @@ def _build_event_stream( def _iterator() -> Iterator[StreamEvent]: nonlocal usage, tool_calls, tool_results try: - for chunk in response: - usage = self._extract_usage(chunk) or usage - assembler.add_deltas(self._extract_chunk_tool_call_deltas(chunk)) - text = self._extract_chunk_text(chunk) + for chunk in payload: + usage = self._extract_usage(chunk, transport=transport) or usage + assembler.add_deltas(self._extract_chunk_tool_call_deltas(chunk, transport=transport)) + text = self._extract_chunk_text(chunk, transport=transport) if text: parts.append(text) yield StreamEvent("text", {"delta": text}) @@ -1639,12 +1683,14 @@ def _build_async_event_stream( model_id: str, attempt: int, ) -> AsyncStreamEvents: - if self._is_non_stream_response(response): + payload, transport = self._unwrap_response(response) + if self._is_non_stream_response(payload, transport=transport): return self._build_async_event_stream_from_response( prepared, - response, + payload, provider_name, model_id, + transport=transport, ) state = StreamState() @@ -1657,10 +1703,10 @@ def _build_async_event_stream( async def _iterator() -> AsyncIterator[StreamEvent]: nonlocal usage, tool_calls, tool_results try: - async for chunk in response: - usage = self._extract_usage(chunk) or usage - assembler.add_deltas(self._extract_chunk_tool_call_deltas(chunk)) - text = self._extract_chunk_text(chunk) + async for chunk in payload: + usage = self._extract_usage(chunk, transport=transport) or usage + assembler.add_deltas(self._extract_chunk_tool_call_deltas(chunk, transport=transport)) + text = self._extract_chunk_text(chunk, transport=transport) if text: parts.append(text) yield StreamEvent("text", {"delta": text}) @@ -1712,10 +1758,12 @@ def _build_event_stream_from_response( response: Any, provider_name: str, model_id: str, + *, + transport: Literal["completion", "responses"] | None = None, ) -> StreamEvents: - text = self._extract_text(response) - tool_calls = self._extract_tool_calls(response) - usage = self._extract_usage(response) + text = self._extract_text(response, transport=transport) + tool_calls = self._extract_tool_calls(response, transport=transport) + usage = self._extract_usage(response, transport=transport) state = StreamState(usage=usage) tool_results: list[Any] = [] try: @@ -1768,10 +1816,12 @@ def _build_async_event_stream_from_response( response: Any, provider_name: str, model_id: str, + *, + transport: Literal["completion", "responses"] | None = None, ) -> AsyncStreamEvents: - text = self._extract_text(response) - tool_calls = self._extract_tool_calls(response) - usage = self._extract_usage(response) + text = self._extract_text(response, transport=transport) + tool_calls = self._extract_tool_calls(response, transport=transport) + usage = self._extract_usage(response, transport=transport) state = StreamState(usage=usage) tool_results: list[Any] = [] @@ -1871,18 +1921,38 @@ def _extract_text_from_responses_output(output: Any) -> str: return responses_parser.extract_text_from_output(output) @staticmethod - def _extract_text(response: Any) -> str: - responses_text = responses_parser.extract_text(response) + def _extract_text( + response: Any, + *, + transport: Literal["completion", "responses"] | None = None, + ) -> str: + payload, detected_transport = ChatClient._unwrap_response(response) + effective_transport = transport or detected_transport + if effective_transport == "responses": + return responses_parser.extract_text(payload) + if effective_transport == "completion": + return completion_parser.extract_text(payload) + responses_text = responses_parser.extract_text(payload) if responses_text: return responses_text - return completion_parser.extract_text(response) + return completion_parser.extract_text(payload) @staticmethod - def _extract_tool_calls(response: Any) -> list[dict[str, Any]]: - output = _field(response, "output") + def _extract_tool_calls( + response: Any, + *, + transport: Literal["completion", "responses"] | None = None, + ) -> list[dict[str, Any]]: + payload, detected_transport = ChatClient._unwrap_response(response) + effective_transport = transport or detected_transport + if effective_transport == "responses": + return responses_parser.extract_tool_calls(_field(payload, "output")) + if effective_transport == "completion": + return completion_parser.extract_tool_calls(payload) + output = _field(payload, "output") if output is not None: return ChatClient._extract_responses_tool_calls(output) - return ChatClient._extract_completion_tool_calls(response) + return ChatClient._extract_completion_tool_calls(payload) @staticmethod def _extract_responses_tool_calls(output: Any) -> list[dict[str, Any]]: @@ -1893,5 +1963,20 @@ def _extract_completion_tool_calls(response: Any) -> list[dict[str, Any]]: return completion_parser.extract_tool_calls(response) @staticmethod - def _extract_usage(response: Any) -> dict[str, Any] | None: - return responses_parser.extract_usage(response) + def _extract_usage( + response: Any, + *, + transport: Literal["completion", "responses"] | None = None, + ) -> dict[str, Any] | None: + payload, detected_transport = ChatClient._unwrap_response(response) + effective_transport = transport or detected_transport + if effective_transport == "completion": + usage = _field(payload, "usage") + if usage is None: + return None + if isinstance(usage, dict): + return dict(usage) + if hasattr(usage, "model_dump"): + return usage.model_dump() + return None + return responses_parser.extract_usage(payload) diff --git a/src/republic/clients/parsing/completion.py b/src/republic/clients/parsing/completion.py index 26c2f4c..5d4dfdf 100644 --- a/src/republic/clients/parsing/completion.py +++ b/src/republic/clients/parsing/completion.py @@ -7,6 +7,10 @@ from republic.clients.parsing.common import expand_tool_calls, field +def is_non_stream_response(response: Any) -> bool: + return isinstance(response, str) or field(response, "choices") is not None + + def extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: choices = field(chunk, "choices") if not choices: diff --git a/src/republic/core/execution.py b/src/republic/core/execution.py index 0fd5394..0011e12 100644 --- a/src/republic/core/execution.py +++ b/src/republic/core/execution.py @@ -9,7 +9,7 @@ from collections.abc import Callable from dataclasses import dataclass from enum import Enum, auto -from typing import Any, NoReturn +from typing import Any, Literal, NoReturn from any_llm import AnyLLM from any_llm.exceptions import ( @@ -48,6 +48,12 @@ class AttemptOutcome: decision: AttemptDecision +@dataclass(frozen=True) +class TransportResponse: + transport: Literal["completion", "responses"] + payload: Any + + class LLMCore: """Shared LLM execution utilities (provider resolution, retries, client cache).""" @@ -63,7 +69,7 @@ def __init__( api_key: str | dict[str, str] | None, api_base: str | dict[str, str] | None, client_args: dict[str, Any], - use_responses: bool, + api_format: Literal["completion", "responses", "anthropic_messages"], verbose: int, error_classifier: Callable[[Exception], ErrorKind | None] | None = None, ) -> None: @@ -74,7 +80,7 @@ def __init__( self._api_key = api_key self._api_base = api_base self._client_args = client_args - self._use_responses = use_responses + self._api_format = api_format self._verbose = verbose self._error_classifier = error_classifier self._client_cache: dict[str, AnyLLM] = {} @@ -330,8 +336,12 @@ def raise_wrapped(self, exc: Exception, provider: str, model: str) -> NoReturn: raise self.wrap_error(exc, kind, provider, model) from exc def _handle_attempt_error(self, exc: Exception, provider_name: str, model_id: str, attempt: int) -> AttemptOutcome: - kind = self.classify_exception(exc) - wrapped = self.wrap_error(exc, kind, provider_name, model_id) + if isinstance(exc, RepublicError): + wrapped = exc + kind = exc.kind + else: + kind = self.classify_exception(exc) + wrapped = self.wrap_error(exc, kind, provider_name, model_id) self.log_error(wrapped, provider_name, model_id, attempt) can_retry_same_model = self.should_retry(kind) and attempt + 1 < self.max_attempts() if can_retry_same_model: @@ -409,25 +419,36 @@ def _convert_tools_for_responses(tools_payload: list[dict[str, Any]] | None) -> converted_tools.append(dict(tool)) return converted_tools - def _transport_order( + def _selected_transport( self, client: AnyLLM, *, provider_name: str, model_id: str, tools_payload: list[dict[str, Any]] | None, - ) -> tuple[str, ...]: - return provider_policies.transport_order( + ) -> Literal["completion", "responses"]: + if self._api_format == "completion": + return "completion" + if self._api_format == "anthropic_messages": + if not provider_policies.supports_anthropic_messages_format( + provider_name=provider_name, + model_id=model_id, + ): + raise RepublicError( + ErrorKind.INVALID_INPUT, + f"{provider_name}:{model_id}: anthropic_messages format is only valid for Anthropic models", + ) + return "completion" + + reason = provider_policies.responses_rejection_reason( provider_name=provider_name, model_id=model_id, has_tools=bool(tools_payload), - use_responses=self._use_responses, supports_responses=bool(getattr(client, "SUPPORTS_RESPONSES", False)), ) - - def _should_fallback_transport(self, exc: Exception) -> bool: - kind = self.classify_exception(exc) - return kind in {ErrorKind.INVALID_INPUT, ErrorKind.PROVIDER} + if reason is not None: + raise RepublicError(ErrorKind.INVALID_INPUT, f"{provider_name}:{model_id}: {reason}") + return "responses" def _call_responses_sync( self, @@ -443,13 +464,16 @@ def _call_responses_sync( ) -> Any: instructions, input_items = self._split_messages_for_responses(messages_payload) responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) - return client.responses( - model=model_id, - input_data=input_items, - tools=self._convert_tools_for_responses(tools_payload), - stream=stream, - instructions=instructions, - **self._decide_responses_kwargs(max_tokens, responses_kwargs), + return TransportResponse( + transport="responses", + payload=client.responses( + model=model_id, + input_data=input_items, + tools=self._convert_tools_for_responses(tools_payload), + stream=stream, + instructions=instructions, + **self._decide_responses_kwargs(max_tokens, responses_kwargs), + ), ) def _call_completion_sync( @@ -467,13 +491,16 @@ def _call_completion_sync( ) -> Any: completion_kwargs = self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs) completion_kwargs = self._with_default_completion_stream_options(provider_name, stream, completion_kwargs) - return client.completion( - model=model_id, - messages=messages_payload, - tools=tools_payload, - stream=stream, - reasoning_effort=reasoning_effort, - **completion_kwargs, + return TransportResponse( + transport="completion", + payload=client.completion( + model=model_id, + messages=messages_payload, + tools=tools_payload, + stream=stream, + reasoning_effort=reasoning_effort, + **completion_kwargs, + ), ) async def _call_responses_async( @@ -490,13 +517,16 @@ async def _call_responses_async( ) -> Any: instructions, input_items = self._split_messages_for_responses(messages_payload) responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) - return await client.aresponses( - model=model_id, - input_data=input_items, - tools=self._convert_tools_for_responses(tools_payload), - stream=stream, - instructions=instructions, - **self._decide_responses_kwargs(max_tokens, responses_kwargs), + return TransportResponse( + transport="responses", + payload=await client.aresponses( + model=model_id, + input_data=input_items, + tools=self._convert_tools_for_responses(tools_payload), + stream=stream, + instructions=instructions, + **self._decide_responses_kwargs(max_tokens, responses_kwargs), + ), ) async def _call_completion_async( @@ -514,13 +544,16 @@ async def _call_completion_async( ) -> Any: completion_kwargs = self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs) completion_kwargs = self._with_default_completion_stream_options(provider_name, stream, completion_kwargs) - return await client.acompletion( - model=model_id, - messages=messages_payload, - tools=tools_payload, - stream=stream, - reasoning_effort=reasoning_effort, - **completion_kwargs, + return TransportResponse( + transport="completion", + payload=await client.acompletion( + model=model_id, + messages=messages_payload, + tools=tools_payload, + stream=stream, + reasoning_effort=reasoning_effort, + **completion_kwargs, + ), ) def _call_client_sync( @@ -536,41 +569,34 @@ def _call_client_sync( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - transports = self._transport_order( + transport = self._selected_transport( client, provider_name=provider_name, model_id=model_id, tools_payload=tools_payload, ) - for index, transport in enumerate(transports): - try: - if transport == "responses": - return self._call_responses_sync( - client=client, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - return self._call_completion_sync( - client=client, - provider_name=provider_name, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - except Exception as exc: - has_next_transport = index + 1 < len(transports) - if has_next_transport and self._should_fallback_transport(exc): - continue - raise + if transport == "responses": + return self._call_responses_sync( + client=client, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) + return self._call_completion_sync( + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) async def _call_client_async( self, @@ -585,41 +611,34 @@ async def _call_client_async( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - transports = self._transport_order( + transport = self._selected_transport( client, provider_name=provider_name, model_id=model_id, tools_payload=tools_payload, ) - for index, transport in enumerate(transports): - try: - if transport == "responses": - return await self._call_responses_async( - client=client, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - return await self._call_completion_async( - client=client, - provider_name=provider_name, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - except Exception as exc: - has_next_transport = index + 1 < len(transports) - if has_next_transport and self._should_fallback_transport(exc): - continue - raise + if transport == "responses": + return await self._call_responses_async( + client=client, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) + return await self._call_completion_async( + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) @staticmethod def _split_messages_for_responses( diff --git a/src/republic/core/provider_policies.py b/src/republic/core/provider_policies.py index bb3c3a2..132183d 100644 --- a/src/republic/core/provider_policies.py +++ b/src/republic/core/provider_policies.py @@ -56,25 +56,26 @@ def should_attempt_responses( return provider_policy(provider_name).enable_responses_without_capability -def transport_order( +def responses_rejection_reason( *, provider_name: str, model_id: str, has_tools: bool, - use_responses: bool, supports_responses: bool, -) -> tuple[str, ...]: - attempt_responses = should_attempt_responses( - provider_name=provider_name, - model_id=model_id, - has_tools=has_tools, - supports_responses=supports_responses, - ) - if not attempt_responses: - return ("completion",) - if use_responses: - return ("responses", "completion") - return ("completion", "responses") +) -> str | None: + if has_tools and _responses_tools_blocked_for_model(provider_name, model_id): + return "responses format is not supported for this model when tools are enabled" + if supports_responses: + return None + if provider_policy(provider_name).enable_responses_without_capability: + return None + return "responses format is not supported by this provider" + + +def supports_anthropic_messages_format(*, provider_name: str, model_id: str) -> bool: + normalized_provider = _normalize_provider_name(provider_name) + normalized_model = model_id.strip().lower() + return normalized_provider == "anthropic" or normalized_model.startswith("anthropic/") def should_include_completion_stream_usage(provider_name: str) -> bool: diff --git a/src/republic/llm.py b/src/republic/llm.py index 445f27b..54278de 100644 --- a/src/republic/llm.py +++ b/src/republic/llm.py @@ -4,7 +4,7 @@ import warnings from collections.abc import Callable -from typing import Any, cast +from typing import Any, Literal, cast from republic.__about__ import DEFAULT_MODEL from republic.clients._internal import InternalOps @@ -48,7 +48,7 @@ def __init__( api_key: str | dict[str, str] | None = None, api_base: str | dict[str, str] | None = None, client_args: dict[str, Any] | None = None, - use_responses: bool = False, + api_format: Literal["completion", "responses", "anthropic_messages"] = "completion", verbose: int = 0, tape_store: TapeStore | AsyncTapeStore | None = None, context: TapeContext | None = None, @@ -58,6 +58,11 @@ def __init__( raise RepublicError(ErrorKind.INVALID_INPUT, "verbose must be 0, 1, or 2") if max_retries < 0: raise RepublicError(ErrorKind.INVALID_INPUT, "max_retries must be >= 0") + if api_format not in {"completion", "responses", "anthropic_messages"}: + raise RepublicError( + ErrorKind.INVALID_INPUT, + "api_format must be 'completion', 'responses', or 'anthropic_messages'", + ) if not model: model = DEFAULT_MODEL @@ -73,7 +78,7 @@ def __init__( api_key=api_key, api_base=api_base, client_args=client_args or {}, - use_responses=use_responses, + api_format=api_format, verbose=verbose, error_classifier=error_classifier, ) diff --git a/tests/test_provider_policies.py b/tests/test_provider_policies.py index cd5d993..7a1412b 100644 --- a/tests/test_provider_policies.py +++ b/tests/test_provider_policies.py @@ -61,38 +61,53 @@ def test_should_attempt_responses_openrouter_anthropic_without_tools_enabled() - ) -def test_transport_order_respects_user_preference() -> None: - assert provider_policies.transport_order( - provider_name="openrouter", - model_id="openai/gpt-4o-mini", - has_tools=False, - use_responses=True, - supports_responses=False, - ) == ("responses", "completion") - assert provider_policies.transport_order( - provider_name="openrouter", - model_id="openai/gpt-4o-mini", - has_tools=False, - use_responses=False, - supports_responses=False, - ) == ("completion", "responses") +def test_responses_rejection_reason_none_when_openrouter_responses_available() -> None: + assert ( + provider_policies.responses_rejection_reason( + provider_name="openrouter", + model_id="openai/gpt-4o-mini", + has_tools=False, + supports_responses=False, + ) + is None + ) -def test_transport_order_uses_completion_only_when_responses_unavailable() -> None: - assert provider_policies.transport_order( +def test_responses_rejection_reason_for_provider_without_responses() -> None: + reason = provider_policies.responses_rejection_reason( provider_name="anthropic", model_id="claude-3-5-haiku-latest", has_tools=False, - use_responses=True, supports_responses=False, - ) == ("completion",) - assert provider_policies.transport_order( + ) + assert reason is not None + assert "not supported" in reason + + +def test_responses_rejection_reason_for_openrouter_anthropic_tools() -> None: + reason = provider_policies.responses_rejection_reason( provider_name="openrouter", model_id="anthropic/claude-3.5-haiku", has_tools=True, - use_responses=True, supports_responses=False, - ) == ("completion",) + ) + assert reason is not None + assert "tools" in reason + + +def test_supports_anthropic_messages_format() -> None: + assert provider_policies.supports_anthropic_messages_format( + provider_name="anthropic", + model_id="claude-3-5-haiku-latest", + ) + assert provider_policies.supports_anthropic_messages_format( + provider_name="openrouter", + model_id="anthropic/claude-3.5-haiku", + ) + assert not provider_policies.supports_anthropic_messages_format( + provider_name="openai", + model_id="gpt-4o-mini", + ) def test_completion_stream_usage_policy() -> None: diff --git a/tests/test_responses_handling.py b/tests/test_responses_handling.py index 8eb6cc7..4d0cdd5 100644 --- a/tests/test_responses_handling.py +++ b/tests/test_responses_handling.py @@ -3,11 +3,11 @@ from typing import Any import pytest -from any_llm.exceptions import InvalidRequestError from republic import LLM, tool from republic.clients.chat import ChatClient from republic.core.execution import LLMCore +from republic.core.results import ErrorPayload from .fakes import ( make_chunk, @@ -111,17 +111,8 @@ def _completion_stream_event_items() -> list[Any]: ] -def _main_path_payloads(*, use_responses: bool, async_mode: bool) -> list[Any]: +def _main_path_payloads(*, async_mode: bool) -> list[Any]: wrap_stream = _as_async_iter if async_mode else iter - if use_responses: - return [ - make_responses_response(text="ready"), - make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]), - make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]), - wrap_stream(_responses_stream_text_items()), - wrap_stream(_responses_stream_event_items()), - ] - return [ make_response(text="ready"), make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')]), @@ -131,20 +122,28 @@ def _main_path_payloads(*, use_responses: bool, async_mode: bool) -> list[Any]: ] -def _queue_main_path_fixtures(client: Any, *, use_responses: bool, async_mode: bool) -> None: - payloads = _main_path_payloads(use_responses=use_responses, async_mode=async_mode) - if use_responses: - queue = client.queue_aresponses if async_mode else client.queue_responses - else: - queue = client.queue_acompletion if async_mode else client.queue_completion +def _queue_main_path_fixtures(client: Any, *, async_mode: bool) -> None: + payloads = _main_path_payloads(async_mode=async_mode) + queue = client.queue_acompletion if async_mode else client.queue_completion queue(*payloads) -def test_llm_use_responses_calls_responses(fake_anyllm) -> None: +def test_default_api_format_uses_completion(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion(make_response(text="hello")) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + result = llm.chat("hi") + + assert result == "hello" + assert client.calls[-1].get("responses") is None + + +def test_responses_api_format_uses_responses(fake_anyllm) -> None: client = fake_anyllm.ensure("openai") client.queue_responses(make_responses_response(text="hello")) - llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", use_responses=True) + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", api_format="responses") result = llm.chat("hi") assert result == "hello" @@ -152,62 +151,61 @@ def test_llm_use_responses_calls_responses(fake_anyllm) -> None: assert client.calls[-1]["input_data"][0]["role"] == "user" -def test_openrouter_uses_responses_when_enabled_even_if_provider_flag_is_false(fake_anyllm) -> None: +def test_openrouter_responses_works_even_if_provider_flag_is_false(fake_anyllm) -> None: client = fake_anyllm.ensure("openrouter") client.SUPPORTS_RESPONSES = False client.queue_responses(make_responses_response(text="hello")) - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") result = llm.chat("hi") assert result == "hello" assert client.calls[-1].get("responses") is True -def test_openrouter_anthropic_with_tools_falls_back_to_completion(fake_anyllm) -> None: +def test_openrouter_anthropic_tools_rejects_responses_format(fake_anyllm) -> None: client = fake_anyllm.ensure("openrouter") client.SUPPORTS_RESPONSES = False + llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="responses") + + with pytest.raises(ErrorPayload) as exc_info: + llm.tool_calls( + "Call echo for tokyo", + tools=[echo], + tool_choice={"type": "function", "function": {"name": "echo"}}, + ) + assert exc_info.value.kind == "invalid_input" + + +def test_openrouter_anthropic_tools_work_with_completion_format(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") client.queue_completion(make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')])) - llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", use_responses=True) - calls = llm.tool_calls( - "Call echo for tokyo", - tools=[echo], - tool_choice={"type": "function", "function": {"name": "echo"}}, - ) + llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy") + calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) assert len(calls) == 1 assert calls[0]["function"]["name"] == "echo" assert client.calls[-1].get("responses") is None -def test_tool_calls_fallback_to_completion_when_responses_rejects_request(fake_anyllm) -> None: +def test_anthropic_messages_format_maps_to_completion(fake_anyllm) -> None: client = fake_anyllm.ensure("openrouter") - client.SUPPORTS_RESPONSES = True - client.queue_responses(InvalidRequestError("responses rejected this tool request")) client.queue_completion(make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')])) - llm = LLM(model="openrouter:openai/gpt-4o-mini", api_key="dummy", use_responses=True) + llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="anthropic_messages") calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) assert len(calls) == 1 assert calls[0]["function"]["name"] == "echo" - assert client.calls[-2].get("responses") is True assert client.calls[-1].get("responses") is None -def test_chat_fallback_to_responses_when_completion_rejects_request(fake_anyllm) -> None: - client = fake_anyllm.ensure("openrouter") - client.SUPPORTS_RESPONSES = True - client.queue_completion(InvalidRequestError("completion rejected request")) - client.queue_responses(make_responses_response(text="hello")) - - llm = LLM(model="openrouter:openai/gpt-4o-mini", api_key="dummy", use_responses=False) - result = llm.chat("hi") - - assert result == "hello" - assert client.calls[-2].get("responses") is None - assert client.calls[-1].get("responses") is True +def test_anthropic_messages_format_rejects_non_anthropic_model(fake_anyllm) -> None: + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", api_format="anthropic_messages") + with pytest.raises(ErrorPayload) as exc_info: + llm.chat("hi") + assert exc_info.value.kind == "invalid_input" def test_responses_tool_choice_accepts_completion_function_shape(fake_anyllm) -> None: @@ -216,7 +214,7 @@ def test_responses_tool_choice_accepts_completion_function_shape(fake_anyllm) -> make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) ) - llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", use_responses=True) + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", api_format="responses") calls = llm.tool_calls( "Call echo for tokyo", tools=[echo], @@ -329,7 +327,7 @@ def test_stream_uses_responses_and_collects_usage(fake_anyllm) -> None: ]) ) - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") stream = llm.stream("Say hello") text = "".join(list(stream)) @@ -352,7 +350,7 @@ def test_stream_events_supports_responses_tool_events(fake_anyllm) -> None: ]) ) - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") stream = llm.stream_events("Call echo for tokyo", tools=[echo]) events = list(stream) @@ -392,7 +390,7 @@ def test_stream_and_events_support_responses_dict_events(fake_anyllm) -> None: ) client.queue_responses(make_responses_response(text="ready", usage={"total_tokens": 3})) - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") events_stream = llm.stream_events("Call echo for tokyo", tools=[echo]) events = list(events_stream) tool_call = next(event for event in events if event.kind == "tool_call") @@ -431,7 +429,7 @@ def test_stream_events_parity_between_completion_and_responses(fake_anyllm) -> N ) completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") - responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") completion_events = list(completion_llm.stream_events("Call echo for tokyo", tools=[echo])) responses_events = list(responses_llm.stream_events("Call echo for tokyo", tools=[echo])) @@ -458,7 +456,7 @@ def test_stream_events_responses_output_item_events_keep_call_id(fake_anyllm) -> ]) ) - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") events = list(llm.stream_events("Call echo for tokyo", tools=[echo])) tool_call = next(event for event in events if event.kind == "tool_call").data["call"] @@ -476,7 +474,7 @@ def test_stream_usage_accepts_responses_in_progress_usage(fake_anyllm) -> None: ]) ) - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") stream = llm.stream("Reply with ok") assert "".join(list(stream)) == "ok" assert stream.usage == {"total_tokens": 2} @@ -488,7 +486,7 @@ def test_non_stream_responses_tool_calls_converts_tools_payload(fake_anyllm) -> make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) ) - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) assert len(calls) == 1 @@ -507,7 +505,7 @@ def test_non_stream_responses_run_tools_uses_converted_tools(fake_anyllm) -> Non make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) ) - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") result = llm.run_tools("Call echo for tokyo", tools=[echo]) assert result.kind == "tools" @@ -533,7 +531,7 @@ def test_chat_reasoning_effort_for_responses_is_mapped(fake_anyllm) -> None: client = fake_anyllm.ensure("openrouter") client.queue_responses(make_responses_response(text="ready")) - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") assert llm.chat("Reply with ready", reasoning_effort="low") == "ready" call = client.calls[-1] @@ -546,7 +544,7 @@ def test_chat_reasoning_kwarg_has_priority_over_reasoning_effort(fake_anyllm) -> client = fake_anyllm.ensure("openrouter") client.queue_responses(make_responses_response(text="ready")) - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") assert llm.chat("Reply with ready", reasoning_effort="low", reasoning={"effort": "high"}) == "ready" call = client.calls[-1] @@ -620,7 +618,7 @@ def test_non_stream_chat_parity_between_completion_and_responses(fake_anyllm) -> responses_client.queue_responses(make_responses_response(text="ready")) completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") - responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") assert completion_llm.chat("Reply with ready") == "ready" assert responses_llm.chat("Reply with ready") == "ready" @@ -636,7 +634,7 @@ def test_non_stream_run_tools_parity_between_completion_and_responses(fake_anyll ) completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") - responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", use_responses=True) + responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") completion_result = completion_llm.run_tools("Call echo for tokyo", tools=[echo]) responses_result = responses_llm.run_tools("Call echo for tokyo", tools=[echo]) @@ -645,11 +643,10 @@ def test_non_stream_run_tools_parity_between_completion_and_responses(fake_anyll assert completion_result.tool_results == responses_result.tool_results == ["TOKYO"] -@pytest.mark.parametrize("use_responses", [False, True]) -def test_sync_main_paths_with_mode_switch(fake_anyllm, use_responses: bool) -> None: +def test_sync_main_paths(fake_anyllm) -> None: client = fake_anyllm.ensure("openai") - _queue_main_path_fixtures(client, use_responses=use_responses, async_mode=False) - llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", use_responses=use_responses) + _queue_main_path_fixtures(client, async_mode=False) + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") assert llm.chat("Reply with ready") == "ready" @@ -675,18 +672,14 @@ def test_sync_main_paths_with_mode_switch(fake_anyllm, use_responses: bool) -> N assert compact[-1][0] == "final" assert event_stream.usage == {"total_tokens": 12} - if use_responses: - assert all(call.get("responses") is True for call in client.calls) - else: - assert all(call.get("responses") is None for call in client.calls) + assert all(call.get("responses") is None for call in client.calls) @pytest.mark.asyncio -@pytest.mark.parametrize("use_responses", [False, True]) -async def test_async_main_paths_with_mode_switch(fake_anyllm, use_responses: bool) -> None: +async def test_async_main_paths(fake_anyllm) -> None: client = fake_anyllm.ensure("openai") - _queue_main_path_fixtures(client, use_responses=use_responses, async_mode=True) - llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", use_responses=use_responses) + _queue_main_path_fixtures(client, async_mode=True) + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") assert await llm.chat_async("Reply with ready") == "ready" @@ -712,7 +705,4 @@ async def test_async_main_paths_with_mode_switch(fake_anyllm, use_responses: boo assert compact[-1][0] == "final" assert event_stream.usage == {"total_tokens": 12} - if use_responses: - assert all(call.get("responses") is True for call in client.calls) - else: - assert all(call.get("responses") is None for call in client.calls) + assert all(call.get("responses") is None for call in client.calls) From 6871e2d315f5e37726df5c6ba95a6a7f6c4e3547 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 1 Mar 2026 20:03:16 +0000 Subject: [PATCH 08/14] refactor: align completion/responses/messages transports across chat/tools/stream --- README.md | 8 +- docs/guides/chat.md | 21 +-- docs/guides/stream-events.md | 2 + docs/guides/tools.md | 9 +- docs/quickstart.md | 12 +- src/republic/clients/chat.py | 122 ++++++++--------- src/republic/clients/parsing/completion.py | 11 ++ src/republic/clients/parsing/messages.py | 84 ++++++++++++ src/republic/core/execution.py | 148 +++++++++++++++++++-- src/republic/core/provider_policies.py | 2 +- src/republic/llm.py | 6 +- tests/test_provider_policies.py | 8 +- tests/test_responses_handling.py | 52 +++++++- 13 files changed, 366 insertions(+), 119 deletions(-) create mode 100644 src/republic/clients/parsing/messages.py diff --git a/README.md b/README.md index e3e05c6..9237c44 100644 --- a/README.md +++ b/README.md @@ -31,17 +31,13 @@ if not api_key: llm = LLM(model="openrouter:openrouter/free", api_key=api_key) result = llm.chat("Describe Republic in one sentence.", max_tokens=48) - -if result.error: - print(result.error.kind, result.error.message) -else: - print(result.value) +print(result) ``` ## Why It Feels Natural - **Plain Python**: The main flow is regular functions and branches, no extra DSL. -- **Structured Result**: Core interfaces return `StructuredOutput`, with stable `ErrorKind` values. +- **Unified API surface**: Use the same `chat/tool_calls/run_tools/stream/stream_events` methods across completion, responses, and messages transports. - **Tools without magic**: Supports both automatic and manual tool execution with clear debugging and auditing. - **Tape-first memory**: Use anchor/handoff to bound context windows and replay full evidence. - **Event streaming**: Subscribe to text deltas, tool calls, tool results, usage, and final state. diff --git a/docs/guides/chat.md b/docs/guides/chat.md index 62fafe3..ee34b6a 100644 --- a/docs/guides/chat.md +++ b/docs/guides/chat.md @@ -9,7 +9,7 @@ from republic import LLM llm = LLM(model="openrouter:openrouter/free", api_key="") out = llm.chat("Output exactly one word: ready", max_tokens=8) -print(out.value, out.error) +print(out) ``` ## Messages Mode @@ -22,17 +22,22 @@ messages = [ out = llm.chat(messages=messages, max_tokens=48) ``` -## Structured Error Handling +## Transport Format (`api_format`) + +Republic exposes one public chat/tool/stream interface, and lets you choose the upstream API format explicitly: + +- `api_format="completion"` (default): chat-completions style. +- `api_format="responses"`: responses style. +- `api_format="messages"`: Anthropic messages style (only Anthropic models, including `openrouter:anthropic/...`). ```python -result = llm.chat("Write one sentence.") -if result.error: - if result.error.kind == "temporary": - print("retry later") - else: - print("fail fast:", result.error.message) +llm_completion = LLM(model="openai:gpt-4o-mini", api_key="", api_format="completion") +llm_responses = LLM(model="openrouter:openrouter/free", api_key="", api_format="responses") +llm_messages = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="", api_format="messages") ``` +The same public methods are used in all formats: `chat`, `tool_calls`, `run_tools`, `stream`, and `stream_events`. + ## Retries and Fallback ```python diff --git a/docs/guides/stream-events.md b/docs/guides/stream-events.md index 6328fab..c447425 100644 --- a/docs/guides/stream-events.md +++ b/docs/guides/stream-events.md @@ -5,6 +5,8 @@ Republic provides two streaming modes: - `stream(...)`: text deltas only. - `stream_events(...)`: full events including text, tools, usage, and final. +Both modes keep the same public API across `completion`, `responses`, and `messages` transports. + ## Text Stream ```python diff --git a/docs/guides/tools.md b/docs/guides/tools.md index 54ee506..89cbdea 100644 --- a/docs/guides/tools.md +++ b/docs/guides/tools.md @@ -5,6 +5,8 @@ Tool workflows have two paths: - Automatic execution: `llm.run_tools(...)` - Manual execution: `llm.tool_calls(...)` + `llm.tools.execute(...)` +The tool API is transport-agnostic: you can use the same calls under `api_format="completion"`, `"responses"`, or `"messages"`. + ## Define a Tool ```python @@ -23,7 +25,7 @@ from republic import LLM llm = LLM(model="openrouter:openai/gpt-4o-mini", api_key="") out = llm.run_tools("What is weather in Tokyo?", tools=[get_weather]) -print(out.kind) # text | tools | error +print(out.kind) # text | tools | error print(out.tool_results) print(out.error) ``` @@ -32,10 +34,7 @@ print(out.error) ```python calls = llm.tool_calls("Use get_weather for Berlin.", tools=[get_weather]) -if calls.error: - raise RuntimeError(calls.error.message) - -execution = llm.tools.execute(calls.value, tools=[get_weather]) +execution = llm.tools.execute(calls, tools=[get_weather]) print(execution.tool_results) print(execution.error) ``` diff --git a/docs/quickstart.md b/docs/quickstart.md index 78edb27..bdec6e0 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -20,11 +20,7 @@ llm = LLM(model="openrouter:openrouter/free", api_key="") ```python out = llm.chat("Write one short release note.", max_tokens=48) - -if out.error: - print("error:", out.error.kind, out.error.message) -else: - print("text:", out.value) +print("text:", out) ``` ## Step 3: Add an auditable trace to the session @@ -36,7 +32,7 @@ tape = llm.tape("release-notes") tape.handoff("draft_v1", state={"owner": "assistant"}) reply = tape.chat("Summarize the version changes in three bullets.", system_prompt="Keep it concise.") -print(reply.value) +print(reply) ``` ## Step 4: Handle failures and fallback @@ -50,7 +46,5 @@ llm = LLM( ) result = llm.chat("say hello", max_tokens=8) -if result.error: - # error.kind is one of invalid_input/config/provider/tool/temporary/not_found/unknown - print(result.error.kind, result.error.message) +print(result) ``` diff --git a/src/republic/clients/chat.py b/src/republic/clients/chat.py index 62a82b8..dd60237 100644 --- a/src/republic/clients/chat.py +++ b/src/republic/clients/chat.py @@ -9,6 +9,7 @@ from typing import Any, Literal from republic.clients.parsing import completion as completion_parser +from republic.clients.parsing import messages as messages_parser from republic.clients.parsing import responses as responses_parser from republic.clients.parsing.common import expand_tool_calls from republic.clients.parsing.common import field as _field @@ -31,6 +32,7 @@ from republic.tools.schema import ToolInput, ToolSet, normalize_tools MessageInput = dict[str, Any] +TransportKind = Literal["completion", "responses", "messages"] @dataclass(frozen=True) @@ -213,22 +215,42 @@ def default_context(self) -> TapeContext: return self._tape.default_context @staticmethod - def _unwrap_response(response: Any) -> tuple[Any, Literal["completion", "responses"] | None]: + def _unwrap_response(response: Any) -> tuple[Any, TransportKind | None]: if isinstance(response, TransportResponse): return response.payload, response.transport return response, None + @staticmethod + def _resolve_transport( + payload: Any, + transport: TransportKind | None = None, + ) -> TransportKind: + if transport is not None: + return transport + if _field(payload, "output") is not None: + return "responses" + event_type = _field(payload, "type") + if isinstance(event_type, str) and event_type.startswith("response."): + return "responses" + return "completion" + + @staticmethod + def _parser_for_transport(transport: TransportKind): + if transport == "responses": + return responses_parser + if transport == "messages": + return messages_parser + return completion_parser + @staticmethod def _is_non_stream_response( response: Any, *, - transport: Literal["completion", "responses"] | None = None, + transport: TransportKind | None = None, ) -> bool: - if transport == "responses": - return responses_parser.is_non_stream_response(response) - if transport == "completion": - return completion_parser.is_non_stream_response(response) - return responses_parser.is_non_stream_response(response) + effective_transport = ChatClient._resolve_transport(response, transport) + parser = ChatClient._parser_for_transport(effective_transport) + return parser.is_non_stream_response(response) def _validate_chat_input( self, @@ -1566,7 +1588,7 @@ async def _iterator() -> AsyncIterator[str]: def _chunk_has_tool_calls( chunk: Any, *, - transport: Literal["completion", "responses"] | None = None, + transport: TransportKind | None = None, ) -> bool: return bool(ChatClient._extract_chunk_tool_call_deltas(chunk, transport=transport)) @@ -1574,31 +1596,21 @@ def _chunk_has_tool_calls( def _extract_chunk_tool_call_deltas( chunk: Any, *, - transport: Literal["completion", "responses"] | None = None, + transport: TransportKind | None = None, ) -> list[Any]: - if transport == "responses": - return responses_parser.extract_chunk_tool_call_deltas(chunk) - if transport == "completion": - return completion_parser.extract_chunk_tool_call_deltas(chunk) - responses_deltas = responses_parser.extract_chunk_tool_call_deltas(chunk) - if responses_deltas: - return responses_deltas - return completion_parser.extract_chunk_tool_call_deltas(chunk) + effective_transport = ChatClient._resolve_transport(chunk, transport) + parser = ChatClient._parser_for_transport(effective_transport) + return parser.extract_chunk_tool_call_deltas(chunk) @staticmethod def _extract_chunk_text( chunk: Any, *, - transport: Literal["completion", "responses"] | None = None, + transport: TransportKind | None = None, ) -> str: - if transport == "responses": - return responses_parser.extract_chunk_text(chunk) - if transport == "completion": - return completion_parser.extract_chunk_text(chunk) - responses_text = responses_parser.extract_chunk_text(chunk) - if responses_text: - return responses_text - return completion_parser.extract_chunk_text(chunk) + effective_transport = ChatClient._resolve_transport(chunk, transport) + parser = ChatClient._parser_for_transport(effective_transport) + return parser.extract_chunk_text(chunk) def _build_event_stream( self, @@ -1759,7 +1771,7 @@ def _build_event_stream_from_response( provider_name: str, model_id: str, *, - transport: Literal["completion", "responses"] | None = None, + transport: TransportKind | None = None, ) -> StreamEvents: text = self._extract_text(response, transport=transport) tool_calls = self._extract_tool_calls(response, transport=transport) @@ -1817,7 +1829,7 @@ def _build_async_event_stream_from_response( provider_name: str, model_id: str, *, - transport: Literal["completion", "responses"] | None = None, + transport: TransportKind | None = None, ) -> AsyncStreamEvents: text = self._extract_text(response, transport=transport) tool_calls = self._extract_tool_calls(response, transport=transport) @@ -1916,67 +1928,37 @@ def _make_tool_context(prepared: PreparedChat, provider_name: str, model_id: str state={} if prepared.context is None else prepared.context.state, ) - @staticmethod - def _extract_text_from_responses_output(output: Any) -> str: - return responses_parser.extract_text_from_output(output) - @staticmethod def _extract_text( response: Any, *, - transport: Literal["completion", "responses"] | None = None, + transport: TransportKind | None = None, ) -> str: payload, detected_transport = ChatClient._unwrap_response(response) - effective_transport = transport or detected_transport - if effective_transport == "responses": - return responses_parser.extract_text(payload) - if effective_transport == "completion": - return completion_parser.extract_text(payload) - responses_text = responses_parser.extract_text(payload) - if responses_text: - return responses_text - return completion_parser.extract_text(payload) + effective_transport = ChatClient._resolve_transport(payload, transport or detected_transport) + parser = ChatClient._parser_for_transport(effective_transport) + return parser.extract_text(payload) @staticmethod def _extract_tool_calls( response: Any, *, - transport: Literal["completion", "responses"] | None = None, + transport: TransportKind | None = None, ) -> list[dict[str, Any]]: payload, detected_transport = ChatClient._unwrap_response(response) - effective_transport = transport or detected_transport + effective_transport = ChatClient._resolve_transport(payload, transport or detected_transport) if effective_transport == "responses": return responses_parser.extract_tool_calls(_field(payload, "output")) - if effective_transport == "completion": - return completion_parser.extract_tool_calls(payload) - output = _field(payload, "output") - if output is not None: - return ChatClient._extract_responses_tool_calls(output) - return ChatClient._extract_completion_tool_calls(payload) - - @staticmethod - def _extract_responses_tool_calls(output: Any) -> list[dict[str, Any]]: - return responses_parser.extract_tool_calls(output) - - @staticmethod - def _extract_completion_tool_calls(response: Any) -> list[dict[str, Any]]: - return completion_parser.extract_tool_calls(response) + parser = ChatClient._parser_for_transport(effective_transport) + return parser.extract_tool_calls(payload) @staticmethod def _extract_usage( response: Any, *, - transport: Literal["completion", "responses"] | None = None, + transport: TransportKind | None = None, ) -> dict[str, Any] | None: payload, detected_transport = ChatClient._unwrap_response(response) - effective_transport = transport or detected_transport - if effective_transport == "completion": - usage = _field(payload, "usage") - if usage is None: - return None - if isinstance(usage, dict): - return dict(usage) - if hasattr(usage, "model_dump"): - return usage.model_dump() - return None - return responses_parser.extract_usage(payload) + effective_transport = ChatClient._resolve_transport(payload, transport or detected_transport) + parser = ChatClient._parser_for_transport(effective_transport) + return parser.extract_usage(payload) diff --git a/src/republic/clients/parsing/completion.py b/src/republic/clients/parsing/completion.py index 5d4dfdf..b32a507 100644 --- a/src/republic/clients/parsing/completion.py +++ b/src/republic/clients/parsing/completion.py @@ -71,3 +71,14 @@ def extract_tool_calls(response: Any) -> list[dict[str, Any]]: entry["type"] = call_type calls.append(entry) return expand_tool_calls(calls) + + +def extract_usage(response: Any) -> dict[str, Any] | None: + usage = field(response, "usage") + if usage is None: + return None + if isinstance(usage, dict): + return dict(usage) + if hasattr(usage, "model_dump"): + return usage.model_dump() + return None diff --git a/src/republic/clients/parsing/messages.py b/src/republic/clients/parsing/messages.py new file mode 100644 index 0000000..d12dbb7 --- /dev/null +++ b/src/republic/clients/parsing/messages.py @@ -0,0 +1,84 @@ +"""Anthropic messages shape parsing.""" + +from __future__ import annotations + +from typing import Any + +from republic.clients.parsing.common import expand_tool_calls, field + + +def is_non_stream_response(response: Any) -> bool: + return isinstance(response, str) or field(response, "choices") is not None + + +def extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: + choices = field(chunk, "choices") + if not choices: + return [] + delta = field(choices[0], "delta") + if delta is None: + return [] + return field(delta, "tool_calls") or [] + + +def extract_chunk_text(chunk: Any) -> str: + choices = field(chunk, "choices") + if not choices: + return "" + delta = field(choices[0], "delta") + if delta is None: + return "" + return field(delta, "content", "") or "" + + +def extract_text(response: Any) -> str: + if isinstance(response, str): + return response + + choices = field(response, "choices") + if not choices: + return "" + message = field(choices[0], "message") + if message is None: + return "" + return field(message, "content", "") or "" + + +def extract_tool_calls(response: Any) -> list[dict[str, Any]]: + choices = field(response, "choices") + if not choices: + return [] + message = field(choices[0], "message") + if message is None: + return [] + tool_calls = field(message, "tool_calls") or [] + calls: list[dict[str, Any]] = [] + for tool_call in tool_calls: + function = field(tool_call, "function") + if function is None: + continue + entry: dict[str, Any] = { + "function": { + "name": field(function, "name"), + "arguments": field(function, "arguments"), + } + } + call_id = field(tool_call, "id") + if call_id: + entry["id"] = call_id + call_type = field(tool_call, "type") + if call_type: + entry["type"] = call_type + calls.append(entry) + return expand_tool_calls(calls) + + +def extract_usage(response: Any) -> dict[str, Any] | None: + usage = field(response, "usage") + if usage is None: + return None + if isinstance(usage, dict): + return dict(usage) + if hasattr(usage, "model_dump"): + return usage.model_dump() + return None diff --git a/src/republic/core/execution.py b/src/republic/core/execution.py index 0011e12..1269161 100644 --- a/src/republic/core/execution.py +++ b/src/republic/core/execution.py @@ -50,7 +50,7 @@ class AttemptOutcome: @dataclass(frozen=True) class TransportResponse: - transport: Literal["completion", "responses"] + transport: Literal["completion", "responses", "messages"] payload: Any @@ -69,7 +69,7 @@ def __init__( api_key: str | dict[str, str] | None, api_base: str | dict[str, str] | None, client_args: dict[str, Any], - api_format: Literal["completion", "responses", "anthropic_messages"], + api_format: Literal["completion", "responses", "messages"], verbose: int, error_classifier: Callable[[Exception], ErrorKind | None] | None = None, ) -> None: @@ -426,19 +426,19 @@ def _selected_transport( provider_name: str, model_id: str, tools_payload: list[dict[str, Any]] | None, - ) -> Literal["completion", "responses"]: + ) -> Literal["completion", "responses", "messages"]: if self._api_format == "completion": return "completion" - if self._api_format == "anthropic_messages": - if not provider_policies.supports_anthropic_messages_format( + if self._api_format == "messages": + if not provider_policies.supports_messages_format( provider_name=provider_name, model_id=model_id, ): raise RepublicError( ErrorKind.INVALID_INPUT, - f"{provider_name}:{model_id}: anthropic_messages format is only valid for Anthropic models", + f"{provider_name}:{model_id}: messages format is only valid for Anthropic models", ) - return "completion" + return "messages" reason = provider_policies.responses_rejection_reason( provider_name=provider_name, @@ -488,11 +488,64 @@ def _call_completion_sync( stream: bool, reasoning_effort: Any | None, kwargs: dict[str, Any], + ) -> Any: + return self._call_completion_like_sync( + transport="completion", + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) + + def _call_messages_sync( + self, + *, + client: AnyLLM, + provider_name: str, + model_id: str, + messages_payload: list[dict[str, Any]], + tools_payload: list[dict[str, Any]] | None, + max_tokens: int | None, + stream: bool, + reasoning_effort: Any | None, + kwargs: dict[str, Any], + ) -> Any: + return self._call_completion_like_sync( + transport="messages", + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) + + def _call_completion_like_sync( + self, + *, + transport: Literal["completion", "messages"], + client: AnyLLM, + provider_name: str, + model_id: str, + messages_payload: list[dict[str, Any]], + tools_payload: list[dict[str, Any]] | None, + max_tokens: int | None, + stream: bool, + reasoning_effort: Any | None, + kwargs: dict[str, Any], ) -> Any: completion_kwargs = self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs) completion_kwargs = self._with_default_completion_stream_options(provider_name, stream, completion_kwargs) return TransportResponse( - transport="completion", + transport=transport, payload=client.completion( model=model_id, messages=messages_payload, @@ -541,11 +594,64 @@ async def _call_completion_async( stream: bool, reasoning_effort: Any | None, kwargs: dict[str, Any], + ) -> Any: + return await self._call_completion_like_async( + transport="completion", + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) + + async def _call_messages_async( + self, + *, + client: AnyLLM, + provider_name: str, + model_id: str, + messages_payload: list[dict[str, Any]], + tools_payload: list[dict[str, Any]] | None, + max_tokens: int | None, + stream: bool, + reasoning_effort: Any | None, + kwargs: dict[str, Any], + ) -> Any: + return await self._call_completion_like_async( + transport="messages", + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) + + async def _call_completion_like_async( + self, + *, + transport: Literal["completion", "messages"], + client: AnyLLM, + provider_name: str, + model_id: str, + messages_payload: list[dict[str, Any]], + tools_payload: list[dict[str, Any]] | None, + max_tokens: int | None, + stream: bool, + reasoning_effort: Any | None, + kwargs: dict[str, Any], ) -> Any: completion_kwargs = self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs) completion_kwargs = self._with_default_completion_stream_options(provider_name, stream, completion_kwargs) return TransportResponse( - transport="completion", + transport=transport, payload=await client.acompletion( model=model_id, messages=messages_payload, @@ -586,6 +692,18 @@ def _call_client_sync( reasoning_effort=reasoning_effort, kwargs=kwargs, ) + if transport == "messages": + return self._call_messages_sync( + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) return self._call_completion_sync( client=client, provider_name=provider_name, @@ -628,6 +746,18 @@ async def _call_client_async( reasoning_effort=reasoning_effort, kwargs=kwargs, ) + if transport == "messages": + return await self._call_messages_async( + client=client, + provider_name=provider_name, + model_id=model_id, + messages_payload=messages_payload, + tools_payload=tools_payload, + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + kwargs=kwargs, + ) return await self._call_completion_async( client=client, provider_name=provider_name, diff --git a/src/republic/core/provider_policies.py b/src/republic/core/provider_policies.py index 132183d..6c38280 100644 --- a/src/republic/core/provider_policies.py +++ b/src/republic/core/provider_policies.py @@ -72,7 +72,7 @@ def responses_rejection_reason( return "responses format is not supported by this provider" -def supports_anthropic_messages_format(*, provider_name: str, model_id: str) -> bool: +def supports_messages_format(*, provider_name: str, model_id: str) -> bool: normalized_provider = _normalize_provider_name(provider_name) normalized_model = model_id.strip().lower() return normalized_provider == "anthropic" or normalized_model.startswith("anthropic/") diff --git a/src/republic/llm.py b/src/republic/llm.py index 54278de..cae1851 100644 --- a/src/republic/llm.py +++ b/src/republic/llm.py @@ -48,7 +48,7 @@ def __init__( api_key: str | dict[str, str] | None = None, api_base: str | dict[str, str] | None = None, client_args: dict[str, Any] | None = None, - api_format: Literal["completion", "responses", "anthropic_messages"] = "completion", + api_format: Literal["completion", "responses", "messages"] = "completion", verbose: int = 0, tape_store: TapeStore | AsyncTapeStore | None = None, context: TapeContext | None = None, @@ -58,10 +58,10 @@ def __init__( raise RepublicError(ErrorKind.INVALID_INPUT, "verbose must be 0, 1, or 2") if max_retries < 0: raise RepublicError(ErrorKind.INVALID_INPUT, "max_retries must be >= 0") - if api_format not in {"completion", "responses", "anthropic_messages"}: + if api_format not in {"completion", "responses", "messages"}: raise RepublicError( ErrorKind.INVALID_INPUT, - "api_format must be 'completion', 'responses', or 'anthropic_messages'", + "api_format must be 'completion', 'responses', or 'messages'", ) if not model: diff --git a/tests/test_provider_policies.py b/tests/test_provider_policies.py index 7a1412b..b220397 100644 --- a/tests/test_provider_policies.py +++ b/tests/test_provider_policies.py @@ -95,16 +95,16 @@ def test_responses_rejection_reason_for_openrouter_anthropic_tools() -> None: assert "tools" in reason -def test_supports_anthropic_messages_format() -> None: - assert provider_policies.supports_anthropic_messages_format( +def test_supports_messages_format() -> None: + assert provider_policies.supports_messages_format( provider_name="anthropic", model_id="claude-3-5-haiku-latest", ) - assert provider_policies.supports_anthropic_messages_format( + assert provider_policies.supports_messages_format( provider_name="openrouter", model_id="anthropic/claude-3.5-haiku", ) - assert not provider_policies.supports_anthropic_messages_format( + assert not provider_policies.supports_messages_format( provider_name="openai", model_id="gpt-4o-mini", ) diff --git a/tests/test_responses_handling.py b/tests/test_responses_handling.py index 4d0cdd5..ff94d15 100644 --- a/tests/test_responses_handling.py +++ b/tests/test_responses_handling.py @@ -189,11 +189,11 @@ def test_openrouter_anthropic_tools_work_with_completion_format(fake_anyllm) -> assert client.calls[-1].get("responses") is None -def test_anthropic_messages_format_maps_to_completion(fake_anyllm) -> None: +def test_messages_format_maps_to_completion(fake_anyllm) -> None: client = fake_anyllm.ensure("openrouter") client.queue_completion(make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')])) - llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="anthropic_messages") + llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) assert len(calls) == 1 @@ -201,8 +201,52 @@ def test_anthropic_messages_format_maps_to_completion(fake_anyllm) -> None: assert client.calls[-1].get("responses") is None -def test_anthropic_messages_format_rejects_non_anthropic_model(fake_anyllm) -> None: - llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", api_format="anthropic_messages") +def test_messages_chat_uses_completion(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_completion(make_response(text="ready")) + + llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") + result = llm.chat("Reply with ready") + + assert result == "ready" + assert client.calls[-1].get("responses") is None + + +def test_messages_stream_uses_completion_and_collects_usage(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_completion(iter(_completion_stream_text_items())) + + llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") + stream = llm.stream("Say hello") + text = "".join(list(stream)) + + assert text == "hello world" + assert stream.error is None + assert stream.usage == {"total_tokens": 7} + assert client.calls[-1].get("responses") is None + assert client.calls[-1]["stream"] is True + + +def test_stream_events_parity_between_completion_and_messages(fake_anyllm) -> None: + completion_client = fake_anyllm.ensure("openai") + completion_client.queue_completion(iter(_completion_stream_event_items())) + + messages_client = fake_anyllm.ensure("openrouter") + messages_client.queue_completion(iter(_completion_stream_event_items())) + + completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + messages_llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") + + completion_events = list(completion_llm.stream_events("Call echo for tokyo", tools=[echo])) + messages_events = list(messages_llm.stream_events("Call echo for tokyo", tools=[echo])) + + assert completion_client.calls[-1].get("responses") is None + assert messages_client.calls[-1].get("responses") is None + assert _compact_stream_events(completion_events) == _compact_stream_events(messages_events) + + +def test_messages_format_rejects_non_anthropic_model(fake_anyllm) -> None: + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", api_format="messages") with pytest.raises(ErrorPayload) as exc_info: llm.chat("hi") assert exc_info.value.kind == "invalid_input" From 96be279f953bee838401699cd262783da7693569 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 2 Mar 2026 16:19:05 +0000 Subject: [PATCH 09/14] refactor: centralize transport parser/caller dispatch --- src/republic/clients/chat.py | 29 +++------- src/republic/clients/parsing/__init__.py | 39 ++++++++++++++ src/republic/clients/parsing/common.py | 6 +-- src/republic/clients/parsing/responses.py | 3 +- src/republic/core/execution.py | 66 +++++++---------------- tests/test_parsing_registry.py | 18 +++++++ 6 files changed, 87 insertions(+), 74 deletions(-) create mode 100644 tests/test_parsing_registry.py diff --git a/src/republic/clients/chat.py b/src/republic/clients/chat.py index dd60237..7e2c322 100644 --- a/src/republic/clients/chat.py +++ b/src/republic/clients/chat.py @@ -6,11 +6,9 @@ from collections.abc import AsyncIterator, Callable, Iterator from dataclasses import dataclass from functools import partial -from typing import Any, Literal +from typing import Any -from republic.clients.parsing import completion as completion_parser -from republic.clients.parsing import messages as messages_parser -from republic.clients.parsing import responses as responses_parser +from republic.clients.parsing import TransportKind, parser_for_transport from republic.clients.parsing.common import expand_tool_calls from republic.clients.parsing.common import field as _field from republic.core.errors import ErrorKind, RepublicError @@ -32,7 +30,6 @@ from republic.tools.schema import ToolInput, ToolSet, normalize_tools MessageInput = dict[str, Any] -TransportKind = Literal["completion", "responses", "messages"] @dataclass(frozen=True) @@ -234,14 +231,6 @@ def _resolve_transport( return "responses" return "completion" - @staticmethod - def _parser_for_transport(transport: TransportKind): - if transport == "responses": - return responses_parser - if transport == "messages": - return messages_parser - return completion_parser - @staticmethod def _is_non_stream_response( response: Any, @@ -249,7 +238,7 @@ def _is_non_stream_response( transport: TransportKind | None = None, ) -> bool: effective_transport = ChatClient._resolve_transport(response, transport) - parser = ChatClient._parser_for_transport(effective_transport) + parser = parser_for_transport(effective_transport) return parser.is_non_stream_response(response) def _validate_chat_input( @@ -1599,7 +1588,7 @@ def _extract_chunk_tool_call_deltas( transport: TransportKind | None = None, ) -> list[Any]: effective_transport = ChatClient._resolve_transport(chunk, transport) - parser = ChatClient._parser_for_transport(effective_transport) + parser = parser_for_transport(effective_transport) return parser.extract_chunk_tool_call_deltas(chunk) @staticmethod @@ -1609,7 +1598,7 @@ def _extract_chunk_text( transport: TransportKind | None = None, ) -> str: effective_transport = ChatClient._resolve_transport(chunk, transport) - parser = ChatClient._parser_for_transport(effective_transport) + parser = parser_for_transport(effective_transport) return parser.extract_chunk_text(chunk) def _build_event_stream( @@ -1936,7 +1925,7 @@ def _extract_text( ) -> str: payload, detected_transport = ChatClient._unwrap_response(response) effective_transport = ChatClient._resolve_transport(payload, transport or detected_transport) - parser = ChatClient._parser_for_transport(effective_transport) + parser = parser_for_transport(effective_transport) return parser.extract_text(payload) @staticmethod @@ -1947,9 +1936,7 @@ def _extract_tool_calls( ) -> list[dict[str, Any]]: payload, detected_transport = ChatClient._unwrap_response(response) effective_transport = ChatClient._resolve_transport(payload, transport or detected_transport) - if effective_transport == "responses": - return responses_parser.extract_tool_calls(_field(payload, "output")) - parser = ChatClient._parser_for_transport(effective_transport) + parser = parser_for_transport(effective_transport) return parser.extract_tool_calls(payload) @staticmethod @@ -1960,5 +1947,5 @@ def _extract_usage( ) -> dict[str, Any] | None: payload, detected_transport = ChatClient._unwrap_response(response) effective_transport = ChatClient._resolve_transport(payload, transport or detected_transport) - parser = ChatClient._parser_for_transport(effective_transport) + parser = parser_for_transport(effective_transport) return parser.extract_usage(payload) diff --git a/src/republic/clients/parsing/__init__.py b/src/republic/clients/parsing/__init__.py index 9b7915b..b2449e1 100644 --- a/src/republic/clients/parsing/__init__.py +++ b/src/republic/clients/parsing/__init__.py @@ -1 +1,40 @@ """Parsing helpers for provider response payloads.""" + +from __future__ import annotations + +from typing import Any, Literal, Protocol + +from republic.clients.parsing import completion, messages, responses + +TransportKind = Literal["completion", "responses", "messages"] + + +class TransportParser(Protocol): + @staticmethod + def is_non_stream_response(response: Any) -> bool: ... + + @staticmethod + def extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: ... + + @staticmethod + def extract_chunk_text(chunk: Any) -> str: ... + + @staticmethod + def extract_text(response: Any) -> str: ... + + @staticmethod + def extract_tool_calls(response: Any) -> list[dict[str, Any]]: ... + + @staticmethod + def extract_usage(response: Any) -> dict[str, Any] | None: ... + + +def parser_for_transport(transport: TransportKind) -> TransportParser: + if transport == "responses": + return responses + if transport == "messages": + return messages + return completion + + +__all__ = ["TransportKind", "TransportParser", "parser_for_transport"] diff --git a/src/republic/clients/parsing/common.py b/src/republic/clients/parsing/common.py index d46f51f..0a47ecd 100644 --- a/src/republic/clients/parsing/common.py +++ b/src/republic/clients/parsing/common.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +from itertools import chain from typing import Any @@ -13,10 +14,7 @@ def field(data: Any, key: str, default: Any = None) -> Any: def expand_tool_calls(calls: list[dict[str, Any]]) -> list[dict[str, Any]]: - expanded: list[dict[str, Any]] = [] - for call in calls: - expanded.extend(_expand_tool_call(call)) - return expanded + return list(chain.from_iterable(_expand_tool_call(call) for call in calls)) def _expand_tool_call(call: dict[str, Any]) -> list[dict[str, Any]]: diff --git a/src/republic/clients/parsing/responses.py b/src/republic/clients/parsing/responses.py index 1da7ed2..b1c405f 100644 --- a/src/republic/clients/parsing/responses.py +++ b/src/republic/clients/parsing/responses.py @@ -103,7 +103,8 @@ def extract_text(response: Any) -> str: return extract_text_from_output(field(response, "output")) -def extract_tool_calls(output: Any) -> list[dict[str, Any]]: +def extract_tool_calls(response: Any) -> list[dict[str, Any]]: + output = response if isinstance(response, list) else field(response, "output") if not isinstance(output, list): return [] calls: list[dict[str, Any]] = [] diff --git a/src/republic/core/execution.py b/src/republic/core/execution.py index 1269161..6fb0998 100644 --- a/src/republic/core/execution.py +++ b/src/republic/core/execution.py @@ -454,6 +454,7 @@ def _call_responses_sync( self, *, client: AnyLLM, + provider_name: str, model_id: str, messages_payload: list[dict[str, Any]], tools_payload: list[dict[str, Any]] | None, @@ -462,6 +463,7 @@ def _call_responses_sync( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: + _ = provider_name instructions, input_items = self._split_messages_for_responses(messages_payload) responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) return TransportResponse( @@ -560,6 +562,7 @@ async def _call_responses_async( self, *, client: AnyLLM, + provider_name: str, model_id: str, messages_payload: list[dict[str, Any]], tools_payload: list[dict[str, Any]] | None, @@ -568,6 +571,7 @@ async def _call_responses_async( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: + _ = provider_name instructions, input_items = self._split_messages_for_responses(messages_payload) responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) return TransportResponse( @@ -681,30 +685,13 @@ def _call_client_sync( model_id=model_id, tools_payload=tools_payload, ) - if transport == "responses": - return self._call_responses_sync( - client=client, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - if transport == "messages": - return self._call_messages_sync( - client=client, - provider_name=provider_name, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - return self._call_completion_sync( + callers: dict[str, Callable[..., Any]] = { + "completion": self._call_completion_sync, + "responses": self._call_responses_sync, + "messages": self._call_messages_sync, + } + call_transport = callers[transport] + return call_transport( client=client, provider_name=provider_name, model_id=model_id, @@ -735,30 +722,13 @@ async def _call_client_async( model_id=model_id, tools_payload=tools_payload, ) - if transport == "responses": - return await self._call_responses_async( - client=client, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - if transport == "messages": - return await self._call_messages_async( - client=client, - provider_name=provider_name, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - return await self._call_completion_async( + callers: dict[str, Callable[..., Any]] = { + "completion": self._call_completion_async, + "responses": self._call_responses_async, + "messages": self._call_messages_async, + } + call_transport = callers[transport] + return await call_transport( client=client, provider_name=provider_name, model_id=model_id, diff --git a/tests/test_parsing_registry.py b/tests/test_parsing_registry.py new file mode 100644 index 0000000..3a0ced7 --- /dev/null +++ b/tests/test_parsing_registry.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from republic.clients.parsing import parser_for_transport +from republic.clients.parsing import responses as responses_parser + +from .fakes import make_responses_function_call, make_responses_response + + +def test_parser_for_transport_returns_expected_modules() -> None: + assert parser_for_transport("completion").__name__.endswith(".completion") + assert parser_for_transport("responses").__name__.endswith(".responses") + assert parser_for_transport("messages").__name__.endswith(".messages") + + +def test_responses_extract_tool_calls_accepts_full_response() -> None: + response = make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) + calls = responses_parser.extract_tool_calls(response) + assert calls[0]["function"]["name"] == "echo" From da811b4881fd3f57e7832ad6a862d1806debf7e3 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 2 Mar 2026 16:33:17 +0000 Subject: [PATCH 10/14] refactor: clean parser abstraction and transport call flow --- src/republic/clients/chat.py | 4 + src/republic/clients/parsing/__init__.py | 44 ++-- src/republic/clients/parsing/completion.py | 11 + src/republic/clients/parsing/messages.py | 11 + src/republic/clients/parsing/responses.py | 11 + src/republic/clients/parsing/types.py | 74 ++++++ src/republic/core/execution.py | 275 ++++++--------------- tests/test_parsing_registry.py | 32 ++- 8 files changed, 228 insertions(+), 234 deletions(-) create mode 100644 src/republic/clients/parsing/types.py diff --git a/src/republic/clients/chat.py b/src/republic/clients/chat.py index 7e2c322..0c49edc 100644 --- a/src/republic/clients/chat.py +++ b/src/republic/clients/chat.py @@ -224,8 +224,12 @@ def _resolve_transport( ) -> TransportKind: if transport is not None: return transport + if isinstance(payload, list): + return "responses" if _field(payload, "output") is not None: return "responses" + if _field(payload, "output_text") is not None: + return "responses" event_type = _field(payload, "type") if isinstance(event_type, str) and event_type.startswith("response."): return "responses" diff --git a/src/republic/clients/parsing/__init__.py b/src/republic/clients/parsing/__init__.py index b2449e1..4a2b3e5 100644 --- a/src/republic/clients/parsing/__init__.py +++ b/src/republic/clients/parsing/__init__.py @@ -2,39 +2,23 @@ from __future__ import annotations -from typing import Any, Literal, Protocol +from republic.clients.parsing.completion import PARSER as completion_parser +from republic.clients.parsing.messages import PARSER as messages_parser +from republic.clients.parsing.responses import PARSER as responses_parser +from republic.clients.parsing.types import BaseTransportParser, TransportKind, validate_transport_parser -from republic.clients.parsing import completion, messages, responses +_PARSERS: dict[TransportKind, BaseTransportParser] = { + "completion": completion_parser, + "responses": responses_parser, + "messages": messages_parser, +} -TransportKind = Literal["completion", "responses", "messages"] +for _name, _parser in _PARSERS.items(): + validate_transport_parser(_parser, name=_name) -class TransportParser(Protocol): - @staticmethod - def is_non_stream_response(response: Any) -> bool: ... +def parser_for_transport(transport: TransportKind) -> BaseTransportParser: + return _PARSERS[transport] - @staticmethod - def extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: ... - @staticmethod - def extract_chunk_text(chunk: Any) -> str: ... - - @staticmethod - def extract_text(response: Any) -> str: ... - - @staticmethod - def extract_tool_calls(response: Any) -> list[dict[str, Any]]: ... - - @staticmethod - def extract_usage(response: Any) -> dict[str, Any] | None: ... - - -def parser_for_transport(transport: TransportKind) -> TransportParser: - if transport == "responses": - return responses - if transport == "messages": - return messages - return completion - - -__all__ = ["TransportKind", "TransportParser", "parser_for_transport"] +__all__ = ["BaseTransportParser", "TransportKind", "parser_for_transport"] diff --git a/src/republic/clients/parsing/completion.py b/src/republic/clients/parsing/completion.py index b32a507..aa77313 100644 --- a/src/republic/clients/parsing/completion.py +++ b/src/republic/clients/parsing/completion.py @@ -5,6 +5,7 @@ from typing import Any from republic.clients.parsing.common import expand_tool_calls, field +from republic.clients.parsing.types import FunctionTransportParser def is_non_stream_response(response: Any) -> bool: @@ -82,3 +83,13 @@ def extract_usage(response: Any) -> dict[str, Any] | None: if hasattr(usage, "model_dump"): return usage.model_dump() return None + + +PARSER = FunctionTransportParser( + is_non_stream_response_fn=is_non_stream_response, + extract_chunk_tool_call_deltas_fn=extract_chunk_tool_call_deltas, + extract_chunk_text_fn=extract_chunk_text, + extract_text_fn=extract_text, + extract_tool_calls_fn=extract_tool_calls, + extract_usage_fn=extract_usage, +) diff --git a/src/republic/clients/parsing/messages.py b/src/republic/clients/parsing/messages.py index d12dbb7..739d1e2 100644 --- a/src/republic/clients/parsing/messages.py +++ b/src/republic/clients/parsing/messages.py @@ -5,6 +5,7 @@ from typing import Any from republic.clients.parsing.common import expand_tool_calls, field +from republic.clients.parsing.types import FunctionTransportParser def is_non_stream_response(response: Any) -> bool: @@ -82,3 +83,13 @@ def extract_usage(response: Any) -> dict[str, Any] | None: if hasattr(usage, "model_dump"): return usage.model_dump() return None + + +PARSER = FunctionTransportParser( + is_non_stream_response_fn=is_non_stream_response, + extract_chunk_tool_call_deltas_fn=extract_chunk_tool_call_deltas, + extract_chunk_text_fn=extract_chunk_text, + extract_text_fn=extract_text, + extract_tool_calls_fn=extract_tool_calls, + extract_usage_fn=extract_usage, +) diff --git a/src/republic/clients/parsing/responses.py b/src/republic/clients/parsing/responses.py index b1c405f..f1bec0b 100644 --- a/src/republic/clients/parsing/responses.py +++ b/src/republic/clients/parsing/responses.py @@ -6,6 +6,7 @@ from typing import Any from republic.clients.parsing.common import expand_tool_calls, field +from republic.clients.parsing.types import FunctionTransportParser def is_non_stream_response(response: Any) -> bool: @@ -144,3 +145,13 @@ def extract_usage(response: Any) -> dict[str, Any] | None: if value is not None: data[usage_field] = value return data or None + + +PARSER = FunctionTransportParser( + is_non_stream_response_fn=is_non_stream_response, + extract_chunk_tool_call_deltas_fn=extract_chunk_tool_call_deltas, + extract_chunk_text_fn=extract_chunk_text, + extract_text_fn=extract_text, + extract_tool_calls_fn=extract_tool_calls, + extract_usage_fn=extract_usage, +) diff --git a/src/republic/clients/parsing/types.py b/src/republic/clients/parsing/types.py new file mode 100644 index 0000000..080c4e7 --- /dev/null +++ b/src/republic/clients/parsing/types.py @@ -0,0 +1,74 @@ +"""Shared parser typing and validation primitives.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Literal + +TransportKind = Literal["completion", "responses", "messages"] + + +class BaseTransportParser(ABC): + @abstractmethod + def is_non_stream_response(self, response: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def extract_chunk_tool_call_deltas(self, chunk: Any) -> list[Any]: + raise NotImplementedError + + @abstractmethod + def extract_chunk_text(self, chunk: Any) -> str: + raise NotImplementedError + + @abstractmethod + def extract_text(self, response: Any) -> str: + raise NotImplementedError + + @abstractmethod + def extract_tool_calls(self, response: Any) -> list[dict[str, Any]]: + raise NotImplementedError + + @abstractmethod + def extract_usage(self, response: Any) -> dict[str, Any] | None: + raise NotImplementedError + + +@dataclass(frozen=True) +class FunctionTransportParser(BaseTransportParser): + is_non_stream_response_fn: Callable[[Any], bool] + extract_chunk_tool_call_deltas_fn: Callable[[Any], list[Any]] + extract_chunk_text_fn: Callable[[Any], str] + extract_text_fn: Callable[[Any], str] + extract_tool_calls_fn: Callable[[Any], list[dict[str, Any]]] + extract_usage_fn: Callable[[Any], dict[str, Any] | None] + + def is_non_stream_response(self, response: Any) -> bool: + return self.is_non_stream_response_fn(response) + + def extract_chunk_tool_call_deltas(self, chunk: Any) -> list[Any]: + return self.extract_chunk_tool_call_deltas_fn(chunk) + + def extract_chunk_text(self, chunk: Any) -> str: + return self.extract_chunk_text_fn(chunk) + + def extract_text(self, response: Any) -> str: + return self.extract_text_fn(response) + + def extract_tool_calls(self, response: Any) -> list[dict[str, Any]]: + return self.extract_tool_calls_fn(response) + + def extract_usage(self, response: Any) -> dict[str, Any] | None: + return self.extract_usage_fn(response) + + +class InvalidTransportParserError(TypeError): + def __init__(self, parser_name: str) -> None: + super().__init__(f"{parser_name} parser must inherit BaseTransportParser") + + +def validate_transport_parser(parser: object, *, name: str) -> None: + if not isinstance(parser, BaseTransportParser): + raise InvalidTransportParserError(name) diff --git a/src/republic/core/execution.py b/src/republic/core/execution.py index 6fb0998..409df0c 100644 --- a/src/republic/core/execution.py +++ b/src/republic/core/execution.py @@ -54,6 +54,19 @@ class TransportResponse: payload: Any +@dataclass(frozen=True) +class TransportCallRequest: + client: AnyLLM + provider_name: str + model_id: str + messages_payload: list[dict[str, Any]] + tools_payload: list[dict[str, Any]] | None + max_tokens: int | None + stream: bool + reasoning_effort: Any | None + kwargs: dict[str, Any] + + class LLMCore: """Shared LLM execution utilities (provider resolution, retries, client cache).""" @@ -452,216 +465,84 @@ def _selected_transport( def _call_responses_sync( self, - *, - client: AnyLLM, - provider_name: str, - model_id: str, - messages_payload: list[dict[str, Any]], - tools_payload: list[dict[str, Any]] | None, - max_tokens: int | None, - stream: bool, - reasoning_effort: Any | None, - kwargs: dict[str, Any], + request: TransportCallRequest, ) -> Any: - _ = provider_name - instructions, input_items = self._split_messages_for_responses(messages_payload) - responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) + instructions, input_items = self._split_messages_for_responses(request.messages_payload) + responses_kwargs = self._with_responses_reasoning(request.kwargs, request.reasoning_effort) return TransportResponse( transport="responses", - payload=client.responses( - model=model_id, + payload=request.client.responses( + model=request.model_id, input_data=input_items, - tools=self._convert_tools_for_responses(tools_payload), - stream=stream, + tools=self._convert_tools_for_responses(request.tools_payload), + stream=request.stream, instructions=instructions, - **self._decide_responses_kwargs(max_tokens, responses_kwargs), + **self._decide_responses_kwargs(request.max_tokens, responses_kwargs), ), ) - def _call_completion_sync( - self, - *, - client: AnyLLM, - provider_name: str, - model_id: str, - messages_payload: list[dict[str, Any]], - tools_payload: list[dict[str, Any]] | None, - max_tokens: int | None, - stream: bool, - reasoning_effort: Any | None, - kwargs: dict[str, Any], - ) -> Any: - return self._call_completion_like_sync( - transport="completion", - client=client, - provider_name=provider_name, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - - def _call_messages_sync( - self, - *, - client: AnyLLM, - provider_name: str, - model_id: str, - messages_payload: list[dict[str, Any]], - tools_payload: list[dict[str, Any]] | None, - max_tokens: int | None, - stream: bool, - reasoning_effort: Any | None, - kwargs: dict[str, Any], - ) -> Any: - return self._call_completion_like_sync( - transport="messages", - client=client, - provider_name=provider_name, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - def _call_completion_like_sync( self, *, transport: Literal["completion", "messages"], - client: AnyLLM, - provider_name: str, - model_id: str, - messages_payload: list[dict[str, Any]], - tools_payload: list[dict[str, Any]] | None, - max_tokens: int | None, - stream: bool, - reasoning_effort: Any | None, - kwargs: dict[str, Any], + request: TransportCallRequest, ) -> Any: - completion_kwargs = self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs) - completion_kwargs = self._with_default_completion_stream_options(provider_name, stream, completion_kwargs) + completion_kwargs = self._decide_kwargs_for_provider(request.provider_name, request.max_tokens, request.kwargs) + completion_kwargs = self._with_default_completion_stream_options( + request.provider_name, + request.stream, + completion_kwargs, + ) return TransportResponse( transport=transport, - payload=client.completion( - model=model_id, - messages=messages_payload, - tools=tools_payload, - stream=stream, - reasoning_effort=reasoning_effort, + payload=request.client.completion( + model=request.model_id, + messages=request.messages_payload, + tools=request.tools_payload, + stream=request.stream, + reasoning_effort=request.reasoning_effort, **completion_kwargs, ), ) async def _call_responses_async( self, - *, - client: AnyLLM, - provider_name: str, - model_id: str, - messages_payload: list[dict[str, Any]], - tools_payload: list[dict[str, Any]] | None, - max_tokens: int | None, - stream: bool, - reasoning_effort: Any | None, - kwargs: dict[str, Any], + request: TransportCallRequest, ) -> Any: - _ = provider_name - instructions, input_items = self._split_messages_for_responses(messages_payload) - responses_kwargs = self._with_responses_reasoning(kwargs, reasoning_effort) + instructions, input_items = self._split_messages_for_responses(request.messages_payload) + responses_kwargs = self._with_responses_reasoning(request.kwargs, request.reasoning_effort) return TransportResponse( transport="responses", - payload=await client.aresponses( - model=model_id, + payload=await request.client.aresponses( + model=request.model_id, input_data=input_items, - tools=self._convert_tools_for_responses(tools_payload), - stream=stream, + tools=self._convert_tools_for_responses(request.tools_payload), + stream=request.stream, instructions=instructions, - **self._decide_responses_kwargs(max_tokens, responses_kwargs), + **self._decide_responses_kwargs(request.max_tokens, responses_kwargs), ), ) - async def _call_completion_async( - self, - *, - client: AnyLLM, - provider_name: str, - model_id: str, - messages_payload: list[dict[str, Any]], - tools_payload: list[dict[str, Any]] | None, - max_tokens: int | None, - stream: bool, - reasoning_effort: Any | None, - kwargs: dict[str, Any], - ) -> Any: - return await self._call_completion_like_async( - transport="completion", - client=client, - provider_name=provider_name, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - - async def _call_messages_async( - self, - *, - client: AnyLLM, - provider_name: str, - model_id: str, - messages_payload: list[dict[str, Any]], - tools_payload: list[dict[str, Any]] | None, - max_tokens: int | None, - stream: bool, - reasoning_effort: Any | None, - kwargs: dict[str, Any], - ) -> Any: - return await self._call_completion_like_async( - transport="messages", - client=client, - provider_name=provider_name, - model_id=model_id, - messages_payload=messages_payload, - tools_payload=tools_payload, - max_tokens=max_tokens, - stream=stream, - reasoning_effort=reasoning_effort, - kwargs=kwargs, - ) - async def _call_completion_like_async( self, *, transport: Literal["completion", "messages"], - client: AnyLLM, - provider_name: str, - model_id: str, - messages_payload: list[dict[str, Any]], - tools_payload: list[dict[str, Any]] | None, - max_tokens: int | None, - stream: bool, - reasoning_effort: Any | None, - kwargs: dict[str, Any], + request: TransportCallRequest, ) -> Any: - completion_kwargs = self._decide_kwargs_for_provider(provider_name, max_tokens, kwargs) - completion_kwargs = self._with_default_completion_stream_options(provider_name, stream, completion_kwargs) + completion_kwargs = self._decide_kwargs_for_provider(request.provider_name, request.max_tokens, request.kwargs) + completion_kwargs = self._with_default_completion_stream_options( + request.provider_name, + request.stream, + completion_kwargs, + ) return TransportResponse( transport=transport, - payload=await client.acompletion( - model=model_id, - messages=messages_payload, - tools=tools_payload, - stream=stream, - reasoning_effort=reasoning_effort, + payload=await request.client.acompletion( + model=request.model_id, + messages=request.messages_payload, + tools=request.tools_payload, + stream=request.stream, + reasoning_effort=request.reasoning_effort, **completion_kwargs, ), ) @@ -679,19 +560,7 @@ def _call_client_sync( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - transport = self._selected_transport( - client, - provider_name=provider_name, - model_id=model_id, - tools_payload=tools_payload, - ) - callers: dict[str, Callable[..., Any]] = { - "completion": self._call_completion_sync, - "responses": self._call_responses_sync, - "messages": self._call_messages_sync, - } - call_transport = callers[transport] - return call_transport( + request = TransportCallRequest( client=client, provider_name=provider_name, model_id=model_id, @@ -702,6 +571,15 @@ def _call_client_sync( reasoning_effort=reasoning_effort, kwargs=kwargs, ) + transport = self._selected_transport( + client, + provider_name=provider_name, + model_id=model_id, + tools_payload=tools_payload, + ) + if transport == "responses": + return self._call_responses_sync(request) + return self._call_completion_like_sync(transport=transport, request=request) async def _call_client_async( self, @@ -716,19 +594,7 @@ async def _call_client_async( reasoning_effort: Any | None, kwargs: dict[str, Any], ) -> Any: - transport = self._selected_transport( - client, - provider_name=provider_name, - model_id=model_id, - tools_payload=tools_payload, - ) - callers: dict[str, Callable[..., Any]] = { - "completion": self._call_completion_async, - "responses": self._call_responses_async, - "messages": self._call_messages_async, - } - call_transport = callers[transport] - return await call_transport( + request = TransportCallRequest( client=client, provider_name=provider_name, model_id=model_id, @@ -739,6 +605,15 @@ async def _call_client_async( reasoning_effort=reasoning_effort, kwargs=kwargs, ) + transport = self._selected_transport( + client, + provider_name=provider_name, + model_id=model_id, + tools_payload=tools_payload, + ) + if transport == "responses": + return await self._call_responses_async(request) + return await self._call_completion_like_async(transport=transport, request=request) @staticmethod def _split_messages_for_responses( diff --git a/tests/test_parsing_registry.py b/tests/test_parsing_registry.py index 3a0ced7..3bf7db9 100644 --- a/tests/test_parsing_registry.py +++ b/tests/test_parsing_registry.py @@ -1,18 +1,42 @@ from __future__ import annotations +from types import SimpleNamespace + +from republic.clients.chat import ChatClient from republic.clients.parsing import parser_for_transport from republic.clients.parsing import responses as responses_parser +from republic.clients.parsing.types import BaseTransportParser from .fakes import make_responses_function_call, make_responses_response -def test_parser_for_transport_returns_expected_modules() -> None: - assert parser_for_transport("completion").__name__.endswith(".completion") - assert parser_for_transport("responses").__name__.endswith(".responses") - assert parser_for_transport("messages").__name__.endswith(".messages") +def test_parser_for_transport_returns_parser_objects() -> None: + parser = parser_for_transport("completion") + assert isinstance(parser, BaseTransportParser) + assert callable(parser.extract_text) + assert callable(parser.extract_tool_calls) + assert callable(parser.extract_usage) + + parser = parser_for_transport("responses") + assert isinstance(parser, BaseTransportParser) + assert callable(parser.extract_text) + assert callable(parser.extract_tool_calls) + assert callable(parser.extract_usage) + + parser = parser_for_transport("messages") + assert isinstance(parser, BaseTransportParser) + assert callable(parser.extract_text) + assert callable(parser.extract_tool_calls) + assert callable(parser.extract_usage) def test_responses_extract_tool_calls_accepts_full_response() -> None: response = make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) calls = responses_parser.extract_tool_calls(response) assert calls[0]["function"]["name"] == "echo" + + +def test_chat_client_resolve_transport_treats_output_text_as_responses() -> None: + payload = SimpleNamespace(output_text="hello") + assert ChatClient._is_non_stream_response(payload) is True + assert ChatClient._extract_text(payload) == "hello" From bcb0e50c91698c908579b6b0f21d0c4d7d964609 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 2 Mar 2026 16:39:43 +0000 Subject: [PATCH 11/14] docs: prioritize structured error handling in guides --- README.md | 1 + docs/guides/chat.md | 39 ++++++++++++++++++++++++++++----------- docs/quickstart.md | 9 +++++++-- 3 files changed, 36 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 9237c44..b3b0ab2 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ print(result) ## Why It Feels Natural - **Plain Python**: The main flow is regular functions and branches, no extra DSL. +- **Structured error handling**: Errors are explicit and typed, so retry and fallback logic stays deterministic. - **Unified API surface**: Use the same `chat/tool_calls/run_tools/stream/stream_events` methods across completion, responses, and messages transports. - **Tools without magic**: Supports both automatic and manual tool execution with clear debugging and auditing. - **Tape-first memory**: Use anchor/handoff to bound context windows and replay full evidence. diff --git a/docs/guides/chat.md b/docs/guides/chat.md index ee34b6a..16cc597 100644 --- a/docs/guides/chat.md +++ b/docs/guides/chat.md @@ -22,22 +22,23 @@ messages = [ out = llm.chat(messages=messages, max_tokens=48) ``` -## Transport Format (`api_format`) +## Structured Error Handling -Republic exposes one public chat/tool/stream interface, and lets you choose the upstream API format explicitly: +```python +from republic import ErrorPayload, LLM -- `api_format="completion"` (default): chat-completions style. -- `api_format="responses"`: responses style. -- `api_format="messages"`: Anthropic messages style (only Anthropic models, including `openrouter:anthropic/...`). +llm = LLM(model="openrouter:openrouter/free", api_key="") -```python -llm_completion = LLM(model="openai:gpt-4o-mini", api_key="", api_format="completion") -llm_responses = LLM(model="openrouter:openrouter/free", api_key="", api_format="responses") -llm_messages = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="", api_format="messages") +try: + out = llm.chat("Write one sentence.", max_tokens=32) + print(out) +except ErrorPayload as error: + if error.kind == "temporary": + print("retry later") + else: + print("fail fast:", error.message) ``` -The same public methods are used in all formats: `chat`, `tool_calls`, `run_tools`, `stream`, and `stream_events`. - ## Retries and Fallback ```python @@ -52,3 +53,19 @@ out = llm.chat("Give me one deployment checklist item.") ``` Recommendation: keep `max_retries` small (for example 2-4), and pick fallback models that are slightly more stable while still meeting quality requirements. + +## Transport Format (`api_format`) + +If you need explicit upstream wire-format control, choose one transport: + +- `api_format="completion"` (default): chat-completions style. +- `api_format="responses"`: responses style. +- `api_format="messages"`: Anthropic messages style (only Anthropic models, including `openrouter:anthropic/...`). + +```python +llm_completion = LLM(model="openai:gpt-4o-mini", api_key="", api_format="completion") +llm_responses = LLM(model="openrouter:openrouter/free", api_key="", api_format="responses") +llm_messages = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="", api_format="messages") +``` + +The same public methods are used in all transports: `chat`, `tool_calls`, `run_tools`, `stream`, and `stream_events`. diff --git a/docs/quickstart.md b/docs/quickstart.md index bdec6e0..8abfc6b 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -38,6 +38,8 @@ print(reply) ## Step 4: Handle failures and fallback ```python +from republic import ErrorPayload, LLM + llm = LLM( model="openai:gpt-4o-mini", fallback_models=["openrouter:openrouter/free"], @@ -45,6 +47,9 @@ llm = LLM( api_key={"openai": "", "openrouter": ""}, ) -result = llm.chat("say hello", max_tokens=8) -print(result) +try: + result = llm.chat("say hello", max_tokens=8) + print(result) +except ErrorPayload as error: + print(error.kind, error.message) ``` From ae03c92ec98e6bc89916be4cc39cdcefb9d820b6 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 2 Mar 2026 16:58:12 +0000 Subject: [PATCH 12/14] refactor: trim redundant transport policy and doc noise --- README.md | 1 - docs/guides/chat.md | 16 ------- docs/guides/stream-events.md | 2 - docs/guides/tools.md | 2 - src/republic/clients/chat.py | 8 ---- src/republic/clients/parsing/__init__.py | 5 +- src/republic/clients/parsing/types.py | 10 ---- src/republic/core/provider_policies.py | 14 ------ tests/test_provider_policies.py | 60 ------------------------ 9 files changed, 1 insertion(+), 117 deletions(-) diff --git a/README.md b/README.md index b3b0ab2..9dd46c4 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,6 @@ print(result) - **Plain Python**: The main flow is regular functions and branches, no extra DSL. - **Structured error handling**: Errors are explicit and typed, so retry and fallback logic stays deterministic. -- **Unified API surface**: Use the same `chat/tool_calls/run_tools/stream/stream_events` methods across completion, responses, and messages transports. - **Tools without magic**: Supports both automatic and manual tool execution with clear debugging and auditing. - **Tape-first memory**: Use anchor/handoff to bound context windows and replay full evidence. - **Event streaming**: Subscribe to text deltas, tool calls, tool results, usage, and final state. diff --git a/docs/guides/chat.md b/docs/guides/chat.md index 16cc597..8717c19 100644 --- a/docs/guides/chat.md +++ b/docs/guides/chat.md @@ -53,19 +53,3 @@ out = llm.chat("Give me one deployment checklist item.") ``` Recommendation: keep `max_retries` small (for example 2-4), and pick fallback models that are slightly more stable while still meeting quality requirements. - -## Transport Format (`api_format`) - -If you need explicit upstream wire-format control, choose one transport: - -- `api_format="completion"` (default): chat-completions style. -- `api_format="responses"`: responses style. -- `api_format="messages"`: Anthropic messages style (only Anthropic models, including `openrouter:anthropic/...`). - -```python -llm_completion = LLM(model="openai:gpt-4o-mini", api_key="", api_format="completion") -llm_responses = LLM(model="openrouter:openrouter/free", api_key="", api_format="responses") -llm_messages = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="", api_format="messages") -``` - -The same public methods are used in all transports: `chat`, `tool_calls`, `run_tools`, `stream`, and `stream_events`. diff --git a/docs/guides/stream-events.md b/docs/guides/stream-events.md index c447425..6328fab 100644 --- a/docs/guides/stream-events.md +++ b/docs/guides/stream-events.md @@ -5,8 +5,6 @@ Republic provides two streaming modes: - `stream(...)`: text deltas only. - `stream_events(...)`: full events including text, tools, usage, and final. -Both modes keep the same public API across `completion`, `responses`, and `messages` transports. - ## Text Stream ```python diff --git a/docs/guides/tools.md b/docs/guides/tools.md index 89cbdea..0a8dab9 100644 --- a/docs/guides/tools.md +++ b/docs/guides/tools.md @@ -5,8 +5,6 @@ Tool workflows have two paths: - Automatic execution: `llm.run_tools(...)` - Manual execution: `llm.tool_calls(...)` + `llm.tools.execute(...)` -The tool API is transport-agnostic: you can use the same calls under `api_format="completion"`, `"responses"`, or `"messages"`. - ## Define a Tool ```python diff --git a/src/republic/clients/chat.py b/src/republic/clients/chat.py index 0c49edc..297225a 100644 --- a/src/republic/clients/chat.py +++ b/src/republic/clients/chat.py @@ -1577,14 +1577,6 @@ async def _iterator() -> AsyncIterator[str]: return AsyncTextStream(_iterator(), state=state) - @staticmethod - def _chunk_has_tool_calls( - chunk: Any, - *, - transport: TransportKind | None = None, - ) -> bool: - return bool(ChatClient._extract_chunk_tool_call_deltas(chunk, transport=transport)) - @staticmethod def _extract_chunk_tool_call_deltas( chunk: Any, diff --git a/src/republic/clients/parsing/__init__.py b/src/republic/clients/parsing/__init__.py index 4a2b3e5..66db4c0 100644 --- a/src/republic/clients/parsing/__init__.py +++ b/src/republic/clients/parsing/__init__.py @@ -5,7 +5,7 @@ from republic.clients.parsing.completion import PARSER as completion_parser from republic.clients.parsing.messages import PARSER as messages_parser from republic.clients.parsing.responses import PARSER as responses_parser -from republic.clients.parsing.types import BaseTransportParser, TransportKind, validate_transport_parser +from republic.clients.parsing.types import BaseTransportParser, TransportKind _PARSERS: dict[TransportKind, BaseTransportParser] = { "completion": completion_parser, @@ -13,9 +13,6 @@ "messages": messages_parser, } -for _name, _parser in _PARSERS.items(): - validate_transport_parser(_parser, name=_name) - def parser_for_transport(transport: TransportKind) -> BaseTransportParser: return _PARSERS[transport] diff --git a/src/republic/clients/parsing/types.py b/src/republic/clients/parsing/types.py index 080c4e7..5b69a5b 100644 --- a/src/republic/clients/parsing/types.py +++ b/src/republic/clients/parsing/types.py @@ -62,13 +62,3 @@ def extract_tool_calls(self, response: Any) -> list[dict[str, Any]]: def extract_usage(self, response: Any) -> dict[str, Any] | None: return self.extract_usage_fn(response) - - -class InvalidTransportParserError(TypeError): - def __init__(self, parser_name: str) -> None: - super().__init__(f"{parser_name} parser must inherit BaseTransportParser") - - -def validate_transport_parser(parser: object, *, name: str) -> None: - if not isinstance(parser, BaseTransportParser): - raise InvalidTransportParserError(name) diff --git a/src/republic/core/provider_policies.py b/src/republic/core/provider_policies.py index 6c38280..db99276 100644 --- a/src/republic/core/provider_policies.py +++ b/src/republic/core/provider_policies.py @@ -42,20 +42,6 @@ def _responses_tools_blocked_for_model(provider_name: str, model_id: str) -> boo return any(lowered_model.startswith(prefix) for prefix in policy.responses_tools_blocked_model_prefixes) -def should_attempt_responses( - *, - provider_name: str, - model_id: str, - has_tools: bool, - supports_responses: bool, -) -> bool: - if has_tools and _responses_tools_blocked_for_model(provider_name, model_id): - return False - if supports_responses: - return True - return provider_policy(provider_name).enable_responses_without_capability - - def responses_rejection_reason( *, provider_name: str, diff --git a/tests/test_provider_policies.py b/tests/test_provider_policies.py index b220397..12346da 100644 --- a/tests/test_provider_policies.py +++ b/tests/test_provider_policies.py @@ -1,66 +1,6 @@ from republic.core import provider_policies -def test_should_attempt_responses_accepts_provider_capability() -> None: - assert ( - provider_policies.should_attempt_responses( - provider_name="anthropic", - model_id="claude-3-5-haiku-latest", - has_tools=False, - supports_responses=True, - ) - is True - ) - - -def test_should_attempt_responses_openrouter_policy_fallback() -> None: - assert ( - provider_policies.should_attempt_responses( - provider_name="openrouter", - model_id="openai/gpt-4o-mini", - has_tools=False, - supports_responses=False, - ) - is True - ) - - -def test_should_attempt_responses_requires_explicit_policy_or_capability() -> None: - assert ( - provider_policies.should_attempt_responses( - provider_name="anthropic", - model_id="claude-3-5-haiku-latest", - has_tools=False, - supports_responses=False, - ) - is False - ) - - -def test_should_attempt_responses_openrouter_anthropic_tools_disabled() -> None: - assert ( - provider_policies.should_attempt_responses( - provider_name="openrouter", - model_id="anthropic/claude-3.5-haiku", - has_tools=True, - supports_responses=False, - ) - is False - ) - - -def test_should_attempt_responses_openrouter_anthropic_without_tools_enabled() -> None: - assert ( - provider_policies.should_attempt_responses( - provider_name="openrouter", - model_id="anthropic/claude-3.5-haiku", - has_tools=False, - supports_responses=False, - ) - is True - ) - - def test_responses_rejection_reason_none_when_openrouter_responses_available() -> None: assert ( provider_policies.responses_rejection_reason( From 65029640c347d904305e321058368554c9b5a30b Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 2 Mar 2026 17:02:14 +0000 Subject: [PATCH 13/14] test: remove redundant responses handling coverage --- tests/test_responses_handling.py | 398 ++----------------------------- 1 file changed, 17 insertions(+), 381 deletions(-) diff --git a/tests/test_responses_handling.py b/tests/test_responses_handling.py index ff94d15..57be480 100644 --- a/tests/test_responses_handling.py +++ b/tests/test_responses_handling.py @@ -5,7 +5,6 @@ import pytest from republic import LLM, tool -from republic.clients.chat import ChatClient from republic.core.execution import LLMCore from republic.core.results import ErrorPayload @@ -41,8 +40,6 @@ def _compact_stream_events(events: list[Any]) -> list[tuple[str, Any]]: compact.append(("tool_result", event.data["result"])) elif event.kind == "usage": compact.append(("usage", event.data)) - elif event.kind == "error": - compact.append(("error", event.data)) elif event.kind == "final": final = event.data compact.append(( @@ -61,45 +58,6 @@ def _compact_stream_events(events: list[Any]) -> list[tuple[str, Any]]: return compact -def _as_async_iter(items: list[Any]) -> Any: - async def _generator() -> Any: - for item in items: - yield item - - return _generator() - - -def _responses_stream_text_items() -> list[Any]: - return [ - make_responses_text_delta("hello"), - make_responses_text_delta(" world"), - make_responses_completed({"total_tokens": 7}), - ] - - -def _responses_stream_event_items() -> list[Any]: - return [ - make_responses_text_delta("Checking "), - make_responses_output_item_added(item_id="fc_1", call_id="call_1", name="echo", arguments=""), - make_responses_function_delta('{"text":"to', item_id="fc_1"), - make_responses_function_delta('kyo"}', item_id="fc_1"), - make_responses_output_item_done( - item_id="fc_1", - call_id="call_1", - name="echo", - arguments='{"text":"tokyo"}', - ), - make_responses_completed({"total_tokens": 12}), - ] - - -def _completion_stream_text_items() -> list[Any]: - return [ - make_chunk(text="hello"), - make_chunk(text=" world", usage={"total_tokens": 7}), - ] - - def _completion_stream_event_items() -> list[Any]: return [ make_chunk(text="Checking "), @@ -111,23 +69,6 @@ def _completion_stream_event_items() -> list[Any]: ] -def _main_path_payloads(*, async_mode: bool) -> list[Any]: - wrap_stream = _as_async_iter if async_mode else iter - return [ - make_response(text="ready"), - make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')]), - make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')]), - wrap_stream(_completion_stream_text_items()), - wrap_stream(_completion_stream_event_items()), - ] - - -def _queue_main_path_fixtures(client: Any, *, async_mode: bool) -> None: - payloads = _main_path_payloads(async_mode=async_mode) - queue = client.queue_acompletion if async_mode else client.queue_completion - queue(*payloads) - - def test_default_api_format_uses_completion(fake_anyllm) -> None: client = fake_anyllm.ensure("openai") client.queue_completion(make_response(text="hello")) @@ -177,18 +118,6 @@ def test_openrouter_anthropic_tools_rejects_responses_format(fake_anyllm) -> Non assert exc_info.value.kind == "invalid_input" -def test_openrouter_anthropic_tools_work_with_completion_format(fake_anyllm) -> None: - client = fake_anyllm.ensure("openrouter") - client.queue_completion(make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')])) - - llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy") - calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) - - assert len(calls) == 1 - assert calls[0]["function"]["name"] == "echo" - assert client.calls[-1].get("responses") is None - - def test_messages_format_maps_to_completion(fake_anyllm) -> None: client = fake_anyllm.ensure("openrouter") client.queue_completion(make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')])) @@ -201,50 +130,6 @@ def test_messages_format_maps_to_completion(fake_anyllm) -> None: assert client.calls[-1].get("responses") is None -def test_messages_chat_uses_completion(fake_anyllm) -> None: - client = fake_anyllm.ensure("openrouter") - client.queue_completion(make_response(text="ready")) - - llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") - result = llm.chat("Reply with ready") - - assert result == "ready" - assert client.calls[-1].get("responses") is None - - -def test_messages_stream_uses_completion_and_collects_usage(fake_anyllm) -> None: - client = fake_anyllm.ensure("openrouter") - client.queue_completion(iter(_completion_stream_text_items())) - - llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") - stream = llm.stream("Say hello") - text = "".join(list(stream)) - - assert text == "hello world" - assert stream.error is None - assert stream.usage == {"total_tokens": 7} - assert client.calls[-1].get("responses") is None - assert client.calls[-1]["stream"] is True - - -def test_stream_events_parity_between_completion_and_messages(fake_anyllm) -> None: - completion_client = fake_anyllm.ensure("openai") - completion_client.queue_completion(iter(_completion_stream_event_items())) - - messages_client = fake_anyllm.ensure("openrouter") - messages_client.queue_completion(iter(_completion_stream_event_items())) - - completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") - messages_llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") - - completion_events = list(completion_llm.stream_events("Call echo for tokyo", tools=[echo])) - messages_events = list(messages_llm.stream_events("Call echo for tokyo", tools=[echo])) - - assert completion_client.calls[-1].get("responses") is None - assert messages_client.calls[-1].get("responses") is None - assert _compact_stream_events(completion_events) == _compact_stream_events(messages_events) - - def test_messages_format_rejects_non_anthropic_model(fake_anyllm) -> None: llm = LLM(model="openai:gpt-4o-mini", api_key="dummy", api_format="messages") with pytest.raises(ErrorPayload) as exc_info: @@ -270,20 +155,6 @@ def test_responses_tool_choice_accepts_completion_function_shape(fake_anyllm) -> assert client.calls[-1]["tool_choice"] == {"type": "function", "name": "echo"} -def test_extract_tool_calls_from_responses() -> None: - response = make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"hi"}')]) - - calls = ChatClient._extract_tool_calls(response) - - assert calls == [ - { - "function": {"name": "echo", "arguments": '{"text":"hi"}'}, - "id": "call_1", - "type": "function", - } - ] - - def test_non_stream_completion_splits_concatenated_tool_arguments(fake_anyllm) -> None: client = fake_anyllm.ensure("openai") client.queue_completion( @@ -329,8 +200,6 @@ def test_stream_events_splits_concatenated_tool_arguments(fake_anyllm) -> None: tool_calls = [event.data["call"] for event in events if event.kind == "tool_call"] assert [call["id"] for call in tool_calls] == ["call_1", "call_1__2"] assert [call["function"]["arguments"] for call in tool_calls] == ['{"text":"tokyo"}', '{"text":"osaka"}'] - tool_results = [event.data["result"] for event in events if event.kind == "tool_result"] - assert tool_results == ["TOKYO", "OSAKA"] def test_split_messages_for_responses() -> None: @@ -378,8 +247,6 @@ def test_stream_uses_responses_and_collects_usage(fake_anyllm) -> None: assert text == "Hello world" assert stream.error is None assert stream.usage == {"total_tokens": 7} - assert client.calls[-1]["responses"] is True - assert client.calls[-1]["stream"] is True def test_stream_events_supports_responses_tool_events(fake_anyllm) -> None: @@ -399,66 +266,15 @@ def test_stream_events_supports_responses_tool_events(fake_anyllm) -> None: events = list(stream) kinds = [event.kind for event in events] - assert "text" in kinds assert "tool_call" in kinds assert "tool_result" in kinds assert "usage" in kinds assert kinds[-1] == "final" - tool_calls = [event for event in events if event.kind == "tool_call"] - assert len(tool_calls) == 1 - assert tool_calls[0].data["call"]["function"]["name"] == "echo" - assert tool_calls[0].data["call"]["function"]["arguments"] == '{"text":"tokyo"}' - - tool_results = [event for event in events if event.kind == "tool_result"] - assert len(tool_results) == 1 - assert tool_results[0].data["result"] == "TOKYO" - assert stream.error is None - assert stream.usage == {"total_tokens": 12} - - -def test_stream_and_events_support_responses_dict_events(fake_anyllm) -> None: - client = fake_anyllm.ensure("openrouter") - client.queue_responses( - iter([ - {"type": "response.output_text.delta", "delta": "Checking "}, - {"type": "response.function_call_arguments.delta", "item_id": "call_d1", "delta": '{"text":"to'}, - { - "type": "response.function_call_arguments.done", - "item_id": "call_d1", - "name": "echo", - "arguments": '{"text":"tokyo"}', - }, - {"type": "response.completed", "response": {"usage": {"total_tokens": 5}}}, - ]) - ) - client.queue_responses(make_responses_response(text="ready", usage={"total_tokens": 3})) - - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") - events_stream = llm.stream_events("Call echo for tokyo", tools=[echo]) - events = list(events_stream) - tool_call = next(event for event in events if event.kind == "tool_call") - assert tool_call.data["call"]["function"]["arguments"] == '{"text":"tokyo"}' - assert events_stream.usage == {"total_tokens": 5} - - text_stream = llm.stream("Reply with ready") - text = "".join(list(text_stream)) - assert text == "ready" - assert text_stream.usage == {"total_tokens": 3} - def test_stream_events_parity_between_completion_and_responses(fake_anyllm) -> None: completion_client = fake_anyllm.ensure("openai") - completion_client.queue_completion( - iter([ - make_chunk(text="Checking "), - make_chunk(tool_calls=[make_tool_call("echo", '{"text":"to', call_id="call_1")]), - make_chunk( - tool_calls=[make_tool_call("echo", 'kyo"}', call_id="call_1")], - usage={"total_tokens": 12}, - ), - ]) - ) + completion_client.queue_completion(iter(_completion_stream_event_items())) responses_client = fake_anyllm.ensure("openrouter") responses_client.queue_responses( @@ -467,7 +283,12 @@ def test_stream_events_parity_between_completion_and_responses(fake_anyllm) -> N make_responses_output_item_added(item_id="fc_1", call_id="call_1", name="echo", arguments=""), make_responses_function_delta('{"text":"to', item_id="fc_1"), make_responses_function_delta('kyo"}', item_id="fc_1"), - make_responses_function_done("echo", '{"text":"tokyo"}', item_id="fc_1"), + make_responses_output_item_done( + item_id="fc_1", + call_id="call_1", + name="echo", + arguments='{"text":"tokyo"}', + ), make_responses_completed({"total_tokens": 12}), ]) ) @@ -477,51 +298,24 @@ def test_stream_events_parity_between_completion_and_responses(fake_anyllm) -> N completion_events = list(completion_llm.stream_events("Call echo for tokyo", tools=[echo])) responses_events = list(responses_llm.stream_events("Call echo for tokyo", tools=[echo])) - assert completion_client.calls[-1].get("responses") is None - assert responses_client.calls[-1].get("responses") is True assert _compact_stream_events(completion_events) == _compact_stream_events(responses_events) -def test_stream_events_responses_output_item_events_keep_call_id(fake_anyllm) -> None: - client = fake_anyllm.ensure("openrouter") - client.queue_responses( - iter([ - make_responses_output_item_added(item_id="fc_123", call_id="call_abc", name="echo"), - make_responses_function_delta('{"text":"to', item_id="fc_123"), - make_responses_function_delta('kyo"}', item_id="fc_123"), - make_responses_output_item_done( - item_id="fc_123", - call_id="call_abc", - name="echo", - arguments='{"text":"tokyo"}', - ), - make_responses_completed({"total_tokens": 6}), - ]) - ) - - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") - events = list(llm.stream_events("Call echo for tokyo", tools=[echo])) +def test_stream_events_parity_between_completion_and_messages(fake_anyllm) -> None: + completion_client = fake_anyllm.ensure("openai") + completion_client.queue_completion(iter(_completion_stream_event_items())) - tool_call = next(event for event in events if event.kind == "tool_call").data["call"] - assert tool_call["id"] == "call_abc" - assert tool_call["function"]["name"] == "echo" - assert tool_call["function"]["arguments"] == '{"text":"tokyo"}' + messages_client = fake_anyllm.ensure("openrouter") + messages_client.queue_completion(iter(_completion_stream_event_items())) + completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + messages_llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") -def test_stream_usage_accepts_responses_in_progress_usage(fake_anyllm) -> None: - client = fake_anyllm.ensure("openrouter") - client.queue_responses( - iter([ - make_responses_text_delta("ok"), - {"type": "response.in_progress", "response": {"usage": {"total_tokens": 2}}}, - ]) - ) + completion_events = list(completion_llm.stream_events("Call echo for tokyo", tools=[echo])) + messages_events = list(messages_llm.stream_events("Call echo for tokyo", tools=[echo])) - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") - stream = llm.stream("Reply with ok") - assert "".join(list(stream)) == "ok" - assert stream.usage == {"total_tokens": 2} + assert _compact_stream_events(completion_events) == _compact_stream_events(messages_events) def test_non_stream_responses_tool_calls_converts_tools_payload(fake_anyllm) -> None: @@ -534,43 +328,12 @@ def test_non_stream_responses_tool_calls_converts_tools_payload(fake_anyllm) -> calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) assert len(calls) == 1 - assert calls[0]["function"]["name"] == "echo" sent_tools = client.calls[-1]["tools"] assert sent_tools[0]["type"] == "function" assert sent_tools[0]["name"] == "echo" - assert sent_tools[0]["description"] == "" - assert sent_tools[0]["parameters"]["type"] == "object" assert "function" not in sent_tools[0] -def test_non_stream_responses_run_tools_uses_converted_tools(fake_anyllm) -> None: - client = fake_anyllm.ensure("openrouter") - client.queue_responses( - make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) - ) - - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") - result = llm.run_tools("Call echo for tokyo", tools=[echo]) - - assert result.kind == "tools" - assert result.tool_results == ["TOKYO"] - sent_tools = client.calls[-1]["tools"] - assert sent_tools[0]["type"] == "function" - assert sent_tools[0]["name"] == "echo" - assert "function" not in sent_tools[0] - - -def test_chat_reasoning_effort_for_completion_is_forwarded(fake_anyllm) -> None: - client = fake_anyllm.ensure("openai") - client.queue_completion(make_response(text="ready")) - - llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") - assert llm.chat("Reply with ready", reasoning_effort="high") == "ready" - - call = client.calls[-1] - assert call.get("reasoning_effort") == "high" - - def test_chat_reasoning_effort_for_responses_is_mapped(fake_anyllm) -> None: client = fake_anyllm.ensure("openrouter") client.queue_responses(make_responses_response(text="ready")) @@ -584,18 +347,6 @@ def test_chat_reasoning_effort_for_responses_is_mapped(fake_anyllm) -> None: assert "reasoning_effort" not in call -def test_chat_reasoning_kwarg_has_priority_over_reasoning_effort(fake_anyllm) -> None: - client = fake_anyllm.ensure("openrouter") - client.queue_responses(make_responses_response(text="ready")) - - llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") - assert llm.chat("Reply with ready", reasoning_effort="low", reasoning={"effort": "high"}) == "ready" - - call = client.calls[-1] - assert call["responses"] is True - assert call.get("reasoning") == {"effort": "high"} - - def test_stream_completion_defaults_include_usage(fake_anyllm) -> None: client = fake_anyllm.ensure("openai") client.queue_completion( @@ -613,23 +364,6 @@ def test_stream_completion_defaults_include_usage(fake_anyllm) -> None: assert client.calls[-1].get("stream_options") == {"include_usage": True} -def test_stream_completion_preserves_user_stream_options(fake_anyllm) -> None: - client = fake_anyllm.ensure("openai") - client.queue_completion( - iter([ - make_chunk(text="hello"), - make_chunk(text=" world", usage={"total_tokens": 7}), - ]) - ) - - llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") - stream = llm.stream("Say hello", stream_options={"include_usage": False, "custom": True}) - assert "".join(list(stream)) == "hello world" - assert stream.usage == {"total_tokens": 7} - - assert client.calls[-1].get("stream_options") == {"include_usage": False, "custom": True} - - def test_openai_completion_uses_max_completion_tokens(fake_anyllm) -> None: client = fake_anyllm.ensure("openai") client.queue_completion(make_response(text="hello")) @@ -652,101 +386,3 @@ def test_non_openai_completion_uses_max_tokens(fake_anyllm) -> None: call = client.calls[-1] assert call.get("max_tokens") == 11 assert "max_completion_tokens" not in call - - -def test_non_stream_chat_parity_between_completion_and_responses(fake_anyllm) -> None: - completion_client = fake_anyllm.ensure("openai") - completion_client.queue_completion(make_response(text="ready")) - - responses_client = fake_anyllm.ensure("openrouter") - responses_client.queue_responses(make_responses_response(text="ready")) - - completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") - responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") - - assert completion_llm.chat("Reply with ready") == "ready" - assert responses_llm.chat("Reply with ready") == "ready" - - -def test_non_stream_run_tools_parity_between_completion_and_responses(fake_anyllm) -> None: - completion_client = fake_anyllm.ensure("openai") - completion_client.queue_completion(make_response(tool_calls=[make_tool_call("echo", '{"text":"tokyo"}')])) - - responses_client = fake_anyllm.ensure("openrouter") - responses_client.queue_responses( - make_responses_response(tool_calls=[make_responses_function_call("echo", '{"text":"tokyo"}')]) - ) - - completion_llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") - responses_llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") - - completion_result = completion_llm.run_tools("Call echo for tokyo", tools=[echo]) - responses_result = responses_llm.run_tools("Call echo for tokyo", tools=[echo]) - - assert completion_result.kind == responses_result.kind == "tools" - assert completion_result.tool_results == responses_result.tool_results == ["TOKYO"] - - -def test_sync_main_paths(fake_anyllm) -> None: - client = fake_anyllm.ensure("openai") - _queue_main_path_fixtures(client, async_mode=False) - llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") - - assert llm.chat("Reply with ready") == "ready" - - calls = llm.tool_calls("Call echo for tokyo", tools=[echo]) - assert len(calls) == 1 - assert calls[0]["function"]["name"] == "echo" - - run_result = llm.run_tools("Call echo for tokyo", tools=[echo]) - assert run_result.kind == "tools" - assert run_result.tool_results == ["TOKYO"] - - text_stream = llm.stream("Say hello") - assert "".join(list(text_stream)) == "hello world" - assert text_stream.usage == {"total_tokens": 7} - - event_stream = llm.stream_events("Call echo for tokyo", tools=[echo]) - events = list(event_stream) - compact = _compact_stream_events(events) - assert compact[0] == ("text", "Checking ") - assert compact[1][0] == "tool_call" - assert compact[2] == ("tool_result", "TOKYO") - assert compact[3] == ("usage", {"total_tokens": 12}) - assert compact[-1][0] == "final" - assert event_stream.usage == {"total_tokens": 12} - - assert all(call.get("responses") is None for call in client.calls) - - -@pytest.mark.asyncio -async def test_async_main_paths(fake_anyllm) -> None: - client = fake_anyllm.ensure("openai") - _queue_main_path_fixtures(client, async_mode=True) - llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") - - assert await llm.chat_async("Reply with ready") == "ready" - - calls = await llm.tool_calls_async("Call echo for tokyo", tools=[echo]) - assert len(calls) == 1 - assert calls[0]["function"]["name"] == "echo" - - run_result = await llm.run_tools_async("Call echo for tokyo", tools=[echo]) - assert run_result.kind == "tools" - assert run_result.tool_results == ["TOKYO"] - - text_stream = await llm.stream_async("Say hello") - assert "".join([part async for part in text_stream]) == "hello world" - assert text_stream.usage == {"total_tokens": 7} - - event_stream = await llm.stream_events_async("Call echo for tokyo", tools=[echo]) - events = [event async for event in event_stream] - compact = _compact_stream_events(events) - assert compact[0] == ("text", "Checking ") - assert compact[1][0] == "tool_call" - assert compact[2] == ("tool_result", "TOKYO") - assert compact[3] == ("usage", {"total_tokens": 12}) - assert compact[-1][0] == "final" - assert event_stream.usage == {"total_tokens": 12} - - assert all(call.get("responses") is None for call in client.calls) From fbd50f413fcfed75334d706bdef1e694b62c415a Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 2 Mar 2026 17:19:31 +0000 Subject: [PATCH 14/14] refactor: simplify parsers and fix openrouter header handling --- src/republic/clients/chat.py | 42 ++++++---- src/republic/clients/parsing/messages.py | 98 ++---------------------- src/republic/clients/parsing/types.py | 34 +++----- src/republic/core/execution.py | 9 +-- tests/test_responses_handling.py | 33 ++++++++ 5 files changed, 80 insertions(+), 136 deletions(-) diff --git a/src/republic/clients/chat.py b/src/republic/clients/chat.py index 297225a..58553c8 100644 --- a/src/republic/clients/chat.py +++ b/src/republic/clients/chat.py @@ -8,7 +8,7 @@ from functools import partial from typing import Any -from republic.clients.parsing import TransportKind, parser_for_transport +from republic.clients.parsing import BaseTransportParser, TransportKind, parser_for_transport from republic.clients.parsing.common import expand_tool_calls from republic.clients.parsing.common import field as _field from republic.core.errors import ErrorKind, RepublicError @@ -235,14 +235,32 @@ def _resolve_transport( return "responses" return "completion" + @staticmethod + def _parser_for_payload( + payload: Any, + *, + transport: TransportKind | None = None, + ) -> BaseTransportParser: + effective_transport = ChatClient._resolve_transport(payload, transport) + return parser_for_transport(effective_transport) + + @staticmethod + def _unwrap_response_with_parser( + response: Any, + *, + transport: TransportKind | None = None, + ) -> tuple[Any, BaseTransportParser]: + payload, detected_transport = ChatClient._unwrap_response(response) + parser = ChatClient._parser_for_payload(payload, transport=transport or detected_transport) + return payload, parser + @staticmethod def _is_non_stream_response( response: Any, *, transport: TransportKind | None = None, ) -> bool: - effective_transport = ChatClient._resolve_transport(response, transport) - parser = parser_for_transport(effective_transport) + parser = ChatClient._parser_for_payload(response, transport=transport) return parser.is_non_stream_response(response) def _validate_chat_input( @@ -1583,8 +1601,7 @@ def _extract_chunk_tool_call_deltas( *, transport: TransportKind | None = None, ) -> list[Any]: - effective_transport = ChatClient._resolve_transport(chunk, transport) - parser = parser_for_transport(effective_transport) + parser = ChatClient._parser_for_payload(chunk, transport=transport) return parser.extract_chunk_tool_call_deltas(chunk) @staticmethod @@ -1593,8 +1610,7 @@ def _extract_chunk_text( *, transport: TransportKind | None = None, ) -> str: - effective_transport = ChatClient._resolve_transport(chunk, transport) - parser = parser_for_transport(effective_transport) + parser = ChatClient._parser_for_payload(chunk, transport=transport) return parser.extract_chunk_text(chunk) def _build_event_stream( @@ -1919,9 +1935,7 @@ def _extract_text( *, transport: TransportKind | None = None, ) -> str: - payload, detected_transport = ChatClient._unwrap_response(response) - effective_transport = ChatClient._resolve_transport(payload, transport or detected_transport) - parser = parser_for_transport(effective_transport) + payload, parser = ChatClient._unwrap_response_with_parser(response, transport=transport) return parser.extract_text(payload) @staticmethod @@ -1930,9 +1944,7 @@ def _extract_tool_calls( *, transport: TransportKind | None = None, ) -> list[dict[str, Any]]: - payload, detected_transport = ChatClient._unwrap_response(response) - effective_transport = ChatClient._resolve_transport(payload, transport or detected_transport) - parser = parser_for_transport(effective_transport) + payload, parser = ChatClient._unwrap_response_with_parser(response, transport=transport) return parser.extract_tool_calls(payload) @staticmethod @@ -1941,7 +1953,5 @@ def _extract_usage( *, transport: TransportKind | None = None, ) -> dict[str, Any] | None: - payload, detected_transport = ChatClient._unwrap_response(response) - effective_transport = ChatClient._resolve_transport(payload, transport or detected_transport) - parser = parser_for_transport(effective_transport) + payload, parser = ChatClient._unwrap_response_with_parser(response, transport=transport) return parser.extract_usage(payload) diff --git a/src/republic/clients/parsing/messages.py b/src/republic/clients/parsing/messages.py index 739d1e2..190f354 100644 --- a/src/republic/clients/parsing/messages.py +++ b/src/republic/clients/parsing/messages.py @@ -1,95 +1,11 @@ -"""Anthropic messages shape parsing.""" - -from __future__ import annotations - -from typing import Any - -from republic.clients.parsing.common import expand_tool_calls, field -from republic.clients.parsing.types import FunctionTransportParser - - -def is_non_stream_response(response: Any) -> bool: - return isinstance(response, str) or field(response, "choices") is not None - - -def extract_chunk_tool_call_deltas(chunk: Any) -> list[Any]: - choices = field(chunk, "choices") - if not choices: - return [] - delta = field(choices[0], "delta") - if delta is None: - return [] - return field(delta, "tool_calls") or [] +"""Anthropic messages parsing. +Currently any-llm exposes Anthropic messages in completion-compatible payload +shapes, so this parser intentionally reuses completion parsing behavior. +""" -def extract_chunk_text(chunk: Any) -> str: - choices = field(chunk, "choices") - if not choices: - return "" - delta = field(choices[0], "delta") - if delta is None: - return "" - return field(delta, "content", "") or "" - - -def extract_text(response: Any) -> str: - if isinstance(response, str): - return response - - choices = field(response, "choices") - if not choices: - return "" - message = field(choices[0], "message") - if message is None: - return "" - return field(message, "content", "") or "" - - -def extract_tool_calls(response: Any) -> list[dict[str, Any]]: - choices = field(response, "choices") - if not choices: - return [] - message = field(choices[0], "message") - if message is None: - return [] - tool_calls = field(message, "tool_calls") or [] - calls: list[dict[str, Any]] = [] - for tool_call in tool_calls: - function = field(tool_call, "function") - if function is None: - continue - entry: dict[str, Any] = { - "function": { - "name": field(function, "name"), - "arguments": field(function, "arguments"), - } - } - call_id = field(tool_call, "id") - if call_id: - entry["id"] = call_id - call_type = field(tool_call, "type") - if call_type: - entry["type"] = call_type - calls.append(entry) - return expand_tool_calls(calls) - - -def extract_usage(response: Any) -> dict[str, Any] | None: - usage = field(response, "usage") - if usage is None: - return None - if isinstance(usage, dict): - return dict(usage) - if hasattr(usage, "model_dump"): - return usage.model_dump() - return None +from __future__ import annotations +from republic.clients.parsing.completion import PARSER -PARSER = FunctionTransportParser( - is_non_stream_response_fn=is_non_stream_response, - extract_chunk_tool_call_deltas_fn=extract_chunk_tool_call_deltas, - extract_chunk_text_fn=extract_chunk_text, - extract_text_fn=extract_text, - extract_tool_calls_fn=extract_tool_calls, - extract_usage_fn=extract_usage, -) +__all__ = ["PARSER"] diff --git a/src/republic/clients/parsing/types.py b/src/republic/clients/parsing/types.py index 5b69a5b..0290e19 100644 --- a/src/republic/clients/parsing/types.py +++ b/src/republic/clients/parsing/types.py @@ -2,42 +2,30 @@ from __future__ import annotations -from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Literal +from typing import Any, Literal, Protocol, runtime_checkable TransportKind = Literal["completion", "responses", "messages"] -class BaseTransportParser(ABC): - @abstractmethod - def is_non_stream_response(self, response: Any) -> bool: - raise NotImplementedError +@runtime_checkable +class BaseTransportParser(Protocol): + def is_non_stream_response(self, response: Any) -> bool: ... - @abstractmethod - def extract_chunk_tool_call_deltas(self, chunk: Any) -> list[Any]: - raise NotImplementedError + def extract_chunk_tool_call_deltas(self, chunk: Any) -> list[Any]: ... - @abstractmethod - def extract_chunk_text(self, chunk: Any) -> str: - raise NotImplementedError + def extract_chunk_text(self, chunk: Any) -> str: ... - @abstractmethod - def extract_text(self, response: Any) -> str: - raise NotImplementedError + def extract_text(self, response: Any) -> str: ... - @abstractmethod - def extract_tool_calls(self, response: Any) -> list[dict[str, Any]]: - raise NotImplementedError + def extract_tool_calls(self, response: Any) -> list[dict[str, Any]]: ... - @abstractmethod - def extract_usage(self, response: Any) -> dict[str, Any] | None: - raise NotImplementedError + def extract_usage(self, response: Any) -> dict[str, Any] | None: ... -@dataclass(frozen=True) -class FunctionTransportParser(BaseTransportParser): +@dataclass(frozen=True, slots=True) +class FunctionTransportParser: is_non_stream_response_fn: Callable[[Any], bool] extract_chunk_tool_call_deltas_fn: Callable[[Any], list[Any]] extract_chunk_text_fn: Callable[[Any], str] diff --git a/src/republic/core/execution.py b/src/republic/core/execution.py index 409df0c..2011314 100644 --- a/src/republic/core/execution.py +++ b/src/republic/core/execution.py @@ -364,23 +364,20 @@ def _handle_attempt_error(self, exc: Exception, provider_name: str, model_id: st def _decide_kwargs_for_provider( self, provider: str, max_tokens: int | None, kwargs: dict[str, Any] ) -> dict[str, Any]: - clean_kwargs = self._sanitize_request_kwargs(kwargs) + clean_kwargs = dict(kwargs) max_tokens_arg = provider_policies.completion_max_tokens_arg(provider) if max_tokens_arg in clean_kwargs: return clean_kwargs return {**clean_kwargs, max_tokens_arg: max_tokens} def _decide_responses_kwargs(self, max_tokens: int | None, kwargs: dict[str, Any]) -> dict[str, Any]: - clean_kwargs = self._sanitize_request_kwargs(kwargs) + # any-llm responses params currently reject extra_headers, so drop it here only. + clean_kwargs = {k: v for k, v in kwargs.items() if k != "extra_headers"} clean_kwargs = normalize_responses_kwargs(clean_kwargs) if "max_output_tokens" in clean_kwargs: return clean_kwargs return {**clean_kwargs, "max_output_tokens": max_tokens} - @staticmethod - def _sanitize_request_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: - return {k: v for k, v in kwargs.items() if k != "extra_headers"} - @staticmethod def _should_default_completion_stream_usage(provider_name: str) -> bool: return provider_policies.should_include_completion_stream_usage(provider_name) diff --git a/tests/test_responses_handling.py b/tests/test_responses_handling.py index 57be480..92e76cf 100644 --- a/tests/test_responses_handling.py +++ b/tests/test_responses_handling.py @@ -347,6 +347,39 @@ def test_chat_reasoning_effort_for_responses_is_mapped(fake_anyllm) -> None: assert "reasoning_effort" not in call +def test_completion_preserves_extra_headers(fake_anyllm) -> None: + client = fake_anyllm.ensure("openai") + client.queue_completion(make_response(text="hello")) + + llm = LLM(model="openai:gpt-4o-mini", api_key="dummy") + assert llm.chat("Say hello", extra_headers={"X-Title": "Republic"}) == "hello" + + call = client.calls[-1] + assert call.get("extra_headers") == {"X-Title": "Republic"} + + +def test_messages_preserves_extra_headers(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_completion(make_response(text="hello")) + + llm = LLM(model="openrouter:anthropic/claude-3.5-haiku", api_key="dummy", api_format="messages") + assert llm.chat("Say hello", extra_headers={"X-Title": "Republic"}) == "hello" + + call = client.calls[-1] + assert call.get("extra_headers") == {"X-Title": "Republic"} + + +def test_responses_drops_extra_headers(fake_anyllm) -> None: + client = fake_anyllm.ensure("openrouter") + client.queue_responses(make_responses_response(text="hello")) + + llm = LLM(model="openrouter:openrouter/free", api_key="dummy", api_format="responses") + assert llm.chat("Say hello", extra_headers={"X-Title": "Republic"}) == "hello" + + call = client.calls[-1] + assert "extra_headers" not in call + + def test_stream_completion_defaults_include_usage(fake_anyllm) -> None: client = fake_anyllm.ensure("openai") client.queue_completion(