From 6bc5d3b52b702a67c49191cd8f31df97744ee45d Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Mon, 18 May 2026 15:16:06 +0200 Subject: [PATCH 1/3] fix(streamable-http): close SSE responses on errors --- src/mcp/client/streamable_http.py | 36 +++--- .../test_streamable_http_response_cleanup.py | 116 ++++++++++++++++++ 2 files changed, 137 insertions(+), 15 deletions(-) create mode 100644 tests/client/test_streamable_http_response_cleanup.py diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9a119c633..58688488c 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -236,20 +236,25 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch original_request_id = ctx.session_message.message.id - async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source: - event_source.response.raise_for_status() - logger.debug("Resumption GET SSE connection established") + event_source: EventSource | None = None + try: + async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as es: + event_source = es + event_source.response.raise_for_status() + logger.debug("Resumption GET SSE connection established") - async for sse in event_source.aiter_sse(): # pragma: no branch - is_complete = await self._handle_sse_event( - sse, - ctx.read_stream_writer, - original_request_id, - ctx.metadata.on_resumption_token_update if ctx.metadata else None, - ) - if is_complete: - await event_source.response.aclose() - break + async for sse in event_source.aiter_sse(): # pragma: no branch + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + break + finally: + if event_source is not None: + await event_source.response.aclose() async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" @@ -361,10 +366,11 @@ async def _handle_sse_response( # If the SSE event indicates completion, like returning response/error # break the loop if is_complete: - await response.aclose() return # Normal completion, no reconnect needed except Exception: - logger.debug("SSE stream ended", exc_info=True) # pragma: no cover + logger.debug("SSE stream ended", exc_info=True) + finally: + await response.aclose() # Stream ended without response - reconnect if we received an event with ID if last_event_id is not None: # pragma: no branch diff --git a/tests/client/test_streamable_http_response_cleanup.py b/tests/client/test_streamable_http_response_cleanup.py new file mode 100644 index 000000000..d30c3212a --- /dev/null +++ b/tests/client/test_streamable_http_response_cleanup.py @@ -0,0 +1,116 @@ +import contextlib + +import anyio +import httpx +import pytest +from httpx_sse import ServerSentEvent + +from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport +from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.types import JSONRPCRequest + + +class _RaiseEventSource: + def __init__(self, response: httpx.Response) -> None: + self.response = response + + async def aiter_sse(self): + yield ServerSentEvent(event="message", data="", id=None, retry=None) + raise RuntimeError("boom") + + +@pytest.mark.anyio +async def test_handle_sse_response_closes_response_on_exception(monkeypatch: pytest.MonkeyPatch) -> None: + closed = False + + async def spy_aclose() -> None: + nonlocal closed + closed = True + + response = httpx.Response(200, headers={"content-type": "text/event-stream"}) + response.aclose = spy_aclose # type: ignore[method-assign] + + monkeypatch.setattr("mcp.client.streamable_http.EventSource", _RaiseEventSource) + + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + async with send_stream, receive_stream: + transport = StreamableHTTPTransport("http://example.invalid/mcp") + async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client: + ctx = RequestContext( + client=client, + session_id=None, + session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)), + metadata=ClientMessageMetadata(), + read_stream_writer=send_stream, + ) + await transport._handle_sse_response(response, ctx) + + assert closed + + +@pytest.mark.anyio +async def test_handle_resumption_request_closes_response_when_aconnect_sse_raises( + monkeypatch: pytest.MonkeyPatch, +) -> None: + @contextlib.asynccontextmanager + async def fake_aconnect_sse(*_args, **_kwargs): + raise RuntimeError("connect failed") + yield + + monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse) + + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + async with send_stream, receive_stream: + transport = StreamableHTTPTransport("http://example.invalid/mcp") + metadata = ClientMessageMetadata(resumption_token="1") + async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client: + ctx = RequestContext( + client=client, + session_id=None, + session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)), + metadata=metadata, + read_stream_writer=send_stream, + ) + + with pytest.raises(RuntimeError, match="connect failed"): + await transport._handle_resumption_request(ctx) + + +@pytest.mark.anyio +async def test_handle_resumption_request_closes_response_on_exception(monkeypatch: pytest.MonkeyPatch) -> None: + closed = False + + async def spy_aclose() -> None: + nonlocal closed + closed = True + + response = httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + request=httpx.Request("GET", "http://example.invalid/mcp"), + ) + response.aclose = spy_aclose # type: ignore[method-assign] + + @contextlib.asynccontextmanager + async def fake_aconnect_sse(*_args, **_kwargs): + yield _RaiseEventSource(response) + + monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse) + + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + async with send_stream, receive_stream: + transport = StreamableHTTPTransport("http://example.invalid/mcp") + metadata = ClientMessageMetadata(resumption_token="1") + async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client: + ctx = RequestContext( + client=client, + session_id=None, + session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)), + metadata=metadata, + read_stream_writer=send_stream, + ) + + with pytest.raises(RuntimeError, match="boom"): + await transport._handle_resumption_request(ctx) + + assert closed From f0443d6e5affbe37327e0e2faa0fd590d6354ff2 Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Mon, 18 May 2026 15:27:21 +0200 Subject: [PATCH 2/3] test: satisfy pyright + branch coverage for SSE cleanup --- .../test_streamable_http_response_cleanup.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/tests/client/test_streamable_http_response_cleanup.py b/tests/client/test_streamable_http_response_cleanup.py index d30c3212a..0fa4ede44 100644 --- a/tests/client/test_streamable_http_response_cleanup.py +++ b/tests/client/test_streamable_http_response_cleanup.py @@ -1,11 +1,11 @@ import contextlib -import anyio import httpx import pytest from httpx_sse import ServerSentEvent from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport +from mcp.shared._context_streams import create_context_streams from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import JSONRPCRequest @@ -32,7 +32,7 @@ async def spy_aclose() -> None: monkeypatch.setattr("mcp.client.streamable_http.EventSource", _RaiseEventSource) - send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + send_stream, receive_stream = create_context_streams[SessionMessage | Exception](1) async with send_stream, receive_stream: transport = StreamableHTTPTransport("http://example.invalid/mcp") async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client: @@ -53,13 +53,13 @@ async def test_handle_resumption_request_closes_response_when_aconnect_sse_raise monkeypatch: pytest.MonkeyPatch, ) -> None: @contextlib.asynccontextmanager - async def fake_aconnect_sse(*_args, **_kwargs): + async def fake_aconnect_sse(*_args: object, **_kwargs: object): raise RuntimeError("connect failed") yield monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse) - send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + send_stream, receive_stream = create_context_streams[SessionMessage | Exception](1) async with send_stream, receive_stream: transport = StreamableHTTPTransport("http://example.invalid/mcp") metadata = ClientMessageMetadata(resumption_token="1") @@ -72,8 +72,14 @@ async def fake_aconnect_sse(*_args, **_kwargs): read_stream_writer=send_stream, ) - with pytest.raises(RuntimeError, match="connect failed"): + error: RuntimeError | None = None + try: await transport._handle_resumption_request(ctx) + except RuntimeError as exc: + error = exc + + assert error is not None + assert str(error) == "connect failed" @pytest.mark.anyio @@ -92,12 +98,12 @@ async def spy_aclose() -> None: response.aclose = spy_aclose # type: ignore[method-assign] @contextlib.asynccontextmanager - async def fake_aconnect_sse(*_args, **_kwargs): + async def fake_aconnect_sse(*_args: object, **_kwargs: object): yield _RaiseEventSource(response) monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse) - send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + send_stream, receive_stream = create_context_streams[SessionMessage | Exception](1) async with send_stream, receive_stream: transport = StreamableHTTPTransport("http://example.invalid/mcp") metadata = ClientMessageMetadata(resumption_token="1") @@ -110,7 +116,13 @@ async def fake_aconnect_sse(*_args, **_kwargs): read_stream_writer=send_stream, ) - with pytest.raises(RuntimeError, match="boom"): + error: RuntimeError | None = None + try: await transport._handle_resumption_request(ctx) + except RuntimeError as exc: + error = exc + + assert error is not None + assert str(error) == "boom" assert closed From fe40b80191ae22f39f8a63c43eb695957d98d2c1 Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Mon, 18 May 2026 15:40:01 +0200 Subject: [PATCH 3/3] fix(streamable-http): stabilize branch coverage --- src/mcp/client/streamable_http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 58688488c..fdd024b7d 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -253,7 +253,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: if is_complete: break finally: - if event_source is not None: + if event_source is not None: # pragma: no branch await event_source.response.aclose() async def _handle_post_request(self, ctx: RequestContext) -> None: