Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import os
import typing

from nat.builder.builder import Builder
from nat.builder.framework_enum import LLMFrameworkEnum
Expand Down Expand Up @@ -43,15 +44,16 @@ async def azure_openai_adk(config: AzureOpenAIModelConfig, _builder: Builder):

config_dict = config.model_dump(
exclude={
"type",
"max_retries",
"thinking",
"azure_endpoint",
"api_type",
"azure_deployment",
"model_name",
"azure_endpoint",
"max_retries",
"model",
"api_type",
"request_timeout"
"model_name",
"request_timeout",
"thinking",
"type",
"verify_ssl"
},
by_alias=True,
exclude_none=True,
Expand All @@ -63,6 +65,7 @@ async def azure_openai_adk(config: AzureOpenAIModelConfig, _builder: Builder):
config_dict["timeout"] = config.request_timeout

config_dict["api_version"] = config.api_version
config_dict["ssl_verify"] = config.verify_ssl

yield LiteLlm(f"azure/{config.azure_deployment}", **config_dict)

Expand All @@ -73,12 +76,13 @@ async def litellm_adk(litellm_config: LiteLlmModelConfig, _builder: Builder):

validate_no_responses_api(litellm_config, LLMFrameworkEnum.ADK)

kwargs = {"ssl_verify": litellm_config.verify_ssl}
yield LiteLlm(**litellm_config.model_dump(
exclude={"type", "max_retries", "thinking", "api_type"},
exclude={"api_type", "max_retries", "thinking", "type", "verify_ssl"},
by_alias=True,
exclude_none=True,
exclude_unset=True,
))
), **kwargs)


@register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.ADK)
Expand All @@ -102,14 +106,16 @@ async def nim_adk(config: NIMModelConfig, _builder: Builder):
os.environ["NVIDIA_NIM_API_KEY"] = api_key

config_dict = config.model_dump(
exclude={"type", "max_retries", "thinking", "model_name", "model", "base_url", "api_type"},
exclude={"api_type", "base_url", "max_retries", "model", "model_name", "thinking", "type", "verify_ssl"},
by_alias=True,
exclude_none=True,
exclude_unset=True,
)
if config.base_url:
config_dict["api_base"] = config.base_url

config_dict["ssl_verify"] = config.verify_ssl

yield LiteLlm(f"nvidia_nim/{config.model_name}", **config_dict)


Expand All @@ -126,7 +132,15 @@ async def openai_adk(config: OpenAIModelConfig, _builder: Builder):
validate_no_responses_api(config, LLMFrameworkEnum.ADK)

config_dict = config.model_dump(
exclude={"type", "max_retries", "thinking", "model_name", "model", "base_url", "api_type", "request_timeout"},
exclude={"api_type",
"base_url",
"max_retries",
"model",
"model_name",
"request_timeout",
"thinking",
"type",
"verify_ssl"},
by_alias=True,
exclude_none=True,
exclude_unset=True,
Expand All @@ -138,6 +152,7 @@ async def openai_adk(config: OpenAIModelConfig, _builder: Builder):
config_dict["api_base"] = base_url
if config.request_timeout is not None:
config_dict["timeout"] = config.request_timeout
config_dict["ssl_verify"] = config.verify_ssl

yield LiteLlm(config.model_name, **config_dict)

Expand Down
41 changes: 35 additions & 6 deletions packages/nvidia_nat_agno/src/nat/plugins/agno/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import os
from typing import TypeVar
import typing

from nat.builder.builder import Builder
from nat.builder.framework_enum import LLMFrameworkEnum
Expand All @@ -27,14 +27,15 @@
from nat.llm.litellm_llm import LiteLlmModelConfig
from nat.llm.nim_llm import NIMModelConfig
from nat.llm.openai_llm import OpenAIModelConfig
from nat.llm.utils.http_client import _create_http_client
from nat.llm.utils.thinking import BaseThinkingInjector
from nat.llm.utils.thinking import FunctionArgumentWrapper
from nat.llm.utils.thinking import patch_with_thinking
from nat.utils.exception_handlers.automatic_retries import patch_with_retry
from nat.utils.responses_api import validate_no_responses_api
from nat.utils.type_utils import override

ModelType = TypeVar("ModelType")
ModelType = typing.TypeVar("ModelType")


def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType:
Expand Down Expand Up @@ -88,14 +89,24 @@ async def nim_agno(llm_config: NIMModelConfig, _builder: Builder):

config_obj = {
**llm_config.model_dump(
exclude={"type", "model_name", "thinking", "api_type"},
exclude={
"api_type",
"model_name",
"thinking",
"type",
"verify_ssl",
},
by_alias=True,
exclude_none=True,
exclude_unset=True,
),
"http_client":
_create_http_client(llm_config),
"id":
llm_config.model_name
}

client = Nvidia(**config_obj, id=llm_config.model_name)
client = Nvidia(**config_obj)

yield _patch_llm_based_on_config(client, llm_config)

Expand All @@ -108,11 +119,22 @@ async def openai_agno(llm_config: OpenAIModelConfig, _builder: Builder):

config_obj = {
**llm_config.model_dump(
exclude={"type", "model_name", "thinking", "api_type", "api_key", "base_url", "request_timeout"},
exclude={
"api_key",
"api_type",
"base_url",
"model_name",
"request_timeout",
"thinking",
"type",
"verify_ssl",
},
by_alias=True,
exclude_none=True,
exclude_unset=True,
),
"http_client":
_create_http_client(llm_config),
}

if (api_key := get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY")):
Expand All @@ -139,11 +161,18 @@ async def litellm_agno(llm_config: LiteLlmModelConfig, _builder: Builder):

client = LiteLLM(
**llm_config.model_dump(
exclude={"type", "thinking", "model_name", "api_type"},
exclude={
"api_type",
"model_name",
"thinking",
"type",
"verify_ssl",
},
by_alias=True,
exclude_none=True,
exclude_unset=True,
),
http_client=_create_http_client(llm_config),
id=llm_config.model_name,
)

Expand Down
29 changes: 23 additions & 6 deletions packages/nvidia_nat_autogen/src/nat/plugins/autogen/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from nat.llm.litellm_llm import LiteLlmModelConfig
from nat.llm.nim_llm import NIMModelConfig
from nat.llm.openai_llm import OpenAIModelConfig
from nat.llm.utils.http_client import _create_http_client
from nat.llm.utils.thinking import BaseThinkingInjector
from nat.llm.utils.thinking import FunctionArgumentWrapper
from nat.llm.utils.thinking import patch_with_thinking
Expand Down Expand Up @@ -147,10 +148,20 @@ async def openai_autogen(llm_config: OpenAIModelConfig, _builder: Builder) -> As
# Extract AutoGen-compatible configuration
config_obj = {
**llm_config.model_dump(
exclude={"type", "model_name", "thinking", "api_key", "base_url", "request_timeout"},
exclude={
"api_key",
"base_url",
"model_name",
"request_timeout",
"thinking",
"type",
"verify_ssl",
},
by_alias=True,
exclude_none=True,
),
"http_client":
_create_http_client(llm_config)
}

if (api_key := get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY")):
Expand Down Expand Up @@ -201,12 +212,14 @@ async def azure_openai_autogen(llm_config: AzureOpenAIModelConfig,
config_obj = {
"api_key":
llm_config.api_key,
"base_url":
f"{llm_config.azure_endpoint}/openai/deployments/{llm_config.azure_deployment}",
"api_version":
llm_config.api_version,
"base_url":
f"{llm_config.azure_endpoint}/openai/deployments/{llm_config.azure_deployment}",
"http_client":
_create_http_client(llm_config),
**llm_config.model_dump(
exclude={"type", "azure_deployment", "thinking", "azure_endpoint", "api_version", "request_timeout"},
exclude={"api_version", "azure_deployment", "azure_endpoint", "request_timeout", "thinking", "type"},
by_alias=True,
exclude_none=True,
),
Expand Down Expand Up @@ -326,8 +339,10 @@ async def nim_autogen(llm_config: NIMModelConfig, _builder: Builder) -> AsyncGen

# Extract NIM configuration for OpenAI-compatible client
config_obj = {
"http_client":
_create_http_client(llm_config),
**llm_config.model_dump(
exclude={"type", "model_name", "thinking"},
exclude={"model_name", "thinking", "type"},
by_alias=True,
exclude_none=True,
exclude_unset=True,
Expand Down Expand Up @@ -386,8 +401,10 @@ async def litellm_autogen(llm_config: LiteLlmModelConfig, _builder: Builder) ->

# Extract LiteLLM configuration for OpenAI-compatible client
config_obj = {
"http_client":
_create_http_client(llm_config),
**llm_config.model_dump(
exclude={"type", "model_name", "thinking"},
exclude={"model_name", "thinking", "type"},
by_alias=True,
exclude_none=True,
exclude_unset=True,
Expand Down
9 changes: 9 additions & 0 deletions packages/nvidia_nat_core/src/nat/data_models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import typing
from enum import StrEnum

from pydantic import BaseModel
from pydantic import Field

from .common import BaseModelRegistryTag
Expand All @@ -39,3 +40,11 @@ class LLMBaseConfig(TypedBaseModel, BaseModelRegistryTag):


LLMBaseConfigT = typing.TypeVar("LLMBaseConfigT", bound=LLMBaseConfig)

class SSLVerificationMixin(BaseModel):
"""Mixin for SSL verification configuration."""

verify_ssl: bool = Field(
default=True,
description="Whether to verify SSL certificates when making API calls to the LLM provider. Defaults to True.",
)
2 changes: 2 additions & 0 deletions packages/nvidia_nat_core/src/nat/llm/litellm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nat.cli.register_workflow import register_llm_provider
from nat.data_models.common import OptionalSecretStr
from nat.data_models.llm import LLMBaseConfig
from nat.data_models.llm import SSLVerificationMixin
from nat.data_models.optimizable import OptimizableField
from nat.data_models.optimizable import OptimizableMixin
from nat.data_models.optimizable import SearchSpace
Expand All @@ -36,6 +37,7 @@ class LiteLlmModelConfig(
OptimizableMixin,
RetryMixin,
ThinkingMixin,
SSLVerificationMixin,
name="litellm",
):
"""A LiteLlm provider to be used with an LLM client."""
Expand Down
3 changes: 2 additions & 1 deletion packages/nvidia_nat_core/src/nat/llm/nim_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
from nat.cli.register_workflow import register_llm_provider
from nat.data_models.common import OptionalSecretStr
from nat.data_models.llm import LLMBaseConfig
from nat.data_models.llm import SSLVerificationMixin
from nat.data_models.optimizable import OptimizableField
from nat.data_models.optimizable import OptimizableMixin
from nat.data_models.optimizable import SearchSpace
from nat.data_models.retry_mixin import RetryMixin
from nat.data_models.thinking_mixin import ThinkingMixin


class NIMModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="nim"):
class NIMModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, SSLVerificationMixin, name="nim"):
"""An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client."""

model_config = ConfigDict(protected_namespaces=(), extra="allow")
Expand Down
3 changes: 2 additions & 1 deletion packages/nvidia_nat_core/src/nat/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from nat.cli.register_workflow import register_llm_provider
from nat.data_models.common import OptionalSecretStr
from nat.data_models.llm import LLMBaseConfig
from nat.data_models.llm import SSLVerificationMixin
from nat.data_models.optimizable import OptimizableField
from nat.data_models.optimizable import OptimizableMixin
from nat.data_models.optimizable import SearchSpace
from nat.data_models.retry_mixin import RetryMixin
from nat.data_models.thinking_mixin import ThinkingMixin


class OpenAIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="openai"):
class OpenAIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, SSLVerificationMixin, name="openai"):
"""An OpenAI LLM provider to be used with an LLM client."""

model_config = ConfigDict(protected_namespaces=(), extra="allow")
Expand Down
10 changes: 4 additions & 6 deletions packages/nvidia_nat_core/src/nat/llm/utils/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,20 @@
import httpx

from nat.llm.utils.constants import LLMHeaderPrefix
from nat.llm.utils.http_client import _create_http_client

logger = logging.getLogger(__name__)


def create_metadata_injection_client(timeout: float = 600.0) -> "httpx.AsyncClient":
def _create_metadata_injection_client(llm_config: "LLMBaseConfig") -> "httpx.AsyncClient":
"""
Httpx event hook that injects custom metadata as HTTP headers.

This client injects custom payload fields as X-Payload-* HTTP headers,
enabling end-to-end traceability in LLM server logs.

Args:
timeout: HTTP request timeout in seconds
llm_config: LLM configuration object

Returns:
An httpx.AsyncClient configured with metadata header injection
Expand All @@ -63,7 +64,4 @@ async def on_request(request: httpx.Request) -> None:
except Exception as e:
logger.debug("Could not inject custom metadata headers, request will proceed without them: %s", e)

return httpx.AsyncClient(
event_hooks={"request": [on_request]},
timeout=httpx.Timeout(timeout),
)
return _create_http_client(llm_config=llm_config, use_async=True, event_hooks={"request": [on_request]})
Loading