From 49ffe1f697635ea35e67bdc303ee41a731b97d35 Mon Sep 17 00:00:00 2001 From: Luke Date: Tue, 9 Dec 2025 11:47:45 -0500 Subject: [PATCH 1/6] fix: handle random length bytes before version bytes --- roborock/protocol.py | 29 ++++++- tests/devices/test_local_decoder_padding.py | 89 +++++++++++++++++++++ 2 files changed, 115 insertions(+), 3 deletions(-) create mode 100644 tests/devices/test_local_decoder_padding.py diff --git a/roborock/protocol.py b/roborock/protocol.py index 9d6cbca5..229a436d 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -341,9 +341,32 @@ class PrefixedStruct(Struct): def _parse(self, stream, context, path): subcon1 = Peek(Optional(Bytes(3))) peek_version = subcon1.parse_stream(stream, **context) - if peek_version not in (b"1.0", b"A01", b"B01", b"L01"): - subcon2 = Bytes(4) - subcon2.parse_stream(stream, **context) + + valid_versions = (b"1.0", b"A01", b"B01", b"L01") + if peek_version not in valid_versions: + # Current stream position does not start with a valid version. + # Scan forward to find one. + current_pos = stream_tell(stream, path) + # Read remaining data to find a valid header + data = stream.read() + + start_index = -1 + # Find the earliest occurrence of any valid version + candidates = [idx for v in valid_versions if (idx := data.find(v)) != -1] + if candidates: + start_index = min(candidates) + + if start_index != -1: + # Found a valid version header at `start_index`. + # Seek to that position (original_pos + index). + if start_index != 4: + # 4 is the typical/expected amount we prune off, + # therefore, we only want a debug if we have a different length. + _LOGGER.debug("Stripping %d bytes of invalid data from stream", start_index) + stream_seek(stream, current_pos + start_index, 0, path) + else: + _LOGGER.debug("No valid version header found in stream, continuing anyways...") + return super()._parse(stream, context, path) def _build(self, obj, stream, context, path): diff --git a/tests/devices/test_local_decoder_padding.py b/tests/devices/test_local_decoder_padding.py new file mode 100644 index 00000000..ef5bea3c --- /dev/null +++ b/tests/devices/test_local_decoder_padding.py @@ -0,0 +1,89 @@ +from roborock.protocol import create_local_decoder, create_local_encoder +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol + +TEST_LOCAL_KEY = "local_key" + + +def test_decoder_clean_message(): + encoder = create_local_encoder(TEST_LOCAL_KEY) + decoder = create_local_decoder(TEST_LOCAL_KEY) + + msg = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=b"test_payload", + version=b"1.0", + seq=1, + random=123, + ) + encoded = encoder(msg) + + decoded = decoder(encoded) + assert len(decoded) == 1 + assert decoded[0].payload == b"test_payload" + + +def test_decoder_4byte_padding(): + """Test existing behavior: 4 byte padding should be skipped.""" + encoder = create_local_encoder(TEST_LOCAL_KEY) + decoder = create_local_decoder(TEST_LOCAL_KEY) + + msg = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=b"test_payload", + version=b"1.0", + ) + encoded = encoder(msg) + + # Prepend 4 bytes of garbage + garbage = b"\x00\x00\x05\xa1" + data = garbage + encoded + + decoded = decoder(data) + assert len(decoded) == 1 + assert decoded[0].payload == b"test_payload" + + +def test_decoder_variable_padding(): + """Test proposed fix: variable length padding should be skipped.""" + encoder = create_local_encoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456) + decoder = create_local_decoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456) + + msg = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=b"test_payload", + version=b"L01", + ) + encoded = encoder(msg) + + # Prepend 6 bytes of garbage + garbage = b"\x00\x00\x05\xa1\xff\xff" + data = garbage + encoded + + decoded = decoder(data) + assert len(decoded) == 1 + assert decoded[0].payload == b"test_payload" + + +def test_decoder_split_padding_variable(): + """Test variable padding split across chunks.""" + encoder = create_local_encoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456) + decoder = create_local_decoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456) + + msg = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=b"test_payload", + version=b"L01", + ) + encoded = encoder(msg) + + garbage = b"\x00\x00\x05\xa1\xff\xff" # 6 bytes + + # Send garbage + decoded1 = decoder(garbage) + assert len(decoded1) == 0 + + # Send message + decoded2 = decoder(encoded) + + assert len(decoded2) == 1 + assert decoded2[0].payload == b"test_payload" From 37d7cf148b22ed75cccaa88c6664bb6d18a3f569 Mon Sep 17 00:00:00 2001 From: Luke Date: Tue, 9 Dec 2025 12:16:51 -0500 Subject: [PATCH 2/6] fix: filter tests to be warnings only --- tests/devices/test_local_channel.py | 6 +++--- tests/devices/test_mqtt_channel.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/devices/test_local_channel.py b/tests/devices/test_local_channel.py index 5c1226c8..e4f7dcd2 100644 --- a/tests/devices/test_local_channel.py +++ b/tests/devices/test_local_channel.py @@ -147,9 +147,9 @@ async def test_message_decode_error(local_channel: LocalChannel, caplog: pytest. local_channel._data_received(b"invalid_payload") await asyncio.sleep(0.01) # yield - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert "Failed to decode message" in caplog.records[0].message + warning_records = [record for record in caplog.records if record.levelname == "WARNING"] + assert len(warning_records) == 1 + assert "Failed to decode message" in warning_records[0].message async def test_subscribe_callback( diff --git a/tests/devices/test_mqtt_channel.py b/tests/devices/test_mqtt_channel.py index c9866cbb..85d9a543 100644 --- a/tests/devices/test_mqtt_channel.py +++ b/tests/devices/test_mqtt_channel.py @@ -158,9 +158,9 @@ async def test_message_decode_error( mqtt_message_handler(b"invalid_payload") await asyncio.sleep(0.01) # yield - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert "Failed to decode message" in caplog.records[0].message + warning_records = [record for record in caplog.records if record.levelname == "WARNING"] + assert len(warning_records) == 1 + assert "Failed to decode message" in warning_records[0].message unsub() From 46bceb2b126e20df267bcebd5b50163cc0d39437 Mon Sep 17 00:00:00 2001 From: Luke Lashley Date: Tue, 9 Dec 2025 19:42:28 -0500 Subject: [PATCH 3/6] chore: apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- roborock/protocol.py | 2 ++ tests/devices/test_local_decoder_padding.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/roborock/protocol.py b/roborock/protocol.py index 229a436d..0ed5f4c8 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -366,6 +366,8 @@ def _parse(self, stream, context, path): stream_seek(stream, current_pos + start_index, 0, path) else: _LOGGER.debug("No valid version header found in stream, continuing anyways...") + # Seek back to the original position to avoid parsing at EOF + stream_seek(stream, current_pos, 0, path) return super()._parse(stream, context, path) diff --git a/tests/devices/test_local_decoder_padding.py b/tests/devices/test_local_decoder_padding.py index ef5bea3c..8f5b6e67 100644 --- a/tests/devices/test_local_decoder_padding.py +++ b/tests/devices/test_local_decoder_padding.py @@ -44,7 +44,7 @@ def test_decoder_4byte_padding(): def test_decoder_variable_padding(): - """Test proposed fix: variable length padding should be skipped.""" + """Test variable length padding handling.""" encoder = create_local_encoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456) decoder = create_local_decoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456) From 2576f494559e79f13e2dc4e2b21bbe2256fddf15 Mon Sep 17 00:00:00 2001 From: Luke Lashley Date: Tue, 9 Dec 2025 19:43:38 -0500 Subject: [PATCH 4/6] chore: update roborock/protocol.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- roborock/protocol.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/roborock/protocol.py b/roborock/protocol.py index 0ed5f4c8..724bf0cc 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -351,10 +351,11 @@ def _parse(self, stream, context, path): data = stream.read() start_index = -1 - # Find the earliest occurrence of any valid version - candidates = [idx for v in valid_versions if (idx := data.find(v)) != -1] - if candidates: - start_index = min(candidates) + # Find the earliest occurrence of any valid version in a single pass + for i in range(len(data) - 2): + if data[i:i+3] in valid_versions: + start_index = i + break if start_index != -1: # Found a valid version header at `start_index`. From 3cff6c43e43e12c043bd8ecf56a4c4fb526f2219 Mon Sep 17 00:00:00 2001 From: Luke Date: Tue, 9 Dec 2025 19:49:51 -0500 Subject: [PATCH 5/6] chore: add debug to help us determine if buffer is source of problem --- roborock/protocol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/roborock/protocol.py b/roborock/protocol.py index 724bf0cc..58dc4671 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -353,7 +353,7 @@ def _parse(self, stream, context, path): start_index = -1 # Find the earliest occurrence of any valid version in a single pass for i in range(len(data) - 2): - if data[i:i+3] in valid_versions: + if data[i : i + 3] in valid_versions: start_index = i break @@ -537,6 +537,7 @@ def decode(bytes_data: bytes) -> list[RoborockMessage]: parsed_messages, remaining = MessageParser.parse( buffer, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce ) + _LOGGER.debug("Found %d extra bytes: %s", len(remaining), remaining) buffer = remaining return parsed_messages From ee8f6ff35dcffa48d9385044e17de7e023e69ada Mon Sep 17 00:00:00 2001 From: Luke Date: Tue, 9 Dec 2025 20:18:43 -0500 Subject: [PATCH 6/6] chore: only log if remaining --- roborock/protocol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/roborock/protocol.py b/roborock/protocol.py index 58dc4671..6d098d20 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -537,7 +537,8 @@ def decode(bytes_data: bytes) -> list[RoborockMessage]: parsed_messages, remaining = MessageParser.parse( buffer, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce ) - _LOGGER.debug("Found %d extra bytes: %s", len(remaining), remaining) + if remaining: + _LOGGER.debug("Found %d extra bytes: %s", len(remaining), remaining) buffer = remaining return parsed_messages