Skip to content

Commit ea68323

Browse files
authored
Merge branch 'main' into b01_improvements
2 parents 28d35b3 + d40cc78 commit ea68323

22 files changed

+702
-749
lines changed

poetry.lock

Lines changed: 152 additions & 131 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "python-roborock"
3-
version = "2.35.0"
3+
version = "2.37.0"
44
description = "A package to control Roborock vacuums."
55
authors = ["humbertogontijo <humbertogontijo@users.noreply.github.com>"]
66
license = "GPL-3.0-only"

roborock/containers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,29 @@ class NetworkInfo(RoborockBase):
733733
rssi: int | None = None
734734

735735

736+
@dataclass
737+
class AppInitStatusLocalInfo(RoborockBase):
738+
location: str
739+
bom: str | None = None
740+
featureset: int | None = None
741+
language: str | None = None
742+
logserver: str | None = None
743+
wifiplan: str | None = None
744+
timezone: str | None = None
745+
name: str | None = None
746+
747+
748+
@dataclass
749+
class AppInitStatus(RoborockBase):
750+
local_info: AppInitStatusLocalInfo
751+
feature_info: list[int]
752+
new_feature_info: int
753+
new_feature_info_str: str
754+
new_feature_info_2: int | None = None
755+
carriage_type: int | None = None
756+
dsp_version: int | None = None
757+
758+
736759
@dataclass
737760
class DeviceData(RoborockBase):
738761
device: HomeDataDevice

roborock/devices/a01_channel.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
"""Thin wrapper around the MQTT channel for Roborock A01 devices."""
22

3-
from __future__ import annotations
4-
3+
import asyncio
54
import logging
65
from typing import Any, overload
76

7+
from roborock.exceptions import RoborockException
88
from roborock.protocols.a01_protocol import (
99
decode_rpc_response,
1010
encode_mqtt_payload,
1111
)
12-
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol
12+
from roborock.roborock_message import (
13+
RoborockDyadDataProtocol,
14+
RoborockMessage,
15+
RoborockZeoProtocol,
16+
)
1317

1418
from .mqtt_channel import MqttChannel
1519

1620
_LOGGER = logging.getLogger(__name__)
21+
_TIMEOUT = 10.0
22+
23+
# Both RoborockDyadDataProtocol and RoborockZeoProtocol have the same
24+
# value for ID_QUERY
25+
_ID_QUERY = int(RoborockDyadDataProtocol.ID_QUERY)
1726

1827

1928
@overload
@@ -39,5 +48,46 @@ async def send_decoded_command(
3948
"""Send a command on the MQTT channel and get a decoded response."""
4049
_LOGGER.debug("Sending MQTT command: %s", params)
4150
roborock_message = encode_mqtt_payload(params)
42-
response = await mqtt_channel.send_message(roborock_message)
43-
return decode_rpc_response(response) # type: ignore[return-value]
51+
52+
# For commands that set values: send the command and do not
53+
# block waiting for a response. Queries are handled below.
54+
param_values = {int(k): v for k, v in params.items()}
55+
if not (query_values := param_values.get(_ID_QUERY)):
56+
await mqtt_channel.publish(roborock_message)
57+
return {}
58+
59+
# Merge any results together than contain the requested data. This
60+
# does not use a future since it needs to merge results across responses.
61+
# This could be simplified if we can assume there is a single response.
62+
finished = asyncio.Event()
63+
result: dict[int, Any] = {}
64+
65+
def find_response(response_message: RoborockMessage) -> None:
66+
"""Handle incoming messages and resolve the future."""
67+
try:
68+
decoded = decode_rpc_response(response_message)
69+
except RoborockException as ex:
70+
_LOGGER.info("Failed to decode a01 message: %s: %s", response_message, ex)
71+
return
72+
for key, value in decoded.items():
73+
if key in query_values:
74+
result[key] = value
75+
if len(result) != len(query_values):
76+
_LOGGER.debug("Incomplete query response: %s != %s", result, query_values)
77+
return
78+
_LOGGER.debug("Received query response: %s", result)
79+
if not finished.is_set():
80+
finished.set()
81+
82+
unsub = await mqtt_channel.subscribe(find_response)
83+
84+
try:
85+
await mqtt_channel.publish(roborock_message)
86+
try:
87+
await asyncio.wait_for(finished.wait(), timeout=_TIMEOUT)
88+
except TimeoutError as ex:
89+
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
90+
finally:
91+
unsub()
92+
93+
return result # type: ignore[return-value]

roborock/devices/b01_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ async def send_decoded_command(
2424
"""Send a command on the MQTT channel and get a decoded response."""
2525
_LOGGER.debug("Sending MQTT command: %s", params)
2626
roborock_message = encode_mqtt_payload(dps, command, params)
27-
await mqtt_channel.send_message_no_wait(roborock_message)
27+
await mqtt_channel.publish(roborock_message)

roborock/devices/local_channel.py

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
from collections.abc import Callable
66
from dataclasses import dataclass
7-
from json import JSONDecodeError
87

98
from roborock.exceptions import RoborockConnectionException, RoborockException
109
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
@@ -46,11 +45,8 @@ def __init__(self, host: str, local_key: str):
4645
self._subscribers: list[Callable[[RoborockMessage], None]] = []
4746
self._is_connected = False
4847

49-
# RPC support
50-
self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
5148
self._decoder: Decoder = create_local_decoder(local_key)
5249
self._encoder: Encoder = create_local_encoder(local_key)
53-
self._queue_lock = asyncio.Lock()
5450

5551
@property
5652
def is_connected(self) -> bool:
@@ -87,7 +83,6 @@ def _data_received(self, data: bytes) -> None:
8783
return
8884
for message in messages:
8985
_LOGGER.debug("Received message: %s", message)
90-
asyncio.create_task(self._resolve_future_with_lock(message))
9186
for callback in self._subscribers:
9287
try:
9388
callback(message)
@@ -109,48 +104,24 @@ def unsubscribe() -> None:
109104

110105
return unsubscribe
111106

112-
async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
113-
"""Resolve waiting future with proper locking."""
114-
if (request_id := message.get_request_id()) is None:
115-
_LOGGER.debug("Received message with no request_id")
116-
return
117-
async with self._queue_lock:
118-
if (future := self._waiting_queue.pop(request_id, None)) is not None:
119-
future.set_result(message)
120-
else:
121-
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
122-
123-
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
124-
"""Send a command message and wait for the response message."""
107+
async def publish(self, message: RoborockMessage) -> None:
108+
"""Send a command message.
109+
110+
The caller is responsible for associating the message with its response.
111+
"""
125112
if not self._transport or not self._is_connected:
126113
raise RoborockConnectionException("Not connected to device")
127114

128-
try:
129-
if (request_id := message.get_request_id()) is None:
130-
raise RoborockException("Message must have a request_id for RPC calls")
131-
except (ValueError, JSONDecodeError) as err:
132-
_LOGGER.exception("Error getting request_id from message: %s", err)
133-
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
134-
135-
future: asyncio.Future[RoborockMessage] = asyncio.Future()
136-
async with self._queue_lock:
137-
if request_id in self._waiting_queue:
138-
raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
139-
self._waiting_queue[request_id] = future
140-
141115
try:
142116
encoded_msg = self._encoder(message)
117+
except Exception as err:
118+
_LOGGER.exception("Error encoding MQTT message: %s", err)
119+
raise RoborockException(f"Failed to encode MQTT message: {err}") from err
120+
try:
143121
self._transport.write(encoded_msg)
144-
return await asyncio.wait_for(future, timeout=timeout)
145-
except asyncio.TimeoutError as ex:
146-
async with self._queue_lock:
147-
self._waiting_queue.pop(request_id, None)
148-
raise RoborockException(f"Command timed out after {timeout}s") from ex
149-
except Exception:
122+
except Exception as err:
150123
logging.exception("Uncaught error sending command")
151-
async with self._queue_lock:
152-
self._waiting_queue.pop(request_id, None)
153-
raise
124+
raise RoborockException(f"Failed to send message: {message}") from err
154125

155126

156127
# This module provides a factory function to create LocalChannel instances.

roborock/devices/mqtt_channel.py

Lines changed: 19 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""Modules for communicating with specific Roborock devices over MQTT."""
22

3-
import asyncio
43
import logging
54
from collections.abc import Callable
6-
from json import JSONDecodeError
75

86
from roborock.containers import HomeDataDevice, RRiot, UserData
97
from roborock.exceptions import RoborockException
10-
from roborock.mqtt.session import MqttParams, MqttSession
8+
from roborock.mqtt.session import MqttParams, MqttSession, MqttSessionException
119
from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
1210
from roborock.roborock_message import RoborockMessage
1311

@@ -30,17 +28,16 @@ def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot:
3028
self._rriot = rriot
3129
self._mqtt_params = mqtt_params
3230

33-
# RPC support
34-
self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
3531
self._decoder = create_mqtt_decoder(local_key)
3632
self._encoder = create_mqtt_encoder(local_key)
37-
self._queue_lock = asyncio.Lock()
38-
self._mqtt_unsub: Callable[[], None] | None = None
3933

4034
@property
4135
def is_connected(self) -> bool:
42-
"""Return true if the channel is connected."""
43-
return (self._mqtt_unsub is not None) and self._mqtt_session.connected
36+
"""Return true if the channel is connected.
37+
38+
This passes through the underlying MQTT session's connected state.
39+
"""
40+
return self._mqtt_session.connected
4441

4542
@property
4643
def _publish_topic(self) -> str:
@@ -57,9 +54,6 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab
5754
5855
The callback will be called with the message payload when a message is received.
5956
60-
All messages received will be processed through the provided callback, even
61-
those sent in response to the `send_command` command.
62-
6357
Returns a callable that can be used to unsubscribe from the topic.
6458
"""
6559

@@ -69,75 +63,29 @@ def message_handler(payload: bytes) -> None:
6963
return
7064
for message in messages:
7165
_LOGGER.debug("Received message: %s", message)
72-
if message.version != b"B01":
73-
asyncio.create_task(self._resolve_future_with_lock(message))
7466
try:
7567
callback(message)
7668
except Exception as e:
7769
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
7870

79-
self._mqtt_unsub = await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
80-
81-
def unsub_wrapper() -> None:
82-
if self._mqtt_unsub is not None:
83-
self._mqtt_unsub()
84-
self._mqtt_unsub = None
85-
86-
return unsub_wrapper
87-
88-
async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
89-
"""Resolve waiting future with proper locking."""
90-
if (request_id := message.get_request_id()) is None:
91-
_LOGGER.debug("Received message with no request_id")
92-
return
93-
async with self._queue_lock:
94-
if (future := self._waiting_queue.pop(request_id, None)) is not None:
95-
future.set_result(message)
96-
else:
97-
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
71+
return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
9872

99-
async def send_message_no_wait(self, message: RoborockMessage) -> None:
100-
"""Send a command message without waiting for a response."""
101-
try:
102-
encoded_msg = self._encoder(message)
103-
await self._mqtt_session.publish(self._publish_topic, encoded_msg)
104-
except Exception:
105-
logging.exception("Uncaught error sending command")
106-
raise
107-
108-
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
109-
"""Send a command message and wait for the response message.
73+
async def publish(self, message: RoborockMessage) -> None:
74+
"""Publish a command message.
11075
111-
Returns the raw response message - caller is responsible for parsing.
76+
The caller is responsible for handling any responses and associating them
77+
with the incoming request.
11278
"""
113-
try:
114-
if (request_id := message.get_request_id()) is None:
115-
raise RoborockException("Message must have a request_id for RPC calls")
116-
except (ValueError, JSONDecodeError) as err:
117-
_LOGGER.exception("Error getting request_id from message: %s", err)
118-
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
119-
120-
future: asyncio.Future[RoborockMessage] = asyncio.Future()
121-
async with self._queue_lock:
122-
if request_id in self._waiting_queue:
123-
raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
124-
self._waiting_queue[request_id] = future
125-
12679
try:
12780
encoded_msg = self._encoder(message)
128-
await self._mqtt_session.publish(self._publish_topic, encoded_msg)
129-
130-
return await asyncio.wait_for(future, timeout=timeout)
131-
132-
except asyncio.TimeoutError as ex:
133-
async with self._queue_lock:
134-
self._waiting_queue.pop(request_id, None)
135-
raise RoborockException(f"Command timed out after {timeout}s") from ex
136-
except Exception:
137-
logging.exception("Uncaught error sending command")
138-
async with self._queue_lock:
139-
self._waiting_queue.pop(request_id, None)
140-
raise
81+
except Exception as e:
82+
_LOGGER.exception("Error encoding MQTT message: %s", e)
83+
raise RoborockException(f"Failed to encode MQTT message: {e}") from e
84+
try:
85+
return await self._mqtt_session.publish(self._publish_topic, encoded_msg)
86+
except MqttSessionException as e:
87+
_LOGGER.exception("Error publishing MQTT message: %s", e)
88+
raise RoborockException(f"Failed to publish MQTT message: {e}") from e
14189

14290

14391
def create_mqtt_channel(

roborock/devices/traits/b01/props.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ def __init__(self, channel: MqttChannel) -> None:
2727

2828
async def query_values(self, props: list[RoborockB01Props]) -> None:
2929
"""Query the device for the values of the given Dyad protocols."""
30-
return await send_decoded_command(
30+
await send_decoded_command(
3131
self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params={"property": props}
3232
)

roborock/devices/v1_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def is_local_connected(self) -> bool:
7979
@property
8080
def is_mqtt_connected(self) -> bool:
8181
"""Return whether MQTT connection is available."""
82-
return self._mqtt_unsub is not None
82+
return self._mqtt_unsub is not None and self._mqtt_channel.is_connected
8383

8484
@property
8585
def rpc_channel(self) -> V1RpcChannel:

0 commit comments

Comments
 (0)