diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index fc72e1e4b..e3eb3bc27 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations +from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient from data_designer.engine.models.clients.base import ModelClient from data_designer.engine.models.clients.errors import ( ProviderError, @@ -11,6 +12,8 @@ map_http_status_to_provider_error_kind, ) from data_designer.engine.models.clients.factory import create_model_client +from data_designer.engine.models.clients.retry import RetryConfig +from data_designer.engine.models.clients.throttle import ThrottleDomain, ThrottleManager from data_designer.engine.models.clients.types import ( AssistantMessage, ChatCompletionRequest, @@ -36,8 +39,12 @@ "ImageGenerationResponse", "ImagePayload", "ModelClient", + "OpenAICompatibleClient", "ProviderError", "ProviderErrorKind", + "RetryConfig", + "ThrottleDomain", + "ThrottleManager", "ToolCall", "Usage", "create_model_client", diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py index cc9feefea..d2f24b41f 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py @@ -4,5 +4,6 @@ from __future__ import annotations from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient, LiteLLMRouter +from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient -__all__ = ["LiteLLMBridgeClient", "LiteLLMRouter"] +__all__ = ["LiteLLMBridgeClient", "LiteLLMRouter", "OpenAICompatibleClient"] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index 017a4af75..07bf884d6 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -11,8 +11,8 @@ from data_designer.engine.models.clients.base import ModelClient from data_designer.engine.models.clients.errors import ( ProviderError, - ProviderErrorKind, extract_message_from_exception_string, + infer_error_kind_from_exception, map_http_status_to_provider_error_kind, ) from data_designer.engine.models.clients.parsing import ( @@ -192,7 +192,7 @@ def _handle_non_provider_errors(provider_name: str) -> Iterator[None]: if isinstance(status_code, int): kind = map_http_status_to_provider_error_kind(status_code=status_code, body_text=str(exc)) else: - kind = _infer_error_kind(exc) + kind = infer_error_kind_from_exception(exc) raise ProviderError( kind=kind, @@ -201,17 +201,3 @@ def _handle_non_provider_errors(provider_name: str) -> Iterator[None]: provider_name=provider_name, cause=exc, ) from exc - - -def _infer_error_kind(exc: Exception) -> ProviderErrorKind: - """Infer error kind from exception type name when no status code is available.""" - type_name = type(exc).__name__.lower() - if "timeout" in type_name: - return ProviderErrorKind.TIMEOUT - if "connection" in type_name or "connect" in type_name: - return ProviderErrorKind.API_CONNECTION - if "auth" in type_name: - return ProviderErrorKind.AUTHENTICATION - if "ratelimit" in type_name: - return ProviderErrorKind.RATE_LIMIT - return ProviderErrorKind.API_ERROR diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py new file mode 100644 index 000000000..ab5e47929 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +import threading +from typing import TYPE_CHECKING, Any + +import data_designer.lazy_heavy_imports as lazy +from data_designer.engine.models.clients.base import ModelClient +from data_designer.engine.models.clients.errors import ( + ProviderError, + ProviderErrorKind, + infer_error_kind_from_exception, + map_http_error_to_provider_error, +) +from data_designer.engine.models.clients.parsing import ( + aextract_images_from_chat_response, + aextract_images_from_image_response, + aparse_chat_completion_response, + extract_embedding_vector, + extract_images_from_chat_response, + extract_images_from_image_response, + extract_usage, + parse_chat_completion_response, +) +from data_designer.engine.models.clients.retry import RetryConfig, create_retry_transport +from data_designer.engine.models.clients.types import ( + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ImageGenerationRequest, + ImageGenerationResponse, + TransportKwargs, +) + +if TYPE_CHECKING: + import httpx + +logger = logging.getLogger(__name__) + + +class OpenAICompatibleClient(ModelClient): + """Native HTTP adapter for OpenAI-compatible provider APIs. + + Uses ``httpx`` with ``httpx_retries.RetryTransport`` for resilient HTTP + calls. Concurrency / throttle policy is an orchestration concern and + is not managed here — see ``ThrottleManager`` and ``AsyncTaskScheduler``. + """ + + _ROUTE_CHAT = "/chat/completions" + _ROUTE_EMBEDDING = "/embeddings" + _ROUTE_IMAGE = "/images/generations" + _IMAGE_EXCLUDE = frozenset({"messages", "prompt"}) + + def __init__( + self, + *, + provider_name: str, + model_id: str, + endpoint: str, + api_key: str | None = None, + retry_config: RetryConfig | None = None, + max_parallel_requests: int = 32, + timeout_s: float = 60.0, + sync_client: httpx.Client | None = None, + async_client: httpx.AsyncClient | None = None, + ) -> None: + self.provider_name = provider_name + self._model_id = model_id + self._endpoint = endpoint.rstrip("/") + self._api_key = api_key + self._timeout_s = timeout_s + self._retry_config = retry_config + + # 2x headroom for burst traffic across domains; floor of 32/16 for low-concurrency configs. + pool_max = max(32, 2 * max_parallel_requests) + pool_keepalive = max(16, max_parallel_requests) + self._limits = lazy.httpx.Limits( + max_connections=pool_max, + max_keepalive_connections=pool_keepalive, + ) + self._transport = create_retry_transport(self._retry_config) + self._client: httpx.Client | None = sync_client + self._aclient: httpx.AsyncClient | None = async_client + self._init_lock = threading.Lock() + + def _get_sync_client(self) -> httpx.Client: + if self._client is None: + with self._init_lock: + if self._client is None: + self._client = lazy.httpx.Client( + transport=self._transport, + limits=self._limits, + timeout=lazy.httpx.Timeout(self._timeout_s), + ) + return self._client + + def _get_async_client(self) -> httpx.AsyncClient: + if self._aclient is None: + with self._init_lock: + if self._aclient is None: + self._aclient = lazy.httpx.AsyncClient( + transport=self._transport, + limits=self._limits, + timeout=lazy.httpx.Timeout(self._timeout_s), + ) + return self._aclient + + # ------------------------------------------------------------------- + # Capability checks — adapter-level (see ModelClient docstring) + # ------------------------------------------------------------------- + + def supports_chat_completion(self) -> bool: + return True + + def supports_embeddings(self) -> bool: + return True + + def supports_image_generation(self) -> bool: + return True + + # ------------------------------------------------------------------- + # Chat completion + # ------------------------------------------------------------------- + + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + transport = TransportKwargs.from_request(request) + payload = {"model": request.model, "messages": request.messages, **transport.body} + response_json = self._post_sync(self._ROUTE_CHAT, payload, transport.headers, request.model, transport.timeout) + return parse_chat_completion_response(response_json) + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + transport = TransportKwargs.from_request(request) + payload = {"model": request.model, "messages": request.messages, **transport.body} + response_json = await self._apost( + self._ROUTE_CHAT, payload, transport.headers, request.model, transport.timeout + ) + return await aparse_chat_completion_response(response_json) + + # ------------------------------------------------------------------- + # Embeddings + # ------------------------------------------------------------------- + + def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + transport = TransportKwargs.from_request(request) + payload = {"model": request.model, "input": request.inputs, **transport.body} + response_json = self._post_sync( + self._ROUTE_EMBEDDING, payload, transport.headers, request.model, transport.timeout + ) + return _parse_embedding_json(response_json) + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + transport = TransportKwargs.from_request(request) + payload = {"model": request.model, "input": request.inputs, **transport.body} + response_json = await self._apost( + self._ROUTE_EMBEDDING, payload, transport.headers, request.model, transport.timeout + ) + return _parse_embedding_json(response_json) + + # ------------------------------------------------------------------- + # Image generation + # ------------------------------------------------------------------- + + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) + if request.messages is not None: + route = self._ROUTE_CHAT + payload = {"model": request.model, "messages": request.messages, **transport.body} + else: + route = self._ROUTE_IMAGE + payload = {"model": request.model, "prompt": request.prompt, **transport.body} + response_json = self._post_sync(route, payload, transport.headers, request.model, transport.timeout) + return _parse_image_json(response_json, is_chat_route=request.messages is not None) + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) + if request.messages is not None: + route = self._ROUTE_CHAT + payload = {"model": request.model, "messages": request.messages, **transport.body} + else: + route = self._ROUTE_IMAGE + payload = {"model": request.model, "prompt": request.prompt, **transport.body} + response_json = await self._apost(route, payload, transport.headers, request.model, transport.timeout) + return await _aparse_image_json(response_json, is_chat_route=request.messages is not None) + + # ------------------------------------------------------------------- + # Lifecycle + # ------------------------------------------------------------------- + + def close(self) -> None: + if self._client is not None: + self._client.close() + self._client = None + + async def aclose(self) -> None: + if self._aclient is not None: + await self._aclient.aclose() + self._aclient = None + if self._client is not None: + self._client.close() + self._client = None + + # ------------------------------------------------------------------- + # HTTP helpers + # ------------------------------------------------------------------- + + def _build_headers(self, extra_headers: dict[str, str]) -> dict[str, str]: + headers: dict[str, str] = {"Content-Type": "application/json"} + if self._api_key: + headers["Authorization"] = f"Bearer {self._api_key}" + if extra_headers: + headers.update(extra_headers) + return headers + + def _resolve_timeout(self, per_request: float | None) -> httpx.Timeout: + return lazy.httpx.Timeout(per_request if per_request is not None else self._timeout_s) + + def _post_sync( + self, + route: str, + payload: dict[str, Any], + extra_headers: dict[str, str], + model_name: str, + timeout: float | None = None, + ) -> dict[str, Any]: + url = f"{self._endpoint}{route}" + headers = self._build_headers(extra_headers) + try: + response = self._get_sync_client().post( + url, json=payload, headers=headers, timeout=self._resolve_timeout(timeout) + ) + except Exception as exc: + raise _wrap_transport_error(exc, self.provider_name, model_name) from exc + if response.status_code >= 400: + raise map_http_error_to_provider_error( + response=response, provider_name=self.provider_name, model_name=model_name + ) + return _parse_json_body(response, self.provider_name, model_name) + + async def _apost( + self, + route: str, + payload: dict[str, Any], + extra_headers: dict[str, str], + model_name: str, + timeout: float | None = None, + ) -> dict[str, Any]: + url = f"{self._endpoint}{route}" + headers = self._build_headers(extra_headers) + try: + response = await self._get_async_client().post( + url, json=payload, headers=headers, timeout=self._resolve_timeout(timeout) + ) + except Exception as exc: + raise _wrap_transport_error(exc, self.provider_name, model_name) from exc + if response.status_code >= 400: + raise map_http_error_to_provider_error( + response=response, provider_name=self.provider_name, model_name=model_name + ) + return _parse_json_body(response, self.provider_name, model_name) + + +# --------------------------------------------------------------------------- +# Response parsing helpers +# --------------------------------------------------------------------------- + + +def _parse_embedding_json(response_json: dict[str, Any]) -> EmbeddingResponse: + data = response_json.get("data") or [] + vectors = [extract_embedding_vector(item) for item in data] + usage = extract_usage(response_json.get("usage")) + return EmbeddingResponse(vectors=vectors, usage=usage, raw=response_json) + + +def _parse_image_json(response_json: dict[str, Any], *, is_chat_route: bool) -> ImageGenerationResponse: + if is_chat_route: + images = extract_images_from_chat_response(response_json) + else: + images = extract_images_from_image_response(response_json) + usage = extract_usage(response_json.get("usage"), generated_images=len(images)) + return ImageGenerationResponse(images=images, usage=usage, raw=response_json) + + +async def _aparse_image_json(response_json: dict[str, Any], *, is_chat_route: bool) -> ImageGenerationResponse: + if is_chat_route: + images = await aextract_images_from_chat_response(response_json) + else: + images = await aextract_images_from_image_response(response_json) + usage = extract_usage(response_json.get("usage"), generated_images=len(images)) + return ImageGenerationResponse(images=images, usage=usage, raw=response_json) + + +def _parse_json_body(response: httpx.Response, provider_name: str, model_name: str) -> dict[str, Any]: + """Parse JSON from a successful HTTP response, wrapping decode errors as ``ProviderError``.""" + try: + return response.json() + except Exception as exc: + raise ProviderError( + kind=ProviderErrorKind.API_ERROR, + message=f"Provider {provider_name!r} returned a non-JSON response (status {response.status_code}).", + status_code=response.status_code, + provider_name=provider_name, + model_name=model_name, + cause=exc, + ) from exc + + +def _wrap_transport_error(exc: Exception, provider_name: str, model_name: str) -> ProviderError: + """Convert httpx transport exceptions into canonical ``ProviderError``.""" + return ProviderError( + kind=infer_error_kind_from_exception(exc), + message=str(exc) or f"Transport error from provider {provider_name!r}", + provider_name=provider_name, + model_name=model_name, + cause=exc, + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/base.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/base.py index d1b5cd23a..cda6fac47 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/base.py @@ -34,6 +34,15 @@ async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerat class ModelClient(ChatCompletionClient, EmbeddingClient, ImageGenerationClient, Protocol): + """Unified protocol for model provider adapters. + + The ``supports_*`` methods indicate whether this **adapter implementation** + is capable of handling a given modality (i.e. it has the code paths and + route mappings to make the call). They do **not** reflect whether a + specific model or endpoint actually supports that modality — that is a + configuration concern handled by ``ModelConfig.generation_type``. + """ + provider_name: str def supports_chat_completion(self) -> bool: ... diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index da3c19383..64d5c5d85 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -216,6 +216,24 @@ def _parse_http_date_as_delay(value: str) -> float | None: return max(delay, 0.0) +def infer_error_kind_from_exception(exc: Exception) -> ProviderErrorKind: + """Infer a ``ProviderErrorKind`` from an exception's type name. + + Used by adapters to classify transport-level exceptions (timeouts, + connection failures, etc.) that don't carry an HTTP status code. + """ + type_name = type(exc).__name__.lower() + if "timeout" in type_name: + return ProviderErrorKind.TIMEOUT + if "connection" in type_name or "connect" in type_name: + return ProviderErrorKind.API_CONNECTION + if "auth" in type_name: + return ProviderErrorKind.AUTHENTICATION + if "ratelimit" in type_name: + return ProviderErrorKind.RATE_LIMIT + return ProviderErrorKind.API_ERROR + + def _looks_like_context_window_error(text: str) -> bool: return any( token in text diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py index c7e32ebcd..6ca8d2639 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py @@ -3,44 +3,84 @@ from __future__ import annotations +import os + import data_designer.lazy_heavy_imports as lazy from data_designer.config.models import ModelConfig -from data_designer.engine.model_provider import ModelProviderRegistry +from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient +from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient from data_designer.engine.models.clients.base import ModelClient +from data_designer.engine.models.clients.retry import RetryConfig from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs from data_designer.engine.secret_resolver import SecretResolver +_BACKEND_ENV_VAR = "DATA_DESIGNER_MODEL_BACKEND" +_BACKEND_BRIDGE = "litellm_bridge" + def create_model_client( model_config: ModelConfig, secret_resolver: SecretResolver, model_provider_registry: ModelProviderRegistry, + *, + retry_config: RetryConfig | None = None, ) -> ModelClient: - """Create a ModelClient for the given model configuration. - - Resolves the provider, API key, and constructs a LiteLLM router wrapped in - a LiteLLMBridgeClient adapter. - - Args: - model_config: The model configuration to create a client for. - secret_resolver: Resolver for secrets referenced in provider configs. - model_provider_registry: Registry of model provider configurations. + """Create a ``ModelClient`` for the given model configuration. - Returns: - A ModelClient instance ready for use. + Routing logic: + 1. If ``DATA_DESIGNER_MODEL_BACKEND=litellm_bridge`` → always use bridge. + 2. If ``provider_type == "openai"`` → ``OpenAICompatibleClient``. + 3. Otherwise → ``LiteLLMBridgeClient`` (Anthropic native adapter is PR-4). """ provider = model_provider_registry.get_provider(model_config.provider) - api_key = None - if provider.api_key: - api_key = secret_resolver.resolve(provider.api_key) - api_key = api_key or "not-used-but-required" + api_key = _resolve_api_key(provider.api_key, secret_resolver) + max_parallel = model_config.inference_parameters.max_parallel_requests + raw_timeout = model_config.inference_parameters.timeout + timeout_s = float(raw_timeout if raw_timeout is not None else 60) + + backend = os.environ.get(_BACKEND_ENV_VAR, "").strip().lower() + if backend == _BACKEND_BRIDGE: + return _create_bridge_client(model_config, provider, api_key, max_parallel) + + if provider.provider_type.lower() == "openai": + return OpenAICompatibleClient( + provider_name=provider.name, + model_id=model_config.model, + endpoint=provider.endpoint, + api_key=api_key, + retry_config=retry_config, + max_parallel_requests=max_parallel, + timeout_s=timeout_s, + ) + + return _create_bridge_client(model_config, provider, api_key, max_parallel) + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + + +def _resolve_api_key(api_key_ref: str | None, secret_resolver: SecretResolver) -> str | None: + if not api_key_ref: + return None + resolved = secret_resolver.resolve(api_key_ref) + return resolved or None + + +def _create_bridge_client( + model_config: ModelConfig, + provider: ModelProvider, + api_key: str | None, + max_parallel: int, +) -> LiteLLMBridgeClient: + bridge_key = api_key or "not-used-but-required" litellm_params = lazy.litellm.LiteLLM_Params( model=f"{provider.provider_type}/{model_config.model}", api_base=provider.endpoint, - api_key=api_key, - max_parallel_requests=model_config.inference_parameters.max_parallel_requests, + api_key=bridge_key, + max_parallel_requests=max_parallel, ) deployment = { "model_name": model_config.model, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py index e5d74d440..fa11f7423 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -33,32 +33,32 @@ def parse_chat_completion_response(response: Any) -> ChatCompletionResponse: - first_choice = get_first_value_or_none(getattr(response, "choices", None)) + first_choice = get_first_value_or_none(get_value_from(response, "choices")) message = get_value_from(first_choice, "message") tool_calls = extract_tool_calls(get_value_from(message, "tool_calls")) images = extract_images_from_chat_message(message) assistant_message = AssistantMessage( content=coerce_message_content(get_value_from(message, "content")), - reasoning_content=get_value_from(message, "reasoning_content"), + reasoning_content=extract_reasoning_content(message), tool_calls=tool_calls, images=images, ) - usage = extract_usage(getattr(response, "usage", None), generated_images=len(images) if images else None) + usage = extract_usage(get_value_from(response, "usage"), generated_images=len(images) if images else None) return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) async def aparse_chat_completion_response(response: Any) -> ChatCompletionResponse: - first_choice = get_first_value_or_none(getattr(response, "choices", None)) + first_choice = get_first_value_or_none(get_value_from(response, "choices")) message = get_value_from(first_choice, "message") tool_calls = extract_tool_calls(get_value_from(message, "tool_calls")) images = await aextract_images_from_chat_message(message) assistant_message = AssistantMessage( content=coerce_message_content(get_value_from(message, "content")), - reasoning_content=get_value_from(message, "reasoning_content"), + reasoning_content=extract_reasoning_content(message), tool_calls=tool_calls, images=images, ) - usage = extract_usage(getattr(response, "usage", None), generated_images=len(images) if images else None) + usage = extract_usage(get_value_from(response, "usage"), generated_images=len(images) if images else None) return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) @@ -68,13 +68,13 @@ async def aparse_chat_completion_response(response: Any) -> ChatCompletionRespon def extract_images_from_chat_response(response: Any) -> list[ImagePayload]: - first_choice = get_first_value_or_none(getattr(response, "choices", None)) + first_choice = get_first_value_or_none(get_value_from(response, "choices")) message = get_value_from(first_choice, "message") return extract_images_from_chat_message(message) async def aextract_images_from_chat_response(response: Any) -> list[ImagePayload]: - first_choice = get_first_value_or_none(getattr(response, "choices", None)) + first_choice = get_first_value_or_none(get_value_from(response, "choices")) message = get_value_from(first_choice, "message") return await aextract_images_from_chat_message(message) @@ -92,11 +92,11 @@ async def aextract_images_from_chat_message(message: Any) -> list[ImagePayload]: def extract_images_from_image_response(response: Any) -> list[ImagePayload]: - return parse_image_list(getattr(response, "data", [])) + return parse_image_list(get_value_from(response, "data") or []) async def aextract_images_from_image_response(response: Any) -> list[ImagePayload]: - return await aparse_image_list(getattr(response, "data", [])) + return await aparse_image_list(get_value_from(response, "data") or []) def collect_raw_image_candidates(message: Any) -> tuple[list[Any], list[Any]]: @@ -227,6 +227,27 @@ def serialize_tool_arguments(arguments_value: Any) -> str: return str(arguments_value) +# --------------------------------------------------------------------------- +# Reasoning content extraction +# --------------------------------------------------------------------------- + + +def extract_reasoning_content(message: Any) -> str | None: + """Extract reasoning content from a provider response message. + + vLLM >= 0.16.0 uses ``message.reasoning`` as the canonical field; + ``message.reasoning_content`` is the legacy / LiteLLM-normalized fallback. + Check the canonical field first so reasoning traces survive LiteLLM removal. + + Ref: https://github.com/NVIDIA-NeMo/DataDesigner/issues/374 + """ + value = get_value_from(message, "reasoning") + if isinstance(value, str) and value: + return value + fallback = get_value_from(message, "reasoning_content") + return fallback if isinstance(fallback, str) and fallback else None + + # --------------------------------------------------------------------------- # Usage & content helpers # --------------------------------------------------------------------------- diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py new file mode 100644 index 000000000..4b7b45898 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass, field + +from httpx_retries import Retry, RetryTransport + + +@dataclass(frozen=True) +class RetryConfig: + """Retry policy for native HTTP adapters. + + Defaults mirror the current LiteLLM router settings in + ``LiteLLMRouterDefaultKwargs`` so behavior is preserved during migration. + """ + + max_retries: int = 3 + backoff_factor: float = 2.0 + backoff_jitter: float = 0.2 + max_backoff_wait: float = 60.0 + # TODO: Remove 429 from retryable_status_codes once ThrottleManager is + # wired via AsyncTaskScheduler (plan 346), so every rate-limit signal + # reaches AIMD backoff instead of being silently retried at the transport layer. + retryable_status_codes: frozenset[int] = field(default_factory=lambda: frozenset({429, 502, 503, 504})) + + +def create_retry_transport(config: RetryConfig | None = None) -> RetryTransport: + """Build an httpx ``RetryTransport`` from a :class:`RetryConfig`. + + The returned transport handles both sync and async requests (``RetryTransport`` + inherits from ``httpx.BaseTransport`` and ``httpx.AsyncBaseTransport``). + """ + cfg = config or RetryConfig() + retry = Retry( + total=cfg.max_retries, + backoff_factor=cfg.backoff_factor, + backoff_jitter=cfg.backoff_jitter, + max_backoff_wait=cfg.max_backoff_wait, + status_forcelist=cfg.retryable_status_codes, + respect_retry_after_header=True, + ) + return RetryTransport(retry=retry) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle.py new file mode 100644 index 000000000..6324f6dbc --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import logging +import math +import threading +import time +from dataclasses import dataclass, field +from enum import Enum + +logger = logging.getLogger(__name__) + + +class ThrottleDomain(str, Enum): + CHAT = "chat" + EMBEDDING = "embedding" + IMAGE = "image" + HEALTHCHECK = "healthcheck" + + +# --------------------------------------------------------------------------- +# AIMD tuning defaults +# --------------------------------------------------------------------------- + +DEFAULT_REDUCE_FACTOR: float = 0.5 +DEFAULT_ADDITIVE_INCREASE: int = 1 +DEFAULT_SUCCESS_WINDOW: int = 50 +DEFAULT_BLOCK_SECONDS: float = 2.0 +DEFAULT_MIN_LIMIT: int = 1 +DEFAULT_ACQUIRE_TIMEOUT: float = 300.0 + + +# --------------------------------------------------------------------------- +# Internal state containers +# --------------------------------------------------------------------------- + + +@dataclass +class DomainThrottleState: + """Per-domain AIMD concurrency state. + + All mutations must be performed while holding the owning + ``ThrottleManager._lock``. + """ + + current_limit: int + in_flight: int = 0 + blocked_until: float = 0.0 + success_streak: int = 0 + + +@dataclass +class GlobalCapState: + """Tracks the effective hard cap across aliases sharing a provider+model.""" + + limits_by_alias: dict[str, int] = field(default_factory=dict) + effective_max: int = 0 + + def register_alias(self, alias: str, max_parallel: int) -> None: + self.limits_by_alias[alias] = max_parallel + self.effective_max = min(self.limits_by_alias.values()) + + +# --------------------------------------------------------------------------- +# ThrottleManager +# --------------------------------------------------------------------------- + + +class ThrottleManager: + """Adaptive concurrency manager using AIMD (Additive Increase / + Multiplicative Decrease). + + Keyed at two levels: + + - **Global cap**: ``(provider_name, model_id)`` — shared hard ceiling. + - **Domain**: ``(provider_name, model_id, throttle_domain)`` — per-route + AIMD state that floats between 1 and the global effective max. + + **AIMD behaviour**: + + - *Decrease* — on a 429 / rate-limit signal the domain's concurrency limit + is multiplied by ``reduce_factor`` (default 0.5, i.e. halved) and a + cooldown block is applied for ``retry_after`` seconds (or + ``default_block_seconds``). + - *Increase* — after every ``success_window`` consecutive successful + releases the limit grows by ``additive_increase`` (default 1), up to the + global effective max. + - *Recovery cost* — after a single rate-limit halves the limit from *L* to + *L/2*, full recovery requires ``(L/2) * success_window / additive_increase`` + successful requests. With defaults (window=50, step=1) and L=32 that is + 800 requests. Raise ``additive_increase`` for faster recovery at the + cost of higher 429 risk. + + Thread-safe: all state mutations are guarded by a single lock so that + sync and async callers co-throttle correctly. + """ + + def __init__( + self, + *, + reduce_factor: float = DEFAULT_REDUCE_FACTOR, + additive_increase: int = DEFAULT_ADDITIVE_INCREASE, + success_window: int = DEFAULT_SUCCESS_WINDOW, + default_block_seconds: float = DEFAULT_BLOCK_SECONDS, + ) -> None: + self._reduce_factor = reduce_factor + self._additive_increase = additive_increase + self._success_window = success_window + self._default_block_seconds = default_block_seconds + self._lock = threading.Lock() + self._global_caps: dict[tuple[str, str], GlobalCapState] = {} + self._domains: dict[tuple[str, str, str], DomainThrottleState] = {} + + # ------------------------------------------------------------------- + # Registration + # ------------------------------------------------------------------- + + def register( + self, + *, + provider_name: str, + model_id: str, + alias: str, + max_parallel_requests: int, + ) -> None: + """Register a model alias and its concurrency limit. + + If multiple aliases share the same ``(provider_name, model_id)`` the + effective max is ``min()`` of all registered limits. Existing domain + states are clamped to the new effective max. + + **Ordering invariant:** ``register()`` must be called for a + ``(provider_name, model_id)`` pair *before* any ``try_acquire()`` for + the same key. If ``try_acquire()`` runs first it creates a domain at + ``DEFAULT_MIN_LIMIT`` and ``_clamp_domains`` only *decreases* limits, + so a later ``register()`` will not raise the domain to the intended + capacity. + """ + with self._lock: + global_key = (provider_name, model_id) + cap = self._global_caps.setdefault(global_key, GlobalCapState()) + cap.register_alias(alias, max_parallel_requests) + self._clamp_domains(global_key, cap.effective_max) + logger.debug( + "Throttle registered alias=%r for %s/%s (max_parallel=%d, effective_max=%d)", + alias, + provider_name, + model_id, + max_parallel_requests, + cap.effective_max, + ) + + # ------------------------------------------------------------------- + # Core non-blocking primitives + # ------------------------------------------------------------------- + + def try_acquire( + self, + *, + provider_name: str, + model_id: str, + domain: ThrottleDomain, + now: float | None = None, + ) -> float: + """Attempt to acquire a concurrency slot. + + Returns ``0.0`` if the slot was acquired, otherwise the number of + seconds the caller should wait before retrying. + """ + now = now if now is not None else time.monotonic() + with self._lock: + state = self._get_or_create_domain(provider_name, model_id, domain) + if now < state.blocked_until: + wait = state.blocked_until - now + logger.debug( + "Throttle %s/%s [%s] blocked for %.1fs (cooldown)", + provider_name, + model_id, + domain.value, + wait, + ) + return wait + if state.in_flight >= state.current_limit: + logger.debug( + "Throttle %s/%s [%s] at capacity (%d/%d), backing off %.1fs", + provider_name, + model_id, + domain.value, + state.in_flight, + state.current_limit, + self._default_block_seconds, + ) + return self._default_block_seconds + state.in_flight += 1 + return 0.0 + + def release_success( + self, + *, + provider_name: str, + model_id: str, + domain: ThrottleDomain, + now: float | None = None, + ) -> None: + with self._lock: + state = self._get_or_create_domain(provider_name, model_id, domain) + state.in_flight = max(0, state.in_flight - 1) + state.success_streak += 1 + if state.success_streak >= self._success_window: + effective_max = self._effective_max_for(provider_name, model_id) + if state.current_limit < effective_max: + prev = state.current_limit + state.current_limit = min(state.current_limit + self._additive_increase, effective_max) + if state.current_limit >= effective_max: + logger.info( + "🟢 Throttle %s/%s [%s] recovered to full capacity (%d)", + provider_name, + model_id, + domain.value, + state.current_limit, + ) + else: + logger.debug( + "Throttle %s/%s [%s] AIMD increase: limit %d -> %d (max %d)", + provider_name, + model_id, + domain.value, + prev, + state.current_limit, + effective_max, + ) + state.success_streak = 0 + + def release_rate_limited( + self, + *, + provider_name: str, + model_id: str, + domain: ThrottleDomain, + retry_after: float | None = None, + now: float | None = None, + ) -> None: + now = now if now is not None else time.monotonic() + with self._lock: + state = self._get_or_create_domain(provider_name, model_id, domain) + state.in_flight = max(0, state.in_flight - 1) + prev_limit = state.current_limit + state.current_limit = max(DEFAULT_MIN_LIMIT, math.floor(state.current_limit * self._reduce_factor)) + block_duration = retry_after if retry_after is not None and retry_after > 0 else self._default_block_seconds + state.blocked_until = now + block_duration + state.success_streak = 0 + logger.warning( + "🚦 Throttle %s/%s [%s] rate-limited: limit %d -> %d, blocked for %.1fs%s", + provider_name, + model_id, + domain.value, + prev_limit, + state.current_limit, + block_duration, + f" (retry-after={retry_after:.1f}s)" if retry_after else "", + ) + + def release_failure( + self, + *, + provider_name: str, + model_id: str, + domain: ThrottleDomain, + now: float | None = None, + ) -> None: + with self._lock: + state = self._get_or_create_domain(provider_name, model_id, domain) + state.in_flight = max(0, state.in_flight - 1) + + # ------------------------------------------------------------------- + # Sync / async wrappers + # ------------------------------------------------------------------- + + def acquire_sync( + self, + *, + provider_name: str, + model_id: str, + domain: ThrottleDomain, + timeout: float = DEFAULT_ACQUIRE_TIMEOUT, + ) -> None: + deadline = time.monotonic() + timeout + while True: + wait = self.try_acquire(provider_name=provider_name, model_id=model_id, domain=domain) + if wait == 0.0: + return + remaining = deadline - time.monotonic() + if remaining <= 0 or wait > remaining: + raise TimeoutError( + f"Throttle acquire timed out after {timeout:.0f}s for {provider_name}/{model_id} [{domain.value}]" + ) + time.sleep(min(wait, remaining)) + + async def acquire_async( + self, + *, + provider_name: str, + model_id: str, + domain: ThrottleDomain, + timeout: float = DEFAULT_ACQUIRE_TIMEOUT, + ) -> None: + deadline = time.monotonic() + timeout + while True: + wait = self.try_acquire(provider_name=provider_name, model_id=model_id, domain=domain) + if wait == 0.0: + return + remaining = deadline - time.monotonic() + if remaining <= 0 or wait > remaining: + raise TimeoutError( + f"Throttle acquire timed out after {timeout:.0f}s for {provider_name}/{model_id} [{domain.value}]" + ) + await asyncio.sleep(min(wait, remaining)) + + # ------------------------------------------------------------------- + # Introspection (useful for tests and telemetry) + # ------------------------------------------------------------------- + + def get_domain_state( + self, + provider_name: str, + model_id: str, + domain: ThrottleDomain, + ) -> DomainThrottleState | None: + with self._lock: + return self._domains.get((provider_name, model_id, domain.value)) + + def get_effective_max(self, provider_name: str, model_id: str) -> int: + with self._lock: + return self._effective_max_for(provider_name, model_id) + + # ------------------------------------------------------------------- + # Private helpers + # ------------------------------------------------------------------- + + def _get_or_create_domain( + self, + provider_name: str, + model_id: str, + domain: ThrottleDomain, + ) -> DomainThrottleState: + key = (provider_name, model_id, domain.value) + state = self._domains.get(key) + if state is None: + effective_max = self._effective_max_for(provider_name, model_id) + state = DomainThrottleState(current_limit=effective_max) + self._domains[key] = state + return state + + def _effective_max_for(self, provider_name: str, model_id: str) -> int: + cap = self._global_caps.get((provider_name, model_id)) + if cap is None or cap.effective_max <= 0: + return DEFAULT_MIN_LIMIT + return cap.effective_max + + def _clamp_domains(self, global_key: tuple[str, str], effective_max: int) -> None: + provider_name, model_id = global_key + for (pn, mid, _dom), state in self._domains.items(): + if pn == provider_name and mid == model_id and state.current_limit > effective_max: + state.current_limit = effective_max diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py index 034170b5d..3100e9b77 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py @@ -124,10 +124,11 @@ class TransportKwargs: - ``headers``: Extra HTTP headers to attach to the outgoing request. """ - _META_FIELDS: ClassVar[frozenset[str]] = frozenset({"extra_body", "extra_headers"}) + _META_FIELDS: ClassVar[frozenset[str]] = frozenset({"extra_body", "extra_headers", "timeout"}) body: dict[str, Any] headers: dict[str, str] + timeout: float | None = None @classmethod def from_request( @@ -146,11 +147,14 @@ def from_request( - ``False``: preserves it as ``extra_body`` in the body dict so that callers like LiteLLM can forward it without param validation. 3. Pops ``extra_headers`` into a separate headers dict. + 4. Extracts ``timeout`` as a per-request HTTP timeout override + (not forwarded to the API body). """ optional_fields = cls._collect_optional_fields(request, exclude=exclude | cls._META_FIELDS) extra_body = getattr(request, "extra_body", None) or {} extra_headers = getattr(request, "extra_headers", None) or {} + timeout = getattr(request, "timeout", None) if flatten_extra_body: body = {**optional_fields, **extra_body} @@ -159,7 +163,7 @@ def from_request( if extra_body: body["extra_body"] = extra_body - return cls(body=body, headers=dict(extra_headers)) + return cls(body=body, headers=dict(extra_headers), timeout=timeout) @staticmethod def _collect_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 8df7beffe..ec1fd6b7a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -704,7 +704,7 @@ def _build_chat_completion_request( if metadata: logger.debug( - "Unknown kwargs %s routed to LiteLLM metadata (not forwarded as model parameters). " + "Unknown kwargs %s dropped (not forwarded as model parameters). " "Use 'extra_body' to pass non-standard parameters to the model.", metadata.keys(), ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/factory.py b/packages/data-designer-engine/src/data_designer/engine/models/factory.py index a23c0dbcc..4fde41986 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/factory.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/factory.py @@ -38,6 +38,8 @@ def create_model_registry( A configured ModelRegistry instance. """ from data_designer.engine.models.clients.factory import create_model_client + from data_designer.engine.models.clients.retry import RetryConfig + from data_designer.engine.models.clients.throttle import ThrottleManager from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.litellm_overrides import apply_litellm_patches from data_designer.engine.models.registry import ModelRegistry @@ -48,8 +50,14 @@ def model_facade_factory( model_config: ModelConfig, secret_resolver: SecretResolver, model_provider_registry: ModelProviderRegistry, + retry_config: RetryConfig | None, ) -> ModelFacade: - client = create_model_client(model_config, secret_resolver, model_provider_registry) + client = create_model_client( + model_config, + secret_resolver, + model_provider_registry, + retry_config=retry_config, + ) return ModelFacade( model_config, model_provider_registry, @@ -62,4 +70,7 @@ def model_facade_factory( secret_resolver=secret_resolver, model_provider_registry=model_provider_registry, model_facade_factory=model_facade_factory, + # Throttle acquire/release is wired in a follow-up PR (AsyncTaskScheduler integration). + throttle_manager=ThrottleManager(), + retry_config=RetryConfig(), ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/registry.py b/packages/data-designer-engine/src/data_designer/engine/models/registry.py index dbb2d613d..1335fec6b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/registry.py @@ -4,7 +4,6 @@ from __future__ import annotations import logging -from collections.abc import Callable from typing import TYPE_CHECKING, Any from data_designer.config.models import GenerationType, ModelConfig @@ -14,8 +13,17 @@ from data_designer.logging import LOG_INDENT if TYPE_CHECKING: + from collections.abc import Callable + + from data_designer.engine.models.clients.retry import RetryConfig + from data_designer.engine.models.clients.throttle import ThrottleManager from data_designer.engine.models.facade import ModelFacade + ModelFacadeFactory = Callable[ + [ModelConfig, SecretResolver, ModelProviderRegistry, RetryConfig | None], + ModelFacade, + ] + logger = logging.getLogger(__name__) @@ -26,11 +34,15 @@ def __init__( secret_resolver: SecretResolver, model_provider_registry: ModelProviderRegistry, model_configs: list[ModelConfig] | None = None, - model_facade_factory: Callable[[ModelConfig, SecretResolver, ModelProviderRegistry], ModelFacade] | None = None, + model_facade_factory: ModelFacadeFactory | None = None, + throttle_manager: ThrottleManager | None = None, + retry_config: RetryConfig | None = None, ) -> None: self._secret_resolver = secret_resolver self._model_provider_registry = model_provider_registry self._model_facade_factory = model_facade_factory + self._throttle_manager = throttle_manager + self._retry_config = retry_config self._model_configs: dict[str, ModelConfig] = {} self._models: dict[str, ModelFacade] = {} self._set_model_configs(model_configs) @@ -43,6 +55,14 @@ def model_configs(self) -> dict[str, ModelConfig]: def models(self) -> dict[str, ModelFacade]: return self._models + @property + def throttle_manager(self) -> ThrottleManager | None: + return self._throttle_manager + + @property + def retry_config(self) -> RetryConfig | None: + return self._retry_config + def register_model_configs(self, model_configs: list[ModelConfig]) -> None: """Register a new Model configuration at runtime. @@ -200,9 +220,6 @@ def run_health_check(self, model_aliases: list[str]) -> None: logger.error(f"{LOG_INDENT}❌ Failed!") raise e - def _set_model_configs(self, model_configs: list[ModelConfig] | None) -> None: - self._model_configs = {mc.alias: mc for mc in (model_configs or [])} - def close(self) -> None: """Release resources held by all model facades. @@ -227,7 +244,23 @@ async def aclose(self) -> None: except Exception: logger.exception("Error closing facade for %s", facade.model_alias) + def _set_model_configs(self, model_configs: list[ModelConfig] | None) -> None: + self._model_configs = {mc.alias: mc for mc in (model_configs or [])} + def _get_model(self, model_config: ModelConfig) -> ModelFacade: if self._model_facade_factory is None: raise RuntimeError("ModelRegistry was not initialized with a model_facade_factory") - return self._model_facade_factory(model_config, self._secret_resolver, self._model_provider_registry) + facade = self._model_facade_factory( + model_config, + self._secret_resolver, + self._model_provider_registry, + self._retry_config, + ) + if self._throttle_manager is not None: + self._throttle_manager.register( + provider_name=facade.model_provider_name, + model_id=model_config.model, + alias=model_config.alias, + max_parallel_requests=model_config.inference_parameters.max_parallel_requests, + ) + return facade diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_factory.py b/packages/data-designer-engine/tests/engine/models/clients/test_factory.py new file mode 100644 index 000000000..d2b99fb01 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/clients/test_factory.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from data_designer.config.models import ( + ChatCompletionInferenceParams, + ModelConfig, + ModelProvider, +) +from data_designer.engine.model_provider import ModelProviderRegistry +from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient +from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient +from data_designer.engine.models.clients.factory import create_model_client +from data_designer.engine.models.clients.retry import RetryConfig +from data_designer.engine.secret_resolver import SecretResolver + + +@pytest.fixture +def openai_registry() -> ModelProviderRegistry: + provider = ModelProvider(name="openai-prod", endpoint="https://api.openai.com/v1", provider_type="openai") + return ModelProviderRegistry(providers=[provider]) + + +@pytest.fixture +def anthropic_registry() -> ModelProviderRegistry: + provider = ModelProvider(name="anthropic-prod", endpoint="https://api.anthropic.com", provider_type="anthropic") + return ModelProviderRegistry(providers=[provider]) + + +@pytest.fixture +def openai_model_config() -> ModelConfig: + return ModelConfig( + alias="test-model", + model="gpt-test", + inference_parameters=ChatCompletionInferenceParams(), + provider="openai-prod", + ) + + +@pytest.fixture +def anthropic_model_config() -> ModelConfig: + return ModelConfig( + alias="test-anthropic", + model="claude-test", + inference_parameters=ChatCompletionInferenceParams(), + provider="anthropic-prod", + ) + + +@pytest.fixture +def secret_resolver() -> SecretResolver: + resolver = MagicMock(spec=SecretResolver) + resolver.resolve.return_value = "resolved-key" + return resolver + + +# --- Provider routing --- + + +def test_openai_provider_creates_native_client( + openai_model_config: ModelConfig, + secret_resolver: SecretResolver, + openai_registry: ModelProviderRegistry, +) -> None: + client = create_model_client( + openai_model_config, + secret_resolver, + openai_registry, + retry_config=RetryConfig(), + ) + assert isinstance(client, OpenAICompatibleClient) + + +@patch("data_designer.engine.models.clients.factory.CustomRouter") +@patch("data_designer.engine.models.clients.factory.LiteLLMRouterDefaultKwargs") +def test_non_openai_provider_creates_bridge_client( + mock_kwargs: MagicMock, + mock_router: MagicMock, + anthropic_model_config: ModelConfig, + secret_resolver: SecretResolver, + anthropic_registry: ModelProviderRegistry, +) -> None: + mock_kwargs.return_value.model_dump.return_value = {} + client = create_model_client(anthropic_model_config, secret_resolver, anthropic_registry) + assert isinstance(client, LiteLLMBridgeClient) + + +# --- Backend env var override --- + + +@patch("data_designer.engine.models.clients.factory.CustomRouter") +@patch("data_designer.engine.models.clients.factory.LiteLLMRouterDefaultKwargs") +def test_bridge_env_override_forces_bridge_for_openai_provider( + mock_kwargs: MagicMock, + mock_router: MagicMock, + openai_model_config: ModelConfig, + secret_resolver: SecretResolver, + openai_registry: ModelProviderRegistry, +) -> None: + mock_kwargs.return_value.model_dump.return_value = {} + with patch.dict("os.environ", {"DATA_DESIGNER_MODEL_BACKEND": "litellm_bridge"}): + client = create_model_client(openai_model_config, secret_resolver, openai_registry) + assert isinstance(client, LiteLLMBridgeClient) + + +def test_openai_provider_type_case_insensitive( + openai_model_config: ModelConfig, + secret_resolver: SecretResolver, +) -> None: + for variant in ("OpenAI", "OPENAI", "Openai"): + provider = ModelProvider(name="openai-prod", endpoint="https://api.openai.com/v1", provider_type=variant) + registry = ModelProviderRegistry(providers=[provider]) + client = create_model_client(openai_model_config, secret_resolver, registry) + assert isinstance(client, OpenAICompatibleClient), f"Failed for provider_type={variant!r}" + + +def test_native_env_var_still_uses_native_for_openai_provider( + openai_model_config: ModelConfig, + secret_resolver: SecretResolver, + openai_registry: ModelProviderRegistry, +) -> None: + with patch.dict("os.environ", {"DATA_DESIGNER_MODEL_BACKEND": "native"}): + client = create_model_client( + openai_model_config, + secret_resolver, + openai_registry, + ) + assert isinstance(client, OpenAICompatibleClient) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py new file mode 100644 index 000000000..392130ac1 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py @@ -0,0 +1,452 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind +from data_designer.engine.models.clients.types import ( + ChatCompletionRequest, + EmbeddingRequest, + ImageGenerationRequest, +) + +PROVIDER = "test-provider" +MODEL = "gpt-test" +ENDPOINT = "https://api.example.com/v1" + + +def _mock_httpx_response(json_data: dict[str, Any], status_code: int = 200) -> MagicMock: + resp = MagicMock() + resp.status_code = status_code + resp.json.return_value = json_data + resp.text = json.dumps(json_data) + resp.headers = {} + return resp + + +def _make_sync_client(response_json: dict[str, Any], status_code: int = 200) -> MagicMock: + mock = MagicMock() + mock.post = MagicMock(return_value=_mock_httpx_response(response_json, status_code)) + return mock + + +def _make_async_client(response_json: dict[str, Any], status_code: int = 200) -> MagicMock: + mock = MagicMock() + mock.post = AsyncMock(return_value=_mock_httpx_response(response_json, status_code)) + return mock + + +def _make_client( + *, + sync_client: MagicMock | None = None, + async_client: MagicMock | None = None, + api_key: str | None = "sk-test-key", +) -> OpenAICompatibleClient: + return OpenAICompatibleClient( + provider_name=PROVIDER, + model_id=MODEL, + endpoint=ENDPOINT, + api_key=api_key, + sync_client=sync_client, + async_client=async_client, + ) + + +# --- Response helpers --- + + +def _chat_response( + content: str = "Hello!", + reasoning: str | None = None, + tool_calls: list[dict[str, Any]] | None = None, +) -> dict[str, Any]: + message: dict[str, Any] = {"role": "assistant", "content": content} + if reasoning is not None: + message["reasoning"] = reasoning + if tool_calls is not None: + message["tool_calls"] = tool_calls + return { + "choices": [{"index": 0, "message": message, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + + +def _embedding_response() -> dict[str, Any]: + return { + "data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}], + "usage": {"prompt_tokens": 5, "total_tokens": 5}, + } + + +def _image_response() -> dict[str, Any]: + return {"data": [{"b64_json": "aW1hZ2VkYXRh"}]} + + +# --- Chat completion --- + + +def test_completion_maps_canonical_fields() -> None: + response_json = _chat_response(content="Hello!", reasoning="step-by-step") + client = _make_client(sync_client=_make_sync_client(response_json)) + + request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Hi"}]) + result = client.completion(request) + + assert result.message.content == "Hello!" + assert result.message.reasoning_content == "step-by-step" + assert result.usage is not None + assert result.usage.input_tokens == 10 + assert result.usage.output_tokens == 5 + + +def test_completion_with_tool_calls() -> None: + tool_calls = [{"id": "tc1", "type": "function", "function": {"name": "search", "arguments": '{"q": "x"}'}}] + client = _make_client(sync_client=_make_sync_client(_chat_response(tool_calls=tool_calls))) + + request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Search"}]) + result = client.completion(request) + + assert len(result.message.tool_calls) == 1 + assert result.message.tool_calls[0].name == "search" + assert result.message.tool_calls[0].arguments_json == '{"q": "x"}' + + +def test_completion_posts_to_chat_completions_route() -> None: + sync_mock = _make_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": "Hi"}], + temperature=0.7, + extra_body={"seed": 42}, + extra_headers={"X-Trace": "1"}, + ) + client.completion(request) + + call_args = sync_mock.post.call_args + assert "/chat/completions" in call_args.args[0] + payload = call_args.kwargs["json"] + assert payload["model"] == MODEL + assert payload["temperature"] == 0.7 + assert payload["seed"] == 42 + assert "timeout" not in payload + assert call_args.kwargs["headers"]["X-Trace"] == "1" + + +def test_timeout_excluded_from_body_and_used_as_http_timeout() -> None: + sync_mock = _make_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": "Hi"}], + timeout=120.0, + ) + client.completion(request) + + call_args = sync_mock.post.call_args + payload = call_args.kwargs["json"] + assert "timeout" not in payload + http_timeout = call_args.kwargs["timeout"] + assert http_timeout.connect == 120.0 + assert http_timeout.read == 120.0 + + +def test_default_timeout_used_when_request_timeout_is_none() -> None: + sync_mock = _make_sync_client(_chat_response()) + client = OpenAICompatibleClient( + provider_name=PROVIDER, + model_id=MODEL, + endpoint=ENDPOINT, + timeout_s=45.0, + sync_client=sync_mock, + ) + + request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Hi"}]) + client.completion(request) + + call_args = sync_mock.post.call_args + http_timeout = call_args.kwargs["timeout"] + assert http_timeout.connect == 45.0 + + +@pytest.mark.asyncio +async def test_acompletion_maps_canonical_fields() -> None: + client = _make_client(async_client=_make_async_client(_chat_response(content="async result"))) + + request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Hi"}]) + result = await client.acompletion(request) + + assert result.message.content == "async result" + + +# --- Embeddings --- + + +def test_embeddings_maps_vectors_and_usage() -> None: + client = _make_client(sync_client=_make_sync_client(_embedding_response())) + + request = EmbeddingRequest(model=MODEL, inputs=["hello world"]) + result = client.embeddings(request) + + assert result.vectors == [[0.1, 0.2, 0.3]] + assert result.usage is not None + assert result.usage.input_tokens == 5 + + +def test_embeddings_posts_to_embeddings_route() -> None: + sync_mock = _make_sync_client(_embedding_response()) + client = _make_client(sync_client=sync_mock) + + request = EmbeddingRequest(model=MODEL, inputs=["hello"]) + client.embeddings(request) + + call_url = sync_mock.post.call_args.args[0] + assert "/embeddings" in call_url + + +@pytest.mark.asyncio +async def test_aembeddings_maps_vectors() -> None: + client = _make_client(async_client=_make_async_client(_embedding_response())) + + request = EmbeddingRequest(model=MODEL, inputs=["hello"]) + result = await client.aembeddings(request) + + assert len(result.vectors) == 1 + + +# --- Image generation --- + + +def test_generate_image_diffusion_route() -> None: + sync_mock = _make_sync_client(_image_response()) + client = _make_client(sync_client=sync_mock) + + request = ImageGenerationRequest(model=MODEL, prompt="a sunset") + result = client.generate_image(request) + + assert len(result.images) == 1 + assert result.images[0].b64_data == "aW1hZ2VkYXRh" + call_url = sync_mock.post.call_args.args[0] + assert "/images/generations" in call_url + + +def test_generate_image_chat_route_when_messages_present() -> None: + chat_img_response = { + "choices": [{"message": {"content": None, "images": [{"b64_json": "Y2hhdGltZw=="}]}}], + "usage": {"prompt_tokens": 5, "completion_tokens": 0, "total_tokens": 5}, + } + sync_mock = _make_sync_client(chat_img_response) + client = _make_client(sync_client=sync_mock) + + request = ImageGenerationRequest( + model=MODEL, + prompt="a sunset", + messages=[{"role": "user", "content": "draw a sunset"}], + ) + result = client.generate_image(request) + + assert len(result.images) == 1 + call_url = sync_mock.post.call_args.args[0] + assert "/chat/completions" in call_url + + +@pytest.mark.asyncio +async def test_agenerate_image_maps_images() -> None: + client = _make_client(async_client=_make_async_client(_image_response())) + + request = ImageGenerationRequest(model=MODEL, prompt="a cat") + result = await client.agenerate_image(request) + + assert len(result.images) == 1 + + +# --- Auth headers --- + + +def test_auth_header_present_when_api_key_set() -> None: + sync_mock = _make_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Hi"}]) + client.completion(request) + + headers = sync_mock.post.call_args.kwargs["headers"] + assert headers["Authorization"] == "Bearer sk-test-key" + assert headers["Content-Type"] == "application/json" + + +def test_no_auth_header_when_api_key_none() -> None: + sync_mock = _make_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock, api_key=None) + + request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Hi"}]) + client.completion(request) + + headers = sync_mock.post.call_args.kwargs["headers"] + assert "Authorization" not in headers + + +def test_extra_headers_merged_into_auth_headers() -> None: + sync_mock = _make_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": "Hi"}], + extra_headers={"X-Custom": "val"}, + ) + client.completion(request) + + headers = sync_mock.post.call_args.kwargs["headers"] + assert headers["X-Custom"] == "val" + assert headers["Authorization"] == "Bearer sk-test-key" + + +# --- Error mapping --- + + +@pytest.mark.parametrize( + "status_code,expected_kind", + [ + (429, ProviderErrorKind.RATE_LIMIT), + (401, ProviderErrorKind.AUTHENTICATION), + (403, ProviderErrorKind.PERMISSION_DENIED), + (404, ProviderErrorKind.NOT_FOUND), + (500, ProviderErrorKind.INTERNAL_SERVER), + ], + ids=["rate-limit", "auth", "permission", "not-found", "server-error"], +) +def test_http_error_maps_to_provider_error( + status_code: int, + expected_kind: ProviderErrorKind, +) -> None: + client = _make_client(sync_client=_make_sync_client({"error": {"message": "fail"}}, status_code=status_code)) + + request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Hi"}]) + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == expected_kind + + +def test_transport_timeout_raises_timeout_error() -> None: + sync_mock = MagicMock() + sync_mock.post = MagicMock(side_effect=TimeoutError("timed out")) + client = _make_client(sync_client=sync_mock) + + request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Hi"}]) + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.TIMEOUT + + +def test_transport_connection_error_raises_connection_error() -> None: + sync_mock = MagicMock() + sync_mock.post = MagicMock(side_effect=ConnectionError("refused")) + client = _make_client(sync_client=sync_mock) + + request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Hi"}]) + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.API_CONNECTION + + +def test_non_json_response_raises_provider_error() -> None: + resp = MagicMock() + resp.status_code = 200 + resp.json.side_effect = ValueError("not json") + sync_mock = MagicMock() + sync_mock.post = MagicMock(return_value=resp) + client = _make_client(sync_client=sync_mock) + + request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Hi"}]) + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.API_ERROR + assert "non-JSON" in exc_info.value.message + + +# --- Lifecycle --- + + +def test_close_delegates_to_httpx_client() -> None: + sync_mock = MagicMock() + client = _make_client(sync_client=sync_mock) + client.close() + sync_mock.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_aclose_closes_both_clients() -> None: + sync_mock = MagicMock() + async_mock = MagicMock() + async_mock.aclose = AsyncMock() + client = _make_client(sync_client=sync_mock, async_client=async_mock) + + await client.aclose() + + async_mock.aclose.assert_awaited_once() + sync_mock.close.assert_called_once() + + +def test_close_noop_when_no_client_created() -> None: + client = _make_client() + client.close() # should not raise + + +@pytest.mark.asyncio +async def test_aclose_noop_when_no_client_created() -> None: + client = _make_client() + await client.aclose() # should not raise + + +# --- Lazy client initialization --- + + +def test_lazy_sync_client_creates_real_httpx_client() -> None: + """Exercise the lazy-init path that production code uses (no injected mock).""" + client = _make_client() + assert client._client is None + + sync_client = client._get_sync_client() + + assert sync_client is not None + assert client._client is sync_client + # Second call returns the same instance (double-check locking). + assert client._get_sync_client() is sync_client + client.close() + + +def test_lazy_async_client_creates_real_httpx_async_client() -> None: + client = _make_client() + assert client._aclient is None + + async_client = client._get_async_client() + + assert async_client is not None + assert client._aclient is async_client + assert client._get_async_client() is async_client + + +# --- Capabilities --- + + +@pytest.mark.parametrize( + "method", + ["supports_chat_completion", "supports_embeddings", "supports_image_generation"], +) +def test_capability_checks_return_true(method: str) -> None: + client = _make_client() + assert getattr(client, method)() is True diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py index e6662e3f2..76ff2b2a3 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py @@ -5,7 +5,7 @@ import pytest -from data_designer.engine.models.clients.parsing import extract_tool_calls +from data_designer.engine.models.clients.parsing import extract_reasoning_content, extract_tool_calls from data_designer.engine.models.clients.types import ( ChatCompletionRequest, EmbeddingRequest, @@ -236,3 +236,43 @@ def test_extract_tool_calls_none_arguments() -> None: result = extract_tool_calls(raw) assert result[0].arguments_json == "{}" + + +# --- extract_reasoning_content (vLLM field migration) --- + + +@pytest.mark.parametrize( + "message,expected", + [ + ({"reasoning": "step-by-step thinking"}, "step-by-step thinking"), + ({"reasoning_content": "legacy thinking"}, "legacy thinking"), + ({"reasoning": "canonical", "reasoning_content": "legacy"}, "canonical"), + ({"content": "hello"}, None), + (None, None), + ({"reasoning": "", "reasoning_content": "fallback"}, "fallback"), + ({"reasoning_content": {"nested": "dict"}}, None), + ({"reasoning_content": ["list", "value"]}, None), + ({"reasoning_content": ""}, None), + ], + ids=[ + "only-reasoning", + "only-reasoning_content", + "both-reasoning-takes-precedence", + "neither-field", + "none-message", + "empty-reasoning-falls-back", + "non-string-dict-fallback-returns-none", + "non-string-list-fallback-returns-none", + "empty-string-fallback-returns-none", + ], +) +def test_extract_reasoning_content(message: dict | None, expected: str | None) -> None: + assert extract_reasoning_content(message) == expected + + +def test_extract_reasoning_content_works_with_object_style_message() -> None: + class Msg: + reasoning = "from object" + reasoning_content = "legacy object" + + assert extract_reasoning_content(Msg()) == "from object" diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_retry.py b/packages/data-designer-engine/tests/engine/models/clients/test_retry.py new file mode 100644 index 000000000..f51409932 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/clients/test_retry.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest +from httpx_retries import RetryTransport + +from data_designer.engine.models.clients.retry import RetryConfig, create_retry_transport + + +@pytest.mark.parametrize( + "field,expected", + [ + ("max_retries", 3), + ("backoff_factor", 2.0), + ("backoff_jitter", 0.2), + ("max_backoff_wait", 60.0), + ("retryable_status_codes", frozenset({429, 502, 503, 504})), + ], + ids=["max_retries", "backoff_factor", "backoff_jitter", "max_backoff_wait", "retryable_status_codes"], +) +def test_retry_config_defaults_match_litellm_router(field: str, expected: object) -> None: + config = RetryConfig() + assert getattr(config, field) == expected + + +def test_retry_config_is_frozen() -> None: + config = RetryConfig() + with pytest.raises(AttributeError): + config.max_retries = 10 # type: ignore[misc] + + +def test_create_retry_transport_returns_retry_transport() -> None: + transport = create_retry_transport() + assert isinstance(transport, RetryTransport) + + +def test_create_retry_transport_with_none_uses_defaults() -> None: + transport = create_retry_transport(None) + assert transport.retry.total == 3 + + +@pytest.mark.parametrize( + "field,config_value,retry_attr,expected", + [ + ("max_retries", 5, "total", 5), + ("backoff_factor", 1.0, "backoff_factor", 1.0), + ("retryable_status_codes", frozenset({429, 500}), "status_forcelist", frozenset({429, 500})), + ], + ids=["max_retries", "backoff_factor", "status_forcelist"], +) +def test_create_retry_transport_propagates_config( + field: str, config_value: object, retry_attr: str, expected: object +) -> None: + config = RetryConfig(**{field: config_value}) + transport = create_retry_transport(config) + assert getattr(transport.retry, retry_attr) == expected diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_throttle.py b/packages/data-designer-engine/tests/engine/models/clients/test_throttle.py new file mode 100644 index 000000000..2de6ab574 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/clients/test_throttle.py @@ -0,0 +1,347 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import threading +import time + +import pytest + +from data_designer.engine.models.clients.throttle import ( + DEFAULT_BLOCK_SECONDS, + ThrottleDomain, + ThrottleManager, +) + +PROVIDER = "test-provider" +MODEL = "gpt-test" +DOMAIN = ThrottleDomain.CHAT + + +@pytest.fixture +def manager() -> ThrottleManager: + tm = ThrottleManager() + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=4) + return tm + + +# --- try_acquire --- + + +def test_acquire_under_limit_returns_zero(manager: ThrottleManager) -> None: + wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + assert wait == 0.0 + + +def test_acquire_at_capacity_returns_positive_wait(manager: ThrottleManager) -> None: + for _ in range(4): + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + assert wait > 0.0 + + +def test_acquire_respects_blocked_until(manager: ThrottleManager) -> None: + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, retry_after=5.0, now=1.0) + wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=2.0) + assert wait == pytest.approx(4.0, abs=0.01) + + +def test_acquire_without_registration_uses_min_limit() -> None: + tm = ThrottleManager() + assert tm.try_acquire(provider_name="unknown", model_id="m", domain=DOMAIN, now=0.0) == 0.0 + assert tm.try_acquire(provider_name="unknown", model_id="m", domain=DOMAIN, now=0.0) > 0.0 + + +# --- release_success --- + + +def test_release_success_frees_slot(manager: ThrottleManager) -> None: + for _ in range(4): + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + manager.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + assert wait == 0.0 + + +def test_additive_increase_after_success_window() -> None: + tm = ThrottleManager(success_window=5) + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=10) + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + limit_after_drop = state.current_limit + + for i in range(5): + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) + tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) + + assert state.current_limit == limit_after_drop + 1 + + +def test_additive_increase_uses_configured_step() -> None: + tm = ThrottleManager(success_window=1, additive_increase=3) + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=20) + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + limit_after_drop = state.current_limit + + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) + tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) + + assert state.current_limit == limit_after_drop + 3 + + +def test_current_limit_never_exceeds_effective_max() -> None: + tm = ThrottleManager(success_window=1) + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=2) + for i in range(20): + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) + tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) + state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.current_limit <= 2 + + +def test_additive_increase_clamped_to_effective_max() -> None: + tm = ThrottleManager(success_window=1, additive_increase=100) + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=5) + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) + tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) + + state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.current_limit == 5 + + +# --- release_rate_limited --- + + +def test_rate_limited_halves_current_limit(manager: ThrottleManager) -> None: + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.current_limit == 2 + + +def test_rate_limited_never_drops_below_one() -> None: + tm = ThrottleManager() + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.current_limit >= 1 + + +def test_rate_limited_resets_success_streak(manager: ThrottleManager) -> None: + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + manager.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.success_streak == 0 + + +def test_rate_limited_uses_retry_after_for_blocked_until(manager: ThrottleManager) -> None: + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) + manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, retry_after=7.0, now=10.0) + state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.blocked_until == pytest.approx(17.0, abs=0.01) + + +def test_rate_limited_uses_default_block_when_no_retry_after(manager: ThrottleManager) -> None: + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) + manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) + state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.blocked_until == pytest.approx(10.0 + DEFAULT_BLOCK_SECONDS, abs=0.01) + + +# --- release_failure --- + + +def test_failure_releases_slot_without_limit_change(manager: ThrottleManager) -> None: + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + limit_before = state.current_limit + manager.release_failure(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + assert state.current_limit == limit_before + assert state.in_flight == 0 + + +# --- Global cap --- + + +def test_two_aliases_effective_max_is_minimum() -> None: + tm = ThrottleManager() + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=10) + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a2", max_parallel_requests=3) + assert tm.get_effective_max(PROVIDER, MODEL) == 3 + + +def test_domain_clamped_when_new_alias_lowers_cap() -> None: + tm = ThrottleManager() + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=10) + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.current_limit == 10 + + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a2", max_parallel_requests=3) + assert state.current_limit == 3 + + +# --- Domain isolation --- + + +def test_chat_and_embedding_throttle_independently() -> None: + tm = ThrottleManager() + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=2) + + for _ in range(2): + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=ThrottleDomain.CHAT, now=0.0) + wait_chat = tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=ThrottleDomain.CHAT, now=0.0) + assert wait_chat > 0.0 + + wait_emb = tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=ThrottleDomain.EMBEDDING, now=0.0) + assert wait_emb == 0.0 + + +# --- 429 lifecycle scenario --- + + +def test_rate_limit_lifecycle_acquire_backoff_recover() -> None: + """End-to-end AIMD lifecycle: steady-state → 429 → backoff → cooldown → recovery. + + Uses the ``now`` parameter to simulate time without real sleeps. + Config: success_window=3, additive_increase=1, max_parallel=4. + """ + tm = ThrottleManager(success_window=3, additive_increase=1) + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=4) + t = 0.0 + + # Phase 1 — Steady state (t=0): all 4 slots acquired and released successfully. + # Limit stays at 4 because no rate-limit event has occurred. + for _ in range(4): + assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 + for _ in range(4): + tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) + + state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.current_limit == 4 + + # Phase 2 — 429 hits (t=10): one request gets rate-limited with retry-after=5s. + # Multiplicative decrease halves the limit: 4 → 2. + # Domain is blocked until t=10+5=15. + t = 10.0 + assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 + tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, retry_after=5.0, now=t) + assert state.current_limit == 2 + assert state.blocked_until == 15.0 + + # Phase 3 — During cooldown (t=12): acquire returns positive wait since 12 < 15. + wait = tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=12.0) + assert wait > 0.0 + + # Phase 4 — Cooldown expires, reduced capacity (t=16): acquire succeeds again. + # One success → streak=1 (need 3 for a window), so limit stays at 2. + t = 16.0 + assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 + tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) + assert state.current_limit == 2 + + # Phase 5 — First recovery window (t=17-18): two more successes complete the + # window (streak hits 3). Additive increase: limit 2 → 3. + for i in range(2): + t += 1.0 + assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 + tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) + + assert state.current_limit == 3 + + # Phase 6 — Second recovery window (t=19-21): three more successes complete + # another window. Additive increase: limit 3 → 4 (fully recovered). + for i in range(3): + t += 1.0 + assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 + tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) + + assert state.current_limit == 4 + + +# --- Acquire timeout --- + + +def test_acquire_sync_raises_timeout_when_at_capacity() -> None: + tm = ThrottleManager() + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) + # Saturate the single slot so try_acquire returns a positive wait. + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) + + with pytest.raises(TimeoutError, match="timed out"): + tm.acquire_sync(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, timeout=0.0) + + +def test_acquire_sync_does_not_overshoot_timeout() -> None: + """When wait > remaining budget, raise immediately instead of sleeping the full wait.""" + tm = ThrottleManager(default_block_seconds=5.0) + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) + + # Timeout of 0.5s is less than the 5s block wait — should raise fast, not sleep 5s. + start = time.monotonic() + with pytest.raises(TimeoutError, match="timed out"): + tm.acquire_sync(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, timeout=0.5) + elapsed = time.monotonic() - start + assert elapsed < 2.0, f"acquire_sync overshot timeout: elapsed {elapsed:.1f}s (expected <2s)" + + +@pytest.mark.asyncio +async def test_acquire_async_raises_timeout_when_at_capacity() -> None: + tm = ThrottleManager() + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) + tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) + + with pytest.raises(TimeoutError, match="timed out"): + await tm.acquire_async(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, timeout=0.0) + + +# --- Thread safety --- + + +def test_concurrent_acquire_release_no_errors() -> None: + tm = ThrottleManager() + tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=4) + errors: list[Exception] = [] + + def worker() -> None: + try: + for _ in range(50): + tm.acquire_sync(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) + tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) + except Exception as exc: + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + assert not errors, f"Thread errors: {errors}" + + state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.in_flight == 0 diff --git a/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py b/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py index 5b046bff8..cb5d569f9 100644 --- a/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py +++ b/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py @@ -19,7 +19,7 @@ def _stub_model_registry() -> ModelRegistry: return ModelRegistry( secret_resolver=Mock(), model_provider_registry=Mock(), - model_facade_factory=lambda *_args, **_kwargs: Mock(), + model_facade_factory=lambda *_args: Mock(), model_configs=[], ) diff --git a/plans/343/model-facade-overhaul-plan-step-1.md b/plans/343/model-facade-overhaul-plan-step-1.md index 56f95cfe5..320145f20 100644 --- a/plans/343/model-facade-overhaul-plan-step-1.md +++ b/plans/343/model-facade-overhaul-plan-step-1.md @@ -192,13 +192,13 @@ Updated files (Step 1): ### PR slicing (recommended) -1. PR-1: canonical types/interfaces/errors + bridge adapter + no behavior change. +1. PR-1 ([#359](https://github.com/NVIDIA-NeMo/DataDesigner/pull/359) — merged): canonical types/interfaces/errors + bridge adapter + no behavior change. - files: `clients/base.py`, `clients/types.py`, `clients/errors.py`, `clients/adapters/litellm_bridge.py` - docs: add architecture notes for canonical adapter boundary and bridge purpose. -2. PR-2: `ModelFacade` switched to `ModelClient` + lifecycle wiring + parity tests on bridge. +2. PR-2 ([#373](https://github.com/NVIDIA-NeMo/DataDesigner/pull/373) — merged): `ModelFacade` switched to `ModelClient` + lifecycle wiring + parity tests on bridge. - files: `models/facade.py`, `models/errors.py`, `models/factory.py`, `clients/factory.py`, `models/registry.py`, `resources/resource_provider.py`, `interface/data_designer.py` - docs: update internal lifecycle/ownership docs for adapter teardown and resource shutdown behavior. -3. PR-3: OpenAI-compatible adapter + shared retry/throttle + auth integration. +3. PR-3 (in progress): OpenAI-compatible adapter + shared retry/throttle + auth integration. - files: `clients/retry.py`, `clients/throttle.py`, `clients/adapters/openai_compatible.py` - docs: add provider docs for openai-compatible routing, endpoint expectations, and retry/throttle semantics. 4. PR-4: Anthropic adapter + auth integration + capability gating. diff --git a/plans/343/model-facade-overhaul-pr-1-architecture-notes.md b/plans/343/model-facade-overhaul-pr-1-architecture-notes.md index 6667d04dc..47bab726a 100644 --- a/plans/343/model-facade-overhaul-pr-1-architecture-notes.md +++ b/plans/343/model-facade-overhaul-pr-1-architecture-notes.md @@ -6,8 +6,9 @@ authors: # Model Facade Overhaul PR-1 Architecture Notes -This document captures the architecture intent for PR-1 from -`plans/343/model-facade-overhaul-plan-step-1.md`. +This document captures the architecture intent for +[PR #359](https://github.com/NVIDIA-NeMo/DataDesigner/pull/359) +from `plans/343/model-facade-overhaul-plan-step-1.md`. ## Canonical Adapter Boundary diff --git a/plans/343/model-facade-overhaul-pr-2-architecture-notes.md b/plans/343/model-facade-overhaul-pr-2-architecture-notes.md index 5a07a9549..cb0b4f3d9 100644 --- a/plans/343/model-facade-overhaul-pr-2-architecture-notes.md +++ b/plans/343/model-facade-overhaul-pr-2-architecture-notes.md @@ -6,8 +6,9 @@ authors: # Model Facade Overhaul PR-2 Architecture Notes -This document captures the architecture intent for PR-2 from -`plans/343/model-facade-overhaul-plan-step-1.md`. +This document captures the architecture intent for +[PR #373](https://github.com/NVIDIA-NeMo/DataDesigner/pull/373) +from `plans/343/model-facade-overhaul-plan-step-1.md`. ## Goal diff --git a/plans/343/model-facade-overhaul-pr-3-architecture-notes.md b/plans/343/model-facade-overhaul-pr-3-architecture-notes.md new file mode 100644 index 000000000..af4bb9382 --- /dev/null +++ b/plans/343/model-facade-overhaul-pr-3-architecture-notes.md @@ -0,0 +1,193 @@ +--- +date: 2026-03-06 +authors: + - nmulepati +--- + +# Model Facade Overhaul PR-3 Architecture Notes + +This document captures the architecture intent for PR-3 from +`plans/343/model-facade-overhaul-plan-step-1.md`. + +## Goal + +Introduce the first native HTTP adapter (`OpenAICompatibleClient`) with shared +retry infrastructure and a standalone adaptive throttle resource. After this +PR, the client factory routes `provider_type="openai"` to the native adapter +while all other provider types continue through the `LiteLLMBridgeClient`. + +## What Changes + +### 1. Shared retry module (`clients/retry.py`) + +`RetryConfig` is a frozen dataclass whose defaults mirror the current LiteLLM +router settings (`LiteLLMRouterDefaultKwargs`): + +- `max_retries = 3` +- `backoff_factor = 2.0` (exponential) +- `backoff_jitter = 0.2` +- `retryable_status_codes = {429, 502, 503, 504}` + +`create_retry_transport()` builds an `httpx_retries.RetryTransport` from a +`RetryConfig`. The transport handles both sync and async requests (it inherits +from `httpx.BaseTransport` and `httpx.AsyncBaseTransport`). + +### 2. Adaptive throttle module (`clients/throttle.py`) + +`ThrottleManager` implements AIMD (additive-increase / multiplicative-decrease) +concurrency control keyed at two levels: + +- **Global cap** `(provider_name, model_id)` — shared hard ceiling derived as + `min(max_parallel_requests)` across all aliases targeting the same provider + and model. +- **Domain** `(provider_name, model_id, throttle_domain)` — per-route AIMD + state (`chat`, `embedding`, `image`, `healthcheck`) that floats between 1 + and the global effective max. + +AIMD behaviour: + +- *Decrease* — on a 429 the domain limit is multiplied by `reduce_factor` + (default 0.5) and a cooldown block is applied. +- *Increase* — after every `success_window` (default 50) consecutive + successful releases the limit grows by `additive_increase` (default 1), + up to the global effective max. Both `additive_increase` and + `success_window` are constructor parameters for tuning recovery speed. +- *Recovery cost* — after a single halve from *L* to *L/2*, full recovery + requires `(L/2) × success_window / additive_increase` successful requests. + +Core state methods are non-blocking so both sync and async wrappers reuse +the same thread-safe state: + +- `try_acquire(now) -> wait_seconds` (0 = acquired) +- `release_success(now)` +- `release_rate_limited(now, retry_after)` +- `release_failure(now)` + +`acquire_sync` and `acquire_async` wrap `try_acquire` in a poll loop with a +configurable `timeout` (default 300s) to prevent indefinite blocking when a +domain is persistently at capacity or in cooldown. + +#### Ownership — standalone resource, not adapter-owned + +`ThrottleManager` is **not** owned by the adapter. It lives as a shared +resource on `ModelRegistry` and is intended to be called by the orchestration +layer (the `AsyncTaskScheduler` from plan 346). + +Rationale: + +- **Separation of concerns** — the adapter is pure HTTP transport (request, + retry, parse). Concurrency policy is an orchestration concern. +- **Scheduler optimization** — the async scheduler needs to release its + execution semaphore slot *while waiting* for a throttle permit, then + reacquire it before executing. This is only possible if the scheduler + owns the acquire/release lifecycle directly. +- **Sync path** — the current sync builder is sequential (one call at a time), + so it cannot exceed concurrency limits and does not need throttle gating. + +The layered responsibility is: + +| Layer | Responsibility | +|---|---| +| **Scheduler / Builder** | Concurrency policy: execution slots + throttle acquire/release | +| **ModelFacade** | Business logic: prompt assembly, usage tracking, correction loops | +| **Adapter** | Transport: HTTP, retry, response parsing | + +### 3. OpenAI-compatible adapter (`clients/adapters/openai_compatible.py`) + +`OpenAICompatibleClient` implements the `ModelClient` protocol using `httpx` +with `RetryTransport` for resilient HTTP calls. The adapter is pure transport +— it has no knowledge of throttle or concurrency policy. + +Routes: + +- `POST /chat/completions` — chat completion and autoregressive image generation +- `POST /embeddings` — text embeddings +- `POST /images/generations` — diffusion-style image generation + +Image routing is request-shape-based: if `request.messages is not None` the +chat route is used, otherwise the dedicated image route. + +Response parsing reuses the shared `parsing.py` helpers. The `get_value_from()` +utility handles both dict and object access, so raw JSON dicts from `httpx` +responses are passed directly to the parsing functions. + +### 4. Reasoning field migration (`clients/parsing.py`) + +`extract_reasoning_content(message)` checks `message.reasoning` first +(vLLM >= 0.16.0 canonical field), falling back to `message.reasoning_content` +(legacy / LiteLLM-normalized). Both `parse_chat_completion_response` and +`aparse_chat_completion_response` now use this helper. + +Internal canonical field remains `reasoning_content` — no downstream contract +change. + +Ref: [GitHub issue #374](https://github.com/NVIDIA-NeMo/DataDesigner/issues/374) + +### 5. Client factory routing (`clients/factory.py`) + +`create_model_client` accepts an optional `retry_config` parameter and routes +based on provider type via sequential early returns: + +1. If `DATA_DESIGNER_MODEL_BACKEND=litellm_bridge` → always `LiteLLMBridgeClient` + (rollback safety during migration). +2. If `provider_type == "openai"` → `OpenAICompatibleClient`. +3. Otherwise → `LiteLLMBridgeClient` (Anthropic native adapter is PR-4). + +The factory does not pass a `ThrottleManager` to adapters — throttle is an +orchestration concern (see §2). + +### 6. Registry integration (`models/factory.py`, `models/registry.py`) + +`create_model_registry` creates a shared `ThrottleManager` (held on +`ModelRegistry` for the scheduler to access) and a shared `RetryConfig` +(passed through to each `create_model_client` call). The throttle manager +is not forwarded to adapters. + +`ModelRegistry._get_model()` calls `throttle_manager.register()` when it +lazily creates each `ModelFacade`. This ensures the throttle manager's +per-`(provider_name, model_id)` global caps are populated before the +scheduler (or any other caller) attempts to acquire permits. + +## What Does NOT Change + +1. `ModelFacade` public method signatures — callers see the same API. +2. MCP tool-loop behavior — tool turns, refusal, parallel execution all preserved. +3. Usage accounting semantics — token, request, image, and tool usage remain identical. +4. Error boundaries — `@catch_llm_exceptions` / `@acatch_llm_exceptions` decorators + and `DataDesignerError` subclass hierarchy remain stable. +5. `consolidate_kwargs` merge semantics for `extra_body` / `extra_headers`. +6. `generate` / `agenerate` parser correction/restart loop logic. + +## Files Touched + +| File | Change | +|---|---| +| `clients/retry.py` | New — `RetryConfig` + `create_retry_transport` | +| `clients/throttle.py` | New — `ThrottleManager` with AIMD (standalone resource) | +| `clients/adapters/openai_compatible.py` | New — native OpenAI-compatible adapter (pure transport, no throttle) | +| `clients/errors.py` | Extract `infer_error_kind_from_exception` as shared function | +| `clients/base.py` | Add docstring to `ModelClient` protocol | +| `clients/parsing.py` | Add `extract_reasoning_content` helper; use `get_value_from` consistently | +| `clients/factory.py` | Route `provider_type=openai` to native adapter (sequential early returns) | +| `clients/__init__.py` | Export new public names | +| `clients/adapters/__init__.py` | Export `OpenAICompatibleClient` | +| `models/factory.py` | Create shared `ThrottleManager` (on registry) and `RetryConfig` | +| `models/registry.py` | Hold `ThrottleManager` as shared resource; expose via property | +| `models/facade.py` | Minor log message tweak (drop LiteLLM-specific wording) | + +## Planned Follow-On + +PR-4 introduces the Anthropic native adapter. At that point, the client +factory gains a third adapter option alongside the LiteLLM bridge and +OpenAI-compatible adapter. + +The `AsyncTaskScheduler` (plan 346) will call `ThrottleManager` directly +from the orchestration layer, acquiring throttle permits as part of its +execution-slot lifecycle: + +1. Acquire execution slot +2. Release execution slot, await `throttle_manager.acquire_async(...)` +3. Reacquire execution slot, execute via `ModelFacade` + +This pattern lets the scheduler free execution slots while waiting for +throttle permits (e.g., during 429 cooldowns), maximizing throughput.