Skip to content

Commit 9de7f7b

Browse files
committed
fix: fail fast when session is not started
1 parent 161834d commit 9de7f7b

2 files changed

Lines changed: 36 additions & 1 deletion

File tree

src/mcp/shared/session.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Generic, Protocol, TypeVar
99

1010
import anyio
11+
from anyio.abc import TaskGroup
1112
from anyio.streams.memory import MemoryObjectSendStream
1213
from opentelemetry.trace import SpanKind
1314
from pydantic import BaseModel, TypeAdapter
@@ -201,6 +202,15 @@ def __init__(
201202
self._progress_callbacks = {}
202203
self._response_routers = []
203204
self._exit_stack = AsyncExitStack()
205+
self._task_group: TaskGroup = anyio.create_task_group()
206+
self._started = False
207+
208+
def _require_started(self) -> None:
209+
if not self._started:
210+
raise RuntimeError(
211+
"Session is not running. Use it as an async context manager "
212+
"(e.g. `async with ClientSession(...) as session:`)."
213+
)
204214

205215
def add_response_router(self, router: ResponseRouter) -> None:
206216
"""Register a response router to handle responses for non-standard requests.
@@ -218,8 +228,11 @@ def add_response_router(self, router: ResponseRouter) -> None:
218228
self._response_routers.append(router)
219229

220230
async def __aenter__(self) -> Self:
231+
if self._started:
232+
raise RuntimeError("Session is already running")
221233
self._task_group = anyio.create_task_group()
222234
await self._task_group.__aenter__()
235+
self._started = True
223236
self._task_group.start_soon(self._receive_loop)
224237
return self
225238

@@ -234,7 +247,10 @@ async def __aexit__(
234247
# would be very surprising behavior), so make sure to cancel the tasks
235248
# in the task group.
236249
self._task_group.cancel_scope.cancel()
237-
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
250+
try:
251+
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
252+
finally:
253+
self._started = False
238254

239255
async def send_request(
240256
self,
@@ -251,6 +267,7 @@ async def send_request(
251267
252268
Do not use this method to emit notifications! Use send_notification() instead.
253269
"""
270+
self._require_started()
254271
request_id = self._request_id
255272
self._request_id = request_id + 1
256273

@@ -313,6 +330,7 @@ async def send_notification(
313330
related_request_id: RequestId | None = None,
314331
) -> None:
315332
"""Emits a notification, which is a one-way message that does not expect a response."""
333+
self._require_started()
316334
# Some transport implementations may need to set the related_request_id
317335
# to attribute to the notifications to the request that triggered them.
318336
jsonrpc_notification = JSONRPCNotification(

tests/client/test_session.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,23 @@ async def message_handler( # pragma: no cover
110110
assert isinstance(initialized_notification, InitializedNotification)
111111

112112

113+
@pytest.mark.anyio
114+
async def test_client_session_requires_context_manager():
115+
client_to_server_send, _client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
116+
_server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
117+
118+
async with (
119+
client_to_server_send,
120+
_client_to_server_receive,
121+
_server_to_client_send,
122+
server_to_client_receive,
123+
):
124+
session = ClientSession(server_to_client_receive, client_to_server_send)
125+
126+
with pytest.raises(RuntimeError, match="async context manager"):
127+
await session.initialize()
128+
129+
113130
@pytest.mark.anyio
114131
async def test_client_session_custom_client_info():
115132
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)

0 commit comments

Comments
 (0)