diff --git a/roborock/mqtt/roborock_session.py b/roborock/mqtt/roborock_session.py index 53a337e7..b35e3910 100644 --- a/roborock/mqtt/roborock_session.py +++ b/roborock/mqtt/roborock_session.py @@ -24,7 +24,8 @@ _LOGGER = logging.getLogger(__name__) _MQTT_LOGGER = logging.getLogger(f"{__name__}.aiomqtt") -KEEPALIVE = 60 +CLIENT_KEEPALIVE = datetime.timedelta(seconds=120) +TOPIC_KEEPALIVE = datetime.timedelta(seconds=60) # Exponential backoff parameters MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10) @@ -47,7 +48,11 @@ class RoborockMqttSession(MqttSession): re-established. """ - def __init__(self, params: MqttParams): + def __init__( + self, + params: MqttParams, + topic_idle_timeout: datetime.timedelta = TOPIC_KEEPALIVE, + ): self._params = params self._reconnect_task: asyncio.Task[None] | None = None self._healthy = False @@ -57,6 +62,8 @@ def __init__(self, params: MqttParams): self._client_lock = asyncio.Lock() self._listeners: CallbackMap[str, bytes] = CallbackMap(_LOGGER) self._connection_task: asyncio.Task[None] | None = None + self._topic_idle_timeout = topic_idle_timeout + self._idle_timers: dict[str, asyncio.Task[None]] = {} @property def connected(self) -> bool: @@ -86,11 +93,15 @@ async def start(self) -> None: async def close(self) -> None: """Cancels the MQTT loop and shutdown the client library.""" self._stop = True - tasks = [task for task in [self._connection_task, self._reconnect_task] if task] + tasks = [task for task in [self._connection_task, self._reconnect_task, *self._idle_timers.values()] if task] + self._connection_task = None + self._reconnect_task = None + self._idle_timers.clear() + for task in tasks: task.cancel() try: - await asyncio.gather(*tasks) + await asyncio.gather(*tasks, return_exceptions=True) except asyncio.CancelledError: pass @@ -183,7 +194,7 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client: port=params.port, username=params.username, password=params.password, - keepalive=KEEPALIVE, + keepalive=int(CLIENT_KEEPALIVE.total_seconds()), protocol=aiomqtt.ProtocolVersion.V5, tls_params=TLSParameters() if params.tls else None, timeout=params.timeout, @@ -210,9 +221,17 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call The callback will be called with the message payload as a bytes object. The callback should not block since it runs in the async loop. It should not raise any exceptions. - The returned callable unsubscribes from the topic when called. + The returned callable unsubscribes from the topic when called, but will delay actual + unsubscription for the idle timeout period. If a new subscription comes in during the + timeout, the timer is cancelled and the subscription is reused. """ _LOGGER.debug("Subscribing to topic %s", topic) + + # If there is an idle timer for this topic, cancel it (reuse subscription) + if idle_timer := self._idle_timers.pop(topic, None): + idle_timer.cancel() + _LOGGER.debug("Cancelled idle timer for topic %s (reused subscription)", topic) + unsub = self._listeners.add_callback(topic, callback) async with self._client_lock: @@ -221,11 +240,41 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call try: await self._client.subscribe(topic) except MqttError as err: + # Clean up the callback if subscription fails + unsub() raise MqttSessionException(f"Error subscribing to topic: {err}") from err else: _LOGGER.debug("Client not connected, will establish subscription later") - return unsub + def schedule_unsubscribe(): + async def idle_unsubscribe(): + try: + await asyncio.sleep(self._topic_idle_timeout.total_seconds()) + # Only unsubscribe if there are no callbacks left for this topic + if not self._listeners.get_callbacks(topic): + async with self._client_lock: + if self._client: + _LOGGER.debug("Idle timeout expired, unsubscribing from topic %s", topic) + try: + await self._client.unsubscribe(topic) + except MqttError as err: + _LOGGER.warning("Error unsubscribing from topic %s: %s", topic, err) + # Clean up timer from dict + self._idle_timers.pop(topic, None) + except asyncio.CancelledError: + _LOGGER.debug("Idle unsubscribe for topic %s cancelled", topic) + + # Start the idle timer task + task = asyncio.create_task(idle_unsubscribe()) + self._idle_timers[topic] = task + + def delayed_unsub(): + unsub() # Remove the callback from CallbackMap + # If no more callbacks for this topic, start idle timer + if not self._listeners.get_callbacks(topic): + schedule_unsubscribe() + + return delayed_unsub async def publish(self, topic: str, message: bytes) -> None: """Publish a message on the topic.""" diff --git a/tests/mqtt/test_roborock_session.py b/tests/mqtt/test_roborock_session.py index bb3f1bdb..f3b10139 100644 --- a/tests/mqtt/test_roborock_session.py +++ b/tests/mqtt/test_roborock_session.py @@ -11,7 +11,7 @@ import paho.mqtt.client as mqtt import pytest -from roborock.mqtt.roborock_session import create_mqtt_session +from roborock.mqtt.roborock_session import RoborockMqttSession, create_mqtt_session from roborock.mqtt.session import MqttParams, MqttSessionException from tests import mqtt_packet from tests.conftest import FakeSocketHandler @@ -80,6 +80,23 @@ def fast_backoff_fixture() -> Generator[None, None, None]: yield +@pytest.fixture +def mock_mqtt_client() -> Generator[AsyncMock, None, None]: + """Fixture to create a mock MQTT client with patched aiomqtt.Client.""" + mock_client = AsyncMock() + mock_client.messages = FakeAsyncIterator() + + mock_aenter = AsyncMock() + mock_aenter.return_value = mock_client + + mock_shim = Mock() + mock_shim.return_value.__aenter__ = mock_aenter + mock_shim.return_value.__aexit__ = AsyncMock() + + with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim): + yield mock_client + + @pytest.fixture def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]: """Fixtures to push messages.""" @@ -195,52 +212,34 @@ async def __anext__(self) -> None: await asyncio.sleep(1) -async def test_publish_failure() -> None: +async def test_publish_failure(mock_mqtt_client: AsyncMock) -> None: """Test an MQTT error is received when publishing a message.""" - mock_client = AsyncMock() - mock_client.messages = FakeAsyncIterator() - - mock_aenter = AsyncMock() - mock_aenter.return_value = mock_client - - with patch("roborock.mqtt.roborock_session.aiomqtt.Client.__aenter__", mock_aenter): - session = await create_mqtt_session(FAKE_PARAMS) - assert session.connected + session = await create_mqtt_session(FAKE_PARAMS) + assert session.connected - mock_client.publish.side_effect = aiomqtt.MqttError + mock_mqtt_client.publish.side_effect = aiomqtt.MqttError - with pytest.raises(MqttSessionException, match="Error publishing message"): - await session.publish("topic-1", message=b"payload") + with pytest.raises(MqttSessionException, match="Error publishing message"): + await session.publish("topic-1", message=b"payload") - await session.close() + await session.close() -async def test_subscribe_failure() -> None: +async def test_subscribe_failure(mock_mqtt_client: AsyncMock) -> None: """Test an MQTT error while subscribing.""" - mock_client = AsyncMock() - mock_client.messages = FakeAsyncIterator() - - mock_aenter = AsyncMock() - mock_aenter.return_value = mock_client - - mock_shim = Mock() - mock_shim.return_value.__aenter__ = mock_aenter - mock_shim.return_value.__aexit__ = AsyncMock() - - with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim): - session = await create_mqtt_session(FAKE_PARAMS) - assert session.connected + session = await create_mqtt_session(FAKE_PARAMS) + assert session.connected - mock_client.subscribe.side_effect = aiomqtt.MqttError + mock_mqtt_client.subscribe.side_effect = aiomqtt.MqttError - subscriber1 = Subscriber() - with pytest.raises(MqttSessionException, match="Error subscribing to topic"): - await session.subscribe("topic-1", subscriber1.append) + subscriber1 = Subscriber() + with pytest.raises(MqttSessionException, match="Error subscribing to topic"): + await session.subscribe("topic-1", subscriber1.append) - assert not subscriber1.messages - await session.close() + assert not subscriber1.messages + await session.close() async def test_restart(push_response: Callable[[bytes], None]) -> None: @@ -279,3 +278,91 @@ async def test_restart(push_response: Callable[[bytes], None]) -> None: assert subscriber.messages == [b"12345", b"67890"] await session.close() + + +async def test_idle_timeout_resubscribe(mock_mqtt_client: AsyncMock) -> None: + """Test that resubscribing before idle timeout cancels the unsubscribe.""" + + # Create session with idle timeout + session = RoborockMqttSession(FAKE_PARAMS, topic_idle_timeout=datetime.timedelta(seconds=5)) + await session.start() + assert session.connected + + topic = "test/topic" + subscriber1 = Subscriber() + unsub1 = await session.subscribe(topic, subscriber1.append) + + # Unsubscribe to start idle timer + unsub1() + + # Resubscribe before idle timeout expires (should cancel timer) + subscriber2 = Subscriber() + await session.subscribe(topic, subscriber2.append) + + # Give a brief moment for any async operations to complete + await asyncio.sleep(0.01) + + # unsubscribe should NOT have been called because we resubscribed + mock_mqtt_client.unsubscribe.assert_not_called() + + await session.close() + + +async def test_idle_timeout_unsubscribe(mock_mqtt_client: AsyncMock) -> None: + """Test that unsubscribe happens after idle timeout expires.""" + + # Create session with very short idle timeout for fast test + session = RoborockMqttSession(FAKE_PARAMS, topic_idle_timeout=datetime.timedelta(milliseconds=50)) + await session.start() + assert session.connected + + topic = "test/topic" + subscriber = Subscriber() + unsub = await session.subscribe(topic, subscriber.append) + + # Unsubscribe to start idle timer + unsub() + + # Wait for idle timeout plus a small buffer + await asyncio.sleep(0.1) + + # unsubscribe should have been called after idle timeout + mock_mqtt_client.unsubscribe.assert_called_once_with(topic) + + await session.close() + + +async def test_idle_timeout_multiple_callbacks(mock_mqtt_client: AsyncMock) -> None: + """Test that unsubscribe is delayed when multiple subscribers exist.""" + + # Create session with very short idle timeout for fast test + session = RoborockMqttSession(FAKE_PARAMS, topic_idle_timeout=datetime.timedelta(milliseconds=50)) + await session.start() + assert session.connected + + topic = "test/topic" + subscriber1 = Subscriber() + subscriber2 = Subscriber() + + unsub1 = await session.subscribe(topic, subscriber1.append) + unsub2 = await session.subscribe(topic, subscriber2.append) + + # Unsubscribe first callback (should NOT start timer, subscriber2 still active) + unsub1() + + # Brief wait to ensure no timer fires + await asyncio.sleep(0.1) + + # unsubscribe should NOT have been called because subscriber2 is still active + mock_mqtt_client.unsubscribe.assert_not_called() + + # Unsubscribe second callback (NOW timer should start) + unsub2() + + # Wait for idle timeout plus a small buffer + await asyncio.sleep(0.1) + + # Now unsubscribe should have been called + mock_mqtt_client.unsubscribe.assert_called_once_with(topic) + + await session.close()