diff --git a/roborock/devices/local_channel.py b/roborock/devices/local_channel.py index 13906893..52599628 100644 --- a/roborock/devices/local_channel.py +++ b/roborock/devices/local_channel.py @@ -11,6 +11,7 @@ from roborock.roborock_message import RoborockMessage from .channel import Channel +from .pending import PendingRpcs _LOGGER = logging.getLogger(__name__) _PORT = 58867 @@ -47,10 +48,9 @@ def __init__(self, host: str, local_key: str): self._is_connected = False # RPC support - self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {} + self._pending_rpcs: PendingRpcs[int, RoborockMessage] = PendingRpcs() self._decoder: Decoder = create_local_decoder(local_key) self._encoder: Encoder = create_local_encoder(local_key) - self._queue_lock = asyncio.Lock() @property def is_connected(self) -> bool: @@ -114,11 +114,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None: if (request_id := message.get_request_id()) is None: _LOGGER.debug("Received message with no request_id") return - async with self._queue_lock: - if (future := self._waiting_queue.pop(request_id, None)) is not None: - future.set_result(message) - else: - _LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id) + await self._pending_rpcs.resolve(request_id, message) async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage: """Send a command message and wait for the response message.""" @@ -132,24 +128,17 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> _LOGGER.exception("Error getting request_id from message: %s", err) raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err - future: asyncio.Future[RoborockMessage] = asyncio.Future() - async with self._queue_lock: - if request_id in self._waiting_queue: - raise RoborockException(f"Request ID {request_id} already pending, cannot send command") - self._waiting_queue[request_id] = future - + future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id) try: encoded_msg = self._encoder(message) self._transport.write(encoded_msg) return await asyncio.wait_for(future, timeout=timeout) except asyncio.TimeoutError as ex: - async with self._queue_lock: - self._waiting_queue.pop(request_id, None) + await self._pending_rpcs.pop(request_id) raise RoborockException(f"Command timed out after {timeout}s") from ex except Exception: logging.exception("Uncaught error sending command") - async with self._queue_lock: - self._waiting_queue.pop(request_id, None) + await self._pending_rpcs.pop(request_id) raise diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index c7be8c12..27882477 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -12,6 +12,7 @@ from roborock.roborock_message import RoborockMessage from .channel import Channel +from .pending import PendingRpcs _LOGGER = logging.getLogger(__name__) @@ -31,10 +32,9 @@ def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot: self._mqtt_params = mqtt_params # RPC support - self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {} + self._pending_rpcs: PendingRpcs[int, RoborockMessage] = PendingRpcs() self._decoder = create_mqtt_decoder(local_key) self._encoder = create_mqtt_encoder(local_key) - self._queue_lock = asyncio.Lock() self._mqtt_unsub: Callable[[], None] | None = None @property @@ -89,11 +89,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None: if (request_id := message.get_request_id()) is None: _LOGGER.debug("Received message with no request_id") return - async with self._queue_lock: - if (future := self._waiting_queue.pop(request_id, None)) is not None: - future.set_result(message) - else: - _LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id) + await self._pending_rpcs.resolve(request_id, message) async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage: """Send a command message and wait for the response message. @@ -107,11 +103,7 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> _LOGGER.exception("Error getting request_id from message: %s", err) raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err - future: asyncio.Future[RoborockMessage] = asyncio.Future() - async with self._queue_lock: - if request_id in self._waiting_queue: - raise RoborockException(f"Request ID {request_id} already pending, cannot send command") - self._waiting_queue[request_id] = future + future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id) try: encoded_msg = self._encoder(message) @@ -120,13 +112,11 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> return await asyncio.wait_for(future, timeout=timeout) except asyncio.TimeoutError as ex: - async with self._queue_lock: - self._waiting_queue.pop(request_id, None) + await self._pending_rpcs.pop(request_id) raise RoborockException(f"Command timed out after {timeout}s") from ex except Exception: logging.exception("Uncaught error sending command") - async with self._queue_lock: - self._waiting_queue.pop(request_id, None) + await self._pending_rpcs.pop(request_id) raise diff --git a/roborock/devices/pending.py b/roborock/devices/pending.py new file mode 100644 index 00000000..d1ab734d --- /dev/null +++ b/roborock/devices/pending.py @@ -0,0 +1,45 @@ +"""Module for managing pending RPCs.""" + +import asyncio +import logging +from typing import Generic, TypeVar + +from roborock.exceptions import RoborockException + +_LOGGER = logging.getLogger(__name__) + + +K = TypeVar("K") +V = TypeVar("V") + + +class PendingRpcs(Generic[K, V]): + """Manage pending RPCs.""" + + def __init__(self) -> None: + """Initialize the pending RPCs.""" + self._queue_lock = asyncio.Lock() + self._waiting_queue: dict[K, asyncio.Future[V]] = {} + + async def start(self, key: K) -> asyncio.Future[V]: + """Start the pending RPCs.""" + future: asyncio.Future[V] = asyncio.Future() + async with self._queue_lock: + if key in self._waiting_queue: + raise RoborockException(f"Request ID {key} already pending, cannot send command") + self._waiting_queue[key] = future + return future + + async def pop(self, key: K) -> None: + """Pop a pending RPC.""" + async with self._queue_lock: + if (future := self._waiting_queue.pop(key, None)) is not None: + future.cancel() + + async def resolve(self, key: K, value: V) -> None: + """Resolve waiting future with proper locking.""" + async with self._queue_lock: + if (future := self._waiting_queue.pop(key, None)) is not None: + future.set_result(value) + else: + _LOGGER.debug("Received unsolicited message: %s", key) diff --git a/tests/devices/test_pending.py b/tests/devices/test_pending.py new file mode 100644 index 00000000..d8ea1864 --- /dev/null +++ b/tests/devices/test_pending.py @@ -0,0 +1,75 @@ +"""Tests for the PendingRpcs class.""" + +import asyncio + +import pytest + +from roborock.devices.pending import PendingRpcs +from roborock.exceptions import RoborockException + + +@pytest.fixture(name="pending_rpcs") +def setup_pending_rpcs() -> PendingRpcs[int, str]: + """Fixture to set up the PendingRpcs for tests.""" + return PendingRpcs[int, str]() + + +async def test_start_duplicate_rpc_raises_exception(pending_rpcs: PendingRpcs[int, str]) -> None: + """Test that starting a duplicate RPC raises an exception.""" + key = 1 + await pending_rpcs.start(key) + with pytest.raises(RoborockException, match=f"Request ID {key} already pending, cannot send command"): + await pending_rpcs.start(key) + + +async def test_resolve_pending_rpc(pending_rpcs: PendingRpcs[int, str]) -> None: + """Test resolving a pending RPC.""" + key = 1 + value = "test_result" + future = await pending_rpcs.start(key) + await pending_rpcs.resolve(key, value) + result = await future + assert result == value + + +async def test_resolve_unsolicited_message( + pending_rpcs: PendingRpcs[int, str], caplog: pytest.LogCaptureFixture +) -> None: + """Test resolving an unsolicited message does not raise.""" + key = 1 + value = "test_result" + await pending_rpcs.resolve(key, value) + + +async def test_pop_pending_rpc(pending_rpcs: PendingRpcs[int, str]) -> None: + """Test popping a pending RPC, which should cancel the future.""" + key = 1 + future = await pending_rpcs.start(key) + await pending_rpcs.pop(key) + with pytest.raises(asyncio.CancelledError): + await future + + +async def test_pop_non_existent_rpc(pending_rpcs: PendingRpcs[int, str]) -> None: + """Test that popping a non-existent RPC does not raise an exception.""" + key = 1 + await pending_rpcs.pop(key) + + +async def test_concurrent_rpcs(pending_rpcs: PendingRpcs[int, str]) -> None: + """Test handling multiple concurrent RPCs.""" + + async def start_and_resolve(key: int, value: str) -> str: + future = await pending_rpcs.start(key) + await asyncio.sleep(0.01) # yield + await pending_rpcs.resolve(key, value) + return await future + + tasks = [ + asyncio.create_task(start_and_resolve(1, "result1")), + asyncio.create_task(start_and_resolve(2, "result2")), + asyncio.create_task(start_and_resolve(3, "result3")), + ] + + results = await asyncio.gather(*tasks) + assert results == ["result1", "result2", "result3"]