33from asyncio import Lock , TimerHandle , Transport , get_running_loop
44from collections .abc import Callable
55from dataclasses import dataclass
6+ from enum import StrEnum
67
78import async_timeout
89
1819_LOGGER = logging .getLogger (__name__ )
1920
2021
22+ class LocalProtocolVersion (StrEnum ):
23+ """Supported local protocol versions. Different from vacuum protocol versions."""
24+
25+ L01 = "L01"
26+ V1 = "1.0"
27+
28+
2129@dataclass
2230class _LocalProtocol (asyncio .Protocol ):
2331 """Callbacks for the Roborock local client transport."""
@@ -37,7 +45,12 @@ def connection_lost(self, exc: Exception | None) -> None:
3745class RoborockLocalClientV1 (RoborockClientV1 , RoborockClient ):
3846 """Roborock local client for v1 devices."""
3947
40- def __init__ (self , device_data : DeviceData , queue_timeout : int = 4 , version : str | None = None ):
48+ def __init__ (
49+ self ,
50+ device_data : DeviceData ,
51+ queue_timeout : int = 4 ,
52+ local_protocol_version : LocalProtocolVersion | None = None ,
53+ ):
4154 """Initialize the Roborock local client."""
4255 if device_data .host is None :
4356 raise RoborockException ("Host is required" )
@@ -50,13 +63,17 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4, version: str
5063 RoborockClientV1 .__init__ (self , device_data , security_data = None )
5164 RoborockClient .__init__ (self , device_data )
5265 self ._local_protocol = _LocalProtocol (self ._data_received , self ._connection_lost )
53- self ._version = version
66+ self ._local_protocol_version = local_protocol_version
5467 self ._connect_nonce = get_next_int (10000 , 32767 )
5568 self ._ack_nonce : int | None = None
5669 self ._set_encoder_decoder ()
5770 self .queue_timeout = queue_timeout
5871 self ._logger = RoborockLoggerAdapter (device_data .device .name , _LOGGER )
5972
73+ @property
74+ def local_protocol_version (self ) -> LocalProtocolVersion :
75+ return LocalProtocolVersion .V1 if self ._local_protocol_version is None else self ._local_protocol_version
76+
6077 def _data_received (self , message ):
6178 """Called when data is received from the transport."""
6279 parsed_msg = self ._decoder (message )
@@ -111,16 +128,21 @@ async def async_disconnect(self) -> None:
111128 self ._sync_disconnect ()
112129
113130 def _set_encoder_decoder (self ):
114- """Updates the encoder decoder. For L01 these are updated with nonces after the first hello."""
131+ """Updates the encoder decoder. These are updated with nonces after the first hello.
132+ Only L01 uses the nonces."""
115133 self ._encoder = create_local_encoder (self .device_info .device .local_key , self ._connect_nonce , self ._ack_nonce )
116134 self ._decoder = create_local_decoder (self .device_info .device .local_key , self ._connect_nonce , self ._ack_nonce )
117135
118- async def _do_hello (self , version : str ) -> bool :
136+ async def _do_hello (self , local_protocol_version : LocalProtocolVersion ) -> bool :
119137 """Perform the initial handshaking."""
120- self ._logger .debug (f"Attempting to use the { version } protocol for client { self .device_info .device .duid } ..." )
138+ self ._logger .debug (
139+ "Attempting to use the %s protocol for client %s..." ,
140+ local_protocol_version ,
141+ self .device_info .device .duid ,
142+ )
121143 request = RoborockMessage (
122144 protocol = RoborockMessageProtocol .HELLO_REQUEST ,
123- version = version .encode (),
145+ version = local_protocol_version .encode (),
124146 random = self ._connect_nonce ,
125147 seq = 1 ,
126148 )
@@ -132,31 +154,39 @@ async def _do_hello(self, version: str) -> bool:
132154 )
133155 self ._ack_nonce = response .random
134156 self ._set_encoder_decoder ()
135- self ._version = version
136- self ._logger .debug (f"Client { self .device_info .device .duid } speaks the { version } protocol." )
157+ self ._local_protocol_version = local_protocol_version
158+
159+ self ._logger .debug (
160+ "Client %s speaks the %s protocol." ,
161+ self .device_info .device .duid ,
162+ local_protocol_version ,
163+ )
137164 return True
138165 except RoborockException as e :
139166 self ._logger .debug (
140- f"Client { self .device_info .device .duid } did not respond or does not speak the { version } protocol. { e } "
167+ "Client %s did not respond or does not speak the %s protocol. %s" ,
168+ self .device_info .device .duid ,
169+ local_protocol_version ,
170+ e ,
141171 )
142172 return False
143173
144174 async def hello (self ):
145175 """Send hello to the device to negotiate protocol."""
146- if self ._version :
176+ if self ._local_protocol_version :
147177 # version is forced
148- if not await self ._do_hello (self ._version ):
149- raise RoborockException (f"Failed to connect to device with protocol { self ._version } " )
178+ if not await self ._do_hello (self ._local_protocol_version ):
179+ raise RoborockException (f"Failed to connect to device with protocol { self ._local_protocol_version } " )
150180 else :
151181 # try 1.0, then L01
152- if not await self ._do_hello ("1.0" ):
153- if not await self ._do_hello (" L01" ):
182+ if not await self ._do_hello (LocalProtocolVersion . V1 ):
183+ if not await self ._do_hello (LocalProtocolVersion . L01 ):
154184 raise RoborockException ("Failed to connect to device with any known protocol" )
155185
156186 async def ping (self ) -> None :
157- # Realistically, this should be set here, but this is to be safe and for typing.
158- version = b"1.0" if self . _version is None else self ._version .encode ()
159- ping_message = RoborockMessage ( protocol = RoborockMessageProtocol . PING_REQUEST , version = version )
187+ ping_message = RoborockMessage (
188+ protocol = RoborockMessageProtocol . PING_REQUEST , version = self .local_protocol_version .encode ()
189+ )
160190 await self ._send_message (
161191 roborock_message = ping_message ,
162192 request_id = ping_message .seq ,
@@ -180,7 +210,8 @@ async def _send_command(
180210 raise RoborockException (f"Method { method } is not supported over local connection" )
181211 request_message = RequestMessage (method = method , params = params )
182212 roborock_message = request_message .encode_message (
183- RoborockMessageProtocol .GENERAL_REQUEST , version = self ._version if self ._version is not None else "1.0"
213+ RoborockMessageProtocol .GENERAL_REQUEST ,
214+ version = self .local_protocol_version ,
184215 )
185216 self ._logger .debug ("Building message id %s for method %s" , request_message .request_id , method )
186217 return await self ._send_message (
0 commit comments