Skip to content

Commit 6d9b85a

Browse files
committed
Fix: Prevent hang when ClientSession is initialized without context manager
1 parent 6b69f63 commit 6d9b85a

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

src/mcp/client/session.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from typing import Any, Protocol, overload
3+
from types import TracebackType
34

45
import anyio.lowlevel
56
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -140,11 +141,33 @@ def __init__(
140141
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
141142
self._server_capabilities: types.ServerCapabilities | None = None
142143
self._experimental_features: ExperimentalClientFeatures | None = None
144+
self._entered = False
143145

144146
# Experimental: Task handlers (use defaults if not provided)
145147
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
146148

149+
async def __aenter__(self) -> "ClientSession":
150+
self._entered = True
151+
await super().__aenter__()
152+
return self
153+
154+
async def __aexit__(
155+
self,
156+
exc_type: type[BaseException] | None,
157+
exc_value: BaseException | None,
158+
traceback: TracebackType | None,
159+
) -> None:
160+
self._entered = False
161+
await super().__aexit__(exc_type, exc_value, traceback)
162+
163+
def _check_is_active(self) -> None:
164+
if not self._entered:
165+
raise RuntimeError(
166+
"ClientSession must be used within an 'async with' block."
167+
)
168+
147169
async def initialize(self) -> types.InitializeResult:
170+
self._check_is_active()
148171
sampling = (
149172
(self._sampling_capabilities or types.SamplingCapability())
150173
if self._sampling_callback is not _default_sampling_callback

0 commit comments

Comments
 (0)