11"""End-to-end tests for LocalChannel using fake sockets."""
22
33import asyncio
4- from collections .abc import AsyncGenerator , Generator , Callable
5- from unittest .mock import patch , Mock
6- from typing import Any
4+ from collections .abc import AsyncGenerator
75
86import pytest
97import syrupy
1412from roborock .roborock_message import RoborockMessage , RoborockMessageProtocol
1513from tests .fixtures .logging import CapturedRequestLog
1614from tests .fixtures .mqtt import Subscriber
17- from tests .fixtures .local_async_fixtures import AsyncLocalRequestHandler
1815from tests .mock_data import LOCAL_KEY
1916
2017TEST_HOST = "192.168.1.100"
2118TEST_DEVICE_UID = "test_device_uid"
2219TEST_RANDOM = 23
2320
2421
25- @pytest .fixture (name = "mock_create_local_connection" )
26- def create_local_connection_fixture (
27- local_async_request_handler : AsyncLocalRequestHandler , log : CapturedRequestLog
28- ) -> Generator [None , None , None ]:
29- """Fixture that overrides the transport creation to wire it up to the mock socket."""
30-
31- async def create_connection (protocol_factory : Callable [[], asyncio .Protocol ], * args , ** kwargs ) -> tuple [Any , Any ]:
32- protocol = protocol_factory ()
33-
34- async def handle_write (data : bytes ) -> None :
35- log .add_log_entry ("[local >]" , data )
36- response = await local_async_request_handler (data )
37- if response is not None :
38- log .add_log_entry ("[local <]" , response )
39- # Call data_received directly to avoid loop scheduling issues in test
40- protocol .data_received (response )
41-
42- closed = asyncio .Event ()
43-
44- mock_transport = Mock ()
45- mock_transport .write = handle_write
46- mock_transport .close = closed .set
47- mock_transport .is_reading = lambda : not closed .is_set ()
48-
49- return (mock_transport , protocol )
50-
51- with patch ("roborock.devices.local_channel.asyncio.get_running_loop" ) as mock_loop :
52- mock_loop .return_value .create_connection .side_effect = create_connection
53- yield
54-
55-
5622@pytest .fixture (name = "local_channel" )
57- async def local_channel_fixture (mock_create_local_connection : None ) -> AsyncGenerator [LocalChannel , None ]:
23+ async def local_channel_fixture (mock_async_create_local_connection : None ) -> AsyncGenerator [LocalChannel , None ]:
5824 channel = LocalChannel (host = TEST_HOST , local_key = LOCAL_KEY , device_uid = TEST_DEVICE_UID )
5925 yield channel
6026 channel .close ()
@@ -88,9 +54,7 @@ async def test_connect(
8854) -> None :
8955 """Test connecting to the device."""
9056 # Queue HELLO response with payload to ensure it can be parsed
91- local_response_queue .put_nowait (
92- build_raw_response (RoborockMessageProtocol .HELLO_RESPONSE , 1 , payload = b"ok" )
93- )
57+ local_response_queue .put_nowait (build_raw_response (RoborockMessageProtocol .HELLO_RESPONSE , 1 , payload = b"ok" ))
9458
9559 await local_channel .connect ()
9660
@@ -120,9 +84,7 @@ async def test_send_command(
12084) -> None :
12185 """Test sending a command."""
12286 # Queue HELLO response
123- local_response_queue .put_nowait (
124- build_raw_response (RoborockMessageProtocol .HELLO_RESPONSE , 1 , payload = b"ok" )
125- )
87+ local_response_queue .put_nowait (build_raw_response (RoborockMessageProtocol .HELLO_RESPONSE , 1 , payload = b"ok" ))
12688
12789 await local_channel .connect ()
12890
@@ -138,7 +100,9 @@ async def test_send_command(
138100 payload = b'{"method":"get_status"}' ,
139101 )
140102 # Prepare a fake response to the command.
141- response_queue .put (build_raw_response (RoborockMessageProtocol .RPC_RESPONSE , cmd_seq , payload = b'{"status": "ok"}' ))
103+ local_response_queue .put_nowait (
104+ build_raw_response (RoborockMessageProtocol .RPC_RESPONSE , cmd_seq , payload = b'{"status": "ok"}' )
105+ )
142106
143107 subscriber = Subscriber ()
144108 unsub = await local_channel .subscribe (subscriber .append )
@@ -208,15 +172,15 @@ async def test_l01_session(
208172 assert local_channel .is_connected
209173
210174 # Verify 1.0 HELLO request
211- request_bytes = local_received_requests .get ()
175+ request_bytes = await local_received_requests .get ()
212176 # Protocol is at offset 19 (2 bytes)
213177 # Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
214178 assert len (request_bytes ) >= 21
215179 protocol_bytes = request_bytes [19 :21 ]
216180 assert int .from_bytes (protocol_bytes , "big" ) == RoborockMessageProtocol .HELLO_REQUEST
217181
218182 # Verify L01 HELLO request
219- request_bytes = local_received_requests .get ()
183+ request_bytes = await local_received_requests .get ()
220184 # Protocol is at offset 19 (2 bytes)
221185 # Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
222186 assert len (request_bytes ) >= 21
0 commit comments