@@ -49,13 +49,14 @@ class RoborockMqttSession(MqttSession):
4949
5050 def __init__ (self , params : MqttParams ):
5151 self ._params = params
52- self ._background_task : asyncio .Task [None ] | None = None
52+ self ._reconnect_task : asyncio .Task [None ] | None = None
5353 self ._healthy = False
5454 self ._stop = False
5555 self ._backoff = MIN_BACKOFF_INTERVAL
5656 self ._client : aiomqtt .Client | None = None
5757 self ._client_lock = asyncio .Lock ()
5858 self ._listeners : CallbackMap [str , bytes ] = CallbackMap (_LOGGER )
59+ self ._connection_task : asyncio .Task [None ] | None = None
5960
6061 @property
6162 def connected (self ) -> bool :
@@ -72,7 +73,7 @@ async def start(self) -> None:
7273 """
7374 start_future : asyncio .Future [None ] = asyncio .Future ()
7475 loop = asyncio .get_event_loop ()
75- self ._background_task = loop .create_task (self ._run_task (start_future ))
76+ self ._reconnect_task = loop .create_task (self ._run_reconnect_loop (start_future ))
7677 try :
7778 await start_future
7879 except MqttError as err :
@@ -85,68 +86,93 @@ async def start(self) -> None:
8586 async def close (self ) -> None :
8687 """Cancels the MQTT loop and shutdown the client library."""
8788 self ._stop = True
88- if self ._background_task :
89- self ._background_task .cancel ()
90- try :
91- await self ._background_task
92- except asyncio .CancelledError :
93- pass
94- async with self ._client_lock :
95- if self ._client :
96- await self ._client .close ()
89+ tasks = [task for task in [self ._connection_task , self ._reconnect_task ] if task ]
90+ for task in tasks :
91+ task .cancel ()
92+ try :
93+ await asyncio .gather (* tasks )
94+ except asyncio .CancelledError :
95+ pass
9796
9897 self ._healthy = False
9998
100- async def _run_task (self , start_future : asyncio .Future [None ] | None ) -> None :
99+ async def restart (self ) -> None :
100+ """Force the session to disconnect and reconnect.
101+
102+ The active connection task will be cancelled and restarted in the background, retried by
103+ the reconnect loop. This is a no-op if there is no active connection.
104+ """
105+ _LOGGER .info ("Forcing MQTT session restart" )
106+ if self ._connection_task :
107+ self ._connection_task .cancel ()
108+ else :
109+ _LOGGER .debug ("No message loop task to cancel" )
110+
111+ async def _run_reconnect_loop (self , start_future : asyncio .Future [None ] | None ) -> None :
101112 """Run the MQTT loop."""
102113 _LOGGER .info ("Starting MQTT session" )
103114 while True :
104115 try :
105- async with self ._mqtt_client (self ._params ) as client :
106- # Reset backoff once we've successfully connected
107- self ._backoff = MIN_BACKOFF_INTERVAL
108- self ._healthy = True
109- _LOGGER .info ("MQTT Session connected." )
110- if start_future :
111- start_future .set_result (None )
112- start_future = None
113-
114- await self ._process_message_loop (client )
115-
116- except MqttError as err :
117- if start_future :
118- _LOGGER .info ("MQTT error starting session: %s" , err )
119- start_future .set_exception (err )
120- return
121- _LOGGER .info ("MQTT error: %s" , err )
122- except asyncio .CancelledError as err :
123- if start_future :
124- _LOGGER .debug ("MQTT loop was cancelled while starting" )
125- start_future .set_exception (err )
126- _LOGGER .debug ("MQTT loop was cancelled" )
127- return
128- # Catch exceptions to avoid crashing the loop
129- # and to allow the loop to retry.
130- except Exception as err :
131- # This error is thrown when the MQTT loop is cancelled
132- # and the generator is not stopped.
133- if "generator didn't stop" in str (err ) or "generator didn't yield" in str (err ):
134- _LOGGER .debug ("MQTT loop was cancelled" )
135- return
136- if start_future :
137- _LOGGER .error ("Uncaught error starting MQTT session: %s" , err )
138- start_future .set_exception (err )
116+ self ._connection_task = asyncio .create_task (self ._run_connection (start_future ))
117+ await self ._connection_task
118+ except asyncio .CancelledError :
119+ _LOGGER .debug ("MQTT connection task cancelled" )
120+ except Exception :
121+ # Exceptions are logged and handled in _run_connection.
122+ # There is a special case for exceptions on startup where we return
123+ # immediately. Otherwise, we let the reconnect loop retry with
124+ # backoff when the reconnect loop is active.
125+ if start_future and start_future .done () and start_future .exception ():
139126 return
140- _LOGGER .exception ("Uncaught error during MQTT session: %s" , err )
141127
142128 self ._healthy = False
129+ start_future = None
143130 if self ._stop :
144131 _LOGGER .debug ("MQTT session closed, stopping retry loop" )
145132 return
146133 _LOGGER .info ("MQTT session disconnected, retrying in %s seconds" , self ._backoff .total_seconds ())
147134 await asyncio .sleep (self ._backoff .total_seconds ())
148135 self ._backoff = min (self ._backoff * BACKOFF_MULTIPLIER , MAX_BACKOFF_INTERVAL )
149136
137+ async def _run_connection (self , start_future : asyncio .Future [None ] | None ) -> None :
138+ """Connect to the MQTT broker and listen for messages.
139+
140+ This is the primary connection loop for the MQTT session that is
141+ long running and processes incoming messages. If the connection
142+ is lost, this method will exit.
143+ """
144+ try :
145+ async with self ._mqtt_client (self ._params ) as client :
146+ self ._backoff = MIN_BACKOFF_INTERVAL
147+ self ._healthy = True
148+ _LOGGER .info ("MQTT Session connected." )
149+ if start_future and not start_future .done ():
150+ start_future .set_result (None )
151+
152+ _LOGGER .debug ("Processing MQTT messages" )
153+ async for message in client .messages :
154+ _LOGGER .debug ("Received message: %s" , message )
155+ self ._listeners (message .topic .value , message .payload )
156+ except MqttError as err :
157+ if start_future and not start_future .done ():
158+ _LOGGER .info ("MQTT error starting session: %s" , err )
159+ start_future .set_exception (err )
160+ else :
161+ _LOGGER .info ("MQTT error: %s" , err )
162+ raise
163+ except Exception as err :
164+ # This error is thrown when the MQTT loop is cancelled
165+ # and the generator is not stopped.
166+ if "generator didn't stop" in str (err ) or "generator didn't yield" in str (err ):
167+ _LOGGER .debug ("MQTT loop was cancelled" )
168+ return
169+ if start_future and not start_future .done ():
170+ _LOGGER .error ("Uncaught error starting MQTT session: %s" , err )
171+ start_future .set_exception (err )
172+ else :
173+ _LOGGER .exception ("Uncaught error during MQTT session: %s" , err )
174+ raise
175+
150176 @asynccontextmanager
151177 async def _mqtt_client (self , params : MqttParams ) -> aiomqtt .Client :
152178 """Connect to the MQTT broker and listen for messages."""
@@ -178,12 +204,6 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
178204 async with self ._client_lock :
179205 self ._client = None
180206
181- async def _process_message_loop (self , client : aiomqtt .Client ) -> None :
182- _LOGGER .debug ("Processing MQTT messages" )
183- async for message in client .messages :
184- _LOGGER .debug ("Received message: %s" , message )
185- self ._listeners (message .topic .value , message .payload )
186-
187207 async def subscribe (self , topic : str , callback : Callable [[bytes ], None ]) -> Callable [[], None ]:
188208 """Subscribe to messages on the specified topic and invoke the callback for new messages.
189209
@@ -271,6 +291,10 @@ async def close(self) -> None:
271291 """
272292 await self ._session .close ()
273293
294+ async def restart (self ) -> None :
295+ """Force the session to disconnect and reconnect."""
296+ await self ._session .restart ()
297+
274298
275299async def create_mqtt_session (params : MqttParams ) -> MqttSession :
276300 """Create an MQTT session.
0 commit comments