diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 87814391..dfb924e3 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -147,34 +147,45 @@ async def start_connect(self) -> None: called. The device will automatically attempt to reconnect if the connection is lost. """ - start_attempt: asyncio.Event = asyncio.Event() + # The future will be set to True if the first attempt succeeds, False if + # it fails, or an exception if an unexpected error occurs. + # We use this to wait a short time for the first attempt to complete. We + # don't actually care about the result, just that we waited long enough. + start_attempt: asyncio.Future[bool] = asyncio.Future() async def connect_loop() -> None: - backoff = MIN_BACKOFF_INTERVAL try: + backoff = MIN_BACKOFF_INTERVAL while True: try: await self.connect() - start_attempt.set() + if not start_attempt.done(): + start_attempt.set_result(True) self._has_connected = True self._ready_callbacks(self) return except RoborockException as e: - start_attempt.set() + if not start_attempt.done(): + start_attempt.set_result(False) self._logger.info("Failed to connect (retry %s): %s", backoff.total_seconds(), e) await asyncio.sleep(backoff.total_seconds()) backoff = min(backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL) + except Exception as e: # pylint: disable=broad-except + if not start_attempt.done(): + start_attempt.set_exception(e) + self._logger.exception("Uncaught error during connect: %s", e) + return except asyncio.CancelledError: self._logger.debug("connect_loop was cancelled for device %s", self.duid) - # Clean exit on cancellation - return finally: - start_attempt.set() + if not start_attempt.done(): + start_attempt.set_result(False) self._connect_task = asyncio.create_task(connect_loop()) try: - await asyncio.wait_for(start_attempt.wait(), timeout=START_ATTEMPT_TIMEOUT.total_seconds()) + async with asyncio.timeout(START_ATTEMPT_TIMEOUT.total_seconds()): + await start_attempt except TimeoutError: self._logger.debug("Initial connection attempt took longer than expected, will keep trying in background") @@ -190,6 +201,7 @@ async def connect(self) -> None: except RoborockException: unsub() raise + self._logger.info("Connected to device") self._unsub = unsub async def close(self) -> None: diff --git a/tests/devices/test_device_manager.py b/tests/devices/test_device_manager.py index 239762bb..85696df2 100644 --- a/tests/devices/test_device_manager.py +++ b/tests/devices/test_device_manager.py @@ -66,11 +66,17 @@ def mock_sleep() -> Generator[None, None, None]: yield +@pytest.fixture(name="channel_exception") +def channel_failure_exception_fixture(mock_rpc_channel: AsyncMock) -> Exception: + """Fixture that provides the exception to be raised by the failing channel.""" + return RoborockException("Connection failed") + + @pytest.fixture(name="channel_failure") -def channel_failure_fixture(mock_rpc_channel: AsyncMock) -> Generator[Mock, None, None]: +def channel_failure_fixture(mock_rpc_channel: AsyncMock, channel_exception: Exception) -> Generator[Mock, None, None]: """Fixture that makes channel subscribe fail.""" with patch("roborock.devices.device_manager.create_v1_channel") as mock_channel: - mock_channel.return_value.subscribe = AsyncMock(side_effect=RoborockException("Connection failed")) + mock_channel.return_value.subscribe = AsyncMock(side_effect=channel_exception) mock_channel.return_value.is_connected = False mock_channel.return_value.rpc_channel = mock_rpc_channel yield mock_channel @@ -192,6 +198,12 @@ async def test_ready_callback(home_data: HomeData) -> None: await device_manager.close() +@pytest.mark.parametrize( + ("channel_exception"), + [ + RoborockException("Connection failed"), + ], +) async def test_start_connect_failure(home_data: HomeData, channel_failure: Mock, mock_sleep: Mock) -> None: """Test that start_connect retries when connection fails.""" ready_devices: list[RoborockDevice] = [] @@ -231,3 +243,15 @@ async def test_start_connect_failure(home_data: HomeData, channel_failure: Mock, await device_manager.close() assert mock_unsub.call_count == 1 + + +@pytest.mark.parametrize( + ("channel_exception"), + [ + Exception("Unexpected error"), + ], +) +async def test_start_connect_unexpected_error(home_data: HomeData, channel_failure: Mock, mock_sleep: Mock) -> None: + """Test that some unexpected errors from start_connect are propagated.""" + with pytest.raises(Exception, match="Unexpected error"): + await create_device_manager(USER_PARAMS)