Skip to content

Commit fde1d0f

Browse files
committed
chore: unify callback handling recipes across mqtt and local channels
1 parent d40cc78 commit fde1d0f

File tree

7 files changed

+407
-50
lines changed

7 files changed

+407
-50
lines changed

roborock/callbacks.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Module for managing callback utility functions."""
2+
3+
import logging
4+
from collections.abc import Callable
5+
from typing import Generic, TypeVar
6+
7+
_LOGGER = logging.getLogger(__name__)
8+
9+
K = TypeVar("K")
10+
V = TypeVar("V")
11+
12+
13+
def safe_callback(callback: Callable[[V], None], logger: logging.Logger | None = None) -> Callable[[V], None]:
14+
"""Wrap a callback to catch and log exceptions.
15+
16+
This is useful for ensuring that errors in callbacks do not propagate
17+
and cause unexpected behavior. Any failures during callback execution will be logged.
18+
"""
19+
20+
if logger is None:
21+
logger = _LOGGER
22+
23+
def wrapper(value: V) -> None:
24+
try:
25+
callback(value)
26+
except Exception as ex: # noqa: BLE001
27+
logger.error("Uncaught error in callback '%s': %s", callback.__name__, ex)
28+
29+
return wrapper
30+
31+
32+
class CallbackMap(Generic[K, V]):
33+
"""A mapping of callbacks for specific keys."""
34+
35+
def __init__(self, logger: logging.Logger | None = None) -> None:
36+
self._callbacks: dict[K, list[Callable[[V], None]]] = {}
37+
self._logger = logger or _LOGGER
38+
39+
def keys(self) -> list[K]:
40+
"""Get all keys in the callback map."""
41+
return list(self._callbacks.keys())
42+
43+
def add_callback(self, key: K, callback: Callable[[V], None]) -> Callable[[], None]:
44+
"""Add a callback for a specific key.
45+
46+
Any failures during callback execution will be logged.
47+
48+
Returns a callable that can be used to remove the callback.
49+
"""
50+
self._callbacks.setdefault(key, []).append(callback)
51+
52+
def remove_callback() -> None:
53+
"""Remove the callback for the specific key."""
54+
if cb_list := self._callbacks.get(key):
55+
cb_list.remove(callback)
56+
if not cb_list:
57+
del self._callbacks[key]
58+
59+
return remove_callback
60+
61+
def get_callbacks(self, key: K) -> list[Callable[[V], None]]:
62+
"""Get all callbacks for a specific key."""
63+
return self._callbacks.get(key, [])
64+
65+
def __call__(self, key: K, value: V) -> None:
66+
"""Invoke all callbacks for a specific key."""
67+
for callback in self.get_callbacks(key):
68+
safe_callback(callback, self._logger)(value)
69+
70+
71+
class CallbackList(Generic[V]):
72+
"""A list of callbacks for specific keys."""
73+
74+
def __init__(self, logger: logging.Logger | None = None) -> None:
75+
self._callbacks: list[Callable[[V], None]] = []
76+
self._logger = logger or _LOGGER
77+
78+
def add_callback(self, callback: Callable[[V], None]) -> Callable[[], None]:
79+
"""Add a callback to the list.
80+
81+
Any failures during callback execution will be logged.
82+
83+
Returns a callable that can be used to remove the callback.
84+
"""
85+
self._callbacks.append(callback)
86+
87+
return lambda: self._callbacks.remove(callback)
88+
89+
def __call__(self, value: V) -> None:
90+
"""Invoke all callbacks in the list."""
91+
for callback in self._callbacks:
92+
safe_callback(callback, self._logger)(value)
93+
94+
95+
def decoder_callback(
96+
decoder: Callable[[K], list[V]], callback: Callable[[V], None], logger: logging.Logger | None = None
97+
) -> Callable[[K], None]:
98+
"""Create a callback that decodes messages using a decoder and invokes a callback.
99+
100+
Any failures during decoding will be logged.
101+
"""
102+
if logger is None:
103+
logger = _LOGGER
104+
105+
safe_cb = safe_callback(callback, logger)
106+
107+
def wrapper(data: K) -> None:
108+
if not (messages := decoder(data)):
109+
logger.warning("Failed to decode message: %s", data)
110+
return
111+
for message in messages:
112+
_LOGGER.debug("Decoded message: %s", message)
113+
safe_cb(message)
114+
115+
return wrapper
116+
117+
118+
119+
def dipspatch_callback(
120+
callback: Callable[[V], None], logger: logging.Logger | None = None
121+
) -> Callable[[list[V]], None]:
122+
"""Create a callback that decodes messages using a decoder and invokes a callback.
123+
124+
Any failures during decoding will be logged.
125+
"""
126+
if logger is None:
127+
logger = _LOGGER
128+
129+
safe_cb = safe_callback(callback, logger)
130+
131+
def wrapper(data: K) -> None:
132+
if not (messages := decoder(data)):
133+
logger.warning("Failed to decode message: %s", data)
134+
return
135+
for message in messages:
136+
_LOGGER.debug("Decoded message: %s", message)
137+
safe_cb(message)
138+
139+
return wrapper

roborock/devices/local_channel.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections.abc import Callable
66
from dataclasses import dataclass
77

8+
from roborock.callbacks import CallbackList, decoder_callback
89
from roborock.exceptions import RoborockConnectionException, RoborockException
910
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
1011
from roborock.roborock_message import RoborockMessage
@@ -42,11 +43,13 @@ def __init__(self, host: str, local_key: str):
4243
self._host = host
4344
self._transport: asyncio.Transport | None = None
4445
self._protocol: _LocalProtocol | None = None
45-
self._subscribers: list[Callable[[RoborockMessage], None]] = []
46+
self._subscribers: CallbackList[RoborockMessage] = CallbackList(_LOGGER)
4647
self._is_connected = False
4748

4849
self._decoder: Decoder = create_local_decoder(local_key)
4950
self._encoder: Encoder = create_local_encoder(local_key)
51+
# Callback to decode messages and dispatch to subscribers
52+
self._data_received: Callable[[bytes], None] = decoder_callback(self._decoder, self._subscribers, _LOGGER)
5053

5154
@property
5255
def is_connected(self) -> bool:
@@ -76,19 +79,6 @@ def close(self) -> None:
7679
self._transport = None
7780
self._is_connected = False
7881

79-
def _data_received(self, data: bytes) -> None:
80-
"""Handle incoming data from the transport."""
81-
if not (messages := self._decoder(data)):
82-
_LOGGER.warning("Failed to decode local message: %s", data)
83-
return
84-
for message in messages:
85-
_LOGGER.debug("Received message: %s", message)
86-
for callback in self._subscribers:
87-
try:
88-
callback(message)
89-
except Exception as e:
90-
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
91-
9282
def _connection_lost(self, exc: Exception | None) -> None:
9383
"""Handle connection loss."""
9484
_LOGGER.warning("Connection lost to %s", self._host, exc_info=exc)
@@ -97,12 +87,7 @@ def _connection_lost(self, exc: Exception | None) -> None:
9787

9888
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
9989
"""Subscribe to all messages from the device."""
100-
self._subscribers.append(callback)
101-
102-
def unsubscribe() -> None:
103-
self._subscribers.remove(callback)
104-
105-
return unsubscribe
90+
return self._subscribers.add_callback(callback)
10691

10792
async def publish(self, message: RoborockMessage) -> None:
10893
"""Send a command message.

roborock/devices/mqtt_channel.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from collections.abc import Callable
55

6+
from roborock.callbacks import decoder_callback
67
from roborock.containers import HomeDataDevice, RRiot, UserData
78
from roborock.exceptions import RoborockException
89
from roborock.mqtt.session import MqttParams, MqttSession, MqttSessionException
@@ -56,19 +57,8 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab
5657
5758
Returns a callable that can be used to unsubscribe from the topic.
5859
"""
59-
60-
def message_handler(payload: bytes) -> None:
61-
if not (messages := self._decoder(payload)):
62-
_LOGGER.warning("Failed to decode MQTT message: %s", payload)
63-
return
64-
for message in messages:
65-
_LOGGER.debug("Received message: %s", message)
66-
try:
67-
callback(message)
68-
except Exception as e:
69-
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
70-
71-
return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
60+
dispatch = decoder_callback(self._decoder, callback, _LOGGER)
61+
return await self._mqtt_session.subscribe(self._subscribe_topic, dispatch)
7262

7363
async def publish(self, message: RoborockMessage) -> None:
7464
"""Publish a command message.

roborock/mqtt/roborock_session.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import aiomqtt
1818
from aiomqtt import MqttError, TLSParameters
1919

20+
from roborock.callbacks import CallbackMap
21+
2022
from .session import MqttParams, MqttSession, MqttSessionException
2123

2224
_LOGGER = logging.getLogger(__name__)
@@ -53,7 +55,7 @@ def __init__(self, params: MqttParams):
5355
self._backoff = MIN_BACKOFF_INTERVAL
5456
self._client: aiomqtt.Client | None = None
5557
self._client_lock = asyncio.Lock()
56-
self._listeners: dict[str, list[Callable[[bytes], None]]] = {}
58+
self._listeners: CallbackMap[str, bytes] = CallbackMap(_LOGGER)
5759

5860
@property
5961
def connected(self) -> bool:
@@ -164,7 +166,7 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
164166
# Re-establish any existing subscriptions
165167
async with self._client_lock:
166168
self._client = client
167-
for topic in self._listeners:
169+
for topic in self._listeners.keys():
168170
_LOGGER.debug("Re-establishing subscription to topic %s", topic)
169171
# TODO: If this fails it will break the whole connection. Make
170172
# this retry again in the background with backoff.
@@ -179,13 +181,7 @@ async def _process_message_loop(self, client: aiomqtt.Client) -> None:
179181
_LOGGER.debug("Processing MQTT messages")
180182
async for message in client.messages:
181183
_LOGGER.debug("Received message: %s", message)
182-
for listener in self._listeners.get(message.topic.value, []):
183-
try:
184-
listener(message.payload)
185-
except asyncio.CancelledError:
186-
raise
187-
except Exception as e:
188-
_LOGGER.exception("Uncaught exception in subscriber callback: %s", e)
184+
self._listeners(message.topic.value, message.payload)
189185

190186
async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
191187
"""Subscribe to messages on the specified topic and invoke the callback for new messages.
@@ -196,9 +192,7 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
196192
The returned callable unsubscribes from the topic when called.
197193
"""
198194
_LOGGER.debug("Subscribing to topic %s", topic)
199-
if topic not in self._listeners:
200-
self._listeners[topic] = []
201-
self._listeners[topic].append(callback)
195+
unsub = self._listeners.add_callback(topic, callback)
202196

203197
async with self._client_lock:
204198
if self._client:
@@ -210,7 +204,7 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
210204
else:
211205
_LOGGER.debug("Client not connected, will establish subscription later")
212206

213-
return lambda: self._listeners[topic].remove(callback)
207+
return unsub
214208

215209
async def publish(self, topic: str, message: bytes) -> None:
216210
"""Publish a message on the topic."""

tests/devices/test_local_channel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ async def test_message_decode_error(local_channel: LocalChannel, caplog: pytest.
148148

149149
assert len(caplog.records) == 1
150150
assert caplog.records[0].levelname == "WARNING"
151-
assert "Failed to decode local message" in caplog.records[0].message
151+
assert "Failed to decode message" in caplog.records[0].message
152152

153153

154154
async def test_subscribe_callback(
@@ -181,7 +181,7 @@ def failing_callback(message: RoborockMessage) -> None:
181181
await asyncio.sleep(0.01) # yield
182182

183183
# Should log the exception but not crash
184-
assert any("Uncaught error in message handler callback" in record.message for record in caplog.records)
184+
assert any("Uncaught error in callback 'failing_callback'" in record.message for record in caplog.records)
185185

186186

187187
async def test_unsubscribe(local_channel: LocalChannel, mock_loop: Mock) -> None:

tests/devices/test_mqtt_channel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ async def test_message_decode_error(
150150

151151
assert len(caplog.records) == 1
152152
assert caplog.records[0].levelname == "WARNING"
153-
assert "Failed to decode MQTT message" in caplog.records[0].message
153+
assert "Failed to decode message" in caplog.records[0].message
154154
unsub()
155155

156156

@@ -255,7 +255,7 @@ def failing_callback(message: RoborockMessage) -> None:
255255
# Check that exception was logged
256256
error_records = [record for record in caplog.records if record.levelname == "ERROR"]
257257
assert len(error_records) == 1
258-
assert "Uncaught error in message handler callback" in error_records[0].message
258+
assert "Uncaught error in callback 'failing_callback'" in error_records[0].message
259259

260260
# Unsubscribe all remaining subscribers
261261
unsub1()

0 commit comments

Comments
 (0)