From b1bfe395941ef5ae11779db7319e28e3bbb10bdf Mon Sep 17 00:00:00 2001 From: prashanthkvs Date: Mon, 4 May 2026 12:36:45 -0700 Subject: [PATCH 1/2] feat: Add AWS Bedrock Mantle client with SigV4 authentication Add AwsOpenAI and AsyncAwsOpenAI clients that support AWS SigV4 request signing for Bedrock Mantle APIs, alongside standard API key auth. Key changes: - Add src/openai/lib/aws.py with sync and async client classes that sign requests using botocore's SigV4Auth - Support custom credential providers (sync and async) as well as automatic credential resolution via the default botocore chain - Export AwsOpenAI and AsyncAwsOpenAI from the top-level package - Add examples for basic usage and STS assume-role credential refresh - Add comprehensive test suite covering SigV4 signing, credential resolution, API key fallback, and copy/with_options behavior - Add botocore as a dev dependency in pyproject.toml - Fix all pyright and mypy lint errors for botocore type stubs --- README.md | 35 ++ examples/aws_client.py | 80 +++++ examples/aws_credential_provider.py | 144 ++++++++ pyproject.toml | 2 + src/openai/__init__.py | 4 + src/openai/lib/aws.py | 345 +++++++++++++++++++ tests/lib/test_aws.py | 502 ++++++++++++++++++++++++++++ 7 files changed, 1112 insertions(+) create mode 100644 examples/aws_client.py create mode 100644 examples/aws_credential_provider.py create mode 100644 src/openai/lib/aws.py create mode 100644 tests/lib/test_aws.py diff --git a/README.md b/README.md index 9450c0bc51..409106fe4d 100644 --- a/README.md +++ b/README.md @@ -933,6 +933,41 @@ In addition to the options provided in the base `OpenAI` client, the following o An example of using the client with Microsoft Entra ID (formerly known as Azure Active Directory) can be found [here](https://github.com/openai/openai-python/blob/main/examples/azure_ad.py). +## AWS Bedrock Mantle + +To use this library with [AWS Bedrock Mantle](https://docs.aws.amazon.com/bedrock/), use the `AwsOpenAI` +class instead of the `OpenAI` class. + +> [!IMPORTANT] +> This requires `botocore` to be installed for SigV4 request signing. Install it with: `pip install 'openai[aws]'` + +```py +from openai import AwsOpenAI + +# uses the default botocore credential chain (env vars, ~/.aws/credentials, IAM role, etc.) +client = AwsOpenAI( + region="us-west-2", +) + +completion = client.chat.completions.create( + model="openai.gpt-oss-120b", + messages=[ + { + "role": "user", + "content": "How do I output all files in a directory using Python?", + }, + ], +) +print(completion.choices[0].message.content) +``` + +In addition to the options provided in the base `OpenAI` client, the following options are provided: + +- `region` (or the `AWS_REGION` / `AWS_DEFAULT_REGION` environment variable) +- `credential_provider` - a callable that returns credentials with `access_key`, `secret_key`, and optional `token` attributes + +An example of using the client with a custom credential provider and STS assume-role refresh can be found [here](https://github.com/openai/openai-python/blob/main/examples/aws_credential_provider.py). + ## Versioning This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions: diff --git a/examples/aws_client.py b/examples/aws_client.py new file mode 100644 index 0000000000..8fdd572919 --- /dev/null +++ b/examples/aws_client.py @@ -0,0 +1,80 @@ +"""Example: Using AwsOpenAI (sync) and AsyncAwsOpenAI (async) with SigV4 signing. + +Requires: + - botocore installed (pip install botocore) + - AWS credentials configured (env vars, ~/.aws/credentials, IAM role, etc.) + - AWS_REGION or AWS_DEFAULT_REGION set (or pass region= explicitly) + +Run: + export AWS_REGION=us-west-2 + PYTHONPATH=src python3 examples/bedrock_mantle.py +""" + +import asyncio + +from openai.lib.aws import AwsOpenAI, AsyncAwsOpenAI + +# --- Synchronous usage --- + +client = AwsOpenAI(region="us-west-2") + +response = client.chat.completions.create( + model="openai.gpt-oss-120b", + messages=[{"role": "user", "content": "Hello, how are you?"}], +) + +print("Sync:", response.choices[0].message.content) + + +# --- Asynchronous usage --- + + +async def main() -> None: + async_client = AsyncAwsOpenAI(region="us-west-2") + + response = await async_client.chat.completions.create( + model="openai.gpt-oss-120b", + messages=[{"role": "user", "content": "Hello from async!"}], + ) + + print("Async:", response.choices[0].message.content) + + +asyncio.run(main()) + + +# --- Streaming usage (sync) --- + +print("\nStreaming: ", end="") +stream = client.chat.completions.create( + model="openai.gpt-oss-120b", + messages=[{"role": "user", "content": "Count from 1 to 5."}], + stream=True, +) +for chunk in stream: + delta = chunk.choices[0].delta.content + if delta: + print(delta, end="", flush=True) +print() + + +# --- Streaming usage (async) --- + + +async def stream_async() -> None: + async_client = AsyncAwsOpenAI(region="us-west-2") + + print("Async streaming: ", end="") + stream = await async_client.chat.completions.create( + model="openai.gpt-oss-120b", + messages=[{"role": "user", "content": "Count from 1 to 5."}], + stream=True, + ) + async for chunk in stream: + delta = chunk.choices[0].delta.content + if delta: + print(delta, end="", flush=True) + print() + + +asyncio.run(stream_async()) diff --git a/examples/aws_credential_provider.py b/examples/aws_credential_provider.py new file mode 100644 index 0000000000..d8b4b3be9e --- /dev/null +++ b/examples/aws_credential_provider.py @@ -0,0 +1,144 @@ +"""Example: Using AwsOpenAI with a custom credential provider and auto-refresh. + +This shows how to: + 1. Use a custom credential provider that returns fresh credentials on each call + 2. Use botocore's RefreshableCredentials for automatic STS assume-role refresh + 3. Use an async credential provider with AsyncAwsOpenAI + +Requires: + - botocore installed (pip install botocore) + - boto3 installed (pip install boto3) — for the STS assume-role example + - AWS credentials configured for the initial session + - AWS_REGION or AWS_DEFAULT_REGION set (or pass region= explicitly) + +Run: + export AWS_REGION=us-west-2 + PYTHONPATH=src python3 examples/bedrock_mantle_credential_provider.py +""" + +from __future__ import annotations + +import asyncio +from typing import Any, Callable +from dataclasses import dataclass + +from openai.lib.aws import AwsOpenAI, AsyncAwsOpenAI + +# --------------------------------------------------------------------------- +# 1. Simple custom credential provider +# --------------------------------------------------------------------------- + + +@dataclass +class MyCredentials: + """Minimal object satisfying the Credentials protocol.""" + + access_key: str + secret_key: str + token: str | None = None + + +def my_credential_provider() -> MyCredentials: + """Return credentials from your own secret store, vault, etc. + + This callable is invoked before every request, so returning fresh + credentials here is all you need for auto-refresh. + """ + # Replace with your actual credential fetching logic + return MyCredentials( + access_key="AKIA...", + secret_key="wJalr...", + token="FwoGZX...", # optional session token + ) + + +client = AwsOpenAI( + region="us-west-2", + credential_provider=my_credential_provider, +) + +response = client.chat.completions.create( + model="openai.gpt-oss-120b", + messages=[{"role": "user", "content": "Hello from custom credentials!"}], +) +print("Custom provider:", response.choices[0].message.content) + + +# --------------------------------------------------------------------------- +# 2. Auto-refreshing STS assume-role credentials via botocore +# --------------------------------------------------------------------------- + + +def make_sts_credential_provider(role_arn: str, session_name: str = "bedrock-mantle") -> Callable[[], Any]: + """Create a credential provider that assumes an IAM role and auto-refreshes. + + botocore's RefreshableCredentials handles expiry checks and refresh + transparently — accessing .access_key / .secret_key / .token on the + returned object triggers a refresh if the credentials are expired. + """ + import botocore.session # type: ignore[import-untyped, import-not-found] + import botocore.credentials # type: ignore[import-untyped, import-not-found] + + session: Any = botocore.session.get_session() # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + sts: Any = session.create_client("sts") # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + + def fetch_credentials() -> dict[str, Any]: + resp: Any = sts.assume_role(RoleArn=role_arn, RoleSessionName=session_name)["Credentials"] # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + return { + "access_key": resp["AccessKeyId"], + "secret_key": resp["SecretAccessKey"], + "token": resp["SessionToken"], + "expiry_time": resp["Expiration"].isoformat(), # pyright: ignore[reportUnknownMemberType] + } + + refreshable: Any = botocore.credentials.RefreshableCredentials.create_from_metadata( # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + metadata=fetch_credentials(), + refresh_using=fetch_credentials, + method="sts-assume-role", + ) + + # Return a provider that gives back the refreshable object. + # Accessing its attributes auto-refreshes when expired. + def provider() -> Any: + return refreshable # pyright: ignore[reportUnknownVariableType] + + return provider + + +# Uncomment to use: +# sts_client = AwsOpenAI( +# region="us-west-2", +# credential_provider=make_sts_credential_provider("arn:aws:iam::123456789012:role/MyRole"), +# ) + + +# --------------------------------------------------------------------------- +# 3. Async credential provider +# --------------------------------------------------------------------------- + + +async def async_credential_provider() -> MyCredentials: + """An async provider — useful when credentials come from an async API.""" + # Simulate async credential fetch (e.g., from an async HTTP vault client) + await asyncio.sleep(0) + return MyCredentials( + access_key="AKIA...", + secret_key="wJalr...", + token="FwoGZX...", + ) + + +async def main() -> None: + async_client = AsyncAwsOpenAI( + region="us-west-2", + credential_provider=async_credential_provider, + ) + + response = await async_client.chat.completions.create( + model="openai.gpt-oss-120b", + messages=[{"role": "user", "content": "Hello from async credentials!"}], + ) + print("Async provider:", response.choices[0].message.content) + + +asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index b2f4dd11cb..63d0751202 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"] realtime = ["websockets >= 13, < 16"] datalib = ["numpy >= 1", "pandas >= 1.2.3", "pandas-stubs >= 1.1.0.11"] voice_helpers = ["sounddevice>=0.5.1", "numpy>=2.0.2"] +aws = ["botocore >= 1.29.0"] [tool.rye] managed = true @@ -68,6 +69,7 @@ dev-dependencies = [ "rich>=13.7.1", "inline-snapshot>=0.28.0", "azure-identity >=1.14.1", + "botocore >=1.29.0", "types-tqdm > 4", "types-pyaudio > 0", "trio >=0.22.2", diff --git a/src/openai/__init__.py b/src/openai/__init__.py index fc9675a8b5..3179f97fe3 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -97,6 +97,10 @@ from ._utils._resources_proxy import resources as resources from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool +from .lib.aws import ( + AwsOpenAI as AwsOpenAI, + AsyncAwsOpenAI as AsyncAwsOpenAI, +) from .version import VERSION as VERSION from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI from .lib._old_api import * diff --git a/src/openai/lib/aws.py b/src/openai/lib/aws.py new file mode 100644 index 0000000000..6d88f9df1f --- /dev/null +++ b/src/openai/lib/aws.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +import os +import inspect +from typing import Any, Union, Mapping, Callable, Awaitable +from typing_extensions import Self, override + +import httpx + +from .._types import NOT_GIVEN, Timeout, NotGiven +from .._client import OpenAI, AsyncOpenAI +from .._exceptions import OpenAIError +from .._base_client import DEFAULT_MAX_RETRIES + +# Sentinel API key used when SigV4 mode is active, so the base OpenAI +# constructor (which requires a non-None api_key) is satisfied. +API_KEY_SENTINEL = "" + +# A credential provider is a callable that returns a botocore-compatible +# credentials object (with access_key, secret_key, token attributes). +CredentialProvider = Callable[[], Any] +AsyncCredentialProvider = Callable[[], "Union[Any, Awaitable[Any]]"] + + +def _ensure_botocore() -> None: + """Raise OpenAIError if botocore is not installed.""" + try: + import botocore # type: ignore[import-untyped] # noqa: F401 # pyright: ignore[reportMissingTypeStubs, reportUnusedImport] + except ImportError as err: + raise OpenAIError( + "botocore must be installed for SigV4 authentication. Install it with: pip install 'openai[aws]'" + ) from err + + +def _get_default_credentials() -> Any: + """Resolve the botocore credentials object via the default credential chain. + + Returns the unfrozen botocore credentials object so that RefreshableCredentials + (e.g. from IAM roles, EC2 instance profiles, ECS task roles) can auto-refresh + on subsequent access. Raises OpenAIError if botocore is not installed or + credentials cannot be resolved. + """ + _ensure_botocore() + import botocore.session # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] + + session = botocore.session.get_session() # pyright: ignore[reportUnknownMemberType] + creds = session.get_credentials() # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + if creds is None: + raise OpenAIError("Could not resolve AWS credentials from the default botocore credential chain.") + # Validate that credentials are usable by checking a frozen snapshot, + # but return the unfrozen object so RefreshableCredentials can auto-refresh. + frozen = creds.get_frozen_credentials() # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + if not frozen.access_key or not frozen.secret_key: # pyright: ignore[reportUnknownMemberType] + raise OpenAIError("Could not resolve AWS credentials from the default botocore credential chain.") + return creds # pyright: ignore[reportUnknownVariableType] + + +def _sign_httpx_request( + request: httpx.Request, + credentials: Any, + region: str, + service: str = "bedrock", +) -> None: + """Sign an httpx.Request in-place using botocore's SigV4Auth.""" + import botocore.auth # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] + import botocore.awsrequest # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] + + # Exclude httpx transport-level headers that cause SigV4 signature mismatch + _HEADERS_TO_EXCLUDE = {"connection", "accept-encoding"} + clean_headers = {k: v for k, v in request.headers.items() if k.lower() not in _HEADERS_TO_EXCLUDE} + + # Convert httpx.Request → botocore.awsrequest.AWSRequest + aws_request = botocore.awsrequest.AWSRequest( + method=request.method, + url=str(request.url), + headers=clean_headers, + data=request.content, + ) + + signer = botocore.auth.SigV4Auth(credentials, service, region) + signer.add_auth(aws_request) # pyright: ignore[reportUnknownMemberType] + + # Copy signed headers back into the httpx request + for key, value in aws_request.headers.items(): + request.headers[key] = value + + +# --------------------------------------------------------------------------- +# Shared init / credential resolution logic +# --------------------------------------------------------------------------- + + +def _resolve_bedrock_mantle_config( + *, + api_key: str | None, + credential_provider: Any | None, + region: str | None, + base_url: str | None, +) -> tuple[bool, Any | None, str, str, Any | None]: + """Shared constructor logic for both sync and async clients. + + Returns (use_sigv4, credential_provider, region, base_url, botocore_credentials). + """ + # Normalize: treat the sentinel as "no api_key provided" + if api_key == API_KEY_SENTINEL: + api_key = None + + if api_key is not None and credential_provider is not None: + raise OpenAIError("api_key and credential_provider are mutually exclusive") + + # Determine auth mode + if api_key is not None: + use_sigv4 = False + credential_provider = None + else: + use_sigv4 = True + + # Resolve region (needed for SigV4 and base_url fallback) + resolved_region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") + if use_sigv4 and not resolved_region: + raise ValueError("Must provide region or set AWS_REGION / AWS_DEFAULT_REGION environment variable") + resolved_region = resolved_region or "" + + # Resolve base_url — fall back to region-derived endpoint + if base_url is None: + if not resolved_region: + raise ValueError("Must provide base_url, or set region / AWS_REGION / AWS_DEFAULT_REGION") + base_url = f"https://bedrock-mantle.{resolved_region}.api.aws/v1" + + # Resolve botocore credentials if needed + botocore_credentials: Any = None + if use_sigv4 and credential_provider is None: + botocore_credentials = _get_default_credentials() + elif use_sigv4: + # SigV4 signing always requires botocore (for botocore.auth.SigV4Auth) + _ensure_botocore() + + return use_sigv4, credential_provider, resolved_region, base_url, botocore_credentials + + +def _resolve_credentials_sync( + credential_provider: CredentialProvider | None, + botocore_credentials: Any | None, +) -> Any: + """Resolve credentials for a sync request. Raises OpenAIError on failure.""" + try: + if credential_provider is not None: + return credential_provider() + # Call get_frozen_credentials() at signing time so that + # RefreshableCredentials can auto-refresh expired tokens. + if botocore_credentials is not None and hasattr(botocore_credentials, "get_frozen_credentials"): + return botocore_credentials.get_frozen_credentials() # pyright: ignore[reportUnknownMemberType] + return botocore_credentials + except OpenAIError: + raise + except Exception as e: + raise OpenAIError(f"Failed to refresh AWS credentials: {e}") from e + + +async def _resolve_credentials_async( + credential_provider: AsyncCredentialProvider | None, + botocore_credentials: Any | None, +) -> Any: + """Resolve credentials for an async request. Raises OpenAIError on failure.""" + try: + if credential_provider is not None: + credentials = credential_provider() + if inspect.isawaitable(credentials): + credentials = await credentials + return credentials + # Call get_frozen_credentials() at signing time so that + # RefreshableCredentials can auto-refresh expired tokens. + if botocore_credentials is not None and hasattr(botocore_credentials, "get_frozen_credentials"): + return botocore_credentials.get_frozen_credentials() # pyright: ignore[reportUnknownMemberType] + return botocore_credentials + except OpenAIError: + raise + except Exception as e: + raise OpenAIError(f"Failed to refresh AWS credentials: {e}") from e + + +# --------------------------------------------------------------------------- +# Client classes +# --------------------------------------------------------------------------- + + +class AwsOpenAI(OpenAI): + """OpenAI-compatible client for AWS Bedrock Mantle APIs. + + Supports SigV4 request signing and API key authentication. + """ + + _region: str + _credential_provider: CredentialProvider | None + _use_sigv4: bool + _botocore_credentials: Any + + def __init__( + self, + *, + base_url: str | None = None, + region: str | None = None, + credential_provider: CredentialProvider | None = None, + api_key: str | None = None, + timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + http_client: httpx.Client | None = None, + _strict_response_validation: bool = False, + **kwargs: Any, + ) -> None: + ( + self._use_sigv4, + self._credential_provider, + self._region, + base_url, + self._botocore_credentials, + ) = _resolve_bedrock_mantle_config( + api_key=api_key, + credential_provider=credential_provider, + region=region, + base_url=base_url, + ) + + super().__init__( + api_key=api_key or API_KEY_SENTINEL, + base_url=base_url, + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + default_query=default_query, + http_client=http_client, + _strict_response_validation=_strict_response_validation, + **kwargs, + ) + + @override + def _prepare_request(self, request: httpx.Request) -> None: + if not self._use_sigv4: + return + credentials = _resolve_credentials_sync(self._credential_provider, self._botocore_credentials) + _sign_httpx_request(request, credentials, self._region) + + @override + def copy( + self, + *, + region: str | None = None, + credential_provider: CredentialProvider | None = None, + _extra_kwargs: Mapping[str, Any] = {}, + **kwargs: Any, + ) -> Self: + """Create a new client instance re-using the same options given to the current client with optional overriding.""" + return super().copy( + **kwargs, + _extra_kwargs={ + "region": region or self._region, + "credential_provider": credential_provider or self._credential_provider, + **_extra_kwargs, + }, + ) + + with_options = copy + + +class AsyncAwsOpenAI(AsyncOpenAI): + """Async OpenAI-compatible client for AWS Bedrock Mantle APIs. + + Supports SigV4 request signing and API key authentication. + """ + + _region: str + _credential_provider: AsyncCredentialProvider | None + _use_sigv4: bool + _botocore_credentials: Any + + def __init__( + self, + *, + base_url: str | None = None, + region: str | None = None, + credential_provider: AsyncCredentialProvider | None = None, + api_key: str | None = None, + timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + http_client: httpx.AsyncClient | None = None, + _strict_response_validation: bool = False, + **kwargs: Any, + ) -> None: + ( + self._use_sigv4, + self._credential_provider, + self._region, + base_url, + self._botocore_credentials, + ) = _resolve_bedrock_mantle_config( + api_key=api_key, + credential_provider=credential_provider, + region=region, + base_url=base_url, + ) + + super().__init__( + api_key=api_key or API_KEY_SENTINEL, + base_url=base_url, + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + default_query=default_query, + http_client=http_client, + _strict_response_validation=_strict_response_validation, + **kwargs, + ) + + @override + async def _prepare_request(self, request: httpx.Request) -> None: + if not self._use_sigv4: + return + credentials = await _resolve_credentials_async(self._credential_provider, self._botocore_credentials) + _sign_httpx_request(request, credentials, self._region) + + @override + def copy( + self, + *, + region: str | None = None, + credential_provider: AsyncCredentialProvider | None = None, + _extra_kwargs: Mapping[str, Any] = {}, + **kwargs: Any, + ) -> Self: + """Create a new client instance re-using the same options given to the current client with optional overriding.""" + return super().copy( + **kwargs, + _extra_kwargs={ + "region": region or self._region, + "credential_provider": credential_provider or self._credential_provider, + **_extra_kwargs, + }, + ) + + with_options = copy + diff --git a/tests/lib/test_aws.py b/tests/lib/test_aws.py new file mode 100644 index 0000000000..5153c48618 --- /dev/null +++ b/tests/lib/test_aws.py @@ -0,0 +1,502 @@ +"""Unit tests for AwsOpenAI and AsyncAwsOpenAI.""" + +from __future__ import annotations + +from typing import Any, Union +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from openai import OpenAI, AsyncOpenAI +from openai.lib.aws import ( + AwsOpenAI, + AsyncAwsOpenAI, +) +from openai._exceptions import OpenAIError + +Client = Union[AwsOpenAI, AsyncAwsOpenAI] + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_frozen_creds( + access_key: str = "AKID", + secret_key: str = "secret", + token: str | None = "tok", +) -> MagicMock: + """Return a mock botocore frozen-credentials object.""" + creds = MagicMock() + creds.access_key = access_key + creds.secret_key = secret_key + creds.token = token + return creds + + +def _make_unfrozen_creds( + access_key: str = "AKID", + secret_key: str = "secret", + token: str | None = "tok", +) -> MagicMock: + """Return a mock unfrozen botocore credentials object with get_frozen_credentials().""" + frozen = _make_frozen_creds(access_key, secret_key, token) + creds = MagicMock() + creds.get_frozen_credentials.return_value = frozen + creds.access_key = access_key + creds.secret_key = secret_key + creds.token = token + return creds + + +def _patch_default_creds(unfrozen: MagicMock | None = None) -> Any: + """Patch _get_default_credentials to return an unfrozen credentials mock.""" + if unfrozen is None: + unfrozen = _make_unfrozen_creds() + return patch("openai.lib.aws._get_default_credentials", return_value=unfrozen) + + +def _patch_ensure_botocore() -> Any: + """Patch _ensure_botocore so it never raises.""" + return patch("openai.lib.aws._ensure_botocore") + + +# --------------------------------------------------------------------------- +# Constructor: region-derived base_url (Req 1.1, 2.1) +# --------------------------------------------------------------------------- + + +class TestConstructorWithRegionDerivedUrl: + def test_sync_base_url_from_region(self) -> None: + with _patch_default_creds(): + client = AwsOpenAI(region="us-west-2") + assert str(client.base_url) == "https://bedrock-mantle.us-west-2.api.aws/v1/" + + def test_async_base_url_from_region(self) -> None: + with _patch_default_creds(): + client = AsyncAwsOpenAI(region="us-east-1") + assert str(client.base_url) == "https://bedrock-mantle.us-east-1.api.aws/v1/" + + +# --------------------------------------------------------------------------- +# Constructor: explicit base_url (Req 1.2, 2.2) +# --------------------------------------------------------------------------- + + +class TestConstructorWithBaseUrl: + def test_sync_uses_provided_base_url(self) -> None: + with _patch_default_creds(): + client = AwsOpenAI( + base_url="https://custom.example.com/v1", + region="us-west-2", + ) + assert str(client.base_url) == "https://custom.example.com/v1/" + + def test_async_uses_provided_base_url(self) -> None: + with _patch_default_creds(): + client = AsyncAwsOpenAI( + base_url="https://custom.example.com/v1", + region="us-west-2", + ) + assert str(client.base_url) == "https://custom.example.com/v1/" + + +# --------------------------------------------------------------------------- +# Constructor: api_key mode (Req 5.1, 5.2) +# --------------------------------------------------------------------------- + + +class TestConstructorWithApiKey: + def test_sync_api_key_mode(self) -> None: + client = AwsOpenAI( + api_key="my-key", + base_url="https://example.com/v1", + ) + assert client._use_sigv4 is False + assert client.api_key == "my-key" + + def test_async_api_key_mode(self) -> None: + client = AsyncAwsOpenAI( + api_key="my-key", + base_url="https://example.com/v1", + ) + assert client._use_sigv4 is False + assert client.api_key == "my-key" + + def test_sync_api_key_no_region_required(self) -> None: + with patch.dict("os.environ", {"AWS_REGION": "", "AWS_DEFAULT_REGION": ""}): + client = AwsOpenAI( + api_key="my-key", + base_url="https://example.com/v1", + ) + assert client._region == "" + + def test_async_api_key_no_region_required(self) -> None: + with patch.dict("os.environ", {"AWS_REGION": "", "AWS_DEFAULT_REGION": ""}): + client = AsyncAwsOpenAI( + api_key="my-key", + base_url="https://example.com/v1", + ) + assert client._region == "" + + +# --------------------------------------------------------------------------- +# isinstance checks (Req 1.4, 2.4) +# --------------------------------------------------------------------------- + + +class TestInheritance: + def test_sync_is_instance_of_openai(self) -> None: + with _patch_default_creds(): + client = AwsOpenAI(region="us-west-2") + assert isinstance(client, OpenAI) + + def test_async_is_instance_of_async_openai(self) -> None: + with _patch_default_creds(): + client = AsyncAwsOpenAI(region="us-west-2") + assert isinstance(client, AsyncOpenAI) + + +# --------------------------------------------------------------------------- +# copy() / with_options() preserves Bedrock Mantle fields (Req 1.4, 2.4) +# --------------------------------------------------------------------------- + + +class TestCopyPreservesFields: + def test_sync_copy_preserves_region_and_sigv4(self) -> None: + with _patch_default_creds(): + client = AwsOpenAI(region="us-west-2") + copied = client.copy() + assert copied._region == "us-west-2" + assert copied._use_sigv4 is True + + def test_sync_with_options_preserves_region(self) -> None: + with _patch_default_creds(): + client = AwsOpenAI(region="us-west-2") + copied = client.with_options() + assert copied._region == "us-west-2" + + def test_async_copy_preserves_region_and_sigv4(self) -> None: + with _patch_default_creds(): + client = AsyncAwsOpenAI(region="us-west-2") + copied = client.copy() + assert copied._region == "us-west-2" + assert copied._use_sigv4 is True + + def test_async_with_options_preserves_region(self) -> None: + with _patch_default_creds(): + client = AsyncAwsOpenAI(region="us-west-2") + copied = client.with_options() + assert copied._region == "us-west-2" + + def test_sync_copy_preserves_credential_provider(self) -> None: + provider = lambda: _make_frozen_creds() + with _patch_ensure_botocore(): + client = AwsOpenAI(region="us-west-2", credential_provider=provider) + copied = client.copy() + assert copied._credential_provider is provider + + def test_async_copy_preserves_credential_provider(self) -> None: + provider = lambda: _make_frozen_creds() + with _patch_ensure_botocore(): + client = AsyncAwsOpenAI(region="us-west-2", credential_provider=provider) + copied = client.copy() + assert copied._credential_provider is provider + + def test_sync_copy_overrides_region(self) -> None: + with _patch_default_creds(): + client = AwsOpenAI(region="us-west-2") + copied = client.copy(region="eu-west-1") + assert copied._region == "eu-west-1" + + def test_async_copy_overrides_region(self) -> None: + with _patch_default_creds(): + client = AsyncAwsOpenAI(region="us-west-2") + copied = client.copy(region="eu-west-1") + assert copied._region == "eu-west-1" + + def test_sync_copy_returns_same_type(self) -> None: + with _patch_default_creds(): + client = AwsOpenAI(region="us-west-2") + copied = client.copy() + assert type(copied) is AwsOpenAI + + def test_async_copy_returns_same_type(self) -> None: + with _patch_default_creds(): + client = AsyncAwsOpenAI(region="us-west-2") + copied = client.copy() + assert type(copied) is AsyncAwsOpenAI + + +# --------------------------------------------------------------------------- +# Default botocore credential chain fallback (Req 4.4, 4.5) +# --------------------------------------------------------------------------- + + +class TestDefaultBotocoreCredentials: + def test_sync_uses_botocore_default_chain(self) -> None: + unfrozen = _make_unfrozen_creds("AKID_DEFAULT", "secret_default", "tok_default") + with _patch_default_creds(unfrozen): + client = AwsOpenAI(region="us-west-2") + assert client._botocore_credentials is unfrozen + assert client._credential_provider is None + + def test_async_uses_botocore_default_chain(self) -> None: + unfrozen = _make_unfrozen_creds("AKID_DEFAULT", "secret_default", "tok_default") + with _patch_default_creds(unfrozen): + client = AsyncAwsOpenAI(region="us-west-2") + assert client._botocore_credentials is unfrozen + assert client._credential_provider is None + + def test_sync_raises_when_botocore_missing(self) -> None: + with patch( + "openai.lib.aws._get_default_credentials", + side_effect=OpenAIError("botocore must be installed"), + ): + with pytest.raises(OpenAIError, match="botocore must be installed"): + AwsOpenAI(region="us-west-2") + + def test_async_raises_when_botocore_missing(self) -> None: + with patch( + "openai.lib.aws._get_default_credentials", + side_effect=OpenAIError("botocore must be installed"), + ): + with pytest.raises(OpenAIError, match="botocore must be installed"): + AsyncAwsOpenAI(region="us-west-2") + + def test_sync_raises_when_no_creds_resolved(self) -> None: + with patch( + "openai.lib.aws._get_default_credentials", + side_effect=OpenAIError("Could not resolve AWS credentials"), + ): + with pytest.raises(OpenAIError, match="Could not resolve AWS credentials"): + AwsOpenAI(region="us-west-2") + + +# --------------------------------------------------------------------------- +# Credential refresh failure wrapping (Req 6.3) +# --------------------------------------------------------------------------- + + +class TestCredentialRefreshFailure: + def test_sync_wraps_provider_error_in_openai_error(self) -> None: + def bad_provider() -> None: + raise RuntimeError("token expired") + + with _patch_ensure_botocore(): + client = AwsOpenAI(region="us-west-2", credential_provider=bad_provider) + + request = httpx.Request("POST", "https://example.com/v1/chat/completions", content=b'{"model":"x"}') + with pytest.raises(OpenAIError, match="Failed to refresh AWS credentials: token expired"): + client._prepare_request(request) + + async def test_async_wraps_provider_error_in_openai_error(self) -> None: + def bad_provider() -> None: + raise RuntimeError("token expired") + + with _patch_ensure_botocore(): + client = AsyncAwsOpenAI(region="us-west-2", credential_provider=bad_provider) + + request = httpx.Request("POST", "https://example.com/v1/chat/completions", content=b'{"model":"x"}') + with pytest.raises(OpenAIError, match="Failed to refresh AWS credentials: token expired"): + await client._prepare_request(request) + + async def test_async_wraps_async_provider_error(self) -> None: + async def bad_async_provider() -> None: + raise ValueError("async refresh failed") + + with _patch_ensure_botocore(): + client = AsyncAwsOpenAI(region="us-west-2", credential_provider=bad_async_provider) + + request = httpx.Request("POST", "https://example.com/v1/chat/completions", content=b'{"model":"x"}') + with pytest.raises(OpenAIError, match="Failed to refresh AWS credentials: async refresh failed"): + await client._prepare_request(request) + + def test_sync_openai_error_not_double_wrapped(self) -> None: + """If the provider raises OpenAIError directly, it should propagate as-is.""" + + def provider_raises_openai_error() -> None: + raise OpenAIError("custom auth failure") + + with _patch_ensure_botocore(): + client = AwsOpenAI(region="us-west-2", credential_provider=provider_raises_openai_error) + + request = httpx.Request("POST", "https://example.com/v1/chat/completions", content=b'{"model":"x"}') + with pytest.raises(OpenAIError, match="custom auth failure"): + client._prepare_request(request) + + +# --------------------------------------------------------------------------- +# Mutual exclusivity validations (Req 1.3, 2.3, 5.3) +# --------------------------------------------------------------------------- + + +class TestMutualExclusivity: + def test_sync_api_key_and_credential_provider_raises(self) -> None: + with pytest.raises(OpenAIError, match="api_key and credential_provider are mutually exclusive"): + AwsOpenAI( + api_key="my-key", + credential_provider=lambda: _make_frozen_creds(), + base_url="https://example.com/v1", + region="us-west-2", + ) + + def test_async_api_key_and_credential_provider_raises(self) -> None: + with pytest.raises(OpenAIError, match="api_key and credential_provider are mutually exclusive"): + AsyncAwsOpenAI( + api_key="my-key", + credential_provider=lambda: _make_frozen_creds(), + base_url="https://example.com/v1", + region="us-west-2", + ) + + def test_sync_no_base_url_no_region_raises(self) -> None: + with patch.dict("os.environ", {"AWS_REGION": "", "AWS_DEFAULT_REGION": ""}): + with pytest.raises(ValueError, match="Must provide region"): + AwsOpenAI() + + def test_async_no_base_url_no_region_raises(self) -> None: + with patch.dict("os.environ", {"AWS_REGION": "", "AWS_DEFAULT_REGION": ""}): + with pytest.raises(ValueError, match="Must provide region"): + AsyncAwsOpenAI() + + +# --------------------------------------------------------------------------- +# Region resolution from env vars (Req 7.2) +# --------------------------------------------------------------------------- + + +class TestRegionFromEnv: + def test_sync_aws_region_env(self) -> None: + with patch.dict("os.environ", {"AWS_REGION": "ap-southeast-1", "AWS_DEFAULT_REGION": ""}): + with _patch_default_creds(): + client = AwsOpenAI() + assert client._region == "ap-southeast-1" + + def test_sync_aws_default_region_env(self) -> None: + with patch.dict("os.environ", {"AWS_DEFAULT_REGION": "eu-central-1"}, clear=False): + import os + + orig = os.environ.pop("AWS_REGION", None) + try: + with _patch_default_creds(): + client = AwsOpenAI() + assert client._region == "eu-central-1" + finally: + if orig is not None: + os.environ["AWS_REGION"] = orig + + def test_sync_aws_region_takes_precedence(self) -> None: + with patch.dict("os.environ", {"AWS_REGION": "us-west-2", "AWS_DEFAULT_REGION": "eu-west-1"}): + with _patch_default_creds(): + client = AwsOpenAI() + assert client._region == "us-west-2" + + def test_async_aws_region_env(self) -> None: + with patch.dict("os.environ", {"AWS_REGION": "ap-southeast-1", "AWS_DEFAULT_REGION": ""}): + with _patch_default_creds(): + client = AsyncAwsOpenAI() + assert client._region == "ap-southeast-1" + + +# --------------------------------------------------------------------------- +# Responses API: SigV4 headers injected for responses.create (sync + async) +# --------------------------------------------------------------------------- + +_MOCK_RESPONSE_JSON: dict[str, object] = { + "id": "resp_test", + "object": "response", + "created_at": 1700000000, + "status": "completed", + "model": "gpt-4o-mini", + "output": [ + { + "type": "message", + "id": "msg_test", + "status": "completed", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello!", "annotations": []}], + } + ], + "parallel_tool_calls": True, + "tool_choice": "auto", + "text": {"format": {"type": "text"}}, + "tools": [], +} + + +def _fake_sign(request: httpx.Request, creds: object, region: str) -> None: + """Inject fake SigV4 headers for testing.""" + ak = getattr(creds, "access_key", "UNKNOWN") + tok = getattr(creds, "token", None) + request.headers["authorization"] = f"AWS4-HMAC-SHA256 Credential={ak}/{region}/bedrock/aws4_request" + request.headers["x-amz-date"] = "20240101T000000Z" + if tok: + request.headers["x-amz-security-token"] = tok + + +class TestResponsesApiSigV4: + """Verify that SigV4 headers are injected when calling the Responses API.""" + + def test_sync_responses_create_has_sigv4_headers(self) -> None: + captured: dict[str, str] = {} + + def capture_handler(request: httpx.Request) -> httpx.Response: + captured.update(dict(request.headers)) + return httpx.Response(200, json=_MOCK_RESPONSE_JSON) + + transport = httpx.MockTransport(capture_handler) + provider = lambda: _make_frozen_creds("AKID_TEST", "secret_test", "session_tok") + with _patch_ensure_botocore(), patch("openai.lib.aws._sign_httpx_request", side_effect=_fake_sign): + client = AwsOpenAI( + region="us-west-2", + credential_provider=provider, + http_client=httpx.Client(transport=transport), + ) + client.responses.create(model="gpt-4o-mini", input="Hello") + + assert "AWS4-HMAC-SHA256" in captured.get("authorization", "") + assert captured.get("x-amz-date") == "20240101T000000Z" + assert captured.get("x-amz-security-token") == "session_tok" + + async def test_async_responses_create_has_sigv4_headers(self) -> None: + captured: dict[str, str] = {} + + async def capture_handler(request: httpx.Request) -> httpx.Response: + captured.update(dict(request.headers)) + return httpx.Response(200, json=_MOCK_RESPONSE_JSON) + + transport = httpx.MockTransport(capture_handler) + provider = lambda: _make_frozen_creds("AKID_ASYNC", "secret_async", "async_tok") + with _patch_ensure_botocore(), patch("openai.lib.aws._sign_httpx_request", side_effect=_fake_sign): + client = AsyncAwsOpenAI( + region="us-west-2", + credential_provider=provider, + http_client=httpx.AsyncClient(transport=transport), + ) + await client.responses.create(model="gpt-4o-mini", input="Hello") + + assert "AWS4-HMAC-SHA256" in captured.get("authorization", "") + assert captured.get("x-amz-date") == "20240101T000000Z" + assert captured.get("x-amz-security-token") == "async_tok" + + def test_sync_responses_create_api_key_no_sigv4(self) -> None: + """In API key mode, responses.create should use Bearer auth, no SigV4 headers.""" + captured: dict[str, str] = {} + + def capture_handler(request: httpx.Request) -> httpx.Response: + captured.update(dict(request.headers)) + return httpx.Response(200, json=_MOCK_RESPONSE_JSON) + + transport = httpx.MockTransport(capture_handler) + client = AwsOpenAI( + api_key="my-api-key", + base_url="https://example.com/v1", + http_client=httpx.Client(transport=transport), + ) + client.responses.create(model="gpt-4o-mini", input="Hello") + + assert captured.get("authorization") == "Bearer my-api-key" + assert "x-amz-date" not in captured + assert "x-amz-security-token" not in captured From 65a27071183929f755f5bd17fa5955bc25ff14ac Mon Sep 17 00:00:00 2001 From: prashanthkvs Date: Mon, 4 May 2026 12:36:45 -0700 Subject: [PATCH 2/2] fix: Address PR review feedback for AWS Bedrock client - Make API key sentinel private with non-guessable value and identity check - Use OpenAIError consistently instead of ValueError for config errors - Fix incorrect filenames in example docstrings - Update README with correct Bedrock Mantle docs link and Responses API example - Move _HEADERS_TO_EXCLUDE to module-level frozenset - Defer credential resolution from __init__ to _prepare_request() - Use asyncio.to_thread in AsyncAwsOpenAI to avoid blocking the event loop - Use NamedTuple for _resolve_bedrock_mantle_config return type - Add test_async_raises_when_no_creds_resolved test --- README.md | 8 +-- examples/aws_client.py | 2 +- examples/aws_credential_provider.py | 2 +- src/openai/lib/aws.py | 61 +++++++++++++------- tests/lib/test_aws.py | 86 ++++++++++++++++++++--------- 5 files changed, 106 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index 409106fe4d..cb3e9d54fb 100644 --- a/README.md +++ b/README.md @@ -935,7 +935,7 @@ An example of using the client with Microsoft Entra ID (formerly known as Azure ## AWS Bedrock Mantle -To use this library with [AWS Bedrock Mantle](https://docs.aws.amazon.com/bedrock/), use the `AwsOpenAI` +To use this library with [AWS Bedrock Mantle](https://docs.aws.amazon.com/bedrock/latest/userguide/bedrock-mantle.html), use the `AwsOpenAI` class instead of the `OpenAI` class. > [!IMPORTANT] @@ -949,16 +949,16 @@ client = AwsOpenAI( region="us-west-2", ) -completion = client.chat.completions.create( +response = client.responses.create( model="openai.gpt-oss-120b", - messages=[ + input=[ { "role": "user", "content": "How do I output all files in a directory using Python?", }, ], ) -print(completion.choices[0].message.content) +print(response.output_text) ``` In addition to the options provided in the base `OpenAI` client, the following options are provided: diff --git a/examples/aws_client.py b/examples/aws_client.py index 8fdd572919..62380db0c3 100644 --- a/examples/aws_client.py +++ b/examples/aws_client.py @@ -7,7 +7,7 @@ Run: export AWS_REGION=us-west-2 - PYTHONPATH=src python3 examples/bedrock_mantle.py + PYTHONPATH=src python3 examples/aws_client.py """ import asyncio diff --git a/examples/aws_credential_provider.py b/examples/aws_credential_provider.py index d8b4b3be9e..7ac830f30e 100644 --- a/examples/aws_credential_provider.py +++ b/examples/aws_credential_provider.py @@ -13,7 +13,7 @@ Run: export AWS_REGION=us-west-2 - PYTHONPATH=src python3 examples/bedrock_mantle_credential_provider.py + PYTHONPATH=src python3 examples/aws_credential_provider.py """ from __future__ import annotations diff --git a/src/openai/lib/aws.py b/src/openai/lib/aws.py index 6d88f9df1f..a3b7b6738d 100644 --- a/src/openai/lib/aws.py +++ b/src/openai/lib/aws.py @@ -1,8 +1,9 @@ from __future__ import annotations import os +import asyncio import inspect -from typing import Any, Union, Mapping, Callable, Awaitable +from typing import Any, Union, Mapping, Callable, Awaitable, NamedTuple from typing_extensions import Self, override import httpx @@ -14,7 +15,10 @@ # Sentinel API key used when SigV4 mode is active, so the base OpenAI # constructor (which requires a non-None api_key) is satisfied. -API_KEY_SENTINEL = "" +_API_KEY_SENTINEL = "" + +# Exclude httpx transport-level headers that cause SigV4 signature mismatch +_HEADERS_TO_EXCLUDE = frozenset({"connection", "accept-encoding"}) # A credential provider is a callable that returns a botocore-compatible # credentials object (with access_key, secret_key, token attributes). @@ -65,8 +69,6 @@ def _sign_httpx_request( import botocore.auth # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] import botocore.awsrequest # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] - # Exclude httpx transport-level headers that cause SigV4 signature mismatch - _HEADERS_TO_EXCLUDE = {"connection", "accept-encoding"} clean_headers = {k: v for k, v in request.headers.items() if k.lower() not in _HEADERS_TO_EXCLUDE} # Convert httpx.Request → botocore.awsrequest.AWSRequest @@ -90,19 +92,30 @@ def _sign_httpx_request( # --------------------------------------------------------------------------- +class _BedrockMantleConfig(NamedTuple): + use_sigv4: bool + credential_provider: Any | None + region: str + base_url: str + + def _resolve_bedrock_mantle_config( *, api_key: str | None, credential_provider: Any | None, region: str | None, base_url: str | None, -) -> tuple[bool, Any | None, str, str, Any | None]: +) -> _BedrockMantleConfig: """Shared constructor logic for both sync and async clients. - Returns (use_sigv4, credential_provider, region, base_url, botocore_credentials). + Validates configuration and resolves region/base_url, but does NOT resolve + credentials — that is deferred to _prepare_request() so that construction + never performs blocking I/O. + + Returns (use_sigv4, credential_provider, region, base_url). """ # Normalize: treat the sentinel as "no api_key provided" - if api_key == API_KEY_SENTINEL: + if api_key is _API_KEY_SENTINEL: api_key = None if api_key is not None and credential_provider is not None: @@ -118,24 +131,20 @@ def _resolve_bedrock_mantle_config( # Resolve region (needed for SigV4 and base_url fallback) resolved_region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") if use_sigv4 and not resolved_region: - raise ValueError("Must provide region or set AWS_REGION / AWS_DEFAULT_REGION environment variable") + raise OpenAIError("Must provide region or set AWS_REGION / AWS_DEFAULT_REGION environment variable") resolved_region = resolved_region or "" # Resolve base_url — fall back to region-derived endpoint if base_url is None: if not resolved_region: - raise ValueError("Must provide base_url, or set region / AWS_REGION / AWS_DEFAULT_REGION") + raise OpenAIError("Must provide base_url, or set region / AWS_REGION / AWS_DEFAULT_REGION") base_url = f"https://bedrock-mantle.{resolved_region}.api.aws/v1" - # Resolve botocore credentials if needed - botocore_credentials: Any = None - if use_sigv4 and credential_provider is None: - botocore_credentials = _get_default_credentials() - elif use_sigv4: - # SigV4 signing always requires botocore (for botocore.auth.SigV4Auth) + # Verify botocore is available when SigV4 is needed + if use_sigv4: _ensure_botocore() - return use_sigv4, credential_provider, resolved_region, base_url, botocore_credentials + return _BedrockMantleConfig(use_sigv4, credential_provider, resolved_region, base_url) def _resolve_credentials_sync( @@ -188,6 +197,7 @@ class AwsOpenAI(OpenAI): """OpenAI-compatible client for AWS Bedrock Mantle APIs. Supports SigV4 request signing and API key authentication. + Credentials are resolved lazily on the first request, not at construction time. """ _region: str @@ -215,16 +225,16 @@ def __init__( self._credential_provider, self._region, base_url, - self._botocore_credentials, ) = _resolve_bedrock_mantle_config( api_key=api_key, credential_provider=credential_provider, region=region, base_url=base_url, ) + self._botocore_credentials: Any = None super().__init__( - api_key=api_key or API_KEY_SENTINEL, + api_key=api_key or _API_KEY_SENTINEL, base_url=base_url, timeout=timeout, max_retries=max_retries, @@ -239,6 +249,9 @@ def __init__( def _prepare_request(self, request: httpx.Request) -> None: if not self._use_sigv4: return + # Lazily resolve default credentials on first request + if self._credential_provider is None and self._botocore_credentials is None: + self._botocore_credentials = _get_default_credentials() credentials = _resolve_credentials_sync(self._credential_provider, self._botocore_credentials) _sign_httpx_request(request, credentials, self._region) @@ -268,6 +281,8 @@ class AsyncAwsOpenAI(AsyncOpenAI): """Async OpenAI-compatible client for AWS Bedrock Mantle APIs. Supports SigV4 request signing and API key authentication. + Credentials are resolved lazily on the first request using asyncio.to_thread + to avoid blocking the event loop. """ _region: str @@ -295,16 +310,17 @@ def __init__( self._credential_provider, self._region, base_url, - self._botocore_credentials, ) = _resolve_bedrock_mantle_config( api_key=api_key, credential_provider=credential_provider, region=region, base_url=base_url, ) + self._botocore_credentials: Any = None + self._creds_lock = asyncio.Lock() super().__init__( - api_key=api_key or API_KEY_SENTINEL, + api_key=api_key or _API_KEY_SENTINEL, base_url=base_url, timeout=timeout, max_retries=max_retries, @@ -319,6 +335,11 @@ def __init__( async def _prepare_request(self, request: httpx.Request) -> None: if not self._use_sigv4: return + # Lazily resolve default credentials on first request, off the event loop + if self._credential_provider is None and self._botocore_credentials is None: + async with self._creds_lock: + if self._botocore_credentials is None: + self._botocore_credentials = await asyncio.to_thread(_get_default_credentials) credentials = await _resolve_credentials_async(self._credential_provider, self._botocore_credentials) _sign_httpx_request(request, credentials, self._region) diff --git a/tests/lib/test_aws.py b/tests/lib/test_aws.py index 5153c48618..f98b27ec68 100644 --- a/tests/lib/test_aws.py +++ b/tests/lib/test_aws.py @@ -69,12 +69,12 @@ def _patch_ensure_botocore() -> Any: class TestConstructorWithRegionDerivedUrl: def test_sync_base_url_from_region(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AwsOpenAI(region="us-west-2") assert str(client.base_url) == "https://bedrock-mantle.us-west-2.api.aws/v1/" def test_async_base_url_from_region(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AsyncAwsOpenAI(region="us-east-1") assert str(client.base_url) == "https://bedrock-mantle.us-east-1.api.aws/v1/" @@ -86,7 +86,7 @@ def test_async_base_url_from_region(self) -> None: class TestConstructorWithBaseUrl: def test_sync_uses_provided_base_url(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AwsOpenAI( base_url="https://custom.example.com/v1", region="us-west-2", @@ -94,7 +94,7 @@ def test_sync_uses_provided_base_url(self) -> None: assert str(client.base_url) == "https://custom.example.com/v1/" def test_async_uses_provided_base_url(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AsyncAwsOpenAI( base_url="https://custom.example.com/v1", region="us-west-2", @@ -148,12 +148,12 @@ def test_async_api_key_no_region_required(self) -> None: class TestInheritance: def test_sync_is_instance_of_openai(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AwsOpenAI(region="us-west-2") assert isinstance(client, OpenAI) def test_async_is_instance_of_async_openai(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AsyncAwsOpenAI(region="us-west-2") assert isinstance(client, AsyncOpenAI) @@ -165,27 +165,27 @@ def test_async_is_instance_of_async_openai(self) -> None: class TestCopyPreservesFields: def test_sync_copy_preserves_region_and_sigv4(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AwsOpenAI(region="us-west-2") copied = client.copy() assert copied._region == "us-west-2" assert copied._use_sigv4 is True def test_sync_with_options_preserves_region(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AwsOpenAI(region="us-west-2") copied = client.with_options() assert copied._region == "us-west-2" def test_async_copy_preserves_region_and_sigv4(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AsyncAwsOpenAI(region="us-west-2") copied = client.copy() assert copied._region == "us-west-2" assert copied._use_sigv4 is True def test_async_with_options_preserves_region(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AsyncAwsOpenAI(region="us-west-2") copied = client.with_options() assert copied._region == "us-west-2" @@ -205,25 +205,25 @@ def test_async_copy_preserves_credential_provider(self) -> None: assert copied._credential_provider is provider def test_sync_copy_overrides_region(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AwsOpenAI(region="us-west-2") copied = client.copy(region="eu-west-1") assert copied._region == "eu-west-1" def test_async_copy_overrides_region(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AsyncAwsOpenAI(region="us-west-2") copied = client.copy(region="eu-west-1") assert copied._region == "eu-west-1" def test_sync_copy_returns_same_type(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AwsOpenAI(region="us-west-2") copied = client.copy() assert type(copied) is AwsOpenAI def test_async_copy_returns_same_type(self) -> None: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AsyncAwsOpenAI(region="us-west-2") copied = client.copy() assert type(copied) is AsyncAwsOpenAI @@ -236,22 +236,32 @@ def test_async_copy_returns_same_type(self) -> None: class TestDefaultBotocoreCredentials: def test_sync_uses_botocore_default_chain(self) -> None: + """Credentials are resolved lazily on first request, not at construction.""" unfrozen = _make_unfrozen_creds("AKID_DEFAULT", "secret_default", "tok_default") - with _patch_default_creds(unfrozen): + with _patch_ensure_botocore(): client = AwsOpenAI(region="us-west-2") - assert client._botocore_credentials is unfrozen + # Not yet resolved at construction time + assert client._botocore_credentials is None assert client._credential_provider is None + # Resolve on first _prepare_request call + request = httpx.Request("POST", "https://example.com/v1/chat/completions", content=b'{"model":"x"}') + with _patch_default_creds(unfrozen), patch("openai.lib.aws._sign_httpx_request"): + client._prepare_request(request) + assert client._botocore_credentials is unfrozen + def test_async_uses_botocore_default_chain(self) -> None: + """Credentials are resolved lazily on first request, not at construction.""" unfrozen = _make_unfrozen_creds("AKID_DEFAULT", "secret_default", "tok_default") - with _patch_default_creds(unfrozen): + with _patch_ensure_botocore(): client = AsyncAwsOpenAI(region="us-west-2") - assert client._botocore_credentials is unfrozen + # Not yet resolved at construction time + assert client._botocore_credentials is None assert client._credential_provider is None def test_sync_raises_when_botocore_missing(self) -> None: with patch( - "openai.lib.aws._get_default_credentials", + "openai.lib.aws._ensure_botocore", side_effect=OpenAIError("botocore must be installed"), ): with pytest.raises(OpenAIError, match="botocore must be installed"): @@ -259,19 +269,38 @@ def test_sync_raises_when_botocore_missing(self) -> None: def test_async_raises_when_botocore_missing(self) -> None: with patch( - "openai.lib.aws._get_default_credentials", + "openai.lib.aws._ensure_botocore", side_effect=OpenAIError("botocore must be installed"), ): with pytest.raises(OpenAIError, match="botocore must be installed"): AsyncAwsOpenAI(region="us-west-2") def test_sync_raises_when_no_creds_resolved(self) -> None: + """Error is raised at request time, not construction time.""" + with _patch_ensure_botocore(): + client = AwsOpenAI(region="us-west-2") + + request = httpx.Request("POST", "https://example.com/v1/chat/completions", content=b'{"model":"x"}') with patch( "openai.lib.aws._get_default_credentials", side_effect=OpenAIError("Could not resolve AWS credentials"), ): with pytest.raises(OpenAIError, match="Could not resolve AWS credentials"): - AwsOpenAI(region="us-west-2") + client._prepare_request(request) + + @pytest.mark.asyncio + async def test_async_raises_when_no_creds_resolved(self) -> None: + """Error is raised at request time, not construction time.""" + with _patch_ensure_botocore(): + client = AsyncAwsOpenAI(region="us-west-2") + + request = httpx.Request("POST", "https://example.com/v1/chat/completions", content=b'{"model":"x"}') + with patch( + "openai.lib.aws._get_default_credentials", + side_effect=OpenAIError("Could not resolve AWS credentials"), + ): + with pytest.raises(OpenAIError, match="Could not resolve AWS credentials"): + await client._prepare_request(request) # --------------------------------------------------------------------------- @@ -291,6 +320,7 @@ def bad_provider() -> None: with pytest.raises(OpenAIError, match="Failed to refresh AWS credentials: token expired"): client._prepare_request(request) + @pytest.mark.asyncio async def test_async_wraps_provider_error_in_openai_error(self) -> None: def bad_provider() -> None: raise RuntimeError("token expired") @@ -302,6 +332,7 @@ def bad_provider() -> None: with pytest.raises(OpenAIError, match="Failed to refresh AWS credentials: token expired"): await client._prepare_request(request) + @pytest.mark.asyncio async def test_async_wraps_async_provider_error(self) -> None: async def bad_async_provider() -> None: raise ValueError("async refresh failed") @@ -353,12 +384,12 @@ def test_async_api_key_and_credential_provider_raises(self) -> None: def test_sync_no_base_url_no_region_raises(self) -> None: with patch.dict("os.environ", {"AWS_REGION": "", "AWS_DEFAULT_REGION": ""}): - with pytest.raises(ValueError, match="Must provide region"): + with pytest.raises(OpenAIError, match="Must provide region"): AwsOpenAI() def test_async_no_base_url_no_region_raises(self) -> None: with patch.dict("os.environ", {"AWS_REGION": "", "AWS_DEFAULT_REGION": ""}): - with pytest.raises(ValueError, match="Must provide region"): + with pytest.raises(OpenAIError, match="Must provide region"): AsyncAwsOpenAI() @@ -370,7 +401,7 @@ def test_async_no_base_url_no_region_raises(self) -> None: class TestRegionFromEnv: def test_sync_aws_region_env(self) -> None: with patch.dict("os.environ", {"AWS_REGION": "ap-southeast-1", "AWS_DEFAULT_REGION": ""}): - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AwsOpenAI() assert client._region == "ap-southeast-1" @@ -380,7 +411,7 @@ def test_sync_aws_default_region_env(self) -> None: orig = os.environ.pop("AWS_REGION", None) try: - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AwsOpenAI() assert client._region == "eu-central-1" finally: @@ -389,13 +420,13 @@ def test_sync_aws_default_region_env(self) -> None: def test_sync_aws_region_takes_precedence(self) -> None: with patch.dict("os.environ", {"AWS_REGION": "us-west-2", "AWS_DEFAULT_REGION": "eu-west-1"}): - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AwsOpenAI() assert client._region == "us-west-2" def test_async_aws_region_env(self) -> None: with patch.dict("os.environ", {"AWS_REGION": "ap-southeast-1", "AWS_DEFAULT_REGION": ""}): - with _patch_default_creds(): + with _patch_ensure_botocore(): client = AsyncAwsOpenAI() assert client._region == "ap-southeast-1" @@ -460,6 +491,7 @@ def capture_handler(request: httpx.Request) -> httpx.Response: assert captured.get("x-amz-date") == "20240101T000000Z" assert captured.get("x-amz-security-token") == "session_tok" + @pytest.mark.asyncio async def test_async_responses_create_has_sigv4_headers(self) -> None: captured: dict[str, str] = {}