Skip to content

Commit 17b046c

Browse files
authored
Merge branch 'main' into cli-session
2 parents 18b8efd + e1a9e69 commit 17b046c

File tree

7 files changed

+79
-21
lines changed

7 files changed

+79
-21
lines changed

roborock/cli.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ def cli(ctx, debug: int):
6565

6666
@click.command()
6767
@click.option("--email", required=True)
68-
@click.option("--password", required=True)
68+
@click.option(
69+
"--password",
70+
required=False,
71+
help="Password for the Roborock account. If not provided, an email code will be requested.",
72+
)
6973
@click.pass_context
7074
@run_sync()
7175
async def login(ctx, email, password):
@@ -78,7 +82,14 @@ async def login(ctx, email, password):
7882
except RoborockException:
7983
pass
8084
client = RoborockApiClient(email)
81-
user_data = await client.pass_login(password)
85+
if password is not None:
86+
user_data = await client.pass_login(password)
87+
else:
88+
print(f"Requesting code for {email}")
89+
await client.request_code()
90+
code = click.prompt("A code has been sent to your email, please enter the code", type=str)
91+
user_data = await client.code_login(code)
92+
print("Login successful")
8293
context.update(LoginData(user_data=user_data, email=email))
8394

8495

roborock/cloud_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .api import KEEPALIVE, RoborockClient
1313
from .containers import DeviceData, UserData
1414
from .exceptions import RoborockException, VacuumError
15-
from .protocol import MessageParser, create_mqtt_params
15+
from .protocol import Decoder, Encoder, MessageParser, create_mqtt_decoder, create_mqtt_encoder, create_mqtt_params, md5hex
1616
from .roborock_future import RoborockFuture
1717

1818
_LOGGER = logging.getLogger(__name__)
@@ -68,6 +68,8 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
6868
self._mqtt_client.username_pw_set(mqtt_params.username, mqtt_params.password)
6969
self._waiting_queue: dict[int, RoborockFuture] = {}
7070
self._mutex = Lock()
71+
self._decoder: Decoder = create_mqtt_decoder(device_info.device.local_key)
72+
self._encoder: Encoder = create_mqtt_encoder(device_info.device.local_key)
7173

7274
def _mqtt_on_connect(self, *args, **kwargs):
7375
_, __, ___, rc, ____ = args
@@ -96,7 +98,7 @@ def _mqtt_on_connect(self, *args, **kwargs):
9698
def _mqtt_on_message(self, *args, **kwargs):
9799
client, __, msg = args
98100
try:
99-
messages, _ = MessageParser.parse(msg.payload, local_key=self.device_info.device.local_key)
101+
messages = self._decoder(msg.payload)
100102
super().on_message_received(messages)
101103
except Exception as ex:
102104
self._logger.exception(ex)

roborock/local_api.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from . import DeviceData
1313
from .api import RoborockClient
1414
from .exceptions import RoborockConnectionException, RoborockException
15-
from .protocol import MessageParser
15+
from .protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
1616
from .roborock_message import RoborockMessage, RoborockMessageProtocol
1717

1818
_LOGGER = logging.getLogger(__name__)
@@ -44,20 +44,18 @@ def __init__(self, device_data: DeviceData):
4444
self.host = device_data.host
4545
self._batch_structs: list[RoborockMessage] = []
4646
self._executing = False
47-
self.remaining = b""
4847
self.transport: Transport | None = None
4948
self._mutex = Lock()
5049
self.keep_alive_task: TimerHandle | None = None
5150
RoborockClient.__init__(self, device_data)
5251
self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)
52+
self._encoder: Encoder = create_local_encoder(device_data.device.local_key)
53+
self._decoder: Decoder = create_local_decoder(device_data.device.local_key)
5354

5455
def _data_received(self, message):
5556
"""Called when data is received from the transport."""
56-
if self.remaining:
57-
message = self.remaining + message
58-
self.remaining = b""
59-
parser_msg, self.remaining = MessageParser.parse(message, local_key=self.device_info.device.local_key)
60-
self.on_message_received(parser_msg)
57+
parsed_msg = self._decoder(message)
58+
self.on_message_received(parsed_msg)
6159

6260
def _connection_lost(self, exc: Exception | None):
6361
"""Called when the transport connection is lost."""

roborock/protocol.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,56 @@ def create_mqtt_params(rriot: RRiot) -> MqttParams:
380380
username=hashed_user,
381381
password=hashed_password,
382382
)
383+
384+
385+
Decoder = Callable[[bytes], list[RoborockMessage]]
386+
Encoder = Callable[[RoborockMessage], bytes]
387+
388+
389+
def create_mqtt_decoder(local_key: str) -> Decoder:
390+
"""Create a decoder for MQTT messages."""
391+
392+
def decode(data: bytes) -> list[RoborockMessage]:
393+
"""Parse the given data into Roborock messages."""
394+
messages, _ = MessageParser.parse(data, local_key)
395+
return messages
396+
397+
return decode
398+
399+
400+
def create_mqtt_encoder(local_key: str) -> Encoder:
401+
"""Create an encoder for MQTT messages."""
402+
403+
def encode(messages: RoborockMessage) -> bytes:
404+
"""Build the given Roborock messages into a byte string."""
405+
return MessageParser.build(messages, local_key, prefixed=False)
406+
407+
return encode
408+
409+
410+
def create_local_decoder(local_key: str) -> Decoder:
411+
"""Create a decoder for local API messages."""
412+
413+
# This buffer is used to accumulate bytes until a complete message can be parsed.
414+
# It is defined outside the decode function to maintain state across calls.
415+
buffer: bytes = b""
416+
417+
def decode(bytes: bytes) -> list[RoborockMessage]:
418+
"""Parse the given data into Roborock messages."""
419+
nonlocal buffer
420+
buffer += bytes
421+
parsed_messages, remaining = MessageParser.parse(buffer, local_key=local_key)
422+
buffer = remaining
423+
return parsed_messages
424+
425+
return decode
426+
427+
428+
def create_local_encoder(local_key: str) -> Encoder:
429+
"""Create an encoder for local API messages."""
430+
431+
def encode(message: RoborockMessage) -> bytes:
432+
"""Called when data is sent to the transport."""
433+
return MessageParser.build(message, local_key=local_key)
434+
435+
return encode

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
66
from ..exceptions import VacuumError
7-
from ..protocol import MessageParser
87
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
98
from ..util import RoborockLoggerAdapter
109
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
@@ -57,8 +56,7 @@ async def send_message(self, roborock_message: RoborockMessage):
5756
response_protocol = RoborockMessageProtocol.GENERAL_REQUEST
5857
if request_id is None:
5958
raise RoborockException(f"Failed build message {roborock_message}")
60-
local_key = self.device_info.device.local_key
61-
msg = MessageParser.build(roborock_message, local_key=local_key)
59+
msg = self._encoder(roborock_message)
6260
if method:
6361
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
6462
# Send the command to the Roborock device

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from ..containers import DeviceData, UserData
1212
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
13-
from ..protocol import MessageParser, Utils
13+
from ..protocol import Utils
1414
from ..roborock_message import (
1515
RoborockMessage,
1616
RoborockMessageProtocol,
@@ -47,9 +47,7 @@ async def send_message(self, roborock_message: RoborockMessage):
4747
response_protocol = (
4848
RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE
4949
)
50-
51-
local_key = self.device_info.device.local_key
52-
msg = MessageParser.build(roborock_message, local_key, False)
50+
msg = self._encoder(roborock_message)
5351
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
5452
async_response = self._async_response(request_id, response_protocol)
5553
self._send_msg_raw(msg)

roborock/version_a01_apis/roborock_mqtt_client_a01.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from roborock.cloud_api import RoborockMqttClient
1010
from roborock.containers import DeviceData, RoborockCategory, UserData
1111
from roborock.exceptions import RoborockException
12-
from roborock.protocol import MessageParser
1312
from roborock.roborock_message import (
1413
RoborockDyadDataProtocol,
1514
RoborockMessage,
@@ -43,8 +42,7 @@ async def send_message(self, roborock_message: RoborockMessage):
4342
await self.validate_connection()
4443
response_protocol = RoborockMessageProtocol.RPC_RESPONSE
4544

46-
local_key = self.device_info.device.local_key
47-
m = MessageParser.build(roborock_message, local_key, prefixed=False)
45+
m = self._encoder(roborock_message)
4846
# self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
4947
payload = json.loads(unpad(roborock_message.payload, AES.block_size))
5048
futures = []

0 commit comments

Comments
 (0)