Skip to content

Commit e199eef

Browse files
committed
fix: suppress completed request cancellation leak
1 parent 161834d commit e199eef

2 files changed

Lines changed: 102 additions & 2 deletions

File tree

src/mcp/shared/session.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,28 @@ def __exit__(
108108
exc_type: type[BaseException] | None,
109109
exc_val: BaseException | None,
110110
exc_tb: TracebackType | None,
111-
) -> None:
111+
) -> bool | None:
112112
"""Exit the context manager, performing cleanup and notifying completion."""
113+
suppress = False
113114
try:
114115
if self._completed:
115116
self._on_complete(self)
116117
finally:
117118
self._entered = False
118119
if not self._cancel_scope: # pragma: no cover
119120
raise RuntimeError("No active cancel scope")
120-
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
121+
try:
122+
suppress = self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
123+
except BaseException as exc:
124+
if (
125+
self._completed
126+
and self._cancel_scope.cancel_called
127+
and isinstance(exc, anyio.get_cancelled_exc_class())
128+
):
129+
return True
130+
raise
131+
132+
return suppress
121133

122134
async def respond(self, response: SendResultT | ErrorData) -> None:
123135
"""Send a response for this request.

tests/shared/test_session.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, cast
2+
13
import anyio
24
import pytest
35

@@ -23,6 +25,11 @@
2325
)
2426

2527

28+
class _DummySession:
29+
async def _send_response(self, request_id: int | str, response: ClientResult | ErrorData) -> None:
30+
pass
31+
32+
2633
@pytest.mark.anyio
2734
async def test_in_flight_requests_cleared_after_completion():
2835
"""Verify that _in_flight is empty after all requests complete."""
@@ -98,6 +105,87 @@ async def make_request(client: Client):
98105
await ev_cancelled.wait()
99106

100107

108+
@pytest.mark.anyio
109+
async def test_request_responder_suppresses_completed_cancellation():
110+
"""A request-local cancellation should not leak out after cancel() responds."""
111+
112+
completed: list[RequestResponder[ServerRequest, ClientResult]] = []
113+
responder = RequestResponder[ServerRequest, ClientResult](
114+
request_id=1,
115+
request_meta=None,
116+
request=types.PingRequest(),
117+
session=cast(Any, _DummySession()),
118+
on_complete=completed.append,
119+
)
120+
121+
with responder:
122+
await responder.cancel()
123+
await anyio.sleep(0)
124+
125+
assert completed == [responder]
126+
127+
128+
@pytest.mark.anyio
129+
async def test_request_responder_ignores_late_completed_cancellation():
130+
"""Some backends can surface cancellation while leaving an already-cancelled scope."""
131+
132+
class _CancelScope:
133+
cancel_called = True
134+
135+
def __exit__(
136+
self,
137+
exc_type: type[BaseException] | None,
138+
exc_val: BaseException | None,
139+
exc_tb: object,
140+
) -> bool:
141+
raise anyio.get_cancelled_exc_class()
142+
143+
completed: list[RequestResponder[ServerRequest, ClientResult]] = []
144+
responder = RequestResponder[ServerRequest, ClientResult](
145+
request_id=1,
146+
request_meta=None,
147+
request=types.PingRequest(),
148+
session=cast(Any, _DummySession()),
149+
on_complete=completed.append,
150+
)
151+
responder._entered = True
152+
responder._completed = True
153+
responder._cancel_scope = cast(Any, _CancelScope())
154+
155+
assert responder.__exit__(None, None, None) is True
156+
assert completed == [responder]
157+
158+
159+
@pytest.mark.anyio
160+
async def test_request_responder_reraises_unexpected_exit_error():
161+
"""Unexpected cancel scope errors should still propagate."""
162+
163+
class _CancelScope:
164+
cancel_called = False
165+
166+
def __exit__(
167+
self,
168+
exc_type: type[BaseException] | None,
169+
exc_val: BaseException | None,
170+
exc_tb: object,
171+
) -> bool:
172+
raise RuntimeError("boom")
173+
174+
responder = RequestResponder[ServerRequest, ClientResult](
175+
request_id=1,
176+
request_meta=None,
177+
request=types.PingRequest(),
178+
session=cast(Any, _DummySession()),
179+
on_complete=lambda _: None,
180+
)
181+
responder._entered = True
182+
responder._completed = True
183+
responder._cancel_scope = cast(Any, _CancelScope())
184+
185+
with pytest.raises(RuntimeError, match="boom"):
186+
responder.__exit__(None, None, None)
187+
188+
101189
@pytest.mark.anyio
102190
async def test_response_id_type_mismatch_string_to_int():
103191
"""Test that responses with string IDs are correctly matched to requests sent with

0 commit comments

Comments
 (0)