1+ import asyncio
12import logging
3+ from asyncio import Lock , TimerHandle , Transport , get_running_loop
4+ from collections .abc import Callable
5+ from dataclasses import dataclass
26
3- from roborock . local_api import RoborockLocalClient
7+ import async_timeout
48
5- from .. import CommandVacuumError , DeviceData , RoborockCommand , RoborockException
6- from ..exceptions import VacuumError
9+ from .. import CommandVacuumError , DeviceData , RoborockCommand
10+ from ..api import RoborockClient
11+ from ..exceptions import RoborockConnectionException , RoborockException , VacuumError
12+ from ..protocol import Decoder , Encoder , create_local_decoder , create_local_encoder
713from ..protocols .v1_protocol import encode_local_payload
814from ..roborock_message import RoborockMessage , RoborockMessageProtocol
915from ..util import RoborockLoggerAdapter
1218_LOGGER = logging .getLogger (__name__ )
1319
1420
15- class RoborockLocalClientV1 (RoborockLocalClient , RoborockClientV1 ):
21+ @dataclass
22+ class _LocalProtocol (asyncio .Protocol ):
23+ """Callbacks for the Roborock local client transport."""
24+
25+ messages_cb : Callable [[bytes ], None ]
26+ connection_lost_cb : Callable [[Exception | None ], None ]
27+
28+ def data_received (self , bytes ) -> None :
29+ """Called when data is received from the transport."""
30+ self .messages_cb (bytes )
31+
32+ def connection_lost (self , exc : Exception | None ) -> None :
33+ """Called when the transport connection is lost."""
34+ self .connection_lost_cb (exc )
35+
36+
37+ class RoborockLocalClientV1 (RoborockClientV1 , RoborockClient ):
1638 """Roborock local client for v1 devices."""
1739
1840 def __init__ (self , device_data : DeviceData , queue_timeout : int = 4 ):
1941 """Initialize the Roborock local client."""
20- RoborockLocalClient .__init__ (self , device_data )
42+ if device_data .host is None :
43+ raise RoborockException ("Host is required" )
44+ self .host = device_data .host
45+ self ._batch_structs : list [RoborockMessage ] = []
46+ self ._executing = False
47+ self .transport : Transport | None = None
48+ self ._mutex = Lock ()
49+ self .keep_alive_task : TimerHandle | None = None
2150 RoborockClientV1 .__init__ (self , device_data , "abc" )
51+ RoborockClient .__init__ (self , device_data )
52+ self ._local_protocol = _LocalProtocol (self ._data_received , self ._connection_lost )
53+ self ._encoder : Encoder = create_local_encoder (device_data .device .local_key )
54+ self ._decoder : Decoder = create_local_decoder (device_data .device .local_key )
2255 self .queue_timeout = queue_timeout
2356 self ._logger = RoborockLoggerAdapter (device_data .device .name , _LOGGER )
2457
58+ def _data_received (self , message ):
59+ """Called when data is received from the transport."""
60+ parsed_msg = self ._decoder (message )
61+ self .on_message_received (parsed_msg )
62+
63+ def _connection_lost (self , exc : Exception | None ):
64+ """Called when the transport connection is lost."""
65+ self ._sync_disconnect ()
66+ self .on_connection_lost (exc )
67+
68+ def is_connected (self ):
69+ return self .transport and self .transport .is_reading ()
70+
71+ async def keep_alive_func (self , _ = None ):
72+ try :
73+ await self .ping ()
74+ except RoborockException :
75+ pass
76+ loop = asyncio .get_running_loop ()
77+ self .keep_alive_task = loop .call_later (10 , lambda : asyncio .create_task (self .keep_alive_func ()))
78+
79+ async def async_connect (self ) -> None :
80+ should_ping = False
81+ async with self ._mutex :
82+ try :
83+ if not self .is_connected ():
84+ self ._sync_disconnect ()
85+ async with async_timeout .timeout (self .queue_timeout ):
86+ self ._logger .debug (f"Connecting to { self .host } " )
87+ loop = get_running_loop ()
88+ self .transport , _ = await loop .create_connection ( # type: ignore
89+ lambda : self ._local_protocol , self .host , 58867
90+ )
91+ self ._logger .info (f"Connected to { self .host } " )
92+ should_ping = True
93+ except BaseException as e :
94+ raise RoborockConnectionException (f"Failed connecting to { self .host } " ) from e
95+ if should_ping :
96+ await self .hello ()
97+ await self .keep_alive_func ()
98+
99+ def _sync_disconnect (self ) -> None :
100+ loop = asyncio .get_running_loop ()
101+ if self .transport and loop .is_running ():
102+ self ._logger .debug (f"Disconnecting from { self .host } " )
103+ self .transport .close ()
104+ if self .keep_alive_task :
105+ self .keep_alive_task .cancel ()
106+
107+ async def async_disconnect (self ) -> None :
108+ async with self ._mutex :
109+ self ._sync_disconnect ()
110+
111+ async def hello (self ):
112+ request_id = 1
113+ protocol = RoborockMessageProtocol .HELLO_REQUEST
114+ try :
115+ return await self ._send_message (
116+ RoborockMessage (
117+ protocol = protocol ,
118+ seq = request_id ,
119+ random = 22 ,
120+ )
121+ )
122+ except Exception as e :
123+ self ._logger .error (e )
124+
125+ async def ping (self ) -> None :
126+ request_id = 2
127+ protocol = RoborockMessageProtocol .PING_REQUEST
128+ return await self ._send_message (
129+ RoborockMessage (
130+ protocol = protocol ,
131+ seq = request_id ,
132+ random = 23 ,
133+ )
134+ )
135+
136+ def _send_msg_raw (self , data : bytes ):
137+ try :
138+ if not self .transport :
139+ raise RoborockException ("Can not send message without connection" )
140+ self .transport .write (data )
141+ except Exception as e :
142+ raise RoborockException (e ) from e
143+
25144 async def _send_command (
26145 self ,
27146 method : RoborockCommand | str ,
@@ -32,9 +151,9 @@ async def _send_command(
32151
33152 roborock_message = encode_local_payload (method , params )
34153 self ._logger .debug ("Building message id %s for method %s" , roborock_message .get_request_id (), method )
35- return await self .send_message (roborock_message )
154+ return await self ._send_message (roborock_message )
36155
37- async def send_message (self , roborock_message : RoborockMessage ):
156+ async def _send_message (self , roborock_message : RoborockMessage ):
38157 await self .validate_connection ()
39158 method = roborock_message .get_method ()
40159 params = roborock_message .get_params ()
0 commit comments