Skip to content

Commit 8925d66

Browse files
committed
chore: tests
1 parent 5ac9718 commit 8925d66

File tree

4 files changed

+208
-2
lines changed

4 files changed

+208
-2
lines changed

roborock/protocols/b01_protocol.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def decode_rpc_response(message: RoborockMessage) -> dict[int, Any]:
4747
raise RoborockException("Invalid B01 message format: missing payload")
4848
try:
4949
unpadded = unpad(message.payload, AES.block_size)
50-
except ValueError as err:
51-
raise RoborockException(f"Unable to unpad B01 payload: {err}")
50+
except ValueError:
51+
# It would be better to fail down the line.
52+
unpadded = message.payload
5253

5354
try:
5455
payload = json.loads(unpadded.decode())
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import json
2+
from typing import Any
3+
from unittest.mock import patch
4+
5+
import pytest
6+
from Crypto.Cipher import AES
7+
from Crypto.Util.Padding import pad, unpad
8+
9+
from roborock.data.b01_q7 import WorkStatusMapping
10+
from roborock.devices.traits.b01.q7 import Q7PropertiesApi
11+
from roborock.protocols.b01_protocol import B01_VERSION
12+
from roborock.roborock_message import RoborockB01Props, RoborockMessage, RoborockMessageProtocol
13+
from tests.conftest import FakeChannel
14+
15+
16+
def build_b01_message(message: dict[Any, Any], msg_id: str = "123456789", seq: int = 2020) -> RoborockMessage:
17+
"""Build an encoded B01 RPC response message."""
18+
dps_payload = {
19+
"dps": {
20+
"10000": json.dumps(
21+
{
22+
"msgId": msg_id,
23+
"data": message,
24+
}
25+
)
26+
}
27+
}
28+
return RoborockMessage(
29+
protocol=RoborockMessageProtocol.RPC_RESPONSE,
30+
payload=pad(
31+
json.dumps(dps_payload).encode(),
32+
AES.block_size,
33+
),
34+
version=b"B01",
35+
seq=seq,
36+
)
37+
38+
39+
@pytest.fixture(name="fake_channel")
40+
def fake_channel_fixture() -> FakeChannel:
41+
return FakeChannel()
42+
43+
44+
@pytest.fixture(name="q7_api")
45+
def q7_api_fixture(fake_channel: FakeChannel) -> Q7PropertiesApi:
46+
return Q7PropertiesApi(fake_channel) # type: ignore[arg-type]
47+
48+
49+
async def test_q7_api_query_values(q7_api: Q7PropertiesApi, fake_channel: FakeChannel):
50+
"""Test that Q7PropertiesApi correctly converts raw values."""
51+
expected_msg_id = "123456789"
52+
53+
# We need to construct the expected result based on the mappings
54+
# status: 1 -> WAITING_FOR_ORDERS
55+
# wind: 1 -> STANDARD
56+
response_data = {
57+
"status": 1,
58+
"wind": 1,
59+
"battery": 100,
60+
}
61+
62+
# Patch get_next_int to return our expected msg_id so the channel waits for it
63+
with patch("roborock.devices.b01_channel.get_next_int", return_value=int(expected_msg_id)):
64+
# Queue the response
65+
fake_channel.response_queue.append(build_b01_message(response_data, msg_id=expected_msg_id))
66+
67+
result = await q7_api.query_values(
68+
[
69+
RoborockB01Props.STATUS,
70+
RoborockB01Props.WIND,
71+
]
72+
)
73+
74+
assert result is not None
75+
assert result.status == WorkStatusMapping.WAITING_FOR_ORDERS
76+
# wind might be mapped to SCWindMapping.STANDARD (1)
77+
# let's verify checking the prop definition in B01Props
78+
# wind: SCWindMapping | None = None
79+
# SCWindMapping.STANDARD is 1 ('balanced')
80+
from roborock.data.b01_q7 import SCWindMapping
81+
82+
assert result.wind == SCWindMapping.STANDARD
83+
84+
assert len(fake_channel.published_messages) == 1
85+
message = fake_channel.published_messages[0]
86+
assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
87+
assert message.version == B01_VERSION
88+
89+
# Verify request payload
90+
payload_data = json.loads(unpad(message.payload, AES.block_size))
91+
# {"dps": {"10000": {"method": "prop.get", "msgId": "123456789", "params": {"property": ["status", "wind"]}}}}
92+
assert "dps" in payload_data
93+
assert "10000" in payload_data["dps"]
94+
inner = payload_data["dps"]["10000"]
95+
assert inner["method"] == "prop.get"
96+
assert inner["msgId"] == expected_msg_id
97+
assert inner["params"] == {"property": [RoborockB01Props.STATUS, RoborockB01Props.WIND]}
98+
99+
100+
@pytest.mark.parametrize(
101+
("query", "response_data", "expected_status"),
102+
[
103+
(
104+
[RoborockB01Props.STATUS],
105+
{"status": 2},
106+
WorkStatusMapping.PAUSED,
107+
),
108+
(
109+
[RoborockB01Props.STATUS],
110+
{"status": 5},
111+
WorkStatusMapping.SWEEP_MOPING,
112+
),
113+
],
114+
)
115+
async def test_q7_response_value_mapping(
116+
query: list[RoborockB01Props],
117+
response_data: dict[str, Any],
118+
expected_status: WorkStatusMapping,
119+
q7_api: Q7PropertiesApi,
120+
fake_channel: FakeChannel,
121+
):
122+
"""Test Q7PropertiesApi value mapping for different statuses."""
123+
msg_id = "987654321"
124+
125+
with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)):
126+
fake_channel.response_queue.append(build_b01_message(response_data, msg_id=msg_id))
127+
128+
result = await q7_api.query_values(query)
129+
130+
assert result is not None
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Tests for the B01 protocol message encoding and decoding."""
2+
3+
import json
4+
import pathlib
5+
from collections.abc import Generator
6+
7+
import pytest
8+
from Crypto.Cipher import AES
9+
from Crypto.Util.Padding import unpad
10+
from freezegun import freeze_time
11+
from syrupy import SnapshotAssertion
12+
13+
from roborock.protocols.b01_protocol import (
14+
decode_rpc_response,
15+
encode_mqtt_payload,
16+
)
17+
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
18+
19+
TESTDATA_PATH = pathlib.Path("tests/protocols/testdata/b01_protocol")
20+
TESTDATA_FILES = list(TESTDATA_PATH.glob("**/*.json"))
21+
TESTDATA_IDS = [x.stem for x in TESTDATA_FILES]
22+
23+
24+
@pytest.fixture(autouse=True)
25+
def fixed_time_fixture() -> Generator[None, None, None]:
26+
"""Fixture to freeze time for predictable request IDs."""
27+
with freeze_time("2025-01-20T12:00:00"):
28+
yield
29+
30+
31+
@pytest.mark.parametrize("filename", TESTDATA_FILES, ids=TESTDATA_IDS)
32+
def test_decode_rpc_payload(filename: str, snapshot: SnapshotAssertion) -> None:
33+
"""Test decoding a B01 RPC response protocol message."""
34+
with open(filename, "rb") as f:
35+
payload = f.read()
36+
37+
message = RoborockMessage(
38+
protocol=RoborockMessageProtocol.RPC_RESPONSE,
39+
payload=payload,
40+
seq=12750,
41+
version=b"B01",
42+
random=97431,
43+
timestamp=1652547161,
44+
)
45+
46+
decoded_message = decode_rpc_response(message)
47+
assert json.dumps(decoded_message, indent=2) == snapshot
48+
49+
50+
@pytest.mark.parametrize(
51+
("dps", "command", "params", "msg_id"),
52+
[
53+
(
54+
10000,
55+
"prop.get",
56+
{"property": ["status", "fault"]},
57+
"123456789",
58+
),
59+
],
60+
)
61+
def test_encode_mqtt_payload(dps: int, command: str, params: dict[str, list[str]], msg_id: str) -> None:
62+
"""Test encoding of MQTT payload for B01 commands."""
63+
64+
message = encode_mqtt_payload(dps, command, params, msg_id)
65+
assert isinstance(message, RoborockMessage)
66+
assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
67+
assert message.version == b"B01"
68+
assert message.payload is not None
69+
unpadded = unpad(message.payload, AES.block_size)
70+
decoded_json = json.loads(unpadded.decode("utf-8"))
71+
72+
assert decoded_json["dps"][str(dps)]["method"] == command
73+
assert decoded_json["dps"][str(dps)]["msgId"] == msg_id
74+
assert decoded_json["dps"][str(dps)]["params"] == params
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"t":1765660648,"dps":{"10001":"{\"msgId\":\"200000000001\",\"code\":0,\"method\":\"prop.get\",\"data\":{\"status\":4,\"main_brush\":4088}}"}}

0 commit comments

Comments
 (0)