From afba7176e7503a40839ba1646899c081ca1b5760 Mon Sep 17 00:00:00 2001 From: Benjamin Boudreau Date: Thu, 2 Apr 2026 12:38:49 -0400 Subject: [PATCH] feat(keycardai-oauth): add authorization code exchange and PKCE support - Implement PKCE code verifier, challenge generation, and validation - Add authorization code exchange operation (sync and async) - Add build_authorize_url for constructing OAuth authorize URLs - Add exchange_authorization_code to Client and AsyncClient - Add get_endpoints/endpoints property to expose resolved endpoints - Add id_token field to TokenResponse --- .../oauth/src/keycardai/oauth/__init__.py | 3 + packages/oauth/src/keycardai/oauth/client.py | 135 ++++++++ .../keycardai/oauth/operations/_authorize.py | 256 ++++++++++++++++ .../oauth/operations/_token_exchange.py | 2 - .../oauth/src/keycardai/oauth/types/models.py | 1 + .../oauth/src/keycardai/oauth/utils/pkce.py | 34 ++- .../oauth/operations/test_authorize.py | 289 ++++++++++++++++++ .../tests/keycardai/oauth/utils/test_pkce.py | 120 ++++++++ 8 files changed, 830 insertions(+), 10 deletions(-) create mode 100644 packages/oauth/src/keycardai/oauth/operations/_authorize.py create mode 100644 packages/oauth/tests/keycardai/oauth/operations/test_authorize.py create mode 100644 packages/oauth/tests/keycardai/oauth/utils/test_pkce.py diff --git a/packages/oauth/src/keycardai/oauth/__init__.py b/packages/oauth/src/keycardai/oauth/__init__.py index ffb644e..34f4c3d 100644 --- a/packages/oauth/src/keycardai/oauth/__init__.py +++ b/packages/oauth/src/keycardai/oauth/__init__.py @@ -41,6 +41,7 @@ TokenExchangeError, ) from .http.auth import AuthStrategy, BasicAuth, BearerAuth, MultiZoneBasicAuth, NoneAuth +from .operations._authorize import build_authorize_url from .types.models import ( PKCE, AuthorizationServerMetadata, @@ -83,6 +84,8 @@ "ClientRegistrationRequest", "TokenExchangeRequest", "AuthorizationServerMetadata", + # === Authorization === + "build_authorize_url", # === OAuth Enums === "GrantType", "ResponseType", diff --git a/packages/oauth/src/keycardai/oauth/client.py b/packages/oauth/src/keycardai/oauth/client.py index a77bd40..30f969d 100644 --- a/packages/oauth/src/keycardai/oauth/client.py +++ b/packages/oauth/src/keycardai/oauth/client.py @@ -18,6 +18,10 @@ NoneAuth, ) from .http.transport import AsyncHTTPTransport, HTTPTransport +from .operations._authorize import ( + exchange_authorization_code as _exchange_authorization_code, + exchange_authorization_code_async as _exchange_authorization_code_async, +) from .operations._discovery import ( discover_server_metadata, discover_server_metadata_async, @@ -387,6 +391,25 @@ async def get_client_secret(self) -> str | None: ) return self._client_secret + async def get_endpoints(self) -> "Endpoints": + """Get the resolved endpoint URLs. + + Returns endpoints resolved during initialization, incorporating + any discovered metadata and explicit overrides. + + Returns: + Resolved Endpoints configuration. + + Raises: + RuntimeError: If called outside of async context manager. + """ + if not self._initialized: + raise RuntimeError( + "AsyncClient must be used within 'async with' statement. " + "Use 'async with AsyncClient(...) as client:' to properly initialize." + ) + return self._discovered_endpoints or self._endpoints + async def _get_current_endpoints(self) -> "Endpoints": """Get current endpoints from cached discovery. @@ -644,6 +667,55 @@ async def exchange_token(self, request: TokenExchangeRequest | None = None, /, * return await exchange_token_async(request, ctx) + async def exchange_authorization_code( + self, + *, + code: str, + redirect_uri: str, + code_verifier: str, + client_id: str | None = None, + timeout: float | None = None, + ) -> TokenResponse: + """Exchange an authorization code for tokens. + + Supports both public clients (client_id in the body, no auth header) + and confidential clients (auth via the client's auth strategy). + + Args: + code: The authorization code from the callback. + redirect_uri: The redirect URI used in the authorize request. + code_verifier: The PKCE code verifier. + client_id: Client ID for the form body. Required for public + clients. Optional for confidential clients where identity + is provided via the auth strategy. + timeout: Optional request timeout override. + + Returns: + TokenResponse with access token and metadata. + + Raises: + OAuthHttpError: If the token endpoint returns an HTTP error. + OAuthProtocolError: If the response contains an OAuth error. + """ + endpoints = await self._get_current_endpoints() + + ctx = build_http_context( + endpoint=endpoints.token, + transport=self.transport, + auth=self.auth_strategy, + user_agent=self.config.user_agent, + custom_headers=self.config.custom_headers, + timeout=timeout or self.config.timeout, + ) + + return await _exchange_authorization_code_async( + code=code, + redirect_uri=redirect_uri, + code_verifier=code_verifier, + client_id=client_id, + context=ctx, + ) + def endpoints_summary(self) -> dict[str, dict[str, str]]: """Get diagnostic summary of resolved endpoints. @@ -844,6 +916,20 @@ def client_secret(self) -> str | None: self._ensure_initialized() return self._client_secret + @property + def endpoints(self) -> "Endpoints": + """Resolved endpoint URLs (lazily initialized). + + Returns endpoints resolved during initialization, incorporating + any discovered metadata and explicit overrides. + + Accessing this property will trigger automatic initialization if needed. + + Returns: + Resolved Endpoints configuration. + """ + self._ensure_initialized() + return self._discovered_endpoints or self._endpoints @overload def register_client( @@ -1091,6 +1177,55 @@ def exchange_token(self, request: TokenExchangeRequest | None = None, /, **token return exchange_token(request, ctx) + def exchange_authorization_code( + self, + *, + code: str, + redirect_uri: str, + code_verifier: str, + client_id: str | None = None, + timeout: float | None = None, + ) -> TokenResponse: + """Exchange an authorization code for tokens. + + Supports both public clients (client_id in the body, no auth header) + and confidential clients (auth via the client's auth strategy). + + Args: + code: The authorization code from the callback. + redirect_uri: The redirect URI used in the authorize request. + code_verifier: The PKCE code verifier. + client_id: Client ID for the form body. Required for public + clients. Optional for confidential clients where identity + is provided via the auth strategy. + timeout: Optional request timeout override. + + Returns: + TokenResponse with access token and metadata. + + Raises: + OAuthHttpError: If the token endpoint returns an HTTP error. + OAuthProtocolError: If the response contains an OAuth error. + """ + endpoints = self._get_current_endpoints() + + ctx = build_http_context( + endpoint=endpoints.token, + transport=self.transport, + auth=self.auth_strategy, + user_agent=self.config.user_agent, + custom_headers=self.config.custom_headers, + timeout=timeout or self.config.timeout, + ) + + return _exchange_authorization_code( + code=code, + redirect_uri=redirect_uri, + code_verifier=code_verifier, + client_id=client_id, + context=ctx, + ) + def endpoints_summary(self) -> dict[str, dict[str, str]]: """Get diagnostic summary of resolved endpoints. diff --git a/packages/oauth/src/keycardai/oauth/operations/_authorize.py b/packages/oauth/src/keycardai/oauth/operations/_authorize.py new file mode 100644 index 0000000..bd447f9 --- /dev/null +++ b/packages/oauth/src/keycardai/oauth/operations/_authorize.py @@ -0,0 +1,256 @@ +"""OAuth 2.0 Authorization Code operations. + +This module implements the authorization URL construction and authorization +code exchange for OAuth 2.0 authorization code flows (RFC 6749 Section 4.1) +with PKCE support (RFC 7636). +""" + +import json +from urllib.parse import urlencode + +from ..exceptions import OAuthHttpError, OAuthProtocolError +from ..http._context import HTTPContext +from ..http._wire import HttpRequest, HttpResponse +from ..types.models import TokenResponse +from ..utils.pkce import PKCEChallenge + + +def build_authorize_url( + authorize_endpoint: str, + *, + client_id: str, + redirect_uri: str, + pkce: PKCEChallenge, + resources: list[str] | None = None, + scope: str | None = None, + state: str | None = None, +) -> str: + """Build an OAuth 2.0 authorization URL with PKCE. + + Constructs the full authorization URL including PKCE challenge parameters + and multiple resource parameters per RFC 8707. + + Args: + authorize_endpoint: The authorization endpoint URL. + client_id: The OAuth client ID. + redirect_uri: The redirect URI for the callback. + pkce: PKCE challenge/verifier pair. + resources: Resource URIs to request (each becomes a separate + ``resource`` query parameter per RFC 8707). + scope: Space-separated scope string. + state: Opaque state value for CSRF protection. + + Returns: + The complete authorization URL string. + """ + params: dict[str, str | list[str]] = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": redirect_uri, + "code_challenge": pkce.code_challenge, + "code_challenge_method": pkce.code_challenge_method, + } + if resources: + params["resource"] = resources + if scope: + params["scope"] = scope + if state: + params["state"] = state + + return f"{authorize_endpoint}?{urlencode(params, doseq=True)}" + + +# --------------------------------------------------------------------------- +# Authorization code exchange +# --------------------------------------------------------------------------- + +def build_authorization_code_http_request( + *, + code: str, + redirect_uri: str, + code_verifier: str, + client_id: str | None, + context: HTTPContext, +) -> HttpRequest: + """Build the HTTP request for an authorization code exchange. + + Args: + code: The authorization code from the callback. + redirect_uri: The redirect URI used in the authorize request. + code_verifier: The PKCE code verifier. + client_id: Client ID to include in the form body (required for + public clients, optional for confidential clients). + context: HTTP context with endpoint, transport, and auth. + + Returns: + HttpRequest ready to send. + """ + payload: dict[str, str] = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "code_verifier": code_verifier, + } + if client_id is not None: + payload["client_id"] = client_id + + headers = { + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded", + } + if context.auth: + headers.update(dict(context.auth.apply_headers())) + + form_data = urlencode(payload).encode("utf-8") + + return HttpRequest( + method="POST", + url=context.endpoint, + headers=headers, + body=form_data, + ) + + +def parse_authorization_code_http_response(res: HttpResponse) -> TokenResponse: + """Parse the token endpoint response from an authorization code exchange. + + Args: + res: HTTP response from the token endpoint. + + Returns: + TokenResponse with tokens and metadata. + + Raises: + OAuthProtocolError: If the response contains an OAuth error. + OAuthHttpError: If the HTTP status indicates an error. + """ + if res.status >= 400: + full_body = res.body.decode("utf-8", "ignore") + try: + data = json.loads(full_body) + if isinstance(data, dict) and "error" in data: + raise OAuthProtocolError( + error=data["error"], + error_description=data.get("error_description"), + error_uri=data.get("error_uri"), + operation="POST /token (authorization_code)", + ) + except (json.JSONDecodeError, ValueError): + pass + raise OAuthHttpError( + status_code=res.status, + response_body=full_body[:512], + headers=dict(res.headers), + operation="POST /token (authorization_code)", + ) + + try: + data = json.loads(res.body.decode("utf-8")) + except Exception as e: + raise OAuthProtocolError( + error="invalid_response", + error_description="Invalid JSON in authorization code response", + operation="POST /token (authorization_code)", + ) from e + + if isinstance(data, dict) and "error" in data: + raise OAuthProtocolError( + error=data["error"], + error_description=data.get("error_description"), + error_uri=data.get("error_uri"), + operation="POST /token (authorization_code)", + ) + + if not isinstance(data, dict) or "access_token" not in data: + raise OAuthProtocolError( + error="invalid_response", + error_description="Missing required 'access_token' in authorization code response", + operation="POST /token (authorization_code)", + ) + + scope = data.get("scope") + if isinstance(scope, str): + scope = scope.split() if scope else None + elif isinstance(scope, list): + scope = scope if scope else None + + return TokenResponse( + access_token=data["access_token"], + token_type=data.get("token_type", "Bearer"), + expires_in=data.get("expires_in"), + refresh_token=data.get("refresh_token"), + id_token=data.get("id_token"), + scope=scope, + raw=data, + headers=dict(res.headers), + ) + + +def exchange_authorization_code( + *, + code: str, + redirect_uri: str, + code_verifier: str, + client_id: str | None = None, + context: HTTPContext, +) -> TokenResponse: + """Exchange an authorization code for tokens (sync). + + Args: + code: The authorization code from the callback. + redirect_uri: The redirect URI used in the authorize request. + code_verifier: The PKCE code verifier. + client_id: Client ID for the form body. Required for public clients. + context: HTTP context with endpoint, transport, and auth. + + Returns: + TokenResponse with tokens. + + Raises: + OAuthHttpError: If the token endpoint returns an HTTP error. + OAuthProtocolError: If the response contains an OAuth error. + """ + http_req = build_authorization_code_http_request( + code=code, + redirect_uri=redirect_uri, + code_verifier=code_verifier, + client_id=client_id, + context=context, + ) + http_res = context.transport.request_raw(http_req, timeout=context.timeout) + return parse_authorization_code_http_response(http_res) + + +async def exchange_authorization_code_async( + *, + code: str, + redirect_uri: str, + code_verifier: str, + client_id: str | None = None, + context: HTTPContext, +) -> TokenResponse: + """Exchange an authorization code for tokens (async). + + Args: + code: The authorization code from the callback. + redirect_uri: The redirect URI used in the authorize request. + code_verifier: The PKCE code verifier. + client_id: Client ID for the form body. Required for public clients. + context: HTTP context with endpoint, transport, and auth. + + Returns: + TokenResponse with tokens. + + Raises: + OAuthHttpError: If the token endpoint returns an HTTP error. + OAuthProtocolError: If the response contains an OAuth error. + """ + http_req = build_authorization_code_http_request( + code=code, + redirect_uri=redirect_uri, + code_verifier=code_verifier, + client_id=client_id, + context=context, + ) + http_res = await context.transport.request_raw(http_req, timeout=context.timeout) + return parse_authorization_code_http_response(http_res) diff --git a/packages/oauth/src/keycardai/oauth/operations/_token_exchange.py b/packages/oauth/src/keycardai/oauth/operations/_token_exchange.py index 42e6e21..a6cb3e0 100644 --- a/packages/oauth/src/keycardai/oauth/operations/_token_exchange.py +++ b/packages/oauth/src/keycardai/oauth/operations/_token_exchange.py @@ -184,8 +184,6 @@ async def exchange_token_async( Reference: https://datatracker.ietf.org/doc/html/rfc8693#section-2.1 """ - # Build HTTP request - http_req = build_token_exchange_http_request(request, context) # Execute HTTP request using async transport diff --git a/packages/oauth/src/keycardai/oauth/types/models.py b/packages/oauth/src/keycardai/oauth/types/models.py index 0a88383..bdabb04 100644 --- a/packages/oauth/src/keycardai/oauth/types/models.py +++ b/packages/oauth/src/keycardai/oauth/types/models.py @@ -65,6 +65,7 @@ class TokenResponse: expires_in: int | None = None refresh_token: str | None = None scope: list[str] | None = None + id_token: str | None = None # RFC 8693 specific fields issued_token_type: TokenType | None = None diff --git a/packages/oauth/src/keycardai/oauth/utils/pkce.py b/packages/oauth/src/keycardai/oauth/utils/pkce.py index 7e009fb..c12a9f4 100644 --- a/packages/oauth/src/keycardai/oauth/utils/pkce.py +++ b/packages/oauth/src/keycardai/oauth/utils/pkce.py @@ -17,6 +17,10 @@ - Enables secure OAuth flows for mobile and SPA applications """ +import base64 +import hashlib +import secrets + from pydantic import BaseModel @@ -86,8 +90,12 @@ def generate_code_verifier(length: int = 128) -> str: Reference: https://datatracker.ietf.org/doc/html/rfc7636#section-4.1 """ - # Implementation placeholder - raise NotImplementedError("PKCE code verifier generation not yet implemented") + if length < 43 or length > 128: + raise ValueError("Code verifier length must be between 43 and 128 characters") + # Scale byte count to requested length so we only generate as much + # entropy as needed (base64url expands 3 bytes into 4 chars). + nbytes = (length * 3 + 3) // 4 + return secrets.token_urlsafe(nbytes)[:length] @staticmethod def generate_code_challenge(verifier: str, method: str = "S256") -> str: @@ -107,8 +115,13 @@ def generate_code_challenge(verifier: str, method: str = "S256") -> str: Reference: https://datatracker.ietf.org/doc/html/rfc7636#section-4.2 """ - # Implementation placeholder - raise NotImplementedError("PKCE code challenge generation not yet implemented") + if method == PKCEMethods.S256: + digest = hashlib.sha256(verifier.encode("ascii")).digest() + return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + elif method == PKCEMethods.PLAIN: + return verifier + else: + raise ValueError(f"Unsupported PKCE method: {method}") def generate_pkce_pair( self, method: str = "S256", verifier_length: int = 128 @@ -127,8 +140,13 @@ def generate_pkce_pair( Reference: https://datatracker.ietf.org/doc/html/rfc7636#section-4 """ - # Implementation placeholder - raise NotImplementedError("PKCE pair generation not yet implemented") + verifier = self.generate_code_verifier(verifier_length) + challenge = self.generate_code_challenge(verifier, method) + return PKCEChallenge( + code_verifier=verifier, + code_challenge=challenge, + code_challenge_method=method, + ) @staticmethod def validate_pkce_pair( @@ -148,5 +166,5 @@ def validate_pkce_pair( Reference: https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 """ - # Implementation placeholder - raise NotImplementedError("PKCE validation not yet implemented") + expected = PKCEGenerator.generate_code_challenge(code_verifier, method) + return secrets.compare_digest(expected, code_challenge) diff --git a/packages/oauth/tests/keycardai/oauth/operations/test_authorize.py b/packages/oauth/tests/keycardai/oauth/operations/test_authorize.py new file mode 100644 index 0000000..cfba47f --- /dev/null +++ b/packages/oauth/tests/keycardai/oauth/operations/test_authorize.py @@ -0,0 +1,289 @@ +"""Unit tests for OAuth 2.0 Authorization Code operations.""" + +from unittest.mock import AsyncMock, Mock +from urllib.parse import parse_qs, urlparse + +import pytest + +from keycardai.oauth.exceptions import OAuthHttpError, OAuthProtocolError +from keycardai.oauth.http._context import HTTPContext +from keycardai.oauth.http._wire import HttpResponse +from keycardai.oauth.http.auth import BasicAuth, NoneAuth +from keycardai.oauth.operations._authorize import ( + build_authorization_code_http_request, + build_authorize_url, + exchange_authorization_code, + exchange_authorization_code_async, + parse_authorization_code_http_response, +) +from keycardai.oauth.types.models import TokenResponse +from keycardai.oauth.utils.pkce import PKCEChallenge + + +class TestBuildAuthorizeUrl: + """Test authorize URL construction.""" + + def _make_pkce(self) -> PKCEChallenge: + return PKCEChallenge( + code_verifier="test_verifier", + code_challenge="test_challenge", + code_challenge_method="S256", + ) + + def test_minimal(self): + url = build_authorize_url( + "https://auth.example.com/authorize", + client_id="my-client", + redirect_uri="http://localhost:9999/callback", + pkce=self._make_pkce(), + ) + parsed = urlparse(url) + qs = parse_qs(parsed.query) + + assert parsed.scheme == "https" + assert parsed.netloc == "auth.example.com" + assert parsed.path == "/authorize" + assert qs["response_type"] == ["code"] + assert qs["client_id"] == ["my-client"] + assert qs["redirect_uri"] == ["http://localhost:9999/callback"] + assert qs["code_challenge"] == ["test_challenge"] + assert qs["code_challenge_method"] == ["S256"] + assert "resource" not in qs + assert "scope" not in qs + assert "state" not in qs + + def test_single_resource(self): + url = build_authorize_url( + "https://auth.example.com/authorize", + client_id="my-client", + redirect_uri="http://localhost:9999/callback", + pkce=self._make_pkce(), + resources=["https://graph.microsoft.com"], + ) + qs = parse_qs(urlparse(url).query) + assert qs["resource"] == ["https://graph.microsoft.com"] + + def test_multiple_resources(self): + url = build_authorize_url( + "https://auth.example.com/authorize", + client_id="my-client", + redirect_uri="http://localhost:9999/callback", + pkce=self._make_pkce(), + resources=[ + "https://graph.microsoft.com", + "https://api.github.com", + "https://api.linear.app", + ], + ) + qs = parse_qs(urlparse(url).query) + assert qs["resource"] == [ + "https://graph.microsoft.com", + "https://api.github.com", + "https://api.linear.app", + ] + + def test_with_scope_and_state(self): + url = build_authorize_url( + "https://auth.example.com/authorize", + client_id="my-client", + redirect_uri="http://localhost:9999/callback", + pkce=self._make_pkce(), + scope="openid email", + state="csrf-token-123", + ) + qs = parse_qs(urlparse(url).query) + assert qs["scope"] == ["openid email"] + assert qs["state"] == ["csrf-token-123"] + + def test_empty_resources_omitted(self): + url = build_authorize_url( + "https://auth.example.com/authorize", + client_id="my-client", + redirect_uri="http://localhost:9999/callback", + pkce=self._make_pkce(), + resources=[], + ) + qs = parse_qs(urlparse(url).query) + assert "resource" not in qs + + +class TestBuildAuthorizationCodeHttpRequest: + """Test HTTP request construction for code exchange.""" + + def test_public_client(self): + auth = NoneAuth() + ctx = HTTPContext( + endpoint="https://auth.example.com/token", + transport=Mock(), + auth=auth, + ) + http_req = build_authorization_code_http_request( + code="AUTH_CODE_123", + redirect_uri="http://localhost:9999/callback", + code_verifier="test_verifier", + client_id="public-client-id", + context=ctx, + ) + + assert http_req.method == "POST" + assert http_req.url == "https://auth.example.com/token" + assert http_req.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert "Authorization" not in http_req.headers + + body = http_req.body.decode("utf-8") + form = parse_qs(body) + assert form["grant_type"] == ["authorization_code"] + assert form["code"] == ["AUTH_CODE_123"] + assert form["redirect_uri"] == ["http://localhost:9999/callback"] + assert form["code_verifier"] == ["test_verifier"] + assert form["client_id"] == ["public-client-id"] + + def test_confidential_client(self): + auth = BasicAuth("conf-client", "conf-secret") + ctx = HTTPContext( + endpoint="https://auth.example.com/token", + transport=Mock(), + auth=auth, + ) + http_req = build_authorization_code_http_request( + code="AUTH_CODE_456", + redirect_uri="http://localhost:9999/callback", + code_verifier="test_verifier", + client_id=None, + context=ctx, + ) + + body = http_req.body.decode("utf-8") + form = parse_qs(body) + assert "client_id" not in form + assert form["grant_type"] == ["authorization_code"] + assert form["code"] == ["AUTH_CODE_456"] + assert "Authorization" in http_req.headers + assert http_req.headers["Authorization"].startswith("Basic ") + + +class TestParseAuthorizationCodeHttpResponse: + """Test response parsing for code exchange.""" + + def test_success(self): + res = HttpResponse( + status=200, + headers={"Content-Type": "application/json"}, + body=b'{"access_token":"at_123","token_type":"Bearer","expires_in":3600,"refresh_token":"rt_456","id_token":"ey.header.sig","scope":"openid email"}', + ) + result = parse_authorization_code_http_response(res) + + assert isinstance(result, TokenResponse) + assert result.access_token == "at_123" + assert result.token_type == "Bearer" + assert result.expires_in == 3600 + assert result.refresh_token == "rt_456" + assert result.id_token == "ey.header.sig" + assert result.scope == ["openid", "email"] + assert result.raw["access_token"] == "at_123" + + def test_minimal_success(self): + res = HttpResponse( + status=200, + headers={"Content-Type": "application/json"}, + body=b'{"access_token":"at_minimal"}', + ) + result = parse_authorization_code_http_response(res) + assert result.access_token == "at_minimal" + assert result.token_type == "Bearer" + assert result.refresh_token is None + assert result.id_token is None + + def test_oauth_error(self): + res = HttpResponse( + status=400, + headers={"Content-Type": "application/json"}, + body=b'{"error":"invalid_grant","error_description":"Code expired"}', + ) + with pytest.raises(OAuthProtocolError, match="invalid_grant") as exc_info: + parse_authorization_code_http_response(res) + assert exc_info.value.error_description == "Code expired" + + def test_http_error_non_json(self): + res = HttpResponse( + status=500, + headers={"Content-Type": "text/plain"}, + body=b"Internal Server Error", + ) + with pytest.raises(OAuthHttpError, match="HTTP 500"): + parse_authorization_code_http_response(res) + + def test_invalid_json(self): + res = HttpResponse( + status=200, + headers={"Content-Type": "application/json"}, + body=b"not json{", + ) + with pytest.raises(OAuthProtocolError, match="Invalid JSON"): + parse_authorization_code_http_response(res) + + def test_missing_access_token(self): + res = HttpResponse( + status=200, + headers={"Content-Type": "application/json"}, + body=b'{"token_type":"Bearer"}', + ) + with pytest.raises(OAuthProtocolError, match="Missing required"): + parse_authorization_code_http_response(res) + + +class TestExchangeAuthorizationCode: + """Test the sync exchange function.""" + + def test_sync_exchange(self): + mock_transport = Mock() + mock_transport.request_raw.return_value = HttpResponse( + status=200, + headers={"Content-Type": "application/json"}, + body=b'{"access_token":"sync_at","token_type":"Bearer","expires_in":3600}', + ) + ctx = HTTPContext( + endpoint="https://auth.example.com/token", + transport=mock_transport, + auth=NoneAuth(), + timeout=30.0, + ) + + result = exchange_authorization_code( + code="CODE", + redirect_uri="http://localhost:9999/callback", + code_verifier="verifier", + client_id="pub-client", + context=ctx, + ) + + assert result.access_token == "sync_at" + assert result.expires_in == 3600 + mock_transport.request_raw.assert_called_once() + + @pytest.mark.asyncio + async def test_async_exchange(self): + mock_transport = AsyncMock() + mock_transport.request_raw.return_value = HttpResponse( + status=200, + headers={"Content-Type": "application/json"}, + body=b'{"access_token":"async_at","token_type":"Bearer","expires_in":7200}', + ) + ctx = HTTPContext( + endpoint="https://auth.example.com/token", + transport=mock_transport, + auth=NoneAuth(), + timeout=30.0, + ) + + result = await exchange_authorization_code_async( + code="CODE", + redirect_uri="http://localhost:9999/callback", + code_verifier="verifier", + client_id="pub-client", + context=ctx, + ) + + assert result.access_token == "async_at" + assert result.expires_in == 7200 + mock_transport.request_raw.assert_called_once() diff --git a/packages/oauth/tests/keycardai/oauth/utils/test_pkce.py b/packages/oauth/tests/keycardai/oauth/utils/test_pkce.py new file mode 100644 index 0000000..b62a8a8 --- /dev/null +++ b/packages/oauth/tests/keycardai/oauth/utils/test_pkce.py @@ -0,0 +1,120 @@ +"""Tests for PKCE utility functions (RFC 7636).""" + +import base64 +import hashlib + +import pytest + +from keycardai.oauth.utils.pkce import PKCEGenerator + + +class TestGenerateCodeVerifier: + """Test code verifier generation per RFC 7636 Section 4.1.""" + + def test_default_length(self): + verifier = PKCEGenerator.generate_code_verifier() + assert len(verifier) == 128 + + def test_minimum_length(self): + verifier = PKCEGenerator.generate_code_verifier(43) + assert len(verifier) == 43 + + def test_maximum_length(self): + verifier = PKCEGenerator.generate_code_verifier(128) + assert len(verifier) == 128 + + def test_too_short_raises(self): + with pytest.raises(ValueError, match="between 43 and 128"): + PKCEGenerator.generate_code_verifier(42) + + def test_too_long_raises(self): + with pytest.raises(ValueError, match="between 43 and 128"): + PKCEGenerator.generate_code_verifier(129) + + def test_uses_unreserved_characters_only(self): + """RFC 7636 Section 4.1: verifier uses [A-Z] [a-z] [0-9] "-" "." "_" "~".""" + verifier = PKCEGenerator.generate_code_verifier() + # token_urlsafe produces [A-Za-z0-9_-], which is a subset of the allowed set. + allowed = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") + assert set(verifier).issubset(allowed) + + def test_unique_across_calls(self): + v1 = PKCEGenerator.generate_code_verifier() + v2 = PKCEGenerator.generate_code_verifier() + assert v1 != v2 + + +class TestGenerateCodeChallenge: + """Test code challenge generation per RFC 7636 Section 4.2.""" + + def test_s256_rfc_appendix_b(self): + """Verify S256 against the test vector from RFC 7636 Appendix B. + + Reference: https://datatracker.ietf.org/doc/html/rfc7636#appendix-B + """ + verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + expected_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + challenge = PKCEGenerator.generate_code_challenge(verifier, "S256") + assert challenge == expected_challenge + + def test_s256_manual_computation(self): + verifier = "test-verifier-string" + digest = hashlib.sha256(verifier.encode("ascii")).digest() + expected = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + assert PKCEGenerator.generate_code_challenge(verifier, "S256") == expected + + def test_plain_returns_verifier(self): + verifier = "some-code-verifier" + challenge = PKCEGenerator.generate_code_challenge(verifier, "plain") + assert challenge == verifier + + def test_unsupported_method_raises(self): + with pytest.raises(ValueError, match="Unsupported PKCE method"): + PKCEGenerator.generate_code_challenge("verifier", "RS256") + + +class TestGeneratePKCEPair: + """Test PKCE pair generation.""" + + def test_s256_pair(self): + gen = PKCEGenerator() + pair = gen.generate_pkce_pair(method="S256", verifier_length=64) + assert len(pair.code_verifier) == 64 + assert pair.code_challenge_method == "S256" + # Challenge should be a valid base64url string, not the verifier itself + assert pair.code_challenge != pair.code_verifier + + def test_plain_pair(self): + gen = PKCEGenerator() + pair = gen.generate_pkce_pair(method="plain", verifier_length=43) + assert pair.code_challenge == pair.code_verifier + assert pair.code_challenge_method == "plain" + + def test_default_method_is_s256(self): + gen = PKCEGenerator() + pair = gen.generate_pkce_pair() + assert pair.code_challenge_method == "S256" + + +class TestValidatePKCEPair: + """Test PKCE pair validation per RFC 7636 Section 4.6.""" + + def test_s256_round_trip(self): + gen = PKCEGenerator() + pair = gen.generate_pkce_pair(method="S256") + assert gen.validate_pkce_pair(pair.code_verifier, pair.code_challenge, "S256") + + def test_plain_round_trip(self): + gen = PKCEGenerator() + pair = gen.generate_pkce_pair(method="plain") + assert gen.validate_pkce_pair(pair.code_verifier, pair.code_challenge, "plain") + + def test_wrong_verifier_fails(self): + gen = PKCEGenerator() + pair = gen.generate_pkce_pair(method="S256") + assert not gen.validate_pkce_pair("wrong-verifier-value-padded-to-be-long-enough", pair.code_challenge, "S256") + + def test_wrong_challenge_fails(self): + gen = PKCEGenerator() + pair = gen.generate_pkce_pair(method="S256") + assert not gen.validate_pkce_pair(pair.code_verifier, "wrong-challenge", "S256")