Skip to content

Commit 0872691

Browse files
authored
chore: add end-to-end tests for the MQTT client (#278)
* test: add end-to-end tests for the MQTT client * test: extract connected client to a fixture style: fix formatting of tests refactor: extract variables for mock data used in mqtt tests style: fix lint errors in tests
1 parent b0611f0 commit 0872691

File tree

5 files changed

+381
-13
lines changed

5 files changed

+381
-13
lines changed

tests/conftest.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1+
import io
2+
import logging
13
import re
4+
from collections.abc import Callable, Generator
5+
from queue import Queue
6+
from typing import Any
7+
from unittest.mock import Mock, patch
28

39
import pytest
410
from aioresponses import aioresponses
@@ -8,9 +14,128 @@
814
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
915
from tests.mock_data import HOME_DATA_RAW, USER_DATA
1016

17+
_LOGGER = logging.getLogger(__name__)
18+
19+
20+
# Used by fixtures to handle incoming requests and prepare responses
21+
RequestHandler = Callable[[bytes], bytes | None]
22+
23+
24+
class FakeSocketHandler:
25+
"""Fake socket used by the test to simulate a connection to the broker.
26+
27+
The socket handler is used to intercept the socket send and recv calls and
28+
populate the response buffer with data to be sent back to the client. The
29+
handle request callback handles the incoming requests and prepares the responses.
30+
"""
31+
32+
def __init__(self, handle_request: RequestHandler) -> None:
33+
self.response_buf = io.BytesIO()
34+
self.handle_request = handle_request
35+
36+
def pending(self) -> int:
37+
"""Return the number of bytes in the response buffer."""
38+
return len(self.response_buf.getvalue())
39+
40+
def handle_socket_recv(self, read_size: int) -> bytes:
41+
"""Intercept a client recv() and populate the buffer."""
42+
if self.pending() == 0:
43+
raise BlockingIOError("No response queued")
44+
45+
self.response_buf.seek(0)
46+
data = self.response_buf.read(read_size)
47+
_LOGGER.debug("Response: 0x%s", data.hex())
48+
# Consume the rest of the data in the buffer
49+
remaining_data = self.response_buf.read()
50+
self.response_buf = io.BytesIO(remaining_data)
51+
return data
52+
53+
def handle_socket_send(self, client_request: bytes) -> int:
54+
"""Receive an incoming request from the client."""
55+
_LOGGER.debug("Request: 0x%s", client_request.hex())
56+
if (response := self.handle_request(client_request)) is not None:
57+
# Enqueue a response to be sent back to the client in the buffer.
58+
# The buffer will be emptied when the client calls recv() on the socket
59+
_LOGGER.debug("Queued: 0x%s", response.hex())
60+
self.response_buf.write(response)
61+
62+
return len(client_request)
63+
64+
65+
@pytest.fixture(name="received_requests")
66+
def received_requests_fixture() -> Queue[bytes]:
67+
"""Fixture that provides access to the received requests."""
68+
return Queue()
69+
70+
71+
@pytest.fixture(name="response_queue")
72+
def response_queue_fixture() -> Generator[Queue[bytes], None, None]:
73+
"""Fixture that provides access to the received requests."""
74+
response_queue: Queue[bytes] = Queue()
75+
yield response_queue
76+
assert response_queue.empty(), "Not all fake responses were consumed"
77+
78+
79+
@pytest.fixture(name="request_handler")
80+
def request_handler_fixture(received_requests: Queue[bytes], response_queue: Queue[bytes]) -> RequestHandler:
81+
"""Fixture records incoming requests and replies with responses from the queue."""
82+
83+
def handle_request(client_request: bytes) -> bytes | None:
84+
"""Handle an incoming request from the client."""
85+
received_requests.put(client_request)
86+
87+
# Insert a prepared response into the response buffer
88+
if not response_queue.empty():
89+
return response_queue.get()
90+
return None
91+
92+
return handle_request
93+
94+
95+
@pytest.fixture(name="fake_socket_handler")
96+
def fake_socket_handler_fixture(request_handler: RequestHandler) -> FakeSocketHandler:
97+
"""Fixture that creates a fake MQTT broker."""
98+
return FakeSocketHandler(request_handler)
99+
100+
101+
@pytest.fixture(name="mock_sock")
102+
def mock_sock_fixture(fake_socket_handler: FakeSocketHandler) -> Mock:
103+
"""Fixture that creates a mock socket connection and wires it to the handler."""
104+
mock_sock = Mock()
105+
mock_sock.recv = fake_socket_handler.handle_socket_recv
106+
mock_sock.send = fake_socket_handler.handle_socket_send
107+
mock_sock.pending = fake_socket_handler.pending
108+
return mock_sock
109+
110+
111+
@pytest.fixture(name="mock_create_connection")
112+
def create_connection_fixture(mock_sock: Mock) -> Generator[None, None, None]:
113+
"""Fixture that overrides the MQTT socket creation to wire it up to the mock socket."""
114+
with patch("paho.mqtt.client.socket.create_connection", return_value=mock_sock):
115+
yield
116+
117+
118+
@pytest.fixture(name="mock_select")
119+
def select_fixture(mock_sock: Mock, fake_socket_handler: FakeSocketHandler) -> Generator[None, None, None]:
120+
"""Fixture that overrides the MQTT client select calls to make select work on the mock socket.
121+
122+
This patch select to activate our mock socket when ready with data. Internal mqtt sockets are
123+
always ready since they are used internally to wake the select loop. Ours is ready if there
124+
is data in the buffer.
125+
"""
126+
127+
def is_ready(sock: Any) -> bool:
128+
return sock is not mock_sock or (fake_socket_handler.pending() > 0)
129+
130+
def handle_select(rlist: list, wlist: list, *args: Any) -> list:
131+
return [list(filter(is_ready, rlist)), list(filter(is_ready, wlist))]
132+
133+
with patch("paho.mqtt.client.select.select", side_effect=handle_select):
134+
yield
135+
11136

12137
@pytest.fixture(name="mqtt_client")
13-
def mqtt_client():
138+
def mqtt_client(mock_create_connection: None, mock_select: None) -> Generator[RoborockMqttClientV1, None, None]:
14139
user_data = UserData.from_dict(USER_DATA)
15140
home_data = HomeData.from_dict(HOME_DATA_RAW)
16141
device_info = DeviceData(

tests/mock_data.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
"""Mock data for Roborock tests."""
2+
3+
import hashlib
4+
25
# All data is based on a U.S. customer with a Roborock S7 MaxV Ultra
36
USER_EMAIL = "user@domain.com"
47

58
BASE_URL = "https://usiot.roborock.com"
69

10+
USER_ID = "user123"
11+
K_VALUE = "domain123"
712
USER_DATA = {
813
"uid": 123456,
914
"tokentype": "token_type",
@@ -14,21 +19,22 @@
1419
"country": "US",
1520
"nickname": "user_nickname",
1621
"rriot": {
17-
"u": "user123",
22+
"u": USER_ID,
1823
"s": "pass123",
1924
"h": "unknown123",
20-
"k": "domain123",
25+
"k": K_VALUE,
2126
"r": {
2227
"r": "US",
2328
"a": "https://api-us.roborock.com",
24-
"m": "ssl://mqtt-us.roborock.com:8883",
29+
"m": "tcp://mqtt-us.roborock.com:8883", # Skip SSL code in MQTT client library
2530
"l": "https://wood-us.roborock.com",
2631
},
2732
},
2833
"tuyaDeviceState": 2,
2934
"avatarurl": "https://files.roborock.com/iottest/default_avatar.png",
3035
}
31-
36+
LOCAL_KEY = "key123"
37+
PRODUCT_ID = "product-id-123"
3238
HOME_DATA_RAW = {
3339
"id": 123456,
3440
"name": "My Home",
@@ -37,7 +43,7 @@
3743
"geoName": None,
3844
"products": [
3945
{
40-
"id": "abc123",
46+
"id": PRODUCT_ID,
4147
"name": "Roborock S7 MaxV",
4248
"code": "a27",
4349
"model": "roborock.vacuum.a27",
@@ -199,7 +205,7 @@
199205
"name": "Roborock S7 MaxV",
200206
"attribute": None,
201207
"activeTime": 1672364449,
202-
"localKey": "key123",
208+
"localKey": LOCAL_KEY,
203209
"runtimeEnv": None,
204210
"timeZoneId": "America/Los_Angeles",
205211
"iconUrl": "no_url",
@@ -339,3 +345,5 @@
339345
}
340346

341347
GET_CODE_RESPONSE = {"code": 200, "msg": "success", "data": None}
348+
HASHED_USER = hashlib.md5((USER_ID + ":" + K_VALUE).encode()).hexdigest()[2:10]
349+
MQTT_PUBLISH_TOPIC = f"rr/m/o/{USER_ID}/{HASHED_USER}/{PRODUCT_ID}"

tests/mqtt_packet.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""Module for crafting MQTT packets.
2+
3+
This library is copied from the paho mqtt client library tests, with just the
4+
parts needed for some roborock messages. This message format in this file is
5+
not specific to roborock.
6+
"""
7+
8+
import struct
9+
10+
PROP_RECEIVE_MAXIMUM = 33
11+
PROP_TOPIC_ALIAS_MAXIMUM = 34
12+
13+
14+
def gen_uint16_prop(identifier: int, word: int) -> bytes:
15+
"""Generate a property with a uint16_t value."""
16+
prop = struct.pack("!BH", identifier, word)
17+
return prop
18+
19+
20+
def pack_varint(varint: int) -> bytes:
21+
"""Pack a variable integer."""
22+
s = b""
23+
while True:
24+
byte = varint % 128
25+
varint = varint // 128
26+
# If there are more digits to encode, set the top bit of this digit
27+
if varint > 0:
28+
byte = byte | 0x80
29+
30+
s = s + struct.pack("!B", byte)
31+
if varint == 0:
32+
return s
33+
34+
35+
def prop_finalise(props: bytes) -> bytes:
36+
"""Finalise the properties."""
37+
if props is None:
38+
return pack_varint(0)
39+
else:
40+
return pack_varint(len(props)) + props
41+
42+
43+
def gen_connack(flags=0, rc=0, properties=b"", property_helper=True):
44+
"""Generate a CONNACK packet."""
45+
if property_helper:
46+
if properties is not None:
47+
properties = (
48+
gen_uint16_prop(PROP_TOPIC_ALIAS_MAXIMUM, 10) + properties + gen_uint16_prop(PROP_RECEIVE_MAXIMUM, 20)
49+
)
50+
else:
51+
properties = b""
52+
properties = prop_finalise(properties)
53+
54+
packet = struct.pack("!BBBB", 32, 2 + len(properties), flags, rc) + properties
55+
56+
return packet
57+
58+
59+
def gen_suback(mid: int, qos: int) -> bytes:
60+
"""Generate a SUBACK packet."""
61+
return struct.pack("!BBHBB", 144, 2 + 1 + 1, mid, 0, qos)
62+
63+
64+
def _gen_short(cmd: int, reason_code: int) -> bytes:
65+
return struct.pack("!BBB", cmd, 1, reason_code)
66+
67+
68+
def gen_disconnect(reason_code: int = 0) -> bytes:
69+
"""Generate a DISCONNECT packet."""
70+
return _gen_short(0xE0, reason_code)
71+
72+
73+
def _gen_command_with_mid(cmd: int, mid: int, reason_code: int = 0) -> bytes:
74+
return struct.pack("!BBHB", cmd, 3, mid, reason_code)
75+
76+
77+
def gen_puback(mid: int, reason_code: int = -1) -> bytes:
78+
"""Generate a PUBACK packet."""
79+
return _gen_command_with_mid(64, mid, reason_code)
80+
81+
82+
def _pack_remaining_length(remaining_length: int) -> bytes:
83+
"""Pack a remaining length."""
84+
s = b""
85+
while True:
86+
byte = remaining_length % 128
87+
remaining_length = remaining_length // 128
88+
# If there are more digits to encode, set the top bit of this digit
89+
if remaining_length > 0:
90+
byte = byte | 0x80
91+
92+
s = s + struct.pack("!B", byte)
93+
if remaining_length == 0:
94+
return s
95+
96+
97+
def gen_publish(
98+
topic: str,
99+
payload: bytes | None = None,
100+
retain: bool = False,
101+
dup: bool = False,
102+
mid: int = 0,
103+
properties: bytes = b"",
104+
) -> bytes:
105+
"""Generate a PUBLISH packet."""
106+
if isinstance(topic, str):
107+
topic_b = topic.encode("utf-8")
108+
rl = 2 + len(topic_b)
109+
pack_format = "H" + str(len(topic_b)) + "s"
110+
111+
properties = prop_finalise(properties)
112+
rl += len(properties)
113+
# This will break if len(properties) > 127
114+
pack_format = pack_format + "%ds" % (len(properties))
115+
116+
if payload is not None:
117+
# payload = payload.encode("utf-8")
118+
rl = rl + len(payload)
119+
pack_format = pack_format + str(len(payload)) + "s"
120+
else:
121+
payload = b""
122+
pack_format = pack_format + "0s"
123+
124+
rlpacked = _pack_remaining_length(rl)
125+
cmd = 48
126+
if retain:
127+
cmd = cmd + 1
128+
if dup:
129+
cmd = cmd + 8
130+
131+
return struct.pack(
132+
"!B" + str(len(rlpacked)) + "s" + pack_format, cmd, rlpacked, len(topic_b), topic_b, properties, payload
133+
)

0 commit comments

Comments
 (0)