Skip to content

Commit 0fcf224

Browse files
committed
Update the messages callback to not mutate the protocol once created.
1 parent 278f54c commit 0fcf224

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

roborock/devices/local_channel.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,26 @@ class LocalChannel(Channel):
5151
format most parsing to higher-level components.
5252
"""
5353

54-
_protocol_cache: dict[str, LocalProtocolVersion] = {}
5554

5655
def __init__(self, host: str, local_key: str):
5756
self._host = host
5857
self._transport: asyncio.Transport | None = None
5958
self._protocol: _LocalProtocol | None = None
6059
self._subscribers: CallbackList[RoborockMessage] = CallbackList(_LOGGER)
6160
self._is_connected = False
62-
self._local_protocol_version: LocalProtocolVersion | None = self._protocol_cache.get(host)
61+
self._local_protocol_version: LocalProtocolVersion | None = None
6362
self._update_encoder_decoder(
6463
LocalChannelParams(local_key=local_key, connect_nonce=get_next_int(10000, 32767), ack_nonce=None)
6564
)
6665

67-
def _update_encoder_decoder(self, params: LocalChannelParams):
66+
def _update_encoder_decoder(self, params: LocalChannelParams) -> None:
67+
"""Update the encoder and decoder with new parameters.
68+
69+
This is invoked once with an initial set of values used for protocol
70+
negotiation. Once negotiation completes, it is updated again to set the
71+
correct nonces for the follow up communications and updates the encoder
72+
and decoder functions accordingly.
73+
"""
6874
self._params = params
6975
self._encoder = create_local_encoder(
7076
local_key=params.local_key, connect_nonce=params.connect_nonce, ack_nonce=params.ack_nonce
@@ -73,9 +79,7 @@ def _update_encoder_decoder(self, params: LocalChannelParams):
7379
local_key=params.local_key, connect_nonce=params.connect_nonce, ack_nonce=params.ack_nonce
7480
)
7581
# Callback to decode messages and dispatch to subscribers
76-
self._data_received: Callable[[bytes], None] = decoder_callback(self._decoder, self._subscribers, _LOGGER)
77-
if self._protocol:
78-
self._protocol.messages_cb = self._data_received
82+
self._dispatch = decoder_callback(self._decoder, self._subscribers, _LOGGER)
7983

8084
async def _do_hello(self, local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None:
8185
"""Perform the initial handshaking and return encoder params if successful."""
@@ -125,7 +129,6 @@ async def _hello(self):
125129
if params is not None:
126130
self._local_protocol_version = version
127131
self._update_encoder_decoder(params)
128-
self._protocol_cache[self._host] = self._local_protocol_version
129132
return
130133

131134
raise RoborockException("Failed to connect to device with any known protocol")
@@ -169,6 +172,10 @@ async def connect(self) -> None:
169172
self.close()
170173
raise
171174

175+
def _data_received(self, data: bytes) -> None:
176+
"""Invoked when data is received on the stream."""
177+
self._dispatch(data)
178+
172179
def close(self) -> None:
173180
"""Disconnect from the device."""
174181
if self._transport:

0 commit comments

Comments
 (0)