Skip to content

Commit 509ff6a

Browse files
authored
feat: Update device manager and device to establish an MQTT subscription (#409)
* feat: Update device manager and device to establish an MQTT subscription * feat: Add test coverage to device modules * feat: Add test coverage for device manager close * feat: Update roborock/devices/mqtt_channel.py * feat: Apply suggestions from code review * feat: Add support for sending/recieving messages * feat: Simplify rpc handling and tests * feat: Gather tasks * feat: Add debug lines
1 parent 69114b2 commit 509ff6a

File tree

7 files changed

+541
-43
lines changed

7 files changed

+541
-43
lines changed

roborock/cli.py

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from pyshark.packet.packet import Packet # type: ignore
1313

1414
from roborock import RoborockException
15-
from roborock.containers import DeviceData, HomeDataProduct, LoginData
16-
from roborock.mqtt.roborock_session import create_mqtt_session
17-
from roborock.protocol import MessageParser, create_mqtt_params
15+
from roborock.containers import DeviceData, HomeData, HomeDataProduct, LoginData
16+
from roborock.devices.device_manager import create_device_manager, create_home_data_api
17+
from roborock.protocol import MessageParser
1818
from roborock.util import run_sync
1919
from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1
2020
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
@@ -101,44 +101,25 @@ async def session(ctx, duration: int):
101101
context: RoborockContext = ctx.obj
102102
login_data = context.login_data()
103103

104-
# Discovery devices if not already available
105-
if not login_data.home_data:
106-
await _discover(ctx)
107-
login_data = context.login_data()
108-
if not login_data.home_data or not login_data.home_data.devices:
109-
raise RoborockException("Unable to discover devices")
110-
111-
all_devices = login_data.home_data.devices + login_data.home_data.received_devices
112-
click.echo(f"Discovered devices: {', '.join([device.name for device in all_devices])}")
113-
114-
rriot = login_data.user_data.rriot
115-
params = create_mqtt_params(rriot)
116-
117-
mqtt_session = await create_mqtt_session(params)
118-
click.echo("Starting MQTT session...")
119-
if not mqtt_session.connected:
120-
raise RoborockException("Failed to connect to MQTT broker")
104+
home_data_api = create_home_data_api(login_data.email, login_data.user_data)
121105

122-
def on_message(bytes: bytes):
123-
"""Callback function to handle incoming MQTT messages."""
124-
# Decode the first 20 bytes of the message for display
125-
bytes = bytes[:20]
106+
async def home_data_cache() -> HomeData:
107+
if login_data.home_data is None:
108+
login_data.home_data = await home_data_api()
109+
context.update(login_data)
110+
return login_data.home_data
126111

127-
click.echo(f"Received message: {bytes}...")
112+
# Create device manager
113+
device_manager = await create_device_manager(login_data.user_data, home_data_cache)
128114

129-
unsubs = []
130-
for device in all_devices:
131-
device_topic = f"rr/m/o/{rriot.u}/{params.username}/{device.duid}"
132-
unsub = await mqtt_session.subscribe(device_topic, on_message)
133-
unsubs.append(unsub)
115+
devices = await device_manager.get_devices()
116+
click.echo(f"Discovered devices: {', '.join([device.name for device in devices])}")
134117

135118
click.echo("MQTT session started. Listening for messages...")
136119
await asyncio.sleep(duration)
137120

138-
click.echo("Stopping MQTT session...")
139-
for unsub in unsubs:
140-
unsub()
141-
await mqtt_session.close()
121+
# Close the device manager (this will close all devices and MQTT session)
122+
await device_manager.close()
142123

143124

144125
async def _discover(ctx):

roborock/devices/device.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66

77
import enum
88
import logging
9+
from collections.abc import Callable
910
from functools import cached_property
1011

1112
from roborock.containers import HomeDataDevice, HomeDataProduct, UserData
13+
from roborock.roborock_message import RoborockMessage
14+
15+
from .mqtt_channel import MqttChannel
1216

1317
_LOGGER = logging.getLogger(__name__)
1418

@@ -29,11 +33,25 @@ class DeviceVersion(enum.StrEnum):
2933
class RoborockDevice:
3034
"""Unified Roborock device class with automatic connection setup."""
3135

32-
def __init__(self, user_data: UserData, device_info: HomeDataDevice, product_info: HomeDataProduct) -> None:
33-
"""Initialize the RoborockDevice with device info, user data, and capabilities."""
36+
def __init__(
37+
self,
38+
user_data: UserData,
39+
device_info: HomeDataDevice,
40+
product_info: HomeDataProduct,
41+
mqtt_channel: MqttChannel,
42+
) -> None:
43+
"""Initialize the RoborockDevice.
44+
45+
The device takes ownership of the MQTT channel for communication with the device.
46+
Use `connect()` to establish the connection, which will set up the MQTT channel
47+
for receiving messages from the device. Use `close()` to unsubscribe from the MQTT
48+
channel.
49+
"""
3450
self._user_data = user_data
3551
self._device_info = device_info
3652
self._product_info = product_info
53+
self._mqtt_channel = mqtt_channel
54+
self._unsub: Callable[[], None] | None = None
3755

3856
@property
3957
def duid(self) -> str:
@@ -63,3 +81,28 @@ def device_version(self) -> str:
6381
self._device_info.name,
6482
)
6583
return DeviceVersion.UNKNOWN
84+
85+
async def connect(self) -> None:
86+
"""Connect to the device using MQTT.
87+
88+
This method will set up the MQTT channel for communication with the device.
89+
"""
90+
if self._unsub:
91+
raise ValueError("Already connected to the device")
92+
self._unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
93+
94+
async def close(self) -> None:
95+
"""Close the MQTT connection to the device.
96+
97+
This method will unsubscribe from the MQTT channel and clean up resources.
98+
"""
99+
if self._unsub:
100+
self._unsub()
101+
self._unsub = None
102+
103+
def _on_mqtt_message(self, message: RoborockMessage) -> None:
104+
"""Handle incoming MQTT messages from the device.
105+
106+
This method should be overridden in subclasses to handle specific device messages.
107+
"""
108+
_LOGGER.debug("Received message from device %s: %s", self.duid, message)

roborock/devices/device_manager.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Module for discovering Roborock devices."""
22

3+
import asyncio
34
import logging
45
from collections.abc import Awaitable, Callable
56

@@ -10,8 +11,13 @@
1011
UserData,
1112
)
1213
from roborock.devices.device import RoborockDevice
14+
from roborock.mqtt.roborock_session import create_mqtt_session
15+
from roborock.mqtt.session import MqttSession
16+
from roborock.protocol import create_mqtt_params
1317
from roborock.web_api import RoborockApiClient
1418

19+
from .mqtt_channel import MqttChannel
20+
1521
_LOGGER = logging.getLogger(__name__)
1622

1723
__all__ = [
@@ -34,21 +40,33 @@ def __init__(
3440
self,
3541
home_data_api: HomeDataApi,
3642
device_creator: DeviceCreator,
43+
mqtt_session: MqttSession,
3744
) -> None:
38-
"""Initialize the DeviceManager with user data and optional cache storage."""
45+
"""Initialize the DeviceManager with user data and optional cache storage.
46+
47+
This takes ownership of the MQTT session and will close it when the manager is closed.
48+
"""
3949
self._home_data_api = home_data_api
4050
self._device_creator = device_creator
4151
self._devices: dict[str, RoborockDevice] = {}
52+
self._mqtt_session = mqtt_session
4253

4354
async def discover_devices(self) -> list[RoborockDevice]:
4455
"""Discover all devices for the logged-in user."""
4556
home_data = await self._home_data_api()
4657
device_products = home_data.device_products
4758
_LOGGER.debug("Discovered %d devices %s", len(device_products), home_data)
4859

49-
self._devices = {
50-
duid: self._device_creator(device, product) for duid, (device, product) in device_products.items()
51-
}
60+
# These are connected serially to avoid overwhelming the MQTT broker
61+
new_devices = {}
62+
for duid, (device, product) in device_products.items():
63+
if duid in self._devices:
64+
continue
65+
new_device = self._device_creator(device, product)
66+
await new_device.connect()
67+
new_devices[duid] = new_device
68+
69+
self._devices.update(new_devices)
5270
return list(self._devices.values())
5371

5472
async def get_device(self, duid: str) -> RoborockDevice | None:
@@ -59,6 +77,13 @@ async def get_devices(self) -> list[RoborockDevice]:
5977
"""Get all discovered devices."""
6078
return list(self._devices.values())
6179

80+
async def close(self) -> None:
81+
"""Close all MQTT connections and clean up resources."""
82+
tasks = [device.close() for device in self._devices.values()]
83+
self._devices.clear()
84+
tasks.append(self._mqtt_session.close())
85+
await asyncio.gather(*tasks)
86+
6287

6388
def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi:
6489
"""Create a home data API wrapper.
@@ -67,7 +92,9 @@ def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi:
6792
home data for the user.
6893
"""
6994

70-
client = RoborockApiClient(email, user_data)
95+
# Note: This will auto discover the API base URL. This can be improved
96+
# by caching this next to `UserData` if needed to avoid unnecessary API calls.
97+
client = RoborockApiClient(email)
7198

7299
async def home_data_api() -> HomeData:
73100
return await client.get_home_data(user_data)
@@ -83,9 +110,13 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi)
83110
include caching or other optimizations.
84111
"""
85112

113+
mqtt_params = create_mqtt_params(user_data.rriot)
114+
mqtt_session = await create_mqtt_session(mqtt_params)
115+
86116
def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
87-
return RoborockDevice(user_data, device, product)
117+
mqtt_channel = MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)
118+
return RoborockDevice(user_data, device, product, mqtt_channel)
88119

89-
manager = DeviceManager(home_data_api, device_creator)
120+
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session)
90121
await manager.discover_devices()
91122
return manager

roborock/devices/mqtt_channel.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Modules for communicating with specific Roborock devices over MQTT."""
2+
3+
import asyncio
4+
import logging
5+
from collections.abc import Callable
6+
from json import JSONDecodeError
7+
8+
from roborock.containers import RRiot
9+
from roborock.exceptions import RoborockException
10+
from roborock.mqtt.session import MqttParams, MqttSession
11+
from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
12+
from roborock.roborock_message import RoborockMessage
13+
14+
_LOGGER = logging.getLogger(__name__)
15+
16+
17+
class MqttChannel:
18+
"""Simple RPC-style channel for communicating with a device over MQTT.
19+
20+
Handles request/response correlation and timeouts, but leaves message
21+
format most parsing to higher-level components.
22+
"""
23+
24+
def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot: RRiot, mqtt_params: MqttParams):
25+
self._mqtt_session = mqtt_session
26+
self._duid = duid
27+
self._local_key = local_key
28+
self._rriot = rriot
29+
self._mqtt_params = mqtt_params
30+
31+
# RPC support
32+
self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
33+
self._decoder = create_mqtt_decoder(local_key)
34+
self._encoder = create_mqtt_encoder(local_key)
35+
self._queue_lock = asyncio.Lock()
36+
37+
@property
38+
def _publish_topic(self) -> str:
39+
"""Topic to send commands to the device."""
40+
return f"rr/m/i/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"
41+
42+
@property
43+
def _subscribe_topic(self) -> str:
44+
"""Topic to receive responses from the device."""
45+
return f"rr/m/o/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"
46+
47+
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
48+
"""Subscribe to the device's response topic.
49+
50+
The callback will be called with the message payload when a message is received.
51+
52+
All messages received will be processed through the provided callback, even
53+
those sent in response to the `send_command` command.
54+
55+
Returns a callable that can be used to unsubscribe from the topic.
56+
"""
57+
58+
def message_handler(payload: bytes) -> None:
59+
if not (messages := self._decoder(payload)):
60+
_LOGGER.warning("Failed to decode MQTT message: %s", payload)
61+
return
62+
for message in messages:
63+
_LOGGER.debug("Received message: %s", message)
64+
asyncio.create_task(self._resolve_future_with_lock(message))
65+
try:
66+
callback(message)
67+
except Exception as e:
68+
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
69+
70+
return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
71+
72+
async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
73+
"""Resolve waiting future with proper locking."""
74+
if (request_id := message.get_request_id()) is None:
75+
_LOGGER.debug("Received message with no request_id")
76+
return
77+
async with self._queue_lock:
78+
if (future := self._waiting_queue.pop(request_id, None)) is not None:
79+
future.set_result(message)
80+
else:
81+
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
82+
83+
async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
84+
"""Send a command message and wait for the response message.
85+
86+
Returns the raw response message - caller is responsible for parsing.
87+
"""
88+
try:
89+
if (request_id := message.get_request_id()) is None:
90+
raise RoborockException("Message must have a request_id for RPC calls")
91+
except (ValueError, JSONDecodeError) as err:
92+
_LOGGER.exception("Error getting request_id from message: %s", err)
93+
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
94+
95+
future: asyncio.Future[RoborockMessage] = asyncio.Future()
96+
async with self._queue_lock:
97+
if request_id in self._waiting_queue:
98+
raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
99+
self._waiting_queue[request_id] = future
100+
101+
try:
102+
encoded_msg = self._encoder(message)
103+
await self._mqtt_session.publish(self._publish_topic, encoded_msg)
104+
105+
return await asyncio.wait_for(future, timeout=timeout)
106+
107+
except asyncio.TimeoutError as ex:
108+
async with self._queue_lock:
109+
self._waiting_queue.pop(request_id, None)
110+
raise RoborockException(f"Command timed out after {timeout}s") from ex
111+
except Exception:
112+
logging.exception("Uncaught error sending command")
113+
async with self._queue_lock:
114+
self._waiting_queue.pop(request_id, None)
115+
raise

tests/devices/test_device.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Tests for the Device class."""
2+
3+
from unittest.mock import AsyncMock, Mock
4+
5+
from roborock.containers import HomeData, UserData
6+
from roborock.devices.device import DeviceVersion, RoborockDevice
7+
8+
from .. import mock_data
9+
10+
USER_DATA = UserData.from_dict(mock_data.USER_DATA)
11+
HOME_DATA = HomeData.from_dict(mock_data.HOME_DATA_RAW)
12+
13+
14+
async def test_device_connection() -> None:
15+
"""Test the Device connection setup."""
16+
17+
unsub = Mock()
18+
subscribe = AsyncMock()
19+
subscribe.return_value = unsub
20+
mqtt_channel = AsyncMock()
21+
mqtt_channel.subscribe = subscribe
22+
23+
device = RoborockDevice(
24+
USER_DATA,
25+
device_info=HOME_DATA.devices[0],
26+
product_info=HOME_DATA.products[0],
27+
mqtt_channel=mqtt_channel,
28+
)
29+
assert device.duid == "abc123"
30+
assert device.name == "Roborock S7 MaxV"
31+
assert device.device_version == DeviceVersion.V1
32+
33+
assert not subscribe.called
34+
35+
await device.connect()
36+
assert subscribe.called
37+
assert not unsub.called
38+
39+
await device.close()
40+
assert unsub.called

0 commit comments

Comments
 (0)