diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 1728191ba..111583a63 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -179,7 +179,18 @@ def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: :return: a tuple of (access_token, token_lifespan) """ - response_json = self._make_handled_request() + try: + response_json = self._make_handled_request() + except ( + requests.exceptions.ConnectionError, + requests.exceptions.ConnectTimeout, + requests.exceptions.ReadTimeout, + ) as e: + raise AirbyteTracedException( + message="OAuth access token refresh request failed due to a network error.", + internal_message=f"Network error during OAuth token refresh after retries were exhausted: {e}", + failure_type=FailureType.transient_error, + ) from e self._ensure_access_token_in_response(response_json) return ( @@ -229,7 +240,12 @@ def _wrap_refresh_token_exception( @backoff.on_exception( backoff.expo, - DefaultBackoffException, + ( + DefaultBackoffException, + requests.exceptions.ConnectionError, + requests.exceptions.ConnectTimeout, + requests.exceptions.ReadTimeout, + ), on_backoff=lambda details: logger.info( f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ), @@ -295,7 +311,11 @@ def _make_handled_request(self) -> Any: ) raise except Exception as e: - raise Exception(f"Error while refreshing access token: {e}") from e + raise AirbyteTracedException( + message="OAuth access token refresh request failed.", + internal_message=f"Unexpected error during OAuth token refresh: {e}", + failure_type=FailureType.system_error, + ) from e def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: """ diff --git a/unit_tests/connector_builder/test_connector_builder_handler.py b/unit_tests/connector_builder/test_connector_builder_handler.py index fee9f3e93..000699609 100644 --- a/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/unit_tests/connector_builder/test_connector_builder_handler.py @@ -1288,7 +1288,7 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error pytest.param( "CLOUD", "https://10.0.27.27/tokens/bearer", - "Error while refreshing access token", + "OAuth access token refresh request failed.", id="test_cloud_read_with_private_endpoint", ), pytest.param( diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index e5e15a035..b41a1b7f1 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -7,16 +7,18 @@ import logging from copy import deepcopy from datetime import timedelta, timezone -from unittest.mock import Mock +from unittest.mock import Mock, patch import freezegun import pytest import requests from requests import Response +from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse +from airbyte_cdk.utils import AirbyteTracedException from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse @@ -645,3 +647,75 @@ def mock_request(method, url, data, headers): raise Exception( f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}" ) + + +class TestOauth2AuthenticatorTransientErrorHandling: + """Tests for transient network error handling during OAuth token refresh.""" + + def _create_authenticator(self): + return DeclarativeOauth2Authenticator( + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + refresh_token="{{ parameters['refresh_token'] }}", + config=config, + token_expiry_date="{{ config['token_expiry_date'] }}", + parameters=parameters, + ) + + @pytest.mark.parametrize( + "exception_class", + [ + requests.exceptions.ConnectionError, + requests.exceptions.ConnectTimeout, + requests.exceptions.ReadTimeout, + ], + ids=["ConnectionError", "ConnectTimeout", "ReadTimeout"], + ) + def test_transient_network_error_wrapped_as_transient_error(self, exception_class): + """Transient network errors during OAuth refresh are wrapped in AirbyteTracedException with transient_error.""" + oauth = self._create_authenticator() + with patch.object( + oauth, "_make_handled_request", side_effect=exception_class("connection reset") + ): + with pytest.raises(AirbyteTracedException) as exc_info: + oauth.refresh_access_token() + + assert exc_info.value.failure_type == FailureType.transient_error + assert "network error" in exc_info.value.message.lower() + + def test_connection_error_is_retried_before_raising(self, mocker): + """ConnectionError triggers backoff retries in _make_handled_request before propagating.""" + oauth = self._create_authenticator() + + call_count = 0 + + def request_side_effect(**kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise requests.exceptions.ConnectionError("connection reset by peer") + mock_response = Mock(spec=requests.Response) + mock_response.ok = True + mock_response.json.return_value = {"access_token": "token_value", "expires_in": 3600} + return mock_response + + mocker.patch("requests.request", side_effect=request_side_effect) + # Patch backoff to avoid actual delays in tests + mocker.patch("time.sleep") + + token, _ = oauth.refresh_access_token() + assert token == "token_value" + assert call_count == 3 + + def test_generic_exception_wrapped_as_system_error(self, mocker): + """Generic exceptions during OAuth refresh are wrapped in AirbyteTracedException with system_error.""" + oauth = self._create_authenticator() + mocker.patch("requests.request", side_effect=ValueError("unexpected parsing error")) + mocker.patch("time.sleep") + + with pytest.raises(AirbyteTracedException) as exc_info: + oauth.refresh_access_token() + + assert exc_info.value.failure_type == FailureType.system_error + assert "OAuth access token refresh request failed" in exc_info.value.message