Skip to content

Commit b3c6511

Browse files
committed
feat: Add diagnostics library for tracking stats/counters
Add a generic library that can be used to track counters, or elapsed time (e.g. how long it takes on average to connect to mqtt). This is a copy of https://github.com/allenporter/python-google-nest-sdm/blob/main/google_nest_sdm/diagnostics.py This is an initial pass to add a few initial example metrics for MQTT, but we can add more as we need fine grained details in diagnostics.
1 parent 813d675 commit b3c6511

File tree

7 files changed

+295
-17
lines changed

7 files changed

+295
-17
lines changed

roborock/devices/device_manager.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import asyncio
44
import enum
55
import logging
6-
from collections.abc import Callable
6+
from collections.abc import Callable, Mapping
77
from dataclasses import dataclass
8+
from typing import Any
89

910
import aiohttp
1011

@@ -15,6 +16,7 @@
1516
UserData,
1617
)
1718
from roborock.devices.device import DeviceReadyCallback, RoborockDevice
19+
from roborock.diagnostics import Diagnostics
1820
from roborock.map.map_parser import MapParserConfig
1921
from roborock.mqtt.roborock_session import create_lazy_mqtt_session
2022
from roborock.mqtt.session import MqttSession
@@ -57,6 +59,7 @@ def __init__(
5759
device_creator: DeviceCreator,
5860
mqtt_session: MqttSession,
5961
cache: Cache,
62+
diagnostics: Diagnostics,
6063
) -> None:
6164
"""Initialize the DeviceManager with user data and optional cache storage.
6265
@@ -67,11 +70,14 @@ def __init__(
6770
self._device_creator = device_creator
6871
self._devices: dict[str, RoborockDevice] = {}
6972
self._mqtt_session = mqtt_session
73+
self._diagnostics = diagnostics
7074

7175
async def discover_devices(self) -> list[RoborockDevice]:
7276
"""Discover all devices for the logged-in user."""
77+
self._diagnostics.increment("discover_devices")
7378
cache_data = await self._cache.get()
7479
if not cache_data.home_data:
80+
self._diagnostics.increment("fetch_home_data")
7581
_LOGGER.debug("No cached home data found, fetching from API")
7682
cache_data.home_data = await self._web_api.get_home_data()
7783
await self._cache.set(cache_data)
@@ -109,6 +115,10 @@ async def close(self) -> None:
109115
tasks.append(self._mqtt_session.close())
110116
await asyncio.gather(*tasks)
111117

118+
def diagnostic_data(self) -> Mapping[str, Any]:
119+
"""Return diagnostics information about the device manager."""
120+
return self._diagnostics.as_dict()
121+
112122

113123
@dataclass
114124
class UserParams:
@@ -175,7 +185,10 @@ async def create_device_manager(
175185
web_api = create_web_api_wrapper(user_params, session=session, cache=cache)
176186
user_data = user_params.user_data
177187

188+
diagnostics = Diagnostics()
189+
178190
mqtt_params = create_mqtt_params(user_data.rriot)
191+
mqtt_params.diagnostics = diagnostics.subkey("mqtt_session")
179192
mqtt_session = await create_lazy_mqtt_session(mqtt_params)
180193

181194
def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
@@ -219,6 +232,6 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat
219232
dev.add_ready_callback(ready_callback)
220233
return dev
221234

222-
manager = DeviceManager(web_api, device_creator, mqtt_session=mqtt_session, cache=cache)
235+
manager = DeviceManager(web_api, device_creator, mqtt_session=mqtt_session, cache=cache, diagnostics=diagnostics)
223236
await manager.discover_devices()
224237
return manager

roborock/diagnostics.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Diagnostics for debugging.
2+
3+
A Diagnostics object can be used to track counts and latencies of various
4+
operations within a module. This can be useful for debugging performance issues
5+
or understanding usage patterns.
6+
7+
This is an internal facing module and is not intended for public use. Diagnostics
8+
data is collected and exposed to clients via higher level APIs like the
9+
DeviceManager.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import time
15+
from collections import Counter
16+
from collections.abc import Generator, Mapping
17+
from contextlib import contextmanager
18+
from typing import Any
19+
20+
21+
class Diagnostics:
22+
"""A class that holdes diagnostics information for a module.
23+
24+
You can use this class to hold counter or for recording timing information
25+
that can be exported as a dictionary for debugging purposes.
26+
"""
27+
28+
def __init__(self) -> None:
29+
"""Initialize Diagnostics."""
30+
self._counter: Counter = Counter()
31+
self._subkeys: dict[str, Diagnostics] = {}
32+
33+
def increment(self, key: str, count: int = 1) -> None:
34+
"""Increment a counter for the specified key/event."""
35+
self._counter.update(Counter({key: count}))
36+
37+
def elapsed(self, key_prefix: str, elapsed_ms: int = 1) -> None:
38+
"""Track a latency event for the specified key/event prefix."""
39+
self.increment(f"{key_prefix}_count", 1)
40+
self.increment(f"{key_prefix}_sum", elapsed_ms)
41+
42+
def as_dict(self) -> Mapping[str, Any]:
43+
"""Return diagnostics as a debug dictionary."""
44+
data: dict[str, Any] = {k: self._counter[k] for k in self._counter}
45+
for k, d in self._subkeys.items():
46+
v = d.as_dict()
47+
if not v:
48+
continue
49+
data[k] = v
50+
return data
51+
52+
def subkey(self, key: str) -> Diagnostics:
53+
"""Return sub-Diagnositics object with the specified subkey.
54+
55+
This will create a new Diagnostics object if one does not already exist
56+
for the specified subkey. Stats from the sub-Diagnostics will be included
57+
in the parent Diagnostics when exported as a dictionary.
58+
59+
Args:
60+
key: The subkey for the diagnostics.
61+
62+
Returns:
63+
The Diagnostics object for the specified subkey.
64+
"""
65+
if key not in self._subkeys:
66+
self._subkeys[key] = Diagnostics()
67+
return self._subkeys[key]
68+
69+
@contextmanager
70+
def timer(self, key_prefix: str) -> Generator[None, None, None]:
71+
"""A context manager that records the timing of operations as a diagnostic."""
72+
start = time.perf_counter()
73+
try:
74+
yield
75+
finally:
76+
end = time.perf_counter()
77+
ms = int((end - start) * 1000)
78+
self.elapsed(key_prefix, ms)
79+
80+
def reset(self) -> None:
81+
"""Clear all diagnostics, for testing."""
82+
self._counter = Counter()
83+
for d in self._subkeys.values():
84+
d.reset()

roborock/mqtt/roborock_session.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from aiomqtt import MqttCodeError, MqttError, TLSParameters
1919

2020
from roborock.callbacks import CallbackMap
21+
from roborock.diagnostics import Diagnostics
2122

2223
from .session import MqttParams, MqttSession, MqttSessionException, MqttSessionUnauthorized
2324

@@ -74,6 +75,7 @@ def __init__(
7475
self._connection_task: asyncio.Task[None] | None = None
7576
self._topic_idle_timeout = topic_idle_timeout
7677
self._idle_timers: dict[str, asyncio.Task[None]] = {}
78+
self._diagnostics = params.diagnostics
7779

7880
@property
7981
def connected(self) -> bool:
@@ -88,24 +90,30 @@ async def start(self) -> None:
8890
handle the failure and retry if desired itself. Once connected,
8991
the session will retry connecting in the background.
9092
"""
93+
self._diagnostics.increment("start_attempt")
9194
start_future: asyncio.Future[None] = asyncio.Future()
9295
loop = asyncio.get_event_loop()
9396
self._reconnect_task = loop.create_task(self._run_reconnect_loop(start_future))
9497
try:
9598
await start_future
9699
except MqttCodeError as err:
100+
self._diagnostics.increment(f"start_failure:{err.rc}")
97101
if err.rc == MqttReasonCode.RC_ERROR_UNAUTHORIZED:
98102
raise MqttSessionUnauthorized(f"Authorization error starting MQTT session: {err}") from err
99103
raise MqttSessionException(f"Error starting MQTT session: {err}") from err
100104
except MqttError as err:
105+
self._diagnostics.increment("start_failure:unknown")
101106
raise MqttSessionException(f"Error starting MQTT session: {err}") from err
102107
except Exception as err:
108+
self._diagnostics.increment("start_failure:uncaught")
103109
raise MqttSessionException(f"Unexpected error starting session: {err}") from err
104110
else:
111+
self._diagnostics.increment("start_success")
105112
_LOGGER.debug("MQTT session started successfully")
106113

107114
async def close(self) -> None:
108115
"""Cancels the MQTT loop and shutdown the client library."""
116+
self._diagnostics.increment("close")
109117
self._stop = True
110118
tasks = [task for task in [self._connection_task, self._reconnect_task, *self._idle_timers.values()] if task]
111119
self._connection_task = None
@@ -128,6 +136,7 @@ async def restart(self) -> None:
128136
the reconnect loop. This is a no-op if there is no active connection.
129137
"""
130138
_LOGGER.info("Forcing MQTT session restart")
139+
self._diagnostics.increment("restart")
131140
if self._connection_task:
132141
self._connection_task.cancel()
133142
else:
@@ -136,6 +145,7 @@ async def restart(self) -> None:
136145
async def _run_reconnect_loop(self, start_future: asyncio.Future[None] | None) -> None:
137146
"""Run the MQTT loop."""
138147
_LOGGER.info("Starting MQTT session")
148+
self._diagnostics.increment("start_loop")
139149
while True:
140150
try:
141151
self._connection_task = asyncio.create_task(self._run_connection(start_future))
@@ -156,6 +166,7 @@ async def _run_reconnect_loop(self, start_future: asyncio.Future[None] | None) -
156166
_LOGGER.debug("MQTT session closed, stopping retry loop")
157167
return
158168
_LOGGER.info("MQTT session disconnected, retrying in %s seconds", self._backoff.total_seconds())
169+
self._diagnostics.increment("reconnect_wait")
159170
await asyncio.sleep(self._backoff.total_seconds())
160171
self._backoff = min(self._backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL)
161172

@@ -167,17 +178,19 @@ async def _run_connection(self, start_future: asyncio.Future[None] | None) -> No
167178
is lost, this method will exit.
168179
"""
169180
try:
170-
async with self._mqtt_client(self._params) as client:
171-
self._backoff = MIN_BACKOFF_INTERVAL
172-
self._healthy = True
173-
_LOGGER.info("MQTT Session connected.")
174-
if start_future and not start_future.done():
175-
start_future.set_result(None)
176-
177-
_LOGGER.debug("Processing MQTT messages")
178-
async for message in client.messages:
179-
_LOGGER.debug("Received message: %s", message)
180-
self._listeners(message.topic.value, message.payload)
181+
with self._diagnostics.timer("connection"):
182+
async with self._mqtt_client(self._params) as client:
183+
self._backoff = MIN_BACKOFF_INTERVAL
184+
self._healthy = True
185+
_LOGGER.info("MQTT Session connected.")
186+
if start_future and not start_future.done():
187+
start_future.set_result(None)
188+
189+
_LOGGER.debug("Processing MQTT messages")
190+
async for message in client.messages:
191+
_LOGGER.debug("Received message: %s", message)
192+
with self._diagnostics.timer("dispatch_message"):
193+
self._listeners(message.topic.value, message.payload)
181194
except MqttError as err:
182195
if start_future and not start_future.done():
183196
_LOGGER.info("MQTT error starting session: %s", err)
@@ -219,6 +232,7 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
219232
async with self._client_lock:
220233
self._client = client
221234
for topic in self._listeners.keys():
235+
self._diagnostics.increment("resubscribe")
222236
_LOGGER.debug("Re-establishing subscription to topic %s", topic)
223237
# TODO: If this fails it will break the whole connection. Make
224238
# this retry again in the background with backoff.
@@ -243,6 +257,7 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
243257

244258
# If there is an idle timer for this topic, cancel it (reuse subscription)
245259
if idle_timer := self._idle_timers.pop(topic, None):
260+
self._diagnostics.increment("unsubscribe_idle_cancel")
246261
idle_timer.cancel()
247262
_LOGGER.debug("Cancelled idle timer for topic %s (reused subscription)", topic)
248263

@@ -252,12 +267,14 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
252267
if self._client:
253268
_LOGGER.debug("Establishing subscription to topic %s", topic)
254269
try:
255-
await self._client.subscribe(topic)
270+
with self._diagnostics.timer("subscribe"):
271+
await self._client.subscribe(topic)
256272
except MqttError as err:
257273
# Clean up the callback if subscription fails
258274
unsub()
259275
raise MqttSessionException(f"Error subscribing to topic: {err}") from err
260276
else:
277+
self._diagnostics.increment("subscribe_pending")
261278
_LOGGER.debug("Client not connected, will establish subscription later")
262279

263280
def schedule_unsubscribe():
@@ -283,9 +300,11 @@ async def idle_unsubscribe():
283300
self._idle_timers[topic] = task
284301

285302
def delayed_unsub():
303+
self._diagnostics.increment("unsubscribe")
286304
unsub() # Remove the callback from CallbackMap
287305
# If no more callbacks for this topic, start idle timer
288306
if not self._listeners.get_callbacks(topic):
307+
self._diagnostics.increment("unsubscribe_idle_start")
289308
schedule_unsubscribe()
290309

291310
return delayed_unsub
@@ -299,7 +318,8 @@ async def publish(self, topic: str, message: bytes) -> None:
299318
raise MqttSessionException("Could not publish message, MQTT client not connected")
300319
client = self._client
301320
try:
302-
await client.publish(topic, message)
321+
with self._diagnostics.timer("publish"):
322+
await client.publish(topic, message)
303323
except MqttError as err:
304324
raise MqttSessionException(f"Error publishing message: {err}") from err
305325

@@ -312,11 +332,12 @@ class LazyMqttSession(MqttSession):
312332
is made.
313333
"""
314334

315-
def __init__(self, session: RoborockMqttSession) -> None:
335+
def __init__(self, session: RoborockMqttSession, diagnostics: Diagnostics) -> None:
316336
"""Initialize the lazy session with an existing session."""
317337
self._lock = asyncio.Lock()
318338
self._started = False
319339
self._session = session
340+
self._diagnostics = diagnostics
320341

321342
@property
322343
def connected(self) -> bool:
@@ -327,6 +348,7 @@ async def _maybe_start(self) -> None:
327348
"""Start the MQTT session if not already started."""
328349
async with self._lock:
329350
if not self._started:
351+
self._diagnostics.increment("start")
330352
await self._session.start()
331353
self._started = True
332354

@@ -377,4 +399,4 @@ async def create_lazy_mqtt_session(params: MqttParams) -> MqttSession:
377399
This function is a factory for creating an MQTT session that will
378400
only connect when the first attempt to subscribe or publish is made.
379401
"""
380-
return LazyMqttSession(RoborockMqttSession(params))
402+
return LazyMqttSession(RoborockMqttSession(params), diagnostics=params.diagnostics.subkey("lazy_mqtt"))

roborock/mqtt/session.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Callable
55
from dataclasses import dataclass
66

7+
from roborock.diagnostics import Diagnostics
78
from roborock.exceptions import RoborockException
89

910
DEFAULT_TIMEOUT = 30.0
@@ -31,6 +32,14 @@ class MqttParams:
3132
timeout: float = DEFAULT_TIMEOUT
3233
"""Timeout for communications with the broker in seconds."""
3334

35+
diagnostics: Diagnostics = Diagnostics()
36+
"""Diagnostics object for tracking MQTT session stats.
37+
38+
This defaults to a new Diagnostics object, but the common case is the
39+
caller will provide their own (e.g., from a DeviceManager) so that the
40+
shared MQTT session diagnostics are included in the overall diagnostics.
41+
"""
42+
3443

3544
class MqttSession(ABC):
3645
"""An MQTT session for sending and receiving messages."""

tests/devices/test_device_manager.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,17 @@ async def test_start_connect_failure(home_data: HomeData, channel_failure: Mock,
231231

232232
await device_manager.close()
233233
assert mock_unsub.call_count == 1
234+
235+
236+
async def test_diagnostics_collection(home_data: HomeData) -> None:
237+
"""Test that diagnostics are collected correctly in the DeviceManager."""
238+
device_manager = await create_device_manager(USER_PARAMS)
239+
devices = await device_manager.get_devices()
240+
assert len(devices) == 1
241+
242+
diagnostics = device_manager.diagnostic_data()
243+
assert diagnostics is not None
244+
assert diagnostics.get("discover_devices") == 1
245+
assert diagnostics.get("fetch_home_data") == 1
246+
247+
await device_manager.close()

0 commit comments

Comments
 (0)