Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions roborock/devices/mqtt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ async def publish(self, message: RoborockMessage) -> None:
_LOGGER.exception("Error publishing MQTT message: %s", e)
raise RoborockException(f"Failed to publish MQTT message: {e}") from e

async def restart(self) -> None:
"""Restart the underlying MQTT session."""
await self._mqtt_session.restart()


def create_mqtt_channel(
user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
Expand Down
12 changes: 11 additions & 1 deletion roborock/devices/v1_rpc_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from roborock.data import RoborockBase
from roborock.exceptions import RoborockException
from roborock.mqtt.health_manager import HealthManager
from roborock.protocols.v1_protocol import (
CommandType,
MapResponse,
Expand Down Expand Up @@ -125,12 +126,14 @@ def __init__(
channel: MqttChannel | LocalChannel,
payload_encoder: Callable[[RequestMessage], RoborockMessage],
decoder: Callable[[RoborockMessage], ResponseMessage] | Callable[[RoborockMessage], MapResponse | None],
health_manager: HealthManager | None = None,
) -> None:
"""Initialize the channel with a raw channel and an encoder function."""
self._name = name
self._channel = channel
self._payload_encoder = payload_encoder
self._decoder = decoder
self._health_manager = health_manager

async def _send_raw_command(
self,
Expand Down Expand Up @@ -165,13 +168,19 @@ def find_response(response_message: RoborockMessage) -> None:
unsub = await self._channel.subscribe(find_response)
try:
await self._channel.publish(message)
return await asyncio.wait_for(future, timeout=_TIMEOUT)
result = await asyncio.wait_for(future, timeout=_TIMEOUT)
except TimeoutError as ex:
if self._health_manager:
await self._health_manager.on_timeout()
future.cancel()
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
finally:
unsub()

if self._health_manager:
await self._health_manager.on_success()
return result


def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:
"""Create a V1 RPC channel using an MQTT channel."""
Expand All @@ -180,6 +189,7 @@ def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityDa
mqtt_channel,
lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data),
decode_rpc_response,
health_manager=HealthManager(mqtt_channel.restart),
)


Expand Down
51 changes: 51 additions & 0 deletions roborock/mqtt/health_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""A health manager for monitoring MQTT connections to Roborock devices.
We observe a problem where sometimes the MQTT connection appears to be alive but
no messages are being received. To mitigate this, we track consecutive timeouts
and restart the connection if too many timeouts occur in succession.
"""

import datetime
from collections.abc import Awaitable, Callable

# Number of consecutive timeouts before considering the connection unhealthy.
TIMEOUT_THRESHOLD = 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would technically trigger immediately with gathered function calls. But anything more than this is probably too much complexity.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potential follow up could be to keep track of the last timeout. if the timeout is more than say 15 seconds ago, it increases the increment. Could be a follow up PR though, I don't want to slow this one down as we are on a time crunch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're currently sending most commands serially, but yeah this kind of heuristic is hard. i was also considering if we could do it entirely in the mqtt session but hard when you can't correlate incoming and outgoing messages to know if something really did timeout.

Not sure if i get the timeout point you're making but interested in following up.


# We won't restart the session more often than this interval.
RESTART_COOLDOWN = datetime.timedelta(minutes=30)


class HealthManager:
"""Manager for monitoring the health of MQTT connections.
This tracks communication timeouts and can trigger restarts of the MQTT
session if too many timeouts occur in succession.
"""

def __init__(self, restart: Callable[[], Awaitable[None]]) -> None:
"""Initialize the health manager.
Args:
restart: A callable to restart the MQTT session.
"""
self._consecutive_timeouts = 0
self._restart = restart
self._last_restart: datetime.datetime | None = None

async def on_success(self) -> None:
"""Record a successful communication event."""
self._consecutive_timeouts = 0

async def on_timeout(self) -> None:
"""Record a timeout event.
This may trigger a restart of the MQTT session if too many timeouts
have occurred in succession.
"""
self._consecutive_timeouts += 1
if self._consecutive_timeouts >= TIMEOUT_THRESHOLD:
now = datetime.datetime.now(datetime.UTC)
if self._last_restart is None or now - self._last_restart >= RESTART_COOLDOWN:
await self._restart()
self._last_restart = now
self._consecutive_timeouts = 0
130 changes: 77 additions & 53 deletions roborock/mqtt/roborock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ class RoborockMqttSession(MqttSession):

def __init__(self, params: MqttParams):
self._params = params
self._background_task: asyncio.Task[None] | None = None
self._reconnect_task: asyncio.Task[None] | None = None
self._healthy = False
self._stop = False
self._backoff = MIN_BACKOFF_INTERVAL
self._client: aiomqtt.Client | None = None
self._client_lock = asyncio.Lock()
self._listeners: CallbackMap[str, bytes] = CallbackMap(_LOGGER)
self._connection_task: asyncio.Task[None] | None = None

@property
def connected(self) -> bool:
Expand All @@ -72,7 +73,7 @@ async def start(self) -> None:
"""
start_future: asyncio.Future[None] = asyncio.Future()
loop = asyncio.get_event_loop()
self._background_task = loop.create_task(self._run_task(start_future))
self._reconnect_task = loop.create_task(self._run_reconnect_loop(start_future))
try:
await start_future
except MqttError as err:
Expand All @@ -85,68 +86,93 @@ async def start(self) -> None:
async def close(self) -> None:
"""Cancels the MQTT loop and shutdown the client library."""
self._stop = True
if self._background_task:
self._background_task.cancel()
try:
await self._background_task
except asyncio.CancelledError:
pass
async with self._client_lock:
if self._client:
await self._client.close()
tasks = [task for task in [self._connection_task, self._reconnect_task] if task]
for task in tasks:
task.cancel()
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
pass

self._healthy = False

async def _run_task(self, start_future: asyncio.Future[None] | None) -> None:
async def restart(self) -> None:
"""Force the session to disconnect and reconnect.

The active connection task will be cancelled and restarted in the background, retried by
the reconnect loop. This is a no-op if there is no active connection.
"""
_LOGGER.info("Forcing MQTT session restart")
if self._connection_task:
self._connection_task.cancel()
else:
_LOGGER.debug("No message loop task to cancel")

async def _run_reconnect_loop(self, start_future: asyncio.Future[None] | None) -> None:
"""Run the MQTT loop."""
_LOGGER.info("Starting MQTT session")
while True:
try:
async with self._mqtt_client(self._params) as client:
# Reset backoff once we've successfully connected
self._backoff = MIN_BACKOFF_INTERVAL
self._healthy = True
_LOGGER.info("MQTT Session connected.")
if start_future:
start_future.set_result(None)
start_future = None

await self._process_message_loop(client)

except MqttError as err:
if start_future:
_LOGGER.info("MQTT error starting session: %s", err)
start_future.set_exception(err)
return
_LOGGER.info("MQTT error: %s", err)
except asyncio.CancelledError as err:
if start_future:
_LOGGER.debug("MQTT loop was cancelled while starting")
start_future.set_exception(err)
_LOGGER.debug("MQTT loop was cancelled")
return
# Catch exceptions to avoid crashing the loop
# and to allow the loop to retry.
except Exception as err:
# This error is thrown when the MQTT loop is cancelled
# and the generator is not stopped.
if "generator didn't stop" in str(err) or "generator didn't yield" in str(err):
_LOGGER.debug("MQTT loop was cancelled")
return
if start_future:
_LOGGER.error("Uncaught error starting MQTT session: %s", err)
start_future.set_exception(err)
self._connection_task = asyncio.create_task(self._run_connection(start_future))
await self._connection_task
except asyncio.CancelledError:
_LOGGER.debug("MQTT connection task cancelled")
except Exception:
# Exceptions are logged and handled in _run_connection.
# There is a special case for exceptions on startup where we return
# immediately. Otherwise, we let the reconnect loop retry with
# backoff when the reconnect loop is active.
if start_future and start_future.done() and start_future.exception():
return
_LOGGER.exception("Uncaught error during MQTT session: %s", err)

self._healthy = False
start_future = None
if self._stop:
_LOGGER.debug("MQTT session closed, stopping retry loop")
return
_LOGGER.info("MQTT session disconnected, retrying in %s seconds", self._backoff.total_seconds())
await asyncio.sleep(self._backoff.total_seconds())
self._backoff = min(self._backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL)

async def _run_connection(self, start_future: asyncio.Future[None] | None) -> None:
"""Connect to the MQTT broker and listen for messages.

This is the primary connection loop for the MQTT session that is
long running and processes incoming messages. If the connection
is lost, this method will exit.
"""
try:
async with self._mqtt_client(self._params) as client:
self._backoff = MIN_BACKOFF_INTERVAL
self._healthy = True
_LOGGER.info("MQTT Session connected.")
if start_future and not start_future.done():
start_future.set_result(None)

_LOGGER.debug("Processing MQTT messages")
async for message in client.messages:
_LOGGER.debug("Received message: %s", message)
self._listeners(message.topic.value, message.payload)
except MqttError as err:
if start_future and not start_future.done():
_LOGGER.info("MQTT error starting session: %s", err)
start_future.set_exception(err)
else:
_LOGGER.info("MQTT error: %s", err)
raise
except Exception as err:
# This error is thrown when the MQTT loop is cancelled
# and the generator is not stopped.
if "generator didn't stop" in str(err) or "generator didn't yield" in str(err):
_LOGGER.debug("MQTT loop was cancelled")
return
if start_future and not start_future.done():
_LOGGER.error("Uncaught error starting MQTT session: %s", err)
start_future.set_exception(err)
else:
_LOGGER.exception("Uncaught error during MQTT session: %s", err)
raise

@asynccontextmanager
async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
"""Connect to the MQTT broker and listen for messages."""
Expand Down Expand Up @@ -178,12 +204,6 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
async with self._client_lock:
self._client = None

async def _process_message_loop(self, client: aiomqtt.Client) -> None:
_LOGGER.debug("Processing MQTT messages")
async for message in client.messages:
_LOGGER.debug("Received message: %s", message)
self._listeners(message.topic.value, message.payload)

async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
"""Subscribe to messages on the specified topic and invoke the callback for new messages.

Expand Down Expand Up @@ -271,6 +291,10 @@ async def close(self) -> None:
"""
await self._session.close()

async def restart(self) -> None:
"""Force the session to disconnect and reconnect."""
await self._session.restart()


async def create_mqtt_session(params: MqttParams) -> MqttSession:
"""Create an MQTT session.
Expand Down
4 changes: 4 additions & 0 deletions roborock/mqtt/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ async def publish(self, topic: str, message: bytes) -> None:
This will raise an exception if the message could not be sent.
"""

@abstractmethod
async def restart(self) -> None:
"""Force the session to disconnect and reconnect."""

@abstractmethod
async def close(self) -> None:
"""Cancels the mqtt loop"""
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def __init__(self):
self.connect = AsyncMock(side_effect=self._connect)
self.close = MagicMock(side_effect=self._close)
self.protocol_version = LocalProtocolVersion.V1
self.restart = AsyncMock()

async def _connect(self) -> None:
self._is_connected = True
Expand Down
73 changes: 73 additions & 0 deletions tests/mqtt/test_health_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Tests for the health manager."""

import datetime
from unittest.mock import AsyncMock, patch

from roborock.mqtt.health_manager import HealthManager


async def test_health_manager_restart_called_after_timeouts() -> None:
"""Test that the health manager calls restart after consecutive timeouts."""
restart = AsyncMock()
health_manager = HealthManager(restart=restart)

await health_manager.on_timeout()
await health_manager.on_timeout()
restart.assert_not_called()

await health_manager.on_timeout()
restart.assert_called_once()


async def test_health_manager_success_resets_counter() -> None:
"""Test that a successful message resets the timeout counter."""
restart = AsyncMock()
health_manager = HealthManager(restart=restart)

await health_manager.on_timeout()
await health_manager.on_timeout()
restart.assert_not_called()

await health_manager.on_success()

await health_manager.on_timeout()
await health_manager.on_timeout()
restart.assert_not_called()

await health_manager.on_timeout()
restart.assert_called_once()


async def test_cooldown() -> None:
"""Test that the health manager respects the restart cooldown."""
restart = AsyncMock()
health_manager = HealthManager(restart=restart)

with patch("roborock.mqtt.health_manager.datetime") as mock_datetime:
now = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = now

# Trigger first restart
await health_manager.on_timeout()
await health_manager.on_timeout()
await health_manager.on_timeout()
restart.assert_called_once()
restart.reset_mock()

# Advance time but stay within cooldown (30 mins)
mock_datetime.datetime.now.return_value = now + datetime.timedelta(minutes=10)

# Trigger timeouts again
await health_manager.on_timeout()
await health_manager.on_timeout()
await health_manager.on_timeout()
restart.assert_not_called()

# Advance time past cooldown
mock_datetime.datetime.now.return_value = now + datetime.timedelta(minutes=31)

# Trigger timeouts again
await health_manager.on_timeout()
await health_manager.on_timeout()
await health_manager.on_timeout()
Comment on lines +66 to +72
Copy link

Copilot AI Nov 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] After the restart is triggered and the cooldown timer is set (lines 51-55), the consecutive timeout counter is reset to 0. When three more timeouts occur (lines 61-63), the counter reaches the threshold of 3 again, but restart isn't called due to the cooldown check.

However, the consecutive timeout counter remains at 3 after line 64. When time advances past the cooldown (line 67) and the test triggers three more timeouts (lines 70-72), the counter would go from 3 to 6, not from 0 to 3.

This works in the current test because the implementation still triggers a restart when _consecutive_timeouts >= TIMEOUT_THRESHOLD, but it's testing a slightly different scenario than intended. Consider adding an assertion after line 64 or resetting expectations to make the test behavior clearer:

# After cooldown period, counter is still at 3, so even one timeout would trigger restart
# Advance time past cooldown
mock_datetime.datetime.now.return_value = now + datetime.timedelta(minutes=31)

await health_manager.on_timeout()  # Counter now at 4, triggers restart
restart.assert_called_once()
Suggested change
# Advance time past cooldown
mock_datetime.datetime.now.return_value = now + datetime.timedelta(minutes=31)
# Trigger timeouts again
await health_manager.on_timeout()
await health_manager.on_timeout()
await health_manager.on_timeout()
# The consecutive timeout counter is now at 3
assert health_manager._consecutive_timeouts == 3
# Advance time past cooldown
mock_datetime.datetime.now.return_value = now + datetime.timedelta(minutes=31)
# Even a single timeout now triggers restart
await health_manager.on_timeout()

Copilot uses AI. Check for mistakes.
restart.assert_called_once()
Loading