diff --git a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py index c68a0ca251..0331391235 100644 --- a/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py +++ b/packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py @@ -24,6 +24,7 @@ from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _handle_litellm_verify_ssl # ADK uses litellm under the hood from nat.utils.responses_api import validate_no_responses_api logger = logging.getLogger(__name__) @@ -43,15 +44,16 @@ async def azure_openai_adk(config: AzureOpenAIModelConfig, _builder: Builder): config_dict = config.model_dump( exclude={ - "type", - "max_retries", - "thinking", - "azure_endpoint", + "api_type", "azure_deployment", - "model_name", + "azure_endpoint", + "max_retries", "model", - "api_type", - "request_timeout" + "model_name", + "request_timeout", + "thinking", + "type", + "verify_ssl" }, by_alias=True, exclude_none=True, @@ -63,6 +65,7 @@ async def azure_openai_adk(config: AzureOpenAIModelConfig, _builder: Builder): config_dict["timeout"] = config.request_timeout config_dict["api_version"] = config.api_version + _handle_litellm_verify_ssl(config.verify_ssl) yield LiteLlm(f"azure/{config.azure_deployment}", **config_dict) @@ -73,8 +76,9 @@ async def litellm_adk(litellm_config: LiteLlmModelConfig, _builder: Builder): validate_no_responses_api(litellm_config, LLMFrameworkEnum.ADK) + _handle_litellm_verify_ssl(litellm_config.verify_ssl) yield LiteLlm(**litellm_config.model_dump( - exclude={"type", "max_retries", "thinking", "api_type"}, + exclude={"api_type", "max_retries", "thinking", "type", "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, @@ -102,7 +106,7 @@ async def nim_adk(config: NIMModelConfig, _builder: Builder): os.environ["NVIDIA_NIM_API_KEY"] = api_key config_dict = config.model_dump( - exclude={"type", "max_retries", "thinking", "model_name", "model", "base_url", "api_type"}, + exclude={"api_type", "base_url", "max_retries", "model", "model_name", "thinking", "type", "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, @@ -110,6 +114,8 @@ async def nim_adk(config: NIMModelConfig, _builder: Builder): if config.base_url: config_dict["api_base"] = config.base_url + _handle_litellm_verify_ssl(config.verify_ssl) + yield LiteLlm(f"nvidia_nim/{config.model_name}", **config_dict) @@ -126,7 +132,17 @@ async def openai_adk(config: OpenAIModelConfig, _builder: Builder): validate_no_responses_api(config, LLMFrameworkEnum.ADK) config_dict = config.model_dump( - exclude={"type", "max_retries", "thinking", "model_name", "model", "base_url", "api_type", "request_timeout"}, + exclude={ + "api_type", + "base_url", + "max_retries", + "model", + "model_name", + "request_timeout", + "thinking", + "type", + "verify_ssl" + }, by_alias=True, exclude_none=True, exclude_unset=True, @@ -139,6 +155,8 @@ async def openai_adk(config: OpenAIModelConfig, _builder: Builder): if config.request_timeout is not None: config_dict["timeout"] = config.request_timeout + _handle_litellm_verify_ssl(config.verify_ssl) + yield LiteLlm(config.model_name, **config_dict) @@ -158,8 +176,9 @@ async def dynamo_adk(config: DynamoModelConfig, _builder: Builder): import os from google.adk.models.lite_llm import LiteLlm + from openai import AsyncOpenAI - from nat.llm.dynamo_llm import create_httpx_client_with_dynamo_hooks + from nat.llm.dynamo_llm import _create_httpx_client_with_dynamo_hooks validate_no_responses_api(config, LLMFrameworkEnum.ADK) @@ -182,53 +201,20 @@ async def dynamo_adk(config: DynamoModelConfig, _builder: Builder): if config.base_url: config_dict["api_base"] = config.base_url - if config.enable_nvext_hints: - from pathlib import Path - - from openai import AsyncOpenAI - - from nat.profiler.prediction_trie import load_prediction_trie - from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup - - prediction_lookup: PredictionTrieLookup | None = None - if config.nvext_prediction_trie_path: - try: - trie_path = Path(config.nvext_prediction_trie_path) - trie = load_prediction_trie(trie_path) - prediction_lookup = PredictionTrieLookup(trie) - logger.info("Loaded prediction trie from %s", config.nvext_prediction_trie_path) - except FileNotFoundError: - logger.warning("Prediction trie file not found: %s", config.nvext_prediction_trie_path) - except Exception as e: - logger.exception("Failed to load prediction trie: %s", e) - - http_client = create_httpx_client_with_dynamo_hooks( - total_requests=config.nvext_prefix_total_requests, - osl=config.nvext_prefix_osl, - iat=config.nvext_prefix_iat, - timeout=config.request_timeout, - prediction_lookup=prediction_lookup, - cache_pin_type=config.nvext_cache_pin_type, - cache_control_mode=config.nvext_cache_control_mode, - max_sensitivity=config.nvext_max_sensitivity, - ) - - api_key = (config.api_key.get_secret_value() if config.api_key else os.getenv("OPENAI_API_KEY", "unused")) - base_url = config.base_url or os.getenv("OPENAI_BASE_URL", "http://localhost:8000/v1") - - openai_client = AsyncOpenAI( - api_key=api_key, - base_url=base_url, - http_client=http_client, - ) - config_dict["client"] = openai_client - - logger.info( - "Dynamo agent hints enabled for ADK: total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", - config.nvext_prefix_total_requests, - config.nvext_prefix_osl, - config.nvext_prefix_iat, - "loaded" if prediction_lookup else "disabled", - ) + http_client = _create_httpx_client_with_dynamo_hooks(config) - yield LiteLlm(config.model_name, **config_dict) + api_key = (config.api_key.get_secret_value() if config.api_key else os.getenv("OPENAI_API_KEY", "unused")) + base_url = config.base_url or os.getenv("OPENAI_BASE_URL", "http://localhost:8000/v1") + + openai_client = AsyncOpenAI( + api_key=api_key, + base_url=base_url, + http_client=http_client, + ) + config_dict["client"] = openai_client + + try: + yield LiteLlm(config.model_name, **config_dict) + finally: + if http_client is not None: + await http_client.aclose() diff --git a/packages/nvidia_nat_adk/tests/test_adk_llm.py b/packages/nvidia_nat_adk/tests/test_adk_llm.py index 6ca8d4eedc..890fd2b800 100644 --- a/packages/nvidia_nat_adk/tests/test_adk_llm.py +++ b/packages/nvidia_nat_adk/tests/test_adk_llm.py @@ -189,7 +189,7 @@ def dynamo_cfg_with_prefix(self): @patch('google.adk.models.lite_llm.LiteLlm') @pytest.mark.asyncio async def test_basic_creation_without_prefix(self, mock_litellm_class, dynamo_cfg_no_prefix, mock_builder): - """Wrapper should create LiteLlm without client kwarg when nvext hints are disabled.""" + """Wrapper should create LiteLlm with client kwarg (no Dynamo transport when nvext hints disabled).""" mock_llm_instance = MagicMock() mock_litellm_class.return_value = mock_llm_instance @@ -199,7 +199,8 @@ async def test_basic_creation_without_prefix(self, mock_litellm_class, dynamo_cf assert mock_litellm_class.call_args.args[0] == "test-model" assert kwargs["api_base"] == "http://localhost:8000/v1" - assert "client" not in kwargs + # Always passes a client; when enable_nvext_hints=False it has no _DynamoTransport + assert "client" in kwargs assert client is mock_llm_instance @patch('google.adk.models.lite_llm.LiteLlm') diff --git a/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py b/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py index fa4b70bc63..fb8b4f74b7 100644 --- a/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py +++ b/packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py @@ -14,7 +14,7 @@ # limitations under the License. import os -from typing import TypeVar +import typing from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum @@ -27,6 +27,7 @@ from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _create_http_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -34,7 +35,7 @@ from nat.utils.responses_api import validate_no_responses_api from nat.utils.type_utils import override -ModelType = TypeVar("ModelType") +ModelType = typing.TypeVar("ModelType") def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType: @@ -88,14 +89,24 @@ async def nim_agno(llm_config: NIMModelConfig, _builder: Builder): config_obj = { **llm_config.model_dump( - exclude={"type", "model_name", "thinking", "api_type"}, + exclude={ + "api_type", + "model_name", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, exclude_unset=True, ), + "http_client": + _create_http_client(llm_config), + "id": + llm_config.model_name } - client = Nvidia(**config_obj, id=llm_config.model_name) + client = Nvidia(**config_obj) yield _patch_llm_based_on_config(client, llm_config) @@ -108,11 +119,22 @@ async def openai_agno(llm_config: OpenAIModelConfig, _builder: Builder): config_obj = { **llm_config.model_dump( - exclude={"type", "model_name", "thinking", "api_type", "api_key", "base_url", "request_timeout"}, + exclude={ + "api_key", + "api_type", + "base_url", + "model_name", + "request_timeout", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, exclude_unset=True, ), + "http_client": + _create_http_client(llm_config), } if (api_key := get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY")): @@ -139,11 +161,18 @@ async def litellm_agno(llm_config: LiteLlmModelConfig, _builder: Builder): client = LiteLLM( **llm_config.model_dump( - exclude={"type", "thinking", "model_name", "api_type"}, + exclude={ + "api_type", + "model_name", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, exclude_unset=True, ), + http_client=_create_http_client(llm_config), id=llm_config.model_name, ) diff --git a/packages/nvidia_nat_autogen/src/nat/plugins/autogen/llm.py b/packages/nvidia_nat_autogen/src/nat/plugins/autogen/llm.py index 7fb64ed6c7..02727e9da4 100644 --- a/packages/nvidia_nat_autogen/src/nat/plugins/autogen/llm.py +++ b/packages/nvidia_nat_autogen/src/nat/plugins/autogen/llm.py @@ -48,6 +48,7 @@ from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _create_http_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -147,10 +148,20 @@ async def openai_autogen(llm_config: OpenAIModelConfig, _builder: Builder) -> As # Extract AutoGen-compatible configuration config_obj = { **llm_config.model_dump( - exclude={"type", "model_name", "thinking", "api_key", "base_url", "request_timeout"}, + exclude={ + "api_key", + "base_url", + "model_name", + "request_timeout", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, ), + "http_client": + _create_http_client(llm_config) } if (api_key := get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY")): @@ -201,12 +212,14 @@ async def azure_openai_autogen(llm_config: AzureOpenAIModelConfig, config_obj = { "api_key": llm_config.api_key, - "base_url": - f"{llm_config.azure_endpoint}/openai/deployments/{llm_config.azure_deployment}", "api_version": llm_config.api_version, + "base_url": + f"{llm_config.azure_endpoint}/openai/deployments/{llm_config.azure_deployment}", + "http_client": + _create_http_client(llm_config), **llm_config.model_dump( - exclude={"type", "azure_deployment", "thinking", "azure_endpoint", "api_version", "request_timeout"}, + exclude={"api_version", "azure_deployment", "azure_endpoint", "request_timeout", "thinking", "type"}, by_alias=True, exclude_none=True, ), @@ -326,8 +339,10 @@ async def nim_autogen(llm_config: NIMModelConfig, _builder: Builder) -> AsyncGen # Extract NIM configuration for OpenAI-compatible client config_obj = { + "http_client": + _create_http_client(llm_config), **llm_config.model_dump( - exclude={"type", "model_name", "thinking"}, + exclude={"model_name", "thinking", "type"}, by_alias=True, exclude_none=True, exclude_unset=True, @@ -386,8 +401,10 @@ async def litellm_autogen(llm_config: LiteLlmModelConfig, _builder: Builder) -> # Extract LiteLLM configuration for OpenAI-compatible client config_obj = { + "http_client": + _create_http_client(llm_config), **llm_config.model_dump( - exclude={"type", "model_name", "thinking"}, + exclude={"model_name", "thinking", "type"}, by_alias=True, exclude_none=True, exclude_unset=True, diff --git a/packages/nvidia_nat_core/src/nat/data_models/llm.py b/packages/nvidia_nat_core/src/nat/data_models/llm.py index a64422825a..344d9b0915 100644 --- a/packages/nvidia_nat_core/src/nat/data_models/llm.py +++ b/packages/nvidia_nat_core/src/nat/data_models/llm.py @@ -16,6 +16,7 @@ import typing from enum import StrEnum +from pydantic import BaseModel from pydantic import Field from .common import BaseModelRegistryTag @@ -39,3 +40,12 @@ class LLMBaseConfig(TypedBaseModel, BaseModelRegistryTag): LLMBaseConfigT = typing.TypeVar("LLMBaseConfigT", bound=LLMBaseConfig) + + +class SSLVerificationMixin(BaseModel): + """Mixin for SSL verification configuration.""" + + verify_ssl: bool = Field( + default=True, + description="Whether to verify SSL certificates when making API calls to the LLM provider. Defaults to True.", + ) diff --git a/packages/nvidia_nat_core/src/nat/llm/azure_openai_llm.py b/packages/nvidia_nat_core/src/nat/llm/azure_openai_llm.py index 08abc67244..23ed772536 100644 --- a/packages/nvidia_nat_core/src/nat/llm/azure_openai_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/azure_openai_llm.py @@ -23,6 +23,7 @@ from nat.cli.register_workflow import register_llm_provider from nat.data_models.common import OptionalSecretStr from nat.data_models.llm import LLMBaseConfig +from nat.data_models.llm import SSLVerificationMixin from nat.data_models.optimizable import OptimizableField from nat.data_models.optimizable import SearchSpace from nat.data_models.retry_mixin import RetryMixin @@ -33,6 +34,7 @@ class AzureOpenAIModelConfig( LLMBaseConfig, RetryMixin, ThinkingMixin, + SSLVerificationMixin, name="azure_openai", ): """An Azure OpenAI LLM provider to be used with an LLM client.""" diff --git a/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py b/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py index f7e03a05c5..da126f255b 100644 --- a/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py @@ -64,6 +64,7 @@ from contextlib import contextmanager from contextvars import ContextVar from enum import StrEnum +from pathlib import Path from typing import TYPE_CHECKING import httpx @@ -702,55 +703,65 @@ async def aclose(self) -> None: # ============================================================================= -def create_httpx_client_with_dynamo_hooks( - total_requests: int, - osl: int, - iat: int, - timeout: float = 600.0, - prediction_lookup: "PredictionTrieLookup | None" = None, - cache_pin_type: CachePinType | None = CachePinType.EPHEMERAL, - cache_control_mode: CacheControlMode = CacheControlMode.ALWAYS, - max_sensitivity: int = 1000, -) -> "httpx.AsyncClient": +def _create_httpx_client_with_dynamo_hooks(config: DynamoModelConfig) -> "httpx.AsyncClient": """ - Create an httpx.AsyncClient with Dynamo hint injection via custom transport. + Create an httpx.AsyncClient, when `config.enable_nvext_hints` is True, Dynamo hint injection via custom transport + is added. This client can be passed to the OpenAI SDK or wrapped in an AsyncOpenAI client for use with LiteLLM/ADK. All hints are injected into ``nvext.agent_hints`` in the request body. Args: - total_requests: Expected number of requests for this prefix - osl: Expected output tokens (raw integer, always sent as int in agent_hints) - iat: Expected inter-arrival time in ms (raw integer, always sent as int) - timeout: HTTP request timeout in seconds - prediction_lookup: Optional PredictionTrieLookup for dynamic hint injection - cache_pin_type: Cache pinning strategy. When set, injects nvext.cache_control with TTL. Set to None to disable. - cache_control_mode: When to inject cache_control: 'always' or 'first_only' per prefix. - max_sensitivity: Maximum latency sensitivity for computing priority + config: LLM Config Returns: An httpx.AsyncClient configured with Dynamo hint injection. """ import httpx - # Create base transport and wrap with custom transport - base_transport = httpx.AsyncHTTPTransport() - dynamo_transport = _DynamoTransport( - transport=base_transport, - total_requests=total_requests, - osl=osl, - iat=iat, - prediction_lookup=prediction_lookup, - cache_pin_type=cache_pin_type, - cache_control_mode=cache_control_mode, - max_sensitivity=max_sensitivity, - ) + from nat.llm.utils.http_client import _create_http_client - return httpx.AsyncClient( - transport=dynamo_transport, - timeout=httpx.Timeout(timeout), - ) + http_client_kwargs = {} + if config.enable_nvext_hints: + from nat.profiler.prediction_trie import load_prediction_trie + from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + + prediction_lookup: PredictionTrieLookup | None = None + if config.nvext_prediction_trie_path: + try: + trie_path = Path(config.nvext_prediction_trie_path) + trie = load_prediction_trie(trie_path) + prediction_lookup = PredictionTrieLookup(trie) + logger.info("Loaded prediction trie from %s", config.nvext_prediction_trie_path) + except FileNotFoundError: + logger.warning("Prediction trie file not found: %s", config.nvext_prediction_trie_path) + except Exception as e: + logger.exception("Failed to load prediction trie: %s", e) + + # Create base transport and wrap with custom transport + base_transport = httpx.AsyncHTTPTransport() + dynamo_transport = _DynamoTransport( + transport=base_transport, + total_requests=config.nvext_prefix_total_requests, + osl=config.nvext_prefix_osl, + iat=config.nvext_prefix_iat, + prediction_lookup=prediction_lookup, + cache_pin_type=config.nvext_cache_pin_type, + cache_control_mode=config.nvext_cache_control_mode, + max_sensitivity=config.nvext_max_sensitivity, + ) + + http_client_kwargs["transport"] = dynamo_transport + logger.info( + "Dynamo agent hints enabled: total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", + config.nvext_prefix_total_requests, + config.nvext_prefix_osl, + config.nvext_prefix_iat, + "loaded" if config.nvext_prediction_trie_path else "disabled", + ) + + return _create_http_client(llm_config=config, use_async=True, **http_client_kwargs) # ============================================================================= diff --git a/packages/nvidia_nat_core/src/nat/llm/litellm_llm.py b/packages/nvidia_nat_core/src/nat/llm/litellm_llm.py index 892bb3a2a6..95f38e6224 100644 --- a/packages/nvidia_nat_core/src/nat/llm/litellm_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/litellm_llm.py @@ -24,6 +24,7 @@ from nat.cli.register_workflow import register_llm_provider from nat.data_models.common import OptionalSecretStr from nat.data_models.llm import LLMBaseConfig +from nat.data_models.llm import SSLVerificationMixin from nat.data_models.optimizable import OptimizableField from nat.data_models.optimizable import OptimizableMixin from nat.data_models.optimizable import SearchSpace @@ -36,6 +37,7 @@ class LiteLlmModelConfig( OptimizableMixin, RetryMixin, ThinkingMixin, + SSLVerificationMixin, name="litellm", ): """A LiteLlm provider to be used with an LLM client.""" diff --git a/packages/nvidia_nat_core/src/nat/llm/nim_llm.py b/packages/nvidia_nat_core/src/nat/llm/nim_llm.py index a3257b84d5..c701e42caa 100644 --- a/packages/nvidia_nat_core/src/nat/llm/nim_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/nim_llm.py @@ -23,6 +23,7 @@ from nat.cli.register_workflow import register_llm_provider from nat.data_models.common import OptionalSecretStr from nat.data_models.llm import LLMBaseConfig +from nat.data_models.llm import SSLVerificationMixin from nat.data_models.optimizable import OptimizableField from nat.data_models.optimizable import OptimizableMixin from nat.data_models.optimizable import SearchSpace @@ -30,7 +31,7 @@ from nat.data_models.thinking_mixin import ThinkingMixin -class NIMModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="nim"): +class NIMModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, SSLVerificationMixin, name="nim"): """An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client.""" model_config = ConfigDict(protected_namespaces=(), extra="allow") diff --git a/packages/nvidia_nat_core/src/nat/llm/openai_llm.py b/packages/nvidia_nat_core/src/nat/llm/openai_llm.py index 4e32075d41..a268bd77d1 100644 --- a/packages/nvidia_nat_core/src/nat/llm/openai_llm.py +++ b/packages/nvidia_nat_core/src/nat/llm/openai_llm.py @@ -22,6 +22,7 @@ from nat.cli.register_workflow import register_llm_provider from nat.data_models.common import OptionalSecretStr from nat.data_models.llm import LLMBaseConfig +from nat.data_models.llm import SSLVerificationMixin from nat.data_models.optimizable import OptimizableField from nat.data_models.optimizable import OptimizableMixin from nat.data_models.optimizable import SearchSpace @@ -29,7 +30,8 @@ from nat.data_models.thinking_mixin import ThinkingMixin -class OpenAIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="openai"): +class OpenAIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, SSLVerificationMixin, + name="openai"): """An OpenAI LLM provider to be used with an LLM client.""" model_config = ConfigDict(protected_namespaces=(), extra="allow") diff --git a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py index b0559f2609..5699410584 100644 --- a/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py +++ b/packages/nvidia_nat_core/src/nat/llm/utils/hooks.py @@ -26,12 +26,15 @@ if TYPE_CHECKING: import httpx + from nat.data_models.llm import LLMBaseConfig + from nat.llm.utils.constants import LLMHeaderPrefix +from nat.llm.utils.http_client import _create_http_client logger = logging.getLogger(__name__) -def create_metadata_injection_client(timeout: float = 600.0) -> "httpx.AsyncClient": +def _create_metadata_injection_client(llm_config: "LLMBaseConfig") -> "httpx.AsyncClient": """ Httpx event hook that injects custom metadata as HTTP headers. @@ -39,7 +42,7 @@ def create_metadata_injection_client(timeout: float = 600.0) -> "httpx.AsyncClie enabling end-to-end traceability in LLM server logs. Args: - timeout: HTTP request timeout in seconds + llm_config: LLM configuration object Returns: An httpx.AsyncClient configured with metadata header injection @@ -63,7 +66,4 @@ async def on_request(request: httpx.Request) -> None: except Exception as e: logger.debug("Could not inject custom metadata headers, request will proceed without them: %s", e) - return httpx.AsyncClient( - event_hooks={"request": [on_request]}, - timeout=httpx.Timeout(timeout), - ) + return _create_http_client(llm_config=llm_config, use_async=True, event_hooks={"request": [on_request]}) diff --git a/packages/nvidia_nat_core/src/nat/llm/utils/http_client.py b/packages/nvidia_nat_core/src/nat/llm/utils/http_client.py new file mode 100644 index 0000000000..9d9bd75552 --- /dev/null +++ b/packages/nvidia_nat_core/src/nat/llm/utils/http_client.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) `2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import httpx + + from nat.data_models.llm import LLMBaseConfig + + +def _create_http_client(llm_config: "LLMBaseConfig", + use_async: bool = True, + **kwargs) -> "httpx.AsyncClient | httpx.Client": + """Create an httpx.AsyncClient based on LLM configuration.""" + import httpx + + def _set_kwarg(kwarg_name: str, config_attr: str): + if kwarg_name not in kwargs and getattr(llm_config, config_attr, None) is not None: + kwargs[kwarg_name] = getattr(llm_config, config_attr) + + _set_kwarg("verify", "verify_ssl") + _set_kwarg("timeout", "request_timeout") + + if use_async: + client_class = httpx.AsyncClient + else: + client_class = httpx.Client + + return client_class(**kwargs) + + +def _handle_litellm_verify_ssl(llm_config: "LLMBaseConfig") -> None: + if not getattr(llm_config, "verify_ssl", True): + # Currently litellm does not support disabling this on a per-LLM basis for any backend other than Bedrock and + # AIM Guardrail. + import litellm + litellm.ssl_verify = False diff --git a/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py b/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py index 1c6d292742..4aad137186 100644 --- a/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py +++ b/packages/nvidia_nat_core/tests/nat/llm/test_dynamo_llm.py @@ -23,7 +23,7 @@ from nat.llm.dynamo_llm import CachePinType from nat.llm.dynamo_llm import DynamoModelConfig from nat.llm.dynamo_llm import DynamoPrefixContext -from nat.llm.dynamo_llm import create_httpx_client_with_dynamo_hooks +from nat.llm.dynamo_llm import _create_httpx_client_with_dynamo_hooks # --------------------------------------------------------------------------- # DynamoModelConfig Tests @@ -409,16 +409,12 @@ async def test_prefix_id_consistent_across_transport_requests(self): class TestCreateHttpxClient: - """Tests for create_httpx_client_with_dynamo_hooks.""" + """Tests for _create_httpx_client_with_dynamo_hooks.""" def test_uses_custom_timeout(self): - """Test that the function uses the provided timeout.""" - client = create_httpx_client_with_dynamo_hooks( - total_requests=10, - osl=512, - iat=250, - timeout=120.0, - ) + """Test that the function uses the provided timeout from config.""" + config = DynamoModelConfig(model_name="test", request_timeout=120.0) + client = _create_httpx_client_with_dynamo_hooks(config) assert client.timeout.connect == 120.0 assert client.timeout.read == 120.0 @@ -426,25 +422,24 @@ def test_uses_custom_timeout(self): def test_uses_default_timeout(self): """Test that the function uses default timeout when not specified.""" - client = create_httpx_client_with_dynamo_hooks( - total_requests=10, - osl=512, - iat=250, - ) + config = DynamoModelConfig(model_name="test") + client = _create_httpx_client_with_dynamo_hooks(config) assert client.timeout.connect == 600.0 def test_creates_client_with_custom_transport(self): - """Test that create_httpx_client_with_dynamo_hooks uses _DynamoTransport.""" + """Test that _create_httpx_client_with_dynamo_hooks uses _DynamoTransport when enable_nvext_hints=True.""" from nat.llm.dynamo_llm import _DynamoTransport - client = create_httpx_client_with_dynamo_hooks( - total_requests=7, - osl=2048, - iat=50, - timeout=120.0, - prediction_lookup=None, + config = DynamoModelConfig( + model_name="test", + enable_nvext_hints=True, + nvext_prefix_total_requests=7, + nvext_prefix_osl=2048, + nvext_prefix_iat=50, + request_timeout=120.0, ) + client = _create_httpx_client_with_dynamo_hooks(config) # Verify client uses custom transport assert isinstance(client._transport, _DynamoTransport) @@ -459,29 +454,29 @@ def test_creates_client_with_custom_transport(self): assert client.timeout.read == 120.0 def test_creates_client_with_cache_pin_type_none(self): - """Test that create_httpx_client_with_dynamo_hooks passes cache_pin_type=None through.""" + """Test that _create_httpx_client_with_dynamo_hooks passes cache_pin_type=None through.""" from nat.llm.dynamo_llm import _DynamoTransport - client = create_httpx_client_with_dynamo_hooks( - total_requests=10, - osl=512, - iat=250, - cache_pin_type=None, + config = DynamoModelConfig( + model_name="test", + enable_nvext_hints=True, + nvext_cache_pin_type=None, ) + client = _create_httpx_client_with_dynamo_hooks(config) assert isinstance(client._transport, _DynamoTransport) assert client._transport._cache_pin_type is None def test_creates_client_with_cache_control_mode_first_only(self): - """Test that create_httpx_client_with_dynamo_hooks passes cache_control_mode through.""" + """Test that _create_httpx_client_with_dynamo_hooks passes cache_control_mode through.""" from nat.llm.dynamo_llm import _DynamoTransport - client = create_httpx_client_with_dynamo_hooks( - total_requests=10, - osl=512, - iat=250, - cache_control_mode=CacheControlMode.FIRST_ONLY, + config = DynamoModelConfig( + model_name="test", + enable_nvext_hints=True, + nvext_cache_control_mode=CacheControlMode.FIRST_ONLY, ) + client = _create_httpx_client_with_dynamo_hooks(config) assert isinstance(client._transport, _DynamoTransport) assert client._transport._cache_control_mode == CacheControlMode.FIRST_ONLY diff --git a/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py b/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py index d645fc687c..2a512556e7 100644 --- a/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py +++ b/packages/nvidia_nat_core/tests/nat/llm/utils/test_hooks.py @@ -15,12 +15,28 @@ """Unit and integration tests for LLM HTTP event hooks.""" from unittest.mock import MagicMock +from unittest.mock import patch import pytest +from pydantic import Field from pytest_httpserver import HTTPServer from nat.builder.context import ContextState -from nat.llm.utils.hooks import create_metadata_injection_client +from nat.data_models.llm import LLMBaseConfig +from nat.data_models.llm import SSLVerificationMixin +from nat.llm.utils.hooks import _create_metadata_injection_client + + +class LLMConfig(LLMBaseConfig): + pass + + +class LLMConfigWithTimeout(LLMBaseConfig): + request_timeout: float | None = Field(default=None, gt=0.0, description="HTTP request timeout in seconds.") + + +class LLMConfigWithSSL(LLMConfigWithTimeout, SSLVerificationMixin): + pass class TestMetadataInjectionHook: @@ -46,7 +62,7 @@ def fixture_mock_input_message(self): async def test_hook_injects_metadata_fields(self, mock_httpx_request, mock_input_message): """Test that the hook injects custom metadata fields as headers.""" - client = create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) hook = client.event_hooks["request"][0] context_state = ContextState.get() @@ -67,7 +83,7 @@ async def test_hook_skips_none_values(self, mock_httpx_request, mock_input_messa "optional_field": None, } - client = create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) hook = client.event_hooks["request"][0] context_state = ContextState.get() @@ -82,7 +98,7 @@ async def test_hook_skips_none_values(self, mock_httpx_request, mock_input_messa async def test_hook_handles_missing_context(self, mock_httpx_request): """Test that hook handles missing context gracefully.""" - client = create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) hook = client.event_hooks["request"][0] await hook(mock_httpx_request) @@ -94,17 +110,69 @@ async def test_hook_handles_missing_context(self, mock_httpx_request): class TestCreateMetadataInjectionClient: - """Unit tests for create_metadata_injection_client function.""" + """Unit tests for _create_metadata_injection_client function.""" async def test_creates_client_with_event_hooks(self): """Test that client is created with event hooks.""" - client = create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) assert "request" in client.event_hooks assert len(client.event_hooks["request"]) == 1 await client.aclose() + @pytest.mark.parametrize( + "llm_config,expected_timeout", + [ + (LLMConfig(), None), + (LLMConfigWithTimeout(), None), + (LLMConfigWithTimeout(request_timeout=45), 45), + ], + ids=["no_request_timeout_attr", "request_timeout_none", "request_timeout_45"], + ) + async def test_request_timeout_passed_to_client(self, llm_config, expected_timeout): + """Client receives timeout from config when request_timeout is set.""" + import httpx + captured: dict = {} + real_async_client = httpx.AsyncClient + def capture_async_client(*args, **kwargs): + captured.clear() + captured.update(kwargs) + return real_async_client(*args, **kwargs) + with patch.object(httpx, "AsyncClient", side_effect=capture_async_client): + client = _create_metadata_injection_client(llm_config=llm_config) + if expected_timeout is None: + assert "timeout" not in captured + else: + assert captured["timeout"] == expected_timeout + await client.aclose() + + @pytest.mark.parametrize( + "llm_config,expected_verify", + [ + (LLMConfig(), None), + (LLMConfigWithSSL(verify_ssl=True), True), + (LLMConfigWithSSL(verify_ssl=False), False), + ], + ids=["no_verify_ssl_attr", "verify_ssl_true", "verify_ssl_false"], + ) + async def test_verify_ssl_passed_to_client(self, llm_config, expected_verify): + """Client receives verify from config when verify_ssl is set.""" + import httpx + captured: dict = {} + real_async_client = httpx.AsyncClient + def capture_async_client(*args, **kwargs): + captured.clear() + captured.update(kwargs) + return real_async_client(*args, **kwargs) + with patch.object(httpx, "AsyncClient", side_effect=capture_async_client): + client = _create_metadata_injection_client(llm_config=llm_config) + if expected_verify is None: + assert "verify" not in captured + else: + assert captured["verify"] is expected_verify + await client.aclose() + class TestMetadataInjectionIntegration: """Integration tests with mock HTTP server.""" @@ -139,7 +207,7 @@ async def test_headers_sent_in_http_request(self, httpserver: HTTPServer, mock_i } }) - client = create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) context_state = ContextState.get() context_state.input_message.set(mock_input_message) @@ -181,7 +249,7 @@ async def test_request_succeeds_without_context(self, httpserver: HTTPServer): } }) - client = create_metadata_injection_client() + client = _create_metadata_injection_client(llm_config=LLMConfig()) response = await client.post(httpserver.url_for("/v1/chat/completions"), json={ diff --git a/packages/nvidia_nat_crewai/src/nat/plugins/crewai/llm.py b/packages/nvidia_nat_crewai/src/nat/plugins/crewai/llm.py index 8d54ff6d71..faa8b25b3d 100644 --- a/packages/nvidia_nat_crewai/src/nat/plugins/crewai/llm.py +++ b/packages/nvidia_nat_crewai/src/nat/plugins/crewai/llm.py @@ -27,6 +27,7 @@ from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _handle_litellm_verify_ssl # crewAI uses litellm under the hood from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -114,6 +115,8 @@ async def azure_openai_crewai(llm_config: AzureOpenAIModelConfig, _builder: Buil if llm_config.request_timeout is not None: config_dict["timeout"] = llm_config.request_timeout + _handle_litellm_verify_ssl(llm_config.verify_ssl) + client = LLM( **config_dict, model=model, @@ -136,6 +139,8 @@ async def nim_crewai(llm_config: NIMModelConfig, _builder: Builder): if nvidia_api_key is not None: os.environ["NVIDIA_NIM_API_KEY"] = nvidia_api_key + _handle_litellm_verify_ssl(llm_config.verify_ssl) + client = LLM( **llm_config.model_dump( exclude={"type", "model_name", "thinking", "api_type"}, @@ -163,6 +168,8 @@ async def openai_crewai(llm_config: OpenAIModelConfig, _builder: Builder): exclude_unset=True, ) + _handle_litellm_verify_ssl(llm_config.verify_ssl) + if (api_key := get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY")): config_dict["api_key"] = api_key if (base_url := llm_config.base_url or os.getenv("OPENAI_BASE_URL")): @@ -182,6 +189,8 @@ async def litellm_crewai(llm_config: LiteLlmModelConfig, _builder: Builder): validate_no_responses_api(llm_config, LLMFrameworkEnum.CREWAI) + _handle_litellm_verify_ssl(llm_config.verify_ssl) + client = LLM(**llm_config.model_dump( exclude={"type", "thinking", "api_type"}, by_alias=True, exclude_none=True, exclude_unset=True)) diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index 5af5fe861e..e51fc948a2 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -18,32 +18,27 @@ import os from collections.abc import AsyncIterator from collections.abc import Sequence -from pathlib import Path from typing import TYPE_CHECKING from typing import Any from typing import TypeVar -if TYPE_CHECKING: - import httpx - from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum from nat.cli.register_workflow import register_llm_client from nat.data_models.common import get_secret_value from nat.data_models.llm import APITypeEnum -from nat.data_models.llm import LLMBaseConfig from nat.data_models.retry_mixin import RetryMixin from nat.data_models.thinking_mixin import ThinkingMixin from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig from nat.llm.azure_openai_llm import AzureOpenAIModelConfig from nat.llm.dynamo_llm import DynamoModelConfig -from nat.llm.dynamo_llm import create_httpx_client_with_dynamo_hooks +from nat.llm.dynamo_llm import _create_httpx_client_with_dynamo_hooks from nat.llm.huggingface_inference_llm import HuggingFaceInferenceLLMConfig from nat.llm.huggingface_llm import HuggingFaceConfig from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig -from nat.llm.utils.hooks import create_metadata_injection_client +from nat.llm.utils.hooks import _create_metadata_injection_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -51,12 +46,17 @@ from nat.utils.responses_api import validate_no_responses_api from nat.utils.type_utils import override +if TYPE_CHECKING: + import httpx + + from nat.data_models.llm import LLMBaseConfig + logger = logging.getLogger(__name__) ModelType = TypeVar("ModelType") -def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType: +def _patch_llm_based_on_config(client: ModelType, llm_config: "LLMBaseConfig") -> ModelType: from langchain_core.language_models import LanguageModelInput from langchain_core.messages import BaseMessage @@ -149,17 +149,14 @@ async def azure_openai_langchain(llm_config: AzureOpenAIModelConfig, _builder: B validate_no_responses_api(llm_config, LLMFrameworkEnum.LANGCHAIN) - client_kwargs: dict = {} - if llm_config.request_timeout is not None: - client_kwargs["timeout"] = llm_config.request_timeout - http_async_client: httpx.AsyncClient = create_metadata_injection_client(**client_kwargs) + http_async_client: httpx.AsyncClient = _create_metadata_injection_client(llm_config) try: client = AzureChatOpenAI( http_async_client=http_async_client, # type: ignore[call-arg] api_version=llm_config.api_version, # type: ignore[call-arg] **llm_config.model_dump( - exclude={"type", "thinking", "api_type", "api_version"}, + exclude={"type", "thinking", "api_type", "api_version", "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, @@ -181,6 +178,7 @@ async def nim_langchain(llm_config: NIMModelConfig, _builder: Builder): validate_no_responses_api(llm_config, LLMFrameworkEnum.LANGCHAIN) # prefer max_completion_tokens over max_tokens + # verify_ssl is a supported keyword parameter for the ChatNVIDIA client client = ChatNVIDIA( **llm_config.model_dump( exclude={"type", "max_tokens", "thinking", "api_type"}, @@ -199,13 +197,10 @@ async def openai_langchain(llm_config: OpenAIModelConfig, _builder: Builder): from langchain_openai import ChatOpenAI - client_kwargs: dict = {} - if llm_config.request_timeout is not None: - client_kwargs["timeout"] = llm_config.request_timeout - http_async_client: httpx.AsyncClient = create_metadata_injection_client(**client_kwargs) + http_async_client: httpx.AsyncClient = _create_metadata_injection_client(llm_config) config_dict = llm_config.model_dump( - exclude={"type", "thinking", "api_type", "api_key", "base_url"}, + exclude={"type", "thinking", "api_type", "api_key", "base_url", "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, @@ -246,9 +241,6 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): """ from langchain_openai import ChatOpenAI - from nat.profiler.prediction_trie import load_prediction_trie - from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup - # Build config dict excluding Dynamo-specific and NAT-specific fields config_dict = llm_config.model_dump( exclude={"type", "thinking", "api_type", *DynamoModelConfig.get_dynamo_field_names()}, @@ -257,43 +249,10 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): exclude_unset=True, ) - # Initialize http_async_client to None for proper cleanup - http_async_client = None - - # Load prediction trie if configured - prediction_lookup: PredictionTrieLookup | None = None - if llm_config.nvext_prediction_trie_path: - try: - trie_path = Path(llm_config.nvext_prediction_trie_path) - trie = load_prediction_trie(trie_path) - prediction_lookup = PredictionTrieLookup(trie) - logger.info("Loaded prediction trie from %s", llm_config.nvext_prediction_trie_path) - except FileNotFoundError: - logger.warning("Prediction trie file not found: %s", llm_config.nvext_prediction_trie_path) - except Exception as e: - logger.warning("Failed to load prediction trie: %s", e) + http_async_client = _create_httpx_client_with_dynamo_hooks(llm_config) + config_dict["http_async_client"] = http_async_client try: - if llm_config.enable_nvext_hints: - http_async_client = create_httpx_client_with_dynamo_hooks( - total_requests=llm_config.nvext_prefix_total_requests, - osl=llm_config.nvext_prefix_osl, - iat=llm_config.nvext_prefix_iat, - timeout=llm_config.request_timeout, - prediction_lookup=prediction_lookup, - cache_pin_type=llm_config.nvext_cache_pin_type, - cache_control_mode=llm_config.nvext_cache_control_mode, - max_sensitivity=llm_config.nvext_max_sensitivity, - ) - config_dict["http_async_client"] = http_async_client - logger.info( - "Dynamo agent hints enabled: total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", - llm_config.nvext_prefix_total_requests, - llm_config.nvext_prefix_osl, - llm_config.nvext_prefix_iat, - "loaded" if prediction_lookup else "disabled", - ) - # Create the ChatOpenAI client if llm_config.api_type == APITypeEnum.RESPONSES: client = ChatOpenAI(stream_usage=True, use_responses_api=True, use_previous_response_id=True, **config_dict) @@ -303,8 +262,7 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): yield _patch_llm_based_on_config(client, llm_config) finally: # Ensure the httpx client is properly closed to avoid resource leaks - if http_async_client is not None: - await http_async_client.aclose() + await http_async_client.aclose() @register_llm_client(config_type=LiteLlmModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) diff --git a/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py b/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py index a386238a77..205e8e31ea 100644 --- a/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py +++ b/packages/nvidia_nat_langchain/tests/test_dynamo_trie_loading.py @@ -28,7 +28,6 @@ from nat.profiler.prediction_trie.data_models import LLMCallPrediction from nat.profiler.prediction_trie.data_models import PredictionMetrics from nat.profiler.prediction_trie.data_models import PredictionTrieNode -from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup @pytest.fixture(name="trie_file") @@ -83,10 +82,10 @@ def test_dynamo_config_with_nonexistent_trie_path(): assert config.nvext_prediction_trie_path == "/nonexistent/path/trie.json" -@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_dynamo_langchain_loads_trie_and_passes_to_client(mock_chat, mock_create_client, trie_file, mock_builder): - """Test that dynamo_langchain loads trie from path and passes PredictionTrieLookup to httpx client.""" + """Test that dynamo_langchain calls _create_httpx_client_with_dynamo_hooks with config that has trie path.""" mock_httpx_client = MagicMock() mock_httpx_client.aclose = AsyncMock() mock_create_client.return_value = mock_httpx_client @@ -101,19 +100,16 @@ async def test_dynamo_langchain_loads_trie_and_passes_to_client(mock_chat, mock_ ) async with dynamo_langchain(config, mock_builder): - # Verify httpx client was created with prediction_lookup - mock_create_client.assert_called_once() - call_kwargs = mock_create_client.call_args.kwargs - assert "prediction_lookup" in call_kwargs - assert isinstance(call_kwargs["prediction_lookup"], PredictionTrieLookup) + mock_create_client.assert_called_once_with(config) + assert mock_create_client.call_args[0][0].nvext_prediction_trie_path == trie_file mock_httpx_client.aclose.assert_awaited_once() -@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_dynamo_langchain_handles_nonexistent_trie_gracefully(mock_chat, mock_create_client, mock_builder): - """Test that dynamo_langchain logs warning and continues when trie file doesn't exist.""" + """Test that dynamo_langchain calls client creation with config when trie path doesn't exist.""" mock_httpx_client = MagicMock() mock_httpx_client.aclose = AsyncMock() mock_create_client.return_value = mock_httpx_client @@ -127,20 +123,16 @@ async def test_dynamo_langchain_handles_nonexistent_trie_gracefully(mock_chat, m nvext_prediction_trie_path="/nonexistent/path/trie.json", ) - # Should not raise an exception async with dynamo_langchain(config, mock_builder): - # Verify httpx client was created with prediction_lookup=None - mock_create_client.assert_called_once() - call_kwargs = mock_create_client.call_args.kwargs - assert call_kwargs["prediction_lookup"] is None + mock_create_client.assert_called_once_with(config) mock_httpx_client.aclose.assert_awaited_once() -@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_dynamo_langchain_no_trie_path_means_no_lookup(mock_chat, mock_create_client, mock_builder): - """Test that dynamo_langchain passes None when no trie path is configured.""" + """Test that dynamo_langchain calls client creation with config when no trie path is configured.""" mock_httpx_client = MagicMock() mock_httpx_client.aclose = AsyncMock() mock_create_client.return_value = mock_httpx_client @@ -154,14 +146,13 @@ async def test_dynamo_langchain_no_trie_path_means_no_lookup(mock_chat, mock_cre ) async with dynamo_langchain(config, mock_builder): - mock_create_client.assert_called_once() - call_kwargs = mock_create_client.call_args.kwargs - assert call_kwargs["prediction_lookup"] is None + mock_create_client.assert_called_once_with(config) + assert mock_create_client.call_args[0][0].nvext_prediction_trie_path is None mock_httpx_client.aclose.assert_awaited_once() -@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_dynamo_langchain_handles_invalid_trie_file_gracefully(mock_chat, mock_create_client, mock_builder): """Test that dynamo_langchain logs warning and continues when trie file is invalid JSON.""" @@ -183,12 +174,8 @@ async def test_dynamo_langchain_handles_invalid_trie_file_gracefully(mock_chat, nvext_prediction_trie_path=invalid_trie_path, ) - # Should not raise an exception async with dynamo_langchain(config, mock_builder): - # Verify httpx client was created with prediction_lookup=None - mock_create_client.assert_called_once() - call_kwargs = mock_create_client.call_args.kwargs - assert call_kwargs["prediction_lookup"] is None + mock_create_client.assert_called_once_with(config) mock_httpx_client.aclose.assert_awaited_once() finally: diff --git a/packages/nvidia_nat_langchain/tests/test_llm_langchain.py b/packages/nvidia_nat_langchain/tests/test_llm_langchain.py index d6bcbabedc..e30e9722a2 100644 --- a/packages/nvidia_nat_langchain/tests/test_llm_langchain.py +++ b/packages/nvidia_nat_langchain/tests/test_llm_langchain.py @@ -25,8 +25,6 @@ from nat.builder.framework_enum import LLMFrameworkEnum from nat.data_models.llm import APITypeEnum from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig -from nat.llm.dynamo_llm import CacheControlMode -from nat.llm.dynamo_llm import CachePinType from nat.llm.dynamo_llm import DynamoModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig @@ -212,7 +210,7 @@ def dynamo_cfg_responses_api(self): @patch("langchain_openai.ChatOpenAI") async def test_basic_creation_without_prefix(self, mock_chat, dynamo_cfg_no_prefix, mock_builder): - """Wrapper should create ChatOpenAI without custom httpx client when nvext hints disabled.""" + """Wrapper should create ChatOpenAI with httpx client (no Dynamo transport when nvext hints disabled).""" async with dynamo_langchain(dynamo_cfg_no_prefix, mock_builder) as client: mock_chat.assert_called_once() kwargs = mock_chat.call_args.kwargs @@ -220,11 +218,11 @@ async def test_basic_creation_without_prefix(self, mock_chat, dynamo_cfg_no_pref assert kwargs["model"] == "test-model" assert kwargs["base_url"] == "http://localhost:8000/v1" assert kwargs["stream_usage"] is True - # Should NOT have custom httpx client - assert "http_async_client" not in kwargs + # Always passes an httpx client; when enable_nvext_hints=False it has no _DynamoTransport + assert "http_async_client" in kwargs assert client is mock_chat.return_value - @patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") + @patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_creation_with_prefix_template(self, mock_chat, @@ -237,17 +235,7 @@ async def test_creation_with_prefix_template(self, mock_create_client.return_value = mock_httpx_client async with dynamo_langchain(dynamo_cfg_with_prefix, mock_builder) as client: - # Verify httpx client was created with correct parameters - mock_create_client.assert_called_once_with( - total_requests=15, - osl=2048, - iat=50, - timeout=300.0, - prediction_lookup=None, - cache_pin_type=CachePinType.EPHEMERAL, - cache_control_mode=CacheControlMode.ALWAYS, - max_sensitivity=1000, - ) + mock_create_client.assert_called_once_with(dynamo_cfg_with_prefix) # Verify ChatOpenAI was called with the custom httpx client mock_chat.assert_called_once() @@ -260,7 +248,7 @@ async def test_creation_with_prefix_template(self, # Verify the httpx client was properly closed mock_httpx_client.aclose.assert_awaited_once() - @patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") + @patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_responses_api_branch(self, mock_chat, mock_create_client, dynamo_cfg_responses_api, mock_builder): """When APIType==RESPONSES, special flags should be added.""" @@ -279,7 +267,7 @@ async def test_responses_api_branch(self, mock_chat, mock_create_client, dynamo_ # Verify the httpx client was properly closed mock_httpx_client.aclose.assert_awaited_once() - @patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") + @patch("nat.plugins.langchain.llm._create_httpx_client_with_dynamo_hooks") @patch("langchain_openai.ChatOpenAI") async def test_excludes_dynamo_specific_fields(self, mock_chat, diff --git a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py index fb478ae3c3..bfb85a773f 100644 --- a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py +++ b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/llm.py @@ -14,8 +14,8 @@ # limitations under the License. import os +import typing from collections.abc import Sequence -from typing import TypeVar from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum @@ -30,6 +30,7 @@ from nat.llm.litellm_llm import LiteLlmModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _create_http_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -37,7 +38,10 @@ from nat.utils.responses_api import validate_no_responses_api from nat.utils.type_utils import override -ModelType = TypeVar("ModelType") +if typing.TYPE_CHECKING: + import httpx + +ModelType = typing.TypeVar("ModelType") def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType: @@ -81,6 +85,14 @@ def inject(self, messages: Sequence[ChatMessage], *args, **kwargs) -> FunctionAr return client +def _get_http_clients(llm_config: LLMBaseConfig) -> dict[str, "httpx.AsyncClient | httpx.Client"]: + """Get a dictionary of HTTP clients, one sync one async.""" + return { + "http_client": _create_http_client(llm_config, use_async=False), + "async_http_client": _create_http_client(llm_config, use_async=True) + } + + @register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX) async def aws_bedrock_llama_index(llm_config: AWSBedrockModelConfig, _builder: Builder): @@ -89,8 +101,10 @@ async def aws_bedrock_llama_index(llm_config: AWSBedrockModelConfig, _builder: B validate_no_responses_api(llm_config, LLMFrameworkEnum.LLAMA_INDEX) # LlamaIndex uses context_size instead of max_tokens - llm = Bedrock(**llm_config.model_dump( - exclude={"type", "top_p", "thinking", "api_type"}, by_alias=True, exclude_none=True, exclude_unset=True)) + llm = Bedrock(**llm_config.model_dump(exclude={"api_type", "thinking", "top_p", "type", "verify_ssl"}, + by_alias=True, + exclude_none=True, + exclude_unset=True)) yield _patch_llm_based_on_config(llm, llm_config) @@ -102,13 +116,15 @@ async def azure_openai_llama_index(llm_config: AzureOpenAIModelConfig, _builder: validate_no_responses_api(llm_config, LLMFrameworkEnum.LLAMA_INDEX) - config_dict = llm_config.model_dump(exclude={"type", "thinking", "api_type", "api_version", "request_timeout"}, - by_alias=True, - exclude_none=True, - exclude_unset=True) + config_dict = llm_config.model_dump( + exclude={"api_type", "api_version", "request_timeout", "thinking", "type", "verify_ssl"}, + by_alias=True, + exclude_none=True, + exclude_unset=True) if llm_config.request_timeout is not None: config_dict["timeout"] = llm_config.request_timeout + config_dict.update(_get_http_clients(llm_config)) llm = AzureOpenAI( **config_dict, api_version=llm_config.api_version, @@ -124,8 +140,20 @@ async def nim_llama_index(llm_config: NIMModelConfig, _builder: Builder): validate_no_responses_api(llm_config, LLMFrameworkEnum.LLAMA_INDEX) - llm = NVIDIA(**llm_config.model_dump( - exclude={"type", "thinking", "api_type"}, by_alias=True, exclude_none=True, exclude_unset=True)) + config_dict = llm_config.model_dump( + exclude={ + "api_type", + "thinking", + "type", + "verify_ssl", + }, + by_alias=True, + exclude_none=True, + exclude_unset=True, + ) + + config_dict.update(_get_http_clients(llm_config)) + llm = NVIDIA(**config_dict) yield _patch_llm_based_on_config(llm, llm_config) @@ -137,7 +165,7 @@ async def openai_llama_index(llm_config: OpenAIModelConfig, _builder: Builder): from llama_index.llms.openai import OpenAIResponses config_dict = llm_config.model_dump( - exclude={"type", "thinking", "api_type", "api_key", "base_url", "request_timeout"}, + exclude={"api_key", "api_type", "base_url", "request_timeout", "thinking", "type", "verify_ssl"}, by_alias=True, exclude_none=True, exclude_unset=True, @@ -151,6 +179,7 @@ async def openai_llama_index(llm_config: OpenAIModelConfig, _builder: Builder): if llm_config.request_timeout is not None: config_dict["timeout"] = llm_config.request_timeout + config_dict.update(_get_http_clients(llm_config)) if llm_config.api_type == APITypeEnum.RESPONSES: llm = OpenAIResponses(**config_dict) else: @@ -164,9 +193,15 @@ async def litellm_llama_index(llm_config: LiteLlmModelConfig, _builder: Builder) from llama_index.llms.litellm import LiteLLM + from nat.llm.utils.http_client import _handle_litellm_verify_ssl + + _handle_litellm_verify_ssl(llm_config.verify_ssl) validate_no_responses_api(llm_config, LLMFrameworkEnum.LLAMA_INDEX) - llm = LiteLLM(**llm_config.model_dump( - exclude={"type", "thinking", "api_type"}, by_alias=True, exclude_none=True, exclude_unset=True)) + llm = LiteLLM( + **llm_config.model_dump(exclude={"api_type", "thinking", "type", "verify_ssl"}, + by_alias=True, + exclude_none=True, + exclude_unset=True), ) yield _patch_llm_based_on_config(llm, llm_config) diff --git a/packages/nvidia_nat_semantic_kernel/src/nat/plugins/semantic_kernel/llm.py b/packages/nvidia_nat_semantic_kernel/src/nat/plugins/semantic_kernel/llm.py index 5a92d34e62..d912cc3cde 100644 --- a/packages/nvidia_nat_semantic_kernel/src/nat/plugins/semantic_kernel/llm.py +++ b/packages/nvidia_nat_semantic_kernel/src/nat/plugins/semantic_kernel/llm.py @@ -25,6 +25,7 @@ from nat.data_models.thinking_mixin import ThinkingMixin from nat.llm.azure_openai_llm import AzureOpenAIModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _create_http_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -90,18 +91,19 @@ def inject(self, chat_history: ChatHistory, *args, **kwargs) -> FunctionArgument @register_llm_client(config_type=AzureOpenAIModelConfig, wrapper_type=LLMFrameworkEnum.SEMANTIC_KERNEL) async def azure_openai_semantic_kernel(llm_config: AzureOpenAIModelConfig, _builder: Builder): + from openai import AsyncAzureOpenAI from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion validate_no_responses_api(llm_config, LLMFrameworkEnum.SEMANTIC_KERNEL) - llm = AzureChatCompletion( - api_key=get_secret_value(llm_config.api_key), - api_version=llm_config.api_version, - endpoint=llm_config.azure_endpoint, - deployment_name=llm_config.azure_deployment, - ) + async with AsyncAzureOpenAI(api_key=get_secret_value(llm_config.api_key), + api_version=llm_config.api_version, + azure_endpoint=llm_config.azure_endpoint, + azure_deployment=llm_config.azure_deployment, + http_client=_create_http_client(llm_config, use_async=True)) as async_client: + llm = AzureChatCompletion(async_client=async_client) - yield _patch_llm_based_on_config(llm, llm_config) + yield _patch_llm_based_on_config(llm, llm_config) @register_llm_client(config_type=OpenAIModelConfig, wrapper_type=LLMFrameworkEnum.SEMANTIC_KERNEL) @@ -115,6 +117,8 @@ async def openai_semantic_kernel(llm_config: OpenAIModelConfig, _builder: Builde api_key = get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY") base_url = llm_config.base_url or os.getenv("OPENAI_BASE_URL") - async with AsyncOpenAI(api_key=api_key, base_url=base_url) as async_client: + async with AsyncOpenAI(api_key=api_key, + base_url=base_url, + http_client=_create_http_client(llm_config, use_async=True)) as async_client: llm = OpenAIChatCompletion(ai_model_id=llm_config.model_name, async_client=async_client) yield _patch_llm_based_on_config(llm, llm_config) diff --git a/packages/nvidia_nat_strands/src/nat/plugins/strands/llm.py b/packages/nvidia_nat_strands/src/nat/plugins/strands/llm.py index 7d499a502b..7a506c5574 100644 --- a/packages/nvidia_nat_strands/src/nat/plugins/strands/llm.py +++ b/packages/nvidia_nat_strands/src/nat/plugins/strands/llm.py @@ -59,6 +59,7 @@ from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig from nat.llm.nim_llm import NIMModelConfig from nat.llm.openai_llm import OpenAIModelConfig +from nat.llm.utils.http_client import _create_http_client from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking @@ -158,13 +159,26 @@ async def openai_strands(llm_config: OpenAIModelConfig, _builder: Builder) -> As validate_no_responses_api(llm_config, LLMFrameworkEnum.STRANDS) + from openai import AsyncOpenAI + from strands.models.openai import OpenAIModel params = llm_config.model_dump( - exclude={"type", "api_type", "api_key", "base_url", "model_name", "max_retries", "thinking", "request_timeout"}, + exclude={ + "api_key", + "api_type", + "base_url", + "max_retries", + "model_name", + "request_timeout", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, - exclude_unset=True) + exclude_unset=True, + ) api_key = get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY") base_url = llm_config.base_url or os.getenv("OPENAI_BASE_URL") @@ -172,12 +186,14 @@ async def openai_strands(llm_config: OpenAIModelConfig, _builder: Builder) -> As client_args: dict[str, Any] = { "api_key": api_key, "base_url": base_url, + "http_client": _create_http_client(llm_config, use_async=True), } if llm_config.request_timeout is not None: client_args["timeout"] = llm_config.request_timeout + oai_client = AsyncOpenAI(**client_args) client = OpenAIModel( - client_args=client_args, + client=oai_client, model_id=llm_config.model_name, params=params, ) @@ -207,6 +223,8 @@ async def nim_strands(llm_config: NIMModelConfig, _builder: Builder) -> AsyncGen validate_no_responses_api(llm_config, LLMFrameworkEnum.STRANDS) # NIM is OpenAI compatible; use OpenAI model with NIM base_url and api_key + from openai import AsyncOpenAI + from strands.models.openai import OpenAIModel # Create a custom OpenAI model that formats text content as strings for NIM compatibility @@ -266,10 +284,20 @@ def format_request_messages(cls, messages, system_prompt=None, *, system_prompt_ return formatted_messages params = llm_config.model_dump( - exclude={"type", "api_type", "api_key", "base_url", "model_name", "max_retries", "thinking"}, + exclude={ + "api_key", + "api_type", + "base_url", + "max_retries", + "model_name", + "thinking", + "type", + "verify_ssl", + }, by_alias=True, exclude_none=True, - exclude_unset=True) + exclude_unset=True, + ) # Determine base_url base_url = llm_config.base_url or "https://integrate.api.nvidia.com/v1" @@ -280,11 +308,13 @@ def format_request_messages(cls, messages, system_prompt=None, *, system_prompt_ if llm_config.base_url and llm_config.base_url.strip() and api_key is None: api_key = "dummy-api-key" + oai_client = AsyncOpenAI( + api_key=api_key, + base_url=base_url, + http_client=_create_http_client(llm_config, use_async=True), + ) client = NIMCompatibleOpenAIModel( - client_args={ - "api_key": api_key, - "base_url": base_url, - }, + client=oai_client, model_id=llm_config.model_name, params=params, ) @@ -326,15 +356,16 @@ async def bedrock_strands(llm_config: AWSBedrockModelConfig, _builder: Builder) params = llm_config.model_dump( exclude={ - "type", "api_type", - "model_name", - "region_name", "base_url", - "max_retries", - "thinking", "context_size", "credentials_profile_name", + "max_retries", + "model_name", + "region_name", + "thinking", + "type", + "verify_ssl", }, by_alias=True, exclude_none=True, diff --git a/packages/nvidia_nat_strands/tests/test_strands_llm.py b/packages/nvidia_nat_strands/tests/test_strands_llm.py index 0313f2bc3d..34d972ee2b 100644 --- a/packages/nvidia_nat_strands/tests/test_strands_llm.py +++ b/packages/nvidia_nat_strands/tests/test_strands_llm.py @@ -156,76 +156,89 @@ def nim_config_wrong_api(self): api_type=APITypeEnum.RESPONSES, ) + @pytest.fixture(name="mock_oai_clients") + def mock_oai_clients_fixture(self): + with patch("openai.AsyncOpenAI") as mock_oai: + mock_oai.return_value = mock_oai + + # Patch OpenAIModel constructor to track the call + with patch("strands.models.openai.OpenAIModel.__init__", return_value=None) as mock_oai_model: + yield mock_oai, mock_oai_model + @pytest.mark.asyncio - async def test_nim_strands_basic(self, nim_config, mock_builder): + async def test_nim_strands_basic(self, nim_config, mock_builder, mock_oai_clients): """Test that nim_strands creates a NIMCompatibleOpenAIModel.""" - # Patch OpenAIModel.__init__ to track the call - with patch("strands.models.openai.OpenAIModel.__init__", return_value=None) as mock_init: - # pylint: disable=not-async-context-manager - async with nim_strands(nim_config, mock_builder) as result: - # Verify the result is a NIMCompatibleOpenAIModel instance - assert result is not None + (mock_oai, mock_oai_model) = mock_oai_clients - # Verify OpenAIModel.__init__ was called (the base class) - mock_init.assert_called_once() - call_args = mock_init.call_args + # pylint: disable=not-async-context-manager + async with nim_strands(nim_config, mock_builder) as result: + # Verify the result is a NIMCompatibleOpenAIModel instance + assert result is not None + + mock_oai.assert_called_once() # Ensure OpenAI client init was called + oai_call_args = mock_oai.call_args + oai_call_kwargs = oai_call_args[1] + assert oai_call_kwargs["api_key"] == "test-api-key" + assert oai_call_kwargs["base_url"] == "https://integrate.api.nvidia.com/v1" + + # Verify OpenAIModel.__init__ was called (the base class) + mock_oai_model.assert_called_once() + call_args = mock_oai_model.call_args - # First arg is self, get kwargs - call_kwargs = call_args[1] + # First arg is self, get kwargs + call_kwargs = call_args[1] - # Verify client_args - assert "client_args" in call_kwargs - client_args = call_kwargs["client_args"] - assert client_args["api_key"] == "test-api-key" - assert client_args["base_url"] == "https://integrate.api.nvidia.com/v1" + # Verify client + assert "client" in call_kwargs + call_kwargs["client"] = mock_oai - # Verify model_id - assert call_kwargs["model_id"] == "meta/llama-3.1-8b-instruct" + # Verify model_id + assert call_kwargs["model_id"] == "meta/llama-3.1-8b-instruct" @pytest.mark.asyncio - async def test_nim_strands_with_env_var(self, mock_builder): + async def test_nim_strands_with_env_var(self, mock_builder, mock_oai_clients): """Test nim_strands with environment variable for API key.""" + (mock_oai, mock_oai_model) = mock_oai_clients nim_config = NIMModelConfig(model_name="test-model") - with patch("strands.models.openai.OpenAIModel.__init__", return_value=None) as mock_init: - with patch.dict("os.environ", {"NVIDIA_API_KEY": "env-api-key"}): - # pylint: disable=not-async-context-manager - async with nim_strands(nim_config, mock_builder): - mock_init.assert_called_once() - call_kwargs = mock_init.call_args[1] - client_args = call_kwargs["client_args"] - assert client_args["api_key"] == "env-api-key" + with patch.dict("os.environ", {"NVIDIA_API_KEY": "env-api-key"}): + # pylint: disable=not-async-context-manager + async with nim_strands(nim_config, mock_builder): + mock_oai_model.assert_called_once() + mock_oai.assert_called_once() + + call_kwargs = mock_oai.call_args[1] + assert call_kwargs["api_key"] == "env-api-key" @pytest.mark.asyncio - async def test_nim_strands_default_base_url(self, mock_builder): + async def test_nim_strands_default_base_url(self, mock_builder, mock_oai_clients): """Test nim_strands uses default base_url when not provided.""" + (mock_oai, mock_oai_model) = mock_oai_clients nim_config = NIMModelConfig(model_name="test-model", api_key="test-key") - with patch("strands.models.openai.OpenAIModel.__init__", return_value=None) as mock_init: - # pylint: disable=not-async-context-manager - async with nim_strands(nim_config, mock_builder): - mock_init.assert_called_once() - call_kwargs = mock_init.call_args[1] - client_args = call_kwargs["client_args"] - assert client_args["base_url"] == "https://integrate.api.nvidia.com/v1" + async with nim_strands(nim_config, mock_builder): # pylint: disable=not-async-context-manager + mock_oai_model.assert_called_once() + mock_oai.assert_called_once() + call_kwargs = mock_oai.call_args[1] + assert call_kwargs["base_url"] == "https://integrate.api.nvidia.com/v1" @pytest.mark.asyncio - async def test_nim_strands_nim_override_dummy_api_key(self, mock_builder): + async def test_nim_strands_nim_override_dummy_api_key(self, mock_builder, mock_oai_clients): """Test nim_strands uses dummy API key when base_url is set but no API key available.""" + (mock_oai, mock_oai_model) = mock_oai_clients nim_config = NIMModelConfig( model_name="test-model", base_url="https://custom-nim.example.com/v1", ) - with patch("strands.models.openai.OpenAIModel.__init__", return_value=None) as mock_init: - with patch.dict(os.environ, {}, clear=True): - # pylint: disable=not-async-context-manager - async with nim_strands(nim_config, mock_builder): - mock_init.assert_called_once() - call_kwargs = mock_init.call_args[1] - client_args = call_kwargs["client_args"] - assert client_args["base_url"] == "https://custom-nim.example.com/v1" - assert client_args["api_key"] == "dummy-api-key" + with patch.dict(os.environ, {}, clear=True): + # pylint: disable=not-async-context-manager + async with nim_strands(nim_config, mock_builder): + mock_oai_model.assert_called_once() + mock_oai.assert_called_once() + call_kwargs = mock_oai.call_args[1] + assert call_kwargs["base_url"] == "https://custom-nim.example.com/v1" + assert call_kwargs["api_key"] == "dummy-api-key" def test_nim_compatible_openai_model_format_request_messages(self): """Test NIMCompatibleOpenAIModel.format_request_messages."""