diff --git a/roborock/protocol.py b/roborock/protocol.py index 9d6cbca5..6d098d20 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -341,9 +341,35 @@ class PrefixedStruct(Struct): def _parse(self, stream, context, path): subcon1 = Peek(Optional(Bytes(3))) peek_version = subcon1.parse_stream(stream, **context) - if peek_version not in (b"1.0", b"A01", b"B01", b"L01"): - subcon2 = Bytes(4) - subcon2.parse_stream(stream, **context) + + valid_versions = (b"1.0", b"A01", b"B01", b"L01") + if peek_version not in valid_versions: + # Current stream position does not start with a valid version. + # Scan forward to find one. + current_pos = stream_tell(stream, path) + # Read remaining data to find a valid header + data = stream.read() + + start_index = -1 + # Find the earliest occurrence of any valid version in a single pass + for i in range(len(data) - 2): + if data[i : i + 3] in valid_versions: + start_index = i + break + + if start_index != -1: + # Found a valid version header at `start_index`. + # Seek to that position (original_pos + index). + if start_index != 4: + # 4 is the typical/expected amount we prune off, + # therefore, we only want a debug if we have a different length. + _LOGGER.debug("Stripping %d bytes of invalid data from stream", start_index) + stream_seek(stream, current_pos + start_index, 0, path) + else: + _LOGGER.debug("No valid version header found in stream, continuing anyways...") + # Seek back to the original position to avoid parsing at EOF + stream_seek(stream, current_pos, 0, path) + return super()._parse(stream, context, path) def _build(self, obj, stream, context, path): @@ -511,6 +537,8 @@ def decode(bytes_data: bytes) -> list[RoborockMessage]: parsed_messages, remaining = MessageParser.parse( buffer, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce ) + if remaining: + _LOGGER.debug("Found %d extra bytes: %s", len(remaining), remaining) buffer = remaining return parsed_messages diff --git a/tests/devices/test_local_channel.py b/tests/devices/test_local_channel.py index 5c1226c8..e4f7dcd2 100644 --- a/tests/devices/test_local_channel.py +++ b/tests/devices/test_local_channel.py @@ -147,9 +147,9 @@ async def test_message_decode_error(local_channel: LocalChannel, caplog: pytest. local_channel._data_received(b"invalid_payload") await asyncio.sleep(0.01) # yield - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert "Failed to decode message" in caplog.records[0].message + warning_records = [record for record in caplog.records if record.levelname == "WARNING"] + assert len(warning_records) == 1 + assert "Failed to decode message" in warning_records[0].message async def test_subscribe_callback( diff --git a/tests/devices/test_local_decoder_padding.py b/tests/devices/test_local_decoder_padding.py new file mode 100644 index 00000000..8f5b6e67 --- /dev/null +++ b/tests/devices/test_local_decoder_padding.py @@ -0,0 +1,89 @@ +from roborock.protocol import create_local_decoder, create_local_encoder +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol + +TEST_LOCAL_KEY = "local_key" + + +def test_decoder_clean_message(): + encoder = create_local_encoder(TEST_LOCAL_KEY) + decoder = create_local_decoder(TEST_LOCAL_KEY) + + msg = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=b"test_payload", + version=b"1.0", + seq=1, + random=123, + ) + encoded = encoder(msg) + + decoded = decoder(encoded) + assert len(decoded) == 1 + assert decoded[0].payload == b"test_payload" + + +def test_decoder_4byte_padding(): + """Test existing behavior: 4 byte padding should be skipped.""" + encoder = create_local_encoder(TEST_LOCAL_KEY) + decoder = create_local_decoder(TEST_LOCAL_KEY) + + msg = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=b"test_payload", + version=b"1.0", + ) + encoded = encoder(msg) + + # Prepend 4 bytes of garbage + garbage = b"\x00\x00\x05\xa1" + data = garbage + encoded + + decoded = decoder(data) + assert len(decoded) == 1 + assert decoded[0].payload == b"test_payload" + + +def test_decoder_variable_padding(): + """Test variable length padding handling.""" + encoder = create_local_encoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456) + decoder = create_local_decoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456) + + msg = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=b"test_payload", + version=b"L01", + ) + encoded = encoder(msg) + + # Prepend 6 bytes of garbage + garbage = b"\x00\x00\x05\xa1\xff\xff" + data = garbage + encoded + + decoded = decoder(data) + assert len(decoded) == 1 + assert decoded[0].payload == b"test_payload" + + +def test_decoder_split_padding_variable(): + """Test variable padding split across chunks.""" + encoder = create_local_encoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456) + decoder = create_local_decoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456) + + msg = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=b"test_payload", + version=b"L01", + ) + encoded = encoder(msg) + + garbage = b"\x00\x00\x05\xa1\xff\xff" # 6 bytes + + # Send garbage + decoded1 = decoder(garbage) + assert len(decoded1) == 0 + + # Send message + decoded2 = decoder(encoded) + + assert len(decoded2) == 1 + assert decoded2[0].payload == b"test_payload" diff --git a/tests/devices/test_mqtt_channel.py b/tests/devices/test_mqtt_channel.py index c9866cbb..85d9a543 100644 --- a/tests/devices/test_mqtt_channel.py +++ b/tests/devices/test_mqtt_channel.py @@ -158,9 +158,9 @@ async def test_message_decode_error( mqtt_message_handler(b"invalid_payload") await asyncio.sleep(0.01) # yield - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert "Failed to decode message" in caplog.records[0].message + warning_records = [record for record in caplog.records if record.levelname == "WARNING"] + assert len(warning_records) == 1 + assert "Failed to decode message" in warning_records[0].message unsub()