From 58cefcab587201124713ccc88374a05428a6f123 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Tue, 24 Feb 2026 09:26:09 -0800 Subject: [PATCH 01/24] Optionally disable SSL verification, this is useful when using a self-hosted LLM using a self-signed SSL certificate Signed-off-by: David Gardner --- packages/nvidia_nat_core/src/nat/llm/utils/hooks.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py index b0559f2609..f7be89c1d5 100644 --- a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py +++ b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py @@ -21,6 +21,7 @@ """ import logging +import os from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -30,6 +31,11 @@ logger = logging.getLogger(__name__) +# Lazily read environment variable to disable SSL verification for LLM requests later. +# When the value is None, the environment variable has not been read yet. When it's a boolean, it indicates whether +# SSL verification should be disabled. +_DISABLE_SSL_VERIFICATION = None + def create_metadata_injection_client(timeout: float = 600.0) -> "httpx.AsyncClient": """ @@ -46,6 +52,11 @@ def create_metadata_injection_client(timeout: float = 600.0) -> "httpx.AsyncClie """ import httpx + global _DISABLE_SSL_VERIFICATION + if _DISABLE_SSL_VERIFICATION is None: + env_value = os.getenv("NAT_DISABLE_SSL_VERIFICATION", "").lower() + _DISABLE_SSL_VERIFICATION = (env_value in ("1", "true", "yes")) + from nat.builder.context import ContextState async def on_request(request: httpx.Request) -> None: @@ -66,4 +77,5 @@ async def on_request(request: httpx.Request) -> None: return httpx.AsyncClient( event_hooks={"request": [on_request]}, timeout=httpx.Timeout(timeout), + verify=not _DISABLE_SSL_VERIFICATION ) From fbe530292680a386d975ca8d3d6486c0d85603d9 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 5 Mar 2026 09:08:28 -0800 Subject: [PATCH 02/24] First pass at a mixin for disabling SSL verification Signed-off-by: David Gardner --- packages/nvidia_nat_core/src/nat/data_models/llm.py | 9 +++++++++ packages/nvidia_nat_core/src/nat/llm/litellm_llm.py | 2 ++ packages/nvidia_nat_core/src/nat/llm/nim_llm.py | 3 ++- packages/nvidia_nat_core/src/nat/llm/openai_llm.py | 3 ++- 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/packages/nvidia_nat_core/src/nat/data_models/llm.py b/packages/nvidia_nat_core/src/nat/data_models/llm.py index a64422825a..fa591458fb 100644 --- a/packages/nvidia_nat_core/src/nat/data_models/llm.py +++ b/packages/nvidia_nat_core/src/nat/data_models/llm.py @@ -16,6 +16,7 @@ import typing from enum import StrEnum +from pydantic import BaseModel from pydantic import Field from .common import BaseModelRegistryTag @@ -39,3 +40,11 @@ class LLMBaseConfig(TypedBaseModel, BaseModelRegistryTag): LLMBaseConfigT = typing.TypeVar("LLMBaseConfigT", bound=LLMBaseConfig) + +class SSLVerificationMixin(BaseModel): + """Mixin for SSL verification configuration.""" + + verify_ssl: bool = Field( + default=True, + description="Whether to verify SSL certificates when making API calls to the LLM provider. Defaults to True.", + ) \ No newline at end of file diff --git a/packages/nvidia_nat_core/src/nat/llm/litellm_llm.py b/packages/nvidia_nat_core/src/nat/llm/litellm_llm.py index 892bb3a2a6..95f38e6224 100644 --- a/packages/nvidia_nat_core/src/nat/llm/litellm_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/litellm_llm.py @@ -24,6 +24,7 @@ from nat.cli.register_workflow import register_llm_provider from nat.data_models.common import OptionalSecretStr from nat.data_models.llm import LLMBaseConfig +from nat.data_models.llm import SSLVerificationMixin from nat.data_models.optimizable import OptimizableField from nat.data_models.optimizable import OptimizableMixin from nat.data_models.optimizable import SearchSpace @@ -36,6 +37,7 @@ class LiteLlmModelConfig( OptimizableMixin, RetryMixin, ThinkingMixin, + SSLVerificationMixin, name="litellm", ): """A LiteLlm provider to be used with an LLM client.""" diff --git a/packages/nvidia_nat_core/src/nat/llm/nim_llm.py b/packages/nvidia_nat_core/src/nat/llm/nim_llm.py index a3257b84d5..c701e42caa 100644 --- a/packages/nvidia_nat_core/src/nat/llm/nim_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/nim_llm.py @@ -23,6 +23,7 @@ from nat.cli.register_workflow import register_llm_provider from nat.data_models.common import OptionalSecretStr from nat.data_models.llm import LLMBaseConfig +from nat.data_models.llm import SSLVerificationMixin from nat.data_models.optimizable import OptimizableField from nat.data_models.optimizable import OptimizableMixin from nat.data_models.optimizable import SearchSpace @@ -30,7 +31,7 @@ from nat.data_models.thinking_mixin import ThinkingMixin -class NIMModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="nim"): +class NIMModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, SSLVerificationMixin, name="nim"): """An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client.""" model_config = ConfigDict(protected_namespaces=(), extra="allow") diff --git a/packages/nvidia_nat_core/src/nat/llm/openai_llm.py b/packages/nvidia_nat_core/src/nat/llm/openai_llm.py index 4e32075d41..c5a31de6f1 100644 --- a/packages/nvidia_nat_core/src/nat/llm/openai_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/openai_llm.py @@ -22,6 +22,7 @@ from nat.cli.register_workflow import register_llm_provider from nat.data_models.common import OptionalSecretStr from nat.data_models.llm import LLMBaseConfig +from nat.data_models.llm import SSLVerificationMixin from nat.data_models.optimizable import OptimizableField from nat.data_models.optimizable import OptimizableMixin from nat.data_models.optimizable import SearchSpace @@ -29,7 +30,7 @@ from nat.data_models.thinking_mixin import ThinkingMixin -class OpenAIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="openai"): +class OpenAIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, SSLVerificationMixin, name="openai"): """An OpenAI LLM provider to be used with an LLM client.""" model_config = ConfigDict(protected_namespaces=(), extra="allow") From a81533c86c953e36ee09b4b30b0a74b8ded27f77 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 5 Mar 2026 09:49:41 -0800 Subject: [PATCH 03/24] Update create_metadata_injection_client to not use the env variable Signed-off-by: David Gardner --- .../src/nat/llm/utils/hooks.py | 16 ++---- .../src/nat/plugins/langchain/llm.py | 50 +++++++++++++------ 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py index f7be89c1d5..cd5bba2673 100644 --- a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py +++ b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py @@ -21,7 +21,6 @@ """ import logging -import os from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -31,13 +30,8 @@ logger = logging.getLogger(__name__) -# Lazily read environment variable to disable SSL verification for LLM requests later. -# When the value is None, the environment variable has not been read yet. When it's a boolean, it indicates whether -# SSL verification should be disabled. -_DISABLE_SSL_VERIFICATION = None - -def create_metadata_injection_client(timeout: float = 600.0) -> "httpx.AsyncClient": +def create_metadata_injection_client(timeout: float = 600.0, verify_ssl: bool = True) -> "httpx.AsyncClient": """ Httpx event hook that injects custom metadata as HTTP headers. @@ -46,17 +40,13 @@ def create_metadata_injection_client(timeout: float = 600.0) -> "httpx.AsyncClie Args: timeout: HTTP request timeout in seconds + verify_ssl: Whether to verify SSL certificates Returns: An httpx.AsyncClient configured with metadata header injection """ import httpx - global _DISABLE_SSL_VERIFICATION - if _DISABLE_SSL_VERIFICATION is None: - env_value = os.getenv("NAT_DISABLE_SSL_VERIFICATION", "").lower() - _DISABLE_SSL_VERIFICATION = (env_value in ("1", "true", "yes")) - from nat.builder.context import ContextState async def on_request(request: httpx.Request) -> None: @@ -77,5 +67,5 @@ async def on_request(request: httpx.Request) -> None: return httpx.AsyncClient( event_hooks={"request": [on_request]}, timeout=httpx.Timeout(timeout), - verify=not _DISABLE_SSL_VERIFICATION + verify=verify_ssl ) diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index 5af5fe861e..cfe7d519ea 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -23,15 +23,11 @@ from typing import Any from typing import TypeVar -if TYPE_CHECKING: - import httpx - from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum from nat.cli.register_workflow import register_llm_client from nat.data_models.common import get_secret_value from nat.data_models.llm import APITypeEnum -from nat.data_models.llm import LLMBaseConfig from nat.data_models.retry_mixin import RetryMixin from nat.data_models.thinking_mixin import ThinkingMixin from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig @@ -51,12 +47,18 @@ from nat.utils.responses_api import validate_no_responses_api from nat.utils.type_utils import override + +if TYPE_CHECKING: + import httpx + from nat.data_models.llm import LLMBaseConfig + + logger = logging.getLogger(__name__) ModelType = TypeVar("ModelType") -def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType: +def _patch_llm_based_on_config(client: ModelType, llm_config: "LLMBaseConfig") -> ModelType: from langchain_core.language_models import LanguageModelInput from langchain_core.messages import BaseMessage @@ -125,6 +127,30 @@ def inject(self, messages: LanguageModelInput, *args, **kwargs) -> FunctionArgum return client +def _create_metadata_injection_client(llm_config: "LLMBaseConfig") -> "httpx.AsyncClient": + """ + Create an httpx.AsyncClient with event hooks to inject custom metadata as HTTP headers. + + This client injects custom payload fields as X-Payload-* HTTP headers, + enabling end-to-end traceability in LLM server logs. + + Args: + llm_config: The LLM configuration containing timeout and SSL verification settings + + Returns: + An httpx.AsyncClient configured with metadata header injection + """ + client_kwargs: dict = {} + + if hasattr(llm_config, "verify_ssl"): + client_kwargs["verify_ssl"] = llm_config.verify_ssl + + if llm_config.request_timeout is not None: + client_kwargs["timeout"] = llm_config.request_timeout + + return create_metadata_injection_client(**client_kwargs) + + @register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) async def aws_bedrock_langchain(llm_config: AWSBedrockModelConfig, _builder: Builder): @@ -149,17 +175,12 @@ async def azure_openai_langchain(llm_config: AzureOpenAIModelConfig, _builder: B validate_no_responses_api(llm_config, LLMFrameworkEnum.LANGCHAIN) - client_kwargs: dict = {} - if llm_config.request_timeout is not None: - client_kwargs["timeout"] = llm_config.request_timeout - http_async_client: httpx.AsyncClient = create_metadata_injection_client(**client_kwargs) - try: client = AzureChatOpenAI( http_async_client=http_async_client, # type: ignore[call-arg] api_version=llm_config.api_version, # type: ignore[call-arg] **llm_config.model_dump( - exclude={"type", "thinking", "api_type", "api_version"}, + exclude={"type", "thinking", "api_type", "api_version", "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, @@ -199,13 +220,10 @@ async def openai_langchain(llm_config: OpenAIModelConfig, _builder: Builder): from langchain_openai import ChatOpenAI - client_kwargs: dict = {} - if llm_config.request_timeout is not None: - client_kwargs["timeout"] = llm_config.request_timeout - http_async_client: httpx.AsyncClient = create_metadata_injection_client(**client_kwargs) + http_async_client: httpx.AsyncClient = _create_metadata_injection_client(llm_config) config_dict = llm_config.model_dump( - exclude={"type", "thinking", "api_type", "api_key", "base_url"}, + exclude={"type", "thinking", "api_type", "api_key", "base_url", "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, From a53a2b99865887c4e2141e20a296d61f88b057f4 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 5 Mar 2026 11:10:02 -0800 Subject: [PATCH 04/24] First pass at ADK Signed-off-by: David Gardner --- .../nvidia_nat_adk/src/nat/plugins/adk/llm.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py index c68a0ca251..88f20d3e98 100644 --- a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py +++ b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py @@ -15,6 +15,7 @@ import logging import os +import typing from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum @@ -43,15 +44,16 @@ async def azure_openai_adk(config: AzureOpenAIModelConfig, _builder: Builder): config_dict = config.model_dump( exclude={ - "type", - "max_retries", - "thinking", - "azure_endpoint", + "api_type", "azure_deployment", - "model_name", + "azure_endpoint", + "max_retries", "model", - "api_type", - "request_timeout" + "model_name", + "request_timeout", + "thinking", + "type", + "verify_ssl" }, by_alias=True, exclude_none=True, @@ -63,6 +65,7 @@ async def azure_openai_adk(config: AzureOpenAIModelConfig, _builder: Builder): config_dict["timeout"] = config.request_timeout config_dict["api_version"] = config.api_version + config_dict["ssl_verify"] = config.verify_ssl yield LiteLlm(f"azure/{config.azure_deployment}", **config_dict) @@ -73,12 +76,13 @@ async def litellm_adk(litellm_config: LiteLlmModelConfig, _builder: Builder): validate_no_responses_api(litellm_config, LLMFrameworkEnum.ADK) + kwargs = {"ssl_verify": litellm_config.verify_ssl} yield LiteLlm(**litellm_config.model_dump( - exclude={"type", "max_retries", "thinking", "api_type"}, + exclude={"api_type", "max_retries", "thinking", "type", "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, - )) + ), **kwargs) @register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.ADK) @@ -102,7 +106,7 @@ async def nim_adk(config: NIMModelConfig, _builder: Builder): os.environ["NVIDIA_NIM_API_KEY"] = api_key config_dict = config.model_dump( - exclude={"type", "max_retries", "thinking", "model_name", "model", "base_url", "api_type"}, + exclude={"api_type", "base_url", "max_retries", "model", "model_name", "thinking", "type", "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, @@ -110,6 +114,8 @@ async def nim_adk(config: NIMModelConfig, _builder: Builder): if config.base_url: config_dict["api_base"] = config.base_url + config_dict["ssl_verify"] = config.verify_ssl + yield LiteLlm(f"nvidia_nim/{config.model_name}", **config_dict) @@ -126,7 +132,15 @@ async def openai_adk(config: OpenAIModelConfig, _builder: Builder): validate_no_responses_api(config, LLMFrameworkEnum.ADK) config_dict = config.model_dump( - exclude={"type", "max_retries", "thinking", "model_name", "model", "base_url", "api_type", "request_timeout"}, + exclude={"api_type", + "base_url", + "max_retries", + "model", + "model_name", + "request_timeout", + "thinking", + "type", + "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, @@ -138,6 +152,7 @@ async def openai_adk(config: OpenAIModelConfig, _builder: Builder): config_dict["api_base"] = base_url if config.request_timeout is not None: config_dict["timeout"] = config.request_timeout + config_dict["ssl_verify"] = config.verify_ssl yield LiteLlm(config.model_name, **config_dict) From 08c3d551a68b86ddef9d0c24253ef06b6d53f999 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 5 Mar 2026 11:28:34 -0800 Subject: [PATCH 05/24] First pass at agno Signed-off-by: David Gardner --- .../src/nat/plugins/agno/llm.py | 53 ++++++++++++++++--- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py b/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py index fa4b70bc63..29c33e443f 100644 --- a/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py +++ b/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py @@ -14,7 +14,7 @@ # limitations under the License. import os -from typing import TypeVar +import typing from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum @@ -34,7 +34,10 @@ from nat.utils.responses_api import validate_no_responses_api from nat.utils.type_utils import override -ModelType = TypeVar("ModelType") +if typing.TYPE_CHECKING: + import httpx + +ModelType = typing.TypeVar("ModelType") def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType: @@ -79,6 +82,16 @@ def inject(self, messages: list[Message], *args, **kwargs) -> FunctionArgumentWr return client +def _create_http_client(llm_config: LLMBaseConfig) -> "httpx.AsyncClient": + """Create an httpx.AsyncClient with event hooks to inject custom metadata as HTTP headers.""" + import httpx + + kwargs = {} + if hasattr(llm_config, "verify_ssl"): + kwargs["verify"] = llm_config.verify_ssl + return httpx.AsyncClient(**kwargs) + + @register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.AGNO) async def nim_agno(llm_config: NIMModelConfig, _builder: Builder): @@ -88,14 +101,24 @@ async def nim_agno(llm_config: NIMModelConfig, _builder: Builder): config_obj = { **llm_config.model_dump( - exclude={"type", "model_name", "thinking", "api_type"}, + exclude={ + "api_type", + "model_name", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, exclude_unset=True, ), + "http_client": + _create_http_client(llm_config), + "id": + llm_config.model_name } - client = Nvidia(**config_obj, id=llm_config.model_name) + client = Nvidia(**config_obj) yield _patch_llm_based_on_config(client, llm_config) @@ -108,11 +131,22 @@ async def openai_agno(llm_config: OpenAIModelConfig, _builder: Builder): config_obj = { **llm_config.model_dump( - exclude={"type", "model_name", "thinking", "api_type", "api_key", "base_url", "request_timeout"}, + exclude={ + "api_key", + "api_type", + "base_url", + "model_name", + "request_timeout", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, exclude_unset=True, ), + "http_client": + _create_http_client(llm_config), } if (api_key := get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY")): @@ -139,11 +173,18 @@ async def litellm_agno(llm_config: LiteLlmModelConfig, _builder: Builder): client = LiteLLM( **llm_config.model_dump( - exclude={"type", "thinking", "model_name", "api_type"}, + exclude={ + "api_type", + "model_name", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, exclude_unset=True, ), + http_client=_create_http_client(llm_config), id=llm_config.model_name, ) From 399c471d471cb01acb9949002e40fd6c86b6076b Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 5 Mar 2026 13:45:07 -0800 Subject: [PATCH 06/24] Add _create_http_client helper method, refactor create_metadata_injection_client to use it Signed-off-by: David Gardner --- .../src/nat/llm/utils/hooks.py | 12 ++---- .../src/nat/llm/utils/http_client.py | 42 +++++++++++++++++++ 2 files changed, 46 insertions(+), 8 deletions(-) create mode 100644 packages/nvidia_nat_core/src/nat/llm/utils/http_client.py diff --git a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py index cd5bba2673..a45253921a 100644 --- a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py +++ b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py @@ -27,11 +27,12 @@ import httpx from nat.llm.utils.constants import LLMHeaderPrefix +from nat.llm.utils.http_client import _create_http_client logger = logging.getLogger(__name__) -def create_metadata_injection_client(timeout: float = 600.0, verify_ssl: bool = True) -> "httpx.AsyncClient": +def create_metadata_injection_client(llm_config: "LLMBaseConfig") -> "httpx.AsyncClient": """ Httpx event hook that injects custom metadata as HTTP headers. @@ -39,8 +40,7 @@ def create_metadata_injection_client(timeout: float = 600.0, verify_ssl: bool = enabling end-to-end traceability in LLM server logs. Args: - timeout: HTTP request timeout in seconds - verify_ssl: Whether to verify SSL certificates + llm_config: LLM configuration object Returns: An httpx.AsyncClient configured with metadata header injection @@ -64,8 +64,4 @@ async def on_request(request: httpx.Request) -> None: except Exception as e: logger.debug("Could not inject custom metadata headers, request will proceed without them: %s", e) - return httpx.AsyncClient( - event_hooks={"request": [on_request]}, - timeout=httpx.Timeout(timeout), - verify=verify_ssl - ) + return _create_http_client(llm_config=llm_config, use_async=True, event_hooks={"request": [on_request]}) diff --git a/packages/nvidia_nat_core/src/nat/llm/utils/http_client.py b/packages/nvidia_nat_core/src/nat/llm/utils/http_client.py new file mode 100644 index 0000000000..8e48bc269e --- /dev/null +++ b/packages/nvidia_nat_core/src/nat/llm/utils/http_client.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) `2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import httpx + + from nat.data_models.llm import LLMBaseConfig + + +def _create_http_client(llm_config: "LLMBaseConfig", + use_async: bool = True, + **kwargs) -> "httpx.AsyncClient | httpx.Client": + """Create an httpx.AsyncClient based on LLM configuration.""" + import httpx + + def _set_kwarg(kwarg_name: str, config_attr: str): + if kwarg_name not in kwargs and getattr(llm_config, config_attr, None) is not None: + kwargs[kwarg_name] = getattr(llm_config, config_attr) + + _set_kwarg("verify", "verify_ssl") + _set_kwarg("timeout", "request_timeout") + + if use_async: + client_class = httpx.AsyncClient + else: + client_class = httpx.Client + + return client_class(**kwargs) From a2df04f19d0a9e8fe487a511a5b083aed3b8b2f1 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 5 Mar 2026 13:47:51 -0800 Subject: [PATCH 07/24] Update to use create_metadata_injection_client and _create_http_client Signed-off-by: David Gardner --- .../src/nat/plugins/agno/llm.py | 14 +-------- .../src/nat/plugins/langchain/llm.py | 31 +++---------------- 2 files changed, 5 insertions(+), 40 deletions(-) diff --git a/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py b/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py index 29c33e443f..fb8b4f74b7 100644 --- a/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py +++ b/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py @@ -27,6 +27,7 @@ from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _create_http_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -34,9 +35,6 @@ from nat.utils.responses_api import validate_no_responses_api from nat.utils.type_utils import override -if typing.TYPE_CHECKING: - import httpx - ModelType = typing.TypeVar("ModelType") @@ -82,16 +80,6 @@ def inject(self, messages: list[Message], *args, **kwargs) -> FunctionArgumentWr return client -def _create_http_client(llm_config: LLMBaseConfig) -> "httpx.AsyncClient": - """Create an httpx.AsyncClient with event hooks to inject custom metadata as HTTP headers.""" - import httpx - - kwargs = {} - if hasattr(llm_config, "verify_ssl"): - kwargs["verify"] = llm_config.verify_ssl - return httpx.AsyncClient(**kwargs) - - @register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.AGNO) async def nim_agno(llm_config: NIMModelConfig, _builder: Builder): diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index cfe7d519ea..ed82f1e0e6 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -47,11 +47,10 @@ from nat.utils.responses_api import validate_no_responses_api from nat.utils.type_utils import override - if TYPE_CHECKING: import httpx - from nat.data_models.llm import LLMBaseConfig + from nat.data_models.llm import LLMBaseConfig logger = logging.getLogger(__name__) @@ -127,30 +126,6 @@ def inject(self, messages: LanguageModelInput, *args, **kwargs) -> FunctionArgum return client -def _create_metadata_injection_client(llm_config: "LLMBaseConfig") -> "httpx.AsyncClient": - """ - Create an httpx.AsyncClient with event hooks to inject custom metadata as HTTP headers. - - This client injects custom payload fields as X-Payload-* HTTP headers, - enabling end-to-end traceability in LLM server logs. - - Args: - llm_config: The LLM configuration containing timeout and SSL verification settings - - Returns: - An httpx.AsyncClient configured with metadata header injection - """ - client_kwargs: dict = {} - - if hasattr(llm_config, "verify_ssl"): - client_kwargs["verify_ssl"] = llm_config.verify_ssl - - if llm_config.request_timeout is not None: - client_kwargs["timeout"] = llm_config.request_timeout - - return create_metadata_injection_client(**client_kwargs) - - @register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) async def aws_bedrock_langchain(llm_config: AWSBedrockModelConfig, _builder: Builder): @@ -175,6 +150,8 @@ async def azure_openai_langchain(llm_config: AzureOpenAIModelConfig, _builder: B validate_no_responses_api(llm_config, LLMFrameworkEnum.LANGCHAIN) + http_async_client: httpx.AsyncClient = create_metadata_injection_client(llm_config) + try: client = AzureChatOpenAI( http_async_client=http_async_client, # type: ignore[call-arg] @@ -220,7 +197,7 @@ async def openai_langchain(llm_config: OpenAIModelConfig, _builder: Builder): from langchain_openai import ChatOpenAI - http_async_client: httpx.AsyncClient = _create_metadata_injection_client(llm_config) + http_async_client: httpx.AsyncClient = create_metadata_injection_client(llm_config) config_dict = llm_config.model_dump( exclude={"type", "thinking", "api_type", "api_key", "base_url", "verify_ssl"}, From 9270f77610e27adb1739c66dbe6d311afea0f27b Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 5 Mar 2026 13:52:14 -0800 Subject: [PATCH 08/24] Rename create_metadata_injection_client to indicated that it is not part of the public API Signed-off-by: David Gardner --- .../nvidia_nat_core/src/nat/llm/utils/hooks.py | 2 +- .../tests/nat/llm/utils/test_hooks.py | 16 ++++++++-------- .../src/nat/plugins/langchain/llm.py | 6 +++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py index a45253921a..3a3da6ae34 100644 --- a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py +++ b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) -def create_metadata_injection_client(llm_config: "LLMBaseConfig") -> "httpx.AsyncClient": +def _create_metadata_injection_client(llm_config: "LLMBaseConfig") -> "httpx.AsyncClient": """ Httpx event hook that injects custom metadata as HTTP headers. diff --git a/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py b/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py index d645fc687c..0c7c2bb2a7 100644 --- a/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py +++ b/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py @@ -20,7 +20,7 @@ from pytest_httpserver import HTTPServer from nat.builder.context import ContextState -from nat.llm.utils.hooks import create_metadata_injection_client +from nat.llm.utils.hooks import _create_metadata_injection_client class TestMetadataInjectionHook: @@ -46,7 +46,7 @@ def fixture_mock_input_message(self): async def test_hook_injects_metadata_fields(self, mock_httpx_request, mock_input_message): """Test that the hook injects custom metadata fields as headers.""" - client = create_metadata_injection_client() + client = _create_metadata_injection_client() hook = client.event_hooks["request"][0] context_state = ContextState.get() @@ -67,7 +67,7 @@ async def test_hook_skips_none_values(self, mock_httpx_request, mock_input_messa "optional_field": None, } - client = create_metadata_injection_client() + client = _create_metadata_injection_client() hook = client.event_hooks["request"][0] context_state = ContextState.get() @@ -82,7 +82,7 @@ async def test_hook_skips_none_values(self, mock_httpx_request, mock_input_messa async def test_hook_handles_missing_context(self, mock_httpx_request): """Test that hook handles missing context gracefully.""" - client = create_metadata_injection_client() + client = _create_metadata_injection_client() hook = client.event_hooks["request"][0] await hook(mock_httpx_request) @@ -94,11 +94,11 @@ async def test_hook_handles_missing_context(self, mock_httpx_request): class TestCreateMetadataInjectionClient: - """Unit tests for create_metadata_injection_client function.""" + """Unit tests for _create_metadata_injection_client function.""" async def test_creates_client_with_event_hooks(self): """Test that client is created with event hooks.""" - client = create_metadata_injection_client() + client = _create_metadata_injection_client() assert "request" in client.event_hooks assert len(client.event_hooks["request"]) == 1 @@ -139,7 +139,7 @@ async def test_headers_sent_in_http_request(self, httpserver: HTTPServer, mock_i } }) - client = create_metadata_injection_client() + client = _create_metadata_injection_client() context_state = ContextState.get() context_state.input_message.set(mock_input_message) @@ -181,7 +181,7 @@ async def test_request_succeeds_without_context(self, httpserver: HTTPServer): } }) - client = create_metadata_injection_client() + client = _create_metadata_injection_client() response = await client.post(httpserver.url_for("/v1/chat/completions"), json={ diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index ed82f1e0e6..716ff14a1a 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -39,7 +39,7 @@ from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig -from nat.llm.utils.hooks import create_metadata_injection_client +from nat.llm.utils.hooks import _create_metadata_injection_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -150,7 +150,7 @@ async def azure_openai_langchain(llm_config: AzureOpenAIModelConfig, _builder: B validate_no_responses_api(llm_config, LLMFrameworkEnum.LANGCHAIN) - http_async_client: httpx.AsyncClient = create_metadata_injection_client(llm_config) + http_async_client: httpx.AsyncClient = _create_metadata_injection_client(llm_config) try: client = AzureChatOpenAI( @@ -197,7 +197,7 @@ async def openai_langchain(llm_config: OpenAIModelConfig, _builder: Builder): from langchain_openai import ChatOpenAI - http_async_client: httpx.AsyncClient = create_metadata_injection_client(llm_config) + http_async_client: httpx.AsyncClient = _create_metadata_injection_client(llm_config) config_dict = llm_config.model_dump( exclude={"type", "thinking", "api_type", "api_key", "base_url", "verify_ssl"}, From 6689f8218848df971e0344466f5f2c22dbe6893f Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 5 Mar 2026 14:23:32 -0800 Subject: [PATCH 09/24] First pass at autogen Signed-off-by: David Gardner --- .../src/nat/plugins/autogen/llm.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/packages/nvidia_nat_autogen/src/nat/plugins/autogen/llm.py b/packages/nvidia_nat_autogen/src/nat/plugins/autogen/llm.py index 7fb64ed6c7..02727e9da4 100644 --- a/packages/nvidia_nat_autogen/src/nat/plugins/autogen/llm.py +++ b/packages/nvidia_nat_autogen/src/nat/plugins/autogen/llm.py @@ -48,6 +48,7 @@ from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _create_http_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -147,10 +148,20 @@ async def openai_autogen(llm_config: OpenAIModelConfig, _builder: Builder) -> As # Extract AutoGen-compatible configuration config_obj = { **llm_config.model_dump( - exclude={"type", "model_name", "thinking", "api_key", "base_url", "request_timeout"}, + exclude={ + "api_key", + "base_url", + "model_name", + "request_timeout", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, ), + "http_client": + _create_http_client(llm_config) } if (api_key := get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY")): @@ -201,12 +212,14 @@ async def azure_openai_autogen(llm_config: AzureOpenAIModelConfig, config_obj = { "api_key": llm_config.api_key, - "base_url": - f"{llm_config.azure_endpoint}/openai/deployments/{llm_config.azure_deployment}", "api_version": llm_config.api_version, + "base_url": + f"{llm_config.azure_endpoint}/openai/deployments/{llm_config.azure_deployment}", + "http_client": + _create_http_client(llm_config), **llm_config.model_dump( - exclude={"type", "azure_deployment", "thinking", "azure_endpoint", "api_version", "request_timeout"}, + exclude={"api_version", "azure_deployment", "azure_endpoint", "request_timeout", "thinking", "type"}, by_alias=True, exclude_none=True, ), @@ -326,8 +339,10 @@ async def nim_autogen(llm_config: NIMModelConfig, _builder: Builder) -> AsyncGen # Extract NIM configuration for OpenAI-compatible client config_obj = { + "http_client": + _create_http_client(llm_config), **llm_config.model_dump( - exclude={"type", "model_name", "thinking"}, + exclude={"model_name", "thinking", "type"}, by_alias=True, exclude_none=True, exclude_unset=True, @@ -386,8 +401,10 @@ async def litellm_autogen(llm_config: LiteLlmModelConfig, _builder: Builder) -> # Extract LiteLLM configuration for OpenAI-compatible client config_obj = { + "http_client": + _create_http_client(llm_config), **llm_config.model_dump( - exclude={"type", "model_name", "thinking"}, + exclude={"model_name", "thinking", "type"}, by_alias=True, exclude_none=True, exclude_unset=True, From 07785e7cce91e3bcf6dafbac57cb8fbe7bcdd7ad Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 5 Mar 2026 14:28:02 -0800 Subject: [PATCH 10/24] Add SSLVerificationMixin to AzureOpenAIModelConfig Signed-off-by: David Gardner --- packages/nvidia_nat_core/src/nat/llm/azure_openai_llm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/nvidia_nat_core/src/nat/llm/azure_openai_llm.py b/packages/nvidia_nat_core/src/nat/llm/azure_openai_llm.py index 08abc67244..23ed772536 100644 --- a/packages/nvidia_nat_core/src/nat/llm/azure_openai_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/azure_openai_llm.py @@ -23,6 +23,7 @@ from nat.cli.register_workflow import register_llm_provider from nat.data_models.common import OptionalSecretStr from nat.data_models.llm import LLMBaseConfig +from nat.data_models.llm import SSLVerificationMixin from nat.data_models.optimizable import OptimizableField from nat.data_models.optimizable import SearchSpace from nat.data_models.retry_mixin import RetryMixin @@ -33,6 +34,7 @@ class AzureOpenAIModelConfig( LLMBaseConfig, RetryMixin, ThinkingMixin, + SSLVerificationMixin, name="azure_openai", ): """An Azure OpenAI LLM provider to be used with an LLM client.""" From 6d8c8e3fc8b2af6fb2857a1264e354bc4d88c8f8 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 5 Mar 2026 15:53:49 -0800 Subject: [PATCH 11/24] Fix ssl handling for ADK, first pass at handling for crewai Signed-off-by: David Gardner --- .../nvidia_nat_adk/src/nat/plugins/adk/llm.py | 32 +++++++++++-------- .../src/nat/llm/utils/http_client.py | 8 +++++ .../src/nat/plugins/crewai/llm.py | 9 ++++++ 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py index 88f20d3e98..8b84abc29c 100644 --- a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py +++ b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py @@ -25,6 +25,7 @@ from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _handle_litellm_verify_ssl # ADK uses litellm under the hood from nat.utils.responses_api import validate_no_responses_api logger = logging.getLogger(__name__) @@ -65,7 +66,7 @@ async def azure_openai_adk(config: AzureOpenAIModelConfig, _builder: Builder): config_dict["timeout"] = config.request_timeout config_dict["api_version"] = config.api_version - config_dict["ssl_verify"] = config.verify_ssl + _handle_litellm_verify_ssl(config.verify_ssl) yield LiteLlm(f"azure/{config.azure_deployment}", **config_dict) @@ -76,13 +77,13 @@ async def litellm_adk(litellm_config: LiteLlmModelConfig, _builder: Builder): validate_no_responses_api(litellm_config, LLMFrameworkEnum.ADK) - kwargs = {"ssl_verify": litellm_config.verify_ssl} + _handle_litellm_verify_ssl(litellm_config.verify_ssl) yield LiteLlm(**litellm_config.model_dump( exclude={"api_type", "max_retries", "thinking", "type", "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, - ), **kwargs) + )) @register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.ADK) @@ -114,7 +115,7 @@ async def nim_adk(config: NIMModelConfig, _builder: Builder): if config.base_url: config_dict["api_base"] = config.base_url - config_dict["ssl_verify"] = config.verify_ssl + _handle_litellm_verify_ssl(config.verify_ssl) yield LiteLlm(f"nvidia_nim/{config.model_name}", **config_dict) @@ -132,15 +133,17 @@ async def openai_adk(config: OpenAIModelConfig, _builder: Builder): validate_no_responses_api(config, LLMFrameworkEnum.ADK) config_dict = config.model_dump( - exclude={"api_type", - "base_url", - "max_retries", - "model", - "model_name", - "request_timeout", - "thinking", - "type", - "verify_ssl"}, + exclude={ + "api_type", + "base_url", + "max_retries", + "model", + "model_name", + "request_timeout", + "thinking", + "type", + "verify_ssl" + }, by_alias=True, exclude_none=True, exclude_unset=True, @@ -152,7 +155,8 @@ async def openai_adk(config: OpenAIModelConfig, _builder: Builder): config_dict["api_base"] = base_url if config.request_timeout is not None: config_dict["timeout"] = config.request_timeout - config_dict["ssl_verify"] = config.verify_ssl + + _handle_litellm_verify_ssl(config.verify_ssl) yield LiteLlm(config.model_name, **config_dict) diff --git a/packages/nvidia_nat_core/src/nat/llm/utils/http_client.py b/packages/nvidia_nat_core/src/nat/llm/utils/http_client.py index 8e48bc269e..9d9bd75552 100644 --- a/packages/nvidia_nat_core/src/nat/llm/utils/http_client.py +++ b/packages/nvidia_nat_core/src/nat/llm/utils/http_client.py @@ -40,3 +40,11 @@ def _set_kwarg(kwarg_name: str, config_attr: str): client_class = httpx.Client return client_class(**kwargs) + + +def _handle_litellm_verify_ssl(llm_config: "LLMBaseConfig") -> None: + if not getattr(llm_config, "verify_ssl", True): + # Currently litellm does not support disabling this on a per-LLM basis for any backend other than Bedrock and + # AIM Guardrail. + import litellm + litellm.ssl_verify = False diff --git a/packages/nvidia_nat_crewai/src/nat/plugins/crewai/llm.py b/packages/nvidia_nat_crewai/src/nat/plugins/crewai/llm.py index 8d54ff6d71..faa8b25b3d 100644 --- a/packages/nvidia_nat_crewai/src/nat/plugins/crewai/llm.py +++ b/packages/nvidia_nat_crewai/src/nat/plugins/crewai/llm.py @@ -27,6 +27,7 @@ from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _handle_litellm_verify_ssl # crewAI uses litellm under the hood from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -114,6 +115,8 @@ async def azure_openai_crewai(llm_config: AzureOpenAIModelConfig, _builder: Buil if llm_config.request_timeout is not None: config_dict["timeout"] = llm_config.request_timeout + _handle_litellm_verify_ssl(llm_config.verify_ssl) + client = LLM( **config_dict, model=model, @@ -136,6 +139,8 @@ async def nim_crewai(llm_config: NIMModelConfig, _builder: Builder): if nvidia_api_key is not None: os.environ["NVIDIA_NIM_API_KEY"] = nvidia_api_key + _handle_litellm_verify_ssl(llm_config.verify_ssl) + client = LLM( **llm_config.model_dump( exclude={"type", "model_name", "thinking", "api_type"}, @@ -163,6 +168,8 @@ async def openai_crewai(llm_config: OpenAIModelConfig, _builder: Builder): exclude_unset=True, ) + _handle_litellm_verify_ssl(llm_config.verify_ssl) + if (api_key := get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY")): config_dict["api_key"] = api_key if (base_url := llm_config.base_url or os.getenv("OPENAI_BASE_URL")): @@ -182,6 +189,8 @@ async def litellm_crewai(llm_config: LiteLlmModelConfig, _builder: Builder): validate_no_responses_api(llm_config, LLMFrameworkEnum.CREWAI) + _handle_litellm_verify_ssl(llm_config.verify_ssl) + client = LLM(**llm_config.model_dump( exclude={"type", "thinking", "api_type"}, by_alias=True, exclude_none=True, exclude_unset=True)) From e8173c65c87792cd5e741937f10e8ceee3382956 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 5 Mar 2026 16:12:40 -0800 Subject: [PATCH 12/24] Document that the verify_ssl parameter is supported as-is Signed-off-by: David Gardner --- packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index 716ff14a1a..dc71c1bce1 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -179,6 +179,7 @@ async def nim_langchain(llm_config: NIMModelConfig, _builder: Builder): validate_no_responses_api(llm_config, LLMFrameworkEnum.LANGCHAIN) # prefer max_completion_tokens over max_tokens + # verify_ssl is a supported keyword parameter for the ChatNVIDIA client client = ChatNVIDIA( **llm_config.model_dump( exclude={"type", "max_tokens", "thinking", "api_type"}, From 3ec4a56301fa7585cc60b9668895b4469fdd414b Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 09:20:57 -0800 Subject: [PATCH 13/24] First pass at llama_index Signed-off-by: David Gardner --- .../src/nat/plugins/llama_index/llm.py | 54 +++++++++++++++---- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py index fb478ae3c3..12ea5bcc5e 100644 --- a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py +++ b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py @@ -30,6 +30,7 @@ from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _create_http_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -81,6 +82,14 @@ def inject(self, messages: Sequence[ChatMessage], *args, **kwargs) -> FunctionAr return client +def _get_http_clients(llm_config: LLMBaseConfig) -> dict[str, "httpx.AsyncClient | httpx.Client"]: + """Get a dictionary of HTTP clients, one sync one async.""" + return { + "http_client": _create_http_client(llm_config, use_async=False), + "async_http_client": _create_http_client(llm_config, use_async=True) + } + + @register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX) async def aws_bedrock_llama_index(llm_config: AWSBedrockModelConfig, _builder: Builder): @@ -89,8 +98,10 @@ async def aws_bedrock_llama_index(llm_config: AWSBedrockModelConfig, _builder: B validate_no_responses_api(llm_config, LLMFrameworkEnum.LLAMA_INDEX) # LlamaIndex uses context_size instead of max_tokens - llm = Bedrock(**llm_config.model_dump( - exclude={"type", "top_p", "thinking", "api_type"}, by_alias=True, exclude_none=True, exclude_unset=True)) + llm = Bedrock(**llm_config.model_dump(exclude={"api_type", "thinking", "top_p", "type", "verify_ssl"}, + by_alias=True, + exclude_none=True, + exclude_unset=True)) yield _patch_llm_based_on_config(llm, llm_config) @@ -102,13 +113,15 @@ async def azure_openai_llama_index(llm_config: AzureOpenAIModelConfig, _builder: validate_no_responses_api(llm_config, LLMFrameworkEnum.LLAMA_INDEX) - config_dict = llm_config.model_dump(exclude={"type", "thinking", "api_type", "api_version", "request_timeout"}, - by_alias=True, - exclude_none=True, - exclude_unset=True) + config_dict = llm_config.model_dump( + exclude={"api_type", "api_version", "request_timeout", "thinking", "type", "verify_ssl"}, + by_alias=True, + exclude_none=True, + exclude_unset=True) if llm_config.request_timeout is not None: config_dict["timeout"] = llm_config.request_timeout + config_dict.update(_get_http_clients(llm_config)) llm = AzureOpenAI( **config_dict, api_version=llm_config.api_version, @@ -124,8 +137,20 @@ async def nim_llama_index(llm_config: NIMModelConfig, _builder: Builder): validate_no_responses_api(llm_config, LLMFrameworkEnum.LLAMA_INDEX) - llm = NVIDIA(**llm_config.model_dump( - exclude={"type", "thinking", "api_type"}, by_alias=True, exclude_none=True, exclude_unset=True)) + config_dict = llm_config.model_dump( + exclude={ + "api_type", + "thinking", + "type", + "verify_ssl", + }, + by_alias=True, + exclude_none=True, + exclude_unset=True, + ) + + config_dict.update(_get_http_clients(llm_config)) + llm = NVIDIA(**config_dict) yield _patch_llm_based_on_config(llm, llm_config) @@ -137,7 +162,7 @@ async def openai_llama_index(llm_config: OpenAIModelConfig, _builder: Builder): from llama_index.llms.openai import OpenAIResponses config_dict = llm_config.model_dump( - exclude={"type", "thinking", "api_type", "api_key", "base_url", "request_timeout"}, + exclude={"api_key", "api_type", "base_url", "request_timeout", "thinking", "type", "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, @@ -151,6 +176,7 @@ async def openai_llama_index(llm_config: OpenAIModelConfig, _builder: Builder): if llm_config.request_timeout is not None: config_dict["timeout"] = llm_config.request_timeout + config_dict.update(_get_http_clients(llm_config)) if llm_config.api_type == APITypeEnum.RESPONSES: llm = OpenAIResponses(**config_dict) else: @@ -164,9 +190,15 @@ async def litellm_llama_index(llm_config: LiteLlmModelConfig, _builder: Builder) from llama_index.llms.litellm import LiteLLM + from nat.llm.utils.http_client import _handle_litellm_verify_ssl + + _handle_litellm_verify_ssl(llm_config.verify_ssl) validate_no_responses_api(llm_config, LLMFrameworkEnum.LLAMA_INDEX) - llm = LiteLLM(**llm_config.model_dump( - exclude={"type", "thinking", "api_type"}, by_alias=True, exclude_none=True, exclude_unset=True)) + llm = LiteLLM( + **llm_config.model_dump(exclude={"api_type", "thinking", "type", "verify_ssl"}, + by_alias=True, + exclude_none=True, + exclude_unset=True), ) yield _patch_llm_based_on_config(llm, llm_config) From 1f5e4b7b135225d7fe7d855cc5e10b2c06cc795e Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 11:16:52 -0800 Subject: [PATCH 14/24] First pass at semantic kernel Signed-off-by: David Gardner --- .../src/nat/plugins/semantic_kernel/llm.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/packages/nvidia_nat_semantic_kernel/src/nat/plugins/semantic_kernel/llm.py b/packages/nvidia_nat_semantic_kernel/src/nat/plugins/semantic_kernel/llm.py index 5a92d34e62..d912cc3cde 100644 --- a/packages/nvidia_nat_semantic_kernel/src/nat/plugins/semantic_kernel/llm.py +++ b/packages/nvidia_nat_semantic_kernel/src/nat/plugins/semantic_kernel/llm.py @@ -25,6 +25,7 @@ from nat.data_models.thinking_mixin import ThinkingMixin from nat.llm.azure_openai_llm import AzureOpenAIModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _create_http_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -90,18 +91,19 @@ def inject(self, chat_history: ChatHistory, *args, **kwargs) -> FunctionArgument @register_llm_client(config_type=AzureOpenAIModelConfig, wrapper_type=LLMFrameworkEnum.SEMANTIC_KERNEL) async def azure_openai_semantic_kernel(llm_config: AzureOpenAIModelConfig, _builder: Builder): + from openai import AsyncAzureOpenAI from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion validate_no_responses_api(llm_config, LLMFrameworkEnum.SEMANTIC_KERNEL) - llm = AzureChatCompletion( - api_key=get_secret_value(llm_config.api_key), - api_version=llm_config.api_version, - endpoint=llm_config.azure_endpoint, - deployment_name=llm_config.azure_deployment, - ) + async with AsyncAzureOpenAI(api_key=get_secret_value(llm_config.api_key), + api_version=llm_config.api_version, + azure_endpoint=llm_config.azure_endpoint, + azure_deployment=llm_config.azure_deployment, + http_client=_create_http_client(llm_config, use_async=True)) as async_client: + llm = AzureChatCompletion(async_client=async_client) - yield _patch_llm_based_on_config(llm, llm_config) + yield _patch_llm_based_on_config(llm, llm_config) @register_llm_client(config_type=OpenAIModelConfig, wrapper_type=LLMFrameworkEnum.SEMANTIC_KERNEL) @@ -115,6 +117,8 @@ async def openai_semantic_kernel(llm_config: OpenAIModelConfig, _builder: Builde api_key = get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY") base_url = llm_config.base_url or os.getenv("OPENAI_BASE_URL") - async with AsyncOpenAI(api_key=api_key, base_url=base_url) as async_client: + async with AsyncOpenAI(api_key=api_key, + base_url=base_url, + http_client=_create_http_client(llm_config, use_async=True)) as async_client: llm = OpenAIChatCompletion(ai_model_id=llm_config.model_name, async_client=async_client) yield _patch_llm_based_on_config(llm, llm_config) From 73af0ab9a35d83dd731326c582765fce0d79eabc Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 11:53:17 -0800 Subject: [PATCH 15/24] First pass at strands Signed-off-by: David Gardner --- .../src/nat/plugins/strands/llm.py | 59 ++++++++++++++----- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/packages/nvidia_nat_strands/src/nat/plugins/strands/llm.py b/packages/nvidia_nat_strands/src/nat/plugins/strands/llm.py index 7d499a502b..7a506c5574 100644 --- a/packages/nvidia_nat_strands/src/nat/plugins/strands/llm.py +++ b/packages/nvidia_nat_strands/src/nat/plugins/strands/llm.py @@ -59,6 +59,7 @@ from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _create_http_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -158,13 +159,26 @@ async def openai_strands(llm_config: OpenAIModelConfig, _builder: Builder) -> As validate_no_responses_api(llm_config, LLMFrameworkEnum.STRANDS) + from openai import AsyncOpenAI + from strands.models.openai import OpenAIModel params = llm_config.model_dump( - exclude={"type", "api_type", "api_key", "base_url", "model_name", "max_retries", "thinking", "request_timeout"}, + exclude={ + "api_key", + "api_type", + "base_url", + "max_retries", + "model_name", + "request_timeout", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, - exclude_unset=True) + exclude_unset=True, + ) api_key = get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY") base_url = llm_config.base_url or os.getenv("OPENAI_BASE_URL") @@ -172,12 +186,14 @@ async def openai_strands(llm_config: OpenAIModelConfig, _builder: Builder) -> As client_args: dict[str, Any] = { "api_key": api_key, "base_url": base_url, + "http_client": _create_http_client(llm_config, use_async=True), } if llm_config.request_timeout is not None: client_args["timeout"] = llm_config.request_timeout + oai_client = AsyncOpenAI(**client_args) client = OpenAIModel( - client_args=client_args, + client=oai_client, model_id=llm_config.model_name, params=params, ) @@ -207,6 +223,8 @@ async def nim_strands(llm_config: NIMModelConfig, _builder: Builder) -> AsyncGen validate_no_responses_api(llm_config, LLMFrameworkEnum.STRANDS) # NIM is OpenAI compatible; use OpenAI model with NIM base_url and api_key + from openai import AsyncOpenAI + from strands.models.openai import OpenAIModel # Create a custom OpenAI model that formats text content as strings for NIM compatibility @@ -266,10 +284,20 @@ def format_request_messages(cls, messages, system_prompt=None, *, system_prompt_ return formatted_messages params = llm_config.model_dump( - exclude={"type", "api_type", "api_key", "base_url", "model_name", "max_retries", "thinking"}, + exclude={ + "api_key", + "api_type", + "base_url", + "max_retries", + "model_name", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, - exclude_unset=True) + exclude_unset=True, + ) # Determine base_url base_url = llm_config.base_url or "https://integrate.api.nvidia.com/v1" @@ -280,11 +308,13 @@ def format_request_messages(cls, messages, system_prompt=None, *, system_prompt_ if llm_config.base_url and llm_config.base_url.strip() and api_key is None: api_key = "dummy-api-key" + oai_client = AsyncOpenAI( + api_key=api_key, + base_url=base_url, + http_client=_create_http_client(llm_config, use_async=True), + ) client = NIMCompatibleOpenAIModel( - client_args={ - "api_key": api_key, - "base_url": base_url, - }, + client=oai_client, model_id=llm_config.model_name, params=params, ) @@ -326,15 +356,16 @@ async def bedrock_strands(llm_config: AWSBedrockModelConfig, _builder: Builder) params = llm_config.model_dump( exclude={ - "type", "api_type", - "model_name", - "region_name", "base_url", - "max_retries", - "thinking", "context_size", "credentials_profile_name", + "max_retries", + "model_name", + "region_name", + "thinking", + "type", + "verify_ssl", }, by_alias=True, exclude_none=True, From 23092a9949d81326c9d47b03f59f7ce4819acbbd Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 12:11:11 -0800 Subject: [PATCH 16/24] Fix type checking imports Signed-off-by: David Gardner --- packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py | 1 - packages/nvidia_nat_core/src/nat/data_models/llm.py | 3 ++- packages/nvidia_nat_core/src/nat/llm/openai_llm.py | 3 ++- packages/nvidia_nat_core/src/nat/llm/utils/hooks.py | 2 ++ .../src/nat/plugins/llama_index/llm.py | 7 +++++-- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py index 8b84abc29c..2e6c7894cd 100644 --- a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py +++ b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py @@ -15,7 +15,6 @@ import logging import os -import typing from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum diff --git a/packages/nvidia_nat_core/src/nat/data_models/llm.py b/packages/nvidia_nat_core/src/nat/data_models/llm.py index fa591458fb..344d9b0915 100644 --- a/packages/nvidia_nat_core/src/nat/data_models/llm.py +++ b/packages/nvidia_nat_core/src/nat/data_models/llm.py @@ -41,10 +41,11 @@ class LLMBaseConfig(TypedBaseModel, BaseModelRegistryTag): LLMBaseConfigT = typing.TypeVar("LLMBaseConfigT", bound=LLMBaseConfig) + class SSLVerificationMixin(BaseModel): """Mixin for SSL verification configuration.""" verify_ssl: bool = Field( default=True, description="Whether to verify SSL certificates when making API calls to the LLM provider. Defaults to True.", - ) \ No newline at end of file + ) diff --git a/packages/nvidia_nat_core/src/nat/llm/openai_llm.py b/packages/nvidia_nat_core/src/nat/llm/openai_llm.py index c5a31de6f1..a268bd77d1 100644 --- a/packages/nvidia_nat_core/src/nat/llm/openai_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/openai_llm.py @@ -30,7 +30,8 @@ from nat.data_models.thinking_mixin import ThinkingMixin -class OpenAIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, SSLVerificationMixin, name="openai"): +class OpenAIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, SSLVerificationMixin, + name="openai"): """An OpenAI LLM provider to be used with an LLM client.""" model_config = ConfigDict(protected_namespaces=(), extra="allow") diff --git a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py index 3a3da6ae34..5699410584 100644 --- a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py +++ b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py @@ -26,6 +26,8 @@ if TYPE_CHECKING: import httpx + from nat.data_models.llm import LLMBaseConfig + from nat.llm.utils.constants import LLMHeaderPrefix from nat.llm.utils.http_client import _create_http_client diff --git a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py index 12ea5bcc5e..bfb85a773f 100644 --- a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py +++ b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py @@ -14,8 +14,8 @@ # limitations under the License. import os +import typing from collections.abc import Sequence -from typing import TypeVar from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum @@ -38,7 +38,10 @@ from nat.utils.responses_api import validate_no_responses_api from nat.utils.type_utils import override -ModelType = TypeVar("ModelType") +if typing.TYPE_CHECKING: + import httpx + +ModelType = typing.TypeVar("ModelType") def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType: From 84692dfdc8f337dd70e0771f0824d176df56d6dd Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 12:51:13 -0800 Subject: [PATCH 17/24] Fix test Signed-off-by: David Gardner --- .../tests/nat/llm/utils/test_hooks.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py b/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py index 0c7c2bb2a7..a57d9212b0 100644 --- a/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py +++ b/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py @@ -17,12 +17,28 @@ from unittest.mock import MagicMock import pytest +from pydantic import Field from pytest_httpserver import HTTPServer from nat.builder.context import ContextState +from nat.data_models.llm import LLMBaseConfig +from nat.data_models.llm import SSLVerificationMixin from nat.llm.utils.hooks import _create_metadata_injection_client +# TODO: need some tests for handling for request_timeout and verify_ssl +class LLMConfig(LLMBaseConfig): + pass + + +class LLMConfigWithTimeout(LLMBaseConfig): + request_timeout: float | None = Field(default=None, gt=0.0, description="HTTP request timeout in seconds.") + + +class LLMConfigWithSSL(LLMConfigWithTimeout, SSLVerificationMixin): + pass + + class TestMetadataInjectionHook: """Unit tests for the metadata injection hook function.""" @@ -46,7 +62,7 @@ def fixture_mock_input_message(self): async def test_hook_injects_metadata_fields(self, mock_httpx_request, mock_input_message): """Test that the hook injects custom metadata fields as headers.""" - client = _create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) hook = client.event_hooks["request"][0] context_state = ContextState.get() @@ -67,7 +83,7 @@ async def test_hook_skips_none_values(self, mock_httpx_request, mock_input_messa "optional_field": None, } - client = _create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) hook = client.event_hooks["request"][0] context_state = ContextState.get() @@ -82,7 +98,7 @@ async def test_hook_skips_none_values(self, mock_httpx_request, mock_input_messa async def test_hook_handles_missing_context(self, mock_httpx_request): """Test that hook handles missing context gracefully.""" - client = _create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) hook = client.event_hooks["request"][0] await hook(mock_httpx_request) @@ -98,7 +114,7 @@ class TestCreateMetadataInjectionClient: async def test_creates_client_with_event_hooks(self): """Test that client is created with event hooks.""" - client = _create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) assert "request" in client.event_hooks assert len(client.event_hooks["request"]) == 1 @@ -139,7 +155,7 @@ async def test_headers_sent_in_http_request(self, httpserver: HTTPServer, mock_i } }) - client = _create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) context_state = ContextState.get() context_state.input_message.set(mock_input_message) @@ -181,7 +197,7 @@ async def test_request_succeeds_without_context(self, httpserver: HTTPServer): } }) - client = _create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) response = await client.post(httpserver.url_for("/v1/chat/completions"), json={ From 29aae30cb156894c90421a74e1dc9b5c16e912b5 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 13:59:15 -0800 Subject: [PATCH 18/24] Update tests Signed-off-by: David Gardner --- .../tests/test_strands_llm.py | 105 ++++++++++-------- 1 file changed, 59 insertions(+), 46 deletions(-) diff --git a/packages/nvidia_nat_strands/tests/test_strands_llm.py b/packages/nvidia_nat_strands/tests/test_strands_llm.py index 0313f2bc3d..34d972ee2b 100644 --- a/packages/nvidia_nat_strands/tests/test_strands_llm.py +++ b/packages/nvidia_nat_strands/tests/test_strands_llm.py @@ -156,76 +156,89 @@ def nim_config_wrong_api(self): api_type=APITypeEnum.RESPONSES, ) + @pytest.fixture(name="mock_oai_clients") + def mock_oai_clients_fixture(self): + with patch("openai.AsyncOpenAI") as mock_oai: + mock_oai.return_value = mock_oai + + # Patch OpenAIModel constructor to track the call + with patch("strands.models.openai.OpenAIModel.__init__", return_value=None) as mock_oai_model: + yield mock_oai, mock_oai_model + @pytest.mark.asyncio - async def test_nim_strands_basic(self, nim_config, mock_builder): + async def test_nim_strands_basic(self, nim_config, mock_builder, mock_oai_clients): """Test that nim_strands creates a NIMCompatibleOpenAIModel.""" - # Patch OpenAIModel.__init__ to track the call - with patch("strands.models.openai.OpenAIModel.__init__", return_value=None) as mock_init: - # pylint: disable=not-async-context-manager - async with nim_strands(nim_config, mock_builder) as result: - # Verify the result is a NIMCompatibleOpenAIModel instance - assert result is not None + (mock_oai, mock_oai_model) = mock_oai_clients - # Verify OpenAIModel.__init__ was called (the base class) - mock_init.assert_called_once() - call_args = mock_init.call_args + # pylint: disable=not-async-context-manager + async with nim_strands(nim_config, mock_builder) as result: + # Verify the result is a NIMCompatibleOpenAIModel instance + assert result is not None + + mock_oai.assert_called_once() # Ensure OpenAI client init was called + oai_call_args = mock_oai.call_args + oai_call_kwargs = oai_call_args[1] + assert oai_call_kwargs["api_key"] == "test-api-key" + assert oai_call_kwargs["base_url"] == "https://integrate.api.nvidia.com/v1" + + # Verify OpenAIModel.__init__ was called (the base class) + mock_oai_model.assert_called_once() + call_args = mock_oai_model.call_args - # First arg is self, get kwargs - call_kwargs = call_args[1] + # First arg is self, get kwargs + call_kwargs = call_args[1] - # Verify client_args - assert "client_args" in call_kwargs - client_args = call_kwargs["client_args"] - assert client_args["api_key"] == "test-api-key" - assert client_args["base_url"] == "https://integrate.api.nvidia.com/v1" + # Verify client + assert "client" in call_kwargs + call_kwargs["client"] = mock_oai - # Verify model_id - assert call_kwargs["model_id"] == "meta/llama-3.1-8b-instruct" + # Verify model_id + assert call_kwargs["model_id"] == "meta/llama-3.1-8b-instruct" @pytest.mark.asyncio - async def test_nim_strands_with_env_var(self, mock_builder): + async def test_nim_strands_with_env_var(self, mock_builder, mock_oai_clients): """Test nim_strands with environment variable for API key.""" + (mock_oai, mock_oai_model) = mock_oai_clients nim_config = NIMModelConfig(model_name="test-model") - with patch("strands.models.openai.OpenAIModel.__init__", return_value=None) as mock_init: - with patch.dict("os.environ", {"NVIDIA_API_KEY": "env-api-key"}): - # pylint: disable=not-async-context-manager - async with nim_strands(nim_config, mock_builder): - mock_init.assert_called_once() - call_kwargs = mock_init.call_args[1] - client_args = call_kwargs["client_args"] - assert client_args["api_key"] == "env-api-key" + with patch.dict("os.environ", {"NVIDIA_API_KEY": "env-api-key"}): + # pylint: disable=not-async-context-manager + async with nim_strands(nim_config, mock_builder): + mock_oai_model.assert_called_once() + mock_oai.assert_called_once() + + call_kwargs = mock_oai.call_args[1] + assert call_kwargs["api_key"] == "env-api-key" @pytest.mark.asyncio - async def test_nim_strands_default_base_url(self, mock_builder): + async def test_nim_strands_default_base_url(self, mock_builder, mock_oai_clients): """Test nim_strands uses default base_url when not provided.""" + (mock_oai, mock_oai_model) = mock_oai_clients nim_config = NIMModelConfig(model_name="test-model", api_key="test-key") - with patch("strands.models.openai.OpenAIModel.__init__", return_value=None) as mock_init: - # pylint: disable=not-async-context-manager - async with nim_strands(nim_config, mock_builder): - mock_init.assert_called_once() - call_kwargs = mock_init.call_args[1] - client_args = call_kwargs["client_args"] - assert client_args["base_url"] == "https://integrate.api.nvidia.com/v1" + async with nim_strands(nim_config, mock_builder): # pylint: disable=not-async-context-manager + mock_oai_model.assert_called_once() + mock_oai.assert_called_once() + call_kwargs = mock_oai.call_args[1] + assert call_kwargs["base_url"] == "https://integrate.api.nvidia.com/v1" @pytest.mark.asyncio - async def test_nim_strands_nim_override_dummy_api_key(self, mock_builder): + async def test_nim_strands_nim_override_dummy_api_key(self, mock_builder, mock_oai_clients): """Test nim_strands uses dummy API key when base_url is set but no API key available.""" + (mock_oai, mock_oai_model) = mock_oai_clients nim_config = NIMModelConfig( model_name="test-model", base_url="https://custom-nim.example.com/v1", ) - with patch("strands.models.openai.OpenAIModel.__init__", return_value=None) as mock_init: - with patch.dict(os.environ, {}, clear=True): - # pylint: disable=not-async-context-manager - async with nim_strands(nim_config, mock_builder): - mock_init.assert_called_once() - call_kwargs = mock_init.call_args[1] - client_args = call_kwargs["client_args"] - assert client_args["base_url"] == "https://custom-nim.example.com/v1" - assert client_args["api_key"] == "dummy-api-key" + with patch.dict(os.environ, {}, clear=True): + # pylint: disable=not-async-context-manager + async with nim_strands(nim_config, mock_builder): + mock_oai_model.assert_called_once() + mock_oai.assert_called_once() + call_kwargs = mock_oai.call_args[1] + assert call_kwargs["base_url"] == "https://custom-nim.example.com/v1" + assert call_kwargs["api_key"] == "dummy-api-key" def test_nim_compatible_openai_model_format_request_messages(self): """Test NIMCompatibleOpenAIModel.format_request_messages.""" From d6502841a617746a9cc1d1470a022254a9da4902 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 15:32:32 -0800 Subject: [PATCH 19/24] Rename create_httpx_client_with_dynamo_hooks to indicate that it is private, update the method to receive a config instance to reduce some of the redundant code Signed-off-by: David Gardner --- .../nvidia_nat_adk/src/nat/plugins/adk/llm.py | 45 +++------------- .../nvidia_nat_core/src/nat/llm/dynamo_llm.py | 51 ++++++++++++------- .../tests/nat/llm/test_dynamo_llm.py | 20 ++++---- .../src/nat/plugins/langchain/llm.py | 33 +----------- .../tests/test_dynamo_trie_loading.py | 8 +-- .../tests/test_llm_langchain.py | 6 +-- 6 files changed, 59 insertions(+), 104 deletions(-) diff --git a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py index 2e6c7894cd..f490d9e14d 100644 --- a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py +++ b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py @@ -177,7 +177,7 @@ async def dynamo_adk(config: DynamoModelConfig, _builder: Builder): from google.adk.models.lite_llm import LiteLlm - from nat.llm.dynamo_llm import create_httpx_client_with_dynamo_hooks + from nat.llm.dynamo_llm import _create_httpx_client_with_dynamo_hooks validate_no_responses_api(config, LLMFrameworkEnum.ADK) @@ -200,36 +200,11 @@ async def dynamo_adk(config: DynamoModelConfig, _builder: Builder): if config.base_url: config_dict["api_base"] = config.base_url + http_client = None if config.enable_nvext_hints: - from pathlib import Path - from openai import AsyncOpenAI - from nat.profiler.prediction_trie import load_prediction_trie - from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup - - prediction_lookup: PredictionTrieLookup | None = None - if config.nvext_prediction_trie_path: - try: - trie_path = Path(config.nvext_prediction_trie_path) - trie = load_prediction_trie(trie_path) - prediction_lookup = PredictionTrieLookup(trie) - logger.info("Loaded prediction trie from %s", config.nvext_prediction_trie_path) - except FileNotFoundError: - logger.warning("Prediction trie file not found: %s", config.nvext_prediction_trie_path) - except Exception as e: - logger.exception("Failed to load prediction trie: %s", e) - - http_client = create_httpx_client_with_dynamo_hooks( - total_requests=config.nvext_prefix_total_requests, - osl=config.nvext_prefix_osl, - iat=config.nvext_prefix_iat, - timeout=config.request_timeout, - prediction_lookup=prediction_lookup, - cache_pin_type=config.nvext_cache_pin_type, - cache_control_mode=config.nvext_cache_control_mode, - max_sensitivity=config.nvext_max_sensitivity, - ) + http_client = _create_httpx_client_with_dynamo_hooks(config) api_key = (config.api_key.get_secret_value() if config.api_key else os.getenv("OPENAI_API_KEY", "unused")) base_url = config.base_url or os.getenv("OPENAI_BASE_URL", "http://localhost:8000/v1") @@ -241,12 +216,8 @@ async def dynamo_adk(config: DynamoModelConfig, _builder: Builder): ) config_dict["client"] = openai_client - logger.info( - "Dynamo agent hints enabled for ADK: total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", - config.nvext_prefix_total_requests, - config.nvext_prefix_osl, - config.nvext_prefix_iat, - "loaded" if prediction_lookup else "disabled", - ) - - yield LiteLlm(config.model_name, **config_dict) + try: + yield LiteLlm(config.model_name, **config_dict) + finally: + if http_client is not None: + await http_client.aclose() diff --git a/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py b/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py index f7e03a05c5..d161af5ad4 100644 --- a/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py @@ -64,6 +64,7 @@ from contextlib import contextmanager from contextvars import ContextVar from enum import StrEnum +from pathlib import Path from typing import TYPE_CHECKING import httpx @@ -702,16 +703,7 @@ async def aclose(self) -> None: # ============================================================================= -def create_httpx_client_with_dynamo_hooks( - total_requests: int, - osl: int, - iat: int, - timeout: float = 600.0, - prediction_lookup: "PredictionTrieLookup | None" = None, - cache_pin_type: CachePinType | None = CachePinType.EPHEMERAL, - cache_control_mode: CacheControlMode = CacheControlMode.ALWAYS, - max_sensitivity: int = 1000, -) -> "httpx.AsyncClient": +def _create_httpx_client_with_dynamo_hooks(config: DynamoModelConfig) -> "httpx.AsyncClient": """ Create an httpx.AsyncClient with Dynamo hint injection via custom transport. @@ -734,24 +726,45 @@ def create_httpx_client_with_dynamo_hooks( """ import httpx + from nat.llm.utils.httpx_utils import _create_http_client + from nat.profiler.prediction_trie import load_prediction_trie + from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + + prediction_lookup: PredictionTrieLookup | None = None + if config.nvext_prediction_trie_path: + try: + trie_path = Path(config.nvext_prediction_trie_path) + trie = load_prediction_trie(trie_path) + prediction_lookup = PredictionTrieLookup(trie) + logger.info("Loaded prediction trie from %s", config.nvext_prediction_trie_path) + except FileNotFoundError: + logger.warning("Prediction trie file not found: %s", config.nvext_prediction_trie_path) + except Exception as e: + logger.exception("Failed to load prediction trie: %s", e) + # Create base transport and wrap with custom transport base_transport = httpx.AsyncHTTPTransport() dynamo_transport = _DynamoTransport( transport=base_transport, - total_requests=total_requests, - osl=osl, - iat=iat, + total_requests=config.nvext_prefix_total_requests, + osl=config.nvext_prefix_osl, + iat=config.nvext_prefix_iat, prediction_lookup=prediction_lookup, - cache_pin_type=cache_pin_type, - cache_control_mode=cache_control_mode, - max_sensitivity=max_sensitivity, + cache_pin_type=config.nvext_cache_pin_type, + cache_control_mode=config.nvext_cache_control_mode, + max_sensitivity=config.nvext_max_sensitivity, ) - return httpx.AsyncClient( - transport=dynamo_transport, - timeout=httpx.Timeout(timeout), + logger.info( + "Dynamo agent hints enabled: total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", + config.nvext_prefix_total_requests, + config.nvext_prefix_osl, + config.nvext_prefix_iat, + "loaded" if config.nvext_prediction_trie_path else "disabled", ) + return _create_http_client(llm_config=config, use_async=True, transport=dynamo_transport) + # ============================================================================= # PROVIDER REGISTRATION diff --git a/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py b/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py index 1c6d292742..b88f10a754 100644 --- a/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py +++ b/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py @@ -23,7 +23,7 @@ from nat.llm.dynamo_llm import CachePinType from nat.llm.dynamo_llm import DynamoModelConfig from nat.llm.dynamo_llm import DynamoPrefixContext -from nat.llm.dynamo_llm import create_httpx_client_with_dynamo_hooks +from nat.llm.dynamo_llm import _create_httpx_client_with_dynamo_hooks # --------------------------------------------------------------------------- # DynamoModelConfig Tests @@ -409,11 +409,11 @@ async def test_prefix_id_consistent_across_transport_requests(self): class TestCreateHttpxClient: - """Tests for create_httpx_client_with_dynamo_hooks.""" + """Tests for _create_httpx_client_with_dynamo_hooks.""" def test_uses_custom_timeout(self): """Test that the function uses the provided timeout.""" - client = create_httpx_client_with_dynamo_hooks( + client = _create_httpx_client_with_dynamo_hooks( total_requests=10, osl=512, iat=250, @@ -426,7 +426,7 @@ def test_uses_custom_timeout(self): def test_uses_default_timeout(self): """Test that the function uses default timeout when not specified.""" - client = create_httpx_client_with_dynamo_hooks( + client = _create_httpx_client_with_dynamo_hooks( total_requests=10, osl=512, iat=250, @@ -435,10 +435,10 @@ def test_uses_default_timeout(self): assert client.timeout.connect == 600.0 def test_creates_client_with_custom_transport(self): - """Test that create_httpx_client_with_dynamo_hooks uses _DynamoTransport.""" + """Test that _create_httpx_client_with_dynamo_hooks uses _DynamoTransport.""" from nat.llm.dynamo_llm import _DynamoTransport - client = create_httpx_client_with_dynamo_hooks( + client = _create_httpx_client_with_dynamo_hooks( total_requests=7, osl=2048, iat=50, @@ -459,10 +459,10 @@ def test_creates_client_with_custom_transport(self): assert client.timeout.read == 120.0 def test_creates_client_with_cache_pin_type_none(self): - """Test that create_httpx_client_with_dynamo_hooks passes cache_pin_type=None through.""" + """Test that _create_httpx_client_with_dynamo_hooks passes cache_pin_type=None through.""" from nat.llm.dynamo_llm import _DynamoTransport - client = create_httpx_client_with_dynamo_hooks( + client = _create_httpx_client_with_dynamo_hooks( total_requests=10, osl=512, iat=250, @@ -473,10 +473,10 @@ def test_creates_client_with_cache_pin_type_none(self): assert client._transport._cache_pin_type is None def test_creates_client_with_cache_control_mode_first_only(self): - """Test that create_httpx_client_with_dynamo_hooks passes cache_control_mode through.""" + """Test that _create_httpx_client_with_dynamo_hooks passes cache_control_mode through.""" from nat.llm.dynamo_llm import _DynamoTransport - client = create_httpx_client_with_dynamo_hooks( + client = _create_httpx_client_with_dynamo_hooks( total_requests=10, osl=512, iat=250, diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index dc71c1bce1..ef6189c595 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -33,7 +33,7 @@ from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig from nat.llm.azure_openai_llm import AzureOpenAIModelConfig from nat.llm.dynamo_llm import DynamoModelConfig -from nat.llm.dynamo_llm import create_httpx_client_with_dynamo_hooks +from nat.llm.dynamo_llm import _create_httpx_client_with_dynamo_hooks from nat.llm.huggingface_inference_llm import HuggingFaceInferenceLLMConfig from nat.llm.huggingface_llm import HuggingFaceConfig from nat.llm.litellm_llm import LiteLlmModelConfig @@ -256,39 +256,10 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): # Initialize http_async_client to None for proper cleanup http_async_client = None - # Load prediction trie if configured - prediction_lookup: PredictionTrieLookup | None = None - if llm_config.nvext_prediction_trie_path: - try: - trie_path = Path(llm_config.nvext_prediction_trie_path) - trie = load_prediction_trie(trie_path) - prediction_lookup = PredictionTrieLookup(trie) - logger.info("Loaded prediction trie from %s", llm_config.nvext_prediction_trie_path) - except FileNotFoundError: - logger.warning("Prediction trie file not found: %s", llm_config.nvext_prediction_trie_path) - except Exception as e: - logger.warning("Failed to load prediction trie: %s", e) - try: if llm_config.enable_nvext_hints: - http_async_client = create_httpx_client_with_dynamo_hooks( - total_requests=llm_config.nvext_prefix_total_requests, - osl=llm_config.nvext_prefix_osl, - iat=llm_config.nvext_prefix_iat, - timeout=llm_config.request_timeout, - prediction_lookup=prediction_lookup, - cache_pin_type=llm_config.nvext_cache_pin_type, - cache_control_mode=llm_config.nvext_cache_control_mode, - max_sensitivity=llm_config.nvext_max_sensitivity, - ) + http_async_client = _create_httpx_client_with_dynamo_hooks(llm_config) config_dict["http_async_client"] = http_async_client - logger.info( - "Dynamo agent hints enabled: total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", - llm_config.nvext_prefix_total_requests, - llm_config.nvext_prefix_osl, - llm_config.nvext_prefix_iat, - "loaded" if prediction_lookup else "disabled", - ) # Create the ChatOpenAI client if llm_config.api_type == APITypeEnum.RESPONSES: diff --git a/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py b/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py index a386238a77..5c388ff812 100644 --- a/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py +++ b/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py @@ -83,7 +83,7 @@ def test_dynamo_config_with_nonexistent_trie_path(): assert config.nvext_prediction_trie_path == "/nonexistent/path/trie.json" -@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_dynamo_langchain_loads_trie_and_passes_to_client(mock_chat, mock_create_client, trie_file, mock_builder): """Test that dynamo_langchain loads trie from path and passes PredictionTrieLookup to httpx client.""" @@ -110,7 +110,7 @@ async def test_dynamo_langchain_loads_trie_and_passes_to_client(mock_chat, mock_ mock_httpx_client.aclose.assert_awaited_once() -@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_dynamo_langchain_handles_nonexistent_trie_gracefully(mock_chat, mock_create_client, mock_builder): """Test that dynamo_langchain logs warning and continues when trie file doesn't exist.""" @@ -137,7 +137,7 @@ async def test_dynamo_langchain_handles_nonexistent_trie_gracefully(mock_chat, m mock_httpx_client.aclose.assert_awaited_once() -@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_dynamo_langchain_no_trie_path_means_no_lookup(mock_chat, mock_create_client, mock_builder): """Test that dynamo_langchain passes None when no trie path is configured.""" @@ -161,7 +161,7 @@ async def test_dynamo_langchain_no_trie_path_means_no_lookup(mock_chat, mock_cre mock_httpx_client.aclose.assert_awaited_once() -@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_dynamo_langchain_handles_invalid_trie_file_gracefully(mock_chat, mock_create_client, mock_builder): """Test that dynamo_langchain logs warning and continues when trie file is invalid JSON.""" diff --git a/packages/nvidia_nat_langchain/tests/test_llm_langchain.py b/packages/nvidia_nat_langchain/tests/test_llm_langchain.py index d6bcbabedc..de9f7e5fb1 100644 --- a/packages/nvidia_nat_langchain/tests/test_llm_langchain.py +++ b/packages/nvidia_nat_langchain/tests/test_llm_langchain.py @@ -224,7 +224,7 @@ async def test_basic_creation_without_prefix(self, mock_chat, dynamo_cfg_no_pref assert "http_async_client" not in kwargs assert client is mock_chat.return_value - @patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") + @patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_creation_with_prefix_template(self, mock_chat, @@ -260,7 +260,7 @@ async def test_creation_with_prefix_template(self, # Verify the httpx client was properly closed mock_httpx_client.aclose.assert_awaited_once() - @patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") + @patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_responses_api_branch(self, mock_chat, mock_create_client, dynamo_cfg_responses_api, mock_builder): """When APIType==RESPONSES, special flags should be added.""" @@ -279,7 +279,7 @@ async def test_responses_api_branch(self, mock_chat, mock_create_client, dynamo_ # Verify the httpx client was properly closed mock_httpx_client.aclose.assert_awaited_once() - @patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") + @patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_excludes_dynamo_specific_fields(self, mock_chat, From bebed1b50ac9b4c161aa7742cbede97fab163b5d Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 15:41:22 -0800 Subject: [PATCH 20/24] Move more code into the method Signed-off-by: David Gardner --- .../nvidia_nat_adk/src/nat/plugins/adk/llm.py | 23 +++--- .../nvidia_nat_core/src/nat/llm/dynamo_llm.py | 72 ++++++++++--------- .../src/nat/plugins/langchain/llm.py | 11 +-- 3 files changed, 51 insertions(+), 55 deletions(-) diff --git a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py index f490d9e14d..0331391235 100644 --- a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py +++ b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py @@ -176,6 +176,7 @@ async def dynamo_adk(config: DynamoModelConfig, _builder: Builder): import os from google.adk.models.lite_llm import LiteLlm + from openai import AsyncOpenAI from nat.llm.dynamo_llm import _create_httpx_client_with_dynamo_hooks @@ -200,21 +201,17 @@ async def dynamo_adk(config: DynamoModelConfig, _builder: Builder): if config.base_url: config_dict["api_base"] = config.base_url - http_client = None - if config.enable_nvext_hints: - from openai import AsyncOpenAI + http_client = _create_httpx_client_with_dynamo_hooks(config) - http_client = _create_httpx_client_with_dynamo_hooks(config) + api_key = (config.api_key.get_secret_value() if config.api_key else os.getenv("OPENAI_API_KEY", "unused")) + base_url = config.base_url or os.getenv("OPENAI_BASE_URL", "http://localhost:8000/v1") - api_key = (config.api_key.get_secret_value() if config.api_key else os.getenv("OPENAI_API_KEY", "unused")) - base_url = config.base_url or os.getenv("OPENAI_BASE_URL", "http://localhost:8000/v1") - - openai_client = AsyncOpenAI( - api_key=api_key, - base_url=base_url, - http_client=http_client, - ) - config_dict["client"] = openai_client + openai_client = AsyncOpenAI( + api_key=api_key, + base_url=base_url, + http_client=http_client, + ) + config_dict["client"] = openai_client try: yield LiteLlm(config.model_name, **config_dict) diff --git a/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py b/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py index d161af5ad4..75106f8179 100644 --- a/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py @@ -727,43 +727,47 @@ def _create_httpx_client_with_dynamo_hooks(config: DynamoModelConfig) -> "httpx. import httpx from nat.llm.utils.httpx_utils import _create_http_client - from nat.profiler.prediction_trie import load_prediction_trie - from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup - prediction_lookup: PredictionTrieLookup | None = None - if config.nvext_prediction_trie_path: - try: - trie_path = Path(config.nvext_prediction_trie_path) - trie = load_prediction_trie(trie_path) - prediction_lookup = PredictionTrieLookup(trie) - logger.info("Loaded prediction trie from %s", config.nvext_prediction_trie_path) - except FileNotFoundError: - logger.warning("Prediction trie file not found: %s", config.nvext_prediction_trie_path) - except Exception as e: - logger.exception("Failed to load prediction trie: %s", e) - - # Create base transport and wrap with custom transport - base_transport = httpx.AsyncHTTPTransport() - dynamo_transport = _DynamoTransport( - transport=base_transport, - total_requests=config.nvext_prefix_total_requests, - osl=config.nvext_prefix_osl, - iat=config.nvext_prefix_iat, - prediction_lookup=prediction_lookup, - cache_pin_type=config.nvext_cache_pin_type, - cache_control_mode=config.nvext_cache_control_mode, - max_sensitivity=config.nvext_max_sensitivity, - ) + http_client_kwargs = {} + if config.enable_nvext_hints: + from nat.profiler.prediction_trie import load_prediction_trie + from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup - logger.info( - "Dynamo agent hints enabled: total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", - config.nvext_prefix_total_requests, - config.nvext_prefix_osl, - config.nvext_prefix_iat, - "loaded" if config.nvext_prediction_trie_path else "disabled", - ) + prediction_lookup: PredictionTrieLookup | None = None + if config.nvext_prediction_trie_path: + try: + trie_path = Path(config.nvext_prediction_trie_path) + trie = load_prediction_trie(trie_path) + prediction_lookup = PredictionTrieLookup(trie) + logger.info("Loaded prediction trie from %s", config.nvext_prediction_trie_path) + except FileNotFoundError: + logger.warning("Prediction trie file not found: %s", config.nvext_prediction_trie_path) + except Exception as e: + logger.exception("Failed to load prediction trie: %s", e) + + # Create base transport and wrap with custom transport + base_transport = httpx.AsyncHTTPTransport() + dynamo_transport = _DynamoTransport( + transport=base_transport, + total_requests=config.nvext_prefix_total_requests, + osl=config.nvext_prefix_osl, + iat=config.nvext_prefix_iat, + prediction_lookup=prediction_lookup, + cache_pin_type=config.nvext_cache_pin_type, + cache_control_mode=config.nvext_cache_control_mode, + max_sensitivity=config.nvext_max_sensitivity, + ) + + http_client_kwargs["transport"] = dynamo_transport + logger.info( + "Dynamo agent hints enabled: total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", + config.nvext_prefix_total_requests, + config.nvext_prefix_osl, + config.nvext_prefix_iat, + "loaded" if config.nvext_prediction_trie_path else "disabled", + ) - return _create_http_client(llm_config=config, use_async=True, transport=dynamo_transport) + return _create_http_client(llm_config=config, use_async=True, **http_client_kwargs) # ============================================================================= diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index ef6189c595..3a6f85c89e 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -253,14 +253,10 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): exclude_unset=True, ) - # Initialize http_async_client to None for proper cleanup - http_async_client = None + http_async_client = _create_httpx_client_with_dynamo_hooks(llm_config) + config_dict["http_async_client"] = http_async_client try: - if llm_config.enable_nvext_hints: - http_async_client = _create_httpx_client_with_dynamo_hooks(llm_config) - config_dict["http_async_client"] = http_async_client - # Create the ChatOpenAI client if llm_config.api_type == APITypeEnum.RESPONSES: client = ChatOpenAI(stream_usage=True, use_responses_api=True, use_previous_response_id=True, **config_dict) @@ -270,8 +266,7 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): yield _patch_llm_based_on_config(client, llm_config) finally: # Ensure the httpx client is properly closed to avoid resource leaks - if http_async_client is not None: - await http_async_client.aclose() + await http_async_client.aclose() @register_llm_client(config_type=LiteLlmModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) From e2371958e507bc273e1a412489f184ad3dc5e5e2 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 16:17:14 -0800 Subject: [PATCH 21/24] Update tests Signed-off-by: David Gardner --- packages/nvidia_nat_adk/tests/test_adk_llm.py | 5 +- .../nvidia_nat_core/src/nat/llm/dynamo_llm.py | 14 ++--- .../tests/nat/llm/test_dynamo_llm.py | 53 +++++++++---------- .../tests/test_dynamo_trie_loading.py | 31 ++++------- .../tests/test_llm_langchain.py | 20 ++----- 5 files changed, 44 insertions(+), 79 deletions(-) diff --git a/packages/nvidia_nat_adk/tests/test_adk_llm.py b/packages/nvidia_nat_adk/tests/test_adk_llm.py index 6ca8d4eedc..890fd2b800 100644 --- a/packages/nvidia_nat_adk/tests/test_adk_llm.py +++ b/packages/nvidia_nat_adk/tests/test_adk_llm.py @@ -189,7 +189,7 @@ def dynamo_cfg_with_prefix(self): @patch('google.adk.models.lite_llm.LiteLlm') @pytest.mark.asyncio async def test_basic_creation_without_prefix(self, mock_litellm_class, dynamo_cfg_no_prefix, mock_builder): - """Wrapper should create LiteLlm without client kwarg when nvext hints are disabled.""" + """Wrapper should create LiteLlm with client kwarg (no Dynamo transport when nvext hints disabled).""" mock_llm_instance = MagicMock() mock_litellm_class.return_value = mock_llm_instance @@ -199,7 +199,8 @@ async def test_basic_creation_without_prefix(self, mock_litellm_class, dynamo_cf assert mock_litellm_class.call_args.args[0] == "test-model" assert kwargs["api_base"] == "http://localhost:8000/v1" - assert "client" not in kwargs + # Always passes a client; when enable_nvext_hints=False it has no _DynamoTransport + assert "client" in kwargs assert client is mock_llm_instance @patch('google.adk.models.lite_llm.LiteLlm') diff --git a/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py b/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py index 75106f8179..da126f255b 100644 --- a/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py @@ -705,28 +705,22 @@ async def aclose(self) -> None: def _create_httpx_client_with_dynamo_hooks(config: DynamoModelConfig) -> "httpx.AsyncClient": """ - Create an httpx.AsyncClient with Dynamo hint injection via custom transport. + Create an httpx.AsyncClient, when `config.enable_nvext_hints` is True, Dynamo hint injection via custom transport + is added. This client can be passed to the OpenAI SDK or wrapped in an AsyncOpenAI client for use with LiteLLM/ADK. All hints are injected into ``nvext.agent_hints`` in the request body. Args: - total_requests: Expected number of requests for this prefix - osl: Expected output tokens (raw integer, always sent as int in agent_hints) - iat: Expected inter-arrival time in ms (raw integer, always sent as int) - timeout: HTTP request timeout in seconds - prediction_lookup: Optional PredictionTrieLookup for dynamic hint injection - cache_pin_type: Cache pinning strategy. When set, injects nvext.cache_control with TTL. Set to None to disable. - cache_control_mode: When to inject cache_control: 'always' or 'first_only' per prefix. - max_sensitivity: Maximum latency sensitivity for computing priority + config: LLM Config Returns: An httpx.AsyncClient configured with Dynamo hint injection. """ import httpx - from nat.llm.utils.httpx_utils import _create_http_client + from nat.llm.utils.http_client import _create_http_client http_client_kwargs = {} if config.enable_nvext_hints: diff --git a/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py b/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py index b88f10a754..4aad137186 100644 --- a/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py +++ b/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py @@ -412,13 +412,9 @@ class TestCreateHttpxClient: """Tests for _create_httpx_client_with_dynamo_hooks.""" def test_uses_custom_timeout(self): - """Test that the function uses the provided timeout.""" - client = _create_httpx_client_with_dynamo_hooks( - total_requests=10, - osl=512, - iat=250, - timeout=120.0, - ) + """Test that the function uses the provided timeout from config.""" + config = DynamoModelConfig(model_name="test", request_timeout=120.0) + client = _create_httpx_client_with_dynamo_hooks(config) assert client.timeout.connect == 120.0 assert client.timeout.read == 120.0 @@ -426,25 +422,24 @@ def test_uses_custom_timeout(self): def test_uses_default_timeout(self): """Test that the function uses default timeout when not specified.""" - client = _create_httpx_client_with_dynamo_hooks( - total_requests=10, - osl=512, - iat=250, - ) + config = DynamoModelConfig(model_name="test") + client = _create_httpx_client_with_dynamo_hooks(config) assert client.timeout.connect == 600.0 def test_creates_client_with_custom_transport(self): - """Test that _create_httpx_client_with_dynamo_hooks uses _DynamoTransport.""" + """Test that _create_httpx_client_with_dynamo_hooks uses _DynamoTransport when enable_nvext_hints=True.""" from nat.llm.dynamo_llm import _DynamoTransport - client = _create_httpx_client_with_dynamo_hooks( - total_requests=7, - osl=2048, - iat=50, - timeout=120.0, - prediction_lookup=None, + config = DynamoModelConfig( + model_name="test", + enable_nvext_hints=True, + nvext_prefix_total_requests=7, + nvext_prefix_osl=2048, + nvext_prefix_iat=50, + request_timeout=120.0, ) + client = _create_httpx_client_with_dynamo_hooks(config) # Verify client uses custom transport assert isinstance(client._transport, _DynamoTransport) @@ -462,12 +457,12 @@ def test_creates_client_with_cache_pin_type_none(self): """Test that _create_httpx_client_with_dynamo_hooks passes cache_pin_type=None through.""" from nat.llm.dynamo_llm import _DynamoTransport - client = _create_httpx_client_with_dynamo_hooks( - total_requests=10, - osl=512, - iat=250, - cache_pin_type=None, + config = DynamoModelConfig( + model_name="test", + enable_nvext_hints=True, + nvext_cache_pin_type=None, ) + client = _create_httpx_client_with_dynamo_hooks(config) assert isinstance(client._transport, _DynamoTransport) assert client._transport._cache_pin_type is None @@ -476,12 +471,12 @@ def test_creates_client_with_cache_control_mode_first_only(self): """Test that _create_httpx_client_with_dynamo_hooks passes cache_control_mode through.""" from nat.llm.dynamo_llm import _DynamoTransport - client = _create_httpx_client_with_dynamo_hooks( - total_requests=10, - osl=512, - iat=250, - cache_control_mode=CacheControlMode.FIRST_ONLY, + config = DynamoModelConfig( + model_name="test", + enable_nvext_hints=True, + nvext_cache_control_mode=CacheControlMode.FIRST_ONLY, ) + client = _create_httpx_client_with_dynamo_hooks(config) assert isinstance(client._transport, _DynamoTransport) assert client._transport._cache_control_mode == CacheControlMode.FIRST_ONLY diff --git a/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py b/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py index 5c388ff812..205e8e31ea 100644 --- a/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py +++ b/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py @@ -28,7 +28,6 @@ from nat.profiler.prediction_trie.data_models import LLMCallPrediction from nat.profiler.prediction_trie.data_models import PredictionMetrics from nat.profiler.prediction_trie.data_models import PredictionTrieNode -from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup @pytest.fixture(name="trie_file") @@ -86,7 +85,7 @@ def test_dynamo_config_with_nonexistent_trie_path(): @patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_dynamo_langchain_loads_trie_and_passes_to_client(mock_chat, mock_create_client, trie_file, mock_builder): - """Test that dynamo_langchain loads trie from path and passes PredictionTrieLookup to httpx client.""" + """Test that dynamo_langchain calls _create_httpx_client_with_dynamo_hooks with config that has trie path.""" mock_httpx_client = MagicMock() mock_httpx_client.aclose = AsyncMock() mock_create_client.return_value = mock_httpx_client @@ -101,11 +100,8 @@ async def test_dynamo_langchain_loads_trie_and_passes_to_client(mock_chat, mock_ ) async with dynamo_langchain(config, mock_builder): - # Verify httpx client was created with prediction_lookup - mock_create_client.assert_called_once() - call_kwargs = mock_create_client.call_args.kwargs - assert "prediction_lookup" in call_kwargs - assert isinstance(call_kwargs["prediction_lookup"], PredictionTrieLookup) + mock_create_client.assert_called_once_with(config) + assert mock_create_client.call_args[0][0].nvext_prediction_trie_path == trie_file mock_httpx_client.aclose.assert_awaited_once() @@ -113,7 +109,7 @@ async def test_dynamo_langchain_loads_trie_and_passes_to_client(mock_chat, mock_ @patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_dynamo_langchain_handles_nonexistent_trie_gracefully(mock_chat, mock_create_client, mock_builder): - """Test that dynamo_langchain logs warning and continues when trie file doesn't exist.""" + """Test that dynamo_langchain calls client creation with config when trie path doesn't exist.""" mock_httpx_client = MagicMock() mock_httpx_client.aclose = AsyncMock() mock_create_client.return_value = mock_httpx_client @@ -127,12 +123,8 @@ async def test_dynamo_langchain_handles_nonexistent_trie_gracefully(mock_chat, m nvext_prediction_trie_path="/nonexistent/path/trie.json", ) - # Should not raise an exception async with dynamo_langchain(config, mock_builder): - # Verify httpx client was created with prediction_lookup=None - mock_create_client.assert_called_once() - call_kwargs = mock_create_client.call_args.kwargs - assert call_kwargs["prediction_lookup"] is None + mock_create_client.assert_called_once_with(config) mock_httpx_client.aclose.assert_awaited_once() @@ -140,7 +132,7 @@ async def test_dynamo_langchain_handles_nonexistent_trie_gracefully(mock_chat, m @patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_dynamo_langchain_no_trie_path_means_no_lookup(mock_chat, mock_create_client, mock_builder): - """Test that dynamo_langchain passes None when no trie path is configured.""" + """Test that dynamo_langchain calls client creation with config when no trie path is configured.""" mock_httpx_client = MagicMock() mock_httpx_client.aclose = AsyncMock() mock_create_client.return_value = mock_httpx_client @@ -154,9 +146,8 @@ async def test_dynamo_langchain_no_trie_path_means_no_lookup(mock_chat, mock_cre ) async with dynamo_langchain(config, mock_builder): - mock_create_client.assert_called_once() - call_kwargs = mock_create_client.call_args.kwargs - assert call_kwargs["prediction_lookup"] is None + mock_create_client.assert_called_once_with(config) + assert mock_create_client.call_args[0][0].nvext_prediction_trie_path is None mock_httpx_client.aclose.assert_awaited_once() @@ -183,12 +174,8 @@ async def test_dynamo_langchain_handles_invalid_trie_file_gracefully(mock_chat, nvext_prediction_trie_path=invalid_trie_path, ) - # Should not raise an exception async with dynamo_langchain(config, mock_builder): - # Verify httpx client was created with prediction_lookup=None - mock_create_client.assert_called_once() - call_kwargs = mock_create_client.call_args.kwargs - assert call_kwargs["prediction_lookup"] is None + mock_create_client.assert_called_once_with(config) mock_httpx_client.aclose.assert_awaited_once() finally: diff --git a/packages/nvidia_nat_langchain/tests/test_llm_langchain.py b/packages/nvidia_nat_langchain/tests/test_llm_langchain.py index de9f7e5fb1..e30e9722a2 100644 --- a/packages/nvidia_nat_langchain/tests/test_llm_langchain.py +++ b/packages/nvidia_nat_langchain/tests/test_llm_langchain.py @@ -25,8 +25,6 @@ from nat.builder.framework_enum import LLMFrameworkEnum from nat.data_models.llm import APITypeEnum from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig -from nat.llm.dynamo_llm import CacheControlMode -from nat.llm.dynamo_llm import CachePinType from nat.llm.dynamo_llm import DynamoModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig @@ -212,7 +210,7 @@ def dynamo_cfg_responses_api(self): @patch("langchain_openai.ChatOpenAI") async def test_basic_creation_without_prefix(self, mock_chat, dynamo_cfg_no_prefix, mock_builder): - """Wrapper should create ChatOpenAI without custom httpx client when nvext hints disabled.""" + """Wrapper should create ChatOpenAI with httpx client (no Dynamo transport when nvext hints disabled).""" async with dynamo_langchain(dynamo_cfg_no_prefix, mock_builder) as client: mock_chat.assert_called_once() kwargs = mock_chat.call_args.kwargs @@ -220,8 +218,8 @@ async def test_basic_creation_without_prefix(self, mock_chat, dynamo_cfg_no_pref assert kwargs["model"] == "test-model" assert kwargs["base_url"] == "http://localhost:8000/v1" assert kwargs["stream_usage"] is True - # Should NOT have custom httpx client - assert "http_async_client" not in kwargs + # Always passes an httpx client; when enable_nvext_hints=False it has no _DynamoTransport + assert "http_async_client" in kwargs assert client is mock_chat.return_value @patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @@ -237,17 +235,7 @@ async def test_creation_with_prefix_template(self, mock_create_client.return_value = mock_httpx_client async with dynamo_langchain(dynamo_cfg_with_prefix, mock_builder) as client: - # Verify httpx client was created with correct parameters - mock_create_client.assert_called_once_with( - total_requests=15, - osl=2048, - iat=50, - timeout=300.0, - prediction_lookup=None, - cache_pin_type=CachePinType.EPHEMERAL, - cache_control_mode=CacheControlMode.ALWAYS, - max_sensitivity=1000, - ) + mock_create_client.assert_called_once_with(dynamo_cfg_with_prefix) # Verify ChatOpenAI was called with the custom httpx client mock_chat.assert_called_once() From 0c9acde5d1476b3b2e7f8e9bbaf4eea8004a6d70 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 16:20:07 -0800 Subject: [PATCH 22/24] Remove unused imports Signed-off-by: David Gardner --- packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index 3a6f85c89e..7bf0be2718 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -18,7 +18,6 @@ import os from collections.abc import AsyncIterator from collections.abc import Sequence -from pathlib import Path from typing import TYPE_CHECKING from typing import Any from typing import TypeVar @@ -242,8 +241,6 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): """ from langchain_openai import ChatOpenAI - from nat.profiler.prediction_trie import load_prediction_trie - from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup # Build config dict excluding Dynamo-specific and NAT-specific fields config_dict = llm_config.model_dump( From 0f9f02e28a93f7edfcfefce5e95b42181ce7e8f7 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 16:20:41 -0800 Subject: [PATCH 23/24] formatting Signed-off-by: David Gardner --- packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index 7bf0be2718..e51fc948a2 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -241,7 +241,6 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): """ from langchain_openai import ChatOpenAI - # Build config dict excluding Dynamo-specific and NAT-specific fields config_dict = llm_config.model_dump( exclude={"type", "thinking", "api_type", *DynamoModelConfig.get_dynamo_field_names()}, From 809d9c0b085ced4dde005c3f02f8f8170fc0f72b Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 6 Mar 2026 17:02:03 -0800 Subject: [PATCH 24/24] Add new tests Signed-off-by: David Gardner --- .../tests/nat/llm/utils/test_hooks.py | 54 ++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py b/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py index a57d9212b0..2a512556e7 100644 --- a/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py +++ b/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py @@ -15,6 +15,7 @@ """Unit and integration tests for LLM HTTP event hooks.""" from unittest.mock import MagicMock +from unittest.mock import patch import pytest from pydantic import Field @@ -26,7 +27,6 @@ from nat.llm.utils.hooks import _create_metadata_injection_client -# TODO: need some tests for handling for request_timeout and verify_ssl class LLMConfig(LLMBaseConfig): pass @@ -121,6 +121,58 @@ async def test_creates_client_with_event_hooks(self): await client.aclose() + @pytest.mark.parametrize( + "llm_config,expected_timeout", + [ + (LLMConfig(), None), + (LLMConfigWithTimeout(), None), + (LLMConfigWithTimeout(request_timeout=45), 45), + ], + ids=["no_request_timeout_attr", "request_timeout_none", "request_timeout_45"], + ) + async def test_request_timeout_passed_to_client(self, llm_config, expected_timeout): + """Client receives timeout from config when request_timeout is set.""" + import httpx + captured: dict = {} + real_async_client = httpx.AsyncClient + def capture_async_client(*args, **kwargs): + captured.clear() + captured.update(kwargs) + return real_async_client(*args, **kwargs) + with patch.object(httpx, "AsyncClient", side_effect=capture_async_client): + client = _create_metadata_injection_client(llm_config=llm_config) + if expected_timeout is None: + assert "timeout" not in captured + else: + assert captured["timeout"] == expected_timeout + await client.aclose() + + @pytest.mark.parametrize( + "llm_config,expected_verify", + [ + (LLMConfig(), None), + (LLMConfigWithSSL(verify_ssl=True), True), + (LLMConfigWithSSL(verify_ssl=False), False), + ], + ids=["no_verify_ssl_attr", "verify_ssl_true", "verify_ssl_false"], + ) + async def test_verify_ssl_passed_to_client(self, llm_config, expected_verify): + """Client receives verify from config when verify_ssl is set.""" + import httpx + captured: dict = {} + real_async_client = httpx.AsyncClient + def capture_async_client(*args, **kwargs): + captured.clear() + captured.update(kwargs) + return real_async_client(*args, **kwargs) + with patch.object(httpx, "AsyncClient", side_effect=capture_async_client): + client = _create_metadata_injection_client(llm_config=llm_config) + if expected_verify is None: + assert "verify" not in captured + else: + assert captured["verify"] is expected_verify + await client.aclose() + class TestMetadataInjectionIntegration: """Integration tests with mock HTTP server."""