diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index 5bc4d9bc..8c840c57 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -82,6 +82,10 @@ async def publish(self, message: RoborockMessage) -> None: _LOGGER.exception("Error publishing MQTT message: %s", e) raise RoborockException(f"Failed to publish MQTT message: {e}") from e + async def restart(self) -> None: + """Restart the underlying MQTT session.""" + await self._mqtt_session.restart() + def create_mqtt_channel( user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice diff --git a/roborock/devices/v1_rpc_channel.py b/roborock/devices/v1_rpc_channel.py index ffd46d75..270ed05a 100644 --- a/roborock/devices/v1_rpc_channel.py +++ b/roborock/devices/v1_rpc_channel.py @@ -13,6 +13,7 @@ from roborock.data import RoborockBase from roborock.exceptions import RoborockException +from roborock.mqtt.health_manager import HealthManager from roborock.protocols.v1_protocol import ( CommandType, MapResponse, @@ -125,12 +126,14 @@ def __init__( channel: MqttChannel | LocalChannel, payload_encoder: Callable[[RequestMessage], RoborockMessage], decoder: Callable[[RoborockMessage], ResponseMessage] | Callable[[RoborockMessage], MapResponse | None], + health_manager: HealthManager | None = None, ) -> None: """Initialize the channel with a raw channel and an encoder function.""" self._name = name self._channel = channel self._payload_encoder = payload_encoder self._decoder = decoder + self._health_manager = health_manager async def _send_raw_command( self, @@ -165,13 +168,19 @@ def find_response(response_message: RoborockMessage) -> None: unsub = await self._channel.subscribe(find_response) try: await self._channel.publish(message) - return await asyncio.wait_for(future, timeout=_TIMEOUT) + result = await asyncio.wait_for(future, timeout=_TIMEOUT) except TimeoutError as ex: + if self._health_manager: + await self._health_manager.on_timeout() future.cancel() raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex finally: unsub() + if self._health_manager: + await self._health_manager.on_success() + return result + def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel: """Create a V1 RPC channel using an MQTT channel.""" @@ -180,6 +189,7 @@ def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityDa mqtt_channel, lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data), decode_rpc_response, + health_manager=HealthManager(mqtt_channel.restart), ) diff --git a/roborock/mqtt/health_manager.py b/roborock/mqtt/health_manager.py new file mode 100644 index 00000000..d02cb4b4 --- /dev/null +++ b/roborock/mqtt/health_manager.py @@ -0,0 +1,51 @@ +"""A health manager for monitoring MQTT connections to Roborock devices. + +We observe a problem where sometimes the MQTT connection appears to be alive but +no messages are being received. To mitigate this, we track consecutive timeouts +and restart the connection if too many timeouts occur in succession. +""" + +import datetime +from collections.abc import Awaitable, Callable + +# Number of consecutive timeouts before considering the connection unhealthy. +TIMEOUT_THRESHOLD = 3 + +# We won't restart the session more often than this interval. +RESTART_COOLDOWN = datetime.timedelta(minutes=30) + + +class HealthManager: + """Manager for monitoring the health of MQTT connections. + + This tracks communication timeouts and can trigger restarts of the MQTT + session if too many timeouts occur in succession. + """ + + def __init__(self, restart: Callable[[], Awaitable[None]]) -> None: + """Initialize the health manager. + + Args: + restart: A callable to restart the MQTT session. + """ + self._consecutive_timeouts = 0 + self._restart = restart + self._last_restart: datetime.datetime | None = None + + async def on_success(self) -> None: + """Record a successful communication event.""" + self._consecutive_timeouts = 0 + + async def on_timeout(self) -> None: + """Record a timeout event. + + This may trigger a restart of the MQTT session if too many timeouts + have occurred in succession. + """ + self._consecutive_timeouts += 1 + if self._consecutive_timeouts >= TIMEOUT_THRESHOLD: + now = datetime.datetime.now(datetime.UTC) + if self._last_restart is None or now - self._last_restart >= RESTART_COOLDOWN: + await self._restart() + self._last_restart = now + self._consecutive_timeouts = 0 diff --git a/roborock/mqtt/roborock_session.py b/roborock/mqtt/roborock_session.py index 3d2c6917..53a337e7 100644 --- a/roborock/mqtt/roborock_session.py +++ b/roborock/mqtt/roborock_session.py @@ -49,13 +49,14 @@ class RoborockMqttSession(MqttSession): def __init__(self, params: MqttParams): self._params = params - self._background_task: asyncio.Task[None] | None = None + self._reconnect_task: asyncio.Task[None] | None = None self._healthy = False self._stop = False self._backoff = MIN_BACKOFF_INTERVAL self._client: aiomqtt.Client | None = None self._client_lock = asyncio.Lock() self._listeners: CallbackMap[str, bytes] = CallbackMap(_LOGGER) + self._connection_task: asyncio.Task[None] | None = None @property def connected(self) -> bool: @@ -72,7 +73,7 @@ async def start(self) -> None: """ start_future: asyncio.Future[None] = asyncio.Future() loop = asyncio.get_event_loop() - self._background_task = loop.create_task(self._run_task(start_future)) + self._reconnect_task = loop.create_task(self._run_reconnect_loop(start_future)) try: await start_future except MqttError as err: @@ -85,61 +86,47 @@ async def start(self) -> None: async def close(self) -> None: """Cancels the MQTT loop and shutdown the client library.""" self._stop = True - if self._background_task: - self._background_task.cancel() - try: - await self._background_task - except asyncio.CancelledError: - pass - async with self._client_lock: - if self._client: - await self._client.close() + tasks = [task for task in [self._connection_task, self._reconnect_task] if task] + for task in tasks: + task.cancel() + try: + await asyncio.gather(*tasks) + except asyncio.CancelledError: + pass self._healthy = False - async def _run_task(self, start_future: asyncio.Future[None] | None) -> None: + async def restart(self) -> None: + """Force the session to disconnect and reconnect. + + The active connection task will be cancelled and restarted in the background, retried by + the reconnect loop. This is a no-op if there is no active connection. + """ + _LOGGER.info("Forcing MQTT session restart") + if self._connection_task: + self._connection_task.cancel() + else: + _LOGGER.debug("No message loop task to cancel") + + async def _run_reconnect_loop(self, start_future: asyncio.Future[None] | None) -> None: """Run the MQTT loop.""" _LOGGER.info("Starting MQTT session") while True: try: - async with self._mqtt_client(self._params) as client: - # Reset backoff once we've successfully connected - self._backoff = MIN_BACKOFF_INTERVAL - self._healthy = True - _LOGGER.info("MQTT Session connected.") - if start_future: - start_future.set_result(None) - start_future = None - - await self._process_message_loop(client) - - except MqttError as err: - if start_future: - _LOGGER.info("MQTT error starting session: %s", err) - start_future.set_exception(err) - return - _LOGGER.info("MQTT error: %s", err) - except asyncio.CancelledError as err: - if start_future: - _LOGGER.debug("MQTT loop was cancelled while starting") - start_future.set_exception(err) - _LOGGER.debug("MQTT loop was cancelled") - return - # Catch exceptions to avoid crashing the loop - # and to allow the loop to retry. - except Exception as err: - # This error is thrown when the MQTT loop is cancelled - # and the generator is not stopped. - if "generator didn't stop" in str(err) or "generator didn't yield" in str(err): - _LOGGER.debug("MQTT loop was cancelled") - return - if start_future: - _LOGGER.error("Uncaught error starting MQTT session: %s", err) - start_future.set_exception(err) + self._connection_task = asyncio.create_task(self._run_connection(start_future)) + await self._connection_task + except asyncio.CancelledError: + _LOGGER.debug("MQTT connection task cancelled") + except Exception: + # Exceptions are logged and handled in _run_connection. + # There is a special case for exceptions on startup where we return + # immediately. Otherwise, we let the reconnect loop retry with + # backoff when the reconnect loop is active. + if start_future and start_future.done() and start_future.exception(): return - _LOGGER.exception("Uncaught error during MQTT session: %s", err) self._healthy = False + start_future = None if self._stop: _LOGGER.debug("MQTT session closed, stopping retry loop") return @@ -147,6 +134,45 @@ async def _run_task(self, start_future: asyncio.Future[None] | None) -> None: await asyncio.sleep(self._backoff.total_seconds()) self._backoff = min(self._backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL) + async def _run_connection(self, start_future: asyncio.Future[None] | None) -> None: + """Connect to the MQTT broker and listen for messages. + + This is the primary connection loop for the MQTT session that is + long running and processes incoming messages. If the connection + is lost, this method will exit. + """ + try: + async with self._mqtt_client(self._params) as client: + self._backoff = MIN_BACKOFF_INTERVAL + self._healthy = True + _LOGGER.info("MQTT Session connected.") + if start_future and not start_future.done(): + start_future.set_result(None) + + _LOGGER.debug("Processing MQTT messages") + async for message in client.messages: + _LOGGER.debug("Received message: %s", message) + self._listeners(message.topic.value, message.payload) + except MqttError as err: + if start_future and not start_future.done(): + _LOGGER.info("MQTT error starting session: %s", err) + start_future.set_exception(err) + else: + _LOGGER.info("MQTT error: %s", err) + raise + except Exception as err: + # This error is thrown when the MQTT loop is cancelled + # and the generator is not stopped. + if "generator didn't stop" in str(err) or "generator didn't yield" in str(err): + _LOGGER.debug("MQTT loop was cancelled") + return + if start_future and not start_future.done(): + _LOGGER.error("Uncaught error starting MQTT session: %s", err) + start_future.set_exception(err) + else: + _LOGGER.exception("Uncaught error during MQTT session: %s", err) + raise + @asynccontextmanager async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client: """Connect to the MQTT broker and listen for messages.""" @@ -178,12 +204,6 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client: async with self._client_lock: self._client = None - async def _process_message_loop(self, client: aiomqtt.Client) -> None: - _LOGGER.debug("Processing MQTT messages") - async for message in client.messages: - _LOGGER.debug("Received message: %s", message) - self._listeners(message.topic.value, message.payload) - async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]: """Subscribe to messages on the specified topic and invoke the callback for new messages. @@ -271,6 +291,10 @@ async def close(self) -> None: """ await self._session.close() + async def restart(self) -> None: + """Force the session to disconnect and reconnect.""" + await self._session.restart() + async def create_mqtt_session(params: MqttParams) -> MqttSession: """Create an MQTT session. diff --git a/roborock/mqtt/session.py b/roborock/mqtt/session.py index c72e3294..f5922d23 100644 --- a/roborock/mqtt/session.py +++ b/roborock/mqtt/session.py @@ -54,6 +54,10 @@ async def publish(self, topic: str, message: bytes) -> None: This will raise an exception if the message could not be sent. """ + @abstractmethod + async def restart(self) -> None: + """Force the session to disconnect and reconnect.""" + @abstractmethod async def close(self) -> None: """Cancels the mqtt loop""" diff --git a/tests/conftest.py b/tests/conftest.py index 97792686..6264f2b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -363,6 +363,7 @@ def __init__(self): self.connect = AsyncMock(side_effect=self._connect) self.close = MagicMock(side_effect=self._close) self.protocol_version = LocalProtocolVersion.V1 + self.restart = AsyncMock() async def _connect(self) -> None: self._is_connected = True diff --git a/tests/mqtt/test_health_manager.py b/tests/mqtt/test_health_manager.py new file mode 100644 index 00000000..94b253b4 --- /dev/null +++ b/tests/mqtt/test_health_manager.py @@ -0,0 +1,73 @@ +"""Tests for the health manager.""" + +import datetime +from unittest.mock import AsyncMock, patch + +from roborock.mqtt.health_manager import HealthManager + + +async def test_health_manager_restart_called_after_timeouts() -> None: + """Test that the health manager calls restart after consecutive timeouts.""" + restart = AsyncMock() + health_manager = HealthManager(restart=restart) + + await health_manager.on_timeout() + await health_manager.on_timeout() + restart.assert_not_called() + + await health_manager.on_timeout() + restart.assert_called_once() + + +async def test_health_manager_success_resets_counter() -> None: + """Test that a successful message resets the timeout counter.""" + restart = AsyncMock() + health_manager = HealthManager(restart=restart) + + await health_manager.on_timeout() + await health_manager.on_timeout() + restart.assert_not_called() + + await health_manager.on_success() + + await health_manager.on_timeout() + await health_manager.on_timeout() + restart.assert_not_called() + + await health_manager.on_timeout() + restart.assert_called_once() + + +async def test_cooldown() -> None: + """Test that the health manager respects the restart cooldown.""" + restart = AsyncMock() + health_manager = HealthManager(restart=restart) + + with patch("roborock.mqtt.health_manager.datetime") as mock_datetime: + now = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = now + + # Trigger first restart + await health_manager.on_timeout() + await health_manager.on_timeout() + await health_manager.on_timeout() + restart.assert_called_once() + restart.reset_mock() + + # Advance time but stay within cooldown (30 mins) + mock_datetime.datetime.now.return_value = now + datetime.timedelta(minutes=10) + + # Trigger timeouts again + await health_manager.on_timeout() + await health_manager.on_timeout() + await health_manager.on_timeout() + restart.assert_not_called() + + # Advance time past cooldown + mock_datetime.datetime.now.return_value = now + datetime.timedelta(minutes=31) + + # Trigger timeouts again + await health_manager.on_timeout() + await health_manager.on_timeout() + await health_manager.on_timeout() + restart.assert_called_once() diff --git a/tests/mqtt/test_roborock_session.py b/tests/mqtt/test_roborock_session.py index 0cca44b0..bb3f1bdb 100644 --- a/tests/mqtt/test_roborock_session.py +++ b/tests/mqtt/test_roborock_session.py @@ -1,7 +1,8 @@ """Tests for the MQTT session module.""" import asyncio -from collections.abc import AsyncGenerator, Callable +import datetime +from collections.abc import AsyncGenerator, Callable, Generator from queue import Queue from typing import Any from unittest.mock import AsyncMock, Mock, patch @@ -72,6 +73,13 @@ def new_client(*args: Any, **kwargs: Any) -> mqtt.Client: task.cancel() +@pytest.fixture(autouse=True) +def fast_backoff_fixture() -> Generator[None, None, None]: + """Fixture to make backoff intervals fast.""" + with patch("roborock.mqtt.roborock_session.MIN_BACKOFF_INTERVAL", datetime.timedelta(seconds=0.01)): + yield + + @pytest.fixture def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]: """Fixtures to push messages.""" @@ -205,6 +213,8 @@ async def test_publish_failure() -> None: with pytest.raises(MqttSessionException, match="Error publishing message"): await session.publish("topic-1", message=b"payload") + await session.close() + async def test_subscribe_failure() -> None: """Test an MQTT error while subscribing.""" @@ -231,3 +241,41 @@ async def test_subscribe_failure() -> None: assert not subscriber1.messages await session.close() + + +async def test_restart(push_response: Callable[[bytes], None]) -> None: + """Test restarting the MQTT session.""" + + push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + session = await create_mqtt_session(FAKE_PARAMS) + assert session.connected + + # Subscribe to a topic + push_response(mqtt_packet.gen_suback(mid=1)) + subscriber = Subscriber() + await session.subscribe("topic-1", subscriber.append) + + # Verify we can receive messages + push_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=b"12345")) + await subscriber.wait() + assert subscriber.messages == [b"12345"] + + # Restart the session. + await session.restart() + # This is a hack where we grab on to the client and wait for it to be + # closed properly and restarted. + while session._client: # type: ignore[attr-defined] + await asyncio.sleep(0.01) + + # We need to queue up a new connack for the reconnection + push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + + # And a suback for the resubscription. Since we created a new client, + # the message ID resets to 1. + push_response(mqtt_packet.gen_suback(mid=1)) + + push_response(mqtt_packet.gen_publish("topic-1", mid=4, payload=b"67890")) + await subscriber.wait() + assert subscriber.messages == [b"12345", b"67890"] + + await session.close()