diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index 8c840c57..dbcfb9dc 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -90,5 +90,5 @@ async def restart(self) -> None: def create_mqtt_channel( user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice ) -> MqttChannel: - """Create a V1Channel for the given device.""" + """Create a MQTT channel for the given device.""" return MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params) diff --git a/roborock/devices/traits/v1/__init__.py b/roborock/devices/traits/v1/__init__.py index 6128b1e0..7020bd0f 100644 --- a/roborock/devices/traits/v1/__init__.py +++ b/roborock/devices/traits/v1/__init__.py @@ -38,8 +38,8 @@ from roborock.data.v1.v1_code_mappings import RoborockDockTypeCode from roborock.devices.cache import Cache from roborock.devices.traits import Trait -from roborock.devices.v1_rpc_channel import V1RpcChannel from roborock.map.map_parser import MapParserConfig +from roborock.protocols.v1_protocol import V1RpcChannel from roborock.web_api import UserWebApiClient from .child_lock import ChildLockTrait diff --git a/roborock/devices/traits/v1/common.py b/roborock/devices/traits/v1/common.py index 6decf0bc..63ae2e20 100644 --- a/roborock/devices/traits/v1/common.py +++ b/roborock/devices/traits/v1/common.py @@ -9,7 +9,7 @@ from typing import ClassVar, Self from roborock.data import RoborockBase -from roborock.devices.v1_rpc_channel import V1RpcChannel +from roborock.protocols.v1_protocol import V1RpcChannel from roborock.roborock_typing import RoborockCommand _LOGGER = logging.getLogger(__name__) diff --git a/roborock/devices/v1_channel.py b/roborock/devices/v1_channel.py index 8f1c148e..1f016aa9 100644 --- a/roborock/devices/v1_channel.py +++ b/roborock/devices/v1_channel.py @@ -8,37 +8,43 @@ import datetime import logging from collections.abc import Callable -from typing import TypeVar +from dataclasses import dataclass +from typing import Any, TypeVar from roborock.data import HomeDataDevice, NetworkInfo, RoborockBase, UserData from roborock.exceptions import RoborockException +from roborock.mqtt.health_manager import HealthManager from roborock.mqtt.session import MqttParams, MqttSession from roborock.protocols.v1_protocol import ( + CommandType, + MapResponse, + ParamsType, + RequestMessage, + ResponseData, + ResponseMessage, SecurityData, + V1RpcChannel, + create_map_response_decoder, create_security_data, + decode_rpc_response, ) -from roborock.roborock_message import RoborockMessage +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand from .cache import Cache from .channel import Channel from .local_channel import LocalChannel, LocalSession, create_local_session from .mqtt_channel import MqttChannel -from .v1_rpc_channel import ( - PickFirstAvailable, - V1RpcChannel, - create_local_rpc_channel, - create_map_rpc_channel, - create_mqtt_rpc_channel, -) _LOGGER = logging.getLogger(__name__) __all__ = [ - "V1Channel", + "create_v1_channel", ] _T = TypeVar("_T", bound=RoborockBase) +_TIMEOUT = 10.0 + # Exponential backoff parameters for reconnecting to local MIN_RECONNECT_INTERVAL = datetime.timedelta(minutes=1) @@ -50,6 +56,106 @@ LOCAL_CONNECTION_CHECK_INTERVAL = datetime.timedelta(seconds=15) +@dataclass(frozen=True) +class RpcStrategy: + """Strategy for encoding/sending/decoding RPC commands.""" + + name: str # For debug logging + channel: LocalChannel | MqttChannel + encoder: Callable[[RequestMessage], RoborockMessage] + decoder: Callable[[RoborockMessage], ResponseMessage | MapResponse | None] + health_manager: HealthManager | None = None + + +class RpcChannel(V1RpcChannel): + """Provides an RPC interface around a pub/sub transport channel.""" + + def __init__(self, rpc_strategies: list[RpcStrategy]) -> None: + """Initialize the RpcChannel with on ordered list of strategies.""" + self._rpc_strategies = rpc_strategies + + async def send_command( + self, + method: CommandType, + *, + response_type: type[_T] | None = None, + params: ParamsType = None, + ) -> _T | Any: + """Send a command and return either a decoded or parsed response.""" + request = RequestMessage(method, params=params) + + # Try each channel in order until one succeeds + last_exception = None + for strategy in self._rpc_strategies: + try: + decoded_response = await self._send_rpc(strategy, request) + except RoborockException as e: + _LOGGER.warning("Command %s failed on %s channel: %s", method, strategy.name, e) + last_exception = e + except Exception as e: + _LOGGER.exception("Unexpected error sending command %s on %s channel", method, strategy.name) + last_exception = RoborockException(f"Unexpected error: {e}") + else: + if response_type is not None: + if not isinstance(decoded_response, dict): + raise RoborockException( + f"Expected dict response to parse {response_type.__name__}, got {type(decoded_response)}" + ) + return response_type.from_dict(decoded_response) + return decoded_response + + raise last_exception or RoborockException("No available connection to send command") + + @staticmethod + async def _send_rpc(strategy: RpcStrategy, request: RequestMessage) -> ResponseData | bytes: + """Send a command and return a decoded response type. + + This provides an RPC interface over a given channel strategy. The device + channel only supports publish and subscribe, so this function handles + associating requests with their corresponding responses. + """ + future: asyncio.Future[ResponseData | bytes] = asyncio.Future() + _LOGGER.debug( + "Sending command (%s, request_id=%s): %s, params=%s", + strategy.name, + request.request_id, + request.method, + request.params, + ) + + message = strategy.encoder(request) + + def find_response(response_message: RoborockMessage) -> None: + try: + decoded = strategy.decoder(response_message) + except RoborockException as ex: + _LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex) + return + if decoded is None: + return + _LOGGER.debug("Received response (%s, request_id=%s)", strategy.name, decoded.request_id) + if decoded.request_id == request.request_id: + if isinstance(decoded, ResponseMessage) and decoded.api_error: + future.set_exception(decoded.api_error) + else: + future.set_result(decoded.data) + + unsub = await strategy.channel.subscribe(find_response) + try: + await strategy.channel.publish(message) + result = await asyncio.wait_for(future, timeout=_TIMEOUT) + except TimeoutError as ex: + if strategy.health_manager: + await strategy.health_manager.on_timeout() + future.cancel() + raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex + finally: + unsub() + if strategy.health_manager: + await strategy.health_manager.on_success() + return result + + class V1Channel(Channel): """Unified V1 protocol channel with automatic MQTT/local connection handling. @@ -66,23 +172,13 @@ def __init__( local_session: LocalSession, cache: Cache, ) -> None: - """Initialize the V1Channel. - - Args: - mqtt_channel: MQTT channel for cloud communication - local_session: Factory that creates LocalChannels for a hostname. - """ + """Initialize the V1Channel.""" self._device_uid = device_uid + self._security_data = security_data self._mqtt_channel = mqtt_channel - self._mqtt_rpc_channel = create_mqtt_rpc_channel(mqtt_channel, security_data) + self._mqtt_health_manager = HealthManager(self._mqtt_channel.restart) self._local_session = local_session self._local_channel: LocalChannel | None = None - self._local_rpc_channel: V1RpcChannel | None = None - # Prefer local, fallback to MQTT - self._combined_rpc_channel = PickFirstAvailable( - [lambda: self._local_rpc_channel, lambda: self._mqtt_rpc_channel] - ) - self._map_rpc_channel = create_map_rpc_channel(mqtt_channel, security_data) self._mqtt_unsub: Callable[[], None] | None = None self._local_unsub: Callable[[], None] | None = None self._callback: Callable[[RoborockMessage], None] | None = None @@ -107,18 +203,60 @@ def is_mqtt_connected(self) -> bool: @property def rpc_channel(self) -> V1RpcChannel: - """Return the combined RPC channel prefers local with a fallback to MQTT.""" - return self._combined_rpc_channel + """Return the combined RPC channel that prefers local with a fallback to MQTT.""" + strategies = [] + if local_rpc_strategy := self._create_local_rpc_strategy(): + strategies.append(local_rpc_strategy) + strategies.append(self._create_mqtt_rpc_strategy()) + return RpcChannel(strategies) @property def mqtt_rpc_channel(self) -> V1RpcChannel: - """Return the MQTT RPC channel.""" - return self._mqtt_rpc_channel + """Return the MQTT-only RPC channel.""" + return RpcChannel([self._create_mqtt_rpc_strategy()]) @property def map_rpc_channel(self) -> V1RpcChannel: """Return the map RPC channel used for fetching map content.""" - return self._map_rpc_channel + decoder = create_map_response_decoder(security_data=self._security_data) + return RpcChannel([self._create_mqtt_rpc_strategy(decoder)]) + + def _create_local_rpc_strategy(self) -> RpcStrategy | None: + """Create the RPC strategy for local transport.""" + if self._local_channel is None or not self.is_local_connected: + return None + return RpcStrategy( + name="local", + channel=self._local_channel, + encoder=self._local_encoder, + decoder=decode_rpc_response, + ) + + def _local_encoder(self, x: RequestMessage) -> RoborockMessage: + """Encode a request message for local transport. + + This will read the current local channel's protocol version which + changes as the protocol version is discovered. + """ + if self._local_channel is None: + raise ValueError("Local channel unavailable for encoding") + return x.encode_message( + RoborockMessageProtocol.GENERAL_REQUEST, + version=self._local_channel.protocol_version, + ) + + def _create_mqtt_rpc_strategy(self, decoder: Callable[[RoborockMessage], Any] = decode_rpc_response) -> RpcStrategy: + """Create the RPC strategy for MQTT transport with optional custom decoder.""" + return RpcStrategy( + name="mqtt", + channel=self._mqtt_channel, + encoder=lambda x: x.encode_message( + RoborockMessageProtocol.RPC_REQUEST, + security_data=self._security_data, + ), + decoder=decoder, + health_manager=self._mqtt_health_manager, + ) async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: """Subscribe to all messages from the device. @@ -185,7 +323,7 @@ async def _get_networking_info(self, *, prefer_cache: bool = True) -> NetworkInf _LOGGER.debug("Using cached network info for device %s", self._device_uid) return network_info try: - network_info = await self._mqtt_rpc_channel.send_command( + network_info = await self.mqtt_rpc_channel.send_command( RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo ) except RoborockException as e: @@ -216,7 +354,6 @@ async def _local_connect(self, *, prefer_cache: bool = True) -> None: raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e # Wire up the new channel self._local_channel = local_channel - self._local_rpc_channel = create_local_rpc_channel(self._local_channel) self._local_unsub = await self._local_channel.subscribe(self._on_local_message) _LOGGER.info("Successfully connected to local device %s", self._device_uid) diff --git a/roborock/devices/v1_rpc_channel.py b/roborock/devices/v1_rpc_channel.py deleted file mode 100644 index 270ed05a..00000000 --- a/roborock/devices/v1_rpc_channel.py +++ /dev/null @@ -1,221 +0,0 @@ -"""V1 Rpc Channel for Roborock devices. - -This is a wrapper around the V1 channel that provides a higher level interface -for sending typed commands and receiving typed responses. This also provides -a simple interface for sending commands and receiving responses over both MQTT -and local connections, preferring local when available. -""" - -import asyncio -import logging -from collections.abc import Callable -from typing import Any, Protocol, TypeVar, overload - -from roborock.data import RoborockBase -from roborock.exceptions import RoborockException -from roborock.mqtt.health_manager import HealthManager -from roborock.protocols.v1_protocol import ( - CommandType, - MapResponse, - ParamsType, - RequestMessage, - ResponseData, - ResponseMessage, - SecurityData, - create_map_response_decoder, - decode_rpc_response, -) -from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol - -from .local_channel import LocalChannel -from .mqtt_channel import MqttChannel - -_LOGGER = logging.getLogger(__name__) -_TIMEOUT = 10.0 - - -_T = TypeVar("_T", bound=RoborockBase) -_V = TypeVar("_V") - - -class V1RpcChannel(Protocol): - """Protocol for V1 RPC channels. - - This is a wrapper around a raw channel that provides a high-level interface - for sending commands and receiving responses. - """ - - @overload - async def send_command( - self, - method: CommandType, - *, - params: ParamsType = None, - ) -> Any: - """Send a command and return a decoded response.""" - ... - - @overload - async def send_command( - self, - method: CommandType, - *, - response_type: type[_T], - params: ParamsType = None, - ) -> _T: - """Send a command and return a parsed response RoborockBase type.""" - ... - - -class BaseV1RpcChannel(V1RpcChannel): - """Base implementation that provides the typed response logic.""" - - async def send_command( - self, - method: CommandType, - *, - response_type: type[_T] | None = None, - params: ParamsType = None, - ) -> _T | Any: - """Send a command and return either a decoded or parsed response.""" - decoded_response = await self._send_raw_command(method, params=params) - - if response_type is not None: - return response_type.from_dict(decoded_response) - return decoded_response - - async def _send_raw_command( - self, - method: CommandType, - *, - params: ParamsType = None, - ) -> Any: - """Send a raw command and return the decoded response. Must be implemented by subclasses.""" - raise NotImplementedError - - -class PickFirstAvailable(BaseV1RpcChannel): - """A V1 RPC channel that tries multiple channels and picks the first that works.""" - - def __init__( - self, - channel_cbs: list[Callable[[], V1RpcChannel | None]], - ) -> None: - """Initialize the pick-first-available channel.""" - self._channel_cbs = channel_cbs - - async def _send_raw_command( - self, - method: CommandType, - *, - params: ParamsType = None, - ) -> Any: - """Send a command and return a parsed response RoborockBase type.""" - for channel_cb in self._channel_cbs: - if channel := channel_cb(): - return await channel.send_command(method, params=params) - raise RoborockException("No available connection to send command") - - -class PayloadEncodedV1RpcChannel(BaseV1RpcChannel): - """Protocol for V1 channels that send encoded commands.""" - - def __init__( - self, - name: str, - channel: MqttChannel | LocalChannel, - payload_encoder: Callable[[RequestMessage], RoborockMessage], - decoder: Callable[[RoborockMessage], ResponseMessage] | Callable[[RoborockMessage], MapResponse | None], - health_manager: HealthManager | None = None, - ) -> None: - """Initialize the channel with a raw channel and an encoder function.""" - self._name = name - self._channel = channel - self._payload_encoder = payload_encoder - self._decoder = decoder - self._health_manager = health_manager - - async def _send_raw_command( - self, - method: CommandType, - *, - params: ParamsType = None, - ) -> ResponseData | bytes: - """Send a command and return a parsed response RoborockBase type.""" - request_message = RequestMessage(method, params=params) - _LOGGER.debug( - "Sending command (%s, request_id=%s): %s, params=%s", self._name, request_message.request_id, method, params - ) - message = self._payload_encoder(request_message) - - future: asyncio.Future[ResponseData | bytes] = asyncio.Future() - - def find_response(response_message: RoborockMessage) -> None: - try: - decoded = self._decoder(response_message) - except RoborockException as ex: - _LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex) - return - if decoded is None: - return - _LOGGER.debug("Received response (%s, request_id=%s)", self._name, decoded.request_id) - if decoded.request_id == request_message.request_id: - if isinstance(decoded, ResponseMessage) and decoded.api_error: - future.set_exception(decoded.api_error) - else: - future.set_result(decoded.data) - - unsub = await self._channel.subscribe(find_response) - try: - await self._channel.publish(message) - result = await asyncio.wait_for(future, timeout=_TIMEOUT) - except TimeoutError as ex: - if self._health_manager: - await self._health_manager.on_timeout() - future.cancel() - raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex - finally: - unsub() - - if self._health_manager: - await self._health_manager.on_success() - return result - - -def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel: - """Create a V1 RPC channel using an MQTT channel.""" - return PayloadEncodedV1RpcChannel( - "mqtt", - mqtt_channel, - lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data), - decode_rpc_response, - health_manager=HealthManager(mqtt_channel.restart), - ) - - -def create_local_rpc_channel(local_channel: LocalChannel) -> V1RpcChannel: - """Create a V1 RPC channel using a local channel.""" - return PayloadEncodedV1RpcChannel( - "local", - local_channel, - lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST, version=local_channel.protocol_version), - decode_rpc_response, - ) - - -def create_map_rpc_channel( - mqtt_channel: MqttChannel, - security_data: SecurityData, -) -> V1RpcChannel: - """Create a V1 RPC channel that fetches map data. - - This will prefer local channels when available, falling back to MQTT - channels if not. If neither is available, an exception will be raised - when trying to send a command. - """ - return PayloadEncodedV1RpcChannel( - "map", - mqtt_channel, - lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data), - create_map_response_decoder(security_data=security_data), - ) diff --git a/roborock/protocols/v1_protocol.py b/roborock/protocols/v1_protocol.py index 213d8d31..34a596ae 100644 --- a/roborock/protocols/v1_protocol.py +++ b/roborock/protocols/v1_protocol.py @@ -12,9 +12,9 @@ from collections.abc import Callable from dataclasses import dataclass, field from enum import StrEnum -from typing import Any +from typing import Any, Protocol, TypeVar, overload -from roborock.data import RRiot +from roborock.data import RoborockBase, RRiot from roborock.exceptions import RoborockException, RoborockUnsupportedFeature from roborock.protocol import Utils from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol @@ -27,6 +27,7 @@ "SecurityData", "create_security_data", "decode_rpc_response", + "V1RpcChannel", ] CommandType = RoborockCommand | str @@ -208,3 +209,35 @@ def _decode_map_response(message: RoborockMessage) -> MapResponse | None: return MapResponse(request_id=request_id, data=decompressed) return _decode_map_response + + +_T = TypeVar("_T", bound=RoborockBase) + + +class V1RpcChannel(Protocol): + """Protocol for V1 RPC channels. + + This is a wrapper around a raw channel that provides a high-level interface + for sending commands and receiving responses. + """ + + @overload + async def send_command( + self, + method: CommandType, + *, + params: ParamsType = None, + ) -> Any: + """Send a command and return a decoded response.""" + ... + + @overload + async def send_command( + self, + method: CommandType, + *, + response_type: type[_T], + params: ParamsType = None, + ) -> _T: + """Send a command and return a parsed response RoborockBase type.""" + ... diff --git a/tests/devices/test_v1_channel.py b/tests/devices/test_v1_channel.py index b2c6acb9..8d2afcbc 100644 --- a/tests/devices/test_v1_channel.py +++ b/tests/devices/test_v1_channel.py @@ -108,7 +108,7 @@ def fake_next_int(*args) -> int: @pytest.fixture(name="mock_create_map_response_decoder") def setup_mock_map_decoder() -> Iterator[Mock]: """Mock the map response decoder to control its behavior in tests.""" - with patch("roborock.devices.v1_rpc_channel.create_map_response_decoder") as mock_create_decoder: + with patch("roborock.devices.v1_channel.create_map_response_decoder") as mock_create_decoder: yield mock_create_decoder @@ -277,7 +277,7 @@ async def test_v1_channel_send_command_local_preferred( # Send command mock_local_channel.response_queue.append(TEST_RESPONSE) result = await v1_channel.rpc_channel.send_command( - RoborockCommand.CHANGE_SOUND_VOLUME, + RoborockCommand.GET_STATUS, response_type=S5MaxStatus, ) @@ -290,7 +290,7 @@ async def test_v1_channel_send_command_local_fails( mock_mqtt_channel: Mock, mock_local_channel: Mock, ) -> None: - """Test case where sending with local connection fails.""" + """Test case where sending with local connection fails, falling back to MQTT.""" # Establish connections mock_mqtt_channel.response_queue.append(TEST_NETWORK_INFO_RESPONSE) @@ -300,12 +300,25 @@ async def test_v1_channel_send_command_local_fails( mock_local_channel.publish = Mock() mock_local_channel.publish.side_effect = RoborockException("Local failed") + # MQTT command succeeds + mock_mqtt_channel.response_queue.append(TEST_RESPONSE) + # Send command - with pytest.raises(RoborockException, match="Local failed"): - await v1_channel.rpc_channel.send_command( - RoborockCommand.CHANGE_SOUND_VOLUME, - response_type=S5MaxStatus, - ) + result = await v1_channel.rpc_channel.send_command( + RoborockCommand.GET_STATUS, + response_type=S5MaxStatus, + ) + + # Verify result + assert result.state == RoborockStateCode.cleaning + + # Verify local was attempted + mock_local_channel.publish.assert_called_once() + + # Verify MQTT was used + assert mock_mqtt_channel.published_messages + # The last message should be the command we sent + assert mock_mqtt_channel.published_messages[-1].protocol == RoborockMessageProtocol.RPC_REQUEST async def test_v1_channel_send_decoded_command_mqtt_only( @@ -323,7 +336,7 @@ async def test_v1_channel_send_decoded_command_mqtt_only( # Send command mock_mqtt_channel.response_queue.append(TEST_RESPONSE) result = await v1_channel.rpc_channel.send_command( - RoborockCommand.CHANGE_SOUND_VOLUME, + RoborockCommand.GET_STATUS, response_type=S5MaxStatus, ) @@ -503,9 +516,7 @@ async def test_v1_channel_command_encoding_validation( assert local_message.protocol == RoborockMessageProtocol.GENERAL_REQUEST -@patch("roborock.devices.v1_rpc_channel.create_map_response_decoder") async def test_v1_channel_send_map_command( - mock_create_decoder: Mock, v1_channel: V1Channel, mock_mqtt_channel: Mock, mock_create_map_response_decoder: Mock, diff --git a/tests/devices/test_v1_device.py b/tests/devices/test_v1_device.py index e33c93b2..16dcb8e1 100644 --- a/tests/devices/test_v1_device.py +++ b/tests/devices/test_v1_device.py @@ -13,7 +13,7 @@ from roborock.devices.device import RoborockDevice from roborock.devices.traits import v1 from roborock.devices.traits.v1.common import V1TraitMixin -from roborock.devices.v1_rpc_channel import decode_rpc_response +from roborock.protocols.v1_protocol import decode_rpc_response from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from .. import mock_data