Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions roborock/broadcast_protocol.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from __future__ import annotations

import asyncio
import hashlib
import json
import logging
from asyncio import BaseTransport, Lock

from construct import ( # type: ignore
Bytes,
Checksum,
GreedyBytes,
Int16ub,
Int32ub,
Prefixed,
RawCopy,
Struct,
)
from Crypto.Cipher import AES

from roborock import RoborockException
from roborock.containers import BroadcastMessage
from roborock.protocol import EncryptionAdapter, Utils, _Parser

Expand All @@ -29,14 +34,41 @@ def __init__(self, timeout: int = 5):
self.devices_found: list[BroadcastMessage] = []
self._mutex = Lock()

def datagram_received(self, data, _):
[broadcast_message], _ = BroadcastParser.parse(data)
if broadcast_message.payload:
parsed_message = BroadcastMessage.from_dict(json.loads(broadcast_message.payload))
_LOGGER.debug(f"Received broadcast: {parsed_message}")
self.devices_found.append(parsed_message)
def datagram_received(self, data: bytes, _):
"""Handle incoming broadcast datagrams."""
try:
version = data[:3]
if version == b"L01":
[parsed_msg], _ = L01Parser.parse(data)
encrypted_payload = parsed_msg.payload
if encrypted_payload is None:
raise RoborockException("No encrypted payload found in broadcast message")
ciphertext = encrypted_payload[:-16]
tag = encrypted_payload[-16:]

async def discover(self):
key = hashlib.sha256(BROADCAST_TOKEN).digest()
iv_digest_input = data[:9]
digest = hashlib.sha256(iv_digest_input).digest()
iv = digest[:12]

cipher = AES.new(key, AES.MODE_GCM, nonce=iv)
decrypted_payload_bytes = cipher.decrypt_and_verify(ciphertext, tag)
json_payload = json.loads(decrypted_payload_bytes)
parsed_message = BroadcastMessage(duid=json_payload["duid"], ip=json_payload["ip"], version=version)
_LOGGER.debug(f"Received L01 broadcast: {parsed_message}")
self.devices_found.append(parsed_message)
else:
# Fallback to the original protocol parser for other versions
[broadcast_message], _ = BroadcastParser.parse(data)
if broadcast_message.payload:
json_payload = json.loads(broadcast_message.payload)
parsed_message = BroadcastMessage(duid=json_payload["duid"], ip=json_payload["ip"], version=version)
_LOGGER.debug(f"Received broadcast: {parsed_message}")
self.devices_found.append(parsed_message)
except Exception as e:
_LOGGER.warning(f"Failed to decode message: {data!r}. Error: {e}")

async def discover(self) -> list[BroadcastMessage]:
async with self._mutex:
try:
loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -64,5 +96,19 @@ def close(self):
"checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data),
)

_L01BroadcastMessage = Struct(
"message"
/ RawCopy(
Struct(
"version" / Bytes(3),
"field1" / Bytes(4), # Unknown field
"field2" / Bytes(2), # Unknown field
"payload" / Prefixed(Int16ub, GreedyBytes), # Encrypted payload with length prefix
)
),
"checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data),
)


BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)
L01Parser: _Parser = _Parser(_L01BroadcastMessage, False)
1 change: 1 addition & 0 deletions roborock/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ class FlowLedStatus(RoborockBase):
class BroadcastMessage(RoborockBase):
duid: str
ip: str
version: bytes


class ServerTimer(NamedTuple):
Expand Down
4 changes: 2 additions & 2 deletions roborock/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,10 @@ def parse(self, data: bytes, local_key: str | None = None) -> tuple[list[Roboroc
messages.append(
RoborockMessage(
version=message.message.value.version,
seq=message.message.value.seq,
seq=message.message.value.get("seq"),
random=message.message.value.get("random"),
timestamp=message.message.value.get("timestamp"),
protocol=message.message.value.protocol,
protocol=message.message.value.get("protocol"),
payload=message.message.value.payload,
)
)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_broadcast_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from roborock.broadcast_protocol import RoborockProtocol


def test_l01_data():
data = bytes.fromhex(
"4c30310000000000000043841496d5a31e34b5b02c1867c445509ba5a21aec1fa4b307bddeb27a75d9b366193e8a97d0534dc39851c"
"980609f2670cdcaee04594ec5c93e3c5ae609b0c9a203139ac8e40c8c"
)
prot = RoborockProtocol()
prot.datagram_received(data, None)
device = prot.devices_found[0]
assert device.duid == "ZrQn1jfZtJQLoPOL7620e"
assert device.ip == "192.168.1.4"
assert device.version == b"L01"


def test_v1_data():
data = bytes.fromhex(
"312e30000003e003e80040b87035058b439f36af42f249605f8661897173f111bb849a6231831f5874a0cf220a25872ea412d796b4902ee"
"57fdc120074b901b482acb1fe6d06317e3a72ddac654fe0"
)
prot = RoborockProtocol()
prot.datagram_received(data, None)
device = prot.devices_found[0]
assert device.duid == "h96rOV3e8DTPMAOLiypREl"
assert device.ip == "192.168.20.250"
assert device.version == b"1.0"