Skip to content

Commit 5ab3d9b

Browse files
committed
fix: eliminate test port allocation race by running uvicorn in-thread
The previous pattern picked a free port via socket.bind(0), released it, then started a uvicorn subprocess hoping to rebind — a TOCTOU race that caused intermittent CI failures when pytest-xdist workers stole the port between release and rebind (connection errors, 404s against wrong server, WS 403s). Replaced with a context manager that runs uvicorn in a background thread with port=0 and reads the actual bound port back from the server's socket after startup. The OS atomically assigns the port at bind time and the server holds it until shutdown — no race window. Also fixed a latent stream leak in SseServerTransport.connect_sse() that the in-process server surfaced: sse_stream_reader was never closed on normal completion or cancellation (sse_starlette only closes it on SendTimeout). Cleanup now runs in a finally block. Side benefits: no more subprocess spawn overhead, faster startup, shared process state when needed, and wait_for_server() is no longer used.
1 parent 62eb08e commit 5ab3d9b

File tree

9 files changed

+276
-683
lines changed

9 files changed

+276
-683
lines changed

src/mcp/server/sse.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,19 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send):
181181
In this case we close our side of the streams to signal the client that
182182
the connection has been closed.
183183
"""
184-
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
185-
scope, receive, send
186-
)
187-
await read_stream_writer.aclose()
188-
await write_stream_reader.aclose()
189-
self._read_stream_writers.pop(session_id, None)
190-
logging.debug(f"Client session disconnected {session_id}")
184+
try:
185+
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
186+
scope, receive, send
187+
)
188+
finally:
189+
# EventSourceResponse does not close its body iterator on
190+
# normal completion or cancellation (only on SendTimeout),
191+
# so we must close it here to avoid leaking the stream.
192+
await sse_stream_reader.aclose()
193+
await read_stream_writer.aclose()
194+
await write_stream_reader.aclose()
195+
self._read_stream_writers.pop(session_id, None)
196+
logging.debug(f"Client session disconnected {session_id}")
191197

192198
logger.debug("Starting SSE response task")
193199
tg.start_soon(response_wrapper, scope, receive, send)

tests/client/test_http_unicode.py

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
(server→client and client→server) using the streamable HTTP transport.
55
"""
66

7-
import multiprocessing
8-
import socket
97
from collections.abc import AsyncGenerator, Generator
108
from contextlib import asynccontextmanager
119

@@ -19,7 +17,7 @@
1917
from mcp.server import Server, ServerRequestContext
2018
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
2119
from mcp.types import TextContent, Tool
22-
from tests.test_helpers import wait_for_server
20+
from tests.test_helpers import run_uvicorn_in_thread
2321

2422
# Test constants with various Unicode characters
2523
UNICODE_TEST_STRINGS = {
@@ -41,11 +39,9 @@
4139
}
4240

4341

44-
def run_unicode_server(port: int) -> None: # pragma: no cover
45-
"""Run the Unicode test server in a separate process."""
46-
import uvicorn
42+
def make_unicode_server_app() -> Starlette: # pragma: no cover
43+
"""Create the Unicode test server app."""
4744

48-
# Need to recreate the server setup in this process
4945
async def handle_list_tools(
5046
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
5147
) -> types.ListToolsResult:
@@ -129,51 +125,20 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
129125
yield
130126

131127
# Create an ASGI application
132-
app = Starlette(
128+
return Starlette(
133129
debug=True,
134130
routes=[
135131
Mount("/mcp", app=session_manager.handle_request),
136132
],
137133
lifespan=lifespan,
138134
)
139135

140-
# Run the server
141-
config = uvicorn.Config(
142-
app=app,
143-
host="127.0.0.1",
144-
port=port,
145-
log_level="error",
146-
)
147-
uvicorn_server = uvicorn.Server(config)
148-
uvicorn_server.run()
149-
150-
151-
@pytest.fixture
152-
def unicode_server_port() -> int:
153-
"""Find an available port for the Unicode test server."""
154-
with socket.socket() as s:
155-
s.bind(("127.0.0.1", 0))
156-
return s.getsockname()[1]
157-
158136

159137
@pytest.fixture
160-
def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]:
161-
"""Start a Unicode test server in a separate process."""
162-
proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True)
163-
proc.start()
164-
165-
# Wait for server to be ready
166-
wait_for_server(unicode_server_port)
167-
168-
try:
169-
yield f"http://127.0.0.1:{unicode_server_port}"
170-
finally:
171-
# Clean up - try graceful termination first
172-
proc.terminate()
173-
proc.join(timeout=2)
174-
if proc.is_alive(): # pragma: no cover
175-
proc.kill()
176-
proc.join(timeout=1)
138+
def running_unicode_server() -> Generator[str, None, None]:
139+
"""Start a Unicode test server in a background thread."""
140+
with run_uvicorn_in_thread(make_unicode_server_app()) as url:
141+
yield url
177142

178143

179144
@pytest.mark.anyio

tests/server/mcpserver/test_integration.py

Lines changed: 31 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@
1010
# pyright: reportUnknownArgumentType=false
1111

1212
import json
13-
import multiprocessing
14-
import socket
1513
from collections.abc import Generator
1614

1715
import pytest
18-
import uvicorn
1916
from inline_snapshot import snapshot
17+
from starlette.applications import Starlette
2018

2119
from examples.snippets.servers import (
2220
basic_prompt,
@@ -58,7 +56,7 @@
5856
TextResourceContents,
5957
ToolListChangedNotification,
6058
)
61-
from tests.test_helpers import wait_for_server
59+
from tests.test_helpers import run_uvicorn_in_thread
6260

6361

6462
class NotificationCollector:
@@ -85,23 +83,8 @@ async def handle_generic_notification(
8583
self.tool_notifications.append(message.params)
8684

8785

88-
# Common fixtures
89-
@pytest.fixture
90-
def server_port() -> int:
91-
"""Get a free port for testing."""
92-
with socket.socket() as s:
93-
s.bind(("127.0.0.1", 0))
94-
return s.getsockname()[1]
95-
96-
97-
@pytest.fixture
98-
def server_url(server_port: int) -> str:
99-
"""Get the server URL for testing."""
100-
return f"http://127.0.0.1:{server_port}"
101-
102-
103-
def run_server_with_transport(module_name: str, port: int, transport: str) -> None: # pragma: no cover
104-
"""Run server with specified transport."""
86+
def make_transport_app(module_name: str, transport: str) -> Starlette: # pragma: no cover
87+
"""Create server app for the specified example module and transport."""
10588
# Get the MCP instance based on module name
10689
if module_name == "basic_tool":
10790
mcp = basic_tool.mcp
@@ -128,46 +111,27 @@ def run_server_with_transport(module_name: str, port: int, transport: str) -> No
128111

129112
# Create app based on transport type
130113
if transport == "sse":
131-
app = mcp.sse_app()
114+
return mcp.sse_app()
132115
elif transport == "streamable-http":
133-
app = mcp.streamable_http_app()
116+
return mcp.streamable_http_app()
134117
else:
135118
raise ValueError(f"Invalid transport for test server: {transport}")
136119

137-
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="error"))
138-
print(f"Starting {transport} server on port {port}")
139-
server.run()
140-
141120

142121
@pytest.fixture
143-
def server_transport(request: pytest.FixtureRequest, server_port: int) -> Generator[str, None, None]:
144-
"""Start server in a separate process with specified MCP instance and transport.
122+
def server_transport(request: pytest.FixtureRequest) -> Generator[tuple[str, str], None, None]:
123+
"""Start server in a background thread with specified MCP instance and transport.
145124
146125
Args:
147126
request: pytest request with param tuple of (module_name, transport)
148-
server_port: Port to run the server on
149127
150128
Yields:
151-
str: The transport type ('sse' or 'streamable_http')
129+
tuple[str, str]: The (transport, url) pair.
152130
"""
153131
module_name, transport = request.param
154132

155-
proc = multiprocessing.Process(
156-
target=run_server_with_transport,
157-
args=(module_name, server_port, transport),
158-
daemon=True,
159-
)
160-
proc.start()
161-
162-
# Wait for server to be ready
163-
wait_for_server(server_port)
164-
165-
yield transport
166-
167-
proc.kill()
168-
proc.join(timeout=2)
169-
if proc.is_alive(): # pragma: no cover
170-
print("Server process failed to terminate")
133+
with run_uvicorn_in_thread(make_transport_app(module_name, transport)) as url:
134+
yield transport, url
171135

172136

173137
# Helper function to create client based on transport
@@ -220,9 +184,9 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E
220184
],
221185
indirect=True,
222186
)
223-
async def test_basic_tools(server_transport: str, server_url: str) -> None:
187+
async def test_basic_tools(server_transport: tuple[str, str]) -> None:
224188
"""Test basic tool functionality."""
225-
transport = server_transport
189+
transport, server_url = server_transport
226190
client_cm = create_client_for_transport(transport, server_url)
227191

228192
async with client_cm as (read_stream, write_stream):
@@ -256,9 +220,9 @@ async def test_basic_tools(server_transport: str, server_url: str) -> None:
256220
],
257221
indirect=True,
258222
)
259-
async def test_basic_resources(server_transport: str, server_url: str) -> None:
223+
async def test_basic_resources(server_transport: tuple[str, str]) -> None:
260224
"""Test basic resource functionality."""
261-
transport = server_transport
225+
transport, server_url = server_transport
262226
client_cm = create_client_for_transport(transport, server_url)
263227

264228
async with client_cm as (read_stream, write_stream):
@@ -296,9 +260,9 @@ async def test_basic_resources(server_transport: str, server_url: str) -> None:
296260
],
297261
indirect=True,
298262
)
299-
async def test_basic_prompts(server_transport: str, server_url: str) -> None:
263+
async def test_basic_prompts(server_transport: tuple[str, str]) -> None:
300264
"""Test basic prompt functionality."""
301-
transport = server_transport
265+
transport, server_url = server_transport
302266
client_cm = create_client_for_transport(transport, server_url)
303267

304268
async with client_cm as (read_stream, write_stream):
@@ -348,9 +312,9 @@ async def test_basic_prompts(server_transport: str, server_url: str) -> None:
348312
],
349313
indirect=True,
350314
)
351-
async def test_tool_progress(server_transport: str, server_url: str) -> None:
315+
async def test_tool_progress(server_transport: tuple[str, str]) -> None:
352316
"""Test tool progress reporting."""
353-
transport = server_transport
317+
transport, server_url = server_transport
354318
collector = NotificationCollector()
355319

356320
async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception):
@@ -404,9 +368,9 @@ async def progress_callback(progress: float, total: float | None, message: str |
404368
],
405369
indirect=True,
406370
)
407-
async def test_sampling(server_transport: str, server_url: str) -> None:
371+
async def test_sampling(server_transport: tuple[str, str]) -> None:
408372
"""Test sampling (LLM interaction) functionality."""
409-
transport = server_transport
373+
transport, server_url = server_transport
410374
client_cm = create_client_for_transport(transport, server_url)
411375

412376
async with client_cm as (read_stream, write_stream):
@@ -434,9 +398,9 @@ async def test_sampling(server_transport: str, server_url: str) -> None:
434398
],
435399
indirect=True,
436400
)
437-
async def test_elicitation(server_transport: str, server_url: str) -> None:
401+
async def test_elicitation(server_transport: tuple[str, str]) -> None:
438402
"""Test elicitation (user interaction) functionality."""
439-
transport = server_transport
403+
transport, server_url = server_transport
440404
client_cm = create_client_for_transport(transport, server_url)
441405

442406
async with client_cm as (read_stream, write_stream):
@@ -483,9 +447,9 @@ async def test_elicitation(server_transport: str, server_url: str) -> None:
483447
],
484448
indirect=True,
485449
)
486-
async def test_notifications(server_transport: str, server_url: str) -> None:
450+
async def test_notifications(server_transport: tuple[str, str]) -> None:
487451
"""Test notifications and logging functionality."""
488-
transport = server_transport
452+
transport, server_url = server_transport
489453
collector = NotificationCollector()
490454

491455
async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception):
@@ -530,9 +494,9 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult]
530494
],
531495
indirect=True,
532496
)
533-
async def test_completion(server_transport: str, server_url: str) -> None:
497+
async def test_completion(server_transport: tuple[str, str]) -> None:
534498
"""Test completion (autocomplete) functionality."""
535-
transport = server_transport
499+
transport, server_url = server_transport
536500
client_cm = create_client_for_transport(transport, server_url)
537501

538502
async with client_cm as (read_stream, write_stream):
@@ -582,9 +546,9 @@ async def test_completion(server_transport: str, server_url: str) -> None:
582546
],
583547
indirect=True,
584548
)
585-
async def test_mcpserver_quickstart(server_transport: str, server_url: str) -> None:
549+
async def test_mcpserver_quickstart(server_transport: tuple[str, str]) -> None:
586550
"""Test MCPServer quickstart example."""
587-
transport = server_transport
551+
transport, server_url = server_transport
588552
client_cm = create_client_for_transport(transport, server_url)
589553

590554
async with client_cm as (read_stream, write_stream):
@@ -617,9 +581,9 @@ async def test_mcpserver_quickstart(server_transport: str, server_url: str) -> N
617581
],
618582
indirect=True,
619583
)
620-
async def test_structured_output(server_transport: str, server_url: str) -> None:
584+
async def test_structured_output(server_transport: tuple[str, str]) -> None:
621585
"""Test structured output functionality."""
622-
transport = server_transport
586+
transport, server_url = server_transport
623587
client_cm = create_client_for_transport(transport, server_url)
624588

625589
async with client_cm as (read_stream, write_stream):

0 commit comments

Comments
 (0)