Skip to content

Commit 7613930

Browse files
authored
Merge branch 'main' into 325adependabot/pip/paho-mqtt-2.1.0
2 parents 2a84279 + 5add0da commit 7613930

File tree

5 files changed

+152
-34
lines changed

5 files changed

+152
-34
lines changed

roborock/cloud_api.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,44 @@
2121
DISCONNECT_REQUEST_ID = 1
2222

2323

24-
class RoborockMqttClient(RoborockClient, mqtt.Client, ABC):
24+
class _Mqtt(mqtt.Client):
25+
"""Internal MQTT client.
26+
27+
This is a subclass of the Paho MQTT client that adds some additional functionality
28+
for error cases where things get stuck.
29+
"""
30+
2531
_thread: threading.Thread
2632
_client_id: str
2733

34+
def __init__(self) -> None:
35+
"""Initialize the MQTT client."""
36+
super().__init__(protocol=mqtt.MQTTv5)
37+
self.reset_client_id()
38+
39+
def reset_client_id(self):
40+
"""Generate a new client id to make a new session when reconnecting."""
41+
self._client_id = mqtt.base62(uuid.uuid4().int, padding=22)
42+
43+
def maybe_restart_loop(self) -> None:
44+
"""Ensure that the MQTT loop is running in case it previously exited."""
45+
if not self._thread or not self._thread.is_alive():
46+
if self._thread:
47+
_LOGGER.info("Stopping mqtt loop")
48+
super().loop_stop()
49+
_LOGGER.info("Starting mqtt loop")
50+
super().loop_start()
51+
52+
53+
class RoborockMqttClient(RoborockClient, ABC):
54+
"""Roborock MQTT client base class."""
55+
2856
def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout: int = 10) -> None:
57+
"""Initialize the Roborock MQTT client."""
2958
rriot = user_data.rriot
3059
if rriot is None:
3160
raise RoborockException("Got no rriot data from user_data")
3261
RoborockClient.__init__(self, device_info, queue_timeout)
33-
mqtt.Client.__init__(self, protocol=mqtt.MQTTv5)
3462
self._mqtt_user = rriot.u
3563
self._hashed_user = md5hex(self._mqtt_user + ":" + rriot.k)[2:10]
3664
url = urlparse(rriot.r.m)
@@ -39,16 +67,21 @@ def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout:
3967
self._mqtt_host = str(url.hostname)
4068
self._mqtt_port = url.port
4169
self._mqtt_ssl = url.scheme == "ssl"
70+
71+
self._mqtt_client = _Mqtt()
72+
self._mqtt_client.on_connect = self._mqtt_on_connect
73+
self._mqtt_client.on_message = self._mqtt_on_message
74+
self._mqtt_client.on_disconnect = self._mqtt_on_disconnect
4275
if self._mqtt_ssl:
43-
super().tls_set()
76+
self._mqtt_client.tls_set()
77+
4478
self._mqtt_password = rriot.s
4579
self._hashed_password = md5hex(self._mqtt_password + ":" + rriot.k)[16:]
46-
super().username_pw_set(self._hashed_user, self._hashed_password)
80+
self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password)
4781
self._waiting_queue: dict[int, RoborockFuture] = {}
4882
self._mutex = Lock()
49-
self.update_client_id()
5083

51-
def on_connect(self, *args, **kwargs):
84+
def _mqtt_on_connect(self, *args, **kwargs):
5285
_, __, ___, rc, ____ = args
5386
connection_queue = self._waiting_queue.get(CONNECT_REQUEST_ID)
5487
if rc != mqtt.MQTT_ERR_SUCCESS:
@@ -59,7 +92,7 @@ def on_connect(self, *args, **kwargs):
5992
return
6093
self._logger.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
6194
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
62-
(result, mid) = self.subscribe(topic)
95+
(result, mid) = self._mqtt_client.subscribe(topic)
6396
if result != 0:
6497
message = f"Failed to subscribe ({mqtt.error_string(rc)})"
6598
self._logger.error(message)
@@ -70,48 +103,38 @@ def on_connect(self, *args, **kwargs):
70103
if connection_queue:
71104
connection_queue.set_result(True)
72105

73-
def on_message(self, *args, **kwargs):
106+
def _mqtt_on_message(self, *args, **kwargs):
74107
client, __, msg = args
75108
try:
76109
messages, _ = MessageParser.parse(msg.payload, local_key=self.device_info.device.local_key)
77110
super().on_message_received(messages)
78111
except Exception as ex:
79112
self._logger.exception(ex)
80113

81-
def on_disconnect(self, *args, **kwargs):
114+
def _mqtt_on_disconnect(self, *args, **kwargs):
82115
_, __, rc, ___ = args
83116
try:
84117
exc = RoborockException(mqtt.error_string(rc)) if rc != mqtt.MQTT_ERR_SUCCESS else None
85118
super().on_connection_lost(exc)
86119
if rc == mqtt.MQTT_ERR_PROTOCOL:
87-
self.update_client_id()
120+
self._mqtt_client.reset_client_id()
88121
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
89122
if connection_queue:
90123
connection_queue.set_result(True)
91124
except Exception as ex:
92125
self._logger.exception(ex)
93126

94-
def update_client_id(self):
95-
self._client_id = mqtt.base62(uuid.uuid4().int, padding=22)
96-
97-
def sync_stop_loop(self) -> None:
98-
if self._thread:
99-
self._logger.info("Stopping mqtt loop")
100-
super().loop_stop()
101-
102-
def sync_start_loop(self) -> None:
103-
if not self._thread or not self._thread.is_alive():
104-
self.sync_stop_loop()
105-
self._logger.info("Starting mqtt loop")
106-
super().loop_start()
127+
def is_connected(self) -> bool:
128+
"""Check if the mqtt client is connected."""
129+
return self._mqtt_client.is_connected()
107130

108131
def sync_disconnect(self) -> Any:
109132
if not self.is_connected():
110133
return None
111134

112135
self._logger.info("Disconnecting from mqtt")
113136
disconnected_future = self._async_response(DISCONNECT_REQUEST_ID)
114-
rc = super().disconnect()
137+
rc = self._mqtt_client.disconnect()
115138

116139
if rc == mqtt.MQTT_ERR_NO_CONN:
117140
disconnected_future.cancel()
@@ -125,17 +148,16 @@ def sync_disconnect(self) -> Any:
125148

126149
def sync_connect(self) -> Any:
127150
if self.is_connected():
128-
self.sync_start_loop()
151+
self._mqtt_client.maybe_restart_loop()
129152
return None
130153

131154
if self._mqtt_port is None or self._mqtt_host is None:
132155
raise RoborockException("Mqtt information was not entered. Cannot connect.")
133156

134157
self._logger.debug("Connecting to mqtt")
135158
connected_future = self._async_response(CONNECT_REQUEST_ID)
136-
super().connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)
137-
138-
self.sync_start_loop()
159+
self._mqtt_client.connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)
160+
self._mqtt_client.maybe_restart_loop()
139161
return connected_future
140162

141163
async def async_disconnect(self) -> None:
@@ -155,6 +177,8 @@ async def async_connect(self) -> None:
155177
raise RoborockException(err) from err
156178

157179
def _send_msg_raw(self, msg: bytes) -> None:
158-
info = self.publish(f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}", msg)
180+
info = self._mqtt_client.publish(
181+
f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}", msg
182+
)
159183
if info.rc != mqtt.MQTT_ERR_SUCCESS:
160184
raise RoborockException(f"Failed to publish ({mqtt.error_string(info.rc)})")

roborock/local_api.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55
from abc import ABC
66
from asyncio import Lock, TimerHandle, Transport
7+
from collections.abc import Callable
8+
from dataclasses import dataclass
79

810
import async_timeout
911

@@ -16,7 +18,15 @@
1618
_LOGGER = logging.getLogger(__name__)
1719

1820

19-
class RoborockLocalClient(RoborockClient, asyncio.Protocol, ABC):
21+
@dataclass
22+
class _LocalProtocol(asyncio.Protocol):
23+
"""Callbacks for the Roborock local client transport."""
24+
25+
messages_cb: Callable[[bytes], None]
26+
connection_lost_cb: Callable[[Exception | None], None]
27+
28+
29+
class RoborockLocalClient(RoborockClient, ABC):
2030
"""Roborock local client base class."""
2131

2232
def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
@@ -31,15 +41,18 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
3141
self._mutex = Lock()
3242
self.keep_alive_task: TimerHandle | None = None
3343
RoborockClient.__init__(self, device_data, queue_timeout)
44+
self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)
3445

35-
def data_received(self, message):
46+
def _data_received(self, message):
47+
"""Called when data is received from the transport."""
3648
if self.remaining:
3749
message = self.remaining + message
3850
self.remaining = b""
3951
parser_msg, self.remaining = MessageParser.parse(message, local_key=self.device_info.device.local_key)
4052
self.on_message_received(parser_msg)
4153

42-
def connection_lost(self, exc: Exception | None):
54+
def _connection_lost(self, exc: Exception | None):
55+
"""Called when the transport connection is lost."""
4356
self.sync_disconnect()
4457
self.on_connection_lost(exc)
4558

@@ -62,7 +75,7 @@ async def async_connect(self) -> None:
6275
async with async_timeout.timeout(self.queue_timeout):
6376
self._logger.debug(f"Connecting to {self.host}")
6477
self.transport, _ = await self.event_loop.create_connection( # type: ignore
65-
lambda: self, self.host, 58867
78+
lambda: self._local_protocol, self.host, 58867
6679
)
6780
self._logger.info(f"Connected to {self.host}")
6881
should_ping = True

tests/conftest.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
import io
23
import logging
34
import re
5+
from asyncio import Protocol
46
from collections.abc import Callable, Generator
57
from queue import Queue
68
from typing import Any
@@ -11,8 +13,9 @@
1113

1214
from roborock import HomeData, UserData
1315
from roborock.containers import DeviceData
16+
from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1
1417
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
15-
from tests.mock_data import HOME_DATA_RAW, USER_DATA
18+
from tests.mock_data import HOME_DATA_RAW, TEST_LOCAL_API_HOST, USER_DATA
1619

1720
_LOGGER = logging.getLogger(__name__)
1821

@@ -191,3 +194,43 @@ def mock_rest() -> aioresponses:
191194
payload={"api": None, "code": 200, "result": HOME_DATA_RAW, "status": "ok", "success": True},
192195
)
193196
yield mocked
197+
198+
199+
@pytest.fixture(name="mock_create_local_connection")
200+
def create_local_connection_fixture(request_handler: RequestHandler) -> Generator[None, None, None]:
201+
"""Fixture that overrides the transport creation to wire it up to the mock socket."""
202+
203+
async def create_connection(protocol_factory: Callable[[], Protocol], *args) -> tuple[Any, Any]:
204+
protocol = protocol_factory()
205+
206+
def handle_write(data: bytes) -> None:
207+
_LOGGER.debug("Received: %s", data)
208+
response = request_handler(data)
209+
if response is not None:
210+
_LOGGER.debug("Replying with %s", response)
211+
loop = asyncio.get_running_loop()
212+
loop.call_soon(protocol.data_received, response)
213+
214+
closed = asyncio.Event()
215+
216+
mock_transport = Mock()
217+
mock_transport.write = handle_write
218+
mock_transport.close = closed.set
219+
mock_transport.is_reading = lambda: not closed.is_set()
220+
221+
return (mock_transport, "proto")
222+
223+
with patch("roborock.api.get_running_loop_or_create_one") as mock_loop:
224+
mock_loop.return_value.create_connection.side_effect = create_connection
225+
yield
226+
227+
228+
@pytest.fixture(name="local_client")
229+
def local_client_fixture(mock_create_local_connection: None) -> Generator[RoborockLocalClientV1, None, None]:
230+
home_data = HomeData.from_dict(HOME_DATA_RAW)
231+
device_info = DeviceData(
232+
device=home_data.devices[0],
233+
model=home_data.products[0].model,
234+
host=TEST_LOCAL_API_HOST,
235+
)
236+
yield RoborockLocalClientV1(device_info)

tests/mock_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,4 @@
347347
GET_CODE_RESPONSE = {"code": 200, "msg": "success", "data": None}
348348
HASHED_USER = hashlib.md5((USER_ID + ":" + K_VALUE).encode()).hexdigest()[2:10]
349349
MQTT_PUBLISH_TOPIC = f"rr/m/o/{USER_ID}/{HASHED_USER}/{PRODUCT_ID}"
350+
TEST_LOCAL_API_HOST = "1.1.1.1"

tests/test_local_api_v1.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Tests for the Roborock Local Client V1."""
2+
3+
from queue import Queue
4+
5+
from roborock.protocol import MessageParser
6+
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
7+
from roborock.version_1_apis import RoborockLocalClientV1
8+
9+
from .mock_data import LOCAL_KEY
10+
11+
12+
def build_rpc_response(protocol: RoborockMessageProtocol, seq: int) -> bytes:
13+
"""Build an encoded RPC response message."""
14+
message = RoborockMessage(
15+
protocol=protocol,
16+
random=23,
17+
seq=seq,
18+
payload=b"ignored",
19+
)
20+
return MessageParser.build(message, local_key=LOCAL_KEY)
21+
22+
23+
async def test_async_connect(
24+
local_client: RoborockLocalClientV1,
25+
received_requests: Queue,
26+
response_queue: Queue,
27+
):
28+
"""Test that we can connect to the Roborock device."""
29+
response_queue.put(build_rpc_response(RoborockMessageProtocol.HELLO_RESPONSE, 1))
30+
response_queue.put(build_rpc_response(RoborockMessageProtocol.PING_RESPONSE, 2))
31+
32+
await local_client.async_connect()
33+
assert local_client.is_connected()
34+
assert received_requests.qsize() == 2
35+
36+
await local_client.async_disconnect()
37+
assert not local_client.is_connected()

0 commit comments

Comments
 (0)