Skip to content

Commit 1affee2

Browse files
authored
Merge branch 'main' into supported_features_markdown
2 parents 8bdd714 + 6c9b7ad commit 1affee2

File tree

13 files changed

+302
-72
lines changed

13 files changed

+302
-72
lines changed

poetry.lock

Lines changed: 72 additions & 44 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "python-roborock"
3-
version = "2.31.0"
3+
version = "2.33.0"
44
description = "A package to control Roborock vacuums."
55
authors = ["humbertogontijo <humbertogontijo@users.noreply.github.com>"]
66
license = "GPL-3.0-only"
@@ -39,7 +39,7 @@ requires = ["poetry-core==1.8.0"]
3939
build-backend = "poetry.core.masonry.api"
4040

4141
[tool.poetry.group.dev.dependencies]
42-
pytest-asyncio = "*"
42+
pytest-asyncio = ">=1.1.0"
4343
pytest = "*"
4444
pre-commit = ">=3.5,<5.0"
4545
mypy = "*"

roborock/cli.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from roborock import SHORT_MODEL_TO_ENUM, DeviceFeatures, RoborockCommand, RoborockException
1414
from roborock.containers import DeviceData, HomeData, HomeDataProduct, LoginData, NetworkInfo, RoborockBase, UserData
15+
from roborock.devices.cache import Cache, CacheData
1516
from roborock.devices.device_manager import create_device_manager, create_home_data_api
1617
from roborock.protocol import MessageParser
1718
from roborock.util import run_sync
@@ -39,7 +40,7 @@ class ConnectionCache(RoborockBase):
3940
network_info: dict[str, NetworkInfo] | None = None
4041

4142

42-
class RoborockContext:
43+
class RoborockContext(Cache):
4344
roborock_file = Path("~/.roborock").expanduser()
4445
_cache_data: ConnectionCache | None = None
4546

@@ -68,6 +69,18 @@ def cache_data(self) -> ConnectionCache:
6869
self.validate()
6970
return self._cache_data
7071

72+
async def get(self) -> CacheData:
73+
"""Get cached value."""
74+
connection_cache = self.cache_data()
75+
return CacheData(home_data=connection_cache.home_data, network_info=connection_cache.network_info or {})
76+
77+
async def set(self, value: CacheData) -> None:
78+
"""Set value in the cache."""
79+
connection_cache = self.cache_data()
80+
connection_cache.home_data = value.home_data
81+
connection_cache.network_info = value.network_info
82+
self.update(connection_cache)
83+
7184

7285
@click.option("-d", "--debug", default=False, count=True)
7386
@click.version_option(package_name="python-roborock")
@@ -119,14 +132,8 @@ async def session(ctx, duration: int):
119132

120133
home_data_api = create_home_data_api(cache_data.email, cache_data.user_data)
121134

122-
async def home_data_cache() -> HomeData:
123-
if cache_data.home_data is None:
124-
cache_data.home_data = await home_data_api()
125-
context.update(cache_data)
126-
return cache_data.home_data
127-
128135
# Create device manager
129-
device_manager = await create_device_manager(cache_data.user_data, home_data_cache)
136+
device_manager = await create_device_manager(cache_data.user_data, home_data_api, context)
130137

131138
devices = await device_manager.get_devices()
132139
click.echo(f"Discovered devices: {', '.join([device.name for device in devices])}")

roborock/containers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,13 @@ def get_mop_mode_code(self, mop_mode: str) -> int:
425425
raise RoborockException("Attempted to get mop_mode before status has been updated.")
426426
return self.mop_mode.as_dict().get(mop_mode)
427427

428+
@property
429+
def current_map(self) -> int | None:
430+
"""Returns the current map ID if the map is present."""
431+
if self.map_status is not None:
432+
return (self.map_status - 3) // 4
433+
return None
434+
428435

429436
@dataclass
430437
class S4MaxStatus(Status):

roborock/devices/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
__all__ = [
44
"device",
55
"device_manager",
6+
"cache",
67
]

roborock/devices/cache.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""This module provides caching functionality for the Roborock device management system.
2+
3+
This module defines a cache interface that you may use to cache device
4+
information to avoid unnecessary API calls. Callers may implement
5+
this interface to provide their own caching mechanism.
6+
"""
7+
8+
from dataclasses import dataclass, field
9+
from typing import Protocol
10+
11+
from roborock.containers import HomeData, NetworkInfo
12+
13+
14+
@dataclass
15+
class CacheData:
16+
"""Data structure for caching device information."""
17+
18+
home_data: HomeData | None = None
19+
"""Home data containing device and product information."""
20+
21+
network_info: dict[str, NetworkInfo] = field(default_factory=dict)
22+
"""Network information indexed by device DUID."""
23+
24+
25+
class Cache(Protocol):
26+
"""Protocol for a cache that can store and retrieve values."""
27+
28+
async def get(self) -> CacheData:
29+
"""Get cached value."""
30+
...
31+
32+
async def set(self, value: CacheData) -> None:
33+
"""Set value in the cache."""
34+
...
35+
36+
37+
class InMemoryCache(Cache):
38+
"""In-memory cache implementation."""
39+
40+
def __init__(self):
41+
self._data = CacheData()
42+
43+
async def get(self) -> CacheData:
44+
return self._data
45+
46+
async def set(self, value: CacheData) -> None:
47+
self._data = value
48+
49+
50+
class NoCache(Cache):
51+
"""No-op cache implementation."""
52+
53+
async def get(self) -> CacheData:
54+
return CacheData()
55+
56+
async def set(self, value: CacheData) -> None:
57+
pass

roborock/devices/device_manager.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from roborock.protocol import create_mqtt_params
1919
from roborock.web_api import RoborockApiClient
2020

21+
from .cache import Cache, NoCache
2122
from .channel import Channel
2223
from .mqtt_channel import create_mqtt_channel
2324
from .traits.dyad import DyadApi
@@ -32,8 +33,6 @@
3233
"create_device_manager",
3334
"create_home_data_api",
3435
"DeviceManager",
35-
"HomeDataApi",
36-
"DeviceCreator",
3736
]
3837

3938

@@ -57,19 +56,27 @@ def __init__(
5756
home_data_api: HomeDataApi,
5857
device_creator: DeviceCreator,
5958
mqtt_session: MqttSession,
59+
cache: Cache,
6060
) -> None:
6161
"""Initialize the DeviceManager with user data and optional cache storage.
6262
6363
This takes ownership of the MQTT session and will close it when the manager is closed.
6464
"""
6565
self._home_data_api = home_data_api
66+
self._cache = cache
6667
self._device_creator = device_creator
6768
self._devices: dict[str, RoborockDevice] = {}
6869
self._mqtt_session = mqtt_session
6970

7071
async def discover_devices(self) -> list[RoborockDevice]:
7172
"""Discover all devices for the logged-in user."""
72-
home_data = await self._home_data_api()
73+
cache_data = await self._cache.get()
74+
if not cache_data.home_data:
75+
_LOGGER.debug("No cached home data found, fetching from API")
76+
cache_data.home_data = await self._home_data_api()
77+
await self._cache.set(cache_data)
78+
home_data = cache_data.home_data
79+
7380
device_products = home_data.device_products
7481
_LOGGER.debug("Discovered %d devices %s", len(device_products), home_data)
7582

@@ -118,13 +125,19 @@ async def home_data_api() -> HomeData:
118125
return home_data_api
119126

120127

121-
async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi) -> DeviceManager:
128+
async def create_device_manager(
129+
user_data: UserData,
130+
home_data_api: HomeDataApi,
131+
cache: Cache | None = None,
132+
) -> DeviceManager:
122133
"""Convenience function to create and initialize a DeviceManager.
123134
124135
The Home Data is fetched using the provided home_data_api callable which
125136
is exposed this way to allow for swapping out other implementations to
126137
include caching or other optimizations.
127138
"""
139+
if cache is None:
140+
cache = NoCache()
128141

129142
mqtt_params = create_mqtt_params(user_data.rriot)
130143
mqtt_session = await create_mqtt_session(mqtt_params)
@@ -135,7 +148,7 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock
135148
# TODO: Define a registration mechanism/factory for v1 traits
136149
match device.pv:
137150
case DeviceVersion.V1:
138-
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device)
151+
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache)
139152
traits.append(StatusTrait(product, channel.rpc_channel))
140153
case DeviceVersion.A01:
141154
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
@@ -150,6 +163,6 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock
150163
raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}")
151164
return RoborockDevice(device, channel, traits)
152165

153-
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session)
166+
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session, cache=cache)
154167
await manager.discover_devices()
155168
return manager

roborock/devices/v1_channel.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from roborock.roborock_message import RoborockMessage
1919
from roborock.roborock_typing import RoborockCommand
2020

21+
from .cache import Cache
2122
from .channel import Channel
2223
from .local_channel import LocalChannel, LocalSession, create_local_session
2324
from .mqtt_channel import MqttChannel
@@ -46,6 +47,7 @@ def __init__(
4647
security_data: SecurityData,
4748
mqtt_channel: MqttChannel,
4849
local_session: LocalSession,
50+
cache: Cache,
4951
) -> None:
5052
"""Initialize the V1Channel.
5153
@@ -62,7 +64,7 @@ def __init__(
6264
self._mqtt_unsub: Callable[[], None] | None = None
6365
self._local_unsub: Callable[[], None] | None = None
6466
self._callback: Callable[[RoborockMessage], None] | None = None
65-
self._networking_info: NetworkInfo | None = None
67+
self._cache = cache
6668

6769
@property
6870
def is_connected(self) -> bool:
@@ -131,19 +133,26 @@ async def _get_networking_info(self) -> NetworkInfo:
131133
132134
This is a cloud only command used to get the local device's IP address.
133135
"""
136+
cache_data = await self._cache.get()
137+
if cache_data.network_info and (network_info := cache_data.network_info.get(self._device_uid)):
138+
_LOGGER.debug("Using cached network info for device %s", self._device_uid)
139+
return network_info
134140
try:
135-
return await self._mqtt_rpc_channel.send_command(
141+
network_info = await self._mqtt_rpc_channel.send_command(
136142
RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo
137143
)
138144
except RoborockException as e:
139145
raise RoborockException(f"Network info failed for device {self._device_uid}") from e
146+
_LOGGER.debug("Network info for device %s: %s", self._device_uid, network_info)
147+
cache_data.network_info[self._device_uid] = network_info
148+
await self._cache.set(cache_data)
149+
return network_info
140150

141151
async def _local_connect(self) -> Callable[[], None]:
142152
"""Set up local connection if possible."""
143153
_LOGGER.debug("Attempting to connect to local channel for device %s", self._device_uid)
144-
if self._networking_info is None:
145-
self._networking_info = await self._get_networking_info()
146-
host = self._networking_info.ip
154+
networking_info = await self._get_networking_info()
155+
host = networking_info.ip
147156
_LOGGER.debug("Connecting to local channel at %s", host)
148157
self._local_channel = self._local_session(host)
149158
try:
@@ -168,10 +177,14 @@ def _on_local_message(self, message: RoborockMessage) -> None:
168177

169178

170179
def create_v1_channel(
171-
user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
180+
user_data: UserData,
181+
mqtt_params: MqttParams,
182+
mqtt_session: MqttSession,
183+
device: HomeDataDevice,
184+
cache: Cache,
172185
) -> V1Channel:
173186
"""Create a V1Channel for the given device."""
174187
security_data = create_security_data(user_data.rriot)
175188
mqtt_channel = MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)
176189
local_session = create_local_session(device.local_key)
177-
return V1Channel(device.duid, security_data, mqtt_channel, local_session=local_session)
190+
return V1Channel(device.duid, security_data, mqtt_channel, local_session=local_session, cache=cache)

roborock/mqtt/roborock_session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,14 @@ async def _run_task(self, start_future: asyncio.Future[None] | None) -> None:
125125
except Exception as err:
126126
# This error is thrown when the MQTT loop is cancelled
127127
# and the generator is not stopped.
128-
if "generator didn't stop" in str(err):
128+
if "generator didn't stop" in str(err) or "generator didn't yield" in str(err):
129129
_LOGGER.debug("MQTT loop was cancelled")
130130
return
131131
if start_future:
132132
_LOGGER.error("Uncaught error starting MQTT session: %s", err)
133133
start_future.set_exception(err)
134134
return
135-
_LOGGER.error("Uncaught error during MQTT session: %s", err)
135+
_LOGGER.exception("Uncaught error during MQTT session: %s", err)
136136

137137
self._healthy = False
138138
_LOGGER.info("MQTT session disconnected, retrying in %s seconds", self._backoff.total_seconds())
@@ -180,7 +180,7 @@ async def _process_message_loop(self, client: aiomqtt.Client) -> None:
180180
except asyncio.CancelledError:
181181
raise
182182
except Exception as e:
183-
_LOGGER.error("Uncaught exception in subscriber callback: %s", e)
183+
_LOGGER.exception("Uncaught exception in subscriber callback: %s", e)
184184

185185
async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
186186
"""Subscribe to messages on the specified topic and invoke the callback for new messages.

tests/devices/test_device_manager.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
from roborock.containers import HomeData, UserData
9+
from roborock.devices.cache import CacheData, InMemoryCache
910
from roborock.devices.device_manager import create_device_manager, create_home_data_api
1011
from roborock.exceptions import RoborockException
1112

@@ -98,3 +99,37 @@ async def test_create_home_data_api_exception() -> None:
9899

99100
with pytest.raises(RoborockException, match="Test exception"):
100101
await api()
102+
103+
104+
async def test_cache_logic() -> None:
105+
"""Test that the cache logic works correctly."""
106+
call_count = 0
107+
108+
async def mock_home_data_with_counter() -> HomeData:
109+
nonlocal call_count
110+
call_count += 1
111+
return HomeData.from_dict(mock_data.HOME_DATA_RAW)
112+
113+
class TestCache:
114+
def __init__(self):
115+
self._data = CacheData()
116+
117+
async def get(self) -> CacheData:
118+
return self._data
119+
120+
async def set(self, value: CacheData) -> None:
121+
self._data = value
122+
123+
# First call happens during create_device_manager initialization
124+
device_manager = await create_device_manager(USER_DATA, mock_home_data_with_counter, cache=InMemoryCache())
125+
assert call_count == 1
126+
127+
# Second call should use cache, not increment call_count
128+
devices2 = await device_manager.discover_devices()
129+
assert call_count == 1 # Should still be 1, not 2
130+
assert len(devices2) == 1
131+
132+
await device_manager.close()
133+
assert len(devices2) == 1
134+
135+
await device_manager.close()

0 commit comments

Comments
 (0)