Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion roborock/devices/a01_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
from collections.abc import Callable
from typing import Any, overload

from roborock.exceptions import RoborockException
Expand Down Expand Up @@ -29,23 +30,26 @@
async def send_decoded_command(
mqtt_channel: MqttChannel,
params: dict[RoborockDyadDataProtocol, Any],
value_encoder: Callable[[Any], Any] | None = None,
) -> dict[RoborockDyadDataProtocol, Any]: ...


@overload
async def send_decoded_command(
mqtt_channel: MqttChannel,
params: dict[RoborockZeoProtocol, Any],
value_encoder: Callable[[Any], Any] | None = None,
) -> dict[RoborockZeoProtocol, Any]: ...


async def send_decoded_command(
mqtt_channel: MqttChannel,
params: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any],
value_encoder: Callable[[Any], Any] | None = None,
) -> dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any]:
"""Send a command on the MQTT channel and get a decoded response."""
_LOGGER.debug("Sending MQTT command: %s", params)
roborock_message = encode_mqtt_payload(params)
roborock_message = encode_mqtt_payload(params, value_encoder)

# For commands that set values: send the command and do not
# block waiting for a response. Queries are handled below.
Expand Down
17 changes: 12 additions & 5 deletions roborock/devices/traits/a01/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from collections.abc import Callable
from datetime import time
from typing import Any
Expand Down Expand Up @@ -121,8 +122,11 @@ def __init__(self, channel: MqttChannel) -> None:

async def query_values(self, protocols: list[RoborockDyadDataProtocol]) -> dict[RoborockDyadDataProtocol, Any]:
"""Query the device for the values of the given Dyad protocols."""
params = {RoborockDyadDataProtocol.ID_QUERY: str([int(p) for p in protocols])}
response = await send_decoded_command(self._channel, params)
response = await send_decoded_command(
self._channel,
{RoborockDyadDataProtocol.ID_QUERY: protocols},
value_encoder=json.dumps,
)
return {protocol: convert_dyad_value(protocol, response.get(protocol)) for protocol in protocols}

async def set_value(self, protocol: RoborockDyadDataProtocol, value: Any) -> dict[RoborockDyadDataProtocol, Any]:
Expand All @@ -142,14 +146,17 @@ def __init__(self, channel: MqttChannel) -> None:

async def query_values(self, protocols: list[RoborockZeoProtocol]) -> dict[RoborockZeoProtocol, Any]:
"""Query the device for the values of the given protocols."""
params = {RoborockZeoProtocol.ID_QUERY: str([int(p) for p in protocols])}
response = await send_decoded_command(self._channel, params)
response = await send_decoded_command(
self._channel,
{RoborockZeoProtocol.ID_QUERY: protocols},
value_encoder=json.dumps,
)
return {protocol: convert_zeo_value(protocol, response.get(protocol)) for protocol in protocols}

async def set_value(self, protocol: RoborockZeoProtocol, value: Any) -> dict[RoborockZeoProtocol, Any]:
"""Set a value for a specific protocol on the device."""
params = {protocol: value}
return await send_decoded_command(self._channel, params)
return await send_decoded_command(self._channel, params, value_encoder=lambda x: x)


def create(product: HomeDataProduct, mqtt_channel: MqttChannel) -> DyadApi | ZeoApi:
Expand Down
20 changes: 18 additions & 2 deletions roborock/protocols/a01_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
from collections.abc import Callable
from typing import Any

from Crypto.Cipher import AES
Expand All @@ -20,13 +21,28 @@
A01_VERSION = b"A01"


def _no_encode(value: Any) -> Any:
return value


def encode_mqtt_payload(
data: dict[RoborockDyadDataProtocol, Any]
| dict[RoborockZeoProtocol, Any]
| dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any],
value_encoder: Callable[[Any], Any] | None = None,
) -> RoborockMessage:
"""Encode payload for A01 commands over MQTT."""
dps_data = {"dps": data}
"""Encode payload for A01 commands over MQTT.

Args:
data: The data to encode.
value_encoder: A function to encode the values of the dictionary.

Returns:
RoborockMessage: The encoded message.
"""
if value_encoder is None:
value_encoder = _no_encode
dps_data = {"dps": {key: value_encoder(value) for key, value in data.items()}}
payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size)
return RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
Expand Down
3 changes: 2 additions & 1 deletion tests/devices/test_a01_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ async def test_id_query(mock_mqtt_channel: FakeChannel):
{
RoborockDyadDataProtocol.WARM_LEVEL: 101,
RoborockDyadDataProtocol.POWER: 75,
}
},
value_encoder=lambda x: x,
)
response_message = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=encoded.payload, version=encoded.version
Expand Down
169 changes: 102 additions & 67 deletions tests/devices/traits/a01/test_init.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,51 @@
import datetime
from collections.abc import Generator
import json
from typing import Any
from unittest.mock import AsyncMock, call, patch

import pytest
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad

from roborock.devices.mqtt_channel import MqttChannel
from roborock.devices.traits.a01 import DyadApi, ZeoApi
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockMessageProtocol, RoborockZeoProtocol
from tests.conftest import FakeChannel
from tests.protocols.common import build_a01_message


@pytest.fixture(name="mock_channel")
def mock_channel_fixture() -> AsyncMock:
return AsyncMock(spec=MqttChannel)
@pytest.fixture(name="fake_channel")
def fake_channel_fixture() -> FakeChannel:
return FakeChannel()


@pytest.fixture(name="mock_send")
def mock_send_fixture(mock_channel) -> Generator[AsyncMock, None, None]:
with patch("roborock.devices.traits.a01.send_decoded_command") as mock_send:
yield mock_send
@pytest.fixture(name="dyad_api")
def dyad_api_fixture(fake_channel: FakeChannel) -> DyadApi:
return DyadApi(fake_channel) # type: ignore[arg-type]


async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMock):
@pytest.fixture(name="zeo_api")
def zeo_api_fixture(fake_channel: FakeChannel) -> ZeoApi:
return ZeoApi(fake_channel) # type: ignore[arg-type]


async def test_dyad_api_query_values(dyad_api: DyadApi, fake_channel: FakeChannel):
"""Test that DyadApi currently returns raw values without conversion."""
api = DyadApi(mock_channel)

mock_send.return_value = {
209: 1, # POWER
201: 6, # STATUS
207: 3, # WATER_LEVEL
214: 120, # MESH_LEFT
215: 90, # BRUSH_LEFT
227: 85, # SILENT_MODE_START_TIME
229: "3,4,5", # RECENT_RUN_TIME
230: 123456, # TOTAL_RUN_TIME
222: 1, # STAND_LOCK_AUTO_RUN
224: 0, # AUTO_DRY_MODE
}
result = await api.query_values(
fake_channel.response_queue.append(
build_a01_message(
{
209: 1, # POWER
201: 6, # STATUS
207: 3, # WATER_LEVEL
214: 120, # MESH_LEFT
215: 90, # BRUSH_LEFT
227: 85, # SILENT_MODE_START_TIME
229: "3,4,5", # RECENT_RUN_TIME
230: 123456, # TOTAL_RUN_TIME
222: 1, # STAND_LOCK_AUTO_RUN
224: 0, # AUTO_DRY_MODE
}
)
)
result = await dyad_api.query_values(
[
RoborockDyadDataProtocol.POWER,
RoborockDyadDataProtocol.STATUS,
Expand All @@ -64,15 +72,12 @@ async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMo
RoborockDyadDataProtocol.AUTO_DRY_MODE: False,
}

# Note: Bug here, this is the wrong encoding for the query
assert mock_send.call_args_list == [
call(
mock_channel,
{
RoborockDyadDataProtocol.ID_QUERY: "[209, 201, 207, 214, 215, 227, 229, 230, 222, 224]",
},
),
]
assert len(fake_channel.published_messages) == 1
message = fake_channel.published_messages[0]
assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
assert message.version == b"A01"
payload_data = json.loads(unpad(message.payload, AES.block_size))
assert payload_data == {"dps": {"10000": "[209, 201, 207, 214, 215, 227, 229, 230, 222, 224]"}}


@pytest.mark.parametrize(
Expand Down Expand Up @@ -117,33 +122,34 @@ async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMo
],
)
async def test_dyad_invalid_response_value(
mock_channel: AsyncMock,
mock_send: AsyncMock,
query: list[RoborockDyadDataProtocol],
response: dict[int, Any],
expected_result: dict[RoborockDyadDataProtocol, Any],
dyad_api: DyadApi,
fake_channel: FakeChannel,
):
"""Test that DyadApi currently returns raw values without conversion."""
api = DyadApi(mock_channel)
fake_channel.response_queue.append(build_a01_message(response))

mock_send.return_value = response
result = await api.query_values(query)
result = await dyad_api.query_values(query)
assert result == expected_result


async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMock):
async def test_zeo_api_query_values(zeo_api: ZeoApi, fake_channel: FakeChannel):
"""Test that ZeoApi currently returns raw values without conversion."""
api = ZeoApi(mock_channel)

mock_send.return_value = {
203: 6, # spinning
207: 3, # medium
226: 1,
227: 0,
224: 1, # Times after clean. Testing int value
218: 0, # Washing left. Testing zero int value
}
result = await api.query_values(
fake_channel.response_queue.append(
build_a01_message(
{
203: 6, # spinning
207: 3, # medium
226: 1,
227: 0,
224: 1, # Times after clean. Testing int value
218: 0, # Washing left. Testing zero int value
}
)
)
result = await zeo_api.query_values(
[
RoborockZeoProtocol.STATE,
RoborockZeoProtocol.TEMP,
Expand All @@ -162,15 +168,13 @@ async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMoc
RoborockZeoProtocol.TIMES_AFTER_CLEAN: 1,
RoborockZeoProtocol.WASHING_LEFT: 0,
}
# Note: Bug here, this is the wrong encoding for the query
assert mock_send.call_args_list == [
call(
mock_channel,
{
RoborockZeoProtocol.ID_QUERY: "[203, 207, 226, 227, 224, 218]",
},
),
]

assert len(fake_channel.published_messages) == 1
message = fake_channel.published_messages[0]
assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
assert message.version == b"A01"
payload_data = json.loads(unpad(message.payload, AES.block_size))
assert payload_data == {"dps": {"10000": "[203, 207, 226, 227, 224, 218]"}}


@pytest.mark.parametrize(
Expand Down Expand Up @@ -215,15 +219,46 @@ async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMoc
],
)
async def test_zeo_invalid_response_value(
mock_channel: AsyncMock,
mock_send: AsyncMock,
query: list[RoborockZeoProtocol],
response: dict[int, Any],
expected_result: dict[RoborockZeoProtocol, Any],
zeo_api: ZeoApi,
fake_channel: FakeChannel,
):
"""Test that ZeoApi currently returns raw values without conversion."""
api = ZeoApi(mock_channel)
fake_channel.response_queue.append(build_a01_message(response))

mock_send.return_value = response
result = await api.query_values(query)
result = await zeo_api.query_values(query)
assert result == expected_result


async def test_dyad_api_set_value(dyad_api: DyadApi, fake_channel: FakeChannel):
"""Test DyadApi set_value sends correct command."""
await dyad_api.set_value(RoborockDyadDataProtocol.POWER, 1)

assert len(fake_channel.published_messages) == 1
message = fake_channel.published_messages[0]

assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
assert message.version == b"A01"

# decode the payload to verify contents
payload_data = json.loads(unpad(message.payload, AES.block_size))
# A01 protocol expects values to be strings in the dps dict
assert payload_data == {"dps": {"209": 1}}


async def test_zeo_api_set_value(zeo_api: ZeoApi, fake_channel: FakeChannel):
"""Test ZeoApi set_value sends correct command."""
await zeo_api.set_value(RoborockZeoProtocol.MODE, "standard")

assert len(fake_channel.published_messages) == 1
message = fake_channel.published_messages[0]

assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
assert message.version == b"A01"

# decode the payload to verify contents
payload_data = json.loads(unpad(message.payload, AES.block_size))
# A01 protocol expects values to be strings in the dps dict
assert payload_data == {"dps": {"204": "standard"}}
26 changes: 26 additions & 0 deletions tests/protocols/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Common test utils for the protocols package."""

import json
from typing import Any

from Crypto.Cipher import AES
from Crypto.Util.Padding import pad

from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol


def build_a01_message(message: dict[Any, Any], seq: int = 2020) -> RoborockMessage:
"""Build an encoded A01 RPC response message."""
return RoborockMessage(
protocol=RoborockMessageProtocol.RPC_RESPONSE,
payload=pad(
json.dumps(
{
"dps": message, # {10000: json.dumps(message)},
}
).encode(),
AES.block_size,
),
version=b"A01",
seq=seq,
)
Loading
Loading