Skip to content

Commit 3e385a9

Browse files
Fail fast on invalid polling callback return types
Co-authored-by: Shri Sukhani <shrisukhani@users.noreply.github.com>
1 parent 6ddcd59 commit 3e385a9

2 files changed

Lines changed: 43 additions & 4 deletions

File tree

hyperbrowser/client/polling.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import inspect
23
import math
34
from numbers import Real
45
import time
@@ -60,6 +61,16 @@ def _ensure_status_string(status: object, *, operation_name: str) -> str:
6061
return status
6162

6263

64+
def _ensure_awaitable(
65+
result: object, *, callback_name: str, operation_name: str
66+
) -> Awaitable[object]:
67+
if not inspect.isawaitable(result):
68+
raise HyperbrowserError(
69+
f"{callback_name} must return an awaitable for {operation_name}"
70+
)
71+
return result
72+
73+
6374
def _validate_retry_config(
6475
*,
6576
max_attempts: int,
@@ -153,7 +164,7 @@ def poll_until_terminal_status(
153164

154165
while True:
155166
try:
156-
status = _ensure_status_string(get_status(), operation_name=operation_name)
167+
status = get_status()
157168
failures = 0
158169
except Exception as exc:
159170
failures += 1
@@ -168,6 +179,7 @@ def poll_until_terminal_status(
168179
time.sleep(poll_interval_seconds)
169180
continue
170181

182+
status = _ensure_status_string(status, operation_name=operation_name)
171183
if _ensure_boolean_terminal_result(
172184
is_terminal_status(status), operation_name=operation_name
173185
):
@@ -226,9 +238,25 @@ async def poll_until_terminal_status_async(
226238

227239
while True:
228240
try:
229-
status = _ensure_status_string(
230-
await get_status(), operation_name=operation_name
231-
)
241+
status_result = get_status()
242+
except Exception as exc:
243+
failures += 1
244+
if failures >= max_status_failures:
245+
raise HyperbrowserPollingError(
246+
f"Failed to poll {operation_name} after {max_status_failures} attempts: {exc}"
247+
) from exc
248+
if has_exceeded_max_wait(start_time, max_wait_seconds):
249+
raise HyperbrowserTimeoutError(
250+
f"Timed out waiting for {operation_name} after {max_wait_seconds} seconds"
251+
)
252+
await asyncio.sleep(poll_interval_seconds)
253+
continue
254+
255+
status_awaitable = _ensure_awaitable(
256+
status_result, callback_name="get_status", operation_name=operation_name
257+
)
258+
try:
259+
status = await status_awaitable
232260
failures = 0
233261
except Exception as exc:
234262
failures += 1
@@ -243,6 +271,7 @@ async def poll_until_terminal_status_async(
243271
await asyncio.sleep(poll_interval_seconds)
244272
continue
245273

274+
status = _ensure_status_string(status, operation_name=operation_name)
246275
if _ensure_boolean_terminal_result(
247276
is_terminal_status(status), operation_name=operation_name
248277
):

tests/test_polling.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,16 @@ async def validate_async_operation_name() -> None:
894894
poll_interval_seconds=0.0,
895895
max_wait_seconds=1.0,
896896
)
897+
with pytest.raises(
898+
HyperbrowserError, match="get_status must return an awaitable"
899+
):
900+
await poll_until_terminal_status_async(
901+
operation_name="invalid-status-awaitable-async",
902+
get_status=lambda: "completed", # type: ignore[return-value]
903+
is_terminal_status=lambda value: value == "completed",
904+
poll_interval_seconds=0.0,
905+
max_wait_seconds=1.0,
906+
)
897907
with pytest.raises(
898908
HyperbrowserError, match="operation_name must be 200 characters or fewer"
899909
):

0 commit comments

Comments
 (0)