diff --git a/roborock/cli.py b/roborock/cli.py index a6959784..4532ca21 100644 --- a/roborock/cli.py +++ b/roborock/cli.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import logging from pathlib import Path @@ -12,7 +13,8 @@ from roborock import RoborockException from roborock.containers import DeviceData, HomeDataProduct, LoginData -from roborock.protocol import MessageParser +from roborock.mqtt.roborock_session import create_mqtt_session +from roborock.protocol import MessageParser, create_mqtt_params from roborock.util import run_sync from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1 from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1 @@ -45,7 +47,8 @@ def validate(self): if self._login_data is None: raise RoborockException("You must login first") - def login_data(self): + def login_data(self) -> LoginData: + """Get the login data.""" self.validate() return self._login_data @@ -90,6 +93,54 @@ async def login(ctx, email, password): context.update(LoginData(user_data=user_data, email=email)) +@click.command() +@click.pass_context +@click.option("--duration", default=10, help="Duration to run the MQTT session in seconds") +@run_sync() +async def session(ctx, duration: int): + context: RoborockContext = ctx.obj + login_data = context.login_data() + + # Discovery devices if not already available + if not login_data.home_data: + await _discover(ctx) + login_data = context.login_data() + if not login_data.home_data or not login_data.home_data.devices: + raise RoborockException("Unable to discover devices") + + all_devices = login_data.home_data.devices + login_data.home_data.received_devices + click.echo(f"Discovered devices: {', '.join([device.name for device in all_devices])}") + + rriot = login_data.user_data.rriot + params = create_mqtt_params(rriot) + + mqtt_session = await create_mqtt_session(params) + click.echo("Starting MQTT session...") + if not mqtt_session.connected: + raise RoborockException("Failed to connect to MQTT broker") + + def on_message(bytes: bytes): + """Callback function to handle incoming MQTT messages.""" + # Decode the first 20 bytes of the message for display + bytes = bytes[:20] + + click.echo(f"Received message: {bytes}...") + + unsubs = [] + for device in all_devices: + device_topic = f"rr/m/o/{rriot.u}/{params.username}/{device.duid}" + unsub = await mqtt_session.subscribe(device_topic, on_message) + unsubs.append(unsub) + + click.echo("MQTT session started. Listening for messages...") + await asyncio.sleep(duration) + + click.echo("Stopping MQTT session...") + for unsub in unsubs: + unsub() + await mqtt_session.close() + + async def _discover(ctx): context: RoborockContext = ctx.obj login_data = context.login_data() @@ -264,6 +315,7 @@ def on_package(packet: Packet): cli.add_command(status) cli.add_command(command) cli.add_command(parser) +cli.add_command(session) def main(): diff --git a/roborock/cloud_api.py b/roborock/cloud_api.py index 68a0a853..0a7f2aa3 100644 --- a/roborock/cloud_api.py +++ b/roborock/cloud_api.py @@ -6,14 +6,19 @@ from abc import ABC from asyncio import Lock from typing import Any -from urllib.parse import urlparse import paho.mqtt.client as mqtt from .api import KEEPALIVE, RoborockClient from .containers import DeviceData, UserData from .exceptions import RoborockException, VacuumError -from .protocol import Decoder, Encoder, create_mqtt_decoder, create_mqtt_encoder, md5hex +from .protocol import ( + Decoder, + Encoder, + create_mqtt_decoder, + create_mqtt_encoder, + create_mqtt_params, +) from .roborock_future import RoborockFuture _LOGGER = logging.getLogger(__name__) @@ -53,25 +58,20 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None: if rriot is None: raise RoborockException("Got no rriot data from user_data") RoborockClient.__init__(self, device_info) + mqtt_params = create_mqtt_params(rriot) self._mqtt_user = rriot.u - self._hashed_user = md5hex(self._mqtt_user + ":" + rriot.k)[2:10] - url = urlparse(rriot.r.m) - if not isinstance(url.hostname, str): - raise RoborockException("Url parsing returned an invalid hostname") - self._mqtt_host = str(url.hostname) - self._mqtt_port = url.port - self._mqtt_ssl = url.scheme == "ssl" + self._hashed_user = mqtt_params.username + self._mqtt_host = mqtt_params.host + self._mqtt_port = mqtt_params.port self._mqtt_client = _Mqtt() self._mqtt_client.on_connect = self._mqtt_on_connect self._mqtt_client.on_message = self._mqtt_on_message self._mqtt_client.on_disconnect = self._mqtt_on_disconnect - if self._mqtt_ssl: + if mqtt_params.tls: self._mqtt_client.tls_set() - self._mqtt_password = rriot.s - self._hashed_password = md5hex(self._mqtt_password + ":" + rriot.k)[16:] - self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password) + self._mqtt_client.username_pw_set(mqtt_params.username, mqtt_params.password) self._waiting_queue: dict[int, RoborockFuture] = {} self._mutex = Lock() self._decoder: Decoder = create_mqtt_decoder(device_info.device.local_key) diff --git a/roborock/mqtt/roborock_session.py b/roborock/mqtt/roborock_session.py index edd8fc7e..e28fe152 100644 --- a/roborock/mqtt/roborock_session.py +++ b/roborock/mqtt/roborock_session.py @@ -116,9 +116,9 @@ async def _run_task(self, start_future: asyncio.Future[None] | None) -> None: _LOGGER.info("MQTT error: %s", err) except asyncio.CancelledError as err: if start_future: - _LOGGER.debug("MQTT loop was cancelled") + _LOGGER.debug("MQTT loop was cancelled while starting") start_future.set_exception(err) - _LOGGER.debug("MQTT loop was cancelled while starting") + _LOGGER.debug("MQTT loop was cancelled") return # Catch exceptions to avoid crashing the loop # and to allow the loop to retry. @@ -171,8 +171,7 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client: self._client = None async def _process_message_loop(self, client: aiomqtt.Client) -> None: - _LOGGER.debug("client=%s", client) - _LOGGER.debug("Processing MQTT messages: %s", client.messages) + _LOGGER.debug("Processing MQTT messages") async for message in client.messages: _LOGGER.debug("Received message: %s", message) for listener in self._listeners.get(message.topic.value, []): diff --git a/roborock/protocol.py b/roborock/protocol.py index 9626eb43..5edcb347 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -8,6 +8,7 @@ import logging from asyncio import BaseTransport, Lock from collections.abc import Callable +from urllib.parse import urlparse from construct import ( # type: ignore Bytes, @@ -30,7 +31,9 @@ from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad -from roborock import BroadcastMessage, RoborockException +from roborock.containers import BroadcastMessage, RRiot +from roborock.exceptions import RoborockException +from roborock.mqtt.session import MqttParams from roborock.roborock_message import RoborockMessage _LOGGER = logging.getLogger(__name__) @@ -361,6 +364,24 @@ def build( BroadcastParser: _Parser = _Parser(_BroadcastMessage, False) +def create_mqtt_params(rriot: RRiot) -> MqttParams: + """Return the MQTT parameters for this user.""" + url = urlparse(rriot.r.m) + if not isinstance(url.hostname, str): + raise RoborockException(f"Url parsing '{rriot.r.m}' returned an invalid hostname") + if not url.port: + raise RoborockException(f"Url parsing '{rriot.r.m}' returned an invalid port") + hashed_user = md5hex(rriot.u + ":" + rriot.k)[2:10] + hashed_password = md5hex(rriot.s + ":" + rriot.k)[16:] + return MqttParams( + host=str(url.hostname), + port=url.port, + tls=(url.scheme == "ssl"), + username=hashed_user, + password=hashed_password, + ) + + Decoder = Callable[[bytes], list[RoborockMessage]] Encoder = Callable[[RoborockMessage], bytes]