Skip to content

Commit 57d82e2

Browse files
authored
chore: unify callback handling recipes across mqtt and local channels (#456)
* chore: unify callback handling recipes across mqtt and local channels * chore: fix style and comments
1 parent 8bc3ab3 commit 57d82e2

File tree

7 files changed

+394
-50
lines changed

7 files changed

+394
-50
lines changed

roborock/callbacks.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""Module for managing callback utility functions."""
2+
3+
import logging
4+
from collections.abc import Callable
5+
from typing import Generic, TypeVar
6+
7+
_LOGGER = logging.getLogger(__name__)
8+
9+
K = TypeVar("K")
10+
V = TypeVar("V")
11+
12+
13+
def safe_callback(callback: Callable[[V], None], logger: logging.Logger | None = None) -> Callable[[V], None]:
14+
"""Wrap a callback to catch and log exceptions.
15+
16+
This is useful for ensuring that errors in callbacks do not propagate
17+
and cause unexpected behavior. Any failures during callback execution will be logged.
18+
"""
19+
20+
if logger is None:
21+
logger = _LOGGER
22+
23+
def wrapper(value: V) -> None:
24+
try:
25+
callback(value)
26+
except Exception as ex: # noqa: BLE001
27+
logger.error("Uncaught error in callback '%s': %s", callback.__name__, ex)
28+
29+
return wrapper
30+
31+
32+
class CallbackMap(Generic[K, V]):
33+
"""A mapping of callbacks for specific keys.
34+
35+
This allows for registering multiple callbacks for different keys and invoking them
36+
when a value is received for a specific key.
37+
"""
38+
39+
def __init__(self, logger: logging.Logger | None = None) -> None:
40+
self._callbacks: dict[K, list[Callable[[V], None]]] = {}
41+
self._logger = logger or _LOGGER
42+
43+
def keys(self) -> list[K]:
44+
"""Get all keys in the callback map."""
45+
return list(self._callbacks.keys())
46+
47+
def add_callback(self, key: K, callback: Callable[[V], None]) -> Callable[[], None]:
48+
"""Add a callback for a specific key.
49+
50+
Any failures during callback execution will be logged.
51+
52+
Returns a callable that can be used to remove the callback.
53+
"""
54+
self._callbacks.setdefault(key, []).append(callback)
55+
56+
def remove_callback() -> None:
57+
"""Remove the callback for the specific key."""
58+
if cb_list := self._callbacks.get(key):
59+
cb_list.remove(callback)
60+
if not cb_list:
61+
del self._callbacks[key]
62+
63+
return remove_callback
64+
65+
def get_callbacks(self, key: K) -> list[Callable[[V], None]]:
66+
"""Get all callbacks for a specific key."""
67+
return self._callbacks.get(key, [])
68+
69+
def __call__(self, key: K, value: V) -> None:
70+
"""Invoke all callbacks for a specific key."""
71+
for callback in self.get_callbacks(key):
72+
safe_callback(callback, self._logger)(value)
73+
74+
75+
class CallbackList(Generic[V]):
76+
"""A list of callbacks that can be invoked.
77+
78+
This combines a list of callbacks into a single callable. Callers can add
79+
additional callbacks to the list at any time.
80+
"""
81+
82+
def __init__(self, logger: logging.Logger | None = None) -> None:
83+
self._callbacks: list[Callable[[V], None]] = []
84+
self._logger = logger or _LOGGER
85+
86+
def add_callback(self, callback: Callable[[V], None]) -> Callable[[], None]:
87+
"""Add a callback to the list.
88+
89+
Any failures during callback execution will be logged.
90+
91+
Returns a callable that can be used to remove the callback.
92+
"""
93+
self._callbacks.append(callback)
94+
95+
return lambda: self._callbacks.remove(callback)
96+
97+
def __call__(self, value: V) -> None:
98+
"""Invoke all callbacks in the list."""
99+
for callback in self._callbacks:
100+
safe_callback(callback, self._logger)(value)
101+
102+
103+
def decoder_callback(
104+
decoder: Callable[[K], list[V]], callback: Callable[[V], None], logger: logging.Logger | None = None
105+
) -> Callable[[K], None]:
106+
"""Create a callback that decodes messages using a decoder and invokes a callback.
107+
108+
The decoder converts a value into a list of values. The callback is then invoked
109+
for each value in the list.
110+
111+
Any failures during decoding or invoking the callbacks will be logged.
112+
"""
113+
if logger is None:
114+
logger = _LOGGER
115+
116+
safe_cb = safe_callback(callback, logger)
117+
118+
def wrapper(data: K) -> None:
119+
if not (messages := decoder(data)):
120+
logger.warning("Failed to decode message: %s", data)
121+
return
122+
for message in messages:
123+
_LOGGER.debug("Decoded message: %s", message)
124+
safe_cb(message)
125+
126+
return wrapper

roborock/devices/local_channel.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections.abc import Callable
66
from dataclasses import dataclass
77

8+
from roborock.callbacks import CallbackList, decoder_callback
89
from roborock.exceptions import RoborockConnectionException, RoborockException
910
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
1011
from roborock.roborock_message import RoborockMessage
@@ -42,11 +43,13 @@ def __init__(self, host: str, local_key: str):
4243
self._host = host
4344
self._transport: asyncio.Transport | None = None
4445
self._protocol: _LocalProtocol | None = None
45-
self._subscribers: list[Callable[[RoborockMessage], None]] = []
46+
self._subscribers: CallbackList[RoborockMessage] = CallbackList(_LOGGER)
4647
self._is_connected = False
4748

4849
self._decoder: Decoder = create_local_decoder(local_key)
4950
self._encoder: Encoder = create_local_encoder(local_key)
51+
# Callback to decode messages and dispatch to subscribers
52+
self._data_received: Callable[[bytes], None] = decoder_callback(self._decoder, self._subscribers, _LOGGER)
5053

5154
@property
5255
def is_connected(self) -> bool:
@@ -76,19 +79,6 @@ def close(self) -> None:
7679
self._transport = None
7780
self._is_connected = False
7881

79-
def _data_received(self, data: bytes) -> None:
80-
"""Handle incoming data from the transport."""
81-
if not (messages := self._decoder(data)):
82-
_LOGGER.warning("Failed to decode local message: %s", data)
83-
return
84-
for message in messages:
85-
_LOGGER.debug("Received message: %s", message)
86-
for callback in self._subscribers:
87-
try:
88-
callback(message)
89-
except Exception as e:
90-
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
91-
9282
def _connection_lost(self, exc: Exception | None) -> None:
9383
"""Handle connection loss."""
9484
_LOGGER.warning("Connection lost to %s", self._host, exc_info=exc)
@@ -97,12 +87,7 @@ def _connection_lost(self, exc: Exception | None) -> None:
9787

9888
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
9989
"""Subscribe to all messages from the device."""
100-
self._subscribers.append(callback)
101-
102-
def unsubscribe() -> None:
103-
self._subscribers.remove(callback)
104-
105-
return unsubscribe
90+
return self._subscribers.add_callback(callback)
10691

10792
async def publish(self, message: RoborockMessage) -> None:
10893
"""Send a command message.

roborock/devices/mqtt_channel.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from collections.abc import Callable
55

6+
from roborock.callbacks import decoder_callback
67
from roborock.containers import HomeDataDevice, RRiot, UserData
78
from roborock.exceptions import RoborockException
89
from roborock.mqtt.session import MqttParams, MqttSession, MqttSessionException
@@ -56,19 +57,8 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab
5657
5758
Returns a callable that can be used to unsubscribe from the topic.
5859
"""
59-
60-
def message_handler(payload: bytes) -> None:
61-
if not (messages := self._decoder(payload)):
62-
_LOGGER.warning("Failed to decode MQTT message: %s", payload)
63-
return
64-
for message in messages:
65-
_LOGGER.debug("Received message: %s", message)
66-
try:
67-
callback(message)
68-
except Exception as e:
69-
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
70-
71-
return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
60+
dispatch = decoder_callback(self._decoder, callback, _LOGGER)
61+
return await self._mqtt_session.subscribe(self._subscribe_topic, dispatch)
7262

7363
async def publish(self, message: RoborockMessage) -> None:
7464
"""Publish a command message.

roborock/mqtt/roborock_session.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import aiomqtt
1818
from aiomqtt import MqttError, TLSParameters
1919

20+
from roborock.callbacks import CallbackMap
21+
2022
from .session import MqttParams, MqttSession, MqttSessionException
2123

2224
_LOGGER = logging.getLogger(__name__)
@@ -53,7 +55,7 @@ def __init__(self, params: MqttParams):
5355
self._backoff = MIN_BACKOFF_INTERVAL
5456
self._client: aiomqtt.Client | None = None
5557
self._client_lock = asyncio.Lock()
56-
self._listeners: dict[str, list[Callable[[bytes], None]]] = {}
58+
self._listeners: CallbackMap[str, bytes] = CallbackMap(_LOGGER)
5759

5860
@property
5961
def connected(self) -> bool:
@@ -164,7 +166,7 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
164166
# Re-establish any existing subscriptions
165167
async with self._client_lock:
166168
self._client = client
167-
for topic in self._listeners:
169+
for topic in self._listeners.keys():
168170
_LOGGER.debug("Re-establishing subscription to topic %s", topic)
169171
# TODO: If this fails it will break the whole connection. Make
170172
# this retry again in the background with backoff.
@@ -179,13 +181,7 @@ async def _process_message_loop(self, client: aiomqtt.Client) -> None:
179181
_LOGGER.debug("Processing MQTT messages")
180182
async for message in client.messages:
181183
_LOGGER.debug("Received message: %s", message)
182-
for listener in self._listeners.get(message.topic.value, []):
183-
try:
184-
listener(message.payload)
185-
except asyncio.CancelledError:
186-
raise
187-
except Exception as e:
188-
_LOGGER.exception("Uncaught exception in subscriber callback: %s", e)
184+
self._listeners(message.topic.value, message.payload)
189185

190186
async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
191187
"""Subscribe to messages on the specified topic and invoke the callback for new messages.
@@ -196,9 +192,7 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
196192
The returned callable unsubscribes from the topic when called.
197193
"""
198194
_LOGGER.debug("Subscribing to topic %s", topic)
199-
if topic not in self._listeners:
200-
self._listeners[topic] = []
201-
self._listeners[topic].append(callback)
195+
unsub = self._listeners.add_callback(topic, callback)
202196

203197
async with self._client_lock:
204198
if self._client:
@@ -210,7 +204,7 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
210204
else:
211205
_LOGGER.debug("Client not connected, will establish subscription later")
212206

213-
return lambda: self._listeners[topic].remove(callback)
207+
return unsub
214208

215209
async def publish(self, topic: str, message: bytes) -> None:
216210
"""Publish a message on the topic."""

tests/devices/test_local_channel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ async def test_message_decode_error(local_channel: LocalChannel, caplog: pytest.
148148

149149
assert len(caplog.records) == 1
150150
assert caplog.records[0].levelname == "WARNING"
151-
assert "Failed to decode local message" in caplog.records[0].message
151+
assert "Failed to decode message" in caplog.records[0].message
152152

153153

154154
async def test_subscribe_callback(
@@ -181,7 +181,7 @@ def failing_callback(message: RoborockMessage) -> None:
181181
await asyncio.sleep(0.01) # yield
182182

183183
# Should log the exception but not crash
184-
assert any("Uncaught error in message handler callback" in record.message for record in caplog.records)
184+
assert any("Uncaught error in callback 'failing_callback'" in record.message for record in caplog.records)
185185

186186

187187
async def test_unsubscribe(local_channel: LocalChannel, mock_loop: Mock) -> None:

tests/devices/test_mqtt_channel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ async def test_message_decode_error(
150150

151151
assert len(caplog.records) == 1
152152
assert caplog.records[0].levelname == "WARNING"
153-
assert "Failed to decode MQTT message" in caplog.records[0].message
153+
assert "Failed to decode message" in caplog.records[0].message
154154
unsub()
155155

156156

@@ -255,7 +255,7 @@ def failing_callback(message: RoborockMessage) -> None:
255255
# Check that exception was logged
256256
error_records = [record for record in caplog.records if record.levelname == "ERROR"]
257257
assert len(error_records) == 1
258-
assert "Uncaught error in message handler callback" in error_records[0].message
258+
assert "Uncaught error in callback 'failing_callback'" in error_records[0].message
259259

260260
# Unsubscribe all remaining subscribers
261261
unsub1()

0 commit comments

Comments
 (0)