Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from urllib.parse import parse_qs, urlparse

import httpx
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client._transport import ReadStream, WriteStream
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
Expand Down Expand Up @@ -241,8 +241,8 @@ async def _default_redirect_handler(authorization_url: str) -> None:

async def _run_session(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_stream: ReadStream[SessionMessage | Exception],
write_stream: WriteStream[SessionMessage],
):
"""Run the MCP session with the given streams."""
print("🤝 Initializing MCP session...")
Expand Down
6 changes: 3 additions & 3 deletions src/mcp/client/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from urllib.parse import urlparse

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp import types
from mcp.client._transport import ReadStream, WriteStream
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
Expand All @@ -33,8 +33,8 @@ async def message_handler(


async def run_session(
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_stream: ReadStream[SessionMessage | Exception],
write_stream: WriteStream[SessionMessage],
client_info: types.Implementation | None = None,
):
async with ClientSession(
Expand Down
7 changes: 4 additions & 3 deletions src/mcp/client/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from contextlib import AbstractAsyncContextManager
from typing import Protocol

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.message import SessionMessage

TransportStreams = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
__all__ = ["ReadStream", "WriteStream", "Transport", "TransportStreams"]

TransportStreams = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]]


class Transport(AbstractAsyncContextManager[TransportStreams], Protocol):
Expand Down
6 changes: 3 additions & 3 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from typing import Any, Protocol

import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import TypeAdapter

from mcp import types
from mcp.client._transport import ReadStream, WriteStream
from mcp.client.experimental import ExperimentalClientFeatures
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
from mcp.shared._context import RequestContext
Expand Down Expand Up @@ -109,8 +109,8 @@ class ClientSession(
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_stream: ReadStream[SessionMessage | Exception],
write_stream: WriteStream[SessionMessage],
read_timeout_seconds: float | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
Expand Down
23 changes: 13 additions & 10 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from httpx_sse._exceptions import SSEError

from mcp import types
from mcp.shared._context_streams import create_context_streams
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage

Expand Down Expand Up @@ -51,14 +51,8 @@ async def sse_client(
auth: Optional HTTPX authentication handler.
on_session_created: Optional callback invoked with the session ID when received.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]

write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)

async with anyio.create_task_group() as tg:
try:
Expand Down Expand Up @@ -132,7 +126,8 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:

async def _send_message(session_message: SessionMessage) -> None:
logger.debug(f"Sending client message: {session_message}")
response = await client.post(
endpoint_url,
Expand All @@ -144,6 +139,14 @@ async def post_writer(endpoint_url: str):
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")

async for session_message in write_stream_reader:
sender_ctx = write_stream_reader.last_context
if sender_ctx is not None:
async with anyio.create_task_group() as tg:
sender_ctx.run(tg.start_soon, _send_message, session_message)
else:
await _send_message(session_message) # pragma: no cover
except Exception: # pragma: lax no cover
logger.exception("Error in post_writer")
finally:
Expand Down
23 changes: 16 additions & 7 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from pydantic import ValidationError

from mcp.client._transport import TransportStreams
from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
Expand All @@ -38,8 +38,8 @@

# TODO(Marcelo): Put the TransportStreams in a module under shared, so we can import here.
SessionMessageOrError = SessionMessage | Exception
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
StreamReader = MemoryObjectReceiveStream[SessionMessage]
StreamWriter = ContextSendStream[SessionMessageOrError]
StreamReader = ContextReceiveStream[SessionMessage]

MCP_SESSION_ID = "mcp-session-id"
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
Expand Down Expand Up @@ -434,14 +434,15 @@ async def post_writer(
client: httpx.AsyncClient,
write_stream_reader: StreamReader,
read_stream_writer: StreamWriter,
write_stream: MemoryObjectSendStream[SessionMessage],
write_stream: ContextSendStream[SessionMessage],
start_get_stream: Callable[[], None],
tg: TaskGroup,
) -> None:
"""Handle writing requests to the server."""
try:
async with write_stream_reader:
async for session_message in write_stream_reader:

async def _handle_message(session_message: SessionMessage) -> None:
message = session_message.message
metadata = (
session_message.metadata
Expand Down Expand Up @@ -478,6 +479,14 @@ async def handle_request_async():
else:
await handle_request_async()

async for session_message in write_stream_reader:
sender_ctx = write_stream_reader.last_context
if sender_ctx is not None:
async with anyio.create_task_group() as tg_local:
sender_ctx.run(tg_local.start_soon, _handle_message, session_message)
else:
await _handle_message(session_message) # pragma: no cover

except Exception: # pragma: lax no cover
logger.exception("Error in post_writer")
finally:
Expand Down Expand Up @@ -533,8 +542,8 @@ async def streamable_http_client(
Example:
See examples/snippets/clients/ for usage patterns.
"""
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)

# Determine if we need to create and manage the client
client_provided = http_client is not None
Expand Down
15 changes: 11 additions & 4 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def main():

from __future__ import annotations

import contextvars
import logging
import warnings
from collections.abc import AsyncIterator, Awaitable, Callable
Expand All @@ -44,7 +45,6 @@ async def main():
from typing import Any, Generic

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
Expand All @@ -65,6 +65,7 @@ async def main():
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.exceptions import MCPError
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder
Expand Down Expand Up @@ -355,8 +356,8 @@ def session_manager(self) -> StreamableHTTPSessionManager:

async def run(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_stream: ReadStream[SessionMessage | Exception],
write_stream: WriteStream[SessionMessage],
initialization_options: InitializationOptions,
# When False, exceptions are returned as messages to the client.
# When True, exceptions are raised, which will cause the server to shut down
Expand Down Expand Up @@ -390,7 +391,13 @@ async def run(
async for message in session.incoming_messages:
logger.debug("Received message: %s", message)

tg.start_soon(
if isinstance(message, RequestResponder) and message.context is not None:
context = message.context
else:
context = contextvars.copy_context()

context.run(
tg.start_soon,
self._handle_message,
message,
session,
Expand Down
7 changes: 4 additions & 3 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:

import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from anyio.streams.memory import MemoryObjectReceiveStream
from pydantic import AnyUrl, TypeAdapter

from mcp import types
from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures
from mcp.server.models import InitializationOptions
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.exceptions import StatelessModeNotSupported
from mcp.shared.experimental.tasks.capabilities import check_tasks_capability
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY
Expand Down Expand Up @@ -79,8 +80,8 @@ class ServerSession(

def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_stream: ReadStream[SessionMessage | Exception],
write_stream: WriteStream[SessionMessage],
init_options: InitializationOptions,
stateless: bool = False,
) -> None:
Expand Down
13 changes: 4 additions & 9 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ async def handle_sse(request):
from uuid import UUID, uuid4

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from starlette.requests import Request
Expand All @@ -55,6 +54,7 @@ async def handle_sse(request):
TransportSecurityMiddleware,
TransportSecuritySettings,
)
from mcp.shared._context_streams import ContextSendStream, create_context_streams
from mcp.shared.message import ServerMessageMetadata, SessionMessage

logger = logging.getLogger(__name__)
Expand All @@ -72,7 +72,7 @@ class SseServerTransport:
"""

_endpoint: str
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
_read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]]
_security: TransportSecurityMiddleware

def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None:
Expand Down Expand Up @@ -129,14 +129,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag
raise ValueError("Request validation failed")

logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]

write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)

session_id = uuid4()
self._read_stream_writers[session_id] = read_stream_writer
Expand Down
12 changes: 3 additions & 9 deletions src/mcp/server/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ async def run_server():

import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp import types
from mcp.shared._context_streams import create_context_streams
from mcp.shared.message import SessionMessage


Expand All @@ -43,14 +43,8 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.
if not stdout:
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))

read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]

write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)

async def stdin_reader():
try:
Expand Down
18 changes: 10 additions & 8 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from starlette.types import Receive, Scope, Send

from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings
from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams
from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types import (
Expand Down Expand Up @@ -119,10 +121,10 @@ class StreamableHTTPServerTransport:
"""

# Server notification streams for POST requests as well as standalone SSE stream
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
_read_stream_writer: ContextSendStream[SessionMessage | Exception] | None = None
_read_stream: ContextReceiveStream[SessionMessage | Exception] | None = None
_write_stream: ContextSendStream[SessionMessage] | None = None
_write_stream_reader: ContextReceiveStream[SessionMessage] | None = None
_security: TransportSecurityMiddleware

def __init__(
Expand Down Expand Up @@ -954,8 +956,8 @@ async def connect(
self,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
ReadStream[SessionMessage | Exception],
WriteStream[SessionMessage],
],
None,
]:
Expand All @@ -967,8 +969,8 @@ async def connect(

# Create the memory streams for this connection

read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)

# Store the streams
self._read_stream_writer = read_stream_writer
Expand Down
Loading
Loading