Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions roborock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from __future__ import annotations

import asyncio
import base64
import logging
import secrets
import time
from abc import ABC, abstractmethod
from typing import Any
Expand Down Expand Up @@ -37,14 +35,11 @@ class RoborockClient(ABC):
def __init__(self, device_info: DeviceData) -> None:
"""Initialize RoborockClient."""
self.device_info = device_info
self._nonce = secrets.token_bytes(16)
self._waiting_queue: dict[int, RoborockFuture] = {}
self._last_device_msg_in = time.monotonic()
self._last_disconnection = time.monotonic()
self.keep_alive = KEEPALIVE
self._diagnostic_data: dict[str, dict[str, Any]] = {
"misc_info": {"Nonce": base64.b64encode(self._nonce).decode("utf-8")}
}
self._diagnostic_data: dict[str, dict[str, Any]] = {}
self.is_available: bool = True

async def async_release(self) -> None:
Expand Down
20 changes: 20 additions & 0 deletions roborock/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,26 @@ def decrypt_ecb(ciphertext: bytes, token: bytes) -> bytes:
return unpad(decipher.decrypt(ciphertext), AES.block_size)
return ciphertext

@staticmethod
def encrypt_cbc(plaintext: bytes, token: bytes) -> bytes:
"""Encrypt plaintext with a given token using cbc mode.

This is currently used for testing purposes only.

:param bytes plaintext: Plaintext (json) to encrypt
:param bytes token: Token to use
:return: Encrypted bytes
"""
if not isinstance(plaintext, bytes):
raise TypeError("plaintext requires bytes")
Utils.verify_token(token)
iv = bytes(AES.block_size)
cipher = AES.new(token, AES.MODE_CBC, iv)
if plaintext:
plaintext = pad(plaintext, AES.block_size)
return cipher.encrypt(plaintext)
return plaintext

@staticmethod
def decrypt_cbc(ciphertext: bytes, token: bytes) -> bytes:
"""Decrypt ciphertext with a given token using cbc mode.
Expand Down
39 changes: 39 additions & 0 deletions roborock/protocols/v1_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import math
import secrets
import struct
import time
from collections.abc import Callable
from dataclasses import dataclass, field
Expand Down Expand Up @@ -44,6 +45,10 @@ def to_dict(self) -> dict[str, Any]:
"""Convert security data to a dictionary for sending in the payload."""
return {"security": {"endpoint": self.endpoint, "nonce": self.nonce.hex().lower()}}

def to_diagnostic_data(self) -> dict[str, Any]:
"""Convert security data to a dictionary for debugging purposes."""
return {"nonce": self.nonce.hex().lower()}


def create_security_data(rriot: RRiot) -> SecurityData:
"""Create a SecurityData instance for the given endpoint and nonce."""
Expand Down Expand Up @@ -142,3 +147,37 @@ def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
if not isinstance(result, dict):
raise RoborockException(f"Invalid V1 message format: 'result' should be a dictionary for {message.payload!r}")
return result


@dataclass
class MapResponse:
"""Data structure for the V1 Map response."""

request_id: int
"""The request ID of the map response."""

data: bytes
"""The map data, decrypted and decompressed."""


def create_map_response_decoder(security_data: SecurityData) -> Callable[[RoborockMessage], MapResponse]:
"""Create a decoder for V1 map response messages."""

def _decode_map_response(message: RoborockMessage) -> MapResponse:
"""Decode a V1 map response message."""
if not message.payload or len(message.payload) < 24:
raise RoborockException("Invalid V1 map response format: missing payload")
header, body = message.payload[:24], message.payload[24:]
[endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", header)
if not endpoint.decode().startswith(security_data.endpoint):
raise RoborockException(
f"Invalid V1 map response endpoint: {endpoint!r}, expected {security_data.endpoint!r}"
)
try:
decrypted = Utils.decrypt_cbc(body, security_data.nonce)
except ValueError as err:
raise RoborockException("Failed to decode map message payload") from err
decompressed = Utils.decompress(decrypted)
return MapResponse(request_id=request_id, data=decompressed)

return _decode_map_response
31 changes: 14 additions & 17 deletions roborock/version_1_apis/roborock_client_v1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import dataclasses
import json
import struct
import time
from abc import ABC, abstractmethod
from collections.abc import Callable, Coroutine
Expand Down Expand Up @@ -45,7 +44,7 @@
ValleyElectricityTimer,
WashTowelMode,
)
from roborock.protocol import Utils
from roborock.protocols.v1_protocol import MapResponse, SecurityData, create_map_response_decoder
from roborock.roborock_message import (
ROBOROCK_DATA_CONSUMABLE_PROTOCOL,
ROBOROCK_DATA_STATUS_PROTOCOL,
Expand Down Expand Up @@ -150,10 +149,15 @@ class RoborockClientV1(RoborockClient, ABC):
"""Roborock client base class for version 1 devices."""

_listeners: dict[str, ListenerModel] = {}
_map_response_decoder: Callable[[RoborockMessage], MapResponse] | None = None

def __init__(self, device_info: DeviceData, endpoint: str):
def __init__(self, device_info: DeviceData, security_data: SecurityData | None) -> None:
"""Initializes the Roborock client."""
super().__init__(device_info)
if security_data is not None:
self._diagnostic_data.update({"misc_info": security_data.to_diagnostic_data()})
self._map_response_decoder = create_map_response_decoder(security_data)

self._status_type: type[Status] = ModelStatus.get(device_info.model, S7MaxVStatus)
self.cache: dict[CacheableAttribute, AttributeCache] = {
cacheable_attribute: AttributeCache(attr, self._send_command)
Expand All @@ -162,7 +166,6 @@ def __init__(self, device_info: DeviceData, endpoint: str):
if device_info.device.duid not in self._listeners:
self._listeners[device_info.device.duid] = ListenerModel({}, self.cache)
self.listener_model = self._listeners[device_info.device.duid]
self._endpoint = endpoint

async def async_release(self) -> None:
await super().async_release()
Expand Down Expand Up @@ -429,21 +432,15 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
dps = {data_point_number: data_point}
self._logger.debug(f"Got unknown data point {dps}")
elif data.payload and protocol == RoborockMessageProtocol.MAP_RESPONSE:
payload = data.payload[0:24]
[endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", payload)
if endpoint.decode().startswith(self._endpoint):
try:
decrypted = Utils.decrypt_cbc(data.payload[24:], self._nonce)
except ValueError as err:
raise RoborockException(f"Failed to decode {data.payload!r} for {data.protocol}") from err
decompressed = Utils.decompress(decrypted)
queue = self._waiting_queue.get(request_id)
if self._map_response_decoder is not None:
map_response = self._map_response_decoder(data)
queue = self._waiting_queue.get(map_response.request_id)
if queue:
if isinstance(decompressed, list):
decompressed = decompressed[0]
queue.set_result(decompressed)
queue.set_result(map_response.data)
else:
self._logger.debug("Received response for unknown request id %s", request_id)
self._logger.debug(
"Received unsolicited map response for request_id %s", map_response.request_id
)
else:
queue = self._waiting_queue.get(data.seq)
if queue:
Expand Down
2 changes: 1 addition & 1 deletion roborock/version_1_apis/roborock_local_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
self.transport: Transport | None = None
self._mutex = Lock()
self.keep_alive_task: TimerHandle | None = None
RoborockClientV1.__init__(self, device_data, "abc")
RoborockClientV1.__init__(self, device_data, security_data=None)
RoborockClient.__init__(self, device_data)
self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)
self._encoder: Encoder = create_local_encoder(device_data.device.local_key)
Expand Down
13 changes: 4 additions & 9 deletions roborock/version_1_apis/roborock_mqtt_client_v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import base64
import logging

from vacuum_map_parser_base.config.color import ColorsPalette
Expand All @@ -10,8 +9,7 @@

from ..containers import DeviceData, UserData
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
from ..protocol import Utils
from ..protocols.v1_protocol import SecurityData, create_mqtt_payload_encoder
from ..protocols.v1_protocol import create_mqtt_payload_encoder, create_security_data
from ..roborock_message import (
RoborockMessageProtocol,
)
Expand All @@ -30,15 +28,12 @@ def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout:
rriot = user_data.rriot
if rriot is None:
raise RoborockException("Got no rriot data from user_data")
endpoint = base64.b64encode(Utils.md5(rriot.k.encode())[8:14]).decode()

security_data = create_security_data(rriot)
RoborockMqttClient.__init__(self, user_data, device_info)
RoborockClientV1.__init__(self, device_info, endpoint)
RoborockClientV1.__init__(self, device_info, security_data=security_data)
self.queue_timeout = queue_timeout
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
self._payload_encoder = create_mqtt_payload_encoder(
SecurityData(endpoint=self._endpoint, nonce=self._nonce),
)
self._payload_encoder = create_mqtt_payload_encoder(security_data)

async def _send_command(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/mock_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
BASE_URL = "https://usiot.roborock.com"

USER_ID = "user123"
K_VALUE = "domain123"
K_VALUE = "qiCNieZa"
USER_DATA = {
"uid": 123456,
"tokentype": "token_type",
Expand Down
79 changes: 77 additions & 2 deletions tests/protocols/test_v1_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from freezegun import freeze_time

from roborock.containers import RoborockBase, UserData
from roborock.exceptions import RoborockException
from roborock.protocol import Utils
from roborock.protocols.v1_protocol import (
SecurityData,
create_map_response_decoder,
create_mqtt_payload_encoder,
decode_rpc_response,
encode_local_payload,
Expand All @@ -20,7 +23,12 @@

USER_DATA = UserData.from_dict(mock_data.USER_DATA)
TEST_REQUEST_ID = 44444
SECURITY_DATA = SecurityData(endpoint="3PBTIjvc", nonce=b"fake-nonce")
TEST_ENDPOINT = "87ItGWdb"
TEST_ENDPOINT_BYTES = TEST_ENDPOINT.encode()
SECURITY_DATA = SecurityData(
endpoint=TEST_ENDPOINT,
nonce=b"\x91\xbe\x10\xc9b+\x9d\x8a\xcdH*\x19\xf6\xfe\x81h",
)


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -62,7 +70,7 @@ def test_encode_local_payload(command, params, expected):
(
RoborockCommand.GET_STATUS,
None,
b'{"dps":{"101":"{\\"id\\":44444,\\"method\\":\\"get_status\\",\\"params\\":[],\\"security\\":{\\"endpoint\\":\\"3PBTIjvc\\",\\"nonce\\":\\"66616b652d6e6f6e6365\\"}}"},"t":1737374400}',
b'{"dps":{"101":"{\\"id\\":44444,\\"method\\":\\"get_status\\",\\"params\\":[],\\"security\\":{\\"endpoint\\":\\"87ItGWdb\\",\\"nonce\\":\\"91be10c9622b9d8acd482a19f6fe8168\\"}}"},"t":1737374400}',
)
],
)
Expand Down Expand Up @@ -122,3 +130,70 @@ def test_decode_rpc_response(payload: bytes, expected: RoborockBase) -> None:
)
decoded_message = decode_rpc_response(message)
assert decoded_message == expected


def test_create_map_response_decoder():
"""Test creating and using a map response decoder."""
test_data = b"some map\n"
compressed_data = (
b"\x1f\x8b\x08\x08\xf9\x13\x99h\x00\x03foo\x00+\xce\xcfMU\xc8M,\xe0\x02\x00@\xdb\xc6\x1a\t\x00\x00\x00"
)

# Create header: endpoint(8) + padding(8) + request_id(2) + padding(6)
# request_id = 44508 (0xaddc in little endian)
header = TEST_ENDPOINT_BYTES + b"\x00" * 8 + b"\xdc\xad" + b"\x00" * 6
encrypted_data = Utils.encrypt_cbc(compressed_data, SECURITY_DATA.nonce)
payload = header + encrypted_data

message = RoborockMessage(
protocol=RoborockMessageProtocol.MAP_RESPONSE,
payload=payload,
seq=12750,
version=b"1.0",
random=97431,
timestamp=1652547161,
)

decoder = create_map_response_decoder(SECURITY_DATA)
result = decoder(message)

assert result.request_id == 44508
assert result.data == test_data


def test_create_map_response_decoder_invalid_endpoint():
"""Test map response decoder with invalid endpoint."""
# Create header with wrong endpoint
header = b"wrongend" + b"\x00" * 8 + b"\xdc\xad" + b"\x00" * 6
payload = header + b"encrypted_data"

message = RoborockMessage(
protocol=RoborockMessageProtocol.MAP_RESPONSE,
payload=payload,
seq=12750,
version=b"1.0",
random=97431,
timestamp=1652547161,
)

decoder = create_map_response_decoder(SECURITY_DATA)

with pytest.raises(RoborockException, match="Invalid V1 map response endpoint"):
decoder(message)


def test_create_map_response_decoder_invalid_payload():
"""Test map response decoder with invalid payload."""
message = RoborockMessage(
protocol=RoborockMessageProtocol.MAP_RESPONSE,
payload=b"short", # Too short payload
seq=12750,
version=b"1.0",
random=97431,
timestamp=1652547161,
)

decoder = create_map_response_decoder(SECURITY_DATA)

with pytest.raises(RoborockException, match="Invalid V1 map response format: missing payload"):
decoder(message)
3 changes: 2 additions & 1 deletion tests/test_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
CONSUMABLE,
DND_TIMER,
HOME_DATA_RAW,
K_VALUE,
LOCAL_KEY,
PRODUCT_ID,
STATUS,
Expand Down Expand Up @@ -130,7 +131,7 @@ def test_user_data():
assert ud.rriot.u == "user123"
assert ud.rriot.s == "pass123"
assert ud.rriot.h == "unknown123"
assert ud.rriot.k == "domain123"
assert ud.rriot.k == K_VALUE
assert ud.rriot.r.r == "US"
assert ud.rriot.r.a == "https://api-us.roborock.com"
assert ud.rriot.r.m == "tcp://mqtt-us.roborock.com:8883"
Expand Down