Skip to content

Commit 8353a9b

Browse files
committed
test: run the interaction suite over the legacy SSE transport in-process
1 parent d64f525 commit 8353a9b

8 files changed

Lines changed: 253 additions & 21 deletions

File tree

src/mcp/server/sse.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,15 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
116116
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
117117

118118
@asynccontextmanager
119-
async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover
120-
if scope["type"] != "http":
119+
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
120+
if scope["type"] != "http": # pragma: no cover
121121
logger.error("connect_sse received non-HTTP request")
122122
raise ValueError("connect_sse can only handle HTTP requests")
123123

124124
# Validate request headers for DNS rebinding protection
125125
request = Request(scope, receive)
126126
error_response = await self._security.validate_request(request, is_post=False)
127-
if error_response:
127+
if error_response: # pragma: no cover
128128
await error_response(scope, receive, send)
129129
raise ValueError("Request validation failed")
130130

@@ -190,13 +190,13 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send):
190190
logger.debug("Yielding read and write streams")
191191
yield (read_stream, write_stream)
192192

193-
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover
193+
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None:
194194
logger.debug("Handling POST message")
195195
request = Request(scope, receive)
196196

197197
# Validate request headers for DNS rebinding protection
198198
error_response = await self._security.validate_request(request, is_post=True)
199-
if error_response:
199+
if error_response: # pragma: no cover
200200
return await error_response(scope, receive, send)
201201

202202
session_id_param = request.query_params.get("session_id")
@@ -225,7 +225,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
225225
try:
226226
message = types.jsonrpc_message_adapter.validate_json(body, by_name=False)
227227
logger.debug(f"Validated client message: {message}")
228-
except ValidationError as err:
228+
except ValidationError as err: # pragma: no cover
229229
logger.exception("Failed to parse message")
230230
response = Response("Could not parse message", status_code=400)
231231
await response(scope, receive, send)

tests/interaction/_connect.py

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,31 @@
11
"""Transport-parametrized connection factories for the interaction suite.
22
33
The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body
4-
runs over the in-memory transport and over streamable HTTP without naming either: the factory is a
5-
drop-in replacement for constructing `Client(server, ...)` and yields the connected client. The
6-
streamable HTTP factory drives the server's real Starlette app through the in-process streaming
7-
bridge, so the full HTTP framing layer (session ids, SSE encoding, session management) runs with
8-
no sockets, threads, or subprocesses.
4+
runs over each transport without naming any of them: the factory is a drop-in replacement for
5+
constructing `Client(server, ...)` and yields the connected client. The HTTP factories drive the
6+
server's real Starlette app through the in-process streaming bridge, so the full transport layer
7+
(session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses.
98
"""
109

10+
import gc
11+
import warnings
1112
from collections.abc import AsyncIterator
1213
from contextlib import AbstractAsyncContextManager, asynccontextmanager
1314
from typing import Protocol
1415

1516
import httpx
17+
from starlette.applications import Starlette
18+
from starlette.requests import Request
19+
from starlette.responses import Response
20+
from starlette.routing import Mount, Route
1621

1722
from mcp.client.client import Client
1823
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
24+
from mcp.client.sse import sse_client
1925
from mcp.client.streamable_http import streamable_http_client
2026
from mcp.server import Server
2127
from mcp.server.mcpserver import MCPServer
28+
from mcp.server.sse import SseServerTransport
2229
from mcp.server.transport_security import TransportSecuritySettings
2330
from mcp.types import Implementation
2431
from tests.interaction.transports._bridge import StreamingASGITransport
@@ -115,3 +122,84 @@ async def connect_over_streamable_http(
115122
elicitation_callback=elicitation_callback,
116123
) as client:
117124
yield client
125+
126+
127+
def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]:
128+
"""Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/.
129+
130+
`MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which
131+
the SSE-specific tests need; building the app explicitly here gives both server flavours the
132+
same routing while keeping that handle.
133+
"""
134+
sse = SseServerTransport(
135+
"/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False)
136+
)
137+
lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server
138+
139+
async def handle_sse(request: Request) -> Response:
140+
async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write):
141+
await lowlevel.run(read, write, lowlevel.create_initialization_options())
142+
return Response()
143+
144+
app = Starlette(
145+
routes=[
146+
Route("/sse", endpoint=handle_sse, methods=["GET"]),
147+
Mount("/messages/", app=sse.handle_post_message),
148+
],
149+
)
150+
return app, sse
151+
152+
153+
@asynccontextmanager
154+
async def connect_over_sse(
155+
server: Server | MCPServer,
156+
*,
157+
read_timeout_seconds: float | None = None,
158+
sampling_callback: SamplingFnT | None = None,
159+
list_roots_callback: ListRootsFnT | None = None,
160+
logging_callback: LoggingFnT | None = None,
161+
message_handler: MessageHandlerFnT | None = None,
162+
client_info: Implementation | None = None,
163+
elicitation_callback: ElicitationFnT | None = None,
164+
) -> AsyncIterator[Client]:
165+
"""Yield a Client connected to the server's legacy SSE transport, entirely in process."""
166+
app, _ = build_sse_app(server)
167+
168+
def httpx_client_factory(
169+
headers: dict[str, str] | None = None,
170+
timeout: httpx.Timeout | None = None,
171+
auth: httpx.Auth | None = None,
172+
) -> httpx.AsyncClient:
173+
# The SSE server transport's connect_sse runs the entire MCP session inside the GET
174+
# request and only releases its streams after that request observes a disconnect, so the
175+
# bridge must let the application drain rather than cancelling at close.
176+
return httpx.AsyncClient(
177+
transport=StreamingASGITransport(app, cancel_on_close=False),
178+
base_url=_BASE_URL,
179+
headers=headers,
180+
timeout=timeout,
181+
auth=auth,
182+
)
183+
184+
transport = sse_client(f"{_BASE_URL}/sse", httpx_client_factory=httpx_client_factory)
185+
try:
186+
async with Client(
187+
transport,
188+
read_timeout_seconds=read_timeout_seconds,
189+
sampling_callback=sampling_callback,
190+
list_roots_callback=list_roots_callback,
191+
logging_callback=logging_callback,
192+
message_handler=message_handler,
193+
client_info=client_info,
194+
elicitation_callback=elicitation_callback,
195+
) as client:
196+
yield client
197+
finally:
198+
# SseServerTransport.connect_sse hands its internal SSE-chunk receive stream to
199+
# sse_starlette's EventSourceResponse, which never closes it when its task group is
200+
# cancelled on disconnect (see notes/findings.md). Collect the orphan here so its
201+
# ResourceWarning fires deterministically inside this fixture instead of at an
202+
# arbitrary later GC.
203+
with warnings.catch_warnings():
204+
warnings.simplefilter("ignore", ResourceWarning)
205+
gc.collect()

tests/interaction/_requirements.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,10 +1759,23 @@ def __post_init__(self) -> None:
17591759
"requests, with server messages delivered on the SSE stream."
17601760
),
17611761
transports=("sse",),
1762-
deferred=(
1763-
"The legacy SSE transport is covered by tests/shared/test_sse.py; in-process coverage in this "
1764-
"suite arrives with the transport fixture work."
1762+
),
1763+
"transport:sse:endpoint-event": Requirement(
1764+
source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility",
1765+
behavior=(
1766+
"Opening the SSE stream delivers an `endpoint` event naming the message-POST URL and a fresh "
1767+
"session identifier; the server registers the session before the event is sent and releases it "
1768+
"when the stream disconnects."
1769+
),
1770+
transports=("sse",),
1771+
),
1772+
"transport:sse:post:session-routing": Requirement(
1773+
source="sdk",
1774+
behavior=(
1775+
"A POST to the SSE message endpoint that names no session id, a malformed session id, or an "
1776+
"unknown session id is rejected (400/400/404) instead of being forwarded."
17651777
),
1778+
transports=("sse",),
17661779
),
17671780
"transport:stdio": Requirement(
17681781
source=f"{SPEC_BASE_URL}/basic/transports#stdio",

tests/interaction/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
import pytest
44

5-
from tests.interaction._connect import Connect, connect_in_memory, connect_over_streamable_http
5+
from tests.interaction._connect import Connect, connect_in_memory, connect_over_sse, connect_over_streamable_http
66

77
_FACTORIES: dict[str, Connect] = {
88
"in-memory": connect_in_memory,
99
"streamable-http": connect_over_streamable_http,
10+
"sse": connect_over_sse,
1011
}
1112

1213

tests/interaction/test_coverage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"tests.interaction.transports.test_bridge.test_response_chunks_arrive_as_the_application_sends_them",
3030
"tests.interaction.transports.test_bridge.test_closing_the_response_delivers_a_disconnect_to_the_application",
3131
"tests.interaction.transports.test_bridge.test_an_application_failure_before_the_response_starts_fails_the_request",
32+
"tests.interaction.transports.test_bridge.test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect",
3233
}
3334

3435

tests/interaction/transports/_bridge.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
2222
The transport owns an anyio task group for the application tasks; it is opened and closed by
2323
`httpx.AsyncClient`'s own context manager, so use the client as a context manager (the suite
24-
always does).
24+
always does). Closing the transport cancels every running application task by default; set
25+
`cancel_on_close=False` to wait for the application's own disconnect handling instead.
2526
"""
2627

2728
import math
@@ -56,12 +57,19 @@ async def aclose(self) -> None:
5657

5758

5859
class StreamingASGITransport(httpx.AsyncBaseTransport):
59-
"""Drive an ASGI application in-process, streaming each response as it is produced."""
60+
"""Drive an ASGI application in-process, streaming each response as it is produced.
61+
62+
With `cancel_on_close` (the default), closing the transport cancels every application task
63+
still running so harness teardown can never hang. Setting it to False makes the transport wait
64+
for the application's own disconnect handling to complete instead, which is the path the legacy
65+
SSE server transport relies on for resource cleanup.
66+
"""
6067

6168
_task_group: anyio.abc.TaskGroup
6269

63-
def __init__(self, app: ASGIApp) -> None:
70+
def __init__(self, app: ASGIApp, *, cancel_on_close: bool = True) -> None:
6471
self._app = app
72+
self._cancel_on_close = cancel_on_close
6573

6674
async def __aenter__(self) -> "StreamingASGITransport":
6775
self._task_group = anyio.create_task_group()
@@ -74,9 +82,11 @@ async def __aexit__(
7482
exc_value: BaseException | None = None,
7583
traceback: TracebackType | None = None,
7684
) -> None:
77-
# Any application task still running at this point is serving a client that no longer
78-
# exists; cancel rather than wait so harness teardown can never hang.
79-
self._task_group.cancel_scope.cancel()
85+
# httpx closes every streamed response before closing the transport, so by now each
86+
# application task has been delivered `http.disconnect`. Either cancel immediately, or wait
87+
# for the application's own disconnect handling to unwind.
88+
if self._cancel_on_close:
89+
self._task_group.cancel_scope.cancel()
8090
await self._task_group.__aexit__(exc_type, exc_value, traceback)
8191

8292
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:

tests/interaction/transports/test_bridge.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,24 @@ async def broken_app(scope: Scope, receive: Receive, send: Send) -> None:
6969
async with httpx.AsyncClient(transport=StreamingASGITransport(broken_app), base_url="http://bridge") as http:
7070
with pytest.raises(RuntimeError, match="the demo application is broken"):
7171
await http.get("/broken")
72+
73+
74+
async def test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect() -> None:
75+
"""With cancel_on_close=False, an application that runs cleanup after seeing http.disconnect
76+
completes that cleanup before the transport finishes closing."""
77+
cleanup_ran = anyio.Event()
78+
79+
async def lingering_app(scope: Scope, receive: Receive, send: Send) -> None:
80+
assert scope["type"] == "http"
81+
await receive()
82+
await send({"type": "http.response.start", "status": 200, "headers": []})
83+
assert (await receive())["type"] == "http.disconnect"
84+
cleanup_ran.set()
85+
86+
transport = StreamingASGITransport(lingering_app, cancel_on_close=False)
87+
async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http:
88+
with anyio.fail_after(5):
89+
async with http.stream("GET", "/linger") as response:
90+
assert response.status_code == 200
91+
assert not cleanup_ran.is_set()
92+
assert cleanup_ran.is_set()
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Behaviour specific to the legacy HTTP+SSE transport, exercised entirely in process.
2+
3+
Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of
4+
the suite over this transport as well; this file pins only what is observable on the SSE wiring
5+
itself: the GET-then-POST connection lifecycle, the endpoint event, and how the message endpoint
6+
rejects requests it cannot route to a session. Every test drives the server's real Starlette app
7+
through the suite's streaming ASGI bridge.
8+
"""
9+
10+
import gc
11+
import warnings
12+
from uuid import UUID, uuid4
13+
14+
import anyio
15+
import httpx
16+
import pytest
17+
from inline_snapshot import snapshot
18+
19+
from mcp.client.client import Client
20+
from mcp.client.sse import sse_client
21+
from mcp.server import Server
22+
from mcp.types import EmptyResult
23+
from tests.interaction._connect import build_sse_app
24+
from tests.interaction._requirements import requirement
25+
from tests.interaction.transports._bridge import StreamingASGITransport
26+
27+
pytestmark = pytest.mark.anyio
28+
29+
_BASE_URL = "http://127.0.0.1:8000"
30+
31+
32+
@requirement("transport:sse")
33+
@requirement("transport:sse:endpoint-event")
34+
async def test_endpoint_event_names_the_message_endpoint_with_a_fresh_session_id() -> None:
35+
"""Connecting opens a GET stream whose first event names the POST endpoint and a fresh
36+
session id; messages POSTed there are answered on that stream, and disconnecting releases the
37+
server's session entry."""
38+
app, sse = build_sse_app(Server("legacy"))
39+
captured_session_id: list[str] = []
40+
41+
def httpx_client_factory(
42+
headers: dict[str, str] | None = None,
43+
timeout: httpx.Timeout | None = None,
44+
auth: httpx.Auth | None = None,
45+
) -> httpx.AsyncClient:
46+
return httpx.AsyncClient(
47+
transport=StreamingASGITransport(app, cancel_on_close=False),
48+
base_url=_BASE_URL,
49+
headers=headers,
50+
timeout=timeout,
51+
auth=auth,
52+
)
53+
54+
transport = sse_client(
55+
f"{_BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append
56+
)
57+
with anyio.fail_after(5):
58+
async with Client(transport) as client:
59+
assert len(captured_session_id) == 1
60+
assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers
61+
assert await client.send_ping() == snapshot(EmptyResult())
62+
63+
assert sse._read_stream_writers == {}
64+
# See connect_over_sse: collect the one stream sse_starlette never closes on disconnect.
65+
with warnings.catch_warnings():
66+
warnings.simplefilter("ignore", ResourceWarning)
67+
gc.collect()
68+
69+
70+
@requirement("transport:sse:post:session-routing")
71+
async def test_post_without_a_session_id_is_rejected() -> None:
72+
"""A POST to the message endpoint with no session_id query parameter is answered 400."""
73+
app, _ = build_sse_app(Server("legacy"))
74+
async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http:
75+
response = await http.post("/messages/", json={"jsonrpc": "2.0", "method": "ping", "id": 1})
76+
assert (response.status_code, response.text) == snapshot((400, "session_id is required"))
77+
78+
79+
@requirement("transport:sse:post:session-routing")
80+
async def test_post_with_a_malformed_session_id_is_rejected() -> None:
81+
"""A POST whose session_id query parameter is not a UUID is answered 400."""
82+
app, _ = build_sse_app(Server("legacy"))
83+
async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http:
84+
response = await http.post(
85+
"/messages/", params={"session_id": "not-a-uuid"}, json={"jsonrpc": "2.0", "method": "ping", "id": 1}
86+
)
87+
assert (response.status_code, response.text) == snapshot((400, "Invalid session ID"))
88+
89+
90+
@requirement("transport:sse:post:session-routing")
91+
async def test_post_for_an_unknown_session_is_rejected() -> None:
92+
"""A POST naming a well-formed session_id that no SSE stream owns is answered 404."""
93+
app, _ = build_sse_app(Server("legacy"))
94+
async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http:
95+
response = await http.post(
96+
"/messages/", params={"session_id": uuid4().hex}, json={"jsonrpc": "2.0", "method": "ping", "id": 1}
97+
)
98+
assert (response.status_code, response.text) == snapshot((404, "Could not find session"))

0 commit comments

Comments
 (0)