@@ -452,3 +452,116 @@ async def async_func() -> int:
452452
453453 manager .stop ()
454454 assert not manager .is_running
455+
456+
457+ # ---------------------------------------------------------------------------
458+ # Regression tests for the await-bridge fix in _AwaitWrapper.__call__().
459+ #
460+ # These tests verify the behavior when asyncio.current_task() returns non-None
461+ # (i.e. we are inside an async task on the same event loop). Because
462+ # asyncio.to_thread / run_in_executor spawn worker threads where current_task()
463+ # returns None, we must mock asyncio.current_task and asyncio.get_running_loop
464+ # to exercise the relevant branches.
465+ # ---------------------------------------------------------------------------
466+
467+
468+ def test_await_portal_fallback_when_current_task_exists () -> None :
469+ """When current_task is non-None and raise_sync_error=False, await_ should
470+ fall back to get_global_portal() instead of raising RuntimeError."""
471+ from unittest .mock import MagicMock , patch
472+
473+ async def async_double (x : int ) -> int :
474+ return x * 2
475+
476+ mock_loop = MagicMock ()
477+ mock_loop .is_running .return_value = True
478+
479+ mock_portal = MagicMock ()
480+ mock_portal .call .return_value = 42
481+
482+ with (
483+ patch ("asyncio.get_running_loop" , return_value = mock_loop ),
484+ patch ("asyncio.current_task" , return_value = MagicMock ()),
485+ patch ("sqlspec.utils.sync_tools.get_global_portal" , return_value = mock_portal ) as mock_get_portal ,
486+ ):
487+ sync_double = await_ (async_double , raise_sync_error = False )
488+ result = sync_double (21 )
489+
490+ assert result == 42
491+ mock_get_portal .assert_called ()
492+ mock_portal .call .assert_called_once ()
493+
494+
495+ def test_await_raises_when_current_task_exists_and_raise_sync_error_true () -> None :
496+ """When current_task is non-None and raise_sync_error=True, await_ should
497+ raise RuntimeError with the appropriate message."""
498+ from unittest .mock import MagicMock , patch
499+
500+ async def async_func () -> int :
501+ return 1
502+
503+ mock_loop = MagicMock ()
504+ mock_loop .is_running .return_value = True
505+
506+ with (
507+ patch ("asyncio.get_running_loop" , return_value = mock_loop ),
508+ patch ("asyncio.current_task" , return_value = MagicMock ()),
509+ ):
510+ sync_func = await_ (async_func , raise_sync_error = True )
511+ with pytest .raises (RuntimeError , match = "await_ cannot be called from within an async task" ):
512+ sync_func ()
513+
514+
515+ def test_await_portal_fallback_propagates_exceptions () -> None :
516+ """When using portal fallback (current_task non-None, raise_sync_error=False),
517+ exceptions from the coroutine should propagate through the portal."""
518+ from unittest .mock import MagicMock , patch
519+
520+ async def async_explode () -> int :
521+ raise ValueError ("test error from async" )
522+
523+ mock_loop = MagicMock ()
524+ mock_loop .is_running .return_value = True
525+
526+ mock_portal = MagicMock ()
527+ mock_portal .call .side_effect = ValueError ("test error from async" )
528+
529+ with (
530+ patch ("asyncio.get_running_loop" , return_value = mock_loop ),
531+ patch ("asyncio.current_task" , return_value = MagicMock ()),
532+ patch ("sqlspec.utils.sync_tools.get_global_portal" , return_value = mock_portal ),
533+ ):
534+ sync_explode = await_ (async_explode , raise_sync_error = False )
535+ with pytest .raises (ValueError , match = "test error from async" ):
536+ sync_explode ()
537+
538+
539+ def test_await_run_coroutine_threadsafe_when_no_current_task () -> None :
540+ """When the loop is running but current_task is None (worker thread context),
541+ await_ should use asyncio.run_coroutine_threadsafe."""
542+ from unittest .mock import MagicMock , patch
543+
544+ async def async_add (a : int , b : int ) -> int :
545+ return a + b
546+
547+ mock_loop = MagicMock ()
548+ mock_loop .is_running .return_value = True
549+
550+ mock_future = MagicMock ()
551+ mock_future .result .return_value = 7
552+
553+ def _capture_and_close_coro (coro : "Any" , loop : "Any" ) -> MagicMock :
554+ """Close the coroutine to avoid 'was never awaited' warning."""
555+ coro .close ()
556+ return mock_future
557+
558+ with (
559+ patch ("asyncio.get_running_loop" , return_value = mock_loop ),
560+ patch ("asyncio.current_task" , return_value = None ),
561+ patch ("asyncio.run_coroutine_threadsafe" , side_effect = _capture_and_close_coro ) as mock_rcts ,
562+ ):
563+ sync_add = await_ (async_add , raise_sync_error = False )
564+ result = sync_add (3 , 4 )
565+
566+ assert result == 7
567+ mock_rcts .assert_called_once ()
0 commit comments