From bc964fd0dc06ff0360605f87c5174ea8daf24d3d Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 6 Dec 2025 20:55:38 -0800 Subject: [PATCH 1/4] fix: Encode a01 values as json strings Update the a01 internal values to be json strings inside the dictionary. This matches the old API behavior. --- roborock/protocols/a01_protocol.py | 5 +++- tests/protocols/test_a01_protocol.py | 38 ++++++++++++++++++++++------ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/roborock/protocols/a01_protocol.py b/roborock/protocols/a01_protocol.py index 5aa5ffb2..854618e9 100644 --- a/roborock/protocols/a01_protocol.py +++ b/roborock/protocols/a01_protocol.py @@ -26,7 +26,10 @@ def encode_mqtt_payload( | dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any], ) -> RoborockMessage: """Encode payload for A01 commands over MQTT.""" - dps_data = {"dps": data} + # The A01 protocol generally expects values to be encoded as strings. + # We use json.dumps for non-string types to ensure valid JSON formatting + # (e.g. [1, 2] -> "[1, 2]", True -> "true", 123 -> "123"). + dps_data = {"dps": {key: json.dumps(value) for key, value in data.items()}} payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size) return RoborockMessage( protocol=RoborockMessageProtocol.RPC_REQUEST, diff --git a/tests/protocols/test_a01_protocol.py b/tests/protocols/test_a01_protocol.py index 57ad6895..6fec7ca5 100644 --- a/tests/protocols/test_a01_protocol.py +++ b/tests/protocols/test_a01_protocol.py @@ -4,6 +4,8 @@ from typing import Any import pytest +from Crypto.Cipher import AES +from Crypto.Util.Padding import unpad from roborock.exceptions import RoborockException from roborock.protocols.a01_protocol import decode_rpc_response, encode_mqtt_payload @@ -33,8 +35,9 @@ def test_encode_mqtt_payload_basic(): assert len(result.payload) % 16 == 0 # Should be padded to AES block size # Decode the payload to verify structure + # With general stringification, numbers are converted to strings: 42 -> "42" decoded_data = decode_rpc_response(result) - assert decoded_data == {200: {"test": "data", "number": 42}} + assert decoded_data == {200: '{"test": "data", "number": 42}'} def test_encode_mqtt_payload_empty_data(): @@ -52,6 +55,21 @@ def test_encode_mqtt_payload_empty_data(): assert decoded_data == {} +def test_encode_mqtt_payload_list_conversion(): + """Test that lists are converted to string representation (Fix validity).""" + # This verifies the fix where lists must be encoded as strings + data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any] = {RoborockDyadDataProtocol.ID_QUERY: [101, 102]} + + result = encode_mqtt_payload(data) + + # Decode manually to check the raw JSON structure + decoded_json = json.loads(unpad(result.payload, AES.block_size).decode()) + + # ID_QUERY (10000) should be a string "[101, 102]", not a list [101, 102] + assert decoded_json["dps"]["10000"] == "[101, 102]" + assert isinstance(decoded_json["dps"]["10000"], str) + + def test_encode_mqtt_payload_complex_data(): """Test encoding with complex nested data.""" data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any] = { @@ -74,13 +92,17 @@ def test_encode_mqtt_payload_complex_data(): # Decode the payload to verify structure decoded_data = decode_rpc_response(result) assert decoded_data == { - 201: { - "nested": {"deep": {"value": 123}}, - "list": [1, 2, 3, "test"], - "boolean": True, - "null": None, - }, - 204: "simple_value", + 201: json.dumps( + { + "nested": {"deep": {"value": 123}}, + # Note: The list inside the dictionary is NOT converted because + # our fix only targets top-level list values in the dps map + "list": [1, 2, 3, "test"], + "boolean": True, + "null": None, + } + ), + 204: '"simple_value"', } From 87ce3ea7199bfa43b0de73a136eea5ef551f5fa5 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 7 Dec 2025 09:57:32 -0800 Subject: [PATCH 2/4] fix: Update where the string conversion happens --- roborock/devices/traits/a01/__init__.py | 4 +-- roborock/protocols/a01_protocol.py | 5 +--- tests/devices/test_a01_channel.py | 5 +++- tests/protocols/test_a01_protocol.py | 36 ++++++------------------- 4 files changed, 15 insertions(+), 35 deletions(-) diff --git a/roborock/devices/traits/a01/__init__.py b/roborock/devices/traits/a01/__init__.py index afa884e1..e804f23c 100644 --- a/roborock/devices/traits/a01/__init__.py +++ b/roborock/devices/traits/a01/__init__.py @@ -21,7 +21,7 @@ def __init__(self, channel: MqttChannel) -> None: async def query_values(self, protocols: list[RoborockDyadDataProtocol]) -> dict[RoborockDyadDataProtocol, Any]: """Query the device for the values of the given Dyad protocols.""" - params = {RoborockDyadDataProtocol.ID_QUERY: [int(p) for p in protocols]} + params = {RoborockDyadDataProtocol.ID_QUERY: str([int(p) for p in protocols])} return await send_decoded_command(self._channel, params) async def set_value(self, protocol: RoborockDyadDataProtocol, value: Any) -> dict[RoborockDyadDataProtocol, Any]: @@ -41,7 +41,7 @@ def __init__(self, channel: MqttChannel) -> None: async def query_values(self, protocols: list[RoborockZeoProtocol]) -> dict[RoborockZeoProtocol, Any]: """Query the device for the values of the given protocols.""" - params = {RoborockZeoProtocol.ID_QUERY: [int(p) for p in protocols]} + params = {RoborockZeoProtocol.ID_QUERY: str([int(p) for p in protocols])} return await send_decoded_command(self._channel, params) async def set_value(self, protocol: RoborockZeoProtocol, value: Any) -> dict[RoborockZeoProtocol, Any]: diff --git a/roborock/protocols/a01_protocol.py b/roborock/protocols/a01_protocol.py index 854618e9..5aa5ffb2 100644 --- a/roborock/protocols/a01_protocol.py +++ b/roborock/protocols/a01_protocol.py @@ -26,10 +26,7 @@ def encode_mqtt_payload( | dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any], ) -> RoborockMessage: """Encode payload for A01 commands over MQTT.""" - # The A01 protocol generally expects values to be encoded as strings. - # We use json.dumps for non-string types to ensure valid JSON formatting - # (e.g. [1, 2] -> "[1, 2]", True -> "true", 123 -> "123"). - dps_data = {"dps": {key: json.dumps(value) for key, value in data.items()}} + dps_data = {"dps": data} payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size) return RoborockMessage( protocol=RoborockMessageProtocol.RPC_REQUEST, diff --git a/tests/devices/test_a01_channel.py b/tests/devices/test_a01_channel.py index 1939f16d..70f1ade3 100644 --- a/tests/devices/test_a01_channel.py +++ b/tests/devices/test_a01_channel.py @@ -45,6 +45,9 @@ async def test_id_query(mock_mqtt_channel: FakeChannel): result = await send_decoded_command(mock_mqtt_channel, params) # type: ignore[call-overload] # Assertions - assert result == {RoborockDyadDataProtocol.WARM_LEVEL: 101, RoborockDyadDataProtocol.POWER: 75} + assert result == { + RoborockDyadDataProtocol.WARM_LEVEL: 101, + RoborockDyadDataProtocol.POWER: 75, + } mock_mqtt_channel.publish.assert_awaited_once() mock_mqtt_channel.subscribe.assert_awaited_once() diff --git a/tests/protocols/test_a01_protocol.py b/tests/protocols/test_a01_protocol.py index 6fec7ca5..0ec258dd 100644 --- a/tests/protocols/test_a01_protocol.py +++ b/tests/protocols/test_a01_protocol.py @@ -35,9 +35,8 @@ def test_encode_mqtt_payload_basic(): assert len(result.payload) % 16 == 0 # Should be padded to AES block size # Decode the payload to verify structure - # With general stringification, numbers are converted to strings: 42 -> "42" decoded_data = decode_rpc_response(result) - assert decoded_data == {200: '{"test": "data", "number": 42}'} + assert decoded_data == {200: {"test": "data", "number": 42}} def test_encode_mqtt_payload_empty_data(): @@ -55,21 +54,6 @@ def test_encode_mqtt_payload_empty_data(): assert decoded_data == {} -def test_encode_mqtt_payload_list_conversion(): - """Test that lists are converted to string representation (Fix validity).""" - # This verifies the fix where lists must be encoded as strings - data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any] = {RoborockDyadDataProtocol.ID_QUERY: [101, 102]} - - result = encode_mqtt_payload(data) - - # Decode manually to check the raw JSON structure - decoded_json = json.loads(unpad(result.payload, AES.block_size).decode()) - - # ID_QUERY (10000) should be a string "[101, 102]", not a list [101, 102] - assert decoded_json["dps"]["10000"] == "[101, 102]" - assert isinstance(decoded_json["dps"]["10000"], str) - - def test_encode_mqtt_payload_complex_data(): """Test encoding with complex nested data.""" data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any] = { @@ -92,17 +76,13 @@ def test_encode_mqtt_payload_complex_data(): # Decode the payload to verify structure decoded_data = decode_rpc_response(result) assert decoded_data == { - 201: json.dumps( - { - "nested": {"deep": {"value": 123}}, - # Note: The list inside the dictionary is NOT converted because - # our fix only targets top-level list values in the dps map - "list": [1, 2, 3, "test"], - "boolean": True, - "null": None, - } - ), - 204: '"simple_value"', + 201: { + "nested": {"deep": {"value": 123}}, + "list": [1, 2, 3, "test"], + "boolean": True, + "null": None, + }, + 204: "simple_value", } From 89f01090e1b86d5633826615836d6912e6315aa0 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 7 Dec 2025 09:58:45 -0800 Subject: [PATCH 3/4] chore: Remove unnecessary imports --- tests/protocols/test_a01_protocol.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/protocols/test_a01_protocol.py b/tests/protocols/test_a01_protocol.py index 0ec258dd..57ad6895 100644 --- a/tests/protocols/test_a01_protocol.py +++ b/tests/protocols/test_a01_protocol.py @@ -4,8 +4,6 @@ from typing import Any import pytest -from Crypto.Cipher import AES -from Crypto.Util.Padding import unpad from roborock.exceptions import RoborockException from roborock.protocols.a01_protocol import decode_rpc_response, encode_mqtt_payload From 60208c2c77032ca433de50944d9f0a40f62ce721 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 7 Dec 2025 10:32:31 -0800 Subject: [PATCH 4/4] chore: update tests to capture bug fix --- tests/devices/traits/a01/test_init.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/devices/traits/a01/test_init.py b/tests/devices/traits/a01/test_init.py index 366d8ddf..05601eff 100644 --- a/tests/devices/traits/a01/test_init.py +++ b/tests/devices/traits/a01/test_init.py @@ -69,7 +69,7 @@ async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMo call( mock_channel, { - RoborockDyadDataProtocol.ID_QUERY: [209, 201, 207, 214, 215, 227, 229, 230, 222, 224], + RoborockDyadDataProtocol.ID_QUERY: "[209, 201, 207, 214, 215, 227, 229, 230, 222, 224]", }, ), ] @@ -176,7 +176,7 @@ async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMoc call( mock_channel, { - RoborockZeoProtocol.ID_QUERY: [203, 207, 226, 227, 224, 218], + RoborockZeoProtocol.ID_QUERY: "[203, 207, 226, 227, 224, 218]", }, ), ]