1919_TIMEOUT = 10.0
2020
2121
22+ @dataclass
23+ class LocalChannelParams :
24+ """Parameters for local channel encoder/decoder."""
25+
26+ local_key : str
27+ connect_nonce : int
28+ ack_nonce : int | None
29+
30+
2231@dataclass
2332class _LocalProtocol (asyncio .Protocol ):
2433 """Callbacks for the Roborock local client transport."""
@@ -42,30 +51,34 @@ class LocalChannel(Channel):
4251 format most parsing to higher-level components.
4352 """
4453
45- def __init__ (self , host : str , local_key : str , local_protocol_version : LocalProtocolVersion | None = None ):
54+ def __init__ (self , host : str , local_key : str ):
4655 self ._host = host
4756 self ._transport : asyncio .Transport | None = None
4857 self ._protocol : _LocalProtocol | None = None
4958 self ._subscribers : CallbackList [RoborockMessage ] = CallbackList (_LOGGER )
5059 self ._is_connected = False
5160 self ._local_key = local_key
52- self ._local_protocol_version = local_protocol_version
61+ self ._local_protocol_version : LocalProtocolVersion | None = None
5362 self ._connect_nonce = get_next_int (10000 , 32767 )
5463 self ._ack_nonce : int | None = None
5564 self ._update_encoder_decoder ()
5665
57- def _update_encoder_decoder (self ):
66+ def _update_encoder_decoder (self , params : LocalChannelParams | None = None ):
67+ if params is None :
68+ params = LocalChannelParams (
69+ local_key = self ._local_key , connect_nonce = self ._connect_nonce , ack_nonce = self ._ack_nonce
70+ )
5871 self ._encoder = create_local_encoder (
59- local_key = self . _local_key , connect_nonce = self . _connect_nonce , ack_nonce = self . _ack_nonce
72+ local_key = params . local_key , connect_nonce = params . connect_nonce , ack_nonce = params . ack_nonce
6073 )
6174 self ._decoder = create_local_decoder (
62- local_key = self . _local_key , connect_nonce = self . _connect_nonce , ack_nonce = self . _ack_nonce
75+ local_key = params . local_key , connect_nonce = params . connect_nonce , ack_nonce = params . ack_nonce
6376 )
6477 # Callback to decode messages and dispatch to subscribers
6578 self ._data_received : Callable [[bytes ], None ] = decoder_callback (self ._decoder , self ._subscribers , _LOGGER )
6679
67- async def _do_hello (self , local_protocol_version : LocalProtocolVersion ) -> bool :
68- """Perform the initial handshaking."""
80+ async def _do_hello (self , local_protocol_version : LocalProtocolVersion ) -> LocalChannelParams | None :
81+ """Perform the initial handshaking and return encoder params if successful ."""
6982 _LOGGER .debug (
7083 "Attempting to use the %s protocol for client %s..." ,
7184 local_protocol_version ,
@@ -83,41 +96,39 @@ async def _do_hello(self, local_protocol_version: LocalProtocolVersion) -> bool:
8396 request_id = request .seq ,
8497 response_protocol = RoborockMessageProtocol .HELLO_RESPONSE ,
8598 )
86- self ._ack_nonce = response .random
87- self ._local_protocol_version = local_protocol_version
88- self ._update_encoder_decoder ()
89-
9099 _LOGGER .debug (
91100 "Client %s speaks the %s protocol." ,
92101 self ._host ,
93102 local_protocol_version ,
94103 )
95- return True
104+ return LocalChannelParams (
105+ local_key = self ._local_key , connect_nonce = self ._connect_nonce , ack_nonce = response .random
106+ )
96107 except RoborockException as e :
97108 _LOGGER .debug (
98109 "Client %s did not respond or does not speak the %s protocol. %s" ,
99110 self ._host ,
100111 local_protocol_version ,
101112 e ,
102113 )
103- return False
114+ return None
104115
105- async def hello (self ):
116+ async def _hello (self ):
106117 """Send hello to the device to negotiate protocol."""
118+ attempt_versions = [LocalProtocolVersion .V1 , LocalProtocolVersion .L01 ]
107119 if self ._local_protocol_version :
108- # version is forced - try it first, if it fails, try the opposite
109- if not await self ._do_hello (self ._local_protocol_version ):
110- if not await self ._do_hello (
111- LocalProtocolVersion .V1
112- if self ._local_protocol_version is not LocalProtocolVersion .V1
113- else LocalProtocolVersion .L01
114- ):
115- raise RoborockException ("Failed to connect to device with any known protocol" )
116- else :
117- # try 1.0, then L01
118- if not await self ._do_hello (LocalProtocolVersion .V1 ):
119- if not await self ._do_hello (LocalProtocolVersion .L01 ):
120- raise RoborockException ("Failed to connect to device with any known protocol" )
120+ # Sort to try the preferred version first
121+ attempt_versions .sort (key = lambda v : v != self ._local_protocol_version )
122+
123+ for version in attempt_versions :
124+ params = await self ._do_hello (version )
125+ if params is not None :
126+ self ._ack_nonce = params .ack_nonce
127+ self ._local_protocol_version = version
128+ self ._update_encoder_decoder (params )
129+ return
130+
131+ raise RoborockException ("Failed to connect to device with any known protocol" )
121132
122133 @property
123134 def is_connected (self ) -> bool :
@@ -130,7 +141,7 @@ def is_local_connected(self) -> bool:
130141 return self ._is_connected
131142
132143 async def connect (self ) -> None :
133- """Connect to the device."""
144+ """Connect to the device and negotiate protocol ."""
134145 if self ._is_connected :
135146 _LOGGER .warning ("Already connected" )
136147 return
@@ -143,6 +154,9 @@ async def connect(self) -> None:
143154 except OSError as e :
144155 raise RoborockConnectionException (f"Failed to connect to { self ._host } :{ _PORT } " ) from e
145156
157+ # Perform protocol negotiation
158+ await self ._hello ()
159+
146160 def close (self ) -> None :
147161 """Disconnect from the device."""
148162 if self ._transport :
0 commit comments