Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Generic, Protocol, TypeVar

import anyio
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectSendStream
from opentelemetry.trace import SpanKind
from pydantic import BaseModel, TypeAdapter
Expand Down Expand Up @@ -201,6 +202,15 @@ def __init__(
self._progress_callbacks = {}
self._response_routers = []
self._exit_stack = AsyncExitStack()
self._task_group: TaskGroup = anyio.create_task_group()
self._started = False

def _require_started(self) -> None:
if not self._started:
raise RuntimeError(
"Session is not running. Use it as an async context manager "
"(e.g. `async with ClientSession(...) as session:`)."
)

def add_response_router(self, router: ResponseRouter) -> None:
"""Register a response router to handle responses for non-standard requests.
Expand All @@ -218,8 +228,11 @@ def add_response_router(self, router: ResponseRouter) -> None:
self._response_routers.append(router)

async def __aenter__(self) -> Self:
if self._started:
raise RuntimeError("Session is already running")
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
self._started = True
self._task_group.start_soon(self._receive_loop)
return self

Expand All @@ -234,7 +247,10 @@ async def __aexit__(
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
try:
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
finally:
self._started = False

async def send_request(
self,
Expand All @@ -251,6 +267,7 @@ async def send_request(

Do not use this method to emit notifications! Use send_notification() instead.
"""
self._require_started()
request_id = self._request_id
self._request_id = request_id + 1

Expand Down Expand Up @@ -313,6 +330,7 @@ async def send_notification(
related_request_id: RequestId | None = None,
) -> None:
"""Emits a notification, which is a one-way message that does not expect a response."""
self._require_started()
# Some transport implementations may need to set the related_request_id
# to attribute to the notifications to the request that triggered them.
jsonrpc_notification = JSONRPCNotification(
Expand Down
7 changes: 4 additions & 3 deletions tests/client/test_resource_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ async def mock_send(*args: Any, **kwargs: Any):
initial_stream_count = len(session._response_streams)

# Run the test with the patched method
with patch.object(session._write_stream, "send", mock_send):
with pytest.raises(RuntimeError):
await session.send_request(request, EmptyResult)
async with session:
with patch.object(session._write_stream, "send", mock_send):
with pytest.raises(RuntimeError, match="Simulated network error"): # pragma: no branch
await session.send_request(request, EmptyResult)

# Verify that no response streams were leaked
assert len(session._response_streams) == initial_stream_count, (
Expand Down
37 changes: 37 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,43 @@ async def message_handler( # pragma: no cover
assert isinstance(initialized_notification, InitializedNotification)


@pytest.mark.anyio
async def test_client_session_requires_context_manager():
client_to_server_send, _client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
_server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

async with (
client_to_server_send,
_client_to_server_receive,
_server_to_client_send,
server_to_client_receive,
):
session = ClientSession(server_to_client_receive, client_to_server_send)

with pytest.raises(RuntimeError, match="async context manager"):
await session.initialize()


@pytest.mark.anyio
async def test_client_session_reentry_raises_runtime_error():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
session = ClientSession(server_to_client_receive, client_to_server_send)
await session.__aenter__()
try:
with pytest.raises(RuntimeError, match="already running"):
await session.__aenter__()
finally:
await session.__aexit__(None, None, None)


@pytest.mark.anyio
async def test_client_session_custom_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
Expand Down
Loading