Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/oauth/src/keycardai/oauth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
TokenExchangeError,
)
from .http.auth import AuthStrategy, BasicAuth, BearerAuth, MultiZoneBasicAuth, NoneAuth
from .operations._authorize import build_authorize_url
from .types.models import (
PKCE,
AuthorizationServerMetadata,
Expand Down Expand Up @@ -83,6 +84,8 @@
"ClientRegistrationRequest",
"TokenExchangeRequest",
"AuthorizationServerMetadata",
# === Authorization ===
"build_authorize_url",
# === OAuth Enums ===
"GrantType",
"ResponseType",
Expand Down
135 changes: 135 additions & 0 deletions packages/oauth/src/keycardai/oauth/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
NoneAuth,
)
from .http.transport import AsyncHTTPTransport, HTTPTransport
from .operations._authorize import (
exchange_authorization_code as _exchange_authorization_code,
exchange_authorization_code_async as _exchange_authorization_code_async,
)
from .operations._discovery import (
discover_server_metadata,
discover_server_metadata_async,
Expand Down Expand Up @@ -387,6 +391,25 @@ async def get_client_secret(self) -> str | None:
)
return self._client_secret

async def get_endpoints(self) -> "Endpoints":
"""Get the resolved endpoint URLs.

Returns endpoints resolved during initialization, incorporating
any discovered metadata and explicit overrides.

Returns:
Resolved Endpoints configuration.

Raises:
RuntimeError: If called outside of async context manager.
"""
if not self._initialized:
raise RuntimeError(
"AsyncClient must be used within 'async with' statement. "
"Use 'async with AsyncClient(...) as client:' to properly initialize."
)
return self._discovered_endpoints or self._endpoints

async def _get_current_endpoints(self) -> "Endpoints":
"""Get current endpoints from cached discovery.

Expand Down Expand Up @@ -644,6 +667,55 @@ async def exchange_token(self, request: TokenExchangeRequest | None = None, /, *

return await exchange_token_async(request, ctx)

async def exchange_authorization_code(
self,
*,
code: str,
redirect_uri: str,
code_verifier: str,
client_id: str | None = None,
timeout: float | None = None,
) -> TokenResponse:
"""Exchange an authorization code for tokens.

Supports both public clients (client_id in the body, no auth header)
and confidential clients (auth via the client's auth strategy).

Args:
code: The authorization code from the callback.
redirect_uri: The redirect URI used in the authorize request.
code_verifier: The PKCE code verifier.
client_id: Client ID for the form body. Required for public
clients. Optional for confidential clients where identity
is provided via the auth strategy.
timeout: Optional request timeout override.

Returns:
TokenResponse with access token and metadata.

Raises:
OAuthHttpError: If the token endpoint returns an HTTP error.
OAuthProtocolError: If the response contains an OAuth error.
"""
endpoints = await self._get_current_endpoints()

ctx = build_http_context(
endpoint=endpoints.token,
transport=self.transport,
auth=self.auth_strategy,
user_agent=self.config.user_agent,
custom_headers=self.config.custom_headers,
timeout=timeout or self.config.timeout,
)

return await _exchange_authorization_code_async(
code=code,
redirect_uri=redirect_uri,
code_verifier=code_verifier,
client_id=client_id,
context=ctx,
)

def endpoints_summary(self) -> dict[str, dict[str, str]]:
"""Get diagnostic summary of resolved endpoints.

Expand Down Expand Up @@ -844,6 +916,20 @@ def client_secret(self) -> str | None:
self._ensure_initialized()
return self._client_secret

@property
def endpoints(self) -> "Endpoints":
"""Resolved endpoint URLs (lazily initialized).

Returns endpoints resolved during initialization, incorporating
any discovered metadata and explicit overrides.

Accessing this property will trigger automatic initialization if needed.

Returns:
Resolved Endpoints configuration.
"""
self._ensure_initialized()
return self._discovered_endpoints or self._endpoints

@overload
def register_client(
Expand Down Expand Up @@ -1091,6 +1177,55 @@ def exchange_token(self, request: TokenExchangeRequest | None = None, /, **token

return exchange_token(request, ctx)

def exchange_authorization_code(
self,
*,
code: str,
redirect_uri: str,
code_verifier: str,
client_id: str | None = None,
timeout: float | None = None,
) -> TokenResponse:
"""Exchange an authorization code for tokens.

Supports both public clients (client_id in the body, no auth header)
and confidential clients (auth via the client's auth strategy).

Args:
code: The authorization code from the callback.
redirect_uri: The redirect URI used in the authorize request.
code_verifier: The PKCE code verifier.
client_id: Client ID for the form body. Required for public
clients. Optional for confidential clients where identity
is provided via the auth strategy.
timeout: Optional request timeout override.

Returns:
TokenResponse with access token and metadata.

Raises:
OAuthHttpError: If the token endpoint returns an HTTP error.
OAuthProtocolError: If the response contains an OAuth error.
"""
endpoints = self._get_current_endpoints()

ctx = build_http_context(
endpoint=endpoints.token,
transport=self.transport,
auth=self.auth_strategy,
user_agent=self.config.user_agent,
custom_headers=self.config.custom_headers,
timeout=timeout or self.config.timeout,
)

return _exchange_authorization_code(
code=code,
redirect_uri=redirect_uri,
code_verifier=code_verifier,
client_id=client_id,
context=ctx,
)

def endpoints_summary(self) -> dict[str, dict[str, str]]:
"""Get diagnostic summary of resolved endpoints.

Expand Down
Loading
Loading