From a02b3bd21f5303bb59a9a20aada66ab9b98bffe9 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 7 Dec 2025 13:51:00 -0800 Subject: [PATCH 1/3] fix: Fix exception when sending dyad/zeo requests The bug was introduced in #645. Add tests that exercise the actual request encoding. This changes the ID QUERY value encoding by passing in a function, which is another variation on the first version of #645 where the json encoding happened inside the decode function. --- roborock/devices/a01_channel.py | 6 +- roborock/devices/traits/a01/__init__.py | 17 ++- roborock/protocols/a01_protocol.py | 20 ++- tests/devices/test_a01_channel.py | 3 +- tests/devices/traits/a01/test_init.py | 169 ++++++++++++++---------- tests/protocols/common.py | 26 ++++ tests/protocols/test_a01_protocol.py | 37 ++++-- tests/test_a01_api.py | 33 +---- 8 files changed, 200 insertions(+), 111 deletions(-) create mode 100644 tests/protocols/common.py diff --git a/roborock/devices/a01_channel.py b/roborock/devices/a01_channel.py index ae1a5d18..f698bb6e 100644 --- a/roborock/devices/a01_channel.py +++ b/roborock/devices/a01_channel.py @@ -2,6 +2,7 @@ import asyncio import logging +from collections.abc import Callable from typing import Any, overload from roborock.exceptions import RoborockException @@ -29,6 +30,7 @@ async def send_decoded_command( mqtt_channel: MqttChannel, params: dict[RoborockDyadDataProtocol, Any], + value_encoder: Callable[[Any], Any] | None = None, ) -> dict[RoborockDyadDataProtocol, Any]: ... @@ -36,16 +38,18 @@ async def send_decoded_command( async def send_decoded_command( mqtt_channel: MqttChannel, params: dict[RoborockZeoProtocol, Any], + value_encoder: Callable[[Any], Any] | None = None, ) -> dict[RoborockZeoProtocol, Any]: ... async def send_decoded_command( mqtt_channel: MqttChannel, params: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any], + value_encoder: Callable[[Any], Any] | None = None, ) -> dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any]: """Send a command on the MQTT channel and get a decoded response.""" _LOGGER.debug("Sending MQTT command: %s", params) - roborock_message = encode_mqtt_payload(params) + roborock_message = encode_mqtt_payload(params, value_encoder) # For commands that set values: send the command and do not # block waiting for a response. Queries are handled below. diff --git a/roborock/devices/traits/a01/__init__.py b/roborock/devices/traits/a01/__init__.py index 1e3f44ae..e56f6e23 100644 --- a/roborock/devices/traits/a01/__init__.py +++ b/roborock/devices/traits/a01/__init__.py @@ -1,3 +1,4 @@ +import json from collections.abc import Callable from datetime import time from typing import Any @@ -121,8 +122,11 @@ def __init__(self, channel: MqttChannel) -> None: async def query_values(self, protocols: list[RoborockDyadDataProtocol]) -> dict[RoborockDyadDataProtocol, Any]: """Query the device for the values of the given Dyad protocols.""" - params = {RoborockDyadDataProtocol.ID_QUERY: str([int(p) for p in protocols])} - response = await send_decoded_command(self._channel, params) + response = await send_decoded_command( + self._channel, + {RoborockDyadDataProtocol.ID_QUERY: protocols}, + value_encoder=json.dumps, + ) return {protocol: convert_dyad_value(protocol, response.get(protocol)) for protocol in protocols} async def set_value(self, protocol: RoborockDyadDataProtocol, value: Any) -> dict[RoborockDyadDataProtocol, Any]: @@ -142,14 +146,17 @@ def __init__(self, channel: MqttChannel) -> None: async def query_values(self, protocols: list[RoborockZeoProtocol]) -> dict[RoborockZeoProtocol, Any]: """Query the device for the values of the given protocols.""" - params = {RoborockZeoProtocol.ID_QUERY: str([int(p) for p in protocols])} - response = await send_decoded_command(self._channel, params) + response = await send_decoded_command( + self._channel, + {RoborockZeoProtocol.ID_QUERY: protocols}, + value_encoder=json.dumps, + ) return {protocol: convert_zeo_value(protocol, response.get(protocol)) for protocol in protocols} async def set_value(self, protocol: RoborockZeoProtocol, value: Any) -> dict[RoborockZeoProtocol, Any]: """Set a value for a specific protocol on the device.""" params = {protocol: value} - return await send_decoded_command(self._channel, params) + return await send_decoded_command(self._channel, params, value_encoder=lambda x: x) def create(product: HomeDataProduct, mqtt_channel: MqttChannel) -> DyadApi | ZeoApi: diff --git a/roborock/protocols/a01_protocol.py b/roborock/protocols/a01_protocol.py index 5aa5ffb2..f3166de8 100644 --- a/roborock/protocols/a01_protocol.py +++ b/roborock/protocols/a01_protocol.py @@ -2,6 +2,7 @@ import json import logging +from collections.abc import Callable from typing import Any from Crypto.Cipher import AES @@ -20,13 +21,28 @@ A01_VERSION = b"A01" +def _no_encode(value: Any) -> Any: + return value + + def encode_mqtt_payload( data: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any] | dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any], + value_encoder: Callable[[Any], Any] | None = None, ) -> RoborockMessage: - """Encode payload for A01 commands over MQTT.""" - dps_data = {"dps": data} + """Encode payload for A01 commands over MQTT. + + Args: + data: The data to encode. + value_encoder: A function to encode the values of the dictionary. + + Returns: + RoborockMessage: The encoded message. + """ + if value_encoder is None: + value_encoder = _no_encode + dps_data = {"dps": {key: value_encoder(value) for key, value in data.items()}} payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size) return RoborockMessage( protocol=RoborockMessageProtocol.RPC_REQUEST, diff --git a/tests/devices/test_a01_channel.py b/tests/devices/test_a01_channel.py index 70f1ade3..e5bf112a 100644 --- a/tests/devices/test_a01_channel.py +++ b/tests/devices/test_a01_channel.py @@ -34,7 +34,8 @@ async def test_id_query(mock_mqtt_channel: FakeChannel): { RoborockDyadDataProtocol.WARM_LEVEL: 101, RoborockDyadDataProtocol.POWER: 75, - } + }, + value_encoder=lambda x: x, ) response_message = RoborockMessage( protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=encoded.payload, version=encoded.version diff --git a/tests/devices/traits/a01/test_init.py b/tests/devices/traits/a01/test_init.py index a8d2a8c2..9b9c83b3 100644 --- a/tests/devices/traits/a01/test_init.py +++ b/tests/devices/traits/a01/test_init.py @@ -1,43 +1,51 @@ import datetime -from collections.abc import Generator +import json from typing import Any -from unittest.mock import AsyncMock, call, patch import pytest +from Crypto.Cipher import AES +from Crypto.Util.Padding import unpad -from roborock.devices.mqtt_channel import MqttChannel from roborock.devices.traits.a01 import DyadApi, ZeoApi -from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol +from roborock.roborock_message import RoborockDyadDataProtocol, RoborockMessageProtocol, RoborockZeoProtocol +from tests.conftest import FakeChannel +from tests.protocols.common import build_a01_message -@pytest.fixture(name="mock_channel") -def mock_channel_fixture() -> AsyncMock: - return AsyncMock(spec=MqttChannel) +@pytest.fixture(name="fake_channel") +def fake_channel_fixture() -> FakeChannel: + return FakeChannel() -@pytest.fixture(name="mock_send") -def mock_send_fixture(mock_channel) -> Generator[AsyncMock, None, None]: - with patch("roborock.devices.traits.a01.send_decoded_command") as mock_send: - yield mock_send +@pytest.fixture(name="dyad_api") +def dyad_api_fixture(fake_channel: FakeChannel) -> DyadApi: + return DyadApi(fake_channel) # type: ignore[arg-type] -async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMock): +@pytest.fixture(name="zeo_api") +def zeo_api_fixture(fake_channel: FakeChannel) -> ZeoApi: + return ZeoApi(fake_channel) # type: ignore[arg-type] + + +async def test_dyad_api_query_values(dyad_api: DyadApi, fake_channel: FakeChannel): """Test that DyadApi currently returns raw values without conversion.""" - api = DyadApi(mock_channel) - - mock_send.return_value = { - 209: 1, # POWER - 201: 6, # STATUS - 207: 3, # WATER_LEVEL - 214: 120, # MESH_LEFT - 215: 90, # BRUSH_LEFT - 227: 85, # SILENT_MODE_START_TIME - 229: "3,4,5", # RECENT_RUN_TIME - 230: 123456, # TOTAL_RUN_TIME - 222: 1, # STAND_LOCK_AUTO_RUN - 224: 0, # AUTO_DRY_MODE - } - result = await api.query_values( + fake_channel.response_queue.append( + build_a01_message( + { + 209: 1, # POWER + 201: 6, # STATUS + 207: 3, # WATER_LEVEL + 214: 120, # MESH_LEFT + 215: 90, # BRUSH_LEFT + 227: 85, # SILENT_MODE_START_TIME + 229: "3,4,5", # RECENT_RUN_TIME + 230: 123456, # TOTAL_RUN_TIME + 222: 1, # STAND_LOCK_AUTO_RUN + 224: 0, # AUTO_DRY_MODE + } + ) + ) + result = await dyad_api.query_values( [ RoborockDyadDataProtocol.POWER, RoborockDyadDataProtocol.STATUS, @@ -64,15 +72,12 @@ async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMo RoborockDyadDataProtocol.AUTO_DRY_MODE: False, } - # Note: Bug here, this is the wrong encoding for the query - assert mock_send.call_args_list == [ - call( - mock_channel, - { - RoborockDyadDataProtocol.ID_QUERY: "[209, 201, 207, 214, 215, 227, 229, 230, 222, 224]", - }, - ), - ] + assert len(fake_channel.published_messages) == 1 + message = fake_channel.published_messages[0] + assert message.protocol == RoborockMessageProtocol.RPC_REQUEST + assert message.version == b"A01" + payload_data = json.loads(unpad(message.payload, AES.block_size)) + assert payload_data == {"dps": {"10000": "[209, 201, 207, 214, 215, 227, 229, 230, 222, 224]"}} @pytest.mark.parametrize( @@ -117,33 +122,34 @@ async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMo ], ) async def test_dyad_invalid_response_value( - mock_channel: AsyncMock, - mock_send: AsyncMock, query: list[RoborockDyadDataProtocol], response: dict[int, Any], expected_result: dict[RoborockDyadDataProtocol, Any], + dyad_api: DyadApi, + fake_channel: FakeChannel, ): """Test that DyadApi currently returns raw values without conversion.""" - api = DyadApi(mock_channel) + fake_channel.response_queue.append(build_a01_message(response)) - mock_send.return_value = response - result = await api.query_values(query) + result = await dyad_api.query_values(query) assert result == expected_result -async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMock): +async def test_zeo_api_query_values(zeo_api: ZeoApi, fake_channel: FakeChannel): """Test that ZeoApi currently returns raw values without conversion.""" - api = ZeoApi(mock_channel) - - mock_send.return_value = { - 203: 6, # spinning - 207: 3, # medium - 226: 1, - 227: 0, - 224: 1, # Times after clean. Testing int value - 218: 0, # Washing left. Testing zero int value - } - result = await api.query_values( + fake_channel.response_queue.append( + build_a01_message( + { + 203: 6, # spinning + 207: 3, # medium + 226: 1, + 227: 0, + 224: 1, # Times after clean. Testing int value + 218: 0, # Washing left. Testing zero int value + } + ) + ) + result = await zeo_api.query_values( [ RoborockZeoProtocol.STATE, RoborockZeoProtocol.TEMP, @@ -162,15 +168,13 @@ async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMoc RoborockZeoProtocol.TIMES_AFTER_CLEAN: 1, RoborockZeoProtocol.WASHING_LEFT: 0, } - # Note: Bug here, this is the wrong encoding for the query - assert mock_send.call_args_list == [ - call( - mock_channel, - { - RoborockZeoProtocol.ID_QUERY: "[203, 207, 226, 227, 224, 218]", - }, - ), - ] + + assert len(fake_channel.published_messages) == 1 + message = fake_channel.published_messages[0] + assert message.protocol == RoborockMessageProtocol.RPC_REQUEST + assert message.version == b"A01" + payload_data = json.loads(unpad(message.payload, AES.block_size)) + assert payload_data == {"dps": {"10000": "[203, 207, 226, 227, 224, 218]"}} @pytest.mark.parametrize( @@ -215,15 +219,46 @@ async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMoc ], ) async def test_zeo_invalid_response_value( - mock_channel: AsyncMock, - mock_send: AsyncMock, query: list[RoborockZeoProtocol], response: dict[int, Any], expected_result: dict[RoborockZeoProtocol, Any], + zeo_api: ZeoApi, + fake_channel: FakeChannel, ): """Test that ZeoApi currently returns raw values without conversion.""" - api = ZeoApi(mock_channel) + fake_channel.response_queue.append(build_a01_message(response)) - mock_send.return_value = response - result = await api.query_values(query) + result = await zeo_api.query_values(query) assert result == expected_result + + +async def test_dyad_api_set_value(dyad_api: DyadApi, fake_channel: FakeChannel): + """Test DyadApi set_value sends correct command.""" + await dyad_api.set_value(RoborockDyadDataProtocol.POWER, 1) + + assert len(fake_channel.published_messages) == 1 + message = fake_channel.published_messages[0] + + assert message.protocol == RoborockMessageProtocol.RPC_REQUEST + assert message.version == b"A01" + + # decode the payload to verify contents + payload_data = json.loads(unpad(message.payload, AES.block_size)) + # A01 protocol expects values to be strings in the dps dict + assert payload_data == {"dps": {"209": 1}} + + +async def test_zeo_api_set_value(zeo_api: ZeoApi, fake_channel: FakeChannel): + """Test ZeoApi set_value sends correct command.""" + await zeo_api.set_value(RoborockZeoProtocol.MODE, "standard") + + assert len(fake_channel.published_messages) == 1 + message = fake_channel.published_messages[0] + + assert message.protocol == RoborockMessageProtocol.RPC_REQUEST + assert message.version == b"A01" + + # decode the payload to verify contents + payload_data = json.loads(unpad(message.payload, AES.block_size)) + # A01 protocol expects values to be strings in the dps dict + assert payload_data == {"dps": {"204": "standard"}} diff --git a/tests/protocols/common.py b/tests/protocols/common.py new file mode 100644 index 00000000..0c1411da --- /dev/null +++ b/tests/protocols/common.py @@ -0,0 +1,26 @@ +"""Common test utils for the protocols package.""" + +import json +from typing import Any + +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad + +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol + + +def build_a01_message(message: dict[Any, Any], seq: int = 2020) -> RoborockMessage: + """Build an encoded A01 RPC response message.""" + return RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=pad( + json.dumps( + { + "dps": message, # {10000: json.dumps(message)}, + } + ).encode(), + AES.block_size, + ), + version=b"A01", + seq=seq, + ) diff --git a/tests/protocols/test_a01_protocol.py b/tests/protocols/test_a01_protocol.py index 57ad6895..f5dddd55 100644 --- a/tests/protocols/test_a01_protocol.py +++ b/tests/protocols/test_a01_protocol.py @@ -4,6 +4,8 @@ from typing import Any import pytest +from Crypto.Cipher import AES +from Crypto.Util.Padding import unpad from roborock.exceptions import RoborockException from roborock.protocols.a01_protocol import decode_rpc_response, encode_mqtt_payload @@ -34,7 +36,7 @@ def test_encode_mqtt_payload_basic(): # Decode the payload to verify structure decoded_data = decode_rpc_response(result) - assert decoded_data == {200: {"test": "data", "number": 42}} + assert decoded_data == {200: '{"test": "data", "number": 42}'} def test_encode_mqtt_payload_empty_data(): @@ -52,6 +54,21 @@ def test_encode_mqtt_payload_empty_data(): assert decoded_data == {} +def test_encode_mqtt_payload_list_conversion(): + """Test that lists are converted to string representation (Fix validity).""" + # This verifies the fix where lists must be encoded as strings + data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any] = {RoborockDyadDataProtocol.ID_QUERY: [101, 102]} + + result = encode_mqtt_payload(data) + + # Decode manually to check the raw JSON structure + decoded_json = json.loads(unpad(result.payload, AES.block_size).decode()) + + # ID_QUERY (10000) should be a string "[101, 102]", not a list [101, 102] + assert decoded_json["dps"]["10000"] == "[101, 102]" + assert isinstance(decoded_json["dps"]["10000"], str) + + def test_encode_mqtt_payload_complex_data(): """Test encoding with complex nested data.""" data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any] = { @@ -74,13 +91,17 @@ def test_encode_mqtt_payload_complex_data(): # Decode the payload to verify structure decoded_data = decode_rpc_response(result) assert decoded_data == { - 201: { - "nested": {"deep": {"value": 123}}, - "list": [1, 2, 3, "test"], - "boolean": True, - "null": None, - }, - 204: "simple_value", + 201: json.dumps( + { + "nested": {"deep": {"value": 123}}, + # Note: The list inside the dictionary is NOT converted because + # our fix only targets top-level list values in the dps map + "list": [1, 2, 3, "test"], + "boolean": True, + "null": None, + } + ), + 204: '"simple_value"', } diff --git a/tests/test_a01_api.py b/tests/test_a01_api.py index f582f5f2..370ec6cc 100644 --- a/tests/test_a01_api.py +++ b/tests/test_a01_api.py @@ -1,5 +1,4 @@ import asyncio -import json from collections.abc import AsyncGenerator from queue import Queue from typing import Any @@ -7,8 +6,6 @@ import paho.mqtt.client as mqtt import pytest -from Crypto.Cipher import AES -from Crypto.Util.Padding import pad from roborock import ( HomeData, @@ -19,12 +16,13 @@ from roborock.protocol import MessageParser from roborock.roborock_message import ( RoborockDyadDataProtocol, - RoborockMessage, - RoborockMessageProtocol, RoborockZeoProtocol, ) from roborock.version_a01_apis import RoborockMqttClientA01 -from tests.mock_data import ( + +from . import mqtt_packet +from .conftest import QUEUE_TIMEOUT +from .mock_data import ( HOME_DATA_RAW, LOCAL_KEY, MQTT_PUBLISH_TOPIC, @@ -32,9 +30,7 @@ WASHER_PRODUCT, ZEO_ONE_DEVICE, ) - -from . import mqtt_packet -from .conftest import QUEUE_TIMEOUT +from .protocols.common import build_a01_message RELEASE_TIMEOUT = 2 @@ -170,24 +166,7 @@ async def test_subscribe_failure( def build_rpc_response(message: dict[Any, Any]) -> bytes: """Build an encoded RPC response message.""" - return MessageParser.build( - [ - RoborockMessage( - protocol=RoborockMessageProtocol.RPC_RESPONSE, - payload=pad( - json.dumps( - { - "dps": message, # {10000: json.dumps(message)}, - } - ).encode(), - AES.block_size, - ), - version=b"A01", - seq=2020, - ), - ], - local_key=LOCAL_KEY, - ) + return MessageParser.build([build_a01_message(message)], local_key=LOCAL_KEY) async def test_update_zeo_values( From 0241037a9b0481701b99b61e87f74924b92e226f Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 7 Dec 2025 13:54:30 -0800 Subject: [PATCH 2/3] chore: fix tests to be focused on value encoder --- tests/protocols/test_a01_protocol.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/protocols/test_a01_protocol.py b/tests/protocols/test_a01_protocol.py index f5dddd55..f6a5861a 100644 --- a/tests/protocols/test_a01_protocol.py +++ b/tests/protocols/test_a01_protocol.py @@ -36,7 +36,7 @@ def test_encode_mqtt_payload_basic(): # Decode the payload to verify structure decoded_data = decode_rpc_response(result) - assert decoded_data == {200: '{"test": "data", "number": 42}'} + assert decoded_data == {200: {"test": "data", "number": 42}} def test_encode_mqtt_payload_empty_data(): @@ -54,12 +54,11 @@ def test_encode_mqtt_payload_empty_data(): assert decoded_data == {} -def test_encode_mqtt_payload_list_conversion(): - """Test that lists are converted to string representation (Fix validity).""" - # This verifies the fix where lists must be encoded as strings +def test_value_encoder(): + """Test that value_encoder is applied to all values.""" data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any] = {RoborockDyadDataProtocol.ID_QUERY: [101, 102]} - result = encode_mqtt_payload(data) + result = encode_mqtt_payload(data, value_encoder=json.dumps) # Decode manually to check the raw JSON structure decoded_json = json.loads(unpad(result.payload, AES.block_size).decode()) @@ -91,17 +90,15 @@ def test_encode_mqtt_payload_complex_data(): # Decode the payload to verify structure decoded_data = decode_rpc_response(result) assert decoded_data == { - 201: json.dumps( - { + 201: { "nested": {"deep": {"value": 123}}, # Note: The list inside the dictionary is NOT converted because # our fix only targets top-level list values in the dps map "list": [1, 2, 3, "test"], "boolean": True, "null": None, - } - ), - 204: '"simple_value"', + }, + 204: "simple_value", } From de7c98ca90045520cfc958871593fbd6703c592f Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 7 Dec 2025 13:55:25 -0800 Subject: [PATCH 3/3] chore: fix lint --- tests/protocols/test_a01_protocol.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/protocols/test_a01_protocol.py b/tests/protocols/test_a01_protocol.py index f6a5861a..f60cfdff 100644 --- a/tests/protocols/test_a01_protocol.py +++ b/tests/protocols/test_a01_protocol.py @@ -91,12 +91,12 @@ def test_encode_mqtt_payload_complex_data(): decoded_data = decode_rpc_response(result) assert decoded_data == { 201: { - "nested": {"deep": {"value": 123}}, - # Note: The list inside the dictionary is NOT converted because - # our fix only targets top-level list values in the dps map - "list": [1, 2, 3, "test"], - "boolean": True, - "null": None, + "nested": {"deep": {"value": 123}}, + # Note: The list inside the dictionary is NOT converted because + # our fix only targets top-level list values in the dps map + "list": [1, 2, 3, "test"], + "boolean": True, + "null": None, }, 204: "simple_value", }