88import datetime
99import logging
1010from collections .abc import Callable
11- from typing import TypeVar
11+ from dataclasses import dataclass
12+ from typing import Any , TypeVar , override
1213
1314from roborock .data import HomeDataDevice , NetworkInfo , RoborockBase , UserData
1415from roborock .exceptions import RoborockException
16+ from roborock .mqtt .health_manager import HealthManager
1517from roborock .mqtt .session import MqttParams , MqttSession
1618from roborock .protocols .v1_protocol import (
19+ CommandType ,
20+ MapResponse ,
21+ ParamsType ,
22+ RequestMessage ,
23+ ResponseData ,
24+ ResponseMessage ,
1725 SecurityData ,
26+ V1RpcChannel ,
27+ create_map_response_decoder ,
1828 create_security_data ,
29+ decode_rpc_response ,
1930)
20- from roborock .roborock_message import RoborockMessage
31+ from roborock .roborock_message import RoborockMessage , RoborockMessageProtocol
2132from roborock .roborock_typing import RoborockCommand
2233
2334from .cache import Cache
2435from .channel import Channel
2536from .local_channel import LocalChannel , LocalSession , create_local_session
2637from .mqtt_channel import MqttChannel
27- from .v1_rpc_channel import (
28- PickFirstAvailable ,
29- V1RpcChannel ,
30- create_local_rpc_channel ,
31- create_map_rpc_channel ,
32- create_mqtt_rpc_channel ,
33- )
3438
3539_LOGGER = logging .getLogger (__name__ )
3640
3741__all__ = [
38- "V1Channel " ,
42+ "create_v1_channel " ,
3943]
4044
4145_T = TypeVar ("_T" , bound = RoborockBase )
46+ _TIMEOUT = 10.0
47+
4248
4349# Exponential backoff parameters for reconnecting to local
4450MIN_RECONNECT_INTERVAL = datetime .timedelta (minutes = 1 )
5056LOCAL_CONNECTION_CHECK_INTERVAL = datetime .timedelta (seconds = 15 )
5157
5258
59+ @dataclass (frozen = True )
60+ class RpcStrategy :
61+ """Strategy for sending RPC commands over a specific channel.
62+
63+ This holds the configuration for a specific transport method that differ
64+ in how messages are encoded/decoded and which channel is used.
65+ """
66+
67+ name : str
68+ """Name of the strategy for logging purposes."""
69+
70+ channel : LocalChannel | MqttChannel
71+ """Channel to use for communication."""
72+
73+ encoder : Callable [[RequestMessage ], RoborockMessage ]
74+ """Function to encode request messages for the channel."""
75+
76+ decoder : Callable [[RoborockMessage ], ResponseMessage | MapResponse | None ]
77+ """Function to decode response messages from the channel."""
78+
79+ health_manager : HealthManager | None = None
80+ """Optional health manager for monitoring the channel."""
81+
82+
83+ async def _send_rpc (strategy : RpcStrategy , request : RequestMessage ) -> ResponseData | bytes :
84+ """Send a command and return a parsed response RoborockBase type.
85+
86+ This provides an RPC interface over a given channel strategy. The device
87+ channel only supports publish and subscribe, so this function handles
88+ associating requests with their corresponding responses.
89+
90+ The provided RpcStrategy defines how to encode/decode messages and which
91+ channel to use for communication.
92+ """
93+ future : asyncio .Future [ResponseData | bytes ] = asyncio .Future ()
94+ _LOGGER .debug (
95+ "Sending command (%s, request_id=%s): %s, params=%s" ,
96+ strategy .name ,
97+ request .request_id ,
98+ request .method ,
99+ request .params ,
100+ )
101+
102+ message = strategy .encoder (request )
103+
104+ def find_response (response_message : RoborockMessage ) -> None :
105+ try :
106+ decoded = strategy .decoder (response_message )
107+ except RoborockException as ex :
108+ _LOGGER .debug ("Exception while decoding message (%s): %s" , response_message , ex )
109+ return
110+ if decoded is None :
111+ return
112+ _LOGGER .debug ("Received response (%s, request_id=%s)" , strategy .name , decoded .request_id )
113+ if decoded .request_id == request .request_id :
114+ if isinstance (decoded , ResponseMessage ) and decoded .api_error :
115+ future .set_exception (decoded .api_error )
116+ else :
117+ future .set_result (decoded .data )
118+
119+ unsub = await strategy .channel .subscribe (find_response )
120+ try :
121+ await strategy .channel .publish (message )
122+ result = await asyncio .wait_for (future , timeout = _TIMEOUT )
123+ except TimeoutError as ex :
124+ if strategy .health_manager :
125+ await strategy .health_manager .on_timeout ()
126+ future .cancel ()
127+ raise RoborockException (f"Command timed out after { _TIMEOUT } s" ) from ex
128+ finally :
129+ unsub ()
130+ if strategy .health_manager :
131+ await strategy .health_manager .on_success ()
132+ return result
133+
134+
135+ class RpcChannel (V1RpcChannel ):
136+ """Wrapper to expose V1RpcChannel interface with a specific set of RpcStrategies.
137+
138+ This is used to provide a simpler interface to v1 traits for sending commands
139+ over multiple possible transports (local, MQTT) with automatic fallback.
140+ """
141+
142+ def __init__ (self , rpc_strategies : list [RpcStrategy ]) -> None :
143+ self ._rpc_strategies = rpc_strategies
144+
145+ @override
146+ async def send_command (
147+ self ,
148+ method : CommandType ,
149+ * ,
150+ response_type : type [_T ] | None = None ,
151+ params : ParamsType = None ,
152+ ) -> _T | Any :
153+ """Send a command and return either a decoded or parsed response."""
154+ request = RequestMessage (method , params = params )
155+
156+ # Try each channel in order until one succeeds
157+ last_exception = None
158+ for strategy in self ._rpc_strategies :
159+ try :
160+ decoded_response = await _send_rpc (strategy , request )
161+ except RoborockException as e :
162+ _LOGGER .warning ("Command %s failed on %s channel: %s" , method , strategy .name , e )
163+ last_exception = e
164+ except Exception as e :
165+ _LOGGER .exception ("Unexpected error sending command %s on %s channel" , method , strategy .name )
166+ last_exception = RoborockException (f"Unexpected error: { e } " )
167+ else :
168+ if response_type is not None :
169+ if not isinstance (decoded_response , dict ):
170+ raise RoborockException (
171+ f"Expected dict response to parse { response_type .__name__ } , got { type (decoded_response )} "
172+ )
173+ return response_type .from_dict (decoded_response )
174+ return decoded_response
175+
176+ raise last_exception or RoborockException ("No available connection to send command" )
177+
178+
53179class V1Channel (Channel ):
54180 """Unified V1 protocol channel with automatic MQTT/local connection handling.
55181
@@ -69,20 +195,17 @@ def __init__(
69195 """Initialize the V1Channel.
70196
71197 Args:
198+ device_uid: Unique device identifier (DUID).
72199 mqtt_channel: MQTT channel for cloud communication
73200 local_session: Factory that creates LocalChannels for a hostname.
201+ cache: Cache for storing network information.
74202 """
75203 self ._device_uid = device_uid
204+ self ._security_data = security_data
76205 self ._mqtt_channel = mqtt_channel
77- self ._mqtt_rpc_channel = create_mqtt_rpc_channel ( mqtt_channel , security_data )
206+ self ._mqtt_health_manager = HealthManager ( self . _mqtt_channel . restart )
78207 self ._local_session = local_session
79208 self ._local_channel : LocalChannel | None = None
80- self ._local_rpc_channel : V1RpcChannel | None = None
81- # Prefer local, fallback to MQTT
82- self ._combined_rpc_channel = PickFirstAvailable (
83- [lambda : self ._local_rpc_channel , lambda : self ._mqtt_rpc_channel ]
84- )
85- self ._map_rpc_channel = create_map_rpc_channel (mqtt_channel , security_data )
86209 self ._mqtt_unsub : Callable [[], None ] | None = None
87210 self ._local_unsub : Callable [[], None ] | None = None
88211 self ._callback : Callable [[RoborockMessage ], None ] | None = None
@@ -108,17 +231,67 @@ def is_mqtt_connected(self) -> bool:
108231 @property
109232 def rpc_channel (self ) -> V1RpcChannel :
110233 """Return the combined RPC channel prefers local with a fallback to MQTT."""
111- return self ._combined_rpc_channel
234+ strategies = []
235+ if local_rpc_strategy := self ._create_local_rpc_strategy ():
236+ strategies .append (local_rpc_strategy )
237+ strategies .append (self ._create_mqtt_rpc_strategy ())
238+ return RpcChannel (strategies )
112239
113240 @property
114241 def mqtt_rpc_channel (self ) -> V1RpcChannel :
115- """Return the MQTT RPC channel."""
116- return self ._mqtt_rpc_channel
242+ """Return the MQTT-only RPC channel."""
243+ return RpcChannel ([ self ._create_mqtt_rpc_strategy ()])
117244
118245 @property
119246 def map_rpc_channel (self ) -> V1RpcChannel :
120247 """Return the map RPC channel used for fetching map content."""
121- return self ._map_rpc_channel
248+ decoder = create_map_response_decoder (security_data = self ._security_data )
249+ return RpcChannel ([self ._create_mqtt_rpc_strategy (decoder )])
250+
251+ def _create_local_rpc_strategy (self ) -> RpcStrategy :
252+ """Create the RPC strategy for local transport."""
253+ if self ._local_channel is None or not self .is_local_connected :
254+ return None
255+ return RpcStrategy (
256+ name = "local" ,
257+ channel = self ._local_channel ,
258+ encoder = self ._local_encoder ,
259+ decoder = decode_rpc_response ,
260+ )
261+
262+ def _local_encoder (self , x : RequestMessage ) -> RoborockMessage :
263+ """Encode a request message for local transport.
264+
265+ This is passed to the RpcStrategy as a function so that it will
266+ read the current local channel's protocol version which changes as
267+ the protocol version is discovered.
268+ """
269+ if self ._local_channel is None :
270+ # This is for typing and should not happen since we only create the
271+ # strategy if local is connected and it will never get set back to
272+ # None once connected.
273+ raise ValueError ("Local channel is not available for encoding" )
274+ return x .encode_message (
275+ RoborockMessageProtocol .GENERAL_REQUEST ,
276+ version = self ._local_channel .protocol_version ,
277+ )
278+
279+ def _create_mqtt_rpc_strategy (self , decoder : Callable [[RoborockMessage ], Any ] = decode_rpc_response ) -> RpcStrategy :
280+ """Create the RPC strategy for MQTT transport.
281+
282+ This can optionally take a custom decoder for different response types
283+ such as map data.
284+ """
285+ return RpcStrategy (
286+ name = "mqtt" ,
287+ channel = self ._mqtt_channel ,
288+ encoder = lambda x : x .encode_message (
289+ RoborockMessageProtocol .RPC_REQUEST ,
290+ security_data = self ._security_data ,
291+ ),
292+ decoder = decoder ,
293+ health_manager = self ._mqtt_health_manager ,
294+ )
122295
123296 async def subscribe (self , callback : Callable [[RoborockMessage ], None ]) -> Callable [[], None ]:
124297 """Subscribe to all messages from the device.
@@ -185,7 +358,7 @@ async def _get_networking_info(self, *, prefer_cache: bool = True) -> NetworkInf
185358 _LOGGER .debug ("Using cached network info for device %s" , self ._device_uid )
186359 return network_info
187360 try :
188- network_info = await self ._mqtt_rpc_channel .send_command (
361+ network_info = await self .mqtt_rpc_channel .send_command (
189362 RoborockCommand .GET_NETWORK_INFO , response_type = NetworkInfo
190363 )
191364 except RoborockException as e :
@@ -216,7 +389,6 @@ async def _local_connect(self, *, prefer_cache: bool = True) -> None:
216389 raise RoborockException (f"Error connecting to local device { self ._device_uid } : { e } " ) from e
217390 # Wire up the new channel
218391 self ._local_channel = local_channel
219- self ._local_rpc_channel = create_local_rpc_channel (self ._local_channel )
220392 self ._local_unsub = await self ._local_channel .subscribe (self ._on_local_message )
221393 _LOGGER .info ("Successfully connected to local device %s" , self ._device_uid )
222394
0 commit comments