Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions roborock/devices/v1_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ class RpcStrategy:
class RpcChannel(V1RpcChannel):
"""Provides an RPC interface around a pub/sub transport channel."""

def __init__(self, rpc_strategies: list[RpcStrategy]) -> None:
"""Initialize the RpcChannel with on ordered list of strategies."""
self._rpc_strategies = rpc_strategies
def __init__(self, rpc_strategies_cb: Callable[[], list[RpcStrategy]]) -> None:
"""Initialize the RpcChannel with an ordered list of strategies."""
self._rpc_strategies_cb = rpc_strategies_cb

async def send_command(
self,
Expand All @@ -86,7 +86,7 @@ async def send_command(

# Try each channel in order until one succeeds
last_exception = None
for strategy in self._rpc_strategies:
for strategy in self._rpc_strategies_cb():
try:
decoded_response = await self._send_rpc(strategy, request)
except RoborockException as e:
Expand Down Expand Up @@ -203,23 +203,35 @@ def is_mqtt_connected(self) -> bool:

@property
def rpc_channel(self) -> V1RpcChannel:
"""Return the combined RPC channel that prefers local with a fallback to MQTT."""
strategies = []
if local_rpc_strategy := self._create_local_rpc_strategy():
strategies.append(local_rpc_strategy)
strategies.append(self._create_mqtt_rpc_strategy())
return RpcChannel(strategies)
"""Return the combined RPC channel that prefers local with a fallback to MQTT.

The returned V1RpcChannel may be long lived and will respect the
current connection state of the underlying channels.
"""

def rpc_strategies_cb() -> list[RpcStrategy]:
strategies = []
if local_rpc_strategy := self._create_local_rpc_strategy():
strategies.append(local_rpc_strategy)
strategies.append(self._create_mqtt_rpc_strategy())
return strategies

return RpcChannel(rpc_strategies_cb)

@property
def mqtt_rpc_channel(self) -> V1RpcChannel:
"""Return the MQTT-only RPC channel."""
return RpcChannel([self._create_mqtt_rpc_strategy()])
"""Return the MQTT-only RPC channel.

The returned V1RpcChannel may be long lived and will respect the
current connection state of the underlying channels.
"""
return RpcChannel(lambda: [self._create_mqtt_rpc_strategy()])

@property
def map_rpc_channel(self) -> V1RpcChannel:
"""Return the map RPC channel used for fetching map content."""
decoder = create_map_response_decoder(security_data=self._security_data)
return RpcChannel([self._create_mqtt_rpc_strategy(decoder)])
return RpcChannel(lambda: [self._create_mqtt_rpc_strategy(decoder)])

def _create_local_rpc_strategy(self) -> RpcStrategy | None:
"""Create the RPC strategy for local transport."""
Expand Down
69 changes: 58 additions & 11 deletions tests/devices/test_v1_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
create_mqtt_decoder,
create_mqtt_encoder,
)
from roborock.protocols.v1_protocol import MapResponse, SecurityData
from roborock.protocols.v1_protocol import MapResponse, SecurityData, V1RpcChannel
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from roborock.roborock_typing import RoborockCommand

Expand Down Expand Up @@ -141,6 +141,29 @@ def setup_v1_channel(
)


@pytest.fixture(name="rpc_channel")
def setup_rpc_channel(v1_channel: V1Channel) -> V1RpcChannel:
"""Fixture to set up the RPC channel for tests.

We expect tests to use this to send commands via the V1Channel since we
want to exercise the behavior that the V1RpcChannel is long lived and
respects the current state of the underlying channels.
"""
return v1_channel.rpc_channel


@pytest.fixture(name="mqtt_rpc_channel")
def setup_mqtt_rpc_channel(v1_channel: V1Channel) -> V1RpcChannel:
"""Fixture to set up the MQTT RPC channel for tests."""
return v1_channel.mqtt_rpc_channel


@pytest.fixture(name="map_rpc_channel")
def setup_map_rpc_channel(v1_channel: V1Channel) -> V1RpcChannel:
"""Fixture to set up the Map RPC channel for tests."""
return v1_channel.map_rpc_channel


@pytest.fixture(name="warning_caplog")
def setup_warning_caplog(caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture:
"""Fixture to capture warning messages."""
Expand Down Expand Up @@ -274,6 +297,7 @@ async def test_v1_channel_send_command_local_preferred(
v1_channel: V1Channel,
mock_mqtt_channel: Mock,
mock_local_channel: Mock,
rpc_channel: V1RpcChannel,
) -> None:
"""Test command sending prefers local connection when available."""
# Establish connections
Expand All @@ -282,7 +306,7 @@ async def test_v1_channel_send_command_local_preferred(

# Send command
mock_local_channel.response_queue.append(TEST_RESPONSE)
result = await v1_channel.rpc_channel.send_command(
result = await rpc_channel.send_command(
RoborockCommand.GET_STATUS,
response_type=S5MaxStatus,
)
Expand All @@ -295,6 +319,7 @@ async def test_v1_channel_send_command_local_fails(
v1_channel: V1Channel,
mock_mqtt_channel: Mock,
mock_local_channel: Mock,
rpc_channel: V1RpcChannel,
) -> None:
"""Test case where sending with local connection fails, falling back to MQTT."""

Expand All @@ -310,7 +335,7 @@ async def test_v1_channel_send_command_local_fails(
mock_mqtt_channel.response_queue.append(TEST_RESPONSE)

# Send command
result = await v1_channel.rpc_channel.send_command(
result = await rpc_channel.send_command(
RoborockCommand.GET_STATUS,
response_type=S5MaxStatus,
)
Expand All @@ -327,21 +352,39 @@ async def test_v1_channel_send_command_local_fails(
assert mock_mqtt_channel.published_messages[-1].protocol == RoborockMessageProtocol.RPC_REQUEST


async def test_v1_channel_send_decoded_command_mqtt_only(
@pytest.mark.parametrize(
("local_channel_side_effect", "local_channel_responses", "mock_mqtt_channel_responses"),
[
(RoborockException("Local failed"), [], [TEST_RESPONSE]),
(None, [], [TEST_RESPONSE]),
(None, [RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=b"invalid")], [TEST_RESPONSE]),
],
ids=[
"local-fails-mqtt-succeeds",
"local-no-response-mqtt-succeeds",
"local-invalid-response-mqtt-succeeds",
],
)
async def test_v1_channel_send_pick_first_available(
v1_channel: V1Channel,
rpc_channel: V1RpcChannel,
mock_mqtt_channel: Mock,
mock_local_channel: Mock,
local_channel_side_effect: Exception | None,
local_channel_responses: list[RoborockMessage],
mock_mqtt_channel_responses: list[RoborockMessage],
) -> None:
"""Test command sending works with MQTT only."""
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring "Test command sending works with MQTT only." doesn't match the test name and behavior. This test actually verifies that the system picks the first available channel when local fails or returns no response, falling back to MQTT. Consider updating the docstring to something like "Test command sending picks first available channel, falling back to MQTT when local is unavailable."

Suggested change
"""Test command sending works with MQTT only."""
"""Test command sending picks first available channel, falling back to MQTT when local is unavailable."""

Copilot uses AI. Check for mistakes.
# Setup: only MQTT connection
mock_mqtt_channel.response_queue.append(TEST_NETWORK_INFO_RESPONSE)
mock_local_channel.connect.side_effect = RoborockException("No local")
mock_local_channel.connect.side_effect = local_channel_side_effect

await v1_channel.subscribe(Mock())

# Send command
mock_mqtt_channel.response_queue.append(TEST_RESPONSE)
result = await v1_channel.rpc_channel.send_command(
mock_mqtt_channel.response_queue.extend(mock_mqtt_channel_responses)
mock_local_channel.response_queue.extend(local_channel_responses)
result = await rpc_channel.send_command(
RoborockCommand.GET_STATUS,
response_type=S5MaxStatus,
)
Expand All @@ -352,6 +395,7 @@ async def test_v1_channel_send_decoded_command_mqtt_only(

async def test_v1_channel_send_decoded_command_with_params(
v1_channel: V1Channel,
rpc_channel: V1RpcChannel,
mock_mqtt_channel: Mock,
mock_local_channel: Mock,
) -> None:
Expand All @@ -363,7 +407,7 @@ async def test_v1_channel_send_decoded_command_with_params(
# Send command with params
mock_local_channel.response_queue.append(TEST_RESPONSE)
test_params = {"volume": 80}
await v1_channel.rpc_channel.send_command(
await rpc_channel.send_command(
RoborockCommand.CHANGE_SOUND_VOLUME,
response_type=S5MaxStatus,
params=test_params,
Expand Down Expand Up @@ -492,6 +536,8 @@ async def test_v1_channel_local_connect_network_info_failure_fallback_to_cache(

async def test_v1_channel_command_encoding_validation(
v1_channel: V1Channel,
mqtt_rpc_channel: V1RpcChannel,
rpc_channel: V1RpcChannel,
mock_mqtt_channel: Mock,
mock_local_channel: Mock,
) -> None:
Expand All @@ -501,13 +547,13 @@ async def test_v1_channel_command_encoding_validation(

# Send mqtt command and capture the request
mock_mqtt_channel.response_queue.append(TEST_RESPONSE)
await v1_channel.mqtt_rpc_channel.send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50})
await mqtt_rpc_channel.send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50})
assert mock_mqtt_channel.published_messages
mqtt_message = mock_mqtt_channel.published_messages[0]

# Send local command and capture the request
mock_local_channel.response_queue.append(TEST_RESPONSE_2)
await v1_channel.rpc_channel.send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50})
await rpc_channel.send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50})
assert mock_local_channel.published_messages
local_message = mock_local_channel.published_messages[0]

Expand All @@ -522,6 +568,7 @@ async def test_v1_channel_command_encoding_validation(

async def test_v1_channel_send_map_command(
v1_channel: V1Channel,
map_rpc_channel: V1RpcChannel,
mock_mqtt_channel: Mock,
mock_create_map_response_decoder: Mock,
) -> None:
Expand All @@ -546,7 +593,7 @@ async def test_v1_channel_send_map_command(
mock_mqtt_channel.response_queue.append(map_response_message)

# Send the command and get the result
result = await v1_channel.map_rpc_channel.send_command(RoborockCommand.GET_MAP_V1)
result = await map_rpc_channel.send_command(RoborockCommand.GET_MAP_V1)

# Verify the result is the data from our mocked decoder
assert result == decompressed_map_data