|
| 1 | +"""Module for managing callback utility functions.""" |
| 2 | + |
| 3 | +import logging |
| 4 | +from collections.abc import Callable |
| 5 | +from typing import Generic, TypeVar |
| 6 | + |
| 7 | +_LOGGER = logging.getLogger(__name__) |
| 8 | + |
| 9 | +K = TypeVar("K") |
| 10 | +V = TypeVar("V") |
| 11 | + |
| 12 | + |
| 13 | +def safe_callback(callback: Callable[[V], None], logger: logging.Logger | None = None) -> Callable[[V], None]: |
| 14 | + """Wrap a callback to catch and log exceptions. |
| 15 | +
|
| 16 | + This is useful for ensuring that errors in callbacks do not propagate |
| 17 | + and cause unexpected behavior. Any failures during callback execution will be logged. |
| 18 | + """ |
| 19 | + |
| 20 | + if logger is None: |
| 21 | + logger = _LOGGER |
| 22 | + |
| 23 | + def wrapper(value: V) -> None: |
| 24 | + try: |
| 25 | + callback(value) |
| 26 | + except Exception as ex: # noqa: BLE001 |
| 27 | + logger.error("Uncaught error in callback '%s': %s", callback.__name__, ex) |
| 28 | + |
| 29 | + return wrapper |
| 30 | + |
| 31 | + |
| 32 | +class CallbackMap(Generic[K, V]): |
| 33 | + """A mapping of callbacks for specific keys. |
| 34 | +
|
| 35 | + This allows for registering multiple callbacks for different keys and invoking them |
| 36 | + when a value is received for a specific key. |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__(self, logger: logging.Logger | None = None) -> None: |
| 40 | + self._callbacks: dict[K, list[Callable[[V], None]]] = {} |
| 41 | + self._logger = logger or _LOGGER |
| 42 | + |
| 43 | + def keys(self) -> list[K]: |
| 44 | + """Get all keys in the callback map.""" |
| 45 | + return list(self._callbacks.keys()) |
| 46 | + |
| 47 | + def add_callback(self, key: K, callback: Callable[[V], None]) -> Callable[[], None]: |
| 48 | + """Add a callback for a specific key. |
| 49 | +
|
| 50 | + Any failures during callback execution will be logged. |
| 51 | +
|
| 52 | + Returns a callable that can be used to remove the callback. |
| 53 | + """ |
| 54 | + self._callbacks.setdefault(key, []).append(callback) |
| 55 | + |
| 56 | + def remove_callback() -> None: |
| 57 | + """Remove the callback for the specific key.""" |
| 58 | + if cb_list := self._callbacks.get(key): |
| 59 | + cb_list.remove(callback) |
| 60 | + if not cb_list: |
| 61 | + del self._callbacks[key] |
| 62 | + |
| 63 | + return remove_callback |
| 64 | + |
| 65 | + def get_callbacks(self, key: K) -> list[Callable[[V], None]]: |
| 66 | + """Get all callbacks for a specific key.""" |
| 67 | + return self._callbacks.get(key, []) |
| 68 | + |
| 69 | + def __call__(self, key: K, value: V) -> None: |
| 70 | + """Invoke all callbacks for a specific key.""" |
| 71 | + for callback in self.get_callbacks(key): |
| 72 | + safe_callback(callback, self._logger)(value) |
| 73 | + |
| 74 | + |
| 75 | +class CallbackList(Generic[V]): |
| 76 | + """A list of callbacks that can be invoked. |
| 77 | +
|
| 78 | + This combines a list of callbacks into a single callable. Callers can add |
| 79 | + additional callbacks to the list at any time. |
| 80 | + """ |
| 81 | + |
| 82 | + def __init__(self, logger: logging.Logger | None = None) -> None: |
| 83 | + self._callbacks: list[Callable[[V], None]] = [] |
| 84 | + self._logger = logger or _LOGGER |
| 85 | + |
| 86 | + def add_callback(self, callback: Callable[[V], None]) -> Callable[[], None]: |
| 87 | + """Add a callback to the list. |
| 88 | +
|
| 89 | + Any failures during callback execution will be logged. |
| 90 | +
|
| 91 | + Returns a callable that can be used to remove the callback. |
| 92 | + """ |
| 93 | + self._callbacks.append(callback) |
| 94 | + |
| 95 | + return lambda: self._callbacks.remove(callback) |
| 96 | + |
| 97 | + def __call__(self, value: V) -> None: |
| 98 | + """Invoke all callbacks in the list.""" |
| 99 | + for callback in self._callbacks: |
| 100 | + safe_callback(callback, self._logger)(value) |
| 101 | + |
| 102 | + |
| 103 | +def decoder_callback( |
| 104 | + decoder: Callable[[K], list[V]], callback: Callable[[V], None], logger: logging.Logger | None = None |
| 105 | +) -> Callable[[K], None]: |
| 106 | + """Create a callback that decodes messages using a decoder and invokes a callback. |
| 107 | +
|
| 108 | + The decoder converts a value into a list of values. The callback is then invoked |
| 109 | + for each value in the list. |
| 110 | +
|
| 111 | + Any failures during decoding or invoking the callbacks will be logged. |
| 112 | + """ |
| 113 | + if logger is None: |
| 114 | + logger = _LOGGER |
| 115 | + |
| 116 | + safe_cb = safe_callback(callback, logger) |
| 117 | + |
| 118 | + def wrapper(data: K) -> None: |
| 119 | + if not (messages := decoder(data)): |
| 120 | + logger.warning("Failed to decode message: %s", data) |
| 121 | + return |
| 122 | + for message in messages: |
| 123 | + _LOGGER.debug("Decoded message: %s", message) |
| 124 | + safe_cb(message) |
| 125 | + |
| 126 | + return wrapper |
0 commit comments