Skip to content

Commit e2cc8de

Browse files
committed
fix: Fix bugs in the subscription idle timeout
We need to actually track which subscriptionsare active and don't add duplicate subscriptions in those cases. We keep a separate object to track the subscription state since it is different than the callback logic (e.g. subscribe callback is removed from the list when unsubscribe happens)
1 parent 11f362e commit e2cc8de

File tree

2 files changed

+70
-15
lines changed

2 files changed

+70
-15
lines changed

roborock/mqtt/roborock_session.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
self._stop = False
7070
self._backoff = MIN_BACKOFF_INTERVAL
7171
self._client: aiomqtt.Client | None = None
72+
self._client_subscribed_topics: set[str] = set()
7273
self._client_lock = asyncio.Lock()
7374
self._listeners: CallbackMap[str, bytes] = CallbackMap(_LOGGER)
7475
self._connection_task: asyncio.Task[None] | None = None
@@ -218,7 +219,7 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
218219
# Re-establish any existing subscriptions
219220
async with self._client_lock:
220221
self._client = client
221-
for topic in self._listeners.keys():
222+
for topic in self._client_subscribed_topics:
222223
_LOGGER.debug("Re-establishing subscription to topic %s", topic)
223224
# TODO: If this fails it will break the whole connection. Make
224225
# this retry again in the background with backoff.
@@ -249,32 +250,42 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
249250
unsub = self._listeners.add_callback(topic, callback)
250251

251252
async with self._client_lock:
252-
if self._client:
253-
_LOGGER.debug("Establishing subscription to topic %s", topic)
254-
try:
255-
await self._client.subscribe(topic)
256-
except MqttError as err:
257-
# Clean up the callback if subscription fails
258-
unsub()
259-
raise MqttSessionException(f"Error subscribing to topic: {err}") from err
260-
else:
261-
_LOGGER.debug("Client not connected, will establish subscription later")
262-
263-
def schedule_unsubscribe():
253+
if topic not in self._client_subscribed_topics:
254+
self._client_subscribed_topics.add(topic)
255+
if self._client:
256+
_LOGGER.debug("Establishing subscription to topic %s", topic)
257+
try:
258+
await self._client.subscribe(topic)
259+
except MqttError as err:
260+
# Clean up the callback if subscription fails
261+
unsub()
262+
self._client_subscribed_topics.discard(topic)
263+
raise MqttSessionException(f"Error subscribing to topic: {err}") from err
264+
else:
265+
_LOGGER.debug("Client not connected, will establish subscription later")
266+
267+
def schedule_unsubscribe() -> None:
264268
async def idle_unsubscribe():
265269
try:
266270
await asyncio.sleep(self._topic_idle_timeout.total_seconds())
267271
# Only unsubscribe if there are no callbacks left for this topic
268272
if not self._listeners.get_callbacks(topic):
269273
async with self._client_lock:
274+
# Check again if we have listeners, in case a subscribe happened
275+
# while we were waiting for the lock or after we popped the timer.
276+
if self._listeners.get_callbacks(topic):
277+
_LOGGER.debug("Skipping unsubscribe for %s, new listeners added", topic)
278+
return
279+
280+
self._idle_timers.pop(topic, None)
281+
self._client_subscribed_topics.discard(topic)
282+
270283
if self._client:
271284
_LOGGER.debug("Idle timeout expired, unsubscribing from topic %s", topic)
272285
try:
273286
await self._client.unsubscribe(topic)
274287
except MqttError as err:
275288
_LOGGER.warning("Error unsubscribing from topic %s: %s", topic, err)
276-
# Clean up timer from dict
277-
self._idle_timers.pop(topic, None)
278289
except asyncio.CancelledError:
279290
_LOGGER.debug("Idle unsubscribe for topic %s cancelled", topic)
280291

@@ -286,7 +297,10 @@ def delayed_unsub():
286297
unsub() # Remove the callback from CallbackMap
287298
# If no more callbacks for this topic, start idle timer
288299
if not self._listeners.get_callbacks(topic):
300+
_LOGGER.debug("Unsubscribing topic %s, starting idle timer", topic)
289301
schedule_unsubscribe()
302+
else:
303+
_LOGGER.debug("Unsubscribing topic %s, still have active callbacks", topic)
290304

291305
return delayed_unsub
292306

tests/mqtt/test_roborock_session.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,47 @@ async def test_idle_timeout_multiple_callbacks(mock_mqtt_client: AsyncMock) -> N
368368
await session.close()
369369

370370

371+
async def test_subscription_reuse(mock_mqtt_client: AsyncMock) -> None:
372+
"""Test that subscriptions are reused and not duplicated."""
373+
session = RoborockMqttSession(FAKE_PARAMS)
374+
await session.start()
375+
assert session.connected
376+
377+
# 1. First subscription
378+
cb1 = Mock()
379+
unsub1 = await session.subscribe("topic1", cb1)
380+
381+
# Verify subscribe called
382+
mock_mqtt_client.subscribe.assert_called_with("topic1")
383+
mock_mqtt_client.subscribe.reset_mock()
384+
385+
# 2. Second subscription (same topic)
386+
cb2 = Mock()
387+
unsub2 = await session.subscribe("topic1", cb2)
388+
389+
# Verify subscribe NOT called
390+
mock_mqtt_client.subscribe.assert_not_called()
391+
392+
# 3. Unsubscribe one
393+
unsub1()
394+
# Verify unsubscribe NOT called (still have cb2)
395+
mock_mqtt_client.unsubscribe.assert_not_called()
396+
397+
# 4. Unsubscribe second (starts idle timer)
398+
unsub2()
399+
# Verify unsubscribe NOT called yet (idle)
400+
mock_mqtt_client.unsubscribe.assert_not_called()
401+
402+
# 5. Resubscribe during idle
403+
cb3 = Mock()
404+
_ = await session.subscribe("topic1", cb3)
405+
406+
# Verify subscribe NOT called (reused)
407+
mock_mqtt_client.subscribe.assert_not_called()
408+
409+
await session.close()
410+
411+
371412
@pytest.mark.parametrize(
372413
("side_effect", "expected_exception", "match"),
373414
[

0 commit comments

Comments
 (0)