Skip to content

Commit 45bfe1b

Browse files
committed
fix: Ensure immediate local connection is attempted
Always start a local connection immediately to ensure initial RPCs can be sent locally. Without this, the first RPC may be sent over MQTT if the local connection didn't have a chance to start yet.
1 parent a80b306 commit 45bfe1b

File tree

3 files changed

+25
-10
lines changed

3 files changed

+25
-10
lines changed

roborock/devices/device.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10)
3030
MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=30)
3131
BACKOFF_MULTIPLIER = 1.5
32+
START_ATTEMPT_TIMEOUT = datetime.timedelta(seconds=5)
3233

3334

3435
class RoborockDevice(ABC, TraitsMixin):
@@ -107,25 +108,32 @@ def is_local_connected(self) -> bool:
107108
"""
108109
return self._channel.is_local_connected
109110

110-
def start_connect(self) -> None:
111+
async def start_connect(self) -> None:
111112
"""Start a background task to connect to the device.
112113
113-
This will attempt to connect to the device using the appropriate protocol
114-
channel. If the connection fails, it will retry with exponential backoff.
114+
This will give a moment for the first connection attempt to start so
115+
that the device will have connections established -- however, this will
116+
never directly fail.
117+
118+
If the connection fails, it will retry in the background with
119+
exponential backoff.
115120
116121
Once connected, the device will remain connected until `close()` is
117122
called. The device will automatically attempt to reconnect if the connection
118123
is lost.
119124
"""
125+
start_attempt: asyncio.Event = asyncio.Event()
120126

121127
async def connect_loop() -> None:
122128
backoff = MIN_BACKOFF_INTERVAL
123129
try:
124130
while True:
125131
try:
126132
await self.connect()
133+
start_attempt.set()
127134
return
128135
except RoborockException as e:
136+
start_attempt.set()
129137
_LOGGER.info("Failed to connect to device %s: %s", self.name, e)
130138
_LOGGER.info(
131139
"Retrying connection to device %s in %s seconds", self.name, backoff.total_seconds()
@@ -136,8 +144,16 @@ async def connect_loop() -> None:
136144
_LOGGER.info("connect_loop for device %s was cancelled", self.name)
137145
# Clean exit on cancellation
138146
return
147+
finally:
148+
start_attempt.set()
139149

140150
self._connect_task = asyncio.create_task(connect_loop())
151+
152+
try:
153+
await asyncio.wait_for(start_attempt.wait(), timeout=START_ATTEMPT_TIMEOUT.total_seconds())
154+
except asyncio.Timeout:
155+
_LOGGER.debug("Initial connection attempt to device %s is taking longer than expected", self.name)
156+
141157

142158
async def connect(self) -> None:
143159
"""Connect to the device using the appropriate protocol channel."""

roborock/devices/device_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,16 @@ async def discover_devices(self) -> list[RoborockDevice]:
8282

8383
# These are connected serially to avoid overwhelming the MQTT broker
8484
new_devices = {}
85+
start_tasks = []
8586
for duid, (device, product) in device_products.items():
8687
if duid in self._devices:
8788
continue
8889
new_device = self._device_creator(home_data, device, product)
89-
new_device.start_connect()
90+
start_tasks.append(new_device.start_connect())
9091
new_devices[duid] = new_device
9192

9293
self._devices.update(new_devices)
94+
await asyncio.gather(*start_tasks)
9395
return list(self._devices.values())
9496

9597
async def get_device(self, duid: str) -> RoborockDevice | None:

tests/devices/test_device_manager.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,10 @@ async def test_start_connect_failure(home_data: HomeData, channel_failure: Mock,
176176
device_manager = await create_device_manager(USER_PARAMS)
177177
devices = await device_manager.get_devices()
178178

179-
# Wait for the device to attempt to connect
180-
attempts = 0
179+
# The device should attempt to connect in the background at least once
180+
# by the time this function returns.
181181
subscribe_mock = channel_failure.return_value.subscribe
182-
while subscribe_mock.call_count < 1:
183-
await asyncio.sleep(0.01)
184-
attempts += 1
185-
assert attempts < 10, "Device did not connect after multiple attempts"
182+
assert subscribe_mock.call_count > 0
186183

187184
# Device should exist but not be connected
188185
assert len(devices) == 1

0 commit comments

Comments
 (0)