diff --git a/roborock/mqtt/roborock_session.py b/roborock/mqtt/roborock_session.py index b35e3910..7f02f216 100644 --- a/roborock/mqtt/roborock_session.py +++ b/roborock/mqtt/roborock_session.py @@ -15,11 +15,11 @@ from contextlib import asynccontextmanager import aiomqtt -from aiomqtt import MqttError, TLSParameters +from aiomqtt import MqttCodeError, MqttError, TLSParameters from roborock.callbacks import CallbackMap -from .session import MqttParams, MqttSession, MqttSessionException +from .session import MqttParams, MqttSession, MqttSessionException, MqttSessionUnauthorized _LOGGER = logging.getLogger(__name__) _MQTT_LOGGER = logging.getLogger(f"{__name__}.aiomqtt") @@ -33,6 +33,16 @@ BACKOFF_MULTIPLIER = 1.5 +class MqttReasonCode: + """MQTT Reason Codes used by Roborock devices. + + This is a subset of paho.mqtt.reasoncodes.ReasonCode where we would like + different error handling behavior. + """ + + RC_ERROR_UNAUTHORIZED = 135 + + class RoborockMqttSession(MqttSession): """An MQTT session for sending and receiving messages. @@ -83,6 +93,10 @@ async def start(self) -> None: self._reconnect_task = loop.create_task(self._run_reconnect_loop(start_future)) try: await start_future + except MqttCodeError as err: + if err.rc == MqttReasonCode.RC_ERROR_UNAUTHORIZED: + raise MqttSessionUnauthorized(f"Authorization error starting MQTT session: {err}") from err + raise MqttSessionException(f"Error starting MQTT session: {err}") from err except MqttError as err: raise MqttSessionException(f"Error starting MQTT session: {err}") from err except Exception as err: diff --git a/roborock/mqtt/session.py b/roborock/mqtt/session.py index f5922d23..02295f1e 100644 --- a/roborock/mqtt/session.py +++ b/roborock/mqtt/session.py @@ -64,4 +64,17 @@ async def close(self) -> None: class MqttSessionException(RoborockException): - """ "Raised when there is an error communicating with MQTT.""" + """Raised when there is an error communicating with MQTT.""" + + +class MqttSessionUnauthorized(RoborockException): + """Raised when there is an authorization error communicating with MQTT. + + This error may be raised in multiple scenarios so there is not a well + defined behavior for how the caller should behave. The two cases are: + - Rate limiting is in effect and the caller should retry after some time. + - The credentials are invalid and the caller needs to obtain new credentials + + However, it is observed that obtaining new credentials may resolve the + issue in both cases. + """ diff --git a/tests/mqtt/test_roborock_session.py b/tests/mqtt/test_roborock_session.py index f3b10139..505ba539 100644 --- a/tests/mqtt/test_roborock_session.py +++ b/tests/mqtt/test_roborock_session.py @@ -12,7 +12,7 @@ import pytest from roborock.mqtt.roborock_session import RoborockMqttSession, create_mqtt_session -from roborock.mqtt.session import MqttParams, MqttSessionException +from roborock.mqtt.session import MqttParams, MqttSessionException, MqttSessionUnauthorized from tests import mqtt_packet from tests.conftest import FakeSocketHandler @@ -366,3 +366,42 @@ async def test_idle_timeout_multiple_callbacks(mock_mqtt_client: AsyncMock) -> N mock_mqtt_client.unsubscribe.assert_called_once_with(topic) await session.close() + + +@pytest.mark.parametrize( + ("side_effect", "expected_exception", "match"), + [ + ( + aiomqtt.MqttError("Connection failed"), + MqttSessionException, + "Error starting MQTT session", + ), + ( + aiomqtt.MqttCodeError(rc=135), + MqttSessionUnauthorized, + "Authorization error starting MQTT session", + ), + ( + aiomqtt.MqttCodeError(rc=128), + MqttSessionException, + "Error starting MQTT session", + ), + ( + ValueError("Unexpected"), + MqttSessionException, + "Unexpected error starting session", + ), + ], +) +async def test_connect_failure( + side_effect: Exception, + expected_exception: type[Exception], + match: str, +) -> None: + """Test connection failure with different exceptions.""" + mock_aenter = AsyncMock() + mock_aenter.side_effect = side_effect + + with patch("roborock.mqtt.roborock_session.aiomqtt.Client.__aenter__", mock_aenter): + with pytest.raises(expected_exception, match=match): + await create_mqtt_session(FAKE_PARAMS)