Skip to content

Commit 9113c36

Browse files
committed
Add a hook for handling background rate limit errors
When a rate limit error is received the background loop will jump to the maximum backoff (now 6 hours) and will also invoke a callback so that the caller can decide to re-authenticate or stop harder. The new backoff follows this trajectory, where the change in behavior is introduced after 15 attempts: - attempt 1: wait 10 seconds - attempt 5: waits 50 seconds - attempt 7: waits 2 minutes - attempt 10: waits 6 minutes - attempt 15: waits 32 minutes - attempt 17: waits 2 hours - attempt 20: waits 6 hours
1 parent a8f5d06 commit 9113c36

File tree

4 files changed

+142
-23
lines changed

4 files changed

+142
-23
lines changed

roborock/devices/device_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from roborock.exceptions import RoborockException
2121
from roborock.map.map_parser import MapParserConfig
2222
from roborock.mqtt.roborock_session import create_lazy_mqtt_session
23-
from roborock.mqtt.session import MqttSession
23+
from roborock.mqtt.session import MqttSession, SessionUnauthorizedHook
2424
from roborock.protocol import create_mqtt_params
2525
from roborock.web_api import RoborockApiClient, UserWebApiClient
2626

@@ -173,6 +173,7 @@ async def create_device_manager(
173173
map_parser_config: MapParserConfig | None = None,
174174
session: aiohttp.ClientSession | None = None,
175175
ready_callback: DeviceReadyCallback | None = None,
176+
mqtt_session_unauthorized_hook: SessionUnauthorizedHook | None = None,
176177
) -> DeviceManager:
177178
"""Convenience function to create and initialize a DeviceManager.
178179
@@ -196,6 +197,7 @@ async def create_device_manager(
196197

197198
mqtt_params = create_mqtt_params(user_data.rriot)
198199
mqtt_params.diagnostics = diagnostics.subkey("mqtt_session")
200+
mqtt_params.unauthorized_hook = mqtt_session_unauthorized_hook
199201
mqtt_session = await create_lazy_mqtt_session(mqtt_params)
200202

201203
def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:

roborock/mqtt/roborock_session.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
# Exponential backoff parameters
3333
MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10)
34-
MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=30)
34+
MAX_BACKOFF_INTERVAL = datetime.timedelta(hours=6)
3535
BACKOFF_MULTIPLIER = 1.5
3636

3737

@@ -79,6 +79,7 @@ def __init__(
7979
self._idle_timers: dict[str, asyncio.Task[None]] = {}
8080
self._diagnostics = params.diagnostics
8181
self._health_manager = HealthManager(self.restart)
82+
self._unauthorized_hook = params.unauthorized_hook
8283

8384
@property
8485
def connected(self) -> bool:
@@ -199,14 +200,28 @@ async def _run_connection(self, start_future: asyncio.Future[None] | None) -> No
199200
_LOGGER.debug("Received message: %s", message)
200201
with self._diagnostics.timer("dispatch_message"):
201202
self._listeners(message.topic.value, message.payload)
203+
except MqttCodeError as err:
204+
self._diagnostics.increment(f"connect_failure:{err.rc}")
205+
if start_future and not start_future.done():
206+
_LOGGER.debug("MQTT error starting session: %s", err)
207+
start_future.set_exception(err)
208+
else:
209+
_LOGGER.debug("MQTT error: %s", err)
210+
if err.rc == MqttReasonCode.RC_ERROR_UNAUTHORIZED and self._unauthorized_hook:
211+
_LOGGER.info("MQTT unauthorized/rate-limit error received, setting backoff to maximum")
212+
self._unauthorized_hook()
213+
self._backoff = MAX_BACKOFF_INTERVAL
214+
raise
202215
except MqttError as err:
216+
self._diagnostics.increment("connect_failure:unknown")
203217
if start_future and not start_future.done():
204218
_LOGGER.info("MQTT error starting session: %s", err)
205219
start_future.set_exception(err)
206220
else:
207221
_LOGGER.info("MQTT error: %s", err)
208222
raise
209-
except Exception as err:
223+
except (RuntimeError, Exception) as err:
224+
self._diagnostics.increment("connect_failure:uncaught")
210225
# This error is thrown when the MQTT loop is cancelled
211226
# and the generator is not stopped.
212227
if "generator didn't stop" in str(err) or "generator didn't yield" in str(err):

roborock/mqtt/session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
DEFAULT_TIMEOUT = 30.0
1212

13+
SessionUnauthorizedHook = Callable[[], None]
14+
1315

1416
@dataclass
1517
class MqttParams:
@@ -41,6 +43,14 @@ class MqttParams:
4143
shared MQTT session diagnostics are included in the overall diagnostics.
4244
"""
4345

46+
unauthorized_hook: SessionUnauthorizedHook | None = None
47+
"""Optional hook invoked when an unauthorized error is received.
48+
49+
This may be invoked by the background reconnect logic when an
50+
unauthorized error is received from the broker. The caller may use
51+
this hook to refresh credentials or take other actions as needed.
52+
"""
53+
4454

4555
class MqttSession(ABC):
4656
"""An MQTT session for sending and receiving messages."""

tests/mqtt/test_roborock_session.py

Lines changed: 112 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import copy
55
import datetime
66
from collections.abc import Callable, Generator
7+
from typing import Any
78
from unittest.mock import AsyncMock, Mock, patch
89

910
import aiomqtt
@@ -31,17 +32,62 @@ def auto_fast_backoff(fast_backoff_fixture: None) -> None:
3132
"""Automatically use the fast backoff fixture."""
3233

3334

34-
@pytest.fixture
35-
def mock_mqtt_client() -> Generator[AsyncMock, None, None]:
36-
"""Fixture to create a mock MQTT client with patched aiomqtt.Client."""
35+
class FakeAsyncIterator:
36+
"""Fake async iterator that waits for messages to arrive, but they never do.
37+
38+
This is used for testing exceptions in other client functions.
39+
"""
40+
41+
def __init__(self) -> None:
42+
self.loop = True
43+
44+
def __aiter__(self):
45+
return self
46+
47+
async def __anext__(self) -> None:
48+
"""Iterator that does not generate any messages."""
49+
while self.loop:
50+
await asyncio.sleep(0.01)
51+
52+
53+
@pytest.fixture(name="message_iterator")
54+
def message_iterator_fixture() -> FakeAsyncIterator:
55+
"""Fixture to provide a side effect for creating the MQTT client."""
56+
return FakeAsyncIterator()
57+
58+
59+
@pytest.fixture(name="mock_client")
60+
def mock_client_fixture(message_iterator: FakeAsyncIterator) -> AsyncMock:
61+
"""Fixture to provide a side effect for creating the MQTT client."""
3762
mock_client = AsyncMock()
38-
mock_client.messages = FakeAsyncIterator()
63+
mock_client.messages = message_iterator
64+
return mock_client
65+
66+
67+
@pytest.fixture(name="create_client_side_effect")
68+
def create_client_side_effect_fixture() -> Exception | None:
69+
"""Fixture to provide a side effect for creating the MQTT client."""
70+
return None
71+
3972

73+
@pytest.fixture(name="mock_aenter_client")
74+
def mock_aenter_client_fixture(mock_client: AsyncMock, create_client_side_effect: Exception | None) -> AsyncMock:
75+
"""Fixture to provide a side effect for creating the MQTT client."""
4076
mock_aenter = AsyncMock()
4177
mock_aenter.return_value = mock_client
78+
mock_aenter.side_effect = create_client_side_effect
79+
return mock_aenter
80+
81+
82+
@pytest.fixture
83+
def mock_mqtt_client(
84+
mock_client: AsyncMock,
85+
mock_aenter_client: AsyncMock,
86+
) -> Generator[AsyncMock, None, None]:
87+
"""Fixture to create a mock MQTT client with patched aiomqtt.Client."""
4288

4389
mock_shim = Mock()
44-
mock_shim.return_value.__aenter__ = mock_aenter
90+
mock_shim.return_value.__aenter__ = mock_aenter_client
4591
mock_shim.return_value.__aexit__ = AsyncMock()
4692

4793
with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim):
@@ -114,21 +160,6 @@ async def test_publish_command(push_response: Callable[[bytes], None]) -> None:
114160
assert not session.connected
115161

116162

117-
class FakeAsyncIterator:
118-
"""Fake async iterator that waits for messages to arrive, but they never do.
119-
120-
This is used for testing exceptions in other client functions.
121-
"""
122-
123-
def __aiter__(self):
124-
return self
125-
126-
async def __anext__(self) -> None:
127-
"""Iterator that does not generate any messages."""
128-
while True:
129-
await asyncio.sleep(1)
130-
131-
132163
async def test_publish_failure(mock_mqtt_client: AsyncMock) -> None:
133164
"""Test an MQTT error is received when publishing a message."""
134165

@@ -432,3 +463,64 @@ async def test_diagnostics_data(push_response: Callable[[bytes], None]) -> None:
432463
assert data.get("subscribe_count") == 2
433464
assert data.get("dispatch_message_count") == 3
434465
assert data.get("close") == 1
466+
467+
468+
@pytest.mark.parametrize(
469+
("create_client_side_effect"),
470+
[
471+
# Unauthorized
472+
aiomqtt.MqttCodeError(rc=135),
473+
],
474+
)
475+
async def test_session_unauthorized_hook(mock_mqtt_client: AsyncMock, push_response: Callable[[bytes], None]) -> None:
476+
"""Test the MQTT session."""
477+
478+
unauthorized = asyncio.Event()
479+
480+
params = copy.deepcopy(FAKE_PARAMS)
481+
params.unauthorized_hook = unauthorized.set
482+
483+
with pytest.raises(MqttSessionUnauthorized):
484+
await create_mqtt_session(params)
485+
486+
assert unauthorized.is_set()
487+
488+
489+
async def test_session_unauthorized_after_start(
490+
mock_aenter_client: AsyncMock,
491+
message_iterator: FakeAsyncIterator,
492+
mock_mqtt_client: AsyncMock,
493+
push_response: Callable[[bytes], None],
494+
) -> None:
495+
"""Test the MQTT session."""
496+
497+
# Configure a hook that is notified of unauthorized errors
498+
unauthorized = asyncio.Event()
499+
params = copy.deepcopy(FAKE_PARAMS)
500+
params.unauthorized_hook = unauthorized.set
501+
502+
# The client will succeed on first connection attempt, then fail with
503+
# unauthorized messages on all future attempts.
504+
request_count = 0
505+
506+
def succeed_then_fail_unauthorized() -> Any:
507+
nonlocal request_count
508+
request_count += 1
509+
if request_count == 1:
510+
return mock_mqtt_client
511+
raise aiomqtt.MqttCodeError(rc=135)
512+
513+
mock_aenter_client.side_effect = succeed_then_fail_unauthorized
514+
# Don't produce messages, just exit and restart to reconnect
515+
message_iterator.loop = False
516+
517+
push_response(mqtt_packet.gen_connack(rc=0, flags=2))
518+
519+
session = await create_mqtt_session(params)
520+
assert session.connected
521+
522+
try:
523+
async with asyncio.timeout(10):
524+
assert await unauthorized.wait()
525+
finally:
526+
await session.close()

0 commit comments

Comments
 (0)