Skip to content

Commit e640e47

Browse files
committed
feat: Update device manager and device to establish an MQTT subscription
1 parent 54547d8 commit e640e47

File tree

4 files changed

+132
-43
lines changed

4 files changed

+132
-43
lines changed

roborock/cli.py

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from pyshark.packet.packet import Packet # type: ignore
1313

1414
from roborock import RoborockException
15-
from roborock.containers import DeviceData, HomeDataProduct, LoginData
16-
from roborock.mqtt.roborock_session import create_mqtt_session
17-
from roborock.protocol import MessageParser, create_mqtt_params
15+
from roborock.containers import DeviceData, HomeData, HomeDataProduct, LoginData
16+
from roborock.devices.device_manager import create_device_manager, create_home_data_api
17+
from roborock.protocol import MessageParser
1818
from roborock.util import run_sync
1919
from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1
2020
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
@@ -101,44 +101,25 @@ async def session(ctx, duration: int):
101101
context: RoborockContext = ctx.obj
102102
login_data = context.login_data()
103103

104-
# Discovery devices if not already available
105-
if not login_data.home_data:
106-
await _discover(ctx)
107-
login_data = context.login_data()
108-
if not login_data.home_data or not login_data.home_data.devices:
109-
raise RoborockException("Unable to discover devices")
110-
111-
all_devices = login_data.home_data.devices + login_data.home_data.received_devices
112-
click.echo(f"Discovered devices: {', '.join([device.name for device in all_devices])}")
113-
114-
rriot = login_data.user_data.rriot
115-
params = create_mqtt_params(rriot)
116-
117-
mqtt_session = await create_mqtt_session(params)
118-
click.echo("Starting MQTT session...")
119-
if not mqtt_session.connected:
120-
raise RoborockException("Failed to connect to MQTT broker")
104+
home_data_api = create_home_data_api(login_data.email, login_data.user_data)
121105

122-
def on_message(bytes: bytes):
123-
"""Callback function to handle incoming MQTT messages."""
124-
# Decode the first 20 bytes of the message for display
125-
bytes = bytes[:20]
106+
async def home_data_cache() -> HomeData:
107+
if login_data.home_data is None:
108+
login_data.home_data = await home_data_api()
109+
context.update(login_data)
110+
return login_data.home_data
126111

127-
click.echo(f"Received message: {bytes}...")
112+
# Create device manager
113+
device_manager = await create_device_manager(login_data.user_data, home_data_cache)
128114

129-
unsubs = []
130-
for device in all_devices:
131-
device_topic = f"rr/m/o/{rriot.u}/{params.username}/{device.duid}"
132-
unsub = await mqtt_session.subscribe(device_topic, on_message)
133-
unsubs.append(unsub)
115+
devices = await device_manager.get_devices()
116+
click.echo(f"Discovered devices: {', '.join([device.name for device in devices])}")
134117

135118
click.echo("MQTT session started. Listening for messages...")
136119
await asyncio.sleep(duration)
137120

138-
click.echo("Stopping MQTT session...")
139-
for unsub in unsubs:
140-
unsub()
141-
await mqtt_session.close()
121+
# Close the device manager (this will close all devices and MQTT session)
122+
await device_manager.close()
142123

143124

144125
async def _discover(ctx):

roborock/devices/device.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from roborock.containers import HomeDataDevice, HomeDataProduct, UserData
1212

13+
from .mqtt_channel import MqttChannel
14+
1315
_LOGGER = logging.getLogger(__name__)
1416

1517
__all__ = [
@@ -29,11 +31,22 @@ class DeviceVersion(enum.StrEnum):
2931
class RoborockDevice:
3032
"""Unified Roborock device class with automatic connection setup."""
3133

32-
def __init__(self, user_data: UserData, device_info: HomeDataDevice, product_info: HomeDataProduct) -> None:
33-
"""Initialize the RoborockDevice with device info, user data, and capabilities."""
34+
def __init__(
35+
self,
36+
user_data: UserData,
37+
device_info: HomeDataDevice,
38+
product_info: HomeDataProduct,
39+
mqtt_channel: MqttChannel,
40+
) -> None:
41+
"""Initialize the RoborockDevice.
42+
43+
The device takes ownership of the MQTT channel for communication with the device
44+
and will close it when the device is closed.
45+
"""
3446
self._user_data = user_data
3547
self._device_info = device_info
3648
self._product_info = product_info
49+
self._mqtt_channel = mqtt_channel
3750

3851
@property
3952
def duid(self) -> str:
@@ -63,3 +76,24 @@ def device_version(self) -> str:
6376
self._device_info.name,
6477
)
6578
return DeviceVersion.UNKNOWN
79+
80+
async def connect(self) -> None:
81+
"""Connect to the device using MQTT.
82+
83+
This method will set up the MQTT channel for communication with the device.
84+
"""
85+
await self._mqtt_channel.subscribe(self._on_mqtt_message)
86+
87+
async def close(self) -> None:
88+
"""Close the MQTT connection to the device.
89+
90+
This method will unsubscribe from the MQTT channel and clean up resources.
91+
"""
92+
await self._mqtt_channel.close()
93+
94+
def _on_mqtt_message(self, message: bytes) -> None:
95+
"""Handle incoming MQTT messages from the device.
96+
97+
This method should be overridden in subclasses to handle specific device messages.
98+
"""
99+
_LOGGER.debug("Received message from device %s: %s", self.duid, message[:50]) # Log first 50 bytes for brevity

roborock/devices/device_manager.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@
1010
UserData,
1111
)
1212
from roborock.devices.device import RoborockDevice
13+
from roborock.mqtt.roborock_session import create_mqtt_session
14+
from roborock.mqtt.session import MqttSession
15+
from roborock.protocol import create_mqtt_params
1316
from roborock.web_api import RoborockApiClient
1417

18+
from .mqtt_channel import MqttChannel
19+
1520
_LOGGER = logging.getLogger(__name__)
1621

1722
__all__ = [
@@ -34,21 +39,32 @@ def __init__(
3439
self,
3540
home_data_api: HomeDataApi,
3641
device_creator: DeviceCreator,
42+
mqtt_session: MqttSession,
3743
) -> None:
38-
"""Initialize the DeviceManager with user data and optional cache storage."""
44+
"""Initialize the DeviceManager with user data and optional cache storage.
45+
46+
This takes ownership of the MQTT session and will close it when the manager is closed.
47+
"""
3948
self._home_data_api = home_data_api
4049
self._device_creator = device_creator
4150
self._devices: dict[str, RoborockDevice] = {}
51+
self._mqtt_session = mqtt_session
4252

4353
async def discover_devices(self) -> list[RoborockDevice]:
4454
"""Discover all devices for the logged-in user."""
4555
home_data = await self._home_data_api()
4656
device_products = home_data.device_products
4757
_LOGGER.debug("Discovered %d devices %s", len(device_products), home_data)
4858

49-
self._devices = {
50-
duid: self._device_creator(device, product) for duid, (device, product) in device_products.items()
51-
}
59+
new_devices = {}
60+
for duid, (device, product) in device_products.items():
61+
if duid in self._devices:
62+
continue
63+
new_device = self._device_creator(device, product)
64+
await new_device.connect()
65+
new_devices[duid] = new_device
66+
67+
self._devices.update(new_devices)
5268
return list(self._devices.values())
5369

5470
async def get_device(self, duid: str) -> RoborockDevice | None:
@@ -59,6 +75,14 @@ async def get_devices(self) -> list[RoborockDevice]:
5975
"""Get all discovered devices."""
6076
return list(self._devices.values())
6177

78+
async def close(self) -> None:
79+
"""Close all MQTT connections and clean up resources."""
80+
for device in self._devices.values():
81+
await device.close()
82+
self._devices.clear()
83+
if self._mqtt_session:
84+
await self._mqtt_session.close()
85+
6286

6387
def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi:
6488
"""Create a home data API wrapper.
@@ -67,7 +91,9 @@ def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi:
6791
home data for the user.
6892
"""
6993

70-
client = RoborockApiClient(email, user_data)
94+
# Note: This will auto discover the API base URL. This can be improved
95+
# by caching this next to `UserData` if needed to avoid unnecessary API calls.
96+
client = RoborockApiClient(email)
7197

7298
async def home_data_api() -> HomeData:
7399
return await client.get_home_data(user_data)
@@ -83,9 +109,13 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi)
83109
include caching or other optimizations.
84110
"""
85111

112+
mqtt_params = create_mqtt_params(user_data.rriot)
113+
mqtt_session = await create_mqtt_session(mqtt_params)
114+
86115
def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
87-
return RoborockDevice(user_data, device, product)
116+
mqtt_channel = MqttChannel(mqtt_session, device.duid, user_data.rriot, mqtt_params)
117+
return RoborockDevice(user_data, device, product, mqtt_channel)
88118

89-
manager = DeviceManager(home_data_api, device_creator)
119+
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session)
90120
await manager.discover_devices()
91121
return manager

roborock/devices/mqtt_channel.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import logging
2+
from collections.abc import Callable
3+
4+
from roborock.containers import RRiot
5+
from roborock.mqtt.session import MqttParams, MqttSession
6+
7+
_LOGGER = logging.getLogger(__name__)
8+
9+
10+
class MqttChannel:
11+
"""RPC-style channel for communicating with a specific device over MQTT.
12+
13+
This currently only supports listening to messages and does not yet
14+
support RPC functionality.
15+
"""
16+
17+
def __init__(self, mqtt_session: MqttSession, duid: str, rriot: RRiot, mqtt_params: MqttParams):
18+
self._mqtt_session = mqtt_session
19+
self._duid = duid
20+
self._rriot = rriot
21+
self._mqtt_params = mqtt_params
22+
self._unsub: Callable[[], None] | None = None
23+
24+
@property
25+
def _publish_topic(self) -> str:
26+
"""Topic to send commands to the device."""
27+
return f"rr/m/i/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"
28+
29+
@property
30+
def _subscribe_topic(self) -> str:
31+
"""Topic to receive responses from the device."""
32+
return f"rr/m/o/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"
33+
34+
async def subscribe(self, callback: Callable[[bytes], None]) -> None:
35+
"""Subscribe to the device's response topic."""
36+
if self._unsub:
37+
raise ValueError("Already subscribed to the response topic")
38+
self._unsub = await self._mqtt_session.subscribe(self._subscribe_topic, callback)
39+
40+
async def close(self) -> None:
41+
"""Close the MQTT subscription."""
42+
if self._unsub:
43+
self._unsub()
44+
self._unsub = None

0 commit comments

Comments
 (0)