diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/grpc.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/grpc.py index 01724061..117ea3a0 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/grpc.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/grpc.py @@ -65,10 +65,12 @@ def __init__( logger.debug(self.config.fatal_status_codes) self.retry_grace_period = config.retry_grace_period + self.retry_backoff_max_seconds = config.retry_backoff_max_ms * 0.001 self.streamline_deadline_seconds = config.stream_deadline_ms * 0.001 self.deadline = config.deadline_ms * 0.001 self.connected = False self._is_fatal = False + self._shutdown_event = threading.Event() self.channel = self._generate_channel(config) self.stub = evaluation_pb2_grpc.ServiceStub(self.channel) @@ -135,6 +137,7 @@ def initialize(self, evaluation_context: EvaluationContext) -> None: def shutdown(self) -> None: self.active = False + self._shutdown_event.set() self.channel.unsubscribe(self._state_change_callback) self.channel.close() if self.timer and self.timer.is_alive(): @@ -145,6 +148,7 @@ def shutdown(self) -> None: def connect(self) -> None: self.active = True + self._shutdown_event.clear() # Run monitoring in a separate thread self.monitor_thread = threading.Thread( @@ -215,6 +219,37 @@ def emit_error(self) -> None: ) ) + def _wait_before_reconnect(self) -> None: + self._shutdown_event.wait(self.retry_backoff_max_seconds) + + def _handle_rpc_error(self, e: grpc.RpcError) -> bool: + # although it seems like this error log is not interesting, without it, the retry is not working as expected + logger.debug(f"SyncFlags stream error, {e.code()=} {e.details()=}") + if e.code().name in self.config.fatal_status_codes: + self._is_fatal = True + self.active = False + self.emit_provider_error( + ProviderEventDetails( + message=f"Fatal gRPC status code: {e.code()}", + error_code=ErrorCode.PROVIDER_FATAL, + ) + ) + return True + return False + + def _handle_event_stream_message( + self, message: evaluation_pb2.EventStreamResponse + ) -> None: + if message.type == "provider_ready": + self.emit_provider_ready( + ProviderEventDetails(message="gRPC sync connection established") + ) + self.connected = True + elif message.type == "configuration_change": + msg_dict = MessageToDict(message) + data = msg_dict.get("data", {}) + self.handle_changed_flags(data) + def listen(self) -> None: logger.debug("gRPC starting listener thread") call_args: GrpcMultiCallableArgs = {"wait_for_ready": True} @@ -227,38 +262,20 @@ def listen(self) -> None: try: logger.debug("Setting up gRPC sync flags connection") for message in self.stub.EventStream(request, **call_args): - if message.type == "provider_ready": - self.emit_provider_ready( - ProviderEventDetails( - message="gRPC sync connection established" - ) - ) - self.connected = True - elif message.type == "configuration_change": - msg_dict = MessageToDict(message) - data = msg_dict.get("data", {}) - self.handle_changed_flags(data) + self._handle_event_stream_message(message) if not self.active: logger.info("Terminating gRPC sync thread") return - except grpc.RpcError as e: # noqa: PERF203 - # although it seems like this error log is not interesting, without it, the retry is not working as expected - logger.debug(f"SyncFlags stream error, {e.code()=} {e.details()=}") - if e.code().name in self.config.fatal_status_codes: - self._is_fatal = True - self.active = False - self.emit_provider_error( - ProviderEventDetails( - message=f"Fatal gRPC status code: {e.code()}", - error_code=ErrorCode.PROVIDER_FATAL, - ) - ) + except grpc.RpcError as e: + if self._handle_rpc_error(e): return except ParseError: logger.exception( f"Could not parse flag data using flagd syntax: {message=}" ) + if self.active: + self._wait_before_reconnect() def handle_changed_flags(self, data: typing.Any) -> None: changed_flags = list(data.get("flags", {}).keys()) diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py index e625ae33..f4cc337c 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py @@ -44,7 +44,7 @@ def __init__( self.channel = self._generate_channel(config) self.stub = sync_pb2_grpc.FlagSyncServiceStub(self.channel) self.retry_backoff_seconds = config.retry_backoff_ms * 0.001 - self.retry_backoff_max_seconds = config.retry_backoff_ms * 0.001 + self.retry_backoff_max_seconds = config.retry_backoff_max_ms * 0.001 self.retry_grace_period = config.retry_grace_period self.streamline_deadline_seconds = config.stream_deadline_ms * 0.001 self.deadline = config.deadline_ms * 0.001 @@ -56,6 +56,7 @@ def __init__( self.connected = False self._is_fatal = False + self._shutdown_event = threading.Event() self.thread: threading.Thread | None = None self.timer: threading.Timer | None = None @@ -129,6 +130,7 @@ def initialize(self, context: EvaluationContext) -> None: def connect(self) -> None: self.active = True + self._shutdown_event.clear() # Run monitoring in a separate thread self.monitor_thread = threading.Thread( @@ -199,6 +201,7 @@ def emit_error(self) -> None: def shutdown(self) -> None: self.active = False + self._shutdown_event.set() self.channel.close() def _create_request_args(self) -> dict: @@ -283,6 +286,9 @@ def _handle_rpc_error(self, e: grpc.RpcError) -> bool: return True return False + def _wait_before_reconnect(self) -> None: + self._shutdown_event.wait(self.retry_backoff_max_seconds) + def listen(self) -> None: call_args = self.generate_grpc_call_args() request_args = self._create_request_args() @@ -295,7 +301,7 @@ def listen(self) -> None: for flag_rsp in self.stub.SyncFlags(request, **call_args): if self._handle_flag_response(flag_rsp, context_values_response): return - except grpc.RpcError as e: # noqa: PERF203 + except grpc.RpcError as e: if self._handle_rpc_error(e): return except json.JSONDecodeError: @@ -304,6 +310,8 @@ def listen(self) -> None: ) except ParseError: logger.exception("Could not parse flag data using flagd syntax") + if self.active: + self._wait_before_reconnect() def generate_grpc_call_args(self) -> GrpcMultiCallableArgs: call_args: GrpcMultiCallableArgs = {"wait_for_ready": True} diff --git a/providers/openfeature-provider-flagd/tests/test_grpc_resolver.py b/providers/openfeature-provider-flagd/tests/test_grpc_resolver.py new file mode 100644 index 00000000..11356d48 --- /dev/null +++ b/providers/openfeature-provider-flagd/tests/test_grpc_resolver.py @@ -0,0 +1,74 @@ +import unittest +from unittest.mock import MagicMock, Mock, patch + +import grpc +from grpc import Channel + +from openfeature.contrib.provider.flagd.config import CacheType, Config +from openfeature.contrib.provider.flagd.resolvers.grpc import GrpcResolver + + +class FakeRpcError(grpc.RpcError): + def code(self): + return grpc.StatusCode.UNAVAILABLE + + def details(self): + return "stream unavailable" + + +class TestGrpcResolver(unittest.TestCase): + def setUp(self): + config = Config( + cache=CacheType.DISABLED, + deadline_ms=100, + retry_backoff_ms=1000, + retry_backoff_max_ms=5000, + stream_deadline_ms=1000, + ) + channel = Mock(spec=Channel) + + with patch( + "openfeature.contrib.provider.flagd.resolvers.grpc.GrpcResolver._generate_channel", + return_value=channel, + ): + self.grpc_resolver = GrpcResolver( + config=config, + emit_provider_ready=Mock(), + emit_provider_error=Mock(), + emit_provider_stale=Mock(), + emit_provider_configuration_changed=Mock(), + ) + + self.grpc_resolver.stub = MagicMock() + self.grpc_resolver.active = True + + def test_uses_max_retry_backoff_for_application_level_reconnect_delay(self): + self.assertEqual(self.grpc_resolver.retry_backoff_max_seconds, 5) + + def test_listen_backs_off_after_rpc_stream_error(self): + self.grpc_resolver.stub.EventStream = Mock(side_effect=FakeRpcError()) + + with patch.object( + self.grpc_resolver, + "_wait_before_reconnect", + side_effect=lambda: setattr(self.grpc_resolver, "active", False), + ) as wait_before_reconnect: + self.grpc_resolver.listen() + + wait_before_reconnect.assert_called_once() + + def test_listen_backs_off_after_stream_completion(self): + self.grpc_resolver.stub.EventStream = Mock(return_value=iter([])) + + with patch.object( + self.grpc_resolver, + "_wait_before_reconnect", + side_effect=lambda: setattr(self.grpc_resolver, "active", False), + ) as wait_before_reconnect: + self.grpc_resolver.listen() + + wait_before_reconnect.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/providers/openfeature-provider-flagd/tests/test_grpc_watcher.py b/providers/openfeature-provider-flagd/tests/test_grpc_watcher.py index 395dd6a2..a8944f76 100644 --- a/providers/openfeature-provider-flagd/tests/test_grpc_watcher.py +++ b/providers/openfeature-provider-flagd/tests/test_grpc_watcher.py @@ -3,6 +3,7 @@ import unittest from unittest.mock import MagicMock, Mock, patch +import grpc from google.protobuf.json_format import MessageToDict from google.protobuf.struct_pb2 import Struct from grpc import Channel @@ -20,6 +21,14 @@ from openfeature.schemas.protobuf.flagd.sync.v1.sync_pb2_grpc import FlagSyncServiceStub +class FakeRpcError(grpc.RpcError): + def code(self): + return grpc.StatusCode.UNAVAILABLE + + def details(self): + return "stream unavailable" + + class TestGrpcWatcher(unittest.TestCase): def setUp(self): config = Mock(spec=Config) @@ -36,6 +45,7 @@ def setUp(self): config.host = "localhost" config.port = 5000 config.sync_metadata_disabled = False + config.fatal_status_codes = [] flag_store = Mock(spec=FlagStore) flag_store.update.return_value = None @@ -133,6 +143,33 @@ def test_listen_with_sync_metadata_disabled_in_config(self): ) self.assertEqual(self.context, {}) + def test_uses_max_retry_backoff_for_application_level_reconnect_delay(self): + self.assertEqual(self.grpc_watcher.retry_backoff_max_seconds, 5) + + def test_listen_backs_off_after_rpc_stream_error(self): + self.mock_stub.SyncFlags = Mock(side_effect=FakeRpcError()) + + with patch.object( + self.grpc_watcher, + "_wait_before_reconnect", + side_effect=lambda: setattr(self.grpc_watcher, "active", False), + ) as wait_before_reconnect: + self.grpc_watcher.listen() + + wait_before_reconnect.assert_called_once() + + def test_listen_backs_off_after_stream_completion(self): + self.mock_stub.SyncFlags = Mock(return_value=iter([])) + + with patch.object( + self.grpc_watcher, + "_wait_before_reconnect", + side_effect=lambda: setattr(self.grpc_watcher, "active", False), + ) as wait_before_reconnect: + self.grpc_watcher.listen() + + wait_before_reconnect.assert_called_once() + def test_selector_passed_via_both_metadata_and_body(self): """Test that selector is passed via both gRPC metadata header and request body for backward compatibility""" self.grpc_watcher.selector = "test-selector"