Skip to content

Commit 4d36a67

Browse files
fix: reject use of pooled connections after release to prevent stale reference escape
ConnectionPool.acquire() yields a raw DqliteConnection. A caller could stash the reference and use it after the context manager exits, silently bypassing pool exclusivity. Two tasks could then operate on the same TCP connection concurrently, corrupting the dqlite wire protocol. Add a _pool_released flag to DqliteConnection that is checked in _check_in_use(). The pool sets the flag after cleanup (including any ROLLBACK for open transactions) completes, and clears it when the connection is re-acquired. Standalone (non-pooled) connections are not affected since _pool_released defaults to False. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 601b4a3 commit 4d36a67

File tree

3 files changed

+156
-1
lines changed

3 files changed

+156
-1
lines changed

src/dqliteclient/connection.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282
self._in_use = False
8383
self._bound_loop: asyncio.AbstractEventLoop | None = None
8484
self._tx_owner: asyncio.Task[Any] | None = None
85+
self._pool_released = False
8586

8687
@property
8788
def address(self) -> str:
@@ -154,7 +155,12 @@ def _ensure_connected(self) -> tuple[DqliteProtocol, int]:
154155
return self._protocol, self._db_id
155156

156157
def _check_in_use(self) -> None:
157-
"""Raise on misuse: wrong event loop or concurrent coroutine access."""
158+
"""Raise on misuse: wrong event loop, concurrent access, or use after pool release."""
159+
if self._pool_released:
160+
raise InterfaceError(
161+
"This connection has been returned to the pool and can no longer "
162+
"be used directly. Acquire a new connection from the pool."
163+
)
158164
if self._bound_loop is not None:
159165
try:
160166
current_loop = asyncio.get_running_loop()

src/dqliteclient/pool.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,17 +143,20 @@ async def acquire(self) -> AsyncIterator[DqliteConnection]:
143143
self._size -= 1
144144
conn = await self._create_connection()
145145

146+
conn._pool_released = False
146147
try:
147148
yield conn
148149
except BaseException:
149150
if conn.is_connected and not self._closed:
150151
# Connection is healthy — user code raised a non-connection error.
151152
# Roll back any open transaction, then return to pool.
152153
if not await self._reset_connection(conn):
154+
conn._pool_released = True
153155
with contextlib.suppress(Exception):
154156
await conn.close()
155157
self._size -= 1
156158
raise
159+
conn._pool_released = True
157160
try:
158161
self._pool.put_nowait(conn)
159162
except asyncio.QueueFull:
@@ -162,6 +165,7 @@ async def acquire(self) -> AsyncIterator[DqliteConnection]:
162165
else:
163166
# Connection is broken (invalidated by execute/fetch error handlers).
164167
# Drain other idle connections — they likely point to the same dead server.
168+
conn._pool_released = True
165169
await self._drain_idle()
166170
with contextlib.suppress(BaseException):
167171
await conn.close()
@@ -188,16 +192,19 @@ async def _reset_connection(self, conn: DqliteConnection) -> bool:
188192
async def _release(self, conn: DqliteConnection) -> None:
189193
"""Return a connection to the pool or close it."""
190194
if self._closed:
195+
conn._pool_released = True
191196
await conn.close()
192197
self._size -= 1
193198
return
194199

195200
if not await self._reset_connection(conn):
201+
conn._pool_released = True
196202
with contextlib.suppress(Exception):
197203
await conn.close()
198204
self._size -= 1
199205
return
200206

207+
conn._pool_released = True
201208
try:
202209
self._pool.put_nowait(conn)
203210
except asyncio.QueueFull:

tests/test_pool.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,148 @@ async def failing_execute(sql, params=None):
640640

641641
await pool.close()
642642

643+
async def test_escaped_reference_rejected_after_release(self) -> None:
644+
"""Using a connection after it's returned to the pool must raise InterfaceError."""
645+
import asyncio
646+
647+
from dqliteclient.connection import DqliteConnection
648+
from dqliteclient.exceptions import InterfaceError
649+
650+
pool = ConnectionPool(["localhost:9001"], max_size=1)
651+
652+
real_conn = DqliteConnection("127.0.0.1:9001")
653+
real_conn._protocol = MagicMock()
654+
real_conn._db_id = 1
655+
real_conn._bound_loop = asyncio.get_running_loop()
656+
real_conn._protocol.exec_sql = AsyncMock(return_value=(0, 1))
657+
real_conn._protocol.query_sql = AsyncMock(return_value=(["id"], [[1]]))
658+
real_conn._protocol.close = MagicMock()
659+
real_conn._protocol.wait_closed = AsyncMock()
660+
661+
with patch.object(pool._cluster, "connect", return_value=real_conn):
662+
await pool.initialize()
663+
664+
# Stash a reference to the connection
665+
escaped = None
666+
async with pool.acquire() as conn:
667+
escaped = conn
668+
await conn.execute("SELECT 1")
669+
670+
# Connection is now back in the pool — escaped reference must be rejected
671+
assert escaped is not None
672+
with pytest.raises(InterfaceError, match="returned to the pool"):
673+
await escaped.execute("SELECT 1")
674+
675+
await pool.close()
676+
677+
async def test_escaped_reference_rejected_after_exception(self) -> None:
678+
"""Escaped reference must be rejected even when user code raises."""
679+
import asyncio
680+
681+
from dqliteclient.connection import DqliteConnection
682+
from dqliteclient.exceptions import InterfaceError
683+
684+
pool = ConnectionPool(["localhost:9001"], max_size=1)
685+
686+
real_conn = DqliteConnection("127.0.0.1:9001")
687+
real_conn._protocol = MagicMock()
688+
real_conn._db_id = 1
689+
real_conn._bound_loop = asyncio.get_running_loop()
690+
real_conn._protocol.exec_sql = AsyncMock(return_value=(0, 1))
691+
real_conn._protocol.close = MagicMock()
692+
real_conn._protocol.wait_closed = AsyncMock()
693+
694+
with patch.object(pool._cluster, "connect", return_value=real_conn):
695+
await pool.initialize()
696+
697+
escaped = None
698+
with pytest.raises(ValueError, match="app error"):
699+
async with pool.acquire() as conn:
700+
escaped = conn
701+
raise ValueError("app error")
702+
703+
assert escaped is not None
704+
with pytest.raises(InterfaceError, match="returned to the pool"):
705+
await escaped.execute("SELECT 1")
706+
707+
await pool.close()
708+
709+
async def test_pool_release_rolls_back_transaction_with_real_connection(self) -> None:
710+
"""_reset_connection must be able to ROLLBACK before _pool_released is set."""
711+
import asyncio
712+
713+
from dqliteclient.connection import DqliteConnection
714+
715+
pool = ConnectionPool(["localhost:9001"], max_size=1)
716+
717+
real_conn = DqliteConnection("127.0.0.1:9001")
718+
real_conn._protocol = MagicMock()
719+
real_conn._db_id = 1
720+
real_conn._bound_loop = asyncio.get_running_loop()
721+
real_conn._protocol.exec_sql = AsyncMock(return_value=(0, 0))
722+
real_conn._protocol.close = MagicMock()
723+
real_conn._protocol.wait_closed = AsyncMock()
724+
725+
with patch.object(pool._cluster, "connect", return_value=real_conn):
726+
await pool.initialize()
727+
728+
# Simulate a connection with an open transaction returned to pool
729+
async with pool.acquire() as conn:
730+
conn._in_transaction = True
731+
732+
# The pool should have issued ROLLBACK successfully (not destroyed the conn)
733+
assert not real_conn._in_transaction
734+
# Connection should be back in the pool (not destroyed)
735+
assert pool._pool.qsize() == 1
736+
assert pool._size == 1
737+
738+
await pool.close()
739+
740+
async def test_escaped_reference_works_when_reacquired(self) -> None:
741+
"""A connection re-acquired from the pool must work normally."""
742+
import asyncio
743+
744+
from dqliteclient.connection import DqliteConnection
745+
746+
pool = ConnectionPool(["localhost:9001"], max_size=1)
747+
748+
real_conn = DqliteConnection("127.0.0.1:9001")
749+
real_conn._protocol = MagicMock()
750+
real_conn._db_id = 1
751+
real_conn._bound_loop = asyncio.get_running_loop()
752+
real_conn._protocol.exec_sql = AsyncMock(return_value=(0, 1))
753+
real_conn._protocol.close = MagicMock()
754+
real_conn._protocol.wait_closed = AsyncMock()
755+
756+
with patch.object(pool._cluster, "connect", return_value=real_conn):
757+
await pool.initialize()
758+
759+
# First acquire and release
760+
async with pool.acquire() as conn:
761+
await conn.execute("SELECT 1")
762+
763+
# Second acquire — same connection from pool must work
764+
async with pool.acquire() as conn:
765+
await conn.execute("SELECT 2")
766+
767+
await pool.close()
768+
769+
async def test_standalone_connection_not_affected_by_pool_guard(self) -> None:
770+
"""A DqliteConnection used standalone (not from a pool) must not be affected."""
771+
import asyncio
772+
773+
from dqliteclient.connection import DqliteConnection
774+
775+
conn = DqliteConnection("127.0.0.1:9001")
776+
conn._protocol = MagicMock()
777+
conn._db_id = 1
778+
conn._bound_loop = asyncio.get_running_loop()
779+
conn._protocol.exec_sql = AsyncMock(return_value=(0, 1))
780+
781+
# Standalone connection — no pool involved, should work fine
782+
await conn.execute("SELECT 1")
783+
await conn.execute("SELECT 2") # No error — _pool_released is always False
784+
643785

644786
class TestConnectionPoolIntegration:
645787
"""Integration tests requiring mocked connections."""

0 commit comments

Comments
 (0)