Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion roborock/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,4 @@ async def get_status(self) -> Status:
This is a placeholder command and will likely be changed/moved in the future.
"""
status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus)
return await self._v1_channel.send_decoded_command(RoborockCommand.GET_STATUS, response_type=status_type)
return await self._v1_channel.rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type)
7 changes: 6 additions & 1 deletion roborock/devices/local_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def __init__(self, host: str, local_key: str):
self._encoder: Encoder = create_local_encoder(local_key)
self._queue_lock = asyncio.Lock()

@property
def is_connected(self) -> bool:
"""Check if the channel is currently connected."""
return self._is_connected

async def connect(self) -> None:
"""Connect to the device."""
if self._is_connected:
Expand Down Expand Up @@ -113,7 +118,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
else:
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)

async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
"""Send a command message and wait for the response message."""
if not self._transport or not self._is_connected:
raise RoborockConnectionException("Not connected to device")
Expand Down
2 changes: 1 addition & 1 deletion roborock/devices/mqtt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
else:
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)

async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
"""Send a command message and wait for the response message.

Returns the raw response message - caller is responsible for parsing.
Expand Down
77 changes: 18 additions & 59 deletions roborock/devices/v1_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,21 @@

import logging
from collections.abc import Callable
from typing import Any, TypeVar
from typing import TypeVar

from roborock.containers import HomeDataDevice, NetworkInfo, RoborockBase, UserData
from roborock.exceptions import RoborockException
from roborock.mqtt.session import MqttParams, MqttSession
from roborock.protocols.v1_protocol import (
CommandType,
ParamsType,
SecurityData,
create_mqtt_payload_encoder,
create_security_data,
decode_rpc_response,
encode_local_payload,
)
from roborock.roborock_message import RoborockMessage
from roborock.roborock_typing import RoborockCommand

from .local_channel import LocalChannel, LocalSession, create_local_session
from .mqtt_channel import MqttChannel
from .v1_rpc_channel import V1RpcChannel, create_combined_rpc_channel, create_mqtt_rpc_channel

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,9 +54,10 @@ def __init__(
"""
self._device_uid = device_uid
self._mqtt_channel = mqtt_channel
self._mqtt_payload_encoder = create_mqtt_payload_encoder(security_data)
self._mqtt_rpc_channel = create_mqtt_rpc_channel(mqtt_channel, security_data)
self._local_session = local_session
self._local_channel: LocalChannel | None = None
self._combined_rpc_channel: V1RpcChannel | None = None
self._mqtt_unsub: Callable[[], None] | None = None
self._local_unsub: Callable[[], None] | None = None
self._callback: Callable[[RoborockMessage], None] | None = None
Expand All @@ -76,6 +73,16 @@ def is_mqtt_connected(self) -> bool:
"""Return whether MQTT connection is available."""
return self._mqtt_unsub is not None

@property
def rpc_channel(self) -> V1RpcChannel:
"""Return the combined RPC channel prefers local with a fallback to MQTT."""
return self._combined_rpc_channel or self._mqtt_rpc_channel

@property
def mqtt_rpc_channel(self) -> V1RpcChannel:
"""Return the MQTT RPC channel."""
return self._mqtt_rpc_channel

async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
"""Subscribe to all messages from the device.

Expand Down Expand Up @@ -119,7 +126,9 @@ async def _get_networking_info(self) -> NetworkInfo:
This is a cloud only command used to get the local device's IP address.
"""
try:
return await self._send_mqtt_decoded_command(RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo)
return await self._mqtt_rpc_channel.send_command(
RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo
)
except RoborockException as e:
raise RoborockException(f"Network info failed for device {self._device_uid}") from e

Expand All @@ -136,59 +145,9 @@ async def _local_connect(self) -> Callable[[], None]:
except RoborockException as e:
self._local_channel = None
raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e

self._combined_rpc_channel = create_combined_rpc_channel(self._local_channel, self._mqtt_rpc_channel)
return await self._local_channel.subscribe(self._on_local_message)

async def send_decoded_command(
self,
method: CommandType,
*,
response_type: type[_T],
params: ParamsType = None,
) -> _T:
"""Send a command using the best available transport.

Will prefer local connection if available, falling back to MQTT.
"""
connection = "local" if self.is_local_connected else "mqtt"
_LOGGER.debug("Sending command (%s): %s, params=%s", connection, method, params)
if self._local_channel:
return await self._send_local_decoded_command(method, response_type=response_type, params=params)
return await self._send_mqtt_decoded_command(method, response_type=response_type, params=params)

async def _send_mqtt_raw_command(self, method: CommandType, params: ParamsType | None = None) -> dict[str, Any]:
"""Send a raw command and return a raw unparsed response."""
message = self._mqtt_payload_encoder(method, params)
_LOGGER.debug("Sending MQTT message for device %s: %s", self._device_uid, message)
response = await self._mqtt_channel.send_command(message)
return decode_rpc_response(response)

async def _send_mqtt_decoded_command(
self, method: CommandType, *, response_type: type[_T], params: ParamsType | None = None
) -> _T:
"""Send a command over MQTT and decode the response."""
decoded_response = await self._send_mqtt_raw_command(method, params)
return response_type.from_dict(decoded_response)

async def _send_local_raw_command(self, method: CommandType, params: ParamsType | None = None) -> dict[str, Any]:
"""Send a raw command over local connection."""
if not self._local_channel:
raise RoborockException("Local channel is not connected")

message = encode_local_payload(method, params)
_LOGGER.debug("Sending local message for device %s: %s", self._device_uid, message)
response = await self._local_channel.send_command(message)
return decode_rpc_response(response)

async def _send_local_decoded_command(
self, method: CommandType, *, response_type: type[_T], params: ParamsType | None = None
) -> _T:
"""Send a command over local connection and decode the response."""
if not self._local_channel:
raise RoborockException("Local channel is not connected")
decoded_response = await self._send_local_raw_command(method, params)
return response_type.from_dict(decoded_response)

def _on_mqtt_message(self, message: RoborockMessage) -> None:
"""Handle incoming MQTT messages."""
_LOGGER.debug("V1Channel received MQTT message from device %s: %s", self._device_uid, message)
Expand Down
148 changes: 148 additions & 0 deletions roborock/devices/v1_rpc_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""V1 Rpc Channel for Roborock devices.

This is a wrapper around the V1 channel that provides a higher level interface
for sending typed commands and receiving typed responses. This also provides
a simple interface for sending commands and receiving responses over both MQTT
and local connections, preferring local when available.
"""

import logging
from collections.abc import Callable
from typing import Any, Protocol, TypeVar, overload

from roborock.containers import RoborockBase
from roborock.protocols.v1_protocol import (
CommandType,
ParamsType,
SecurityData,
create_mqtt_payload_encoder,
decode_rpc_response,
encode_local_payload,
)
from roborock.roborock_message import RoborockMessage

from .local_channel import LocalChannel
from .mqtt_channel import MqttChannel

_LOGGER = logging.getLogger(__name__)


_T = TypeVar("_T", bound=RoborockBase)


class V1RpcChannel(Protocol):
"""Protocol for V1 RPC channels.

This is a wrapper around a raw channel that provides a high-level interface
for sending commands and receiving responses.
"""

@overload
async def send_command(
self,
method: CommandType,
*,
params: ParamsType = None,
) -> Any:
"""Send a command and return a decoded response."""
...

@overload
async def send_command(
self,
method: CommandType,
*,
response_type: type[_T],
params: ParamsType = None,
) -> _T:
"""Send a command and return a parsed response RoborockBase type."""
...


class BaseV1RpcChannel(V1RpcChannel):
"""Base implementation that provides the typed response logic."""

async def send_command(
self,
method: CommandType,
*,
response_type: type[_T] | None = None,
params: ParamsType = None,
) -> _T | Any:
"""Send a command and return either a decoded or parsed response."""
decoded_response = await self._send_raw_command(method, params=params)

if response_type is not None:
return response_type.from_dict(decoded_response)
return decoded_response

async def _send_raw_command(
self,
method: CommandType,
*,
params: ParamsType = None,
) -> Any:
"""Send a raw command and return the decoded response. Must be implemented by subclasses."""
raise NotImplementedError


class CombinedV1RpcChannel(BaseV1RpcChannel):
"""A V1 RPC channel that can use both local and MQTT channels, preferring local when available."""

def __init__(
self, local_channel: LocalChannel, local_rpc_channel: V1RpcChannel, mqtt_channel: V1RpcChannel
) -> None:
"""Initialize the combined channel with local and MQTT channels."""
self._local_channel = local_channel
self._local_rpc_channel = local_rpc_channel
self._mqtt_rpc_channel = mqtt_channel

async def _send_raw_command(
self,
method: CommandType,
*,
params: ParamsType = None,
) -> Any:
"""Send a command and return a parsed response RoborockBase type."""
if self._local_channel.is_connected:
return await self._local_rpc_channel.send_command(method, params=params)
return await self._mqtt_rpc_channel.send_command(method, params=params)


class PayloadEncodedV1RpcChannel(BaseV1RpcChannel):
"""Protocol for V1 channels that send encoded commands."""

def __init__(
self,
name: str,
channel: MqttChannel | LocalChannel,
payload_encoder: Callable[[CommandType, ParamsType], RoborockMessage],
) -> None:
"""Initialize the channel with a raw channel and an encoder function."""
self._name = name
self._channel = channel
self._payload_encoder = payload_encoder

async def _send_raw_command(
self,
method: CommandType,
*,
params: ParamsType = None,
) -> Any:
"""Send a command and return a parsed response RoborockBase type."""
_LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params)
message = self._payload_encoder(method, params)
response = await self._channel.send_message(message)
return decode_rpc_response(response)


def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:
"""Create a V1 RPC channel using an MQTT channel."""
payload_encoder = create_mqtt_payload_encoder(security_data)
return PayloadEncodedV1RpcChannel("mqtt", mqtt_channel, payload_encoder)


def create_combined_rpc_channel(local_channel: LocalChannel, mqtt_rpc_channel: V1RpcChannel) -> V1RpcChannel:
"""Create a V1 RPC channel that combines local and MQTT channels."""
local_rpc_channel = PayloadEncodedV1RpcChannel("local", local_channel, encode_local_payload)
return CombinedV1RpcChannel(local_channel, local_rpc_channel, mqtt_rpc_channel)
4 changes: 2 additions & 2 deletions tests/devices/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ async def test_device_connection(device: RoborockDevice, channel: AsyncMock) ->
async def test_device_get_status_command(device: RoborockDevice, channel: AsyncMock) -> None:
"""Test the device get_status command."""
# Mock response for get_status command
channel.send_decoded_command.return_value = STATUS
channel.rpc_channel.send_command.return_value = STATUS

# Test get_status and verify the command was sent
status = await device.get_status()
assert channel.send_decoded_command.called
assert channel.rpc_channel.send_command.called

# Verify the result
assert status is not None
Expand Down
Loading
Loading