Skip to content

Commit 9fe6575

Browse files
committed
refactor: simplify future usage within the api clients
1 parent 9100bbf commit 9fe6575

File tree

10 files changed

+85
-70
lines changed

10 files changed

+85
-70
lines changed

roborock/api.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import logging
88
import secrets
99
import time
10-
from collections.abc import Coroutine
1110
from typing import Any
1211

1312
from .containers import (
@@ -16,7 +15,6 @@
1615
from .exceptions import (
1716
RoborockTimeout,
1817
UnknownMethodError,
19-
VacuumError,
2018
)
2119
from .roborock_future import RoborockFuture
2220
from .roborock_message import (
@@ -89,20 +87,18 @@ async def validate_connection(self) -> None:
8987
await self.async_disconnect()
9088
await self.async_connect()
9189

92-
async def _wait_response(self, request_id: int, queue: RoborockFuture) -> tuple[Any, VacuumError | None]:
90+
async def _wait_response(self, request_id: int, queue: RoborockFuture) -> Any:
9391
try:
94-
(response, err) = await queue.async_get(self.queue_timeout)
92+
response = await queue.async_get(self.queue_timeout)
9593
if response == "unknown_method":
9694
raise UnknownMethodError("Unknown method")
97-
return response, err
95+
return response
9896
except (asyncio.TimeoutError, asyncio.CancelledError):
9997
raise RoborockTimeout(f"id={request_id} Timeout after {self.queue_timeout} seconds") from None
10098
finally:
10199
self._waiting_queue.pop(request_id, None)
102100

103-
def _async_response(
104-
self, request_id: int, protocol_id: int = 0
105-
) -> Coroutine[Any, Any, tuple[Any, VacuumError | None]]:
101+
def _async_response(self, request_id: int, protocol_id: int = 0) -> Any:
106102
queue = RoborockFuture(protocol_id)
107103
if request_id in self._waiting_queue:
108104
new_id = get_next_int(10000, 32767)
@@ -115,7 +111,7 @@ def _async_response(
115111
)
116112
request_id = new_id
117113
self._waiting_queue[request_id] = queue
118-
return self._wait_response(request_id, queue)
114+
return asyncio.ensure_future(self._wait_response(request_id, queue))
119115

120116
async def send_message(self, roborock_message: RoborockMessage):
121117
raise NotImplementedError

roborock/cloud_api.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import base64
54
import logging
65
import threading
76
import uuid
8-
from asyncio import Lock, Task
7+
from asyncio import Lock
98
from typing import Any
109
from urllib.parse import urlparse
1110

@@ -62,7 +61,7 @@ def on_connect(self, *args, **kwargs):
6261
message = f"Failed to connect ({mqtt.error_string(rc)})"
6362
self._logger.error(message)
6463
if connection_queue:
65-
connection_queue.resolve((None, VacuumError(message)))
64+
connection_queue.set_exception(VacuumError(message))
6665
return
6766
self._logger.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
6867
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
@@ -71,11 +70,11 @@ def on_connect(self, *args, **kwargs):
7170
message = f"Failed to subscribe ({mqtt.error_string(rc)})"
7271
self._logger.error(message)
7372
if connection_queue:
74-
connection_queue.resolve((None, VacuumError(message)))
73+
connection_queue.set_exception(VacuumError(message))
7574
return
7675
self._logger.info(f"Subscribed to topic {topic}")
7776
if connection_queue:
78-
connection_queue.resolve((True, None))
77+
connection_queue.set_result(True)
7978

8079
def on_message(self, *args, **kwargs):
8180
client, __, msg = args
@@ -94,7 +93,7 @@ def on_disconnect(self, *args, **kwargs):
9493
self.update_client_id()
9594
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
9695
if connection_queue:
97-
connection_queue.resolve((True, None))
96+
connection_queue.set_result(True)
9897
except Exception as ex:
9998
self._logger.exception(ex)
10099

@@ -112,53 +111,53 @@ def sync_start_loop(self) -> None:
112111
self._logger.info("Starting mqtt loop")
113112
super().loop_start()
114113

115-
def sync_disconnect(self) -> tuple[bool, Task[tuple[Any, VacuumError | None]] | None]:
114+
def sync_disconnect(self) -> Any:
116115
if not self.is_connected():
117-
return False, None
116+
return None
118117

119118
self._logger.info("Disconnecting from mqtt")
120-
disconnected_future = asyncio.ensure_future(self._async_response(DISCONNECT_REQUEST_ID))
119+
disconnected_future = self._async_response(DISCONNECT_REQUEST_ID)
121120
rc = super().disconnect()
122121

123122
if rc == mqtt.MQTT_ERR_NO_CONN:
124123
disconnected_future.cancel()
125-
return False, None
124+
return None
126125

127126
if rc != mqtt.MQTT_ERR_SUCCESS:
128127
disconnected_future.cancel()
129128
raise RoborockException(f"Failed to disconnect ({mqtt.error_string(rc)})")
130129

131-
return True, disconnected_future
130+
return disconnected_future
132131

133-
def sync_connect(self) -> tuple[bool, Task[tuple[Any, VacuumError | None]] | None]:
132+
def sync_connect(self) -> Any:
134133
if self.is_connected():
135134
self.sync_start_loop()
136-
return False, None
135+
return None
137136

138137
if self._mqtt_port is None or self._mqtt_host is None:
139138
raise RoborockException("Mqtt information was not entered. Cannot connect.")
140139

141140
self._logger.debug("Connecting to mqtt")
142-
connected_future = asyncio.ensure_future(self._async_response(CONNECT_REQUEST_ID))
141+
connected_future = self._async_response(CONNECT_REQUEST_ID)
143142
super().connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)
144143

145144
self.sync_start_loop()
146-
return True, connected_future
145+
return connected_future
147146

148147
async def async_disconnect(self) -> None:
149148
async with self._mutex:
150-
(disconnecting, disconnected_future) = self.sync_disconnect()
151-
if disconnecting and disconnected_future:
152-
(_, err) = await disconnected_future
153-
if err:
149+
if disconnected_future := self.sync_disconnect():
150+
try:
151+
await disconnected_future
152+
except VacuumError as err:
154153
raise RoborockException(err) from err
155154

156155
async def async_connect(self) -> None:
157156
async with self._mutex:
158-
(connecting, connected_future) = self.sync_connect()
159-
if connecting and connected_future:
160-
(_, err) = await connected_future
161-
if err:
157+
if connected_future := self.sync_connect():
158+
try:
159+
await connected_future
160+
except VacuumError as err:
162161
raise RoborockException(err) from err
163162

164163
def _send_msg_raw(self, msg: bytes) -> None:

roborock/roborock_future.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,19 @@ def __init__(self, protocol: int):
1414
self.fut: Future = Future()
1515
self.loop = self.fut.get_loop()
1616

17-
def _resolve(self, item: tuple[Any, VacuumError | None]) -> None:
17+
def _set_result(self, item: Any) -> None:
1818
if not self.fut.cancelled():
1919
self.fut.set_result(item)
2020

21-
def resolve(self, item: tuple[Any, VacuumError | None]) -> None:
22-
self.loop.call_soon_threadsafe(self._resolve, item)
21+
def set_result(self, item: Any) -> None:
22+
self.loop.call_soon_threadsafe(self._set_result, item)
23+
24+
def _set_exception(self, exc: VacuumError) -> None:
25+
if not self.fut.cancelled():
26+
self.fut.set_exception(exc)
27+
28+
def set_exception(self, exc: VacuumError) -> None:
29+
self.loop.call_soon_threadsafe(self._set_exception, exc)
2330

2431
async def async_get(self, timeout: float | int) -> tuple[Any, VacuumError | None]:
2532
try:

roborock/version_1_apis/roborock_client_v1.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -378,20 +378,17 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
378378
if queue and queue.protocol == protocol:
379379
error = data_point_response.get("error")
380380
if error:
381-
queue.resolve(
382-
(
383-
None,
384-
VacuumError(
385-
error.get("code"),
386-
error.get("message"),
387-
),
388-
)
381+
queue.set_exception(
382+
VacuumError(
383+
error.get("code"),
384+
error.get("message"),
385+
),
389386
)
390387
else:
391388
result = data_point_response.get("result")
392389
if isinstance(result, list) and len(result) == 1:
393390
result = result[0]
394-
queue.resolve((result, None))
391+
queue.set_result(result)
395392
else:
396393
self._logger.debug("Received response for unknown request id %s", request_id)
397394
else:
@@ -451,13 +448,13 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
451448
if queue:
452449
if isinstance(decompressed, list):
453450
decompressed = decompressed[0]
454-
queue.resolve((decompressed, None))
451+
queue.set_result(decompressed)
455452
else:
456453
self._logger.debug("Received response for unknown request id %s", request_id)
457454
else:
458455
queue = self._waiting_queue.get(data.seq)
459456
if queue:
460-
queue.resolve((data.payload, None))
457+
queue.set_result(data.payload)
461458
else:
462459
self._logger.debug("Received response for unknown request id %s", data.seq)
463460
except Exception as ex:

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import asyncio
2-
31
from roborock.local_api import RoborockLocalClient
42

53
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
4+
from ..exceptions import VacuumError
65
from ..protocol import MessageParser
76
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
87
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
@@ -53,16 +52,21 @@ async def send_message(self, roborock_message: RoborockMessage):
5352
if method:
5453
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
5554
# Send the command to the Roborock device
56-
async_response = asyncio.ensure_future(self._async_response(request_id, response_protocol))
55+
async_response = self._async_response(request_id, response_protocol)
5756
self._send_msg_raw(msg)
58-
(response, err) = await async_response
59-
self._diagnostic_data[method if method is not None else "unknown"] = {
57+
diagnostic_key = method if method is not None else "unknown"
58+
try:
59+
response = await async_response
60+
except VacuumError as err:
61+
self._diagnostic_data[diagnostic_key] = {
62+
"params": roborock_message.get_params(),
63+
"error": err,
64+
}
65+
raise CommandVacuumError(method, err) from err
66+
self._diagnostic_data[diagnostic_key] = {
6067
"params": roborock_message.get_params(),
6168
"response": response,
62-
"error": err,
6369
}
64-
if err:
65-
raise CommandVacuumError(method, err) from err
6670
if roborock_message.protocol == RoborockMessageProtocol.GENERAL_REQUEST:
6771
self._logger.debug(f"id={request_id} Response from method {roborock_message.get_method()}: {response}")
6872
if response == "retry":

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import base64
32

43
import paho.mqtt.client as mqtt
@@ -10,7 +9,7 @@
109
from roborock.cloud_api import RoborockMqttClient
1110

1211
from ..containers import DeviceData, UserData
13-
from ..exceptions import CommandVacuumError, RoborockException
12+
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
1413
from ..protocol import MessageParser, Utils
1514
from ..roborock_message import (
1615
RoborockMessage,
@@ -49,16 +48,21 @@ async def send_message(self, roborock_message: RoborockMessage):
4948
local_key = self.device_info.device.local_key
5049
msg = MessageParser.build(roborock_message, local_key, False)
5150
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
52-
async_response = asyncio.ensure_future(self._async_response(request_id, response_protocol))
51+
async_response = self._async_response(request_id, response_protocol)
5352
self._send_msg_raw(msg)
54-
(response, err) = await async_response
55-
self._diagnostic_data[method if method is not None else "unknown"] = {
53+
diagnostic_key = method if method is not None else "unknown"
54+
try:
55+
response = await async_response
56+
except VacuumError as err:
57+
self._diagnostic_data[diagnostic_key] = {
58+
"params": roborock_message.get_params(),
59+
"error": err,
60+
}
61+
raise CommandVacuumError(method, err) from err
62+
self._diagnostic_data[diagnostic_key] = {
5663
"params": roborock_message.get_params(),
5764
"response": response,
58-
"error": err,
5965
}
60-
if err:
61-
raise CommandVacuumError(method, err) from err
6266
if response_protocol == RoborockMessageProtocol.MAP_RESPONSE:
6367
self._logger.debug(f"id={request_id} Response from {method}: {len(response)} bytes")
6468
else:

roborock/version_a01_apis/roborock_client_a01.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
135135
converted_response = entries[data_point_protocol].post_process_fn(data_point)
136136
queue = self._waiting_queue.get(int(data_point_number))
137137
if queue and queue.protocol == protocol:
138-
queue.resolve((converted_response, None))
138+
queue.set_result(converted_response)
139139

140140
async def update_values(
141141
self, dyad_data_protocols: list[RoborockDyadDataProtocol | RoborockZeoProtocol]

roborock/version_a01_apis/roborock_mqtt_client_a01.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def send_message(self, roborock_message: RoborockMessage):
4343
futures = []
4444
if "10000" in payload["dps"]:
4545
for dps in json.loads(payload["dps"]["10000"]):
46-
futures.append(asyncio.ensure_future(self._async_response(dps, response_protocol)))
46+
futures.append(self._async_response(dps, response_protocol))
4747
self._send_msg_raw(m)
4848
responses = await asyncio.gather(*futures, return_exceptions=True)
4949
dps_responses: dict[int, typing.Any] = {}
@@ -54,7 +54,7 @@ async def send_message(self, roborock_message: RoborockMessage):
5454
self._logger.warning("Timed out get req for %s after %s s", dps, self.queue_timeout)
5555
dps_responses[dps] = None
5656
else:
57-
dps_responses[dps] = response[0]
57+
dps_responses[dps] = response
5858
return dps_responses
5959

6060
async def update_values(

tests/test_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ async def test_can_create_mqtt_roborock():
5151
async def test_sync_connect(mqtt_client):
5252
with patch("paho.mqtt.client.Client.connect", return_value=mqtt.MQTT_ERR_SUCCESS):
5353
with patch("paho.mqtt.client.Client.loop_start", return_value=mqtt.MQTT_ERR_SUCCESS):
54-
connecting, connected_future = mqtt_client.sync_connect()
55-
assert connecting is True
54+
connected_future = mqtt_client.sync_connect()
5655
assert connected_future is not None
5756

5857
connected_future.cancel()

tests/test_queue.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
from roborock.exceptions import VacuumError
56
from roborock.roborock_future import RoborockFuture
67

78

@@ -10,10 +11,18 @@ def test_can_create():
1011

1112

1213
@pytest.mark.asyncio
13-
async def test_put():
14+
async def test_set_result():
1415
rq = RoborockFuture(1)
15-
rq.resolve(("test", None))
16-
assert await rq.async_get(1) == ("test", None)
16+
rq.set_result("test")
17+
assert await rq.async_get(1) == "test"
18+
19+
20+
@pytest.mark.asyncio
21+
async def test_set_exception():
22+
rq = RoborockFuture(1)
23+
rq.set_exception(VacuumError("test"))
24+
with pytest.raises(VacuumError):
25+
assert await rq.async_get(1)
1726

1827

1928
@pytest.mark.asyncio

0 commit comments

Comments
 (0)