Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions roborock/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,35 @@ class PrefixedStruct(Struct):
def _parse(self, stream, context, path):
subcon1 = Peek(Optional(Bytes(3)))
peek_version = subcon1.parse_stream(stream, **context)
if peek_version not in (b"1.0", b"A01", b"B01", b"L01"):
subcon2 = Bytes(4)
subcon2.parse_stream(stream, **context)

valid_versions = (b"1.0", b"A01", b"B01", b"L01")
if peek_version not in valid_versions:
# Current stream position does not start with a valid version.
# Scan forward to find one.
current_pos = stream_tell(stream, path)
# Read remaining data to find a valid header
data = stream.read()

start_index = -1
# Find the earliest occurrence of any valid version in a single pass
for i in range(len(data) - 2):
if data[i : i + 3] in valid_versions:
start_index = i
break

if start_index != -1:
# Found a valid version header at `start_index`.
# Seek to that position (original_pos + index).
if start_index != 4:
# 4 is the typical/expected amount we prune off,
# therefore, we only want a debug if we have a different length.
_LOGGER.debug("Stripping %d bytes of invalid data from stream", start_index)
stream_seek(stream, current_pos + start_index, 0, path)
else:
_LOGGER.debug("No valid version header found in stream, continuing anyways...")
# Seek back to the original position to avoid parsing at EOF
stream_seek(stream, current_pos, 0, path)

return super()._parse(stream, context, path)

def _build(self, obj, stream, context, path):
Expand Down Expand Up @@ -511,6 +537,8 @@ def decode(bytes_data: bytes) -> list[RoborockMessage]:
parsed_messages, remaining = MessageParser.parse(
buffer, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce
)
if remaining:
_LOGGER.debug("Found %d extra bytes: %s", len(remaining), remaining)
buffer = remaining
return parsed_messages

Expand Down
6 changes: 3 additions & 3 deletions tests/devices/test_local_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ async def test_message_decode_error(local_channel: LocalChannel, caplog: pytest.
local_channel._data_received(b"invalid_payload")
await asyncio.sleep(0.01) # yield

assert len(caplog.records) == 1
assert caplog.records[0].levelname == "WARNING"
assert "Failed to decode message" in caplog.records[0].message
warning_records = [record for record in caplog.records if record.levelname == "WARNING"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think theres something you can do like with caplog.at_level(logging.WARNING): as well surrounding this, here and below

assert len(warning_records) == 1
assert "Failed to decode message" in warning_records[0].message


async def test_subscribe_callback(
Expand Down
89 changes: 89 additions & 0 deletions tests/devices/test_local_decoder_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from roborock.protocol import create_local_decoder, create_local_encoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently have the pattern that each test module corresponds to an actual model. I know AI likes to make up random test files depending on the subject, but i'd rather see this correspond to a module.

Should this be in tests/test_protocol.py?

from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol

TEST_LOCAL_KEY = "local_key"


def test_decoder_clean_message():
encoder = create_local_encoder(TEST_LOCAL_KEY)
decoder = create_local_decoder(TEST_LOCAL_KEY)

msg = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
payload=b"test_payload",
version=b"1.0",
seq=1,
random=123,
)
encoded = encoder(msg)

decoded = decoder(encoded)
assert len(decoded) == 1
assert decoded[0].payload == b"test_payload"
Comment on lines +7 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you can write the first three tests as one test like this:

Suggested change
def test_decoder_clean_message():
encoder = create_local_encoder(TEST_LOCAL_KEY)
decoder = create_local_decoder(TEST_LOCAL_KEY)
msg = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
payload=b"test_payload",
version=b"1.0",
seq=1,
random=123,
)
encoded = encoder(msg)
decoded = decoder(encoded)
assert len(decoded) == 1
assert decoded[0].payload == b"test_payload"
@pytest.mark.parametrize(
("garbage"),
[
b"",
b"\x00\x00\x05\xa1"
b"\x00\x00\x05\xa1\xff\xff"
],
)
def test_decoder_clean_message(garbage: bytes):
encoder = create_local_encoder(TEST_LOCAL_KEY)
decoder = create_local_decoder(TEST_LOCAL_KEY)
msg = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
payload=b"test_payload",
version=b"1.0",
seq=1,
random=123,
)
encoded = encoder(msg)
decoded = decoder(garbage + encoded)
assert len(decoded) == 1
assert decoded[0].payload == b"test_payload"



def test_decoder_4byte_padding():
"""Test existing behavior: 4 byte padding should be skipped."""
encoder = create_local_encoder(TEST_LOCAL_KEY)
decoder = create_local_decoder(TEST_LOCAL_KEY)

msg = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
payload=b"test_payload",
version=b"1.0",
)
encoded = encoder(msg)

# Prepend 4 bytes of garbage
garbage = b"\x00\x00\x05\xa1"
data = garbage + encoded

decoded = decoder(data)
assert len(decoded) == 1
assert decoded[0].payload == b"test_payload"


def test_decoder_variable_padding():
"""Test variable length padding handling."""
encoder = create_local_encoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456)
decoder = create_local_decoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456)

msg = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
payload=b"test_payload",
version=b"L01",
)
encoded = encoder(msg)

# Prepend 6 bytes of garbage
garbage = b"\x00\x00\x05\xa1\xff\xff"
data = garbage + encoded

decoded = decoder(data)
assert len(decoded) == 1
assert decoded[0].payload == b"test_payload"


def test_decoder_split_padding_variable():
"""Test variable padding split across chunks."""
encoder = create_local_encoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456)
decoder = create_local_decoder(TEST_LOCAL_KEY, connect_nonce=123, ack_nonce=456)

msg = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
payload=b"test_payload",
version=b"L01",
)
encoded = encoder(msg)

garbage = b"\x00\x00\x05\xa1\xff\xff" # 6 bytes

# Send garbage
decoded1 = decoder(garbage)
assert len(decoded1) == 0

# Send message
decoded2 = decoder(encoded)

assert len(decoded2) == 1
assert decoded2[0].payload == b"test_payload"
6 changes: 3 additions & 3 deletions tests/devices/test_mqtt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ async def test_message_decode_error(
mqtt_message_handler(b"invalid_payload")
await asyncio.sleep(0.01) # yield

assert len(caplog.records) == 1
assert caplog.records[0].levelname == "WARNING"
assert "Failed to decode message" in caplog.records[0].message
warning_records = [record for record in caplog.records if record.levelname == "WARNING"]
assert len(warning_records) == 1
assert "Failed to decode message" in warning_records[0].message
unsub()


Expand Down