Skip to content

Commit c97ee1d

Browse files
committed
chore: refactor v1 rpc channels
This merges the two packages v1_channel and v1_rpc_channel since they are very closely related. The differences between the RPC types are now handled with `RpcStrategy` (encoding, decoding, which channel, health checking, etc). The "PayloadEncodedV1RpcChannel" is now `_send_rpc` which accepts the rpc strategy. The `V1RpcChannel` interface is moved to the `v1_protocol` module and now has a single implementation that handles everyting (the `PickFirstAvailable` logic as well as the response parsing). Overall the code is now less generalized, but is probably easier to understand since its all in a single place. Notably, all the rpc code was already tested via the v1_channel interface.
1 parent 80d7d5a commit c97ee1d

File tree

8 files changed

+256
-261
lines changed

8 files changed

+256
-261
lines changed

roborock/devices/mqtt_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,5 +90,5 @@ async def restart(self) -> None:
9090
def create_mqtt_channel(
9191
user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
9292
) -> MqttChannel:
93-
"""Create a V1Channel for the given device."""
93+
"""Create a MQTT channel for the given device."""
9494
return MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)

roborock/devices/traits/v1/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
from roborock.data.v1.v1_code_mappings import RoborockDockTypeCode
3939
from roborock.devices.cache import Cache
4040
from roborock.devices.traits import Trait
41-
from roborock.devices.v1_rpc_channel import V1RpcChannel
4241
from roborock.map.map_parser import MapParserConfig
42+
from roborock.protocols.v1_protocol import V1RpcChannel
4343
from roborock.web_api import UserWebApiClient
4444

4545
from .child_lock import ChildLockTrait

roborock/devices/traits/v1/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import ClassVar, Self
1010

1111
from roborock.data import RoborockBase
12-
from roborock.devices.v1_rpc_channel import V1RpcChannel
12+
from roborock.protocols.v1_protocol import V1RpcChannel
1313
from roborock.roborock_typing import RoborockCommand
1414

1515
_LOGGER = logging.getLogger(__name__)

roborock/devices/v1_channel.py

Lines changed: 195 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,43 @@
88
import datetime
99
import logging
1010
from collections.abc import Callable
11-
from typing import TypeVar
11+
from dataclasses import dataclass
12+
from typing import Any, TypeVar, override
1213

1314
from roborock.data import HomeDataDevice, NetworkInfo, RoborockBase, UserData
1415
from roborock.exceptions import RoborockException
16+
from roborock.mqtt.health_manager import HealthManager
1517
from roborock.mqtt.session import MqttParams, MqttSession
1618
from roborock.protocols.v1_protocol import (
19+
CommandType,
20+
MapResponse,
21+
ParamsType,
22+
RequestMessage,
23+
ResponseData,
24+
ResponseMessage,
1725
SecurityData,
26+
V1RpcChannel,
27+
create_map_response_decoder,
1828
create_security_data,
29+
decode_rpc_response,
1930
)
20-
from roborock.roborock_message import RoborockMessage
31+
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
2132
from roborock.roborock_typing import RoborockCommand
2233

2334
from .cache import Cache
2435
from .channel import Channel
2536
from .local_channel import LocalChannel, LocalSession, create_local_session
2637
from .mqtt_channel import MqttChannel
27-
from .v1_rpc_channel import (
28-
PickFirstAvailable,
29-
V1RpcChannel,
30-
create_local_rpc_channel,
31-
create_map_rpc_channel,
32-
create_mqtt_rpc_channel,
33-
)
3438

3539
_LOGGER = logging.getLogger(__name__)
3640

3741
__all__ = [
38-
"V1Channel",
42+
"create_v1_channel",
3943
]
4044

4145
_T = TypeVar("_T", bound=RoborockBase)
46+
_TIMEOUT = 10.0
47+
4248

4349
# Exponential backoff parameters for reconnecting to local
4450
MIN_RECONNECT_INTERVAL = datetime.timedelta(minutes=1)
@@ -50,6 +56,126 @@
5056
LOCAL_CONNECTION_CHECK_INTERVAL = datetime.timedelta(seconds=15)
5157

5258

59+
@dataclass(frozen=True)
60+
class RpcStrategy:
61+
"""Strategy for sending RPC commands over a specific channel.
62+
63+
This holds the configuration for a specific transport method that differ
64+
in how messages are encoded/decoded and which channel is used.
65+
"""
66+
67+
name: str
68+
"""Name of the strategy for logging purposes."""
69+
70+
channel: LocalChannel | MqttChannel
71+
"""Channel to use for communication."""
72+
73+
encoder: Callable[[RequestMessage], RoborockMessage]
74+
"""Function to encode request messages for the channel."""
75+
76+
decoder: Callable[[RoborockMessage], ResponseMessage | MapResponse | None]
77+
"""Function to decode response messages from the channel."""
78+
79+
health_manager: HealthManager | None = None
80+
"""Optional health manager for monitoring the channel."""
81+
82+
83+
async def _send_rpc(strategy: RpcStrategy, request: RequestMessage) -> ResponseData | bytes:
84+
"""Send a command and return a parsed response RoborockBase type.
85+
86+
This provides an RPC interface over a given channel strategy. The device
87+
channel only supports publish and subscribe, so this function handles
88+
associating requests with their corresponding responses.
89+
90+
The provided RpcStrategy defines how to encode/decode messages and which
91+
channel to use for communication.
92+
"""
93+
future: asyncio.Future[ResponseData | bytes] = asyncio.Future()
94+
_LOGGER.debug(
95+
"Sending command (%s, request_id=%s): %s, params=%s",
96+
strategy.name,
97+
request.request_id,
98+
request.method,
99+
request.params,
100+
)
101+
102+
message = strategy.encoder(request)
103+
104+
def find_response(response_message: RoborockMessage) -> None:
105+
try:
106+
decoded = strategy.decoder(response_message)
107+
except RoborockException as ex:
108+
_LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex)
109+
return
110+
if decoded is None:
111+
return
112+
_LOGGER.debug("Received response (%s, request_id=%s)", strategy.name, decoded.request_id)
113+
if decoded.request_id == request.request_id:
114+
if isinstance(decoded, ResponseMessage) and decoded.api_error:
115+
future.set_exception(decoded.api_error)
116+
else:
117+
future.set_result(decoded.data)
118+
119+
unsub = await strategy.channel.subscribe(find_response)
120+
try:
121+
await strategy.channel.publish(message)
122+
result = await asyncio.wait_for(future, timeout=_TIMEOUT)
123+
except TimeoutError as ex:
124+
if strategy.health_manager:
125+
await strategy.health_manager.on_timeout()
126+
future.cancel()
127+
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
128+
finally:
129+
unsub()
130+
if strategy.health_manager:
131+
await strategy.health_manager.on_success()
132+
return result
133+
134+
135+
class RpcChannel(V1RpcChannel):
136+
"""Wrapper to expose V1RpcChannel interface with a specific set of RpcStrategies.
137+
138+
This is used to provide a simpler interface to v1 traits for sending commands
139+
over multiple possible transports (local, MQTT) with automatic fallback.
140+
"""
141+
142+
def __init__(self, rpc_strategies: list[RpcStrategy]) -> None:
143+
self._rpc_strategies = rpc_strategies
144+
145+
@override
146+
async def send_command(
147+
self,
148+
method: CommandType,
149+
*,
150+
response_type: type[_T] | None = None,
151+
params: ParamsType = None,
152+
) -> _T | Any:
153+
"""Send a command and return either a decoded or parsed response."""
154+
request = RequestMessage(method, params=params)
155+
156+
# Try each channel in order until one succeeds
157+
last_exception = None
158+
for strategy in self._rpc_strategies:
159+
try:
160+
decoded_response = await _send_rpc(strategy, request)
161+
except RoborockException as e:
162+
_LOGGER.warning("Command %s failed on %s channel: %s", method, strategy.name, e)
163+
last_exception = e
164+
except Exception as e:
165+
_LOGGER.exception("Unexpected error sending command %s on %s channel", method, strategy.name)
166+
last_exception = RoborockException(f"Unexpected error: {e}")
167+
else:
168+
if response_type is not None:
169+
if not isinstance(decoded_response, dict):
170+
raise RoborockException(
171+
f"Expected dict response to parse {response_type.__name__}, got {type(decoded_response)}"
172+
)
173+
return response_type.from_dict(decoded_response)
174+
return decoded_response
175+
176+
raise last_exception or RoborockException("No available connection to send command")
177+
178+
53179
class V1Channel(Channel):
54180
"""Unified V1 protocol channel with automatic MQTT/local connection handling.
55181
@@ -69,20 +195,17 @@ def __init__(
69195
"""Initialize the V1Channel.
70196
71197
Args:
198+
device_uid: Unique device identifier (DUID).
72199
mqtt_channel: MQTT channel for cloud communication
73200
local_session: Factory that creates LocalChannels for a hostname.
201+
cache: Cache for storing network information.
74202
"""
75203
self._device_uid = device_uid
204+
self._security_data = security_data
76205
self._mqtt_channel = mqtt_channel
77-
self._mqtt_rpc_channel = create_mqtt_rpc_channel(mqtt_channel, security_data)
206+
self._mqtt_health_manager = HealthManager(self._mqtt_channel.restart)
78207
self._local_session = local_session
79208
self._local_channel: LocalChannel | None = None
80-
self._local_rpc_channel: V1RpcChannel | None = None
81-
# Prefer local, fallback to MQTT
82-
self._combined_rpc_channel = PickFirstAvailable(
83-
[lambda: self._local_rpc_channel, lambda: self._mqtt_rpc_channel]
84-
)
85-
self._map_rpc_channel = create_map_rpc_channel(mqtt_channel, security_data)
86209
self._mqtt_unsub: Callable[[], None] | None = None
87210
self._local_unsub: Callable[[], None] | None = None
88211
self._callback: Callable[[RoborockMessage], None] | None = None
@@ -108,17 +231,67 @@ def is_mqtt_connected(self) -> bool:
108231
@property
109232
def rpc_channel(self) -> V1RpcChannel:
110233
"""Return the combined RPC channel prefers local with a fallback to MQTT."""
111-
return self._combined_rpc_channel
234+
strategies = []
235+
if local_rpc_strategy := self._create_local_rpc_strategy():
236+
strategies.append(local_rpc_strategy)
237+
strategies.append(self._create_mqtt_rpc_strategy())
238+
return RpcChannel(strategies)
112239

113240
@property
114241
def mqtt_rpc_channel(self) -> V1RpcChannel:
115-
"""Return the MQTT RPC channel."""
116-
return self._mqtt_rpc_channel
242+
"""Return the MQTT-only RPC channel."""
243+
return RpcChannel([self._create_mqtt_rpc_strategy()])
117244

118245
@property
119246
def map_rpc_channel(self) -> V1RpcChannel:
120247
"""Return the map RPC channel used for fetching map content."""
121-
return self._map_rpc_channel
248+
decoder = create_map_response_decoder(security_data=self._security_data)
249+
return RpcChannel([self._create_mqtt_rpc_strategy(decoder)])
250+
251+
def _create_local_rpc_strategy(self) -> RpcStrategy:
252+
"""Create the RPC strategy for local transport."""
253+
if self._local_channel is None or not self.is_local_connected:
254+
return None
255+
return RpcStrategy(
256+
name="local",
257+
channel=self._local_channel,
258+
encoder=self._local_encoder,
259+
decoder=decode_rpc_response,
260+
)
261+
262+
def _local_encoder(self, x: RequestMessage) -> RoborockMessage:
263+
"""Encode a request message for local transport.
264+
265+
This is passed to the RpcStrategy as a function so that it will
266+
read the current local channel's protocol version which changes as
267+
the protocol version is discovered.
268+
"""
269+
if self._local_channel is None:
270+
# This is for typing and should not happen since we only create the
271+
# strategy if local is connected and it will never get set back to
272+
# None once connected.
273+
raise ValueError("Local channel is not available for encoding")
274+
return x.encode_message(
275+
RoborockMessageProtocol.GENERAL_REQUEST,
276+
version=self._local_channel.protocol_version,
277+
)
278+
279+
def _create_mqtt_rpc_strategy(self, decoder: Callable[[RoborockMessage], Any] = decode_rpc_response) -> RpcStrategy:
280+
"""Create the RPC strategy for MQTT transport.
281+
282+
This can optionally take a custom decoder for different response types
283+
such as map data.
284+
"""
285+
return RpcStrategy(
286+
name="mqtt",
287+
channel=self._mqtt_channel,
288+
encoder=lambda x: x.encode_message(
289+
RoborockMessageProtocol.RPC_REQUEST,
290+
security_data=self._security_data,
291+
),
292+
decoder=decoder,
293+
health_manager=self._mqtt_health_manager,
294+
)
122295

123296
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
124297
"""Subscribe to all messages from the device.
@@ -185,7 +358,7 @@ async def _get_networking_info(self, *, prefer_cache: bool = True) -> NetworkInf
185358
_LOGGER.debug("Using cached network info for device %s", self._device_uid)
186359
return network_info
187360
try:
188-
network_info = await self._mqtt_rpc_channel.send_command(
361+
network_info = await self.mqtt_rpc_channel.send_command(
189362
RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo
190363
)
191364
except RoborockException as e:
@@ -216,7 +389,6 @@ async def _local_connect(self, *, prefer_cache: bool = True) -> None:
216389
raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e
217390
# Wire up the new channel
218391
self._local_channel = local_channel
219-
self._local_rpc_channel = create_local_rpc_channel(self._local_channel)
220392
self._local_unsub = await self._local_channel.subscribe(self._on_local_message)
221393
_LOGGER.info("Successfully connected to local device %s", self._device_uid)
222394

0 commit comments

Comments
 (0)