Skip to content

Commit 879a641

Browse files
authored
fix: add a health manager for restarting unhealthy mqtt connections (#605)
* fix: add ability to restart the mqtt session * fix: add a health manager for restarting unhealthy mqtt connections * chore: fix async tests * fix: reset start_future each loop * chore: always use utc for now * chore: cancel the connection and reconnect tasks
1 parent e912dac commit 879a641

File tree

8 files changed

+270
-55
lines changed

8 files changed

+270
-55
lines changed

roborock/devices/mqtt_channel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ async def publish(self, message: RoborockMessage) -> None:
8282
_LOGGER.exception("Error publishing MQTT message: %s", e)
8383
raise RoborockException(f"Failed to publish MQTT message: {e}") from e
8484

85+
async def restart(self) -> None:
86+
"""Restart the underlying MQTT session."""
87+
await self._mqtt_session.restart()
88+
8589

8690
def create_mqtt_channel(
8791
user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice

roborock/devices/v1_rpc_channel.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from roborock.data import RoborockBase
1515
from roborock.exceptions import RoborockException
16+
from roborock.mqtt.health_manager import HealthManager
1617
from roborock.protocols.v1_protocol import (
1718
CommandType,
1819
MapResponse,
@@ -125,12 +126,14 @@ def __init__(
125126
channel: MqttChannel | LocalChannel,
126127
payload_encoder: Callable[[RequestMessage], RoborockMessage],
127128
decoder: Callable[[RoborockMessage], ResponseMessage] | Callable[[RoborockMessage], MapResponse | None],
129+
health_manager: HealthManager | None = None,
128130
) -> None:
129131
"""Initialize the channel with a raw channel and an encoder function."""
130132
self._name = name
131133
self._channel = channel
132134
self._payload_encoder = payload_encoder
133135
self._decoder = decoder
136+
self._health_manager = health_manager
134137

135138
async def _send_raw_command(
136139
self,
@@ -165,13 +168,19 @@ def find_response(response_message: RoborockMessage) -> None:
165168
unsub = await self._channel.subscribe(find_response)
166169
try:
167170
await self._channel.publish(message)
168-
return await asyncio.wait_for(future, timeout=_TIMEOUT)
171+
result = await asyncio.wait_for(future, timeout=_TIMEOUT)
169172
except TimeoutError as ex:
173+
if self._health_manager:
174+
await self._health_manager.on_timeout()
170175
future.cancel()
171176
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
172177
finally:
173178
unsub()
174179

180+
if self._health_manager:
181+
await self._health_manager.on_success()
182+
return result
183+
175184

176185
def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:
177186
"""Create a V1 RPC channel using an MQTT channel."""
@@ -180,6 +189,7 @@ def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityDa
180189
mqtt_channel,
181190
lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data),
182191
decode_rpc_response,
192+
health_manager=HealthManager(mqtt_channel.restart),
183193
)
184194

185195

roborock/mqtt/health_manager.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""A health manager for monitoring MQTT connections to Roborock devices.
2+
3+
We observe a problem where sometimes the MQTT connection appears to be alive but
4+
no messages are being received. To mitigate this, we track consecutive timeouts
5+
and restart the connection if too many timeouts occur in succession.
6+
"""
7+
8+
import datetime
9+
from collections.abc import Awaitable, Callable
10+
11+
# Number of consecutive timeouts before considering the connection unhealthy.
12+
TIMEOUT_THRESHOLD = 3
13+
14+
# We won't restart the session more often than this interval.
15+
RESTART_COOLDOWN = datetime.timedelta(minutes=30)
16+
17+
18+
class HealthManager:
19+
"""Manager for monitoring the health of MQTT connections.
20+
21+
This tracks communication timeouts and can trigger restarts of the MQTT
22+
session if too many timeouts occur in succession.
23+
"""
24+
25+
def __init__(self, restart: Callable[[], Awaitable[None]]) -> None:
26+
"""Initialize the health manager.
27+
28+
Args:
29+
restart: A callable to restart the MQTT session.
30+
"""
31+
self._consecutive_timeouts = 0
32+
self._restart = restart
33+
self._last_restart: datetime.datetime | None = None
34+
35+
async def on_success(self) -> None:
36+
"""Record a successful communication event."""
37+
self._consecutive_timeouts = 0
38+
39+
async def on_timeout(self) -> None:
40+
"""Record a timeout event.
41+
42+
This may trigger a restart of the MQTT session if too many timeouts
43+
have occurred in succession.
44+
"""
45+
self._consecutive_timeouts += 1
46+
if self._consecutive_timeouts >= TIMEOUT_THRESHOLD:
47+
now = datetime.datetime.now(datetime.UTC)
48+
if self._last_restart is None or now - self._last_restart >= RESTART_COOLDOWN:
49+
await self._restart()
50+
self._last_restart = now
51+
self._consecutive_timeouts = 0

roborock/mqtt/roborock_session.py

Lines changed: 77 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ class RoborockMqttSession(MqttSession):
4949

5050
def __init__(self, params: MqttParams):
5151
self._params = params
52-
self._background_task: asyncio.Task[None] | None = None
52+
self._reconnect_task: asyncio.Task[None] | None = None
5353
self._healthy = False
5454
self._stop = False
5555
self._backoff = MIN_BACKOFF_INTERVAL
5656
self._client: aiomqtt.Client | None = None
5757
self._client_lock = asyncio.Lock()
5858
self._listeners: CallbackMap[str, bytes] = CallbackMap(_LOGGER)
59+
self._connection_task: asyncio.Task[None] | None = None
5960

6061
@property
6162
def connected(self) -> bool:
@@ -72,7 +73,7 @@ async def start(self) -> None:
7273
"""
7374
start_future: asyncio.Future[None] = asyncio.Future()
7475
loop = asyncio.get_event_loop()
75-
self._background_task = loop.create_task(self._run_task(start_future))
76+
self._reconnect_task = loop.create_task(self._run_reconnect_loop(start_future))
7677
try:
7778
await start_future
7879
except MqttError as err:
@@ -85,68 +86,93 @@ async def start(self) -> None:
8586
async def close(self) -> None:
8687
"""Cancels the MQTT loop and shutdown the client library."""
8788
self._stop = True
88-
if self._background_task:
89-
self._background_task.cancel()
90-
try:
91-
await self._background_task
92-
except asyncio.CancelledError:
93-
pass
94-
async with self._client_lock:
95-
if self._client:
96-
await self._client.close()
89+
tasks = [task for task in [self._connection_task, self._reconnect_task] if task]
90+
for task in tasks:
91+
task.cancel()
92+
try:
93+
await asyncio.gather(*tasks)
94+
except asyncio.CancelledError:
95+
pass
9796

9897
self._healthy = False
9998

100-
async def _run_task(self, start_future: asyncio.Future[None] | None) -> None:
99+
async def restart(self) -> None:
100+
"""Force the session to disconnect and reconnect.
101+
102+
The active connection task will be cancelled and restarted in the background, retried by
103+
the reconnect loop. This is a no-op if there is no active connection.
104+
"""
105+
_LOGGER.info("Forcing MQTT session restart")
106+
if self._connection_task:
107+
self._connection_task.cancel()
108+
else:
109+
_LOGGER.debug("No message loop task to cancel")
110+
111+
async def _run_reconnect_loop(self, start_future: asyncio.Future[None] | None) -> None:
101112
"""Run the MQTT loop."""
102113
_LOGGER.info("Starting MQTT session")
103114
while True:
104115
try:
105-
async with self._mqtt_client(self._params) as client:
106-
# Reset backoff once we've successfully connected
107-
self._backoff = MIN_BACKOFF_INTERVAL
108-
self._healthy = True
109-
_LOGGER.info("MQTT Session connected.")
110-
if start_future:
111-
start_future.set_result(None)
112-
start_future = None
113-
114-
await self._process_message_loop(client)
115-
116-
except MqttError as err:
117-
if start_future:
118-
_LOGGER.info("MQTT error starting session: %s", err)
119-
start_future.set_exception(err)
120-
return
121-
_LOGGER.info("MQTT error: %s", err)
122-
except asyncio.CancelledError as err:
123-
if start_future:
124-
_LOGGER.debug("MQTT loop was cancelled while starting")
125-
start_future.set_exception(err)
126-
_LOGGER.debug("MQTT loop was cancelled")
127-
return
128-
# Catch exceptions to avoid crashing the loop
129-
# and to allow the loop to retry.
130-
except Exception as err:
131-
# This error is thrown when the MQTT loop is cancelled
132-
# and the generator is not stopped.
133-
if "generator didn't stop" in str(err) or "generator didn't yield" in str(err):
134-
_LOGGER.debug("MQTT loop was cancelled")
135-
return
136-
if start_future:
137-
_LOGGER.error("Uncaught error starting MQTT session: %s", err)
138-
start_future.set_exception(err)
116+
self._connection_task = asyncio.create_task(self._run_connection(start_future))
117+
await self._connection_task
118+
except asyncio.CancelledError:
119+
_LOGGER.debug("MQTT connection task cancelled")
120+
except Exception:
121+
# Exceptions are logged and handled in _run_connection.
122+
# There is a special case for exceptions on startup where we return
123+
# immediately. Otherwise, we let the reconnect loop retry with
124+
# backoff when the reconnect loop is active.
125+
if start_future and start_future.done() and start_future.exception():
139126
return
140-
_LOGGER.exception("Uncaught error during MQTT session: %s", err)
141127

142128
self._healthy = False
129+
start_future = None
143130
if self._stop:
144131
_LOGGER.debug("MQTT session closed, stopping retry loop")
145132
return
146133
_LOGGER.info("MQTT session disconnected, retrying in %s seconds", self._backoff.total_seconds())
147134
await asyncio.sleep(self._backoff.total_seconds())
148135
self._backoff = min(self._backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL)
149136

137+
async def _run_connection(self, start_future: asyncio.Future[None] | None) -> None:
138+
"""Connect to the MQTT broker and listen for messages.
139+
140+
This is the primary connection loop for the MQTT session that is
141+
long running and processes incoming messages. If the connection
142+
is lost, this method will exit.
143+
"""
144+
try:
145+
async with self._mqtt_client(self._params) as client:
146+
self._backoff = MIN_BACKOFF_INTERVAL
147+
self._healthy = True
148+
_LOGGER.info("MQTT Session connected.")
149+
if start_future and not start_future.done():
150+
start_future.set_result(None)
151+
152+
_LOGGER.debug("Processing MQTT messages")
153+
async for message in client.messages:
154+
_LOGGER.debug("Received message: %s", message)
155+
self._listeners(message.topic.value, message.payload)
156+
except MqttError as err:
157+
if start_future and not start_future.done():
158+
_LOGGER.info("MQTT error starting session: %s", err)
159+
start_future.set_exception(err)
160+
else:
161+
_LOGGER.info("MQTT error: %s", err)
162+
raise
163+
except Exception as err:
164+
# This error is thrown when the MQTT loop is cancelled
165+
# and the generator is not stopped.
166+
if "generator didn't stop" in str(err) or "generator didn't yield" in str(err):
167+
_LOGGER.debug("MQTT loop was cancelled")
168+
return
169+
if start_future and not start_future.done():
170+
_LOGGER.error("Uncaught error starting MQTT session: %s", err)
171+
start_future.set_exception(err)
172+
else:
173+
_LOGGER.exception("Uncaught error during MQTT session: %s", err)
174+
raise
175+
150176
@asynccontextmanager
151177
async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
152178
"""Connect to the MQTT broker and listen for messages."""
@@ -178,12 +204,6 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
178204
async with self._client_lock:
179205
self._client = None
180206

181-
async def _process_message_loop(self, client: aiomqtt.Client) -> None:
182-
_LOGGER.debug("Processing MQTT messages")
183-
async for message in client.messages:
184-
_LOGGER.debug("Received message: %s", message)
185-
self._listeners(message.topic.value, message.payload)
186-
187207
async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
188208
"""Subscribe to messages on the specified topic and invoke the callback for new messages.
189209
@@ -271,6 +291,10 @@ async def close(self) -> None:
271291
"""
272292
await self._session.close()
273293

294+
async def restart(self) -> None:
295+
"""Force the session to disconnect and reconnect."""
296+
await self._session.restart()
297+
274298

275299
async def create_mqtt_session(params: MqttParams) -> MqttSession:
276300
"""Create an MQTT session.

roborock/mqtt/session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ async def publish(self, topic: str, message: bytes) -> None:
5454
This will raise an exception if the message could not be sent.
5555
"""
5656

57+
@abstractmethod
58+
async def restart(self) -> None:
59+
"""Force the session to disconnect and reconnect."""
60+
5761
@abstractmethod
5862
async def close(self) -> None:
5963
"""Cancels the mqtt loop"""

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def __init__(self):
363363
self.connect = AsyncMock(side_effect=self._connect)
364364
self.close = MagicMock(side_effect=self._close)
365365
self.protocol_version = LocalProtocolVersion.V1
366+
self.restart = AsyncMock()
366367

367368
async def _connect(self) -> None:
368369
self._is_connected = True

tests/mqtt/test_health_manager.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Tests for the health manager."""
2+
3+
import datetime
4+
from unittest.mock import AsyncMock, patch
5+
6+
from roborock.mqtt.health_manager import HealthManager
7+
8+
9+
async def test_health_manager_restart_called_after_timeouts() -> None:
10+
"""Test that the health manager calls restart after consecutive timeouts."""
11+
restart = AsyncMock()
12+
health_manager = HealthManager(restart=restart)
13+
14+
await health_manager.on_timeout()
15+
await health_manager.on_timeout()
16+
restart.assert_not_called()
17+
18+
await health_manager.on_timeout()
19+
restart.assert_called_once()
20+
21+
22+
async def test_health_manager_success_resets_counter() -> None:
23+
"""Test that a successful message resets the timeout counter."""
24+
restart = AsyncMock()
25+
health_manager = HealthManager(restart=restart)
26+
27+
await health_manager.on_timeout()
28+
await health_manager.on_timeout()
29+
restart.assert_not_called()
30+
31+
await health_manager.on_success()
32+
33+
await health_manager.on_timeout()
34+
await health_manager.on_timeout()
35+
restart.assert_not_called()
36+
37+
await health_manager.on_timeout()
38+
restart.assert_called_once()
39+
40+
41+
async def test_cooldown() -> None:
42+
"""Test that the health manager respects the restart cooldown."""
43+
restart = AsyncMock()
44+
health_manager = HealthManager(restart=restart)
45+
46+
with patch("roborock.mqtt.health_manager.datetime") as mock_datetime:
47+
now = datetime.datetime(2023, 1, 1, 12, 0, 0)
48+
mock_datetime.datetime.now.return_value = now
49+
50+
# Trigger first restart
51+
await health_manager.on_timeout()
52+
await health_manager.on_timeout()
53+
await health_manager.on_timeout()
54+
restart.assert_called_once()
55+
restart.reset_mock()
56+
57+
# Advance time but stay within cooldown (30 mins)
58+
mock_datetime.datetime.now.return_value = now + datetime.timedelta(minutes=10)
59+
60+
# Trigger timeouts again
61+
await health_manager.on_timeout()
62+
await health_manager.on_timeout()
63+
await health_manager.on_timeout()
64+
restart.assert_not_called()
65+
66+
# Advance time past cooldown
67+
mock_datetime.datetime.now.return_value = now + datetime.timedelta(minutes=31)
68+
69+
# Trigger timeouts again
70+
await health_manager.on_timeout()
71+
await health_manager.on_timeout()
72+
await health_manager.on_timeout()
73+
restart.assert_called_once()

0 commit comments

Comments
 (0)