Skip to content

Commit c4d2eca

Browse files
authored
fix(sync_tools): use portal fallback when await_ called from async task (#381)
`_AwaitWrapper.__call__()` unconditionally raised `RuntimeError` when called from within an async task on the same event loop, blocking all `await_()` call sites (migrations, CLI commands) when running inside an existing event loop (e.g., uvloop, anyio). Now when `raise_sync_error=False` (the default for all 8 migration/CLI call sites), the code falls through to the portal pattern which runs coroutines on a separate background daemon thread's event loop, avoiding the deadlock entirely. When `raise_sync_error=True`, the existing strict-mode behavior is preserved.
1 parent c1726c1 commit c4d2eca

2 files changed

Lines changed: 119 additions & 2 deletions

File tree

sqlspec/utils/sync_tools.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,12 @@ def __call__(self, *args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "
225225
current_task = None
226226

227227
if current_task is not None:
228-
msg = "await_ cannot be called from within an async task running on the same event loop. Use 'await' instead."
229-
raise RuntimeError(msg)
228+
if self._raise_sync_error:
229+
msg = "await_ cannot be called from within an async task running on the same event loop. Use 'await' instead."
230+
raise RuntimeError(msg)
231+
portal = get_global_portal()
232+
typed_partial = cast("Callable[[], Coroutine[Any, Any, ReturnT]]", partial_f)
233+
return portal.call(typed_partial)
230234
future = asyncio.run_coroutine_threadsafe(partial_f(), loop)
231235
return future.result()
232236
if self._raise_sync_error:

tests/unit/utils/test_sync_tools.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,3 +452,116 @@ async def async_func() -> int:
452452

453453
manager.stop()
454454
assert not manager.is_running
455+
456+
457+
# ---------------------------------------------------------------------------
458+
# Regression tests for the await-bridge fix in _AwaitWrapper.__call__().
459+
#
460+
# These tests verify the behavior when asyncio.current_task() returns non-None
461+
# (i.e. we are inside an async task on the same event loop). Because
462+
# asyncio.to_thread / run_in_executor spawn worker threads where current_task()
463+
# returns None, we must mock asyncio.current_task and asyncio.get_running_loop
464+
# to exercise the relevant branches.
465+
# ---------------------------------------------------------------------------
466+
467+
468+
def test_await_portal_fallback_when_current_task_exists() -> None:
469+
"""When current_task is non-None and raise_sync_error=False, await_ should
470+
fall back to get_global_portal() instead of raising RuntimeError."""
471+
from unittest.mock import MagicMock, patch
472+
473+
async def async_double(x: int) -> int:
474+
return x * 2
475+
476+
mock_loop = MagicMock()
477+
mock_loop.is_running.return_value = True
478+
479+
mock_portal = MagicMock()
480+
mock_portal.call.return_value = 42
481+
482+
with (
483+
patch("asyncio.get_running_loop", return_value=mock_loop),
484+
patch("asyncio.current_task", return_value=MagicMock()),
485+
patch("sqlspec.utils.sync_tools.get_global_portal", return_value=mock_portal) as mock_get_portal,
486+
):
487+
sync_double = await_(async_double, raise_sync_error=False)
488+
result = sync_double(21)
489+
490+
assert result == 42
491+
mock_get_portal.assert_called()
492+
mock_portal.call.assert_called_once()
493+
494+
495+
def test_await_raises_when_current_task_exists_and_raise_sync_error_true() -> None:
496+
"""When current_task is non-None and raise_sync_error=True, await_ should
497+
raise RuntimeError with the appropriate message."""
498+
from unittest.mock import MagicMock, patch
499+
500+
async def async_func() -> int:
501+
return 1
502+
503+
mock_loop = MagicMock()
504+
mock_loop.is_running.return_value = True
505+
506+
with (
507+
patch("asyncio.get_running_loop", return_value=mock_loop),
508+
patch("asyncio.current_task", return_value=MagicMock()),
509+
):
510+
sync_func = await_(async_func, raise_sync_error=True)
511+
with pytest.raises(RuntimeError, match="await_ cannot be called from within an async task"):
512+
sync_func()
513+
514+
515+
def test_await_portal_fallback_propagates_exceptions() -> None:
516+
"""When using portal fallback (current_task non-None, raise_sync_error=False),
517+
exceptions from the coroutine should propagate through the portal."""
518+
from unittest.mock import MagicMock, patch
519+
520+
async def async_explode() -> int:
521+
raise ValueError("test error from async")
522+
523+
mock_loop = MagicMock()
524+
mock_loop.is_running.return_value = True
525+
526+
mock_portal = MagicMock()
527+
mock_portal.call.side_effect = ValueError("test error from async")
528+
529+
with (
530+
patch("asyncio.get_running_loop", return_value=mock_loop),
531+
patch("asyncio.current_task", return_value=MagicMock()),
532+
patch("sqlspec.utils.sync_tools.get_global_portal", return_value=mock_portal),
533+
):
534+
sync_explode = await_(async_explode, raise_sync_error=False)
535+
with pytest.raises(ValueError, match="test error from async"):
536+
sync_explode()
537+
538+
539+
def test_await_run_coroutine_threadsafe_when_no_current_task() -> None:
540+
"""When the loop is running but current_task is None (worker thread context),
541+
await_ should use asyncio.run_coroutine_threadsafe."""
542+
from unittest.mock import MagicMock, patch
543+
544+
async def async_add(a: int, b: int) -> int:
545+
return a + b
546+
547+
mock_loop = MagicMock()
548+
mock_loop.is_running.return_value = True
549+
550+
mock_future = MagicMock()
551+
mock_future.result.return_value = 7
552+
553+
def _capture_and_close_coro(coro: "Any", loop: "Any") -> MagicMock:
554+
"""Close the coroutine to avoid 'was never awaited' warning."""
555+
coro.close()
556+
return mock_future
557+
558+
with (
559+
patch("asyncio.get_running_loop", return_value=mock_loop),
560+
patch("asyncio.current_task", return_value=None),
561+
patch("asyncio.run_coroutine_threadsafe", side_effect=_capture_and_close_coro) as mock_rcts,
562+
):
563+
sync_add = await_(async_add, raise_sync_error=False)
564+
result = sync_add(3, 4)
565+
566+
assert result == 7
567+
mock_rcts.assert_called_once()

0 commit comments

Comments
 (0)