Skip to content

Commit d8028db

Browse files
committed
chore: fix comments
1 parent 0d807db commit d8028db

File tree

2 files changed

+42
-29
lines changed

2 files changed

+42
-29
lines changed

roborock/devices/local_channel.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
_TIMEOUT = 10.0
2020

2121

22+
@dataclass
23+
class LocalChannelParams:
24+
"""Parameters for local channel encoder/decoder."""
25+
26+
local_key: str
27+
connect_nonce: int
28+
ack_nonce: int | None
29+
30+
2231
@dataclass
2332
class _LocalProtocol(asyncio.Protocol):
2433
"""Callbacks for the Roborock local client transport."""
@@ -42,30 +51,34 @@ class LocalChannel(Channel):
4251
format most parsing to higher-level components.
4352
"""
4453

45-
def __init__(self, host: str, local_key: str, local_protocol_version: LocalProtocolVersion | None = None):
54+
def __init__(self, host: str, local_key: str):
4655
self._host = host
4756
self._transport: asyncio.Transport | None = None
4857
self._protocol: _LocalProtocol | None = None
4958
self._subscribers: CallbackList[RoborockMessage] = CallbackList(_LOGGER)
5059
self._is_connected = False
5160
self._local_key = local_key
52-
self._local_protocol_version = local_protocol_version
61+
self._local_protocol_version: LocalProtocolVersion | None = None
5362
self._connect_nonce = get_next_int(10000, 32767)
5463
self._ack_nonce: int | None = None
5564
self._update_encoder_decoder()
5665

57-
def _update_encoder_decoder(self):
66+
def _update_encoder_decoder(self, params: LocalChannelParams | None = None):
67+
if params is None:
68+
params = LocalChannelParams(
69+
local_key=self._local_key, connect_nonce=self._connect_nonce, ack_nonce=self._ack_nonce
70+
)
5871
self._encoder = create_local_encoder(
59-
local_key=self._local_key, connect_nonce=self._connect_nonce, ack_nonce=self._ack_nonce
72+
local_key=params.local_key, connect_nonce=params.connect_nonce, ack_nonce=params.ack_nonce
6073
)
6174
self._decoder = create_local_decoder(
62-
local_key=self._local_key, connect_nonce=self._connect_nonce, ack_nonce=self._ack_nonce
75+
local_key=params.local_key, connect_nonce=params.connect_nonce, ack_nonce=params.ack_nonce
6376
)
6477
# Callback to decode messages and dispatch to subscribers
6578
self._data_received: Callable[[bytes], None] = decoder_callback(self._decoder, self._subscribers, _LOGGER)
6679

67-
async def _do_hello(self, local_protocol_version: LocalProtocolVersion) -> bool:
68-
"""Perform the initial handshaking."""
80+
async def _do_hello(self, local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None:
81+
"""Perform the initial handshaking and return encoder params if successful."""
6982
_LOGGER.debug(
7083
"Attempting to use the %s protocol for client %s...",
7184
local_protocol_version,
@@ -83,41 +96,39 @@ async def _do_hello(self, local_protocol_version: LocalProtocolVersion) -> bool:
8396
request_id=request.seq,
8497
response_protocol=RoborockMessageProtocol.HELLO_RESPONSE,
8598
)
86-
self._ack_nonce = response.random
87-
self._local_protocol_version = local_protocol_version
88-
self._update_encoder_decoder()
89-
9099
_LOGGER.debug(
91100
"Client %s speaks the %s protocol.",
92101
self._host,
93102
local_protocol_version,
94103
)
95-
return True
104+
return LocalChannelParams(
105+
local_key=self._local_key, connect_nonce=self._connect_nonce, ack_nonce=response.random
106+
)
96107
except RoborockException as e:
97108
_LOGGER.debug(
98109
"Client %s did not respond or does not speak the %s protocol. %s",
99110
self._host,
100111
local_protocol_version,
101112
e,
102113
)
103-
return False
114+
return None
104115

105-
async def hello(self):
116+
async def _hello(self):
106117
"""Send hello to the device to negotiate protocol."""
118+
attempt_versions = [LocalProtocolVersion.V1, LocalProtocolVersion.L01]
107119
if self._local_protocol_version:
108-
# version is forced - try it first, if it fails, try the opposite
109-
if not await self._do_hello(self._local_protocol_version):
110-
if not await self._do_hello(
111-
LocalProtocolVersion.V1
112-
if self._local_protocol_version is not LocalProtocolVersion.V1
113-
else LocalProtocolVersion.L01
114-
):
115-
raise RoborockException("Failed to connect to device with any known protocol")
116-
else:
117-
# try 1.0, then L01
118-
if not await self._do_hello(LocalProtocolVersion.V1):
119-
if not await self._do_hello(LocalProtocolVersion.L01):
120-
raise RoborockException("Failed to connect to device with any known protocol")
120+
# Sort to try the preferred version first
121+
attempt_versions.sort(key=lambda v: v != self._local_protocol_version)
122+
123+
for version in attempt_versions:
124+
params = await self._do_hello(version)
125+
if params is not None:
126+
self._ack_nonce = params.ack_nonce
127+
self._local_protocol_version = version
128+
self._update_encoder_decoder(params)
129+
return
130+
131+
raise RoborockException("Failed to connect to device with any known protocol")
121132

122133
@property
123134
def is_connected(self) -> bool:
@@ -130,7 +141,7 @@ def is_local_connected(self) -> bool:
130141
return self._is_connected
131142

132143
async def connect(self) -> None:
133-
"""Connect to the device."""
144+
"""Connect to the device and negotiate protocol."""
134145
if self._is_connected:
135146
_LOGGER.warning("Already connected")
136147
return
@@ -143,6 +154,9 @@ async def connect(self) -> None:
143154
except OSError as e:
144155
raise RoborockConnectionException(f"Failed to connect to {self._host}:{_PORT}") from e
145156

157+
# Perform protocol negotiation
158+
await self._hello()
159+
146160
def close(self) -> None:
147161
"""Disconnect from the device."""
148162
if self._transport:

roborock/devices/v1_channel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ async def _local_connect(self, *, use_cache: bool = True) -> None:
212212
raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e
213213
# Wire up the new channel
214214
self._local_channel = local_channel
215-
await self._local_channel.hello()
216215
self._local_rpc_channel = create_local_rpc_channel(self._local_channel)
217216
self._local_unsub = await self._local_channel.subscribe(self._on_local_message)
218217
_LOGGER.info("Successfully connected to local device %s", self._device_uid)

0 commit comments

Comments
 (0)