diff --git a/CHANGELOG.md b/CHANGELOG.md index bbaa80e..f929453 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `CrossAppAccessFlow.start()` now accepts an optional `resource` parameter (RFC 8707), forwarded to the token exchange alongside `audience` and `scope`. +- `OAuth2Error` now exposes an `additional_fields` mapping containing any non-standard keys returned in the error response body, so server-specific remediation hints are no longer discarded. +- `OAuth2Error.from_response()` classmethod builds an error from a parsed OAuth2 error response body, mapping standard RFC 6749 fields to their attributes and collecting the rest into `additional_fields`. ## 0.2.0 diff --git a/README.md b/README.md index 5ad8336..b7ba7af 100644 --- a/README.md +++ b/README.md @@ -652,6 +652,39 @@ flow = CrossAppAccessFlow( +## Error Handling + +Authentication flows raise `OAuth2Error` when the authorization server +returns an error response, or when the SDK detects a protocol violation +locally (e.g., a `state` mismatch on the authorization-code callback). + +```python +from okta_client.authfoundation import OAuth2Error + +try: + token = await flow.start(...) +except OAuth2Error as err: + print(err.error) # RFC 6749 error code, e.g. "invalid_grant" + print(err.error_description) # Human-readable description (if provided) + print(err.error_uri) # Documentation link (if provided) + print(err.status_code) # HTTP status (server responses only) + print(err.request_id) # Request ID header (server responses only) +``` + +Servers sometimes return additional keys alongside the standard fields — +for example `required_acr` and `max_age` on a step-up challenge, or +Okta-specific `errorCauses` / `errorId` values. Any keys the SDK doesn't +already model are preserved verbatim on `OAuth2Error.additional_fields`: + +```python +except OAuth2Error as err: + if err.error == "interaction_required": + required_acr = err.additional_fields.get("required_acr") + # ...re-prompt the user at the requested assurance level +``` + +Locally-raised errors (no server payload) leave `additional_fields` empty. + ## Listeners A common pattern within this SDK is the use of "Listeners" which enable developers to observe key events within the SDK's lifecycle. This permits you to implement some protocol within your application, and add your class instance as a listener to the client or flow you would like to observe. diff --git a/src/okta_client/authfoundation/oauth2/client.py b/src/okta_client/authfoundation/oauth2/client.py index 65fa4ed..86a72d8 100644 --- a/src/okta_client/authfoundation/oauth2/client.py +++ b/src/okta_client/authfoundation/oauth2/client.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from okta_client.authfoundation.oauth2.requests.oauth_authorization_server import OAuthAuthorizationServerRequest -from okta_client.authfoundation.utils import coerce_optional_sequence, coerce_optional_str +from okta_client.authfoundation.utils import coerce_optional_sequence from ..coalesced_result import CoalescedResult from ..networking import APIClient, APIClientListener, APIResponse, NetworkInterface @@ -413,10 +413,8 @@ def _raise_for_oauth2_error( except Exception: error = None if error is None and ("error" in result or response.status_code >= 400): - error = OAuth2Error( - error=str(result.get("error", "oauth2_error")), - error_description=coerce_optional_str(result.get("error_description")), - error_uri=coerce_optional_str(result.get("error_uri")), + error = OAuth2Error.from_response( + result, status_code=response.status_code, request_id=response.request_id, ) diff --git a/src/okta_client/authfoundation/oauth2/errors.py b/src/okta_client/authfoundation/oauth2/errors.py index 9125897..8d36fb5 100644 --- a/src/okta_client/authfoundation/oauth2/errors.py +++ b/src/okta_client/authfoundation/oauth2/errors.py @@ -10,7 +10,14 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any, Mapping + +_STANDARD_FIELDS = frozenset({"error", "error_description", "error_uri"}) + + +def _coerce_optional_str(value: Any) -> str | None: + return None if value is None else str(value) @dataclass @@ -21,6 +28,7 @@ class OAuth2Error(Exception): error_uri: str | None = None status_code: int | None = None request_id: str | None = None + additional_fields: Mapping[str, Any] = field(default_factory=dict) def __str__(self) -> str: """Return a readable error string.""" @@ -30,3 +38,30 @@ def __str__(self) -> str: if self.error_uri: details.append(self.error_uri) return ": ".join(details) + + @classmethod + def from_response( + cls, + data: Mapping[str, Any], + *, + status_code: int | None = None, + request_id: str | None = None, + ) -> "OAuth2Error": + """Build an :class:`OAuth2Error` from a parsed OAuth2 error response body. + + Standard RFC 6749 keys (``error``, ``error_description``, ``error_uri``) + are mapped to their dedicated attributes; any other keys are kept + verbatim on :attr:`additional_fields` so callers can inspect + server-specific remediation hints. + + ``error`` defaults to ``"oauth2_error"`` when the response body omits it + (e.g., a 5xx with no JSON ``error`` key). + """ + return cls( + error=str(data.get("error", "oauth2_error")), + error_description=_coerce_optional_str(data.get("error_description")), + error_uri=_coerce_optional_str(data.get("error_uri")), + status_code=status_code, + request_id=request_id, + additional_fields={k: v for k, v in data.items() if k not in _STANDARD_FIELDS}, + ) diff --git a/src/okta_client/authfoundation/oauth2/request_protocols.py b/src/okta_client/authfoundation/oauth2/request_protocols.py index 2d3e439..38d2b1e 100644 --- a/src/okta_client/authfoundation/oauth2/request_protocols.py +++ b/src/okta_client/authfoundation/oauth2/request_protocols.py @@ -13,8 +13,6 @@ from collections.abc import Mapping from typing import Any, Protocol, runtime_checkable -from okta_client.authfoundation.utils import coerce_optional_str - from ..networking import ( APIContentType, APIParsingContext, @@ -164,11 +162,6 @@ def accepts_type(self) -> APIContentType | None: def parse_error(self, data: Mapping[str, Any]) -> Exception | None: """Parse standard OAuth2 error fields when present.""" - error = data.get("error") - if not error: + if not data.get("error"): return None - return OAuth2Error( - error=str(error), - error_description=coerce_optional_str(data.get("error_description")), - error_uri=coerce_optional_str(data.get("error_uri")), - ) + return OAuth2Error.from_response(data) diff --git a/tests/test_oauth2_exchange.py b/tests/test_oauth2_exchange.py index f25c39d..d02f79c 100644 --- a/tests/test_oauth2_exchange.py +++ b/tests/test_oauth2_exchange.py @@ -227,3 +227,77 @@ def test_oauth2_exchange_oauth_error() -> None: assert error.error_description == "invalid credentials" return raise AssertionError("Expected OAuth2Error for invalid_grant") + + +def test_oauth2_error_preserves_server_additional_fields() -> None: + """Non-standard fields in the token-endpoint error body are preserved on OAuth2Error.""" + openid = OpenIdConfiguration.from_json( + { + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "jwks_uri": "https://example.com/keys", + } + ) + token_body = json.dumps( + { + "error": "interaction_required", + "error_description": "step-up required", + "required_acr": "urn:okta:loa:2fa:any", + "max_age": 0, + } + ).encode("utf-8") + discovery_body = json.dumps( + { + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "jwks_uri": "https://example.com/keys", + } + ).encode("utf-8") + jwks_body = json.dumps({"keys": []}).encode("utf-8") + network = DummyNetwork( + responses={ + "https://example.com/.well-known/openid-configuration": RawResponse( + status_code=200, headers={}, body=discovery_body, + ), + "https://example.com/keys?client_id=client": RawResponse( + status_code=200, headers={}, body=jwks_body, + ), + "https://example.com/token": RawResponse( + status_code=400, headers={}, body=token_body, + ), + } + ) + client = OAuth2Client( + configuration=OAuth2ClientConfiguration( + issuer="https://example.com", + scope=["openid"], + client_authorization=ClientIdAuthorization(id="client"), + ), + network=network, + ) + request = TokenExchangeRequest( + _openid_configuration=openid, + _client_configuration=client.configuration, + username="user", + password="pass", + ) + + try: + asyncio.run(client.exchange(request)) + except OAuth2Error as error: + assert error.error == "interaction_required" + assert error.additional_fields == { + "required_acr": "urn:okta:loa:2fa:any", + "max_age": 0, + } + # str() should remain unchanged (no extras appended). + assert str(error) == "interaction_required: step-up required" + return + raise AssertionError("Expected OAuth2Error for interaction_required") + + +def test_oauth2_error_default_additional_fields_is_empty() -> None: + """Locally-raised OAuth2Errors have an empty additional_fields mapping.""" + err = OAuth2Error(error="state_mismatch", error_description="bad state") + assert err.additional_fields == {}