Skip to content

Commit 5a2dac0

Browse files
authored
chore: Move a01 encoding and decoding to a separate module (#417)
* chore: Move a01 encoding and decoding to a separate module * chore: Remove logging code * chore: Revert some logging changes * chore: Remove stale comment in roborock_client_a01.py
1 parent 7c1e3aa commit 5a2dac0

File tree

6 files changed

+391
-100
lines changed

6 files changed

+391
-100
lines changed

roborock/protocols/a01_protocol.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Roborock A01 Protocol encoding and decoding."""
2+
3+
import json
4+
import logging
5+
from typing import Any
6+
7+
from Crypto.Cipher import AES
8+
from Crypto.Util.Padding import pad, unpad
9+
10+
from roborock.exceptions import RoborockException
11+
from roborock.roborock_message import (
12+
RoborockDyadDataProtocol,
13+
RoborockMessage,
14+
RoborockMessageProtocol,
15+
RoborockZeoProtocol,
16+
)
17+
18+
_LOGGER = logging.getLogger(__name__)
19+
20+
A01_VERSION = b"A01"
21+
22+
23+
def encode_mqtt_payload(data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any]) -> RoborockMessage:
24+
"""Encode payload for A01 commands over MQTT."""
25+
dps_data = {"dps": data}
26+
payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size)
27+
return RoborockMessage(
28+
protocol=RoborockMessageProtocol.RPC_REQUEST,
29+
version=A01_VERSION,
30+
payload=payload,
31+
)
32+
33+
34+
def decode_rpc_response(message: RoborockMessage) -> dict[int, Any]:
35+
"""Decode a V1 RPC_RESPONSE message."""
36+
if not message.payload:
37+
raise RoborockException("Invalid A01 message format: missing payload")
38+
try:
39+
unpadded = unpad(message.payload, AES.block_size)
40+
except ValueError as err:
41+
raise RoborockException(f"Unable to unpad A01 payload: {err}")
42+
43+
try:
44+
payload = json.loads(unpadded.decode())
45+
except (json.JSONDecodeError, TypeError) as e:
46+
raise RoborockException(f"Invalid A01 message payload: {e} for {message.payload!r}") from e
47+
48+
datapoints = payload.get("dps", {})
49+
if not isinstance(datapoints, dict):
50+
raise RoborockException(f"Invalid A01 message format: 'dps' should be a dictionary for {message.payload!r}")
51+
try:
52+
return {int(key): value for key, value in datapoints.items()}
53+
except ValueError:
54+
raise RoborockException(f"Invalid A01 message format: 'dps' key should be an integer for {message.payload!r}")
Lines changed: 86 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
1-
import dataclasses
2-
import json
31
import logging
4-
import typing
52
from abc import ABC, abstractmethod
63
from collections.abc import Callable
74
from datetime import time
8-
9-
from Crypto.Cipher import AES
10-
from Crypto.Util.Padding import unpad
5+
from typing import Any
116

127
from roborock import DeviceData
138
from roborock.api import RoborockClient
@@ -33,6 +28,8 @@
3328
ZeoTemperature,
3429
)
3530
from roborock.containers import DyadProductInfo, DyadSndState, RoborockCategory
31+
from roborock.exceptions import RoborockException
32+
from roborock.protocols.a01_protocol import decode_rpc_response
3633
from roborock.roborock_message import (
3734
RoborockDyadDataProtocol,
3835
RoborockMessage,
@@ -43,111 +40,120 @@
4340
_LOGGER = logging.getLogger(__name__)
4441

4542

46-
@dataclasses.dataclass
47-
class A01ProtocolCacheEntry:
48-
post_process_fn: Callable
49-
value: typing.Any | None = None
50-
51-
52-
# Right now this cache is not active, it was too much complexity for the initial addition of dyad.
53-
protocol_entries = {
54-
RoborockDyadDataProtocol.STATUS: A01ProtocolCacheEntry(lambda val: RoborockDyadStateCode(val).name),
55-
RoborockDyadDataProtocol.SELF_CLEAN_MODE: A01ProtocolCacheEntry(lambda val: DyadSelfCleanMode(val).name),
56-
RoborockDyadDataProtocol.SELF_CLEAN_LEVEL: A01ProtocolCacheEntry(lambda val: DyadSelfCleanLevel(val).name),
57-
RoborockDyadDataProtocol.WARM_LEVEL: A01ProtocolCacheEntry(lambda val: DyadWarmLevel(val).name),
58-
RoborockDyadDataProtocol.CLEAN_MODE: A01ProtocolCacheEntry(lambda val: DyadCleanMode(val).name),
59-
RoborockDyadDataProtocol.SUCTION: A01ProtocolCacheEntry(lambda val: DyadSuction(val).name),
60-
RoborockDyadDataProtocol.WATER_LEVEL: A01ProtocolCacheEntry(lambda val: DyadWaterLevel(val).name),
61-
RoborockDyadDataProtocol.BRUSH_SPEED: A01ProtocolCacheEntry(lambda val: DyadBrushSpeed(val).name),
62-
RoborockDyadDataProtocol.POWER: A01ProtocolCacheEntry(lambda val: int(val)),
63-
RoborockDyadDataProtocol.AUTO_DRY: A01ProtocolCacheEntry(lambda val: bool(val)),
64-
RoborockDyadDataProtocol.MESH_LEFT: A01ProtocolCacheEntry(lambda val: int(360000 - val * 60)),
65-
RoborockDyadDataProtocol.BRUSH_LEFT: A01ProtocolCacheEntry(lambda val: int(360000 - val * 60)),
66-
RoborockDyadDataProtocol.ERROR: A01ProtocolCacheEntry(lambda val: DyadError(val).name),
67-
RoborockDyadDataProtocol.VOLUME_SET: A01ProtocolCacheEntry(lambda val: int(val)),
68-
RoborockDyadDataProtocol.STAND_LOCK_AUTO_RUN: A01ProtocolCacheEntry(lambda val: bool(val)),
69-
RoborockDyadDataProtocol.AUTO_DRY_MODE: A01ProtocolCacheEntry(lambda val: bool(val)),
70-
RoborockDyadDataProtocol.SILENT_DRY_DURATION: A01ProtocolCacheEntry(lambda val: int(val)), # in minutes
71-
RoborockDyadDataProtocol.SILENT_MODE: A01ProtocolCacheEntry(lambda val: bool(val)),
72-
RoborockDyadDataProtocol.SILENT_MODE_START_TIME: A01ProtocolCacheEntry(
73-
lambda val: time(hour=int(val / 60), minute=val % 60)
43+
DYAD_PROTOCOL_ENTRIES: dict[RoborockDyadDataProtocol, Callable] = {
44+
RoborockDyadDataProtocol.STATUS: lambda val: RoborockDyadStateCode(val).name,
45+
RoborockDyadDataProtocol.SELF_CLEAN_MODE: lambda val: DyadSelfCleanMode(val).name,
46+
RoborockDyadDataProtocol.SELF_CLEAN_LEVEL: lambda val: DyadSelfCleanLevel(val).name,
47+
RoborockDyadDataProtocol.WARM_LEVEL: lambda val: DyadWarmLevel(val).name,
48+
RoborockDyadDataProtocol.CLEAN_MODE: lambda val: DyadCleanMode(val).name,
49+
RoborockDyadDataProtocol.SUCTION: lambda val: DyadSuction(val).name,
50+
RoborockDyadDataProtocol.WATER_LEVEL: lambda val: DyadWaterLevel(val).name,
51+
RoborockDyadDataProtocol.BRUSH_SPEED: lambda val: DyadBrushSpeed(val).name,
52+
RoborockDyadDataProtocol.POWER: lambda val: int(val),
53+
RoborockDyadDataProtocol.AUTO_DRY: lambda val: bool(val),
54+
RoborockDyadDataProtocol.MESH_LEFT: lambda val: int(360000 - val * 60),
55+
RoborockDyadDataProtocol.BRUSH_LEFT: lambda val: int(360000 - val * 60),
56+
RoborockDyadDataProtocol.ERROR: lambda val: DyadError(val).name,
57+
RoborockDyadDataProtocol.VOLUME_SET: lambda val: int(val),
58+
RoborockDyadDataProtocol.STAND_LOCK_AUTO_RUN: lambda val: bool(val),
59+
RoborockDyadDataProtocol.AUTO_DRY_MODE: lambda val: bool(val),
60+
RoborockDyadDataProtocol.SILENT_DRY_DURATION: lambda val: int(val), # in minutes
61+
RoborockDyadDataProtocol.SILENT_MODE: lambda val: bool(val),
62+
RoborockDyadDataProtocol.SILENT_MODE_START_TIME: lambda val: time(
63+
hour=int(val / 60), minute=val % 60
7464
), # in minutes since 00:00
75-
RoborockDyadDataProtocol.SILENT_MODE_END_TIME: A01ProtocolCacheEntry(
76-
lambda val: time(hour=int(val / 60), minute=val % 60)
65+
RoborockDyadDataProtocol.SILENT_MODE_END_TIME: lambda val: time(
66+
hour=int(val / 60), minute=val % 60
7767
), # in minutes since 00:00
78-
RoborockDyadDataProtocol.RECENT_RUN_TIME: A01ProtocolCacheEntry(
79-
lambda val: [int(v) for v in val.split(",")]
80-
), # minutes of cleaning in past few days.
81-
RoborockDyadDataProtocol.TOTAL_RUN_TIME: A01ProtocolCacheEntry(lambda val: int(val)),
82-
RoborockDyadDataProtocol.SND_STATE: A01ProtocolCacheEntry(lambda val: DyadSndState.from_dict(val)),
83-
RoborockDyadDataProtocol.PRODUCT_INFO: A01ProtocolCacheEntry(lambda val: DyadProductInfo.from_dict(val)),
68+
RoborockDyadDataProtocol.RECENT_RUN_TIME: lambda val: [
69+
int(v) for v in val.split(",")
70+
], # minutes of cleaning in past few days.
71+
RoborockDyadDataProtocol.TOTAL_RUN_TIME: lambda val: int(val),
72+
RoborockDyadDataProtocol.SND_STATE: lambda val: DyadSndState.from_dict(val),
73+
RoborockDyadDataProtocol.PRODUCT_INFO: lambda val: DyadProductInfo.from_dict(val),
8474
}
8575

86-
zeo_data_protocol_entries = {
76+
ZEO_PROTOCOL_ENTRIES: dict[RoborockZeoProtocol, Callable] = {
8777
# ro
88-
RoborockZeoProtocol.STATE: A01ProtocolCacheEntry(lambda val: ZeoState(val).name),
89-
RoborockZeoProtocol.COUNTDOWN: A01ProtocolCacheEntry(lambda val: int(val)),
90-
RoborockZeoProtocol.WASHING_LEFT: A01ProtocolCacheEntry(lambda val: int(val)),
91-
RoborockZeoProtocol.ERROR: A01ProtocolCacheEntry(lambda val: ZeoError(val).name),
92-
RoborockZeoProtocol.TIMES_AFTER_CLEAN: A01ProtocolCacheEntry(lambda val: int(val)),
93-
RoborockZeoProtocol.DETERGENT_EMPTY: A01ProtocolCacheEntry(lambda val: bool(val)),
94-
RoborockZeoProtocol.SOFTENER_EMPTY: A01ProtocolCacheEntry(lambda val: bool(val)),
78+
RoborockZeoProtocol.STATE: lambda val: ZeoState(val).name,
79+
RoborockZeoProtocol.COUNTDOWN: lambda val: int(val),
80+
RoborockZeoProtocol.WASHING_LEFT: lambda val: int(val),
81+
RoborockZeoProtocol.ERROR: lambda val: ZeoError(val).name,
82+
RoborockZeoProtocol.TIMES_AFTER_CLEAN: lambda val: int(val),
83+
RoborockZeoProtocol.DETERGENT_EMPTY: lambda val: bool(val),
84+
RoborockZeoProtocol.SOFTENER_EMPTY: lambda val: bool(val),
9585
# rw
96-
RoborockZeoProtocol.MODE: A01ProtocolCacheEntry(lambda val: ZeoMode(val).name),
97-
RoborockZeoProtocol.PROGRAM: A01ProtocolCacheEntry(lambda val: ZeoProgram(val).name),
98-
RoborockZeoProtocol.TEMP: A01ProtocolCacheEntry(lambda val: ZeoTemperature(val).name),
99-
RoborockZeoProtocol.RINSE_TIMES: A01ProtocolCacheEntry(lambda val: ZeoRinse(val).name),
100-
RoborockZeoProtocol.SPIN_LEVEL: A01ProtocolCacheEntry(lambda val: ZeoSpin(val).name),
101-
RoborockZeoProtocol.DRYING_MODE: A01ProtocolCacheEntry(lambda val: ZeoDryingMode(val).name),
102-
RoborockZeoProtocol.DETERGENT_TYPE: A01ProtocolCacheEntry(lambda val: ZeoDetergentType(val).name),
103-
RoborockZeoProtocol.SOFTENER_TYPE: A01ProtocolCacheEntry(lambda val: ZeoSoftenerType(val).name),
104-
RoborockZeoProtocol.SOUND_SET: A01ProtocolCacheEntry(lambda val: bool(val)),
86+
RoborockZeoProtocol.MODE: lambda val: ZeoMode(val).name,
87+
RoborockZeoProtocol.PROGRAM: lambda val: ZeoProgram(val).name,
88+
RoborockZeoProtocol.TEMP: lambda val: ZeoTemperature(val).name,
89+
RoborockZeoProtocol.RINSE_TIMES: lambda val: ZeoRinse(val).name,
90+
RoborockZeoProtocol.SPIN_LEVEL: lambda val: ZeoSpin(val).name,
91+
RoborockZeoProtocol.DRYING_MODE: lambda val: ZeoDryingMode(val).name,
92+
RoborockZeoProtocol.DETERGENT_TYPE: lambda val: ZeoDetergentType(val).name,
93+
RoborockZeoProtocol.SOFTENER_TYPE: lambda val: ZeoSoftenerType(val).name,
94+
RoborockZeoProtocol.SOUND_SET: lambda val: bool(val),
10595
}
10696

10797

98+
def convert_dyad_value(protocol: int, value: Any) -> Any:
99+
"""Convert a dyad protocol value to its corresponding type."""
100+
protocol_value = RoborockDyadDataProtocol(protocol)
101+
if (converter := DYAD_PROTOCOL_ENTRIES.get(protocol_value)) is not None:
102+
return converter(value)
103+
return None
104+
105+
106+
def convert_zeo_value(protocol: int, value: Any) -> Any:
107+
"""Convert a zeo protocol value to its corresponding type."""
108+
protocol_value = RoborockZeoProtocol(protocol)
109+
if (converter := ZEO_PROTOCOL_ENTRIES.get(protocol_value)) is not None:
110+
return converter(value)
111+
return None
112+
113+
108114
class RoborockClientA01(RoborockClient, ABC):
109115
"""Roborock client base class for A01 devices."""
110116

117+
value_converter: Callable[[int, Any], Any] | None = None
118+
111119
def __init__(self, device_info: DeviceData, category: RoborockCategory):
112120
"""Initialize the Roborock client."""
113121
super().__init__(device_info)
114-
self.category = category
122+
if category == RoborockCategory.WET_DRY_VAC:
123+
self.value_converter = convert_dyad_value
124+
elif category == RoborockCategory.WASHING_MACHINE:
125+
self.value_converter = convert_zeo_value
126+
else:
127+
_LOGGER.debug("Device category %s is not (yet) supported", category)
128+
self.value_converter = None
115129

116130
def on_message_received(self, messages: list[RoborockMessage]) -> None:
131+
if self.value_converter is None:
132+
return
117133
for message in messages:
118134
protocol = message.protocol
119135
if message.payload and protocol in [
120136
RoborockMessageProtocol.RPC_RESPONSE,
121137
RoborockMessageProtocol.GENERAL_REQUEST,
122138
]:
123-
payload = message.payload
124139
try:
125-
payload = unpad(payload, AES.block_size)
126-
except Exception as err:
127-
self._logger.debug("Failed to unpad payload: %s", err)
140+
data_points = decode_rpc_response(message)
141+
except RoborockException as err:
142+
self._logger.debug("Failed to decode message: %s", err)
128143
continue
129-
payload_json = json.loads(payload.decode())
130-
for data_point_number, data_point in payload_json.get("dps").items():
131-
data_point_protocol: RoborockDyadDataProtocol | RoborockZeoProtocol
144+
for data_point_number, data_point in data_points.items():
132145
self._logger.debug("received msg with dps, protocol: %s, %s", data_point_number, protocol)
133-
entries: dict
134-
if self.category == RoborockCategory.WET_DRY_VAC:
135-
data_point_protocol = RoborockDyadDataProtocol(int(data_point_number))
136-
entries = protocol_entries
137-
elif self.category == RoborockCategory.WASHING_MACHINE:
138-
data_point_protocol = RoborockZeoProtocol(int(data_point_number))
139-
entries = zeo_data_protocol_entries
140-
else:
141-
continue
142-
if data_point_protocol in entries:
143-
# Auto convert into data struct we want.
144-
converted_response = entries[data_point_protocol].post_process_fn(data_point)
146+
if converted_response := self.value_converter(data_point_number, data_point):
145147
queue = self._waiting_queue.get(int(data_point_number))
146148
if queue and queue.protocol == protocol:
147149
queue.set_result(converted_response)
150+
else:
151+
self._logger.debug(
152+
"Received unknown data point %s for protocol %s, ignoring", data_point_number, protocol
153+
)
148154

149155
@abstractmethod
150156
async def update_values(
151157
self, dyad_data_protocols: list[RoborockDyadDataProtocol | RoborockZeoProtocol]
152-
) -> dict[RoborockDyadDataProtocol | RoborockZeoProtocol, typing.Any]:
158+
) -> dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any]:
153159
"""This should handle updating for each given protocol."""

roborock/version_a01_apis/roborock_mqtt_client_a01.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import typing
55

66
from Crypto.Cipher import AES
7-
from Crypto.Util.Padding import pad, unpad
7+
from Crypto.Util.Padding import unpad
88

99
from roborock.cloud_api import RoborockMqttClient
1010
from roborock.containers import DeviceData, RoborockCategory, UserData
1111
from roborock.exceptions import RoborockException
12+
from roborock.protocols.a01_protocol import encode_mqtt_payload
1213
from roborock.roborock_message import (
1314
RoborockDyadDataProtocol,
1415
RoborockMessage,
@@ -43,7 +44,6 @@ async def send_message(self, roborock_message: RoborockMessage):
4344
response_protocol = RoborockMessageProtocol.RPC_RESPONSE
4445

4546
m = self._encoder(roborock_message)
46-
# self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
4747
payload = json.loads(unpad(roborock_message.payload, AES.block_size))
4848
futures = []
4949
if "10000" in payload["dps"]:
@@ -56,7 +56,6 @@ async def send_message(self, roborock_message: RoborockMessage):
5656
for i, dps in enumerate(json.loads(payload["dps"]["10000"])):
5757
response = responses[i]
5858
if isinstance(response, BaseException):
59-
self._logger.warning("Timed out get req for %s after %s s", dps, self.queue_timeout)
6059
dps_responses[dps] = None
6160
else:
6261
dps_responses[dps] = response
@@ -65,24 +64,14 @@ async def send_message(self, roborock_message: RoborockMessage):
6564
async def update_values(
6665
self, dyad_data_protocols: list[RoborockDyadDataProtocol | RoborockZeoProtocol]
6766
) -> dict[RoborockDyadDataProtocol | RoborockZeoProtocol, typing.Any]:
68-
payload = {"dps": {RoborockDyadDataProtocol.ID_QUERY: str([int(protocol) for protocol in dyad_data_protocols])}}
69-
return await self.send_message(
70-
RoborockMessage(
71-
protocol=RoborockMessageProtocol.RPC_REQUEST,
72-
version=b"A01",
73-
payload=pad(json.dumps(payload).encode("utf-8"), AES.block_size),
74-
)
67+
message = encode_mqtt_payload(
68+
{RoborockDyadDataProtocol.ID_QUERY: str([int(protocol) for protocol in dyad_data_protocols])}
7569
)
70+
return await self.send_message(message)
7671

7772
async def set_value(
7873
self, protocol: RoborockDyadDataProtocol | RoborockZeoProtocol, value: typing.Any
7974
) -> dict[int, typing.Any]:
8075
"""Set a value for a specific protocol on the A01 device."""
81-
payload = {"dps": {int(protocol): value}}
82-
return await self.send_message(
83-
RoborockMessage(
84-
protocol=RoborockMessageProtocol.RPC_REQUEST,
85-
version=b"A01",
86-
payload=pad(json.dumps(payload).encode("utf-8"), AES.block_size),
87-
)
88-
)
76+
message = encode_mqtt_payload({protocol: value})
77+
return await self.send_message(message)

tests/protocols/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for the protocols package."""

0 commit comments

Comments
 (0)