Skip to content

Commit 21a3979

Browse files
committed
fix: eliminate port allocation race in test_streamable_http fixtures
The previous pattern picked a free port via socket.bind(0), released it, then started a uvicorn subprocess hoping to rebind — a TOCTOU race that caused intermittent CI failures when pytest-xdist workers stole the port between release and rebind (ConnectError, 404 against wrong server). Added run_uvicorn_in_thread() which pre-binds the listening socket with port=0 and passes it to uvicorn via server.run(sockets=[sock]). The port is held atomically from bind until shutdown and is known before the server thread even starts — no polling, no race. The kernel's listen queue buffers any connections that arrive during uvicorn startup. Migrated the four test_streamable_http.py fixtures (basic_server, event_server, json_response_server, context_aware_server) that share create_app(). These include the SSE auto-reconnect tests that genuinely need real TCP to exercise connection lifecycle. Running the server in-process means coverage now tracks transport code that was previously subprocess-invisible; adjusted pragmas accordingly (targeted no-cover on unreached error paths, lax no-cover on timing-dependent branches). wait_for_server() is kept for files not touched by this PR.
1 parent 2c73a2a commit 21a3979

File tree

5 files changed

+199
-279
lines changed

5 files changed

+199
-279
lines changed

src/mcp/server/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ async def send_log_message(
222222
related_request_id,
223223
)
224224

225-
async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover
225+
async def send_resource_updated(self, uri: str | AnyUrl) -> None:
226226
"""Send a resource updated notification."""
227227
await self.send_notification(
228228
types.ResourceUpdatedNotification(

src/mcp/server/streamable_http.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def is_terminated(self) -> bool:
177177
"""Check if this transport has been explicitly terminated."""
178178
return self._terminated
179179

180-
def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover
180+
def close_sse_stream(self, request_id: RequestId) -> None:
181181
"""Close SSE connection for a specific request without terminating the stream.
182182
183183
This method closes the HTTP connection for the specified request, triggering
@@ -200,12 +200,12 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover
200200
writer.close()
201201

202202
# Also close and remove request streams
203-
if request_id in self._request_streams:
203+
if request_id in self._request_streams: # pragma: no branch
204204
send_stream, receive_stream = self._request_streams.pop(request_id)
205205
send_stream.close()
206206
receive_stream.close()
207207

208-
def close_standalone_sse_stream(self) -> None: # pragma: no cover
208+
def close_standalone_sse_stream(self) -> None:
209209
"""Close the standalone GET SSE stream, triggering client reconnection.
210210
211211
This method closes the HTTP connection for the standalone GET stream used
@@ -240,10 +240,10 @@ def _create_session_message(
240240
# Only provide close callbacks when client supports resumability
241241
if self._event_store and protocol_version >= "2025-11-25":
242242

243-
async def close_stream_callback() -> None: # pragma: no cover
243+
async def close_stream_callback() -> None:
244244
self.close_sse_stream(request_id)
245245

246-
async def close_standalone_stream_callback() -> None: # pragma: no cover
246+
async def close_standalone_stream_callback() -> None:
247247
self.close_standalone_sse_stream()
248248

249249
metadata = ServerMessageMetadata(
@@ -291,7 +291,7 @@ def _create_error_response(
291291
) -> Response:
292292
"""Create an error response with a simple string message."""
293293
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
294-
if headers: # pragma: no cover
294+
if headers:
295295
response_headers.update(headers)
296296

297297
if self.mcp_session_id:
@@ -342,7 +342,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
342342
}
343343

344344
# If an event ID was provided, include it
345-
if event_message.event_id: # pragma: no cover
345+
if event_message.event_id:
346346
event_data["id"] = event_message.event_id
347347

348348
return event_data
@@ -372,7 +372,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
372372
await error_response(scope, receive, send)
373373
return
374374

375-
if self._terminated: # pragma: no cover
375+
if self._terminated:
376376
# If the session has been terminated, return 404 Not Found
377377
response = self._create_error_response(
378378
"Not Found: Session has been terminated",
@@ -387,7 +387,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
387387
await self._handle_get_request(request, send)
388388
elif request.method == "DELETE":
389389
await self._handle_delete_request(request, send)
390-
else: # pragma: no cover
390+
else:
391391
await self._handle_unsupported_request(request, send)
392392

393393
def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
@@ -467,7 +467,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
467467

468468
try:
469469
message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False)
470-
except ValidationError as e: # pragma: no cover
470+
except ValidationError as e:
471471
response = self._create_error_response(
472472
f"Validation error: {str(e)}",
473473
HTTPStatus.BAD_REQUEST,
@@ -493,7 +493,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
493493
)
494494
await response(scope, receive, send)
495495
return
496-
elif not await self._validate_request_headers(request, send): # pragma: no cover
496+
elif not await self._validate_request_headers(request, send):
497497
return
498498

499499
# For notifications and responses only, return 202 Accepted
@@ -633,7 +633,7 @@ async def sse_writer(): # pragma: lax no cover
633633
finally:
634634
await sse_stream_reader.aclose()
635635

636-
except Exception as err: # pragma: no cover
636+
except Exception as err: # pragma: lax no cover
637637
logger.exception("Error handling POST request")
638638
response = self._create_error_response(
639639
f"Error handling POST request: {err}",
@@ -659,19 +659,19 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
659659
# Validate Accept header - must include text/event-stream
660660
_, has_sse = self._check_accept_headers(request)
661661

662-
if not has_sse: # pragma: no cover
662+
if not has_sse:
663663
response = self._create_error_response(
664664
"Not Acceptable: Client must accept text/event-stream",
665665
HTTPStatus.NOT_ACCEPTABLE,
666666
)
667667
await response(request.scope, request.receive, send)
668668
return
669669

670-
if not await self._validate_request_headers(request, send): # pragma: no cover
671-
return
670+
if not await self._validate_request_headers(request, send):
671+
return # pragma: no cover
672672

673673
# Handle resumability: check for Last-Event-ID header
674-
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover
674+
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
675675
await self._replay_events(last_event_id, request, send)
676676
return
677677

@@ -681,11 +681,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
681681
"Content-Type": CONTENT_TYPE_SSE,
682682
}
683683

684-
if self.mcp_session_id:
684+
if self.mcp_session_id: # pragma: no branch
685685
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
686686

687687
# Check if we already have an active GET stream
688-
if GET_STREAM_KEY in self._request_streams: # pragma: no cover
688+
if GET_STREAM_KEY in self._request_streams:
689689
response = self._create_error_response(
690690
"Conflict: Only one SSE stream is allowed per session",
691691
HTTPStatus.CONFLICT,
@@ -714,7 +714,7 @@ async def standalone_sse_writer():
714714
# Send the message via SSE
715715
event_data = self._create_event_data(event_message)
716716
await sse_stream_writer.send(event_data)
717-
except Exception: # pragma: no cover
717+
except Exception: # pragma: lax no cover
718718
logger.exception("Error in standalone SSE writer")
719719
finally:
720720
logger.debug("Closing standalone SSE writer")
@@ -791,13 +791,13 @@ async def terminate(self) -> None:
791791
# During cleanup, we catch all exceptions since streams might be in various states
792792
logger.debug(f"Error closing streams: {e}")
793793

794-
async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover
794+
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
795795
"""Handle unsupported HTTP methods."""
796796
headers = {
797797
"Content-Type": CONTENT_TYPE_JSON,
798798
"Allow": "GET, POST, DELETE",
799799
}
800-
if self.mcp_session_id:
800+
if self.mcp_session_id: # pragma: no branch
801801
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
802802

803803
response = self._create_error_response(
@@ -824,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool:
824824
request_session_id = self._get_session_id(request)
825825

826826
# If no session ID provided but required, return error
827-
if not request_session_id: # pragma: no cover
827+
if not request_session_id:
828828
response = self._create_error_response(
829829
"Bad Request: Missing session ID",
830830
HTTPStatus.BAD_REQUEST,
@@ -849,11 +849,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
849849
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
850850

851851
# If no protocol version provided, assume default version
852-
if protocol_version is None: # pragma: no cover
852+
if protocol_version is None:
853853
protocol_version = DEFAULT_NEGOTIATED_VERSION
854854

855855
# Check if the protocol version is supported
856-
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover
856+
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
857857
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
858858
response = self._create_error_response(
859859
f"Bad Request: Unsupported protocol version: {protocol_version}. "
@@ -865,13 +865,13 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
865865

866866
return True
867867

868-
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover
868+
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
869869
"""Replays events that would have been sent after the specified event ID.
870870
871871
Only used when resumability is enabled.
872872
"""
873873
event_store = self._event_store
874-
if not event_store:
874+
if not event_store: # pragma: no cover
875875
return
876876

877877
try:
@@ -881,7 +881,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send)
881881
"Content-Type": CONTENT_TYPE_SSE,
882882
}
883883

884-
if self.mcp_session_id:
884+
if self.mcp_session_id: # pragma: no branch
885885
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
886886

887887
# Get protocol version from header (already validated in _validate_protocol_version)
@@ -902,7 +902,7 @@ async def send_event(event_message: EventMessage) -> None:
902902
stream_id = await event_store.replay_events_after(last_event_id, send_event)
903903

904904
# If stream ID not in mapping, create it
905-
if stream_id and stream_id not in self._request_streams:
905+
if stream_id and stream_id not in self._request_streams: # pragma: no branch
906906
# Register SSE writer so close_sse_stream() can close it
907907
self._sse_stream_writers[stream_id] = sse_stream_writer
908908

@@ -919,10 +919,10 @@ async def send_event(event_message: EventMessage) -> None:
919919
event_data = self._create_event_data(event_message)
920920

921921
await sse_stream_writer.send(event_data)
922-
except anyio.ClosedResourceError:
922+
except anyio.ClosedResourceError: # pragma: lax no cover
923923
# Expected when close_sse_stream() is called
924924
logger.debug("Replay SSE stream closed by close_sse_stream()")
925-
except Exception:
925+
except Exception: # pragma: lax no cover
926926
logger.exception("Error in replay sender")
927927

928928
# Create and start EventSourceResponse
@@ -934,13 +934,13 @@ async def send_event(event_message: EventMessage) -> None:
934934

935935
try:
936936
await response(request.scope, request.receive, send)
937-
except Exception:
937+
except Exception: # pragma: no cover
938938
logger.exception("Error in replay response")
939939
finally:
940940
await sse_stream_writer.aclose()
941941
await sse_stream_reader.aclose()
942942

943-
except Exception:
943+
except Exception: # pragma: no cover
944944
logger.exception("Error replaying events")
945945
response = self._create_error_response(
946946
"Error replaying events",
@@ -991,7 +991,7 @@ async def message_router():
991991
if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None:
992992
target_request_id = str(message.id)
993993
# Extract related_request_id from meta if it exists
994-
elif ( # pragma: no cover
994+
elif (
995995
session_message.metadata is not None
996996
and isinstance(
997997
session_message.metadata,
@@ -1015,10 +1015,10 @@ async def message_router():
10151015
try:
10161016
# Send both the message and the event ID
10171017
await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
1018-
except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover
1018+
except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: lax no cover
10191019
# Stream might be closed, remove from registry
10201020
self._request_streams.pop(request_stream_id, None)
1021-
else: # pragma: no cover
1021+
else:
10221022
logger.debug(
10231023
f"""Request stream {request_stream_id} not found
10241024
for message. Still processing message as the client

src/mcp/server/transport_security.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
4040
# If not specified, disable DNS rebinding protection by default for backwards compatibility
4141
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)
4242

43-
def _validate_host(self, host: str | None) -> bool: # pragma: no cover
43+
def _validate_host(self, host: str | None) -> bool: # pragma: lax no cover
4444
"""Validate the Host header against allowed values."""
4545
if not host:
4646
logger.warning("Missing Host header in request")
@@ -62,7 +62,7 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover
6262
logger.warning(f"Invalid Host header: {host}")
6363
return False
6464

65-
def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
65+
def _validate_origin(self, origin: str | None) -> bool: # pragma: lax no cover
6666
"""Validate the Origin header against allowed values."""
6767
# Origin can be absent for same-origin requests
6868
if not origin:
@@ -103,14 +103,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
103103
if not self.settings.enable_dns_rebinding_protection:
104104
return None
105105

106-
# Validate Host header # pragma: no cover
107-
host = request.headers.get("host") # pragma: no cover
108-
if not self._validate_host(host): # pragma: no cover
109-
return Response("Invalid Host header", status_code=421) # pragma: no cover
106+
# Validate Host header
107+
host = request.headers.get("host") # pragma: lax no cover
108+
if not self._validate_host(host): # pragma: lax no cover
109+
return Response("Invalid Host header", status_code=421)
110110

111-
# Validate Origin header # pragma: no cover
112-
origin = request.headers.get("origin") # pragma: no cover
113-
if not self._validate_origin(origin): # pragma: no cover
114-
return Response("Invalid Origin header", status_code=403) # pragma: no cover
111+
# Validate Origin header
112+
origin = request.headers.get("origin") # pragma: lax no cover
113+
if not self._validate_origin(origin): # pragma: lax no cover
114+
return Response("Invalid Origin header", status_code=403)
115115

116-
return None # pragma: no cover
116+
return None # pragma: lax no cover

0 commit comments

Comments
 (0)