diff --git a/roborock/devices/device.py b/roborock/devices/device.py index d86b3f38..31931139 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -29,6 +29,7 @@ MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10) MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=30) BACKOFF_MULTIPLIER = 1.5 +START_ATTEMPT_TIMEOUT = datetime.timedelta(seconds=5) class RoborockDevice(ABC, TraitsMixin): @@ -107,16 +108,21 @@ def is_local_connected(self) -> bool: """ return self._channel.is_local_connected - def start_connect(self) -> None: + async def start_connect(self) -> None: """Start a background task to connect to the device. - This will attempt to connect to the device using the appropriate protocol - channel. If the connection fails, it will retry with exponential backoff. + This will give a moment for the first connection attempt to start so + that the device will have connections established -- however, this will + never directly fail. + + If the connection fails, it will retry in the background with + exponential backoff. Once connected, the device will remain connected until `close()` is called. The device will automatically attempt to reconnect if the connection is lost. """ + start_attempt: asyncio.Event = asyncio.Event() async def connect_loop() -> None: backoff = MIN_BACKOFF_INTERVAL @@ -124,8 +130,10 @@ async def connect_loop() -> None: while True: try: await self.connect() + start_attempt.set() return except RoborockException as e: + start_attempt.set() _LOGGER.info("Failed to connect to device %s: %s", self.name, e) _LOGGER.info( "Retrying connection to device %s in %s seconds", self.name, backoff.total_seconds() @@ -136,9 +144,16 @@ async def connect_loop() -> None: _LOGGER.info("connect_loop for device %s was cancelled", self.name) # Clean exit on cancellation return + finally: + start_attempt.set() self._connect_task = asyncio.create_task(connect_loop()) + try: + await asyncio.wait_for(start_attempt.wait(), timeout=START_ATTEMPT_TIMEOUT.total_seconds()) + except TimeoutError: + _LOGGER.debug("Initial connection attempt to device %s is taking longer than expected", self.name) + async def connect(self) -> None: """Connect to the device using the appropriate protocol channel.""" if self._unsub: diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index a244cbc9..b42a5d9e 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -82,14 +82,16 @@ async def discover_devices(self) -> list[RoborockDevice]: # These are connected serially to avoid overwhelming the MQTT broker new_devices = {} + start_tasks = [] for duid, (device, product) in device_products.items(): if duid in self._devices: continue new_device = self._device_creator(home_data, device, product) - new_device.start_connect() + start_tasks.append(new_device.start_connect()) new_devices[duid] = new_device self._devices.update(new_devices) + await asyncio.gather(*start_tasks) return list(self._devices.values()) async def get_device(self, duid: str) -> RoborockDevice | None: diff --git a/tests/devices/test_device_manager.py b/tests/devices/test_device_manager.py index 3b0584c6..79acd333 100644 --- a/tests/devices/test_device_manager.py +++ b/tests/devices/test_device_manager.py @@ -176,13 +176,10 @@ async def test_start_connect_failure(home_data: HomeData, channel_failure: Mock, device_manager = await create_device_manager(USER_PARAMS) devices = await device_manager.get_devices() - # Wait for the device to attempt to connect - attempts = 0 + # The device should attempt to connect in the background at least once + # by the time this function returns. subscribe_mock = channel_failure.return_value.subscribe - while subscribe_mock.call_count < 1: - await asyncio.sleep(0.01) - attempts += 1 - assert attempts < 10, "Device did not connect after multiple attempts" + assert subscribe_mock.call_count > 0 # Device should exist but not be connected assert len(devices) == 1