diff --git a/roborock/devices/local_channel.py b/roborock/devices/local_channel.py index 21fc1893..eb94646c 100644 --- a/roborock/devices/local_channel.py +++ b/roborock/devices/local_channel.py @@ -62,7 +62,14 @@ def __init__(self, host: str, local_key: str): LocalChannelParams(local_key=local_key, connect_nonce=get_next_int(10000, 32767), ack_nonce=None) ) - def _update_encoder_decoder(self, params: LocalChannelParams): + def _update_encoder_decoder(self, params: LocalChannelParams) -> None: + """Update the encoder and decoder with new parameters. + + This is invoked once with an initial set of values used for protocol + negotiation. Once negotiation completes, it is updated again to set the + correct nonces for the follow up communications and updates the encoder + and decoder functions accordingly. + """ self._params = params self._encoder = create_local_encoder( local_key=params.local_key, connect_nonce=params.connect_nonce, ack_nonce=params.ack_nonce @@ -71,7 +78,7 @@ def _update_encoder_decoder(self, params: LocalChannelParams): local_key=params.local_key, connect_nonce=params.connect_nonce, ack_nonce=params.ack_nonce ) # Callback to decode messages and dispatch to subscribers - self._data_received: Callable[[bytes], None] = decoder_callback(self._decoder, self._subscribers, _LOGGER) + self._dispatch = decoder_callback(self._decoder, self._subscribers, _LOGGER) async def _do_hello(self, local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None: """Perform the initial handshaking and return encoder params if successful.""" @@ -125,6 +132,13 @@ async def _hello(self): raise RoborockException("Failed to connect to device with any known protocol") + @property + def protocol_version(self) -> LocalProtocolVersion: + """Return the negotiated local protocol version, or a sensible default.""" + if self._local_protocol_version is not None: + return self._local_protocol_version + return LocalProtocolVersion.V1 + @property def is_connected(self) -> bool: """Check if the channel is currently connected.""" @@ -157,6 +171,10 @@ async def connect(self) -> None: self.close() raise + def _data_received(self, data: bytes) -> None: + """Invoked when data is received on the stream.""" + self._dispatch(data) + def close(self) -> None: """Disconnect from the device.""" if self._transport: diff --git a/roborock/devices/v1_rpc_channel.py b/roborock/devices/v1_rpc_channel.py index a8685bdb..ffd46d75 100644 --- a/roborock/devices/v1_rpc_channel.py +++ b/roborock/devices/v1_rpc_channel.py @@ -188,7 +188,7 @@ def create_local_rpc_channel(local_channel: LocalChannel) -> V1RpcChannel: return PayloadEncodedV1RpcChannel( "local", local_channel, - lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST), + lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST, version=local_channel.protocol_version), decode_rpc_response, ) diff --git a/roborock/protocols/v1_protocol.py b/roborock/protocols/v1_protocol.py index 47ca8145..213d8d31 100644 --- a/roborock/protocols/v1_protocol.py +++ b/roborock/protocols/v1_protocol.py @@ -73,14 +73,17 @@ class RequestMessage: request_id: int = field(default_factory=lambda: get_next_int(10000, 32767)) def encode_message( - self, protocol: RoborockMessageProtocol, security_data: SecurityData | None = None, version: str = "1.0" + self, + protocol: RoborockMessageProtocol, + security_data: SecurityData | None = None, + version: LocalProtocolVersion = LocalProtocolVersion.V1, ) -> RoborockMessage: """Convert the request message to a RoborockMessage.""" return RoborockMessage( timestamp=self.timestamp, protocol=protocol, payload=self._as_payload(security_data=security_data), - version=version.encode(), + version=version.value.encode(), ) def _as_payload(self, security_data: SecurityData | None) -> bytes: diff --git a/tests/conftest.py b/tests/conftest.py index cde3e146..97792686 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ from roborock import HomeData, UserData from roborock.data import DeviceData +from roborock.protocols.v1_protocol import LocalProtocolVersion from roborock.roborock_message import RoborockMessage from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1 from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1 @@ -361,6 +362,7 @@ def __init__(self): self.subscribe = AsyncMock(side_effect=self._subscribe) self.connect = AsyncMock(side_effect=self._connect) self.close = MagicMock(side_effect=self._close) + self.protocol_version = LocalProtocolVersion.V1 async def _connect(self) -> None: self._is_connected = True diff --git a/tests/devices/test_local_channel.py b/tests/devices/test_local_channel.py index 19a85b39..72a9a83a 100644 --- a/tests/devices/test_local_channel.py +++ b/tests/devices/test_local_channel.py @@ -281,6 +281,8 @@ async def test_hello_success_with_v1_protocol_first(mock_loop: Mock, mock_transp # Create a channel without the automatic hello mocking channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY) + # Clear cached protocol to ensure V1 is tried first + channel._local_protocol_version = None # Mock _do_hello to succeed for V1 on first attempt async def mock_do_hello(local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None: