diff --git a/roborock/broadcast_protocol.py b/roborock/broadcast_protocol.py index 93b5b0a7..e1f69b18 100644 --- a/roborock/broadcast_protocol.py +++ b/roborock/broadcast_protocol.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import hashlib import json import logging from asyncio import BaseTransport, Lock @@ -8,12 +9,16 @@ from construct import ( # type: ignore Bytes, Checksum, + GreedyBytes, Int16ub, Int32ub, + Prefixed, RawCopy, Struct, ) +from Crypto.Cipher import AES +from roborock import RoborockException from roborock.containers import BroadcastMessage from roborock.protocol import EncryptionAdapter, Utils, _Parser @@ -29,14 +34,41 @@ def __init__(self, timeout: int = 5): self.devices_found: list[BroadcastMessage] = [] self._mutex = Lock() - def datagram_received(self, data, _): - [broadcast_message], _ = BroadcastParser.parse(data) - if broadcast_message.payload: - parsed_message = BroadcastMessage.from_dict(json.loads(broadcast_message.payload)) - _LOGGER.debug(f"Received broadcast: {parsed_message}") - self.devices_found.append(parsed_message) + def datagram_received(self, data: bytes, _): + """Handle incoming broadcast datagrams.""" + try: + version = data[:3] + if version == b"L01": + [parsed_msg], _ = L01Parser.parse(data) + encrypted_payload = parsed_msg.payload + if encrypted_payload is None: + raise RoborockException("No encrypted payload found in broadcast message") + ciphertext = encrypted_payload[:-16] + tag = encrypted_payload[-16:] - async def discover(self): + key = hashlib.sha256(BROADCAST_TOKEN).digest() + iv_digest_input = data[:9] + digest = hashlib.sha256(iv_digest_input).digest() + iv = digest[:12] + + cipher = AES.new(key, AES.MODE_GCM, nonce=iv) + decrypted_payload_bytes = cipher.decrypt_and_verify(ciphertext, tag) + json_payload = json.loads(decrypted_payload_bytes) + parsed_message = BroadcastMessage(duid=json_payload["duid"], ip=json_payload["ip"], version=version) + _LOGGER.debug(f"Received L01 broadcast: {parsed_message}") + self.devices_found.append(parsed_message) + else: + # Fallback to the original protocol parser for other versions + [broadcast_message], _ = BroadcastParser.parse(data) + if broadcast_message.payload: + json_payload = json.loads(broadcast_message.payload) + parsed_message = BroadcastMessage(duid=json_payload["duid"], ip=json_payload["ip"], version=version) + _LOGGER.debug(f"Received broadcast: {parsed_message}") + self.devices_found.append(parsed_message) + except Exception as e: + _LOGGER.warning(f"Failed to decode message: {data!r}. Error: {e}") + + async def discover(self) -> list[BroadcastMessage]: async with self._mutex: try: loop = asyncio.get_event_loop() @@ -64,5 +96,19 @@ def close(self): "checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data), ) +_L01BroadcastMessage = Struct( + "message" + / RawCopy( + Struct( + "version" / Bytes(3), + "field1" / Bytes(4), # Unknown field + "field2" / Bytes(2), # Unknown field + "payload" / Prefixed(Int16ub, GreedyBytes), # Encrypted payload with length prefix + ) + ), + "checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data), +) + BroadcastParser: _Parser = _Parser(_BroadcastMessage, False) +L01Parser: _Parser = _Parser(_L01BroadcastMessage, False) diff --git a/roborock/containers.py b/roborock/containers.py index 0c0ce00b..4e78c8f0 100644 --- a/roborock/containers.py +++ b/roborock/containers.py @@ -783,6 +783,7 @@ class FlowLedStatus(RoborockBase): class BroadcastMessage(RoborockBase): duid: str ip: str + version: bytes class ServerTimer(NamedTuple): diff --git a/roborock/protocol.py b/roborock/protocol.py index 08b04cca..fdc52c10 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -304,10 +304,10 @@ def parse(self, data: bytes, local_key: str | None = None) -> tuple[list[Roboroc messages.append( RoborockMessage( version=message.message.value.version, - seq=message.message.value.seq, + seq=message.message.value.get("seq"), random=message.message.value.get("random"), timestamp=message.message.value.get("timestamp"), - protocol=message.message.value.protocol, + protocol=message.message.value.get("protocol"), payload=message.message.value.payload, ) ) diff --git a/tests/test_broadcast_protocol.py b/tests/test_broadcast_protocol.py new file mode 100644 index 00000000..062d1909 --- /dev/null +++ b/tests/test_broadcast_protocol.py @@ -0,0 +1,27 @@ +from roborock.broadcast_protocol import RoborockProtocol + + +def test_l01_data(): + data = bytes.fromhex( + "4c30310000000000000043841496d5a31e34b5b02c1867c445509ba5a21aec1fa4b307bddeb27a75d9b366193e8a97d0534dc39851c" + "980609f2670cdcaee04594ec5c93e3c5ae609b0c9a203139ac8e40c8c" + ) + prot = RoborockProtocol() + prot.datagram_received(data, None) + device = prot.devices_found[0] + assert device.duid == "ZrQn1jfZtJQLoPOL7620e" + assert device.ip == "192.168.1.4" + assert device.version == b"L01" + + +def test_v1_data(): + data = bytes.fromhex( + "312e30000003e003e80040b87035058b439f36af42f249605f8661897173f111bb849a6231831f5874a0cf220a25872ea412d796b4902ee" + "57fdc120074b901b482acb1fe6d06317e3a72ddac654fe0" + ) + prot = RoborockProtocol() + prot.datagram_received(data, None) + device = prot.devices_found[0] + assert device.duid == "h96rOV3e8DTPMAOLiypREl" + assert device.ip == "192.168.20.250" + assert device.version == b"1.0"