Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 56 additions & 7 deletions roborock/mqtt/roborock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
_LOGGER = logging.getLogger(__name__)
_MQTT_LOGGER = logging.getLogger(f"{__name__}.aiomqtt")

KEEPALIVE = 60
CLIENT_KEEPALIVE = datetime.timedelta(seconds=120)
TOPIC_KEEPALIVE = datetime.timedelta(seconds=60)

# Exponential backoff parameters
MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10)
Expand All @@ -47,7 +48,11 @@ class RoborockMqttSession(MqttSession):
re-established.
"""

def __init__(self, params: MqttParams):
def __init__(
self,
params: MqttParams,
topic_idle_timeout: datetime.timedelta = TOPIC_KEEPALIVE,
):
self._params = params
self._reconnect_task: asyncio.Task[None] | None = None
self._healthy = False
Expand All @@ -57,6 +62,8 @@ def __init__(self, params: MqttParams):
self._client_lock = asyncio.Lock()
self._listeners: CallbackMap[str, bytes] = CallbackMap(_LOGGER)
self._connection_task: asyncio.Task[None] | None = None
self._topic_idle_timeout = topic_idle_timeout
self._idle_timers: dict[str, asyncio.Task[None]] = {}

@property
def connected(self) -> bool:
Expand Down Expand Up @@ -86,11 +93,15 @@ async def start(self) -> None:
async def close(self) -> None:
"""Cancels the MQTT loop and shutdown the client library."""
self._stop = True
tasks = [task for task in [self._connection_task, self._reconnect_task] if task]
tasks = [task for task in [self._connection_task, self._reconnect_task, *self._idle_timers.values()] if task]
self._connection_task = None
self._reconnect_task = None
self._idle_timers.clear()

for task in tasks:
task.cancel()
try:
await asyncio.gather(*tasks)
await asyncio.gather(*tasks, return_exceptions=True)
except asyncio.CancelledError:
pass

Expand Down Expand Up @@ -183,7 +194,7 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
port=params.port,
username=params.username,
password=params.password,
keepalive=KEEPALIVE,
keepalive=int(CLIENT_KEEPALIVE.total_seconds()),
protocol=aiomqtt.ProtocolVersion.V5,
tls_params=TLSParameters() if params.tls else None,
timeout=params.timeout,
Expand All @@ -210,9 +221,17 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
The callback will be called with the message payload as a bytes object. The callback
should not block since it runs in the async loop. It should not raise any exceptions.

The returned callable unsubscribes from the topic when called.
The returned callable unsubscribes from the topic when called, but will delay actual
unsubscription for the idle timeout period. If a new subscription comes in during the
timeout, the timer is cancelled and the subscription is reused.
"""
_LOGGER.debug("Subscribing to topic %s", topic)

# If there is an idle timer for this topic, cancel it (reuse subscription)
if idle_timer := self._idle_timers.pop(topic, None):
idle_timer.cancel()
_LOGGER.debug("Cancelled idle timer for topic %s (reused subscription)", topic)

unsub = self._listeners.add_callback(topic, callback)

async with self._client_lock:
Expand All @@ -221,11 +240,41 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
try:
await self._client.subscribe(topic)
except MqttError as err:
# Clean up the callback if subscription fails
unsub()
raise MqttSessionException(f"Error subscribing to topic: {err}") from err
else:
_LOGGER.debug("Client not connected, will establish subscription later")

return unsub
def schedule_unsubscribe():
async def idle_unsubscribe():
try:
await asyncio.sleep(self._topic_idle_timeout.total_seconds())
# Only unsubscribe if there are no callbacks left for this topic
if not self._listeners.get_callbacks(topic):
async with self._client_lock:
if self._client:
_LOGGER.debug("Idle timeout expired, unsubscribing from topic %s", topic)
try:
await self._client.unsubscribe(topic)
except MqttError as err:
_LOGGER.warning("Error unsubscribing from topic %s: %s", topic, err)
Comment on lines +252 to +261
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A race condition exists in the idle timer cleanup logic. When the timer expires and checks if not self._listeners.get_callbacks(topic), another task could add a new callback between this check and the actual unsubscribe call. This could lead to unsubscribing from a topic that has active callbacks.

To fix this, the check and the unsubscribe operation should be atomic within the lock. Consider moving the get_callbacks check inside the async with self._client_lock: block before unsubscribing.

Copilot uses AI. Check for mistakes.
# Clean up timer from dict
self._idle_timers.pop(topic, None)
except asyncio.CancelledError:
_LOGGER.debug("Idle unsubscribe for topic %s cancelled", topic)

# Start the idle timer task
task = asyncio.create_task(idle_unsubscribe())
self._idle_timers[topic] = task

def delayed_unsub():
unsub() # Remove the callback from CallbackMap
# If no more callbacks for this topic, start idle timer
if not self._listeners.get_callbacks(topic):
schedule_unsubscribe()

return delayed_unsub

async def publish(self, topic: str, message: bytes) -> None:
"""Publish a message on the topic."""
Expand Down
157 changes: 122 additions & 35 deletions tests/mqtt/test_roborock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import paho.mqtt.client as mqtt
import pytest

from roborock.mqtt.roborock_session import create_mqtt_session
from roborock.mqtt.roborock_session import RoborockMqttSession, create_mqtt_session
from roborock.mqtt.session import MqttParams, MqttSessionException
from tests import mqtt_packet
from tests.conftest import FakeSocketHandler
Expand Down Expand Up @@ -80,6 +80,23 @@ def fast_backoff_fixture() -> Generator[None, None, None]:
yield


@pytest.fixture
def mock_mqtt_client() -> Generator[AsyncMock, None, None]:
"""Fixture to create a mock MQTT client with patched aiomqtt.Client."""
mock_client = AsyncMock()
mock_client.messages = FakeAsyncIterator()

mock_aenter = AsyncMock()
mock_aenter.return_value = mock_client

mock_shim = Mock()
mock_shim.return_value.__aenter__ = mock_aenter
mock_shim.return_value.__aexit__ = AsyncMock()

with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim):
yield mock_client


@pytest.fixture
def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]:
"""Fixtures to push messages."""
Expand Down Expand Up @@ -195,52 +212,34 @@ async def __anext__(self) -> None:
await asyncio.sleep(1)


async def test_publish_failure() -> None:
async def test_publish_failure(mock_mqtt_client: AsyncMock) -> None:
"""Test an MQTT error is received when publishing a message."""

mock_client = AsyncMock()
mock_client.messages = FakeAsyncIterator()

mock_aenter = AsyncMock()
mock_aenter.return_value = mock_client

with patch("roborock.mqtt.roborock_session.aiomqtt.Client.__aenter__", mock_aenter):
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected

mock_client.publish.side_effect = aiomqtt.MqttError
mock_mqtt_client.publish.side_effect = aiomqtt.MqttError

with pytest.raises(MqttSessionException, match="Error publishing message"):
await session.publish("topic-1", message=b"payload")
with pytest.raises(MqttSessionException, match="Error publishing message"):
await session.publish("topic-1", message=b"payload")

await session.close()
await session.close()


async def test_subscribe_failure() -> None:
async def test_subscribe_failure(mock_mqtt_client: AsyncMock) -> None:
"""Test an MQTT error while subscribing."""

mock_client = AsyncMock()
mock_client.messages = FakeAsyncIterator()

mock_aenter = AsyncMock()
mock_aenter.return_value = mock_client

mock_shim = Mock()
mock_shim.return_value.__aenter__ = mock_aenter
mock_shim.return_value.__aexit__ = AsyncMock()

with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim):
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected

mock_client.subscribe.side_effect = aiomqtt.MqttError
mock_mqtt_client.subscribe.side_effect = aiomqtt.MqttError

subscriber1 = Subscriber()
with pytest.raises(MqttSessionException, match="Error subscribing to topic"):
await session.subscribe("topic-1", subscriber1.append)
subscriber1 = Subscriber()
with pytest.raises(MqttSessionException, match="Error subscribing to topic"):
await session.subscribe("topic-1", subscriber1.append)

assert not subscriber1.messages
await session.close()
assert not subscriber1.messages
await session.close()


async def test_restart(push_response: Callable[[bytes], None]) -> None:
Expand Down Expand Up @@ -279,3 +278,91 @@ async def test_restart(push_response: Callable[[bytes], None]) -> None:
assert subscriber.messages == [b"12345", b"67890"]

await session.close()


async def test_idle_timeout_resubscribe(mock_mqtt_client: AsyncMock) -> None:
"""Test that resubscribing before idle timeout cancels the unsubscribe."""

# Create session with idle timeout
session = RoborockMqttSession(FAKE_PARAMS, topic_idle_timeout=datetime.timedelta(seconds=5))
await session.start()
assert session.connected

topic = "test/topic"
subscriber1 = Subscriber()
unsub1 = await session.subscribe(topic, subscriber1.append)

# Unsubscribe to start idle timer
unsub1()

# Resubscribe before idle timeout expires (should cancel timer)
subscriber2 = Subscriber()
await session.subscribe(topic, subscriber2.append)

# Give a brief moment for any async operations to complete
await asyncio.sleep(0.01)

# unsubscribe should NOT have been called because we resubscribed
mock_mqtt_client.unsubscribe.assert_not_called()

await session.close()


async def test_idle_timeout_unsubscribe(mock_mqtt_client: AsyncMock) -> None:
"""Test that unsubscribe happens after idle timeout expires."""

# Create session with very short idle timeout for fast test
session = RoborockMqttSession(FAKE_PARAMS, topic_idle_timeout=datetime.timedelta(milliseconds=50))
await session.start()
assert session.connected

topic = "test/topic"
subscriber = Subscriber()
unsub = await session.subscribe(topic, subscriber.append)

# Unsubscribe to start idle timer
unsub()

# Wait for idle timeout plus a small buffer
await asyncio.sleep(0.1)

# unsubscribe should have been called after idle timeout
mock_mqtt_client.unsubscribe.assert_called_once_with(topic)

await session.close()


async def test_idle_timeout_multiple_callbacks(mock_mqtt_client: AsyncMock) -> None:
"""Test that unsubscribe is delayed when multiple subscribers exist."""

# Create session with very short idle timeout for fast test
session = RoborockMqttSession(FAKE_PARAMS, topic_idle_timeout=datetime.timedelta(milliseconds=50))
await session.start()
assert session.connected

topic = "test/topic"
subscriber1 = Subscriber()
subscriber2 = Subscriber()

unsub1 = await session.subscribe(topic, subscriber1.append)
unsub2 = await session.subscribe(topic, subscriber2.append)

# Unsubscribe first callback (should NOT start timer, subscriber2 still active)
unsub1()

# Brief wait to ensure no timer fires
await asyncio.sleep(0.1)

# unsubscribe should NOT have been called because subscriber2 is still active
mock_mqtt_client.unsubscribe.assert_not_called()

# Unsubscribe second callback (NOW timer should start)
unsub2()

# Wait for idle timeout plus a small buffer
await asyncio.sleep(0.1)

# Now unsubscribe should have been called
mock_mqtt_client.unsubscribe.assert_called_once_with(topic)

await session.close()