Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions asyncpg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .connection import connect, Connection # NOQA
from .exceptions import * # NOQA
from .pool import create_pool, Pool # NOQA
from .pool import create_pool, Pool, AcquireEvent # NOQA
from .protocol import Record # NOQA
from .types import * # NOQA

Expand All @@ -19,6 +19,6 @@


__all__: tuple[str, ...] = (
'connect', 'create_pool', 'Pool', 'Record', 'Connection'
'connect', 'create_pool', 'Pool', 'Record', 'Connection', 'AcquireEvent'
)
__all__ += exceptions.__all__ # NOQA
54 changes: 51 additions & 3 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import asyncio
from collections.abc import Awaitable, Callable
import dataclasses
import functools
import inspect
import logging
Expand All @@ -25,6 +26,19 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass(frozen=True)
class AcquireEvent:
"""Emitted by :meth:`Pool.acquire` on every successful dispatch.

.. versionadded:: 0.32.0
"""

wait_seconds: float
size: int
idle: int
max_size: int


class PoolConnectionProxyMeta(type):

def __new__(
Expand Down Expand Up @@ -342,7 +356,8 @@ class Pool:
'_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
'_holders', '_initialized', '_initializing', '_closing',
'_closed', '_connection_class', '_record_class', '_generation',
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
'_setup', '_max_queries', '_max_inactive_connection_lifetime',
'_on_acquire',
)

def __init__(self, *connect_args,
Expand All @@ -357,6 +372,8 @@ def __init__(self, *connect_args,
loop,
connection_class,
record_class,
on_acquire: Optional[
Callable[[AcquireEvent], None]] = None,
**connect_kwargs):

if len(connect_args) > 1:
Expand Down Expand Up @@ -399,6 +416,8 @@ def __init__(self, *connect_args,
'record_class is expected to be a subclass of '
'asyncpg.Record, got {!r}'.format(record_class))

self._on_acquire = on_acquire

self._minsize = min_size
self._maxsize = max_size

Expand Down Expand Up @@ -892,11 +911,27 @@ async def _acquire_impl():
raise exceptions.InterfaceError('pool is closing')
self._check_init()

started = time.monotonic()
if timeout is None:
return await _acquire_impl()
proxy = await _acquire_impl()
else:
return await compat.wait_for(
proxy = await compat.wait_for(
_acquire_impl(), timeout=timeout)
if self._on_acquire is not None:
self._fire_on_acquire(time.monotonic() - started)
return proxy

def _fire_on_acquire(self, wait_seconds: float) -> None:
try:
self._on_acquire(AcquireEvent(
wait_seconds=wait_seconds,
size=self.get_size(),
idle=self.get_idle_size(),
max_size=self._maxsize,
))
except Exception:
logger.exception(
'asyncpg on_acquire callback raised; suppressing')

async def release(self, connection, *, timeout=None):
"""Release a database connection back to the pool.
Expand Down Expand Up @@ -1084,6 +1119,8 @@ def create_pool(dsn=None, *,
loop=None,
connection_class=connection.Connection,
record_class=protocol.Record,
on_acquire: Optional[
Callable[[AcquireEvent], None]] = None,
**connect_kwargs):
r"""Create a connection pool.

Expand Down Expand Up @@ -1230,6 +1267,16 @@ def create_pool(dsn=None, *,

.. versionchanged:: 0.30.0
Added the *connect* and *reset* parameters.

:param on_acquire:
Synchronous callback invoked with an :class:`AcquireEvent` after
every successful :meth:`Pool.acquire` dispatch. ``wait_seconds``
is wall-clock time spent inside :meth:`Pool.acquire` (queue wait
plus any reconnect or ``setup`` callback). Exceptions are
logged and suppressed.

.. versionchanged:: 0.32.0
Added the *on_acquire* parameter.
"""
return Pool(
dsn,
Expand All @@ -1244,5 +1291,6 @@ def create_pool(dsn=None, *,
init=init,
reset=reset,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
on_acquire=on_acquire,
**connect_kwargs,
)
56 changes: 56 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,62 @@ async def worker():
conn = await pool.acquire(timeout=0.1)
await pool.release(conn)

async def test_pool_on_acquire_reports_saturation_wait(self):
events = []
pool = await self.create_pool(
database='postgres',
min_size=1,
max_size=1,
on_acquire=events.append,
)
try:
holder_acquired = asyncio.Event()
release_holder = asyncio.Event()

async def holder():
async with pool.acquire():
holder_acquired.set()
await release_holder.wait()

async def waiter():
await holder_acquired.wait()
async with pool.acquire() as con:
await con.fetchval('SELECT 1')

holder_task = self.loop.create_task(holder())
waiter_task = self.loop.create_task(waiter())
await holder_acquired.wait()
await asyncio.sleep(0.15)
release_holder.set()
await asyncio.gather(holder_task, waiter_task)
finally:
await pool.close()

self.assertEqual(len(events), 2)
for ev in events:
self.assertEqual(ev.max_size, 1)
self.assertGreaterEqual(ev.wait_seconds, 0)
self.assertGreaterEqual(
max(ev.wait_seconds for ev in events), 0.1)

async def test_pool_on_acquire_not_fired_on_timeout(self):
events = []
pool = await self.create_pool(
database='postgres',
min_size=1,
max_size=1,
on_acquire=events.append,
)
try:
async with pool.acquire():
with self.assertRaises(asyncio.TimeoutError):
await pool.acquire(timeout=0.1)
finally:
await pool.close()

# one event for the outer successful acquire, none for the timeout
self.assertEqual(len(events), 1)


@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
class TestPoolReconnectWithTargetSessionAttrs(tb.ClusterTestCase):
Expand Down