Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions roborock/mqtt/roborock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from contextlib import asynccontextmanager

import aiomqtt
from aiomqtt import MqttError, TLSParameters
from aiomqtt import MqttCodeError, MqttError, TLSParameters

from roborock.callbacks import CallbackMap

from .session import MqttParams, MqttSession, MqttSessionException
from .session import MqttParams, MqttSession, MqttSessionException, MqttSessionUnauthorized

_LOGGER = logging.getLogger(__name__)
_MQTT_LOGGER = logging.getLogger(f"{__name__}.aiomqtt")
Expand All @@ -33,6 +33,16 @@
BACKOFF_MULTIPLIER = 1.5


class MqttReasonCode:
"""MQTT Reason Codes used by Roborock devices.

This is a subset of paho.mqtt.reasoncodes.ReasonCode where we would like
different error handling behavior.
"""

RC_ERROR_UNAUTHORIZED = 135


class RoborockMqttSession(MqttSession):
"""An MQTT session for sending and receiving messages.

Expand Down Expand Up @@ -83,6 +93,10 @@ async def start(self) -> None:
self._reconnect_task = loop.create_task(self._run_reconnect_loop(start_future))
try:
await start_future
except MqttCodeError as err:
if err.rc == MqttReasonCode.RC_ERROR_UNAUTHORIZED:
raise MqttSessionUnauthorized(f"Authorization error starting MQTT session: {err}") from err
raise MqttSessionException(f"Error starting MQTT session: {err}") from err
except MqttError as err:
raise MqttSessionException(f"Error starting MQTT session: {err}") from err
except Exception as err:
Expand Down
15 changes: 14 additions & 1 deletion roborock/mqtt/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,17 @@ async def close(self) -> None:


class MqttSessionException(RoborockException):
""" "Raised when there is an error communicating with MQTT."""
"""Raised when there is an error communicating with MQTT."""


class MqttSessionUnauthorized(RoborockException):
"""Raised when there is an authorization error communicating with MQTT.

This error may be raised in multiple scenarios so there is not a well
defined behavior for how the caller should behave. The two cases are:
- Rate limiting is in effect and the caller should retry after some time.
- The credentials are invalid and the caller needs to obtain new credentials

However, it is observed that obtaining new credentials may resolve the
issue in both cases.
"""
41 changes: 40 additions & 1 deletion tests/mqtt/test_roborock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest

from roborock.mqtt.roborock_session import RoborockMqttSession, create_mqtt_session
from roborock.mqtt.session import MqttParams, MqttSessionException
from roborock.mqtt.session import MqttParams, MqttSessionException, MqttSessionUnauthorized
from tests import mqtt_packet
from tests.conftest import FakeSocketHandler

Expand Down Expand Up @@ -366,3 +366,42 @@ async def test_idle_timeout_multiple_callbacks(mock_mqtt_client: AsyncMock) -> N
mock_mqtt_client.unsubscribe.assert_called_once_with(topic)

await session.close()


@pytest.mark.parametrize(
("side_effect", "expected_exception", "match"),
[
(
aiomqtt.MqttError("Connection failed"),
MqttSessionException,
"Error starting MQTT session",
),
(
aiomqtt.MqttCodeError(rc=135),
MqttSessionUnauthorized,
"Authorization error starting MQTT session",
),
(
aiomqtt.MqttCodeError(rc=128),
MqttSessionException,
"Error starting MQTT session",
),
(
ValueError("Unexpected"),
MqttSessionException,
"Unexpected error starting session",
),
],
)
async def test_connect_failure(
side_effect: Exception,
expected_exception: type[Exception],
match: str,
) -> None:
"""Test connection failure with different exceptions."""
mock_aenter = AsyncMock()
mock_aenter.side_effect = side_effect

with patch("roborock.mqtt.roborock_session.aiomqtt.Client.__aenter__", mock_aenter):
with pytest.raises(expected_exception, match=match):
await create_mqtt_session(FAKE_PARAMS)