Skip to content

Commit 1013cb5

Browse files
authored
chore: inheritance fixes and simplifications (#282)
1 parent 39a8661 commit 1013cb5

File tree

8 files changed

+80
-64
lines changed

8 files changed

+80
-64
lines changed

roborock/api.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import secrets
99
import time
10+
from abc import ABC, abstractmethod
1011
from typing import Any
1112

1213
from .containers import (
@@ -21,17 +22,21 @@
2122
RoborockMessage,
2223
)
2324
from .roborock_typing import RoborockCommand
24-
from .util import RoborockLoggerAdapter, get_next_int, get_running_loop_or_create_one
25+
from .util import get_next_int, get_running_loop_or_create_one
2526

2627
_LOGGER = logging.getLogger(__name__)
2728
KEEPALIVE = 60
2829

2930

30-
class RoborockClient:
31-
def __init__(self, endpoint: str, device_info: DeviceData, queue_timeout: int = 4) -> None:
31+
class RoborockClient(ABC):
32+
"""Roborock client base class."""
33+
34+
_logger: logging.LoggerAdapter
35+
36+
def __init__(self, device_info: DeviceData, queue_timeout: int = 4) -> None:
37+
"""Initialize RoborockClient."""
3238
self.event_loop = get_running_loop_or_create_one()
3339
self.device_info = device_info
34-
self._endpoint = endpoint
3540
self._nonce = secrets.token_bytes(16)
3641
self._waiting_queue: dict[int, RoborockFuture] = {}
3742
self._last_device_msg_in = time.monotonic()
@@ -40,7 +45,6 @@ def __init__(self, endpoint: str, device_info: DeviceData, queue_timeout: int =
4045
self._diagnostic_data: dict[str, dict[str, Any]] = {
4146
"misc_info": {"Nonce": base64.b64encode(self._nonce).decode("utf-8")}
4247
}
43-
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
4448
self.is_available: bool = True
4549
self.queue_timeout = queue_timeout
4650

@@ -57,17 +61,21 @@ async def async_release(self) -> None:
5761
def diagnostic_data(self) -> dict:
5862
return self._diagnostic_data
5963

64+
@abstractmethod
6065
async def async_connect(self):
61-
raise NotImplementedError
66+
"""Connect to the Roborock device."""
6267

68+
@abstractmethod
6369
def sync_disconnect(self) -> Any:
64-
raise NotImplementedError
70+
"""Disconnect from the Roborock device."""
6571

72+
@abstractmethod
6673
async def async_disconnect(self) -> Any:
67-
raise NotImplementedError
74+
"""Disconnect from the Roborock device."""
6875

76+
@abstractmethod
6977
def on_message_received(self, messages: list[RoborockMessage]) -> None:
70-
raise NotImplementedError
78+
"""Handle received incoming messages from the device."""
7179

7280
def on_connection_lost(self, exc: Exception | None) -> None:
7381
self._last_disconnection = time.monotonic()
@@ -102,7 +110,7 @@ def _async_response(self, request_id: int, protocol_id: int = 0) -> Any:
102110
queue = RoborockFuture(protocol_id)
103111
if request_id in self._waiting_queue:
104112
new_id = get_next_int(10000, 32767)
105-
_LOGGER.warning(
113+
self._logger.warning(
106114
"Attempting to create a future with an existing id %s (%s)... New id is %s. "
107115
"Code may not function properly.",
108116
request_id,
@@ -113,12 +121,14 @@ def _async_response(self, request_id: int, protocol_id: int = 0) -> Any:
113121
self._waiting_queue[request_id] = queue
114122
return asyncio.ensure_future(self._wait_response(request_id, queue))
115123

124+
@abstractmethod
116125
async def send_message(self, roborock_message: RoborockMessage):
117-
raise NotImplementedError
126+
"""Send a message to the Roborock device."""
118127

128+
@abstractmethod
119129
async def _send_command(
120130
self,
121131
method: RoborockCommand | str,
122132
params: list | dict | int | None = None,
123133
):
124-
raise NotImplementedError
134+
"""Send a command to the Roborock device."""

roborock/cloud_api.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

3-
import base64
43
import logging
54
import threading
65
import uuid
6+
from abc import ABC
77
from asyncio import Lock
88
from typing import Any
99
from urllib.parse import urlparse
@@ -13,29 +13,24 @@
1313
from .api import KEEPALIVE, RoborockClient
1414
from .containers import DeviceData, UserData
1515
from .exceptions import RoborockException, VacuumError
16-
from .protocol import MessageParser, Utils, md5hex
16+
from .protocol import MessageParser, md5hex
1717
from .roborock_future import RoborockFuture
18-
from .roborock_message import RoborockMessage
19-
from .roborock_typing import RoborockCommand
20-
from .util import RoborockLoggerAdapter
2118

2219
_LOGGER = logging.getLogger(__name__)
2320
CONNECT_REQUEST_ID = 0
2421
DISCONNECT_REQUEST_ID = 1
2522

2623

27-
class RoborockMqttClient(RoborockClient, mqtt.Client):
24+
class RoborockMqttClient(RoborockClient, mqtt.Client, ABC):
2825
_thread: threading.Thread
2926
_client_id: str
3027

3128
def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout: int = 10) -> None:
3229
rriot = user_data.rriot
3330
if rriot is None:
3431
raise RoborockException("Got no rriot data from user_data")
35-
endpoint = base64.b64encode(Utils.md5(rriot.k.encode())[8:14]).decode()
36-
RoborockClient.__init__(self, endpoint, device_info, queue_timeout)
32+
RoborockClient.__init__(self, device_info, queue_timeout)
3733
mqtt.Client.__init__(self, protocol=mqtt.MQTTv5)
38-
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
3934
self._mqtt_user = rriot.u
4035
self._hashed_user = md5hex(self._mqtt_user + ":" + rriot.k)[2:10]
4136
url = urlparse(rriot.r.m)
@@ -49,7 +44,6 @@ def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout:
4944
self._mqtt_password = rriot.s
5045
self._hashed_password = md5hex(self._mqtt_password + ":" + rriot.k)[16:]
5146
super().username_pw_set(self._hashed_user, self._hashed_password)
52-
self._endpoint = base64.b64encode(Utils.md5(rriot.k.encode())[8:14]).decode()
5347
self._waiting_queue: dict[int, RoborockFuture] = {}
5448
self._mutex = Lock()
5549
self.update_client_id()
@@ -164,13 +158,3 @@ def _send_msg_raw(self, msg: bytes) -> None:
164158
info = self.publish(f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}", msg)
165159
if info.rc != mqtt.MQTT_ERR_SUCCESS:
166160
raise RoborockException(f"Failed to publish ({mqtt.error_string(info.rc)})")
167-
168-
async def send_message(self, roborock_message: RoborockMessage):
169-
raise NotImplementedError
170-
171-
async def _send_command(
172-
self,
173-
method: RoborockCommand | str,
174-
params: list | dict | int | None = None,
175-
):
176-
raise NotImplementedError

roborock/local_api.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import logging
5+
from abc import ABC
56
from asyncio import Lock, TimerHandle, Transport
67

78
import async_timeout
@@ -11,14 +12,15 @@
1112
from .exceptions import RoborockConnectionException, RoborockException
1213
from .protocol import MessageParser
1314
from .roborock_message import RoborockMessage, RoborockMessageProtocol
14-
from .roborock_typing import RoborockCommand
15-
from .util import RoborockLoggerAdapter
1615

1716
_LOGGER = logging.getLogger(__name__)
1817

1918

20-
class RoborockLocalClient(RoborockClient, asyncio.Protocol):
19+
class RoborockLocalClient(RoborockClient, asyncio.Protocol, ABC):
20+
"""Roborock local client base class."""
21+
2122
def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
23+
"""Initialize the Roborock local client."""
2224
if device_data.host is None:
2325
raise RoborockException("Host is required")
2426
self.host = device_data.host
@@ -28,8 +30,7 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
2830
self.transport: Transport | None = None
2931
self._mutex = Lock()
3032
self.keep_alive_task: TimerHandle | None = None
31-
self._logger = RoborockLoggerAdapter(device_data.device.name, _LOGGER)
32-
RoborockClient.__init__(self, "abc", device_data, queue_timeout)
33+
RoborockClient.__init__(self, device_data, queue_timeout)
3334

3435
def data_received(self, message):
3536
if self.remaining:
@@ -107,13 +108,6 @@ async def ping(self) -> None:
107108
)
108109
)
109110

110-
async def _send_command(
111-
self,
112-
method: RoborockCommand | str,
113-
params: list | dict | int | None = None,
114-
):
115-
raise NotImplementedError
116-
117111
def _send_msg_raw(self, data: bytes):
118112
try:
119113
if not self.transport:

roborock/version_1_apis/roborock_client_v1.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55
import struct
66
import time
7+
from abc import ABC
78
from collections.abc import Callable, Coroutine
89
from typing import Any, TypeVar, final
910

@@ -142,19 +143,22 @@ class ListenerModel:
142143
cache: dict[CacheableAttribute, AttributeCache]
143144

144145

145-
class RoborockClientV1(RoborockClient):
146+
class RoborockClientV1(RoborockClient, ABC):
147+
"""Roborock client base class for version 1 devices."""
148+
146149
_listeners: dict[str, ListenerModel] = {}
147150

148-
def __init__(self, device_info: DeviceData, logger, endpoint: str):
149-
super().__init__(endpoint, device_info)
151+
def __init__(self, device_info: DeviceData, endpoint: str):
152+
"""Initializes the Roborock client."""
153+
super().__init__(device_info)
150154
self._status_type: type[Status] = ModelStatus.get(device_info.model, S7MaxVStatus)
151-
self._logger = logger
152155
self.cache: dict[CacheableAttribute, AttributeCache] = {
153156
cacheable_attribute: AttributeCache(attr, self) for cacheable_attribute, attr in get_cache_map().items()
154157
}
155158
if device_info.device.duid not in self._listeners:
156159
self._listeners[device_info.device.duid] = ListenerModel({}, self.cache)
157160
self.listener_model = self._listeners[device_info.device.duid]
161+
self._endpoint = endpoint
158162

159163
def release(self):
160164
super().release()

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
1+
import logging
2+
13
from roborock.local_api import RoborockLocalClient
24

35
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
46
from ..exceptions import VacuumError
57
from ..protocol import MessageParser
68
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
9+
from ..util import RoborockLoggerAdapter
710
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
811

12+
_LOGGER = logging.getLogger(__name__)
13+
914

1015
class RoborockLocalClientV1(RoborockLocalClient, RoborockClientV1):
16+
"""Roborock local client for v1 devices."""
17+
1118
def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
19+
"""Initialize the Roborock local client."""
1220
RoborockLocalClient.__init__(self, device_data, queue_timeout)
13-
RoborockClientV1.__init__(self, device_data, self._logger, "abc")
21+
RoborockClientV1.__init__(self, device_data, "abc")
22+
self._logger = RoborockLoggerAdapter(device_data.device.name, _LOGGER)
1423

1524
def build_roborock_message(
1625
self, method: RoborockCommand | str, params: list | dict | int | None = None

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import base64
2+
import logging
23

3-
import paho.mqtt.client as mqtt
44
from vacuum_map_parser_base.config.color import ColorsPalette
55
from vacuum_map_parser_base.config.image_config import ImageConfig
66
from vacuum_map_parser_base.config.size import Sizes
@@ -16,23 +16,25 @@
1616
RoborockMessageProtocol,
1717
)
1818
from ..roborock_typing import RoborockCommand
19+
from ..util import RoborockLoggerAdapter
1920
from .roborock_client_v1 import COMMANDS_SECURED, CUSTOM_COMMANDS, RoborockClientV1
2021

22+
_LOGGER = logging.getLogger(__name__)
23+
2124

2225
class RoborockMqttClientV1(RoborockMqttClient, RoborockClientV1):
26+
"""Roborock mqtt client for v1 devices."""
27+
2328
def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout: int = 10) -> None:
29+
"""Initialize the Roborock mqtt client."""
2430
rriot = user_data.rriot
2531
if rriot is None:
2632
raise RoborockException("Got no rriot data from user_data")
2733
endpoint = base64.b64encode(Utils.md5(rriot.k.encode())[8:14]).decode()
2834

2935
RoborockMqttClient.__init__(self, user_data, device_info, queue_timeout)
30-
RoborockClientV1.__init__(self, device_info, self._logger, endpoint)
31-
32-
def _send_msg_raw(self, msg: bytes) -> None:
33-
info = self.publish(f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}", msg)
34-
if info.rc != mqtt.MQTT_ERR_SUCCESS:
35-
raise RoborockException(f"Failed to publish ({mqtt.error_string(info.rc)})")
36+
RoborockClientV1.__init__(self, device_info, endpoint)
37+
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
3638

3739
async def send_message(self, roborock_message: RoborockMessage):
3840
await self.validate_connection()

roborock/version_a01_apis/roborock_client_a01.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import dataclasses
22
import json
3+
import logging
34
import typing
5+
from abc import ABC, abstractmethod
46
from collections.abc import Callable
57
from datetime import time
68

@@ -38,6 +40,8 @@
3840
RoborockZeoProtocol,
3941
)
4042

43+
_LOGGER = logging.getLogger(__name__)
44+
4145

4246
@dataclasses.dataclass
4347
class A01ProtocolCacheEntry:
@@ -101,9 +105,12 @@ class A01ProtocolCacheEntry:
101105
}
102106

103107

104-
class RoborockClientA01(RoborockClient):
105-
def __init__(self, endpoint: str, device_info: DeviceData, category: RoborockCategory, queue_timeout: int = 4):
106-
super().__init__(endpoint, device_info, queue_timeout)
108+
class RoborockClientA01(RoborockClient, ABC):
109+
"""Roborock client base class for A01 devices."""
110+
111+
def __init__(self, device_info: DeviceData, category: RoborockCategory, queue_timeout: int = 4):
112+
"""Initialize the Roborock client."""
113+
super().__init__(device_info, queue_timeout)
107114
self.category = category
108115

109116
def on_message_received(self, messages: list[RoborockMessage]) -> None:
@@ -137,8 +144,8 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
137144
if queue and queue.protocol == protocol:
138145
queue.set_result(converted_response)
139146

147+
@abstractmethod
140148
async def update_values(
141149
self, dyad_data_protocols: list[RoborockDyadDataProtocol | RoborockZeoProtocol]
142150
) -> dict[RoborockDyadDataProtocol | RoborockZeoProtocol, typing.Any]:
143151
"""This should handle updating for each given protocol."""
144-
raise NotImplementedError

0 commit comments

Comments
 (0)