Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,18 @@ def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:

:return: a tuple of (access_token, token_lifespan)
"""
response_json = self._make_handled_request()
try:
response_json = self._make_handled_request()
except (
requests.exceptions.ConnectionError,
requests.exceptions.ConnectTimeout,
requests.exceptions.ReadTimeout,
) as e:
raise AirbyteTracedException(
message="OAuth access token refresh request failed due to a network error.",
internal_message=f"Network error during OAuth token refresh after retries were exhausted: {e}",
failure_type=FailureType.transient_error,
) from e
self._ensure_access_token_in_response(response_json)

return (
Expand Down Expand Up @@ -229,7 +240,12 @@ def _wrap_refresh_token_exception(

@backoff.on_exception(
backoff.expo,
DefaultBackoffException,
(
DefaultBackoffException,
requests.exceptions.ConnectionError,
requests.exceptions.ConnectTimeout,
requests.exceptions.ReadTimeout,
),
on_backoff=lambda details: logger.info(
f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
),
Expand Down Expand Up @@ -295,7 +311,11 @@ def _make_handled_request(self) -> Any:
)
raise
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e
raise AirbyteTracedException(
message="OAuth access token refresh request failed.",
internal_message=f"Unexpected error during OAuth token refresh: {e}",
failure_type=FailureType.system_error,
) from e

def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error
pytest.param(
"CLOUD",
"https://10.0.27.27/tokens/bearer",
"Error while refreshing access token",
"OAuth access token refresh request failed.",
id="test_cloud_read_with_private_endpoint",
),
pytest.param(
Expand Down
76 changes: 75 additions & 1 deletion unit_tests/sources/declarative/auth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
import logging
from copy import deepcopy
from datetime import timedelta, timezone
from unittest.mock import Mock
from unittest.mock import Mock, patch

import freezegun
import pytest
import requests
from requests import Response

from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator
from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator
from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse

Expand Down Expand Up @@ -645,3 +647,75 @@ def mock_request(method, url, data, headers):
raise Exception(
f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}"
)


class TestOauth2AuthenticatorTransientErrorHandling:
"""Tests for transient network error handling during OAuth token refresh."""

def _create_authenticator(self):
return DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
refresh_token="{{ parameters['refresh_token'] }}",
config=config,
token_expiry_date="{{ config['token_expiry_date'] }}",
parameters=parameters,
)

@pytest.mark.parametrize(
"exception_class",
[
requests.exceptions.ConnectionError,
requests.exceptions.ConnectTimeout,
requests.exceptions.ReadTimeout,
],
ids=["ConnectionError", "ConnectTimeout", "ReadTimeout"],
)
def test_transient_network_error_wrapped_as_transient_error(self, exception_class):
"""Transient network errors during OAuth refresh are wrapped in AirbyteTracedException with transient_error."""
oauth = self._create_authenticator()
with patch.object(
oauth, "_make_handled_request", side_effect=exception_class("connection reset")
):
with pytest.raises(AirbyteTracedException) as exc_info:
oauth.refresh_access_token()

assert exc_info.value.failure_type == FailureType.transient_error
assert "network error" in exc_info.value.message.lower()

def test_connection_error_is_retried_before_raising(self, mocker):
"""ConnectionError triggers backoff retries in _make_handled_request before propagating."""
oauth = self._create_authenticator()

call_count = 0

def request_side_effect(**kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
raise requests.exceptions.ConnectionError("connection reset by peer")
mock_response = Mock(spec=requests.Response)
mock_response.ok = True
mock_response.json.return_value = {"access_token": "token_value", "expires_in": 3600}
return mock_response

mocker.patch("requests.request", side_effect=request_side_effect)
# Patch backoff to avoid actual delays in tests
mocker.patch("time.sleep")

token, _ = oauth.refresh_access_token()
assert token == "token_value"
assert call_count == 3

def test_generic_exception_wrapped_as_system_error(self, mocker):
"""Generic exceptions during OAuth refresh are wrapped in AirbyteTracedException with system_error."""
oauth = self._create_authenticator()
mocker.patch("requests.request", side_effect=ValueError("unexpected parsing error"))
mocker.patch("time.sleep")

with pytest.raises(AirbyteTracedException) as exc_info:
oauth.refresh_access_token()

assert exc_info.value.failure_type == FailureType.system_error
assert "OAuth access token refresh request failed" in exc_info.value.message
Loading