Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion burr/common/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import inspect
from typing import AsyncGenerator, AsyncIterable, Generator, List, TypeVar, Union
from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Generator, List, TypeVar, Union

T = TypeVar("T")

Expand All @@ -27,6 +27,46 @@
SyncOrAsyncGeneratorOrItemOrList = Union[SyncOrAsyncGenerator[GenType], List[GenType], GenType]


class _AsyncPersisterContextManager:
"""Wraps an async coroutine that returns a persister so it can be used
directly with ``async with``::

async with AsyncSQLitePersister.from_values(...) as persister:
...

The wrapper awaits the coroutine on ``__aenter__`` and delegates
``__aexit__`` to the persister's own ``__aexit__``.

.. note::
Each instance wraps a single coroutine and can only be consumed once,
either via ``await`` or ``async with``. A second use will raise
``RuntimeError``.
"""

def __init__(self, coro: Coroutine[Any, Any, Any]):
self._coro = coro
self._persister = None
self._consumed = False

def __await__(self):
if self._consumed:
raise RuntimeError("This factory result has already been consumed")
self._consumed = True
return self._coro.__await__()

async def __aenter__(self):
if self._consumed:
raise RuntimeError("This factory result has already been consumed")
self._consumed = True
self._persister = await self._coro
return await self._persister.__aenter__()

async def __aexit__(self, exc_type, exc_value, traceback):
if self._persister is None:
return False
return await self._persister.__aexit__(exc_type, exc_value, traceback)


async def asyncify_generator(
generator: SyncOrAsyncGenerator[GenType],
) -> AsyncGenerator[GenType, None]:
Expand Down
35 changes: 27 additions & 8 deletions burr/integrations/persisters/b_aiosqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import aiosqlite

from burr.common.async_utils import _AsyncPersisterContextManager
from burr.common.types import BaseCopyable
from burr.core import State
from burr.core.persistence import AsyncBaseStatePersister, PersistedStateData
Expand Down Expand Up @@ -60,38 +61,56 @@ def copy(self) -> "Self":
PARTITION_KEY_DEFAULT = ""

@classmethod
async def from_config(cls, config: dict) -> "AsyncSQLitePersister":
def from_config(cls, config: dict) -> "_AsyncPersisterContextManager":
"""Creates a new instance of the AsyncSQLitePersister from a configuration dictionary.

Can be used with ``await`` or as an async context manager::

persister = await AsyncSQLitePersister.from_config(config)
# or
async with AsyncSQLitePersister.from_config(config) as persister:
...

The config key:value pair needed are:
db_path: str,
table_name: str,
serde_kwargs: dict,
connect_kwargs: dict,
"""
return await cls.from_values(**config)
return cls.from_values(**config)

@classmethod
async def from_values(
def from_values(
cls,
db_path: str,
table_name: str = "burr_state",
serde_kwargs: dict = None,
connect_kwargs: dict = None,
) -> "AsyncSQLitePersister":
) -> "_AsyncPersisterContextManager":
"""Creates a new instance of the AsyncSQLitePersister from passed in values.

Can be used with ``await`` or as an async context manager::

persister = await AsyncSQLitePersister.from_values(db_path="test.db")
# or
async with AsyncSQLitePersister.from_values(db_path="test.db") as persister:
...

:param db_path: the path the DB will be stored.
:param table_name: the table name to store things under.
:param serde_kwargs: kwargs for state serialization/deserialization.
:param connect_kwargs: kwargs to pass to the aiosqlite.connect method.
:return: async sqlite persister instance with an open connection. You are responsible
for closing the connection yourself.
"""
connection = await aiosqlite.connect(
db_path, **connect_kwargs if connect_kwargs is not None else {}
)
return cls(connection, table_name, serde_kwargs)

async def _create():
connection = await aiosqlite.connect(
db_path, **connect_kwargs if connect_kwargs is not None else {}
)
return cls(connection, table_name, serde_kwargs)

return _AsyncPersisterContextManager(_create())

def __init__(
self,
Expand Down
61 changes: 40 additions & 21 deletions burr/integrations/persisters/b_asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
from typing import Any, ClassVar, Literal, Optional

from burr.common.async_utils import _AsyncPersisterContextManager
from burr.common.types import BaseCopyable
from burr.core import persistence, state
from burr.integrations import base
Expand Down Expand Up @@ -106,12 +107,20 @@ async def create_pool(
return cls._pool

@classmethod
async def from_config(cls, config: dict) -> "AsyncPostgreSQLPersister":
"""Creates a new instance of the PostgreSQLPersister from a configuration dictionary."""
return await cls.from_values(**config)
def from_config(cls, config: dict) -> "_AsyncPersisterContextManager":
"""Creates a new instance of the PostgreSQLPersister from a configuration dictionary.

Can be used with ``await`` or as an async context manager::

persister = await AsyncPostgreSQLPersister.from_config(config)
# or
async with AsyncPostgreSQLPersister.from_config(config) as persister:
...
"""
return cls.from_values(**config)

@classmethod
async def from_values(
def from_values(
cls,
db_name: str,
user: str,
Expand All @@ -121,9 +130,16 @@ async def from_values(
table_name: str = "burr_state",
use_pool: bool = False,
**pool_kwargs,
) -> "AsyncPostgreSQLPersister":
) -> "_AsyncPersisterContextManager":
"""Builds a new instance of the PostgreSQLPersister from the provided values.

Can be used with ``await`` or as an async context manager::

persister = await AsyncPostgreSQLPersister.from_values(...)
# or
async with AsyncPostgreSQLPersister.from_values(...) as persister:
...

:param db_name: the name of the PostgreSQL database.
:param user: the username to connect to the PostgreSQL database.
:param password: the password to connect to the PostgreSQL database.
Expand All @@ -133,22 +149,25 @@ async def from_values(
:param use_pool: whether to use a connection pool (True) or a direct connection (False)
:param pool_kwargs: additional kwargs to pass to the pool creation
"""
if use_pool:
pool = await cls.create_pool(
user=user,
password=password,
database=db_name,
host=host,
port=port,
**pool_kwargs,
)
return cls(connection=None, pool=pool, table_name=table_name)
else:
# Original behavior - direct connection
connection = await asyncpg.connect(
user=user, password=password, database=db_name, host=host, port=port
)
return cls(connection=connection, table_name=table_name)

async def _create():
if use_pool:
pool = await cls.create_pool(
user=user,
password=password,
database=db_name,
host=host,
port=port,
**pool_kwargs,
)
return cls(connection=None, pool=pool, table_name=table_name)
else:
connection = await asyncpg.connect(
user=user, password=password, database=db_name, host=host, port=port
)
return cls(connection=connection, table_name=table_name)

return _AsyncPersisterContextManager(_create())

def __init__(
self,
Expand Down
8 changes: 4 additions & 4 deletions docs/concepts/parallelism.rst
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ When using state persistence with async parallelism, make sure to use the async
from burr.integrations.persisters.b_asyncpg import AsyncPGPersister

# Create an async persister with a connection pool
persister = AsyncPGPersister.from_values(
persister = await AsyncPGPersister.from_values(
host="localhost",
port=5432,
user="postgres",
Expand All @@ -707,7 +707,7 @@ When using state persistence with async parallelism, make sure to use the async
use_pool=True # Important for parallelism!
)

app = (
app = await (
ApplicationBuilder()
.with_state_persister(persister)
.with_action(
Expand All @@ -722,12 +722,12 @@ Remember to properly clean up your async persisters when you're done with them:

.. code-block:: python

# Using as a context manager
# Using as a context manager (recommended)
async with AsyncPGPersister.from_values(..., use_pool=True) as persister:
# Use persister here

# Or manual cleanup
persister = AsyncPGPersister.from_values(..., use_pool=True)
persister = await AsyncPGPersister.from_values(..., use_pool=True)
try:
# Use persister here
finally:
Expand Down
28 changes: 11 additions & 17 deletions tests/core/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,12 @@ def test_sqlite_persister_save_without_initialize_raises_runtime_error():
try:
with pytest.raises(RuntimeError, match="Uninitialized persister"):
persister.save(
"partition_key", "app_id", 1, "position", State({"key": "value"}), "completed"
"partition_key",
"app_id",
1,
"position",
State({"key": "value"}),
"completed",
)
finally:
persister.cleanup()
Expand Down Expand Up @@ -168,17 +173,6 @@ def test_persister_methods_none_partition_key(persistence, method_name: str, kwa
"""Asyncio integration for sqlite persister + """


class AsyncSQLiteContextManager:
def __init__(self, sqlite_object):
self.client = sqlite_object

async def __aenter__(self):
return self.client

async def __aexit__(self, exc_type, exc, tb):
await self.client.close()


@pytest.fixture()
async def async_persistence(request):
yield AsyncInMemoryPersister()
Expand Down Expand Up @@ -276,15 +270,15 @@ async def test_AsyncSQLitePersister_connection_shutdown():

@pytest.fixture()
async def initializing_async_persistence():
sqlite_persister = await AsyncSQLitePersister.from_values(
async with AsyncSQLitePersister.from_values(
db_path=":memory:", table_name="test_table"
)
async_context_manager = AsyncSQLiteContextManager(sqlite_persister)
async with async_context_manager as client:
) as client:
yield client


async def test_async_persistence_initialization_creates_table(initializing_async_persistence):
async def test_async_persistence_initialization_creates_table(
initializing_async_persistence,
):
await asyncio.sleep(0.00001)
await initializing_async_persistence.initialize()
assert await initializing_async_persistence.list_app_ids("partition_key") == []
Expand Down
Loading
Loading