diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 3be8b8cb..623e9e68 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -117,4 +117,4 @@ async def get_status(self) -> Status: This is a placeholder command and will likely be changed/moved in the future. """ status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus) - return await self._v1_channel.send_decoded_command(RoborockCommand.GET_STATUS, response_type=status_type) + return await self._v1_channel.rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type) diff --git a/roborock/devices/local_channel.py b/roborock/devices/local_channel.py index 9048947d..c4bb20ea 100644 --- a/roborock/devices/local_channel.py +++ b/roborock/devices/local_channel.py @@ -50,6 +50,11 @@ def __init__(self, host: str, local_key: str): self._encoder: Encoder = create_local_encoder(local_key) self._queue_lock = asyncio.Lock() + @property + def is_connected(self) -> bool: + """Check if the channel is currently connected.""" + return self._is_connected + async def connect(self) -> None: """Connect to the device.""" if self._is_connected: @@ -113,7 +118,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None: else: _LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id) - async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage: + async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage: """Send a command message and wait for the response message.""" if not self._transport or not self._is_connected: raise RoborockConnectionException("Not connected to device") diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index 00a01210..eb147436 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -80,7 +80,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None: else: _LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id) - async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage: + async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage: """Send a command message and wait for the response message. Returns the raw response message - caller is responsible for parsing. diff --git a/roborock/devices/v1_channel.py b/roborock/devices/v1_channel.py index 927b4083..dd2e7a14 100644 --- a/roborock/devices/v1_channel.py +++ b/roborock/devices/v1_channel.py @@ -6,25 +6,21 @@ import logging from collections.abc import Callable -from typing import Any, TypeVar +from typing import TypeVar from roborock.containers import HomeDataDevice, NetworkInfo, RoborockBase, UserData from roborock.exceptions import RoborockException from roborock.mqtt.session import MqttParams, MqttSession from roborock.protocols.v1_protocol import ( - CommandType, - ParamsType, SecurityData, - create_mqtt_payload_encoder, create_security_data, - decode_rpc_response, - encode_local_payload, ) from roborock.roborock_message import RoborockMessage from roborock.roborock_typing import RoborockCommand from .local_channel import LocalChannel, LocalSession, create_local_session from .mqtt_channel import MqttChannel +from .v1_rpc_channel import V1RpcChannel, create_combined_rpc_channel, create_mqtt_rpc_channel _LOGGER = logging.getLogger(__name__) @@ -58,9 +54,10 @@ def __init__( """ self._device_uid = device_uid self._mqtt_channel = mqtt_channel - self._mqtt_payload_encoder = create_mqtt_payload_encoder(security_data) + self._mqtt_rpc_channel = create_mqtt_rpc_channel(mqtt_channel, security_data) self._local_session = local_session self._local_channel: LocalChannel | None = None + self._combined_rpc_channel: V1RpcChannel | None = None self._mqtt_unsub: Callable[[], None] | None = None self._local_unsub: Callable[[], None] | None = None self._callback: Callable[[RoborockMessage], None] | None = None @@ -76,6 +73,16 @@ def is_mqtt_connected(self) -> bool: """Return whether MQTT connection is available.""" return self._mqtt_unsub is not None + @property + def rpc_channel(self) -> V1RpcChannel: + """Return the combined RPC channel prefers local with a fallback to MQTT.""" + return self._combined_rpc_channel or self._mqtt_rpc_channel + + @property + def mqtt_rpc_channel(self) -> V1RpcChannel: + """Return the MQTT RPC channel.""" + return self._mqtt_rpc_channel + async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: """Subscribe to all messages from the device. @@ -119,7 +126,9 @@ async def _get_networking_info(self) -> NetworkInfo: This is a cloud only command used to get the local device's IP address. """ try: - return await self._send_mqtt_decoded_command(RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo) + return await self._mqtt_rpc_channel.send_command( + RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo + ) except RoborockException as e: raise RoborockException(f"Network info failed for device {self._device_uid}") from e @@ -136,59 +145,9 @@ async def _local_connect(self) -> Callable[[], None]: except RoborockException as e: self._local_channel = None raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e - + self._combined_rpc_channel = create_combined_rpc_channel(self._local_channel, self._mqtt_rpc_channel) return await self._local_channel.subscribe(self._on_local_message) - async def send_decoded_command( - self, - method: CommandType, - *, - response_type: type[_T], - params: ParamsType = None, - ) -> _T: - """Send a command using the best available transport. - - Will prefer local connection if available, falling back to MQTT. - """ - connection = "local" if self.is_local_connected else "mqtt" - _LOGGER.debug("Sending command (%s): %s, params=%s", connection, method, params) - if self._local_channel: - return await self._send_local_decoded_command(method, response_type=response_type, params=params) - return await self._send_mqtt_decoded_command(method, response_type=response_type, params=params) - - async def _send_mqtt_raw_command(self, method: CommandType, params: ParamsType | None = None) -> dict[str, Any]: - """Send a raw command and return a raw unparsed response.""" - message = self._mqtt_payload_encoder(method, params) - _LOGGER.debug("Sending MQTT message for device %s: %s", self._device_uid, message) - response = await self._mqtt_channel.send_command(message) - return decode_rpc_response(response) - - async def _send_mqtt_decoded_command( - self, method: CommandType, *, response_type: type[_T], params: ParamsType | None = None - ) -> _T: - """Send a command over MQTT and decode the response.""" - decoded_response = await self._send_mqtt_raw_command(method, params) - return response_type.from_dict(decoded_response) - - async def _send_local_raw_command(self, method: CommandType, params: ParamsType | None = None) -> dict[str, Any]: - """Send a raw command over local connection.""" - if not self._local_channel: - raise RoborockException("Local channel is not connected") - - message = encode_local_payload(method, params) - _LOGGER.debug("Sending local message for device %s: %s", self._device_uid, message) - response = await self._local_channel.send_command(message) - return decode_rpc_response(response) - - async def _send_local_decoded_command( - self, method: CommandType, *, response_type: type[_T], params: ParamsType | None = None - ) -> _T: - """Send a command over local connection and decode the response.""" - if not self._local_channel: - raise RoborockException("Local channel is not connected") - decoded_response = await self._send_local_raw_command(method, params) - return response_type.from_dict(decoded_response) - def _on_mqtt_message(self, message: RoborockMessage) -> None: """Handle incoming MQTT messages.""" _LOGGER.debug("V1Channel received MQTT message from device %s: %s", self._device_uid, message) diff --git a/roborock/devices/v1_rpc_channel.py b/roborock/devices/v1_rpc_channel.py new file mode 100644 index 00000000..f66f33fb --- /dev/null +++ b/roborock/devices/v1_rpc_channel.py @@ -0,0 +1,148 @@ +"""V1 Rpc Channel for Roborock devices. + +This is a wrapper around the V1 channel that provides a higher level interface +for sending typed commands and receiving typed responses. This also provides +a simple interface for sending commands and receiving responses over both MQTT +and local connections, preferring local when available. +""" + +import logging +from collections.abc import Callable +from typing import Any, Protocol, TypeVar, overload + +from roborock.containers import RoborockBase +from roborock.protocols.v1_protocol import ( + CommandType, + ParamsType, + SecurityData, + create_mqtt_payload_encoder, + decode_rpc_response, + encode_local_payload, +) +from roborock.roborock_message import RoborockMessage + +from .local_channel import LocalChannel +from .mqtt_channel import MqttChannel + +_LOGGER = logging.getLogger(__name__) + + +_T = TypeVar("_T", bound=RoborockBase) + + +class V1RpcChannel(Protocol): + """Protocol for V1 RPC channels. + + This is a wrapper around a raw channel that provides a high-level interface + for sending commands and receiving responses. + """ + + @overload + async def send_command( + self, + method: CommandType, + *, + params: ParamsType = None, + ) -> Any: + """Send a command and return a decoded response.""" + ... + + @overload + async def send_command( + self, + method: CommandType, + *, + response_type: type[_T], + params: ParamsType = None, + ) -> _T: + """Send a command and return a parsed response RoborockBase type.""" + ... + + +class BaseV1RpcChannel(V1RpcChannel): + """Base implementation that provides the typed response logic.""" + + async def send_command( + self, + method: CommandType, + *, + response_type: type[_T] | None = None, + params: ParamsType = None, + ) -> _T | Any: + """Send a command and return either a decoded or parsed response.""" + decoded_response = await self._send_raw_command(method, params=params) + + if response_type is not None: + return response_type.from_dict(decoded_response) + return decoded_response + + async def _send_raw_command( + self, + method: CommandType, + *, + params: ParamsType = None, + ) -> Any: + """Send a raw command and return the decoded response. Must be implemented by subclasses.""" + raise NotImplementedError + + +class CombinedV1RpcChannel(BaseV1RpcChannel): + """A V1 RPC channel that can use both local and MQTT channels, preferring local when available.""" + + def __init__( + self, local_channel: LocalChannel, local_rpc_channel: V1RpcChannel, mqtt_channel: V1RpcChannel + ) -> None: + """Initialize the combined channel with local and MQTT channels.""" + self._local_channel = local_channel + self._local_rpc_channel = local_rpc_channel + self._mqtt_rpc_channel = mqtt_channel + + async def _send_raw_command( + self, + method: CommandType, + *, + params: ParamsType = None, + ) -> Any: + """Send a command and return a parsed response RoborockBase type.""" + if self._local_channel.is_connected: + return await self._local_rpc_channel.send_command(method, params=params) + return await self._mqtt_rpc_channel.send_command(method, params=params) + + +class PayloadEncodedV1RpcChannel(BaseV1RpcChannel): + """Protocol for V1 channels that send encoded commands.""" + + def __init__( + self, + name: str, + channel: MqttChannel | LocalChannel, + payload_encoder: Callable[[CommandType, ParamsType], RoborockMessage], + ) -> None: + """Initialize the channel with a raw channel and an encoder function.""" + self._name = name + self._channel = channel + self._payload_encoder = payload_encoder + + async def _send_raw_command( + self, + method: CommandType, + *, + params: ParamsType = None, + ) -> Any: + """Send a command and return a parsed response RoborockBase type.""" + _LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params) + message = self._payload_encoder(method, params) + response = await self._channel.send_message(message) + return decode_rpc_response(response) + + +def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel: + """Create a V1 RPC channel using an MQTT channel.""" + payload_encoder = create_mqtt_payload_encoder(security_data) + return PayloadEncodedV1RpcChannel("mqtt", mqtt_channel, payload_encoder) + + +def create_combined_rpc_channel(local_channel: LocalChannel, mqtt_rpc_channel: V1RpcChannel) -> V1RpcChannel: + """Create a V1 RPC channel that combines local and MQTT channels.""" + local_rpc_channel = PayloadEncodedV1RpcChannel("local", local_channel, encode_local_payload) + return CombinedV1RpcChannel(local_channel, local_rpc_channel, mqtt_rpc_channel) diff --git a/tests/devices/test_device.py b/tests/devices/test_device.py index 3bb85670..52c941ca 100644 --- a/tests/devices/test_device.py +++ b/tests/devices/test_device.py @@ -56,11 +56,11 @@ async def test_device_connection(device: RoborockDevice, channel: AsyncMock) -> async def test_device_get_status_command(device: RoborockDevice, channel: AsyncMock) -> None: """Test the device get_status command.""" # Mock response for get_status command - channel.send_decoded_command.return_value = STATUS + channel.rpc_channel.send_command.return_value = STATUS # Test get_status and verify the command was sent status = await device.get_status() - assert channel.send_decoded_command.called + assert channel.rpc_channel.send_command.called # Verify the result assert status is not None diff --git a/tests/devices/test_local_channel.py b/tests/devices/test_local_channel.py index 801c81fc..04339168 100644 --- a/tests/devices/test_local_channel.py +++ b/tests/devices/test_local_channel.py @@ -116,13 +116,13 @@ async def test_close_without_connection(local_channel: LocalChannel) -> None: assert local_channel._is_connected is False -async def test_send_command_not_connected(local_channel: LocalChannel) -> None: +async def test_send_message_not_connected(local_channel: LocalChannel) -> None: """Test sending command when not connected raises exception.""" with pytest.raises(RoborockConnectionException, match="Not connected to device"): - await local_channel.send_command(TEST_REQUEST) + await local_channel.send_message(TEST_REQUEST) -async def test_send_command_without_request_id(local_channel: LocalChannel, mock_loop: Mock) -> None: +async def test_send_message_without_request_id(local_channel: LocalChannel, mock_loop: Mock) -> None: """Test sending command without request ID raises exception.""" await local_channel.connect() @@ -133,7 +133,7 @@ async def test_send_command_without_request_id(local_channel: LocalChannel, mock ) with pytest.raises(RoborockException, match="Message must have a request_id"): - await local_channel.send_command(test_message) + await local_channel.send_message(test_message) async def test_successful_command_response(local_channel: LocalChannel, mock_loop: Mock, mock_transport: Mock) -> None: @@ -141,7 +141,7 @@ async def test_successful_command_response(local_channel: LocalChannel, mock_loo await local_channel.connect() # Send command in background task - command_task = asyncio.create_task(local_channel.send_command(TEST_REQUEST)) + command_task = asyncio.create_task(local_channel.send_message(TEST_REQUEST)) await asyncio.sleep(0.01) # yield # Simulate receiving response via the protocol callback @@ -165,8 +165,8 @@ async def test_concurrent_commands(local_channel: LocalChannel, mock_loop: Mock, await local_channel.connect() # Start both commands concurrently - task1 = asyncio.create_task(local_channel.send_command(TEST_REQUEST, timeout=5.0)) - task2 = asyncio.create_task(local_channel.send_command(TEST_REQUEST2, timeout=5.0)) + task1 = asyncio.create_task(local_channel.send_message(TEST_REQUEST, timeout=5.0)) + task2 = asyncio.create_task(local_channel.send_message(TEST_REQUEST2, timeout=5.0)) await asyncio.sleep(0.01) # yield # Send responses @@ -188,12 +188,12 @@ async def test_duplicate_request_id_prevention(local_channel: LocalChannel, mock await local_channel.connect() # Start first command - task1 = asyncio.create_task(local_channel.send_command(TEST_REQUEST, timeout=5.0)) + task1 = asyncio.create_task(local_channel.send_message(TEST_REQUEST, timeout=5.0)) await asyncio.sleep(0.01) # yield # Try to start second command with same request ID with pytest.raises(RoborockException, match="Request ID 12345 already pending"): - await local_channel.send_command(TEST_REQUEST, timeout=5.0) + await local_channel.send_message(TEST_REQUEST, timeout=5.0) # Complete first command local_channel._data_received(ENCODER(TEST_RESPONSE)) @@ -208,7 +208,7 @@ async def test_command_timeout(local_channel: LocalChannel, mock_loop: Mock) -> await local_channel.connect() with pytest.raises(RoborockException, match="Command timed out after 0.1s"): - await local_channel.send_command(TEST_REQUEST, timeout=0.1) + await local_channel.send_message(TEST_REQUEST, timeout=0.1) async def test_message_decode_error(local_channel: LocalChannel, caplog: pytest.LogCaptureFixture) -> None: @@ -242,7 +242,7 @@ async def test_subscribe_callback_with_rpc_response( await local_channel.connect() # Send request - task = asyncio.create_task(local_channel.send_command(TEST_REQUEST, timeout=5.0)) + task = asyncio.create_task(local_channel.send_message(TEST_REQUEST, timeout=5.0)) await asyncio.sleep(0.01) # yield # Send response and unrelated message diff --git a/tests/devices/test_mqtt_channel.py b/tests/devices/test_mqtt_channel.py index 4bebcf0e..37a7c27c 100644 --- a/tests/devices/test_mqtt_channel.py +++ b/tests/devices/test_mqtt_channel.py @@ -122,7 +122,7 @@ async def test_mqtt_channel(mqtt_session: Mock, mqtt_channel: MqttChannel) -> No assert result == unsub -async def test_send_command_success( +async def test_send_message_success( mqtt_session: Mock, mqtt_channel: MqttChannel, mqtt_message_handler: Callable[[bytes], None], @@ -130,7 +130,7 @@ async def test_send_command_success( """Test successful RPC command sending and response handling.""" # Send a test request. We use a task so we can simulate receiving the response # while the command is still being processed. - command_task = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST)) + command_task = asyncio.create_task(mqtt_channel.send_message(TEST_REQUEST)) await asyncio.sleep(0.01) # yield # Simulate receiving the response message via MQTT @@ -153,7 +153,7 @@ async def test_send_command_success( assert result == TEST_RESPONSE -async def test_send_command_without_request_id( +async def test_send_message_without_request_id( mqtt_session: Mock, mqtt_channel: MqttChannel, mqtt_message_handler: Callable[[bytes], None], @@ -166,7 +166,7 @@ async def test_send_command_without_request_id( ) with pytest.raises(RoborockException, match="Message must have a request_id"): - await mqtt_channel.send_command(test_message) + await mqtt_channel.send_message(test_message) async def test_concurrent_commands( @@ -179,8 +179,8 @@ async def test_concurrent_commands( # Create multiple test messages with different request IDs # Start both commands concurrently - task1 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) - task2 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST2, timeout=5.0)) + task1 = asyncio.create_task(mqtt_channel.send_message(TEST_REQUEST, timeout=5.0)) + task2 = asyncio.create_task(mqtt_channel.send_message(TEST_REQUEST2, timeout=5.0)) await asyncio.sleep(0.01) # yield # Create responses for both @@ -209,8 +209,8 @@ async def test_concurrent_commands_same_request_id( # Create multiple test messages with different request IDs # Start both commands concurrently - task1 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) - task2 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + task1 = asyncio.create_task(mqtt_channel.send_message(TEST_REQUEST, timeout=5.0)) + task2 = asyncio.create_task(mqtt_channel.send_message(TEST_REQUEST, timeout=5.0)) await asyncio.sleep(0.01) # yield # Create response @@ -233,7 +233,7 @@ async def test_handle_completed_future( ) -> None: """Test handling response for an already completed future.""" # Send request - task = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + task = asyncio.create_task(mqtt_channel.send_message(TEST_REQUEST, timeout=5.0)) await asyncio.sleep(0.01) # yield # Send the response twice @@ -255,7 +255,7 @@ async def test_subscribe_callback_with_rpc_response( ) -> None: """Test that subscribe callback is called independent of RPC handling.""" # Send request - task = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + task = asyncio.create_task(mqtt_channel.send_message(TEST_REQUEST, timeout=5.0)) await asyncio.sleep(0.01) # yield assert not received_messages diff --git a/tests/devices/test_v1_channel.py b/tests/devices/test_v1_channel.py index ac24a8fd..fb4ee281 100644 --- a/tests/devices/test_v1_channel.py +++ b/tests/devices/test_v1_channel.py @@ -16,7 +16,7 @@ from roborock.devices.v1_channel import V1Channel from roborock.exceptions import RoborockException from roborock.protocol import create_local_decoder, create_local_encoder, create_mqtt_decoder, create_mqtt_encoder -from roborock.protocols.v1_protocol import SecurityData, encode_local_payload +from roborock.protocols.v1_protocol import SecurityData from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand @@ -59,7 +59,7 @@ def setup_mock_mqtt_channel() -> Mock: """Mock MQTT channel for testing.""" mock_mqtt = AsyncMock(spec=MqttChannel) mock_mqtt.subscribe = AsyncMock() - mock_mqtt.send_command = AsyncMock() + mock_mqtt.send_message = AsyncMock() return mock_mqtt @@ -69,10 +69,10 @@ def setup_mqtt_responses(mock_mqtt_channel: Mock) -> list[RoborockMessage]: responses: list[RoborockMessage] = [TEST_NETWORK_INFO_RESPONSE] - def send_command(*args) -> RoborockMessage: + def send_message(*args) -> RoborockMessage: return responses.pop(0) - mock_mqtt_channel.send_command.side_effect = send_command + mock_mqtt_channel.send_message.side_effect = send_message return responses @@ -82,7 +82,7 @@ def setup_mock_local_channel() -> Mock: mock_local = AsyncMock(spec=LocalChannel) mock_local.connect = AsyncMock() mock_local.subscribe = AsyncMock() - mock_local.send_command = AsyncMock() + mock_local.send_message = AsyncMock() return mock_local @@ -125,7 +125,7 @@ async def test_v1_channel_subscribe_mqtt_only_success( # Setup: MQTT succeeds, local fails mqtt_unsub = Mock() mock_mqtt_channel.subscribe.return_value = mqtt_unsub - mock_mqtt_channel.send_command.return_value = TEST_NETWORK_INFO_RESPONSE + mock_mqtt_channel.send_message.return_value = TEST_NETWORK_INFO_RESPONSE mock_local_channel.connect.side_effect = RoborockException("Connection failed") callback = Mock() @@ -211,7 +211,7 @@ async def test_v1_channel_local_connection_warning_logged( # V1Channel command sending with fallback logic tests -async def test_v1_channel_send_decoded_command_local_preferred( +async def test_v1_channel_send_command_local_preferred( v1_channel: V1Channel, mock_mqtt_channel: Mock, mock_local_channel: Mock, @@ -219,22 +219,22 @@ async def test_v1_channel_send_decoded_command_local_preferred( """Test command sending prefers local connection when available.""" # Establish connections await v1_channel.subscribe(Mock()) - mock_mqtt_channel.send_command.reset_mock(return_value=False) + mock_mqtt_channel.send_message.reset_mock(return_value=False) # Send command - mock_local_channel.send_command.return_value = TEST_RESPONSE - result = await v1_channel.send_decoded_command( + mock_local_channel.send_message.return_value = TEST_RESPONSE + result = await v1_channel.rpc_channel.send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, ) # Verify local was used, not MQTT - mock_local_channel.send_command.assert_called_once() - mock_mqtt_channel.send_command.assert_not_called() + mock_local_channel.send_message.assert_called_once() + mock_mqtt_channel.send_message.assert_not_called() assert result.state == RoborockStateCode.cleaning -async def test_v1_channel_send_decoded_command_local_fails( +async def test_v1_channel_send_command_local_fails( v1_channel: V1Channel, mock_mqtt_channel: Mock, mock_local_channel: Mock, @@ -244,22 +244,22 @@ async def test_v1_channel_send_decoded_command_local_fails( # Establish connections await v1_channel.subscribe(Mock()) - mock_mqtt_channel.send_command.reset_mock(return_value=False) + mock_mqtt_channel.send_message.reset_mock(return_value=False) # Local command fails - mock_local_channel.send_command.side_effect = RoborockException("Local failed") + mock_local_channel.send_message.side_effect = RoborockException("Local failed") # Send command mqtt_responses.append(TEST_RESPONSE) with pytest.raises(RoborockException, match="Local failed"): - await v1_channel.send_decoded_command( + await v1_channel.rpc_channel.send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, ) # Verify local was attempted but not mqtt - mock_local_channel.send_command.assert_called_once() - mock_mqtt_channel.send_command.assert_not_called() + mock_local_channel.send_message.assert_called_once() + mock_mqtt_channel.send_message.assert_not_called() async def test_v1_channel_send_decoded_command_mqtt_only( @@ -274,19 +274,19 @@ async def test_v1_channel_send_decoded_command_mqtt_only( mock_local_channel.connect.side_effect = RoborockException("No local") await v1_channel.subscribe(Mock()) - mock_mqtt_channel.send_command.assert_called_once() # network info - mock_mqtt_channel.send_command.reset_mock(return_value=False) + mock_mqtt_channel.send_message.assert_called_once() # network info + mock_mqtt_channel.send_message.reset_mock(return_value=False) # Send command mqtt_responses.append(TEST_RESPONSE) - result = await v1_channel.send_decoded_command( + result = await v1_channel.rpc_channel.send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, ) # Verify only MQTT was used - mock_local_channel.send_command.assert_not_called() - mock_mqtt_channel.send_command.assert_called_once() + mock_local_channel.send_message.assert_not_called() + mock_mqtt_channel.send_message.assert_called_once() assert result.state == RoborockStateCode.cleaning @@ -300,17 +300,17 @@ async def test_v1_channel_send_decoded_command_with_params( await v1_channel.subscribe(Mock()) # Send command with params - mock_local_channel.send_command.return_value = TEST_RESPONSE + mock_local_channel.send_message.return_value = TEST_RESPONSE test_params = {"volume": 80} - await v1_channel.send_decoded_command( + await v1_channel.rpc_channel.send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, params=test_params, ) # Verify command was sent with correct params - mock_local_channel.send_command.assert_called_once() - call_args = mock_local_channel.send_command.call_args + mock_local_channel.send_message.assert_called_once() + call_args = mock_local_channel.send_message.call_args sent_message = call_args[0][0] assert sent_message assert isinstance(sent_message, RoborockMessage) @@ -388,7 +388,7 @@ async def test_v1_channel_networking_info_retrieved_during_connection( """Test that networking information is retrieved during local connection setup.""" # Setup: MQTT returns network info when requested mock_mqtt_channel.subscribe.return_value = Mock() - mock_mqtt_channel.send_command.return_value = RoborockMessage( + mock_mqtt_channel.send_message.return_value = RoborockMessage( protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=json.dumps({"dps": {"102": json.dumps({"id": 12345, "result": mock_data.NETWORK_INFO})}}).encode(), ) @@ -402,7 +402,7 @@ async def test_v1_channel_networking_info_retrieved_during_connection( assert v1_channel.is_local_connected # Verify network info was requested via MQTT - mock_mqtt_channel.send_command.assert_called_once() + mock_mqtt_channel.send_message.assert_called_once() # Verify local session was created with the correct IP mock_local_session.assert_called_once_with(mock_data.NETWORK_INFO["ip"]) @@ -416,7 +416,7 @@ async def test_v1_channel_local_connect_network_info_failure( mock_mqtt_channel: Mock, ) -> None: """Test local connection when network info retrieval fails.""" - mock_mqtt_channel.send_command.side_effect = RoborockException("Network info failed") + mock_mqtt_channel.send_message.side_effect = RoborockException("Network info failed") with pytest.raises(RoborockException): await v1_channel._local_connect() @@ -429,7 +429,7 @@ async def test_v1_channel_local_connect_connection_failure( ) -> None: """Test local connection when connection itself fails.""" # Network info succeeds but connection fails - mock_mqtt_channel.send_command.return_value = RoborockMessage( + mock_mqtt_channel.send_message.return_value = RoborockMessage( protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=json.dumps({"dps": {"102": json.dumps({"id": 12345, "result": mock_data.NETWORK_INFO})}}).encode(), ) @@ -439,13 +439,27 @@ async def test_v1_channel_local_connect_connection_failure( await v1_channel._local_connect() -async def test_v1_channel_command_encoding_validation(v1_channel: V1Channel) -> None: +async def test_v1_channel_command_encoding_validation( + v1_channel: V1Channel, + mqtt_responses: list[RoborockMessage], + mock_mqtt_channel: Mock, + mock_local_channel: Mock, +) -> None: """Test that command encoding works for different protocols.""" - # Test MQTT encoding - mqtt_message = v1_channel._mqtt_payload_encoder(RoborockCommand.CHANGE_SOUND_VOLUME, {"volume": 50}) + await v1_channel.subscribe(Mock()) + mock_mqtt_channel.send_message.reset_mock(return_value=False) + + # Send mqtt command and capture the request + mqtt_responses.append(TEST_RESPONSE) + await v1_channel.mqtt_rpc_channel.send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50}) + mock_mqtt_channel.send_message.assert_called_once() + mqtt_message = mock_mqtt_channel.send_message.call_args[0][0] - # Test local encoding - local_message = encode_local_payload(RoborockCommand.CHANGE_SOUND_VOLUME, {"volume": 50}) + # Send local command and capture the request + mock_local_channel.send_message.return_value = TEST_RESPONSE + await v1_channel.rpc_channel.send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50}) + mock_local_channel.send_message.assert_called_once() + local_message = mock_local_channel.send_message.call_args[0][0] # Verify both are RoborockMessage instances assert isinstance(mqtt_message, RoborockMessage) @@ -491,7 +505,7 @@ async def test_v1_channel_full_subscribe_and_command_flow( local_unsub = Mock() mock_mqtt_channel.subscribe.return_value = mqtt_unsub mock_local_channel.subscribe.return_value = local_unsub - mock_local_channel.send_command.return_value = TEST_RESPONSE + mock_local_channel.send_message.return_value = TEST_RESPONSE # Create V1Channel and subscribe v1_channel = V1Channel( @@ -504,21 +518,21 @@ async def test_v1_channel_full_subscribe_and_command_flow( # Mock network info for local connection callback = Mock() unsub = await v1_channel.subscribe(callback) - mock_mqtt_channel.send_command.reset_mock(return_value=False) + mock_mqtt_channel.send_message.reset_mock(return_value=False) # Verify both connections established assert v1_channel.is_mqtt_connected assert v1_channel.is_local_connected # Send a command (should use local) - result = await v1_channel.send_decoded_command( + result = await v1_channel.rpc_channel.send_command( RoborockCommand.GET_STATUS, response_type=S5MaxStatus, ) # Verify command was sent via local connection - mock_local_channel.send_command.assert_called_once() - mock_mqtt_channel.send_command.assert_not_called() + mock_local_channel.send_message.assert_called_once() + mock_mqtt_channel.send_message.assert_not_called() assert result.state == RoborockStateCode.cleaning # Test message callback