Skip to content

Commit 6bc5d3b

Browse files
committed
fix(streamable-http): close SSE responses on errors
1 parent 161834d commit 6bc5d3b

2 files changed

Lines changed: 137 additions & 15 deletions

File tree

src/mcp/client/streamable_http.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -236,20 +236,25 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
236236
if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch
237237
original_request_id = ctx.session_message.message.id
238238

239-
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
240-
event_source.response.raise_for_status()
241-
logger.debug("Resumption GET SSE connection established")
239+
event_source: EventSource | None = None
240+
try:
241+
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as es:
242+
event_source = es
243+
event_source.response.raise_for_status()
244+
logger.debug("Resumption GET SSE connection established")
242245

243-
async for sse in event_source.aiter_sse(): # pragma: no branch
244-
is_complete = await self._handle_sse_event(
245-
sse,
246-
ctx.read_stream_writer,
247-
original_request_id,
248-
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
249-
)
250-
if is_complete:
251-
await event_source.response.aclose()
252-
break
246+
async for sse in event_source.aiter_sse(): # pragma: no branch
247+
is_complete = await self._handle_sse_event(
248+
sse,
249+
ctx.read_stream_writer,
250+
original_request_id,
251+
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
252+
)
253+
if is_complete:
254+
break
255+
finally:
256+
if event_source is not None:
257+
await event_source.response.aclose()
253258

254259
async def _handle_post_request(self, ctx: RequestContext) -> None:
255260
"""Handle a POST request with response processing."""
@@ -361,10 +366,11 @@ async def _handle_sse_response(
361366
# If the SSE event indicates completion, like returning response/error
362367
# break the loop
363368
if is_complete:
364-
await response.aclose()
365369
return # Normal completion, no reconnect needed
366370
except Exception:
367-
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover
371+
logger.debug("SSE stream ended", exc_info=True)
372+
finally:
373+
await response.aclose()
368374

369375
# Stream ended without response - reconnect if we received an event with ID
370376
if last_event_id is not None: # pragma: no branch
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import contextlib
2+
3+
import anyio
4+
import httpx
5+
import pytest
6+
from httpx_sse import ServerSentEvent
7+
8+
from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport
9+
from mcp.shared.message import ClientMessageMetadata, SessionMessage
10+
from mcp.types import JSONRPCRequest
11+
12+
13+
class _RaiseEventSource:
14+
def __init__(self, response: httpx.Response) -> None:
15+
self.response = response
16+
17+
async def aiter_sse(self):
18+
yield ServerSentEvent(event="message", data="", id=None, retry=None)
19+
raise RuntimeError("boom")
20+
21+
22+
@pytest.mark.anyio
23+
async def test_handle_sse_response_closes_response_on_exception(monkeypatch: pytest.MonkeyPatch) -> None:
24+
closed = False
25+
26+
async def spy_aclose() -> None:
27+
nonlocal closed
28+
closed = True
29+
30+
response = httpx.Response(200, headers={"content-type": "text/event-stream"})
31+
response.aclose = spy_aclose # type: ignore[method-assign]
32+
33+
monkeypatch.setattr("mcp.client.streamable_http.EventSource", _RaiseEventSource)
34+
35+
send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1)
36+
async with send_stream, receive_stream:
37+
transport = StreamableHTTPTransport("http://example.invalid/mcp")
38+
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client:
39+
ctx = RequestContext(
40+
client=client,
41+
session_id=None,
42+
session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)),
43+
metadata=ClientMessageMetadata(),
44+
read_stream_writer=send_stream,
45+
)
46+
await transport._handle_sse_response(response, ctx)
47+
48+
assert closed
49+
50+
51+
@pytest.mark.anyio
52+
async def test_handle_resumption_request_closes_response_when_aconnect_sse_raises(
53+
monkeypatch: pytest.MonkeyPatch,
54+
) -> None:
55+
@contextlib.asynccontextmanager
56+
async def fake_aconnect_sse(*_args, **_kwargs):
57+
raise RuntimeError("connect failed")
58+
yield
59+
60+
monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse)
61+
62+
send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1)
63+
async with send_stream, receive_stream:
64+
transport = StreamableHTTPTransport("http://example.invalid/mcp")
65+
metadata = ClientMessageMetadata(resumption_token="1")
66+
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client:
67+
ctx = RequestContext(
68+
client=client,
69+
session_id=None,
70+
session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)),
71+
metadata=metadata,
72+
read_stream_writer=send_stream,
73+
)
74+
75+
with pytest.raises(RuntimeError, match="connect failed"):
76+
await transport._handle_resumption_request(ctx)
77+
78+
79+
@pytest.mark.anyio
80+
async def test_handle_resumption_request_closes_response_on_exception(monkeypatch: pytest.MonkeyPatch) -> None:
81+
closed = False
82+
83+
async def spy_aclose() -> None:
84+
nonlocal closed
85+
closed = True
86+
87+
response = httpx.Response(
88+
200,
89+
headers={"content-type": "text/event-stream"},
90+
request=httpx.Request("GET", "http://example.invalid/mcp"),
91+
)
92+
response.aclose = spy_aclose # type: ignore[method-assign]
93+
94+
@contextlib.asynccontextmanager
95+
async def fake_aconnect_sse(*_args, **_kwargs):
96+
yield _RaiseEventSource(response)
97+
98+
monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse)
99+
100+
send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1)
101+
async with send_stream, receive_stream:
102+
transport = StreamableHTTPTransport("http://example.invalid/mcp")
103+
metadata = ClientMessageMetadata(resumption_token="1")
104+
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(200))) as client:
105+
ctx = RequestContext(
106+
client=client,
107+
session_id=None,
108+
session_message=SessionMessage(JSONRPCRequest(method="initialize", params={}, jsonrpc="2.0", id=1)),
109+
metadata=metadata,
110+
read_stream_writer=send_stream,
111+
)
112+
113+
with pytest.raises(RuntimeError, match="boom"):
114+
await transport._handle_resumption_request(ctx)
115+
116+
assert closed

0 commit comments

Comments
 (0)