Skip to content

Commit 564ffa7

Browse files
committed
test: add end-to-end tests for the MQTT client
1 parent 8f779c3 commit 564ffa7

File tree

5 files changed

+367
-8
lines changed

5 files changed

+367
-8
lines changed

tests/conftest.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1+
import io
2+
import logging
3+
import queue
14
import re
5+
from collections.abc import Callable, Generator
6+
from typing import Any
7+
from unittest.mock import Mock, patch
28

39
import pytest
410
from aioresponses import aioresponses
@@ -8,9 +14,128 @@
814
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
915
from tests.mock_data import HOME_DATA_RAW, USER_DATA
1016

17+
_LOGGER = logging.getLogger(__name__)
18+
19+
20+
# Used by fixtures to handle incoming requests and prepare responses
21+
RequestHandler = Callable[[bytes], bytes | None]
22+
23+
24+
class FakeSocketHandler:
25+
"""Fake socket used by the test to simulate a connection to the broker.
26+
27+
The socket handler is used to intercept the socket send and recv calls and
28+
populate the response buffer with data to be sent back to the client. The
29+
handle request callback handles the incoming requests and prepares the responses.
30+
"""
31+
32+
def __init__(self, handle_request: RequestHandler) -> None:
33+
self.response_buf = io.BytesIO()
34+
self.handle_request = handle_request
35+
36+
def pending(self) -> int:
37+
"""Return the number of bytes in the response buffer."""
38+
return len(self.response_buf.getvalue())
39+
40+
def handle_socket_recv(self, read_size: int) -> bytes:
41+
"""Intercept a client recv() and populate the buffer."""
42+
if self.pending() == 0:
43+
raise BlockingIOError("No response queued")
44+
45+
self.response_buf.seek(0)
46+
data = self.response_buf.read(read_size)
47+
_LOGGER.debug("Response: 0x%s", data.hex())
48+
# Consume the rest of the data in the buffer
49+
remaining_data = self.response_buf.read()
50+
self.response_buf = io.BytesIO(remaining_data)
51+
return data
52+
53+
def handle_socket_send(self, client_request: bytes) -> int:
54+
"""Receive an incoming request from the client."""
55+
_LOGGER.debug("Request: 0x%s", client_request.hex())
56+
if (response := self.handle_request(client_request)) is not None:
57+
# Enqueue a response to be sent back to the client in the buffer.
58+
# The buffer will be emptied when the client calls recv() on the socket
59+
_LOGGER.debug("Queued: 0x%s", response.hex())
60+
self.response_buf.write(response)
61+
62+
return len(client_request)
63+
64+
65+
@pytest.fixture(name="received_requests")
66+
def received_requests_fixture() -> queue.Queue[bytes]:
67+
"""Fixture that provides access to the received requests."""
68+
return queue.Queue()
69+
70+
71+
@pytest.fixture(name="response_queue")
72+
def response_queue_fixture() -> queue.Queue[bytes]:
73+
"""Fixture that provides access to the received requests."""
74+
return queue.Queue()
75+
76+
77+
@pytest.fixture(name="request_handler")
78+
def request_handler_fixture(
79+
received_requests: queue.Queue[bytes], response_queue: queue.Queue[bytes]
80+
) -> RequestHandler:
81+
"""Fixture records incoming requests and replies with responses from the queue."""
82+
83+
def handle_request(client_request: bytes) -> bytes | None:
84+
"""Handle an incoming request from the client."""
85+
received_requests.put(client_request)
86+
87+
# Insert a prepared response into the response buffer
88+
if response_queue.qsize() > 0:
89+
return response_queue.get()
90+
return None
91+
92+
return handle_request
93+
94+
95+
@pytest.fixture(name="fake_socket_handler")
96+
def fake_socket_handler_fixture(request_handler: RequestHandler) -> FakeSocketHandler:
97+
"""Fixture that creates a fake MQTT broker."""
98+
return FakeSocketHandler(request_handler)
99+
100+
101+
@pytest.fixture(name="mock_sock")
102+
def mock_sock_fixture(fake_socket_handler: FakeSocketHandler) -> Mock:
103+
"""Fixture that creates a mock socket connection and wires it to the handler."""
104+
mock_sock = Mock()
105+
mock_sock.recv = fake_socket_handler.handle_socket_recv
106+
mock_sock.send = fake_socket_handler.handle_socket_send
107+
mock_sock.pending = fake_socket_handler.pending
108+
return mock_sock
109+
110+
111+
@pytest.fixture(name="mock_create_connection")
112+
def create_connection_fixture(mock_sock: Mock) -> Generator[None, None, None]:
113+
"""Fixture that overrides the MQTT socket creation to wire it up to the mock socket."""
114+
with patch("paho.mqtt.client.socket.create_connection", return_value=mock_sock):
115+
yield
116+
117+
118+
@pytest.fixture(name="mock_select")
119+
def select_fixture(mock_sock: Mock, fake_socket_handler: FakeSocketHandler) -> Generator[None, None, None]:
120+
"""Fixture that overrides the MQTT client select calls to make select work on the mock socket.
121+
122+
This patch select to activate our mock socket when ready with data. Internal mqtt sockets are
123+
always ready since they are used internally to wake the select loop. Ours is ready if there
124+
is data in the buffer.
125+
"""
126+
127+
def is_ready(sock: Any) -> bool:
128+
return sock is not mock_sock or (fake_socket_handler.pending() > 0)
129+
130+
def handle_select(rlist: list, wlist: list, *args: Any) -> list:
131+
return [list(filter(is_ready, rlist)), list(filter(is_ready, wlist))]
132+
133+
with patch("paho.mqtt.client.select.select", side_effect=handle_select):
134+
yield
135+
11136

12137
@pytest.fixture(name="mqtt_client")
13-
def mqtt_client():
138+
def mqtt_client(mock_create_connection: None, mock_select: None) -> Generator[RoborockMqttClientV1, None, None]:
14139
user_data = UserData.from_dict(USER_DATA)
15140
home_data = HomeData.from_dict(HOME_DATA_RAW)
16141
device_info = DeviceData(

tests/mock_data.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
"r": {
2222
"r": "US",
2323
"a": "https://api-us.roborock.com",
24-
"m": "ssl://mqtt-us.roborock.com:8883",
24+
"m": "tcp://mqtt-us.roborock.com:8883", # Skip SSL code in MQTT client library
2525
"l": "https://wood-us.roborock.com",
2626
},
2727
},
2828
"tuyaDeviceState": 2,
2929
"avatarurl": "https://files.roborock.com/iottest/default_avatar.png",
3030
}
31-
31+
LOCAL_KEY = "key123"
3232
HOME_DATA_RAW = {
3333
"id": 123456,
3434
"name": "My Home",
@@ -199,7 +199,7 @@
199199
"name": "Roborock S7 MaxV",
200200
"attribute": None,
201201
"activeTime": 1672364449,
202-
"localKey": "key123",
202+
"localKey": LOCAL_KEY,
203203
"runtimeEnv": None,
204204
"timeZoneId": "America/Los_Angeles",
205205
"iconUrl": "no_url",
@@ -339,3 +339,5 @@
339339
}
340340

341341
GET_CODE_RESPONSE = {"code": 200, "msg": "success", "data": None}
342+
343+
MQTT_PUBLISH_TOPIC = "rr/m/o/user123/6ac2e6f8/abc123"

tests/mqtt_packet.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""Module for crafting MQTT packets.
2+
3+
This library is copied from the paho mqtt client library tests, with just the
4+
parts needed for some roborock messages. This message format in this file is
5+
not specific to roborock.
6+
"""
7+
8+
import struct
9+
10+
PROP_RECEIVE_MAXIMUM = 33
11+
PROP_TOPIC_ALIAS_MAXIMUM = 34
12+
13+
14+
def gen_uint16_prop(identifier: int, word: int) -> bytes:
15+
"""Generate a property with a uint16_t value."""
16+
prop = struct.pack("!BH", identifier, word)
17+
return prop
18+
19+
20+
def pack_varint(varint: int) -> bytes:
21+
"""Pack a variable integer."""
22+
s = b""
23+
while True:
24+
byte = varint % 128
25+
varint = varint // 128
26+
# If there are more digits to encode, set the top bit of this digit
27+
if varint > 0:
28+
byte = byte | 0x80
29+
30+
s = s + struct.pack("!B", byte)
31+
if varint == 0:
32+
return s
33+
34+
35+
def prop_finalise(props: bytes) -> bytes:
36+
"""Finalise the properties."""
37+
if props is None:
38+
return pack_varint(0)
39+
else:
40+
return pack_varint(len(props)) + props
41+
42+
43+
def gen_connack(flags=0, rc=0, properties=b"", property_helper=True):
44+
"""Generate a CONNACK packet."""
45+
if property_helper:
46+
if properties is not None:
47+
properties = (
48+
gen_uint16_prop(PROP_TOPIC_ALIAS_MAXIMUM, 10) + properties + gen_uint16_prop(PROP_RECEIVE_MAXIMUM, 20)
49+
)
50+
else:
51+
properties = b""
52+
properties = prop_finalise(properties)
53+
54+
packet = struct.pack("!BBBB", 32, 2 + len(properties), flags, rc) + properties
55+
56+
return packet
57+
58+
59+
def gen_suback(mid: int, qos: int) -> bytes:
60+
"""Generate a SUBACK packet."""
61+
return struct.pack("!BBHBB", 144, 2 + 1 + 1, mid, 0, qos)
62+
63+
64+
def _gen_short(cmd: int, reason_code: int) -> bytes:
65+
return struct.pack("!BBB", cmd, 1, reason_code)
66+
67+
68+
def gen_disconnect(reason_code: int = 0) -> bytes:
69+
"""Generate a DISCONNECT packet."""
70+
return _gen_short(0xE0, reason_code)
71+
72+
73+
def _gen_command_with_mid(cmd: int, mid: int, reason_code: int = 0) -> bytes:
74+
return struct.pack("!BBHB", cmd, 3, mid, reason_code)
75+
76+
77+
def gen_puback(mid: int, reason_code: int = -1) -> bytes:
78+
"""Generate a PUBACK packet."""
79+
return _gen_command_with_mid(64, mid, reason_code)
80+
81+
82+
def _pack_remaining_length(remaining_length: int) -> bytes:
83+
"""Pack a remaining length."""
84+
s = b""
85+
while True:
86+
byte = remaining_length % 128
87+
remaining_length = remaining_length // 128
88+
# If there are more digits to encode, set the top bit of this digit
89+
if remaining_length > 0:
90+
byte = byte | 0x80
91+
92+
s = s + struct.pack("!B", byte)
93+
if remaining_length == 0:
94+
return s
95+
96+
97+
def gen_publish(
98+
topic: str,
99+
payload: bytes | None = None,
100+
retain: bool = False,
101+
dup: bool = False,
102+
mid: int = 0,
103+
properties: bytes = b"",
104+
) -> bytes:
105+
"""Generate a PUBLISH packet."""
106+
if isinstance(topic, str):
107+
topic_b = topic.encode("utf-8")
108+
rl = 2 + len(topic_b)
109+
pack_format = "H" + str(len(topic_b)) + "s"
110+
111+
properties = prop_finalise(properties)
112+
rl += len(properties)
113+
# This will break if len(properties) > 127
114+
pack_format = pack_format + "%ds" % (len(properties))
115+
116+
if payload is not None:
117+
# payload = payload.encode("utf-8")
118+
rl = rl + len(payload)
119+
pack_format = pack_format + str(len(payload)) + "s"
120+
else:
121+
payload = b""
122+
pack_format = pack_format + "0s"
123+
124+
rlpacked = _pack_remaining_length(rl)
125+
cmd = 48
126+
if retain:
127+
cmd = cmd + 1
128+
if dup:
129+
cmd = cmd + 8
130+
131+
return struct.pack(
132+
"!B" + str(len(rlpacked)) + "s" + pack_format, cmd, rlpacked, len(topic_b), topic_b, properties, payload
133+
)

0 commit comments

Comments
 (0)