Skip to content

Commit eb4e691

Browse files
authored
Merge branch 'main' into fix-hang-github-mcp
2 parents 0340602 + 75a80b6 commit eb4e691

File tree

3 files changed

+161
-153
lines changed

3 files changed

+161
-153
lines changed

src/mcp/client/sse.py

Lines changed: 97 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -57,108 +57,101 @@ async def sse_client(
5757
write_stream: MemoryObjectSendStream[SessionMessage]
5858
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
5959

60-
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
61-
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
62-
63-
async with anyio.create_task_group() as tg:
64-
try:
65-
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
66-
async with httpx_client_factory(
67-
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
68-
) as client:
69-
async with aconnect_sse(
70-
client,
71-
"GET",
72-
url,
73-
) as event_source:
74-
event_source.response.raise_for_status()
75-
logger.debug("SSE connection established")
76-
77-
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
78-
try:
79-
async for sse in event_source.aiter_sse(): # pragma: no branch
80-
logger.debug(f"Received SSE event: {sse.event}")
81-
match sse.event:
82-
case "endpoint":
83-
endpoint_url = urljoin(url, sse.data)
84-
logger.debug(f"Received endpoint URL: {endpoint_url}")
85-
86-
url_parsed = urlparse(url)
87-
endpoint_parsed = urlparse(endpoint_url)
88-
if ( # pragma: no cover
89-
url_parsed.netloc != endpoint_parsed.netloc
90-
or url_parsed.scheme != endpoint_parsed.scheme
91-
):
92-
error_msg = ( # pragma: no cover
93-
f"Endpoint origin does not match connection origin: {endpoint_url}"
94-
)
95-
logger.error(error_msg) # pragma: no cover
96-
raise ValueError(error_msg) # pragma: no cover
97-
98-
if on_session_created:
99-
session_id = _extract_session_id_from_endpoint(endpoint_url)
100-
if session_id:
101-
on_session_created(session_id)
102-
103-
task_status.started(endpoint_url)
104-
105-
case "message":
106-
# Skip empty data (keep-alive pings)
107-
if not sse.data:
108-
continue
109-
try:
110-
message = types.jsonrpc_message_adapter.validate_json(
111-
sse.data, by_name=False
112-
)
113-
logger.debug(f"Received server message: {message}")
114-
except Exception as exc: # pragma: no cover
115-
logger.exception("Error parsing server message") # pragma: no cover
116-
await read_stream_writer.send(exc) # pragma: no cover
117-
continue # pragma: no cover
118-
119-
session_message = SessionMessage(message)
120-
await read_stream_writer.send(session_message)
121-
case _: # pragma: no cover
122-
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
123-
except SSEError as sse_exc: # pragma: lax no cover
124-
logger.exception("Encountered SSE exception")
125-
raise sse_exc
126-
except Exception as exc: # pragma: lax no cover
127-
logger.exception("Error in sse_reader")
128-
await read_stream_writer.send(exc)
129-
finally:
130-
await read_stream_writer.aclose()
131-
132-
async def post_writer(endpoint_url: str):
133-
try:
134-
async with write_stream_reader:
135-
async for session_message in write_stream_reader:
136-
logger.debug(f"Sending client message: {session_message}")
137-
response = await client.post(
138-
endpoint_url,
139-
json=session_message.message.model_dump(
140-
by_alias=True,
141-
mode="json",
142-
exclude_unset=True,
143-
),
60+
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
61+
async with httpx_client_factory(
62+
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
63+
) as client:
64+
async with aconnect_sse(client, "GET", url) as event_source:
65+
event_source.response.raise_for_status()
66+
logger.debug("SSE connection established")
67+
68+
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
69+
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
70+
71+
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
72+
try:
73+
async for sse in event_source.aiter_sse(): # pragma: no branch
74+
logger.debug(f"Received SSE event: {sse.event}")
75+
match sse.event:
76+
case "endpoint":
77+
endpoint_url = urljoin(url, sse.data)
78+
logger.debug(f"Received endpoint URL: {endpoint_url}")
79+
80+
url_parsed = urlparse(url)
81+
endpoint_parsed = urlparse(endpoint_url)
82+
if ( # pragma: no cover
83+
url_parsed.netloc != endpoint_parsed.netloc
84+
or url_parsed.scheme != endpoint_parsed.scheme
85+
):
86+
error_msg = ( # pragma: no cover
87+
f"Endpoint origin does not match connection origin: {endpoint_url}"
14488
)
145-
response.raise_for_status()
146-
logger.debug(f"Client message sent successfully: {response.status_code}")
147-
except Exception: # pragma: lax no cover
148-
logger.exception("Error in post_writer")
149-
finally:
150-
await write_stream.aclose()
151-
152-
endpoint_url = await tg.start(sse_reader)
153-
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
154-
tg.start_soon(post_writer, endpoint_url)
155-
156-
try:
157-
yield read_stream, write_stream
158-
finally:
159-
tg.cancel_scope.cancel()
160-
finally:
161-
await read_stream_writer.aclose()
162-
await write_stream.aclose()
163-
await read_stream.aclose()
164-
await write_stream_reader.aclose()
89+
logger.error(error_msg) # pragma: no cover
90+
raise ValueError(error_msg) # pragma: no cover
91+
92+
if on_session_created:
93+
session_id = _extract_session_id_from_endpoint(endpoint_url)
94+
if session_id:
95+
on_session_created(session_id)
96+
97+
task_status.started(endpoint_url)
98+
99+
case "message":
100+
# Skip empty data (keep-alive pings)
101+
if not sse.data:
102+
continue
103+
try:
104+
message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False)
105+
logger.debug(f"Received server message: {message}")
106+
except Exception as exc: # pragma: no cover
107+
logger.exception("Error parsing server message") # pragma: no cover
108+
await read_stream_writer.send(exc) # pragma: no cover
109+
continue # pragma: no cover
110+
111+
session_message = SessionMessage(message)
112+
await read_stream_writer.send(session_message)
113+
case _: # pragma: no cover
114+
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
115+
except SSEError as sse_exc: # pragma: lax no cover
116+
logger.exception("Encountered SSE exception")
117+
raise sse_exc
118+
except Exception as exc: # pragma: lax no cover
119+
logger.exception("Error in sse_reader")
120+
await read_stream_writer.send(exc)
121+
finally:
122+
await read_stream_writer.aclose()
123+
124+
async def post_writer(endpoint_url: str):
125+
try:
126+
async with write_stream_reader, write_stream:
127+
async for session_message in write_stream_reader:
128+
logger.debug(f"Sending client message: {session_message}")
129+
response = await client.post(
130+
endpoint_url,
131+
json=session_message.message.model_dump(
132+
by_alias=True,
133+
mode="json",
134+
exclude_unset=True,
135+
),
136+
)
137+
response.raise_for_status()
138+
logger.debug(f"Client message sent successfully: {response.status_code}")
139+
except Exception: # pragma: lax no cover
140+
logger.exception("Error in post_writer")
141+
142+
# On Python 3.14, coverage.py reports a phantom branch arc on this
143+
# line (->yield) when nested two async-with levels deep. The branch
144+
# is the unreachable "did __aexit__ suppress?" arm for memory streams.
145+
async with ( # pragma: no branch
146+
read_stream_writer,
147+
read_stream,
148+
write_stream,
149+
write_stream_reader,
150+
anyio.create_task_group() as tg,
151+
):
152+
endpoint_url = await tg.start(sse_reader)
153+
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
154+
tg.start_soon(post_writer, endpoint_url)
155+
156+
yield read_stream, write_stream
157+
tg.cancel_scope.cancel()

src/mcp/client/streamable_http.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ async def post_writer(
452452
) -> None:
453453
"""Handle writing requests to the server."""
454454
try:
455-
async with write_stream_reader:
455+
async with write_stream_reader, read_stream_writer, write_stream:
456456
async for session_message in write_stream_reader:
457457
message = session_message.message
458458
metadata = (
@@ -492,9 +492,6 @@ async def handle_request_async():
492492

493493
except Exception: # pragma: lax no cover
494494
logger.exception("Error in post_writer")
495-
finally:
496-
await read_stream_writer.aclose()
497-
await write_stream.aclose()
498495

499496
async def terminate_session(self, client: httpx.AsyncClient) -> None:
500497
"""Terminate the session by sending a DELETE request."""
@@ -545,9 +542,6 @@ async def streamable_http_client(
545542
Example:
546543
See examples/snippets/clients/ for usage patterns.
547544
"""
548-
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
549-
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
550-
551545
# Determine if we need to create and manage the client
552546
client_provided = http_client is not None
553547
client = http_client
@@ -558,36 +552,40 @@ async def streamable_http_client(
558552

559553
transport = StreamableHTTPTransport(url)
560554

561-
async with anyio.create_task_group() as tg:
562-
try:
563-
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
564-
565-
async with contextlib.AsyncExitStack() as stack:
566-
# Only manage client lifecycle if we created it
567-
if not client_provided:
568-
await stack.enter_async_context(client)
569-
570-
def start_get_stream() -> None:
571-
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
572-
573-
tg.start_soon(
574-
transport.post_writer,
575-
client,
576-
write_stream_reader,
577-
read_stream_writer,
578-
write_stream,
579-
start_get_stream,
580-
tg,
581-
)
555+
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
556+
557+
async with contextlib.AsyncExitStack() as stack:
558+
# Only manage client lifecycle if we created it
559+
if not client_provided:
560+
await stack.enter_async_context(client)
561+
562+
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
563+
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
564+
565+
async with (
566+
read_stream_writer,
567+
read_stream,
568+
write_stream,
569+
write_stream_reader,
570+
anyio.create_task_group() as tg,
571+
):
572+
573+
def start_get_stream() -> None:
574+
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
575+
576+
tg.start_soon(
577+
transport.post_writer,
578+
client,
579+
write_stream_reader,
580+
read_stream_writer,
581+
write_stream,
582+
start_get_stream,
583+
tg,
584+
)
582585

583-
try:
584-
yield read_stream, write_stream
585-
finally:
586-
if transport.session_id and terminate_on_close:
587-
await transport.terminate_session(client)
588-
tg.cancel_scope.cancel()
589-
finally:
590-
await read_stream_writer.aclose()
591-
await write_stream.aclose()
592-
await read_stream.aclose()
593-
await write_stream_reader.aclose()
586+
try:
587+
yield read_stream, write_stream
588+
finally:
589+
if transport.session_id and terminate_on_close:
590+
await transport.terminate_session(client)
591+
tg.cancel_scope.cancel()

tests/client/test_transport_stream_cleanup.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,39 @@ def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover
5858

5959
@pytest.mark.anyio
6060
async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None:
61-
"""sse_client must close all 4 stream ends when the connection fails.
61+
"""sse_client creates streams only after the SSE connection succeeds, so a
62+
ConnectError propagates directly with nothing to leak.
6263
63-
Before the fix, only read_stream_writer and write_stream were closed in
64-
the finally block. read_stream and write_stream_reader were leaked.
64+
Before the fix, streams were created before connecting and only 2 of 4 were
65+
closed in the finally block.
6566
"""
6667
with _assert_no_memory_stream_leak():
67-
# sse_client enters a task group BEFORE connecting, so anyio wraps the
68-
# ConnectError from aconnect_sse in an ExceptionGroup.
69-
with pytest.raises(Exception) as exc_info: # noqa: B017
68+
with pytest.raises(httpx.ConnectError):
7069
async with sse_client(f"http://127.0.0.1:{free_tcp_port}/sse"):
7170
pytest.fail("should not reach here") # pragma: no cover
7271

73-
assert exc_info.group_contains(httpx.ConnectError)
74-
# exc_info holds the traceback → holds frame locals → keeps leaked
75-
# streams alive. Must drop it before gc.collect() can detect a leak.
76-
del exc_info
72+
73+
@pytest.mark.anyio
74+
async def test_sse_client_closes_all_streams_on_http_error() -> None:
75+
"""sse_client creates streams only after raise_for_status() passes, so an
76+
HTTPStatusError from a 4xx/5xx response propagates bare (not wrapped in an
77+
ExceptionGroup) with nothing to leak — the task group is never entered.
78+
"""
79+
80+
def return_403(request: httpx.Request) -> httpx.Response:
81+
return httpx.Response(403)
82+
83+
def mock_factory(
84+
headers: dict[str, str] | None = None,
85+
timeout: httpx.Timeout | None = None,
86+
auth: httpx.Auth | None = None,
87+
) -> httpx.AsyncClient:
88+
return httpx.AsyncClient(transport=httpx.MockTransport(return_403))
89+
90+
with _assert_no_memory_stream_leak():
91+
with pytest.raises(httpx.HTTPStatusError):
92+
async with sse_client("http://test/sse", httpx_client_factory=mock_factory):
93+
pytest.fail("should not reach here") # pragma: no cover
7794

7895

7996
@pytest.mark.anyio

0 commit comments

Comments
 (0)