From 6fef500116b353a8c00a26bbbe24bca378bb6a8c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 16 Mar 2026 08:49:55 +0100 Subject: [PATCH 1/5] Propagate contextvars.Context through anyio streams without modifying SessionMessage Introduce context-aware stream wrappers (ContextSendStream / ContextReceiveStream) that capture the sender's contextvars.Context at send() time and expose it on the receive side via last_context. This enables OpenTelemetry trace propagation, per-request auth via ContextVars, and other context-dependent use cases across the anyio memory stream boundary - without adding any field to SessionMessage. Key changes: - New _context_streams module with ContextSendStream, ContextReceiveStream, and create_context_streams factory (mirrors anyio's bracket syntax API) - Protocol-based ReadStream/WriteStream in _transport.py, replacing concrete MemoryObjectReceiveStream/MemoryObjectSendStream in all parameter types - All transport stream creation sites use create_context_streams - BaseSession._receive_loop and client post_writers restore sender context via ctx.run(tg.start_soon, handler, message) - RequestResponder carries context for the session-to-server handler boundary Github-Issue:#1996 --- .../mcp_simple_auth_client/main.py | 6 +- src/mcp/client/__main__.py | 6 +- src/mcp/client/_transport.py | 52 ++++++- src/mcp/client/session.py | 6 +- src/mcp/client/sse.py | 23 ++-- src/mcp/client/streamable_http.py | 23 +++- src/mcp/server/lowlevel/server.py | 15 +- src/mcp/server/session.py | 7 +- src/mcp/server/sse.py | 13 +- src/mcp/server/stdio.py | 12 +- src/mcp/server/streamable_http.py | 18 +-- src/mcp/server/websocket.py | 12 +- src/mcp/shared/_context_streams.py | 129 ++++++++++++++++++ src/mcp/shared/memory.py | 10 +- src/mcp/shared/session.py | 30 +++- tests/client/conftest.py | 4 +- tests/shared/test_streamable_http.py | 10 +- 17 files changed, 286 insertions(+), 90 deletions(-) create mode 100644 src/mcp/shared/_context_streams.py diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 5fac56be5..6ef2f0b11 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -18,7 +18,7 @@ from urllib.parse import parse_qs, urlparse import httpx -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp.client._transport import ReadStream, WriteStream from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.session import ClientSession from mcp.client.sse import sse_client @@ -241,8 +241,8 @@ async def _default_redirect_handler(authorization_url: str) -> None: async def _run_session( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], ): """Run the MCP session with the given streams.""" print("🤝 Initializing MCP session...") diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index f3db17906..b9ec34422 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -6,9 +6,9 @@ from urllib.parse import urlparse import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import types +from mcp.client._transport import ReadStream, WriteStream from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client @@ -33,8 +33,8 @@ async def message_handler( async def run_session( - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], client_info: types.Implementation | None = None, ): async with ClientSession( diff --git a/src/mcp/client/_transport.py b/src/mcp/client/_transport.py index a86362900..ac1726be7 100644 --- a/src/mcp/client/_transport.py +++ b/src/mcp/client/_transport.py @@ -3,13 +3,59 @@ from __future__ import annotations from contextlib import AbstractAsyncContextManager -from typing import Protocol +from types import TracebackType +from typing import Protocol, TypeVar, runtime_checkable -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from typing_extensions import Self from mcp.shared.message import SessionMessage -TransportStreams = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +@runtime_checkable +class ReadStream(Protocol[T_co]): + """Protocol for reading items from a stream. + + Both ``MemoryObjectReceiveStream`` and ``ContextReceiveStream`` satisfy + this protocol. Consumers that need the sender's context should use + ``getattr(stream, 'last_context', None)``. + """ + + async def receive(self) -> T_co: ... + async def aclose(self) -> None: ... + def __aiter__(self) -> ReadStream[T_co]: ... + async def __anext__(self) -> T_co: ... + async def __aenter__(self) -> Self: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... + + +@runtime_checkable +class WriteStream(Protocol[T_contra]): + """Protocol for writing items to a stream. + + Both ``MemoryObjectSendStream`` and ``ContextSendStream`` satisfy + this protocol. + """ + + async def send(self, item: T_contra, /) -> None: ... + async def aclose(self) -> None: ... + async def __aenter__(self) -> Self: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... + + +TransportStreams = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]] class Transport(AbstractAsyncContextManager[TransportStreams], Protocol): diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a0ca751bd..47dadd77b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -4,10 +4,10 @@ from typing import Any, Protocol import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import TypeAdapter from mcp import types +from mcp.client._transport import ReadStream, WriteStream from mcp.client.experimental import ExperimentalClientFeatures from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared._context import RequestContext @@ -109,8 +109,8 @@ class ClientSession( ): def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 972efce58..e1e886cc1 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -7,11 +7,11 @@ import anyio import httpx from anyio.abc import TaskStatus -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse from httpx_sse._exceptions import SSEError from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage @@ -51,14 +51,8 @@ async def sse_client( auth: Optional HTTPX authentication handler. on_session_created: Optional callback invoked with the session ID when received. """ - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async with anyio.create_task_group() as tg: try: @@ -132,7 +126,8 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): async def post_writer(endpoint_url: str): try: async with write_stream_reader: - async for session_message in write_stream_reader: + + async def _send_message(session_message: SessionMessage) -> None: logger.debug(f"Sending client message: {session_message}") response = await client.post( endpoint_url, @@ -144,6 +139,14 @@ async def post_writer(endpoint_url: str): ) response.raise_for_status() logger.debug(f"Client message sent successfully: {response.status_code}") + + async for session_message in write_stream_reader: + sender_ctx = write_stream_reader.last_context + if sender_ctx is not None: + async with anyio.create_task_group() as tg: + sender_ctx.run(tg.start_soon, _send_message, session_message) + else: + await _send_message(session_message) # pragma: no cover except Exception: # pragma: lax no cover logger.exception("Error in post_writer") finally: diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3416bbc81..d178953f6 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -11,11 +11,11 @@ import anyio import httpx from anyio.abc import TaskGroup -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse from pydantic import ValidationError from mcp.client._transport import TransportStreams +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( @@ -38,8 +38,8 @@ # TODO(Marcelo): Put the TransportStreams in a module under shared, so we can import here. SessionMessageOrError = SessionMessage | Exception -StreamWriter = MemoryObjectSendStream[SessionMessageOrError] -StreamReader = MemoryObjectReceiveStream[SessionMessage] +StreamWriter = ContextSendStream[SessionMessageOrError] +StreamReader = ContextReceiveStream[SessionMessage] MCP_SESSION_ID = "mcp-session-id" MCP_PROTOCOL_VERSION = "mcp-protocol-version" @@ -434,14 +434,15 @@ async def post_writer( client: httpx.AsyncClient, write_stream_reader: StreamReader, read_stream_writer: StreamWriter, - write_stream: MemoryObjectSendStream[SessionMessage], + write_stream: ContextSendStream[SessionMessage], start_get_stream: Callable[[], None], tg: TaskGroup, ) -> None: """Handle writing requests to the server.""" try: async with write_stream_reader: - async for session_message in write_stream_reader: + + async def _handle_message(session_message: SessionMessage) -> None: message = session_message.message metadata = ( session_message.metadata @@ -478,6 +479,14 @@ async def handle_request_async(): else: await handle_request_async() + async for session_message in write_stream_reader: + sender_ctx = write_stream_reader.last_context + if sender_ctx is not None: + async with anyio.create_task_group() as tg_local: + sender_ctx.run(tg_local.start_soon, _handle_message, session_message) + else: + await _handle_message(session_message) # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Error in post_writer") finally: @@ -533,8 +542,8 @@ async def streamable_http_client( Example: See examples/snippets/clients/ for usage patterns. """ - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) # Determine if we need to create and manage the client client_provided = http_client is not None diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 167f34b8b..259029ccf 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -36,6 +36,7 @@ async def main(): from __future__ import annotations +import contextvars import logging import warnings from collections.abc import AsyncIterator, Awaitable, Callable @@ -44,7 +45,6 @@ async def main(): from typing import Any, Generic import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -52,6 +52,7 @@ async def main(): from typing_extensions import TypeVar from mcp import types +from mcp.client._transport import ReadStream, WriteStream from mcp.server.auth.middleware.auth_context import AuthContextMiddleware from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier @@ -355,8 +356,8 @@ def session_manager(self) -> StreamableHTTPSessionManager: async def run( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], initialization_options: InitializationOptions, # When False, exceptions are returned as messages to the client. # When True, exceptions are raised, which will cause the server to shut down @@ -390,7 +391,13 @@ async def run( async for message in session.incoming_messages: logger.debug("Received message: %s", message) - tg.start_soon( + if isinstance(message, RequestResponder) and message.context is not None: + context = message.context + else: + context = contextvars.copy_context() + + context.run( + tg.start_soon, self._handle_message, message, session, diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 759d2131a..1fbb65955 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -33,10 +33,11 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.memory import MemoryObjectReceiveStream from pydantic import AnyUrl, TypeAdapter from mcp import types +from mcp.client._transport import ReadStream, WriteStream from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages @@ -79,8 +80,8 @@ class ServerSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, ) -> None: diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9dcee67f7..48192ff61 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -43,7 +43,6 @@ async def handle_sse(request): from uuid import UUID, uuid4 import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from sse_starlette import EventSourceResponse from starlette.requests import Request @@ -55,6 +54,7 @@ async def handle_sse(request): TransportSecurityMiddleware, TransportSecuritySettings, ) +from mcp.shared._context_streams import ContextSendStream, create_context_streams from mcp.shared.message import ServerMessageMetadata, SessionMessage logger = logging.getLogger(__name__) @@ -72,7 +72,7 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] + _read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]] _security: TransportSecurityMiddleware def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: @@ -129,14 +129,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag raise ValueError("Request validation failed") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) session_id = uuid4() self._read_stream_writers[session_id] = read_stream_writer diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index e526bab56..119676b18 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -23,9 +23,9 @@ async def run_server(): import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared.message import SessionMessage @@ -43,14 +43,8 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async def stdin_reader(): try: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index aa99e7c88..e16b1d77b 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -24,7 +24,9 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.client._transport import ReadStream, WriteStream from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( @@ -119,10 +121,10 @@ class StreamableHTTPServerTransport: """ # Server notification streams for POST requests as well as standalone SSE stream - _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None - _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None - _write_stream: MemoryObjectSendStream[SessionMessage] | None = None - _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None + _read_stream_writer: ContextSendStream[SessionMessage | Exception] | None = None + _read_stream: ContextReceiveStream[SessionMessage | Exception] | None = None + _write_stream: ContextSendStream[SessionMessage] | None = None + _write_stream_reader: ContextReceiveStream[SessionMessage] | None = None _security: TransportSecurityMiddleware def __init__( @@ -954,8 +956,8 @@ async def connect( self, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], - MemoryObjectSendStream[SessionMessage], + ReadStream[SessionMessage | Exception], + WriteStream[SessionMessage], ], None, ]: @@ -967,8 +969,8 @@ async def connect( # Create the memory streams for this connection - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) # Store the streams self._read_stream_writer = read_stream_writer diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 3e675da5f..a1fe64d40 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -1,12 +1,12 @@ from contextlib import asynccontextmanager import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic_core import ValidationError from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared.message import SessionMessage @@ -19,14 +19,8 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async def ws_reader(): try: diff --git a/src/mcp/shared/_context_streams.py b/src/mcp/shared/_context_streams.py new file mode 100644 index 000000000..69365a1d1 --- /dev/null +++ b/src/mcp/shared/_context_streams.py @@ -0,0 +1,129 @@ +"""Context-aware memory stream wrappers. + +anyio memory streams do not propagate ``contextvars.Context`` across task +boundaries. These thin wrappers capture the sender's context at ``send()`` +time and expose it on the receive side via ``last_context``, so consumers +can restore it with ``ctx.run(handler, item)``. + +The iteration interface is unchanged (yields ``T``, not tuples), keeping +these wrappers duck-type compatible with plain ``MemoryObjectSendStream`` +and ``MemoryObjectReceiveStream``. +""" + +from __future__ import annotations + +import contextvars +from types import TracebackType +from typing import Any, Generic, TypeVar + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +T = TypeVar("T") +T_Item = TypeVar("T_Item") + +# Internal payload carried through the underlying raw stream. +_Envelope = tuple[contextvars.Context, T] + + +class ContextSendStream(Generic[T]): + """Send-side wrapper that snapshots ``contextvars.copy_context()`` on every ``send()``.""" + + __slots__ = ("_inner",) + + def __init__(self, inner: MemoryObjectSendStream[_Envelope[T]]) -> None: + self._inner = inner + + async def send(self, item: T) -> None: + await self._inner.send((contextvars.copy_context(), item)) + + def close(self) -> None: + self._inner.close() + + async def aclose(self) -> None: + await self._inner.aclose() + + def clone(self) -> ContextSendStream[T]: + return ContextSendStream(self._inner.clone()) + + async def __aenter__(self) -> ContextSendStream[T]: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.aclose() + return None + + +class ContextReceiveStream(Generic[T]): + """Receive-side wrapper that yields ``T`` and stores the sender's context in ``last_context``.""" + + __slots__ = ("_inner", "last_context") + + def __init__(self, inner: MemoryObjectReceiveStream[_Envelope[T]]) -> None: + self._inner = inner + self.last_context: contextvars.Context | None = None + + async def receive(self) -> T: + ctx, item = await self._inner.receive() + self.last_context = ctx + return item + + def close(self) -> None: + self._inner.close() + + async def aclose(self) -> None: + await self._inner.aclose() + + def clone(self) -> ContextReceiveStream[T]: + return ContextReceiveStream(self._inner.clone()) + + def __aiter__(self) -> ContextReceiveStream[T]: + return self + + async def __anext__(self) -> T: + try: + return await self.receive() + except anyio.EndOfStream: + raise StopAsyncIteration + + async def __aenter__(self) -> ContextReceiveStream[T]: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.aclose() + return None + + +def _create_context_streams( + max_buffer_size: float = 0, +) -> tuple[ContextSendStream[Any], ContextReceiveStream[Any]]: + raw_send: MemoryObjectSendStream[Any] + raw_receive: MemoryObjectReceiveStream[Any] + raw_send, raw_receive = anyio.create_memory_object_stream(max_buffer_size) + return ContextSendStream(raw_send), ContextReceiveStream(raw_receive) + + +class _CreateContextStreams: + """Callable that supports ``create_context_streams[T](n)`` bracket syntax. + + Matches anyio's ``create_memory_object_stream`` API style. + """ + + def __getitem__(self, _item: Any) -> _CreateContextStreams: + return self + + def __call__(self, max_buffer_size: float = 0) -> tuple[ContextSendStream[Any], ContextReceiveStream[Any]]: + return _create_context_streams(max_buffer_size) + + +create_context_streams = _CreateContextStreams() diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index f2d5e2b9a..468590d09 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -5,12 +5,10 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared.message import SessionMessage -MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] +MessageStream = tuple[ContextReceiveStream[SessionMessage | Exception], ContextSendStream[SessionMessage | Exception]] @asynccontextmanager @@ -22,8 +20,8 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS (read_stream, write_stream) """ # Create streams for both directions - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_to_client_send, server_to_client_receive = create_context_streams[SessionMessage | Exception](1) + client_to_server_send, client_to_server_receive = create_context_streams[SessionMessage | Exception](1) client_streams = (server_to_client_receive, client_to_server_send) server_streams = (client_to_server_receive, server_to_client_send) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 9364abb73..6f6513472 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextvars import logging from collections.abc import Callable from contextlib import AsyncExitStack @@ -7,10 +8,11 @@ from typing import Any, Generic, Protocol, TypeVar import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.memory import MemoryObjectSendStream from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +from mcp.client._transport import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.response_router import ResponseRouter @@ -79,11 +81,13 @@ def __init__( session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any], message_metadata: MessageMetadata = None, + context: contextvars.Context | None = None, ) -> None: self.request_id = request_id self.request_meta = request_meta self.request = request self.message_metadata = message_metadata + self.context = context self._session = session self._completed = False self._cancel_scope = anyio.CancelScope() @@ -181,8 +185,8 @@ class BaseSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], # If none, reading will never time out read_timeout_seconds: float | None = None, ) -> None: @@ -333,10 +337,9 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: async def _receive_loop(self) -> None: async with self._read_stream, self._write_stream: try: - async for message in self._read_stream: - if isinstance(message, Exception): - await self._handle_incoming(message) - elif isinstance(message.message, JSONRPCRequest): + + async def _handle_session_message(message: SessionMessage) -> None: + if isinstance(message.message, JSONRPCRequest): try: validated_request = self._receive_request_adapter.validate_python( message.message.model_dump(by_alias=True, mode="json", exclude_none=True), @@ -349,6 +352,7 @@ async def _receive_loop(self) -> None: session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, + context=contextvars.copy_context(), ) self._in_flight[responder.request_id] = responder await self._received_request(responder) @@ -406,6 +410,18 @@ async def _receive_loop(self) -> None: else: # Response or error await self._handle_response(message) + async for message in self._read_stream: + if isinstance(message, Exception): + await self._handle_incoming(message) + continue + + sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) + if sender_ctx is not None: + async with anyio.create_task_group() as tg: + sender_ctx.run(tg.start_soon, _handle_session_message, message) + else: + await _handle_session_message(message) + except anyio.ClosedResourceError: # This is expected when the client disconnects abruptly. # Without this handler, the exception would propagate up and diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 2e39f1363..081e1d68e 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -4,15 +4,15 @@ from unittest.mock import patch import pytest -from anyio.streams.memory import MemoryObjectSendStream import mcp.shared.memory +from mcp.client._transport import WriteStream from mcp.shared.message import SessionMessage from mcp.types import JSONRPCNotification, JSONRPCRequest class SpyMemoryObjectSendStream: - def __init__(self, original_stream: MemoryObjectSendStream[SessionMessage]): + def __init__(self, original_stream: WriteStream[SessionMessage]): self.original_stream = original_stream self.sent_messages: list[SessionMessage] = [] diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f8ca30441..3d5770fb6 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -45,6 +45,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._context import RequestContext +from mcp.shared._context_streams import create_context_streams from mcp.shared._httpx_utils import ( MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT, @@ -1783,8 +1784,8 @@ async def test_handle_sse_event_skips_empty_data(): # Create a mock SSE event with empty data (keep-alive ping) mock_sse = ServerSentEvent(event="message", data="", id=None, retry=None) - # Create a mock stream writer - write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + # Create a context-aware stream writer (matches StreamWriter type alias) + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) try: # Call _handle_sse_event with empty data - should return False and not raise @@ -1794,8 +1795,9 @@ async def test_handle_sse_event_skips_empty_data(): assert result is False # Nothing should have been written to the stream - # Check buffer is empty (statistics().current_buffer_used returns buffer size) - assert write_stream.statistics().current_buffer_used == 0 + with pytest.raises(TimeoutError): + with anyio.fail_after(0): + await read_stream.receive() finally: await write_stream.aclose() await read_stream.aclose() From 2513380f03bba03cb2e5d5fa5c82a68dcbf9965b Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 16 Mar 2026 08:59:14 +0100 Subject: [PATCH 2/5] Add pragma: no cover to Protocol stubs and unused clone() methods --- src/mcp/client/_transport.py | 4 ++-- src/mcp/shared/_context_streams.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/_transport.py b/src/mcp/client/_transport.py index ac1726be7..ef0b1bb4e 100644 --- a/src/mcp/client/_transport.py +++ b/src/mcp/client/_transport.py @@ -15,7 +15,7 @@ @runtime_checkable -class ReadStream(Protocol[T_co]): +class ReadStream(Protocol[T_co]): # pragma: no cover """Protocol for reading items from a stream. Both ``MemoryObjectReceiveStream`` and ``ContextReceiveStream`` satisfy @@ -37,7 +37,7 @@ async def __aexit__( @runtime_checkable -class WriteStream(Protocol[T_contra]): +class WriteStream(Protocol[T_contra]): # pragma: no cover """Protocol for writing items to a stream. Both ``MemoryObjectSendStream`` and ``ContextSendStream`` satisfy diff --git a/src/mcp/shared/_context_streams.py b/src/mcp/shared/_context_streams.py index 69365a1d1..25c924bf5 100644 --- a/src/mcp/shared/_context_streams.py +++ b/src/mcp/shared/_context_streams.py @@ -43,7 +43,7 @@ def close(self) -> None: async def aclose(self) -> None: await self._inner.aclose() - def clone(self) -> ContextSendStream[T]: + def clone(self) -> ContextSendStream[T]: # pragma: no cover return ContextSendStream(self._inner.clone()) async def __aenter__(self) -> ContextSendStream[T]: @@ -79,7 +79,7 @@ def close(self) -> None: async def aclose(self) -> None: await self._inner.aclose() - def clone(self) -> ContextReceiveStream[T]: + def clone(self) -> ContextReceiveStream[T]: # pragma: no cover return ContextReceiveStream(self._inner.clone()) def __aiter__(self) -> ContextReceiveStream[T]: From 1b34f2eb21c5382f83b712dce2952fcc15dff6c0 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 16 Mar 2026 09:13:38 +0100 Subject: [PATCH 3/5] Use pragma: no branch for Protocol stubs to satisfy strict-no-cover --- src/mcp/client/_transport.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/mcp/client/_transport.py b/src/mcp/client/_transport.py index ef0b1bb4e..076387533 100644 --- a/src/mcp/client/_transport.py +++ b/src/mcp/client/_transport.py @@ -15,7 +15,7 @@ @runtime_checkable -class ReadStream(Protocol[T_co]): # pragma: no cover +class ReadStream(Protocol[T_co]): # pragma: no branch """Protocol for reading items from a stream. Both ``MemoryObjectReceiveStream`` and ``ContextReceiveStream`` satisfy @@ -23,12 +23,12 @@ class ReadStream(Protocol[T_co]): # pragma: no cover ``getattr(stream, 'last_context', None)``. """ - async def receive(self) -> T_co: ... - async def aclose(self) -> None: ... - def __aiter__(self) -> ReadStream[T_co]: ... - async def __anext__(self) -> T_co: ... - async def __aenter__(self) -> Self: ... - async def __aexit__( + async def receive(self) -> T_co: ... # pragma: no branch + async def aclose(self) -> None: ... # pragma: no branch + def __aiter__(self) -> ReadStream[T_co]: ... # pragma: no branch + async def __anext__(self) -> T_co: ... # pragma: no branch + async def __aenter__(self) -> Self: ... # pragma: no branch + async def __aexit__( # pragma: no branch self, exc_type: type[BaseException] | None, exc_val: BaseException | None, @@ -37,17 +37,17 @@ async def __aexit__( @runtime_checkable -class WriteStream(Protocol[T_contra]): # pragma: no cover +class WriteStream(Protocol[T_contra]): # pragma: no branch """Protocol for writing items to a stream. Both ``MemoryObjectSendStream`` and ``ContextSendStream`` satisfy this protocol. """ - async def send(self, item: T_contra, /) -> None: ... - async def aclose(self) -> None: ... - async def __aenter__(self) -> Self: ... - async def __aexit__( + async def send(self, item: T_contra, /) -> None: ... # pragma: no branch + async def aclose(self) -> None: ... # pragma: no branch + async def __aenter__(self) -> Self: ... # pragma: no branch + async def __aexit__( # pragma: no branch self, exc_type: type[BaseException] | None, exc_val: BaseException | None, From 32565b3ebe3b0aba566a7150f637b0b3f86af7ac Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 16 Mar 2026 09:30:09 +0100 Subject: [PATCH 4/5] Address review feedback: ExceptionGroup bug, reverse dependency, type widening - Replace task group + ctx.run(tg.start_soon, ...) with direct await sender_ctx.run(handler, msg) to avoid ExceptionGroup wrapping that would prevent ClosedResourceError from being caught - Move ReadStream/WriteStream protocols to mcp.shared._stream_protocols so shared/server modules don't depend on client internals - Restore write stream type narrowing in MessageStream (SessionMessage only, not SessionMessage | Exception) - Remove unused T_Item TypeVar --- src/mcp/client/_transport.py | 51 ++--------------------------- src/mcp/client/sse.py | 3 +- src/mcp/client/streamable_http.py | 3 +- src/mcp/server/lowlevel/server.py | 2 +- src/mcp/server/session.py | 2 +- src/mcp/server/streamable_http.py | 2 +- src/mcp/shared/_context_streams.py | 1 - src/mcp/shared/_stream_protocols.py | 51 +++++++++++++++++++++++++++++ src/mcp/shared/memory.py | 2 +- src/mcp/shared/session.py | 6 ++-- 10 files changed, 63 insertions(+), 60 deletions(-) create mode 100644 src/mcp/shared/_stream_protocols.py diff --git a/src/mcp/client/_transport.py b/src/mcp/client/_transport.py index 076387533..0163fef95 100644 --- a/src/mcp/client/_transport.py +++ b/src/mcp/client/_transport.py @@ -3,57 +3,12 @@ from __future__ import annotations from contextlib import AbstractAsyncContextManager -from types import TracebackType -from typing import Protocol, TypeVar, runtime_checkable - -from typing_extensions import Self +from typing import Protocol +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.message import SessionMessage -T_co = TypeVar("T_co", covariant=True) -T_contra = TypeVar("T_contra", contravariant=True) - - -@runtime_checkable -class ReadStream(Protocol[T_co]): # pragma: no branch - """Protocol for reading items from a stream. - - Both ``MemoryObjectReceiveStream`` and ``ContextReceiveStream`` satisfy - this protocol. Consumers that need the sender's context should use - ``getattr(stream, 'last_context', None)``. - """ - - async def receive(self) -> T_co: ... # pragma: no branch - async def aclose(self) -> None: ... # pragma: no branch - def __aiter__(self) -> ReadStream[T_co]: ... # pragma: no branch - async def __anext__(self) -> T_co: ... # pragma: no branch - async def __aenter__(self) -> Self: ... # pragma: no branch - async def __aexit__( # pragma: no branch - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: ... - - -@runtime_checkable -class WriteStream(Protocol[T_contra]): # pragma: no branch - """Protocol for writing items to a stream. - - Both ``MemoryObjectSendStream`` and ``ContextSendStream`` satisfy - this protocol. - """ - - async def send(self, item: T_contra, /) -> None: ... # pragma: no branch - async def aclose(self) -> None: ... # pragma: no branch - async def __aenter__(self) -> Self: ... # pragma: no branch - async def __aexit__( # pragma: no branch - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: ... - +__all__ = ["ReadStream", "WriteStream", "Transport", "TransportStreams"] TransportStreams = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]] diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index e1e886cc1..b7213b828 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -143,8 +143,7 @@ async def _send_message(session_message: SessionMessage) -> None: async for session_message in write_stream_reader: sender_ctx = write_stream_reader.last_context if sender_ctx is not None: - async with anyio.create_task_group() as tg: - sender_ctx.run(tg.start_soon, _send_message, session_message) + await sender_ctx.run(_send_message, session_message) else: await _send_message(session_message) # pragma: no cover except Exception: # pragma: lax no cover diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index d178953f6..e025e55b8 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -482,8 +482,7 @@ async def handle_request_async(): async for session_message in write_stream_reader: sender_ctx = write_stream_reader.last_context if sender_ctx is not None: - async with anyio.create_task_group() as tg_local: - sender_ctx.run(tg_local.start_soon, _handle_message, session_message) + await sender_ctx.run(_handle_message, session_message) else: await _handle_message(session_message) # pragma: no cover diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 259029ccf..b6dcc398d 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -52,7 +52,6 @@ async def main(): from typing_extensions import TypeVar from mcp import types -from mcp.client._transport import ReadStream, WriteStream from mcp.server.auth.middleware.auth_context import AuthContextMiddleware from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier @@ -66,6 +65,7 @@ async def main(): from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 1fbb65955..2d75e9b9b 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -37,10 +37,10 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: from pydantic import AnyUrl, TypeAdapter from mcp import types -from mcp.client._transport import ReadStream, WriteStream from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index e16b1d77b..f14201857 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -24,9 +24,9 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send -from mcp.client._transport import ReadStream, WriteStream from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( diff --git a/src/mcp/shared/_context_streams.py b/src/mcp/shared/_context_streams.py index 25c924bf5..8c3e98df4 100644 --- a/src/mcp/shared/_context_streams.py +++ b/src/mcp/shared/_context_streams.py @@ -20,7 +20,6 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream T = TypeVar("T") -T_Item = TypeVar("T_Item") # Internal payload carried through the underlying raw stream. _Envelope = tuple[contextvars.Context, T] diff --git a/src/mcp/shared/_stream_protocols.py b/src/mcp/shared/_stream_protocols.py new file mode 100644 index 000000000..f7da4fc20 --- /dev/null +++ b/src/mcp/shared/_stream_protocols.py @@ -0,0 +1,51 @@ +"""Stream protocols for MCP transports. + +These are general-purpose protocols satisfied by both ``MemoryObjectSendStream``/ +``MemoryObjectReceiveStream`` and the context-aware wrappers in ``_context_streams``. +""" + +from __future__ import annotations + +from types import TracebackType +from typing import Protocol, TypeVar, runtime_checkable + +from typing_extensions import Self + +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +@runtime_checkable +class ReadStream(Protocol[T_co]): # pragma: no branch + """Protocol for reading items from a stream. + + Consumers that need the sender's context should use + ``getattr(stream, 'last_context', None)``. + """ + + async def receive(self) -> T_co: ... # pragma: no branch + async def aclose(self) -> None: ... # pragma: no branch + def __aiter__(self) -> ReadStream[T_co]: ... # pragma: no branch + async def __anext__(self) -> T_co: ... # pragma: no branch + async def __aenter__(self) -> Self: ... # pragma: no branch + async def __aexit__( # pragma: no branch + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... + + +@runtime_checkable +class WriteStream(Protocol[T_contra]): # pragma: no branch + """Protocol for writing items to a stream.""" + + async def send(self, item: T_contra, /) -> None: ... # pragma: no branch + async def aclose(self) -> None: ... # pragma: no branch + async def __aenter__(self) -> Self: ... # pragma: no branch + async def __aexit__( # pragma: no branch + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 468590d09..1f2e05cc3 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -8,7 +8,7 @@ from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared.message import SessionMessage -MessageStream = tuple[ContextReceiveStream[SessionMessage | Exception], ContextSendStream[SessionMessage | Exception]] +MessageStream = tuple[ContextReceiveStream[SessionMessage | Exception], ContextSendStream[SessionMessage]] @asynccontextmanager diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 6f6513472..28dec2211 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,7 +12,7 @@ from pydantic import BaseModel, TypeAdapter from typing_extensions import Self -from mcp.client._transport import ReadStream, WriteStream +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.response_router import ResponseRouter @@ -417,8 +417,8 @@ async def _handle_session_message(message: SessionMessage) -> None: sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) if sender_ctx is not None: - async with anyio.create_task_group() as tg: - sender_ctx.run(tg.start_soon, _handle_session_message, message) + coro = sender_ctx.run(_handle_session_message, message) + await coro else: await _handle_session_message(message) From b321dd206bb318e4ccec416fd26f07d773b76efb Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 16 Mar 2026 09:46:12 +0100 Subject: [PATCH 5/5] Restore correct context propagation for async handlers ctx.run(async_fn) only sets context during coroutine creation, not execution. For session._receive_loop, pass sender_ctx explicitly to the handler which stores it on RequestResponder.context. For client post_writers (sse/streamable_http), restore ctx.run(tg.start_soon, ...) so the httpx calls actually run in the sender's context. --- src/mcp/client/sse.py | 3 ++- src/mcp/client/streamable_http.py | 3 ++- src/mcp/shared/session.py | 13 ++++++------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index b7213b828..e1e886cc1 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -143,7 +143,8 @@ async def _send_message(session_message: SessionMessage) -> None: async for session_message in write_stream_reader: sender_ctx = write_stream_reader.last_context if sender_ctx is not None: - await sender_ctx.run(_send_message, session_message) + async with anyio.create_task_group() as tg: + sender_ctx.run(tg.start_soon, _send_message, session_message) else: await _send_message(session_message) # pragma: no cover except Exception: # pragma: lax no cover diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index e025e55b8..d178953f6 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -482,7 +482,8 @@ async def handle_request_async(): async for session_message in write_stream_reader: sender_ctx = write_stream_reader.last_context if sender_ctx is not None: - await sender_ctx.run(_handle_message, session_message) + async with anyio.create_task_group() as tg_local: + sender_ctx.run(tg_local.start_soon, _handle_message, session_message) else: await _handle_message(session_message) # pragma: no cover diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 28dec2211..22a604c82 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -338,7 +338,10 @@ async def _receive_loop(self) -> None: async with self._read_stream, self._write_stream: try: - async def _handle_session_message(message: SessionMessage) -> None: + async def _handle_session_message( + message: SessionMessage, + sender_context: contextvars.Context | None = None, + ) -> None: if isinstance(message.message, JSONRPCRequest): try: validated_request = self._receive_request_adapter.validate_python( @@ -352,7 +355,7 @@ async def _handle_session_message(message: SessionMessage) -> None: session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, - context=contextvars.copy_context(), + context=sender_context, ) self._in_flight[responder.request_id] = responder await self._received_request(responder) @@ -416,11 +419,7 @@ async def _handle_session_message(message: SessionMessage) -> None: continue sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) - if sender_ctx is not None: - coro = sender_ctx.run(_handle_session_message, message) - await coro - else: - await _handle_session_message(message) + await _handle_session_message(message, sender_context=sender_ctx) except anyio.ClosedResourceError: # This is expected when the client disconnects abruptly.