diff --git a/roborock/devices/a01_channel.py b/roborock/devices/a01_channel.py index ae1a5d18..39818a12 100644 --- a/roborock/devices/a01_channel.py +++ b/roborock/devices/a01_channel.py @@ -1,6 +1,7 @@ """Thin wrapper around the MQTT channel for Roborock A01 devices.""" import asyncio +import json import logging from typing import Any, overload @@ -54,6 +55,13 @@ async def send_decoded_command( await mqtt_channel.publish(roborock_message) return {} + if isinstance(query_values, str): + try: + query_values = json.loads(query_values) + except ValueError: + _LOGGER.warning("Failed to parse query values: %s", query_values) + return {} + # Merge any results together than contain the requested data. This # does not use a future since it needs to merge results across responses. # This could be simplified if we can assume there is a single response. diff --git a/tests/test_a01_api.py b/tests/test_a01_api.py index f582f5f2..f412e4bb 100644 --- a/tests/test_a01_api.py +++ b/tests/test_a01_api.py @@ -3,7 +3,7 @@ from collections.abc import AsyncGenerator from queue import Queue from typing import Any -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import paho.mqtt.client as mqtt import pytest @@ -307,3 +307,38 @@ async def test_future_timeout( with patch("roborock.roborock_future.asyncio.timeout", side_effect=asyncio.TimeoutError): data = await connected_a01_mqtt_client.update_values([RoborockZeoProtocol.STATE]) assert data.get(RoborockZeoProtocol.STATE) is None + + +async def test_send_decoded_command_handles_stringified_query() -> None: + """Test that send_decoded_command handles ID_QUERY as a stringified list.""" + from roborock.devices.a01_channel import send_decoded_command + from roborock.devices.mqtt_channel import MqttChannel + from roborock.roborock_message import RoborockDyadDataProtocol, RoborockMessage, RoborockMessageProtocol + + channel = MagicMock(spec=MqttChannel) + channel.publish = AsyncMock() + + captured_callback = None + + async def mock_subscribe(callback): + nonlocal captured_callback + captured_callback = callback + return lambda: None + + channel.subscribe = AsyncMock(side_effect=mock_subscribe) + + protocol_id = 101 + params = {RoborockDyadDataProtocol.ID_QUERY: str([protocol_id])} + + task = asyncio.create_task(send_decoded_command(channel, params)) + await asyncio.sleep(0) + + response_data = {"dps": {str(protocol_id): 123}} + payload = pad(json.dumps(response_data).encode("utf-8"), AES.block_size) + message = RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=payload) + + if captured_callback: + captured_callback(message) + + result = await task + assert result == {protocol_id: 123}