-
Notifications
You must be signed in to change notification settings - Fork 66
feat: Native OpenAI adapter with retry and AIMD throttle infrastructure #402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nabinchha
wants to merge
77
commits into
main
Choose a base branch
from
nm/overhaul-model-facade-guts-pr3
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
77 commits
Select commit
Hold shift + click to select a range
ab30a2d
plans for model facade overhaul
nabinchha 43824ea
update plan
nabinchha 2a5f1e4
add review
johnnygreco f945d5b
address feedback + add more details after several self reviews
nabinchha dfa3817
update plan doc
nabinchha 5b18f74
Merge branch 'main' into nm/overhaul-model-facade-guts
nabinchha 0f449a7
address nits
nabinchha 37f092a
Merge branch 'nm/overhaul-model-facade-guts' into nm/overhaul-model-fβ¦
nabinchha 08e57f8
Add cannonical objects
nabinchha 3ab18ee
Merge branch 'main' into nm/overhaul-model-facade-guts-pr1
nabinchha 34349c7
self-review feedback + address
nabinchha 6aae4b6
add LiteLLMRouter protocol to strongly type bridge router param
nabinchha 2a53d37
simplify some things
nabinchha 4e2f3af
add a protol for http response like object
nabinchha b1c85f2
move HttpResponse
nabinchha f6dc769
update PR-1 architecture notes for lifecycle and router protocol
nabinchha ec5ed9b
Address PR #359 feedback: exception wrapping, shared parsing, test imβ¦
nabinchha b6b4028
Merge branch 'main' into nm/overhaul-model-facade-guts-pr1
nabinchha ba22397
Use contextlib to dry out some code
nabinchha aeac3b9
Address Greptile feedback: HTTP-date retry-after parsing, docstring cβ¦
nabinchha 55f3c96
Address Greptile feedback: FastAPI detail parsing, comment fixes
nabinchha c390912
Merge branch 'main' into nm/overhaul-model-facade-guts-pr1
nabinchha 828cc49
add PR-2 architecture notes for model facade overhaul
nabinchha 89a6d4e
save progress on pr2
nabinchha e527503
Merge branch 'main' into nm/overhaul-model-facade-guts-pr1
nabinchha f6fa447
Merge branch 'nm/overhaul-model-facade-guts-pr1' into nm/overhaul-modβ¦
nabinchha b8579c2
small refactor
nabinchha 61024c0
address feedback
nabinchha d47d508
Merge branch 'nm/overhaul-model-facade-guts-pr1' into nm/overhaul-modβ¦
nabinchha 49a45ba
Address greptile comment in pr1
nabinchha e8445cc
refactor ProviderError from dataclass to regular Exception
nabinchha 8a385ff
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha 521c1e4
Address greptile feedback
nabinchha a831d24
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha 4836e03
PR feedback
nabinchha ae1bf98
track usage tracking in finally block for images
nabinchha 18b9966
pr feedback
nabinchha ad45ee2
add native OpenAI adapter with retry and throttle infrastructure
nabinchha 724e734
Self CR
nabinchha 25650b0
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha 651813b
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha bfed5af
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha 3636c56
Merge branch 'nm/overhaul-model-facade-guts-pr2' into nm/overhaul-modβ¦
nabinchha 504d040
fix claude slop
nabinchha afbe197
Updates after self-review. Simplify use of ThrottleManager in light oβ¦
nabinchha c9d6f4c
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha a084038
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha 632c7c6
wrap facade close in try/catch
nabinchha 40c05ae
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha c1d807c
clean up stray params
nabinchha 56caed5
Merge branch 'nm/overhaul-model-facade-guts-pr2' into nm/overhaul-modβ¦
nabinchha e34e566
fix: address review findings from model facade overhaul PR3
nabinchha 879b941
fix stray inclusion of metadata
nabinchha dcbbcba
small regression fix
nabinchha 462810c
address more feedback
nabinchha ac02f2c
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha fb809bd
Merge branch 'nm/overhaul-model-facade-guts-pr2' into nm/overhaul-modβ¦
nabinchha c538367
Merge branch 'main' into nm/overhaul-model-facade-guts-pr3
nabinchha 117baf4
self review
nabinchha 843227b
Fixes
nabinchha c7e67d6
new test for aimd lifecycle
nabinchha 707a22f
update plan docs
nabinchha 0f55d4c
update plans with refs to prs
nabinchha e13f7b6
fix: cap acquire_sync/acquire_async sleep to remaining budget to prevβ¦
nabinchha 664e3cf
test lay init
nabinchha eb27418
Merge branch 'main' into nm/overhaul-model-facade-guts-pr3
nabinchha 7d7fd41
fix timeout for openaicompatibleadapter
nabinchha 9142494
remove unused attr
nabinchha bdd0202
fix: address review findings from PR #402
nabinchha e46efdf
Merge branch 'main' into nm/overhaul-model-facade-guts-pr3
nabinchha 5c5bfab
Merge branch 'main' into nm/overhaul-model-facade-guts-pr3
nabinchha c23e360
Address pr feedback
nabinchha 7bd4763
fix method order
nabinchha 1658544
Merge branch 'main' into nm/overhaul-model-facade-guts-pr3
nabinchha ffa0e11
Merge branch 'main' into nm/overhaul-model-facade-guts-pr3
nabinchha c0c2418
Merge branch 'main' into nm/overhaul-model-facade-guts-pr3
nabinchha 3b491c9
Fix failing test
nabinchha File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
319 changes: 319 additions & 0 deletions
319
...ata-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,319 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| import threading | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| import data_designer.lazy_heavy_imports as lazy | ||
| from data_designer.engine.models.clients.base import ModelClient | ||
| from data_designer.engine.models.clients.errors import ( | ||
| ProviderError, | ||
| ProviderErrorKind, | ||
| infer_error_kind_from_exception, | ||
| map_http_error_to_provider_error, | ||
| ) | ||
| from data_designer.engine.models.clients.parsing import ( | ||
| aextract_images_from_chat_response, | ||
| aextract_images_from_image_response, | ||
| aparse_chat_completion_response, | ||
| extract_embedding_vector, | ||
| extract_images_from_chat_response, | ||
| extract_images_from_image_response, | ||
| extract_usage, | ||
| parse_chat_completion_response, | ||
| ) | ||
| from data_designer.engine.models.clients.retry import RetryConfig, create_retry_transport | ||
| from data_designer.engine.models.clients.types import ( | ||
| ChatCompletionRequest, | ||
| ChatCompletionResponse, | ||
| EmbeddingRequest, | ||
| EmbeddingResponse, | ||
| ImageGenerationRequest, | ||
| ImageGenerationResponse, | ||
| TransportKwargs, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| import httpx | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class OpenAICompatibleClient(ModelClient): | ||
| """Native HTTP adapter for OpenAI-compatible provider APIs. | ||
|
|
||
| Uses ``httpx`` with ``httpx_retries.RetryTransport`` for resilient HTTP | ||
| calls. Concurrency / throttle policy is an orchestration concern and | ||
| is not managed here β see ``ThrottleManager`` and ``AsyncTaskScheduler``. | ||
| """ | ||
|
|
||
| _ROUTE_CHAT = "/chat/completions" | ||
| _ROUTE_EMBEDDING = "/embeddings" | ||
| _ROUTE_IMAGE = "/images/generations" | ||
| _IMAGE_EXCLUDE = frozenset({"messages", "prompt"}) | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| provider_name: str, | ||
| model_id: str, | ||
| endpoint: str, | ||
| api_key: str | None = None, | ||
| retry_config: RetryConfig | None = None, | ||
| max_parallel_requests: int = 32, | ||
| timeout_s: float = 60.0, | ||
| sync_client: httpx.Client | None = None, | ||
| async_client: httpx.AsyncClient | None = None, | ||
| ) -> None: | ||
| self.provider_name = provider_name | ||
| self._model_id = model_id | ||
| self._endpoint = endpoint.rstrip("/") | ||
| self._api_key = api_key | ||
| self._timeout_s = timeout_s | ||
| self._retry_config = retry_config | ||
|
|
||
| # 2x headroom for burst traffic across domains; floor of 32/16 for low-concurrency configs. | ||
| pool_max = max(32, 2 * max_parallel_requests) | ||
| pool_keepalive = max(16, max_parallel_requests) | ||
| self._limits = lazy.httpx.Limits( | ||
| max_connections=pool_max, | ||
| max_keepalive_connections=pool_keepalive, | ||
| ) | ||
| self._transport = create_retry_transport(self._retry_config) | ||
| self._client: httpx.Client | None = sync_client | ||
| self._aclient: httpx.AsyncClient | None = async_client | ||
| self._init_lock = threading.Lock() | ||
|
|
||
| def _get_sync_client(self) -> httpx.Client: | ||
| if self._client is None: | ||
| with self._init_lock: | ||
| if self._client is None: | ||
| self._client = lazy.httpx.Client( | ||
| transport=self._transport, | ||
| limits=self._limits, | ||
| timeout=lazy.httpx.Timeout(self._timeout_s), | ||
| ) | ||
| return self._client | ||
|
|
||
| def _get_async_client(self) -> httpx.AsyncClient: | ||
| if self._aclient is None: | ||
| with self._init_lock: | ||
| if self._aclient is None: | ||
| self._aclient = lazy.httpx.AsyncClient( | ||
| transport=self._transport, | ||
| limits=self._limits, | ||
| timeout=lazy.httpx.Timeout(self._timeout_s), | ||
| ) | ||
| return self._aclient | ||
nabinchha marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # ------------------------------------------------------------------- | ||
| # Capability checks β adapter-level (see ModelClient docstring) | ||
| # ------------------------------------------------------------------- | ||
|
|
||
| def supports_chat_completion(self) -> bool: | ||
| return True | ||
|
|
||
| def supports_embeddings(self) -> bool: | ||
| return True | ||
|
|
||
| def supports_image_generation(self) -> bool: | ||
| return True | ||
|
|
||
| # ------------------------------------------------------------------- | ||
| # Chat completion | ||
| # ------------------------------------------------------------------- | ||
|
|
||
| def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: | ||
| transport = TransportKwargs.from_request(request) | ||
| payload = {"model": request.model, "messages": request.messages, **transport.body} | ||
| response_json = self._post_sync(self._ROUTE_CHAT, payload, transport.headers, request.model, transport.timeout) | ||
| return parse_chat_completion_response(response_json) | ||
|
|
||
| async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: | ||
| transport = TransportKwargs.from_request(request) | ||
| payload = {"model": request.model, "messages": request.messages, **transport.body} | ||
| response_json = await self._apost( | ||
| self._ROUTE_CHAT, payload, transport.headers, request.model, transport.timeout | ||
| ) | ||
| return await aparse_chat_completion_response(response_json) | ||
|
|
||
| # ------------------------------------------------------------------- | ||
| # Embeddings | ||
| # ------------------------------------------------------------------- | ||
|
|
||
| def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: | ||
| transport = TransportKwargs.from_request(request) | ||
| payload = {"model": request.model, "input": request.inputs, **transport.body} | ||
| response_json = self._post_sync( | ||
| self._ROUTE_EMBEDDING, payload, transport.headers, request.model, transport.timeout | ||
| ) | ||
| return _parse_embedding_json(response_json) | ||
|
|
||
| async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: | ||
| transport = TransportKwargs.from_request(request) | ||
| payload = {"model": request.model, "input": request.inputs, **transport.body} | ||
| response_json = await self._apost( | ||
| self._ROUTE_EMBEDDING, payload, transport.headers, request.model, transport.timeout | ||
| ) | ||
| return _parse_embedding_json(response_json) | ||
|
|
||
| # ------------------------------------------------------------------- | ||
| # Image generation | ||
| # ------------------------------------------------------------------- | ||
|
|
||
| def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: | ||
| transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) | ||
| if request.messages is not None: | ||
| route = self._ROUTE_CHAT | ||
| payload = {"model": request.model, "messages": request.messages, **transport.body} | ||
| else: | ||
| route = self._ROUTE_IMAGE | ||
| payload = {"model": request.model, "prompt": request.prompt, **transport.body} | ||
| response_json = self._post_sync(route, payload, transport.headers, request.model, transport.timeout) | ||
| return _parse_image_json(response_json, is_chat_route=request.messages is not None) | ||
|
|
||
| async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: | ||
| transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) | ||
| if request.messages is not None: | ||
| route = self._ROUTE_CHAT | ||
| payload = {"model": request.model, "messages": request.messages, **transport.body} | ||
| else: | ||
| route = self._ROUTE_IMAGE | ||
| payload = {"model": request.model, "prompt": request.prompt, **transport.body} | ||
| response_json = await self._apost(route, payload, transport.headers, request.model, transport.timeout) | ||
| return await _aparse_image_json(response_json, is_chat_route=request.messages is not None) | ||
|
|
||
| # ------------------------------------------------------------------- | ||
| # Lifecycle | ||
| # ------------------------------------------------------------------- | ||
|
|
||
| def close(self) -> None: | ||
| if self._client is not None: | ||
| self._client.close() | ||
| self._client = None | ||
|
|
||
| async def aclose(self) -> None: | ||
| if self._aclient is not None: | ||
| await self._aclient.aclose() | ||
nabinchha marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self._aclient = None | ||
| if self._client is not None: | ||
| self._client.close() | ||
| self._client = None | ||
|
|
||
| # ------------------------------------------------------------------- | ||
| # HTTP helpers | ||
| # ------------------------------------------------------------------- | ||
|
|
||
| def _build_headers(self, extra_headers: dict[str, str]) -> dict[str, str]: | ||
| headers: dict[str, str] = {"Content-Type": "application/json"} | ||
| if self._api_key: | ||
| headers["Authorization"] = f"Bearer {self._api_key}" | ||
| if extra_headers: | ||
| headers.update(extra_headers) | ||
| return headers | ||
|
|
||
| def _resolve_timeout(self, per_request: float | None) -> httpx.Timeout: | ||
| return lazy.httpx.Timeout(per_request if per_request is not None else self._timeout_s) | ||
|
|
||
| def _post_sync( | ||
| self, | ||
| route: str, | ||
| payload: dict[str, Any], | ||
| extra_headers: dict[str, str], | ||
| model_name: str, | ||
| timeout: float | None = None, | ||
| ) -> dict[str, Any]: | ||
| url = f"{self._endpoint}{route}" | ||
| headers = self._build_headers(extra_headers) | ||
| try: | ||
| response = self._get_sync_client().post( | ||
| url, json=payload, headers=headers, timeout=self._resolve_timeout(timeout) | ||
| ) | ||
| except Exception as exc: | ||
| raise _wrap_transport_error(exc, self.provider_name, model_name) from exc | ||
| if response.status_code >= 400: | ||
| raise map_http_error_to_provider_error( | ||
| response=response, provider_name=self.provider_name, model_name=model_name | ||
| ) | ||
| return _parse_json_body(response, self.provider_name, model_name) | ||
|
|
||
| async def _apost( | ||
| self, | ||
| route: str, | ||
| payload: dict[str, Any], | ||
| extra_headers: dict[str, str], | ||
| model_name: str, | ||
| timeout: float | None = None, | ||
| ) -> dict[str, Any]: | ||
| url = f"{self._endpoint}{route}" | ||
| headers = self._build_headers(extra_headers) | ||
| try: | ||
| response = await self._get_async_client().post( | ||
| url, json=payload, headers=headers, timeout=self._resolve_timeout(timeout) | ||
| ) | ||
| except Exception as exc: | ||
| raise _wrap_transport_error(exc, self.provider_name, model_name) from exc | ||
| if response.status_code >= 400: | ||
| raise map_http_error_to_provider_error( | ||
| response=response, provider_name=self.provider_name, model_name=model_name | ||
| ) | ||
| return _parse_json_body(response, self.provider_name, model_name) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Response parsing helpers | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _parse_embedding_json(response_json: dict[str, Any]) -> EmbeddingResponse: | ||
| data = response_json.get("data") or [] | ||
| vectors = [extract_embedding_vector(item) for item in data] | ||
| usage = extract_usage(response_json.get("usage")) | ||
| return EmbeddingResponse(vectors=vectors, usage=usage, raw=response_json) | ||
|
|
||
|
|
||
| def _parse_image_json(response_json: dict[str, Any], *, is_chat_route: bool) -> ImageGenerationResponse: | ||
| if is_chat_route: | ||
| images = extract_images_from_chat_response(response_json) | ||
| else: | ||
| images = extract_images_from_image_response(response_json) | ||
| usage = extract_usage(response_json.get("usage"), generated_images=len(images)) | ||
| return ImageGenerationResponse(images=images, usage=usage, raw=response_json) | ||
|
|
||
|
|
||
| async def _aparse_image_json(response_json: dict[str, Any], *, is_chat_route: bool) -> ImageGenerationResponse: | ||
| if is_chat_route: | ||
| images = await aextract_images_from_chat_response(response_json) | ||
| else: | ||
| images = await aextract_images_from_image_response(response_json) | ||
| usage = extract_usage(response_json.get("usage"), generated_images=len(images)) | ||
| return ImageGenerationResponse(images=images, usage=usage, raw=response_json) | ||
|
|
||
|
|
||
| def _parse_json_body(response: httpx.Response, provider_name: str, model_name: str) -> dict[str, Any]: | ||
| """Parse JSON from a successful HTTP response, wrapping decode errors as ``ProviderError``.""" | ||
| try: | ||
| return response.json() | ||
| except Exception as exc: | ||
| raise ProviderError( | ||
| kind=ProviderErrorKind.API_ERROR, | ||
| message=f"Provider {provider_name!r} returned a non-JSON response (status {response.status_code}).", | ||
| status_code=response.status_code, | ||
| provider_name=provider_name, | ||
| model_name=model_name, | ||
| cause=exc, | ||
| ) from exc | ||
nabinchha marked this conversation as resolved.
Show resolved
Hide resolved
nabinchha marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def _wrap_transport_error(exc: Exception, provider_name: str, model_name: str) -> ProviderError: | ||
| """Convert httpx transport exceptions into canonical ``ProviderError``.""" | ||
| return ProviderError( | ||
| kind=infer_error_kind_from_exception(exc), | ||
| message=str(exc) or f"Transport error from provider {provider_name!r}", | ||
| provider_name=provider_name, | ||
| model_name=model_name, | ||
| cause=exc, | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.