Skip to content

Commit 6fef500

Browse files
committed
Propagate contextvars.Context through anyio streams without modifying SessionMessage
Introduce context-aware stream wrappers (ContextSendStream / ContextReceiveStream) that capture the sender's contextvars.Context at send() time and expose it on the receive side via last_context. This enables OpenTelemetry trace propagation, per-request auth via ContextVars, and other context-dependent use cases across the anyio memory stream boundary - without adding any field to SessionMessage. Key changes: - New _context_streams module with ContextSendStream, ContextReceiveStream, and create_context_streams factory (mirrors anyio's bracket syntax API) - Protocol-based ReadStream/WriteStream in _transport.py, replacing concrete MemoryObjectReceiveStream/MemoryObjectSendStream in all parameter types - All transport stream creation sites use create_context_streams - BaseSession._receive_loop and client post_writers restore sender context via ctx.run(tg.start_soon, handler, message) - RequestResponder carries context for the session-to-server handler boundary Github-Issue:#1996
1 parent e1fd62e commit 6fef500

File tree

17 files changed

+286
-90
lines changed

17 files changed

+286
-90
lines changed

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from urllib.parse import parse_qs, urlparse
1919

2020
import httpx
21-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
21+
from mcp.client._transport import ReadStream, WriteStream
2222
from mcp.client.auth import OAuthClientProvider, TokenStorage
2323
from mcp.client.session import ClientSession
2424
from mcp.client.sse import sse_client
@@ -241,8 +241,8 @@ async def _default_redirect_handler(authorization_url: str) -> None:
241241

242242
async def _run_session(
243243
self,
244-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
245-
write_stream: MemoryObjectSendStream[SessionMessage],
244+
read_stream: ReadStream[SessionMessage | Exception],
245+
write_stream: WriteStream[SessionMessage],
246246
):
247247
"""Run the MCP session with the given streams."""
248248
print("🤝 Initializing MCP session...")

src/mcp/client/__main__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from urllib.parse import urlparse
77

88
import anyio
9-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
109

1110
from mcp import types
11+
from mcp.client._transport import ReadStream, WriteStream
1212
from mcp.client.session import ClientSession
1313
from mcp.client.sse import sse_client
1414
from mcp.client.stdio import StdioServerParameters, stdio_client
@@ -33,8 +33,8 @@ async def message_handler(
3333

3434

3535
async def run_session(
36-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
37-
write_stream: MemoryObjectSendStream[SessionMessage],
36+
read_stream: ReadStream[SessionMessage | Exception],
37+
write_stream: WriteStream[SessionMessage],
3838
client_info: types.Implementation | None = None,
3939
):
4040
async with ClientSession(

src/mcp/client/_transport.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,59 @@
33
from __future__ import annotations
44

55
from contextlib import AbstractAsyncContextManager
6-
from typing import Protocol
6+
from types import TracebackType
7+
from typing import Protocol, TypeVar, runtime_checkable
78

8-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
9+
from typing_extensions import Self
910

1011
from mcp.shared.message import SessionMessage
1112

12-
TransportStreams = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
13+
T_co = TypeVar("T_co", covariant=True)
14+
T_contra = TypeVar("T_contra", contravariant=True)
15+
16+
17+
@runtime_checkable
18+
class ReadStream(Protocol[T_co]):
19+
"""Protocol for reading items from a stream.
20+
21+
Both ``MemoryObjectReceiveStream`` and ``ContextReceiveStream`` satisfy
22+
this protocol. Consumers that need the sender's context should use
23+
``getattr(stream, 'last_context', None)``.
24+
"""
25+
26+
async def receive(self) -> T_co: ...
27+
async def aclose(self) -> None: ...
28+
def __aiter__(self) -> ReadStream[T_co]: ...
29+
async def __anext__(self) -> T_co: ...
30+
async def __aenter__(self) -> Self: ...
31+
async def __aexit__(
32+
self,
33+
exc_type: type[BaseException] | None,
34+
exc_val: BaseException | None,
35+
exc_tb: TracebackType | None,
36+
) -> bool | None: ...
37+
38+
39+
@runtime_checkable
40+
class WriteStream(Protocol[T_contra]):
41+
"""Protocol for writing items to a stream.
42+
43+
Both ``MemoryObjectSendStream`` and ``ContextSendStream`` satisfy
44+
this protocol.
45+
"""
46+
47+
async def send(self, item: T_contra, /) -> None: ...
48+
async def aclose(self) -> None: ...
49+
async def __aenter__(self) -> Self: ...
50+
async def __aexit__(
51+
self,
52+
exc_type: type[BaseException] | None,
53+
exc_val: BaseException | None,
54+
exc_tb: TracebackType | None,
55+
) -> bool | None: ...
56+
57+
58+
TransportStreams = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]]
1359

1460

1561
class Transport(AbstractAsyncContextManager[TransportStreams], Protocol):

src/mcp/client/session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from typing import Any, Protocol
55

66
import anyio.lowlevel
7-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
87
from pydantic import TypeAdapter
98

109
from mcp import types
10+
from mcp.client._transport import ReadStream, WriteStream
1111
from mcp.client.experimental import ExperimentalClientFeatures
1212
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
1313
from mcp.shared._context import RequestContext
@@ -109,8 +109,8 @@ class ClientSession(
109109
):
110110
def __init__(
111111
self,
112-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
113-
write_stream: MemoryObjectSendStream[SessionMessage],
112+
read_stream: ReadStream[SessionMessage | Exception],
113+
write_stream: WriteStream[SessionMessage],
114114
read_timeout_seconds: float | None = None,
115115
sampling_callback: SamplingFnT | None = None,
116116
elicitation_callback: ElicitationFnT | None = None,

src/mcp/client/sse.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import anyio
88
import httpx
99
from anyio.abc import TaskStatus
10-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1110
from httpx_sse import aconnect_sse
1211
from httpx_sse._exceptions import SSEError
1312

1413
from mcp import types
14+
from mcp.shared._context_streams import create_context_streams
1515
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
1616
from mcp.shared.message import SessionMessage
1717

@@ -51,14 +51,8 @@ async def sse_client(
5151
auth: Optional HTTPX authentication handler.
5252
on_session_created: Optional callback invoked with the session ID when received.
5353
"""
54-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
55-
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
56-
57-
write_stream: MemoryObjectSendStream[SessionMessage]
58-
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
59-
60-
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
61-
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
54+
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
55+
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
6256

6357
async with anyio.create_task_group() as tg:
6458
try:
@@ -132,7 +126,8 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
132126
async def post_writer(endpoint_url: str):
133127
try:
134128
async with write_stream_reader:
135-
async for session_message in write_stream_reader:
129+
130+
async def _send_message(session_message: SessionMessage) -> None:
136131
logger.debug(f"Sending client message: {session_message}")
137132
response = await client.post(
138133
endpoint_url,
@@ -144,6 +139,14 @@ async def post_writer(endpoint_url: str):
144139
)
145140
response.raise_for_status()
146141
logger.debug(f"Client message sent successfully: {response.status_code}")
142+
143+
async for session_message in write_stream_reader:
144+
sender_ctx = write_stream_reader.last_context
145+
if sender_ctx is not None:
146+
async with anyio.create_task_group() as tg:
147+
sender_ctx.run(tg.start_soon, _send_message, session_message)
148+
else:
149+
await _send_message(session_message) # pragma: no cover
147150
except Exception: # pragma: lax no cover
148151
logger.exception("Error in post_writer")
149152
finally:

src/mcp/client/streamable_http.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
import anyio
1212
import httpx
1313
from anyio.abc import TaskGroup
14-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1514
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
1615
from pydantic import ValidationError
1716

1817
from mcp.client._transport import TransportStreams
18+
from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams
1919
from mcp.shared._httpx_utils import create_mcp_http_client
2020
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2121
from mcp.types import (
@@ -38,8 +38,8 @@
3838

3939
# TODO(Marcelo): Put the TransportStreams in a module under shared, so we can import here.
4040
SessionMessageOrError = SessionMessage | Exception
41-
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
42-
StreamReader = MemoryObjectReceiveStream[SessionMessage]
41+
StreamWriter = ContextSendStream[SessionMessageOrError]
42+
StreamReader = ContextReceiveStream[SessionMessage]
4343

4444
MCP_SESSION_ID = "mcp-session-id"
4545
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
@@ -434,14 +434,15 @@ async def post_writer(
434434
client: httpx.AsyncClient,
435435
write_stream_reader: StreamReader,
436436
read_stream_writer: StreamWriter,
437-
write_stream: MemoryObjectSendStream[SessionMessage],
437+
write_stream: ContextSendStream[SessionMessage],
438438
start_get_stream: Callable[[], None],
439439
tg: TaskGroup,
440440
) -> None:
441441
"""Handle writing requests to the server."""
442442
try:
443443
async with write_stream_reader:
444-
async for session_message in write_stream_reader:
444+
445+
async def _handle_message(session_message: SessionMessage) -> None:
445446
message = session_message.message
446447
metadata = (
447448
session_message.metadata
@@ -478,6 +479,14 @@ async def handle_request_async():
478479
else:
479480
await handle_request_async()
480481

482+
async for session_message in write_stream_reader:
483+
sender_ctx = write_stream_reader.last_context
484+
if sender_ctx is not None:
485+
async with anyio.create_task_group() as tg_local:
486+
sender_ctx.run(tg_local.start_soon, _handle_message, session_message)
487+
else:
488+
await _handle_message(session_message) # pragma: no cover
489+
481490
except Exception: # pragma: lax no cover
482491
logger.exception("Error in post_writer")
483492
finally:
@@ -533,8 +542,8 @@ async def streamable_http_client(
533542
Example:
534543
See examples/snippets/clients/ for usage patterns.
535544
"""
536-
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
537-
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
545+
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
546+
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
538547

539548
# Determine if we need to create and manage the client
540549
client_provided = http_client is not None

src/mcp/server/lowlevel/server.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ async def main():
3636

3737
from __future__ import annotations
3838

39+
import contextvars
3940
import logging
4041
import warnings
4142
from collections.abc import AsyncIterator, Awaitable, Callable
@@ -44,14 +45,14 @@ async def main():
4445
from typing import Any, Generic
4546

4647
import anyio
47-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4848
from starlette.applications import Starlette
4949
from starlette.middleware import Middleware
5050
from starlette.middleware.authentication import AuthenticationMiddleware
5151
from starlette.routing import Mount, Route
5252
from typing_extensions import TypeVar
5353

5454
from mcp import types
55+
from mcp.client._transport import ReadStream, WriteStream
5556
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
5657
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
5758
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
@@ -355,8 +356,8 @@ def session_manager(self) -> StreamableHTTPSessionManager:
355356

356357
async def run(
357358
self,
358-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
359-
write_stream: MemoryObjectSendStream[SessionMessage],
359+
read_stream: ReadStream[SessionMessage | Exception],
360+
write_stream: WriteStream[SessionMessage],
360361
initialization_options: InitializationOptions,
361362
# When False, exceptions are returned as messages to the client.
362363
# When True, exceptions are raised, which will cause the server to shut down
@@ -390,7 +391,13 @@ async def run(
390391
async for message in session.incoming_messages:
391392
logger.debug("Received message: %s", message)
392393

393-
tg.start_soon(
394+
if isinstance(message, RequestResponder) and message.context is not None:
395+
context = message.context
396+
else:
397+
context = contextvars.copy_context()
398+
399+
context.run(
400+
tg.start_soon,
394401
self._handle_message,
395402
message,
396403
session,

src/mcp/server/session.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
3333

3434
import anyio
3535
import anyio.lowlevel
36-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
36+
from anyio.streams.memory import MemoryObjectReceiveStream
3737
from pydantic import AnyUrl, TypeAdapter
3838

3939
from mcp import types
40+
from mcp.client._transport import ReadStream, WriteStream
4041
from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures
4142
from mcp.server.models import InitializationOptions
4243
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
@@ -79,8 +80,8 @@ class ServerSession(
7980

8081
def __init__(
8182
self,
82-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
83-
write_stream: MemoryObjectSendStream[SessionMessage],
83+
read_stream: ReadStream[SessionMessage | Exception],
84+
write_stream: WriteStream[SessionMessage],
8485
init_options: InitializationOptions,
8586
stateless: bool = False,
8687
) -> None:

src/mcp/server/sse.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ async def handle_sse(request):
4343
from uuid import UUID, uuid4
4444

4545
import anyio
46-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4746
from pydantic import ValidationError
4847
from sse_starlette import EventSourceResponse
4948
from starlette.requests import Request
@@ -55,6 +54,7 @@ async def handle_sse(request):
5554
TransportSecurityMiddleware,
5655
TransportSecuritySettings,
5756
)
57+
from mcp.shared._context_streams import ContextSendStream, create_context_streams
5858
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5959

6060
logger = logging.getLogger(__name__)
@@ -72,7 +72,7 @@ class SseServerTransport:
7272
"""
7373

7474
_endpoint: str
75-
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
75+
_read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]]
7676
_security: TransportSecurityMiddleware
7777

7878
def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None:
@@ -129,14 +129,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag
129129
raise ValueError("Request validation failed")
130130

131131
logger.debug("Setting up SSE connection")
132-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
133-
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
134132

135-
write_stream: MemoryObjectSendStream[SessionMessage]
136-
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
137-
138-
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
139-
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
133+
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
134+
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
140135

141136
session_id = uuid4()
142137
self._read_stream_writers[session_id] = read_stream_writer

src/mcp/server/stdio.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ async def run_server():
2323

2424
import anyio
2525
import anyio.lowlevel
26-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2726

2827
from mcp import types
28+
from mcp.shared._context_streams import create_context_streams
2929
from mcp.shared.message import SessionMessage
3030

3131

@@ -43,14 +43,8 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.
4343
if not stdout:
4444
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
4545

46-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
47-
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
48-
49-
write_stream: MemoryObjectSendStream[SessionMessage]
50-
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
51-
52-
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
53-
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
46+
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
47+
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
5448

5549
async def stdin_reader():
5650
try:

0 commit comments

Comments
 (0)