Skip to content

Commit 17b3ab2

Browse files
fix: resolve lost-wakeup races in InMemoryTaskStore.wait_for_update
Two races in wait_for_update: 1. Concurrent waiters: second caller overwrites the first's event in _update_events[task_id], so the first waiter hangs forever. Fix: use a list of events per task_id so each waiter gets its own. 2. Notify before wait: if update_task completes before wait_for_update is called, the signal is lost because no event exists yet. Fix: track pending updates in a set; wait_for_update checks and consumes pending flags before creating an event. Both races are reachable via task_result_handler.py:126 when multiple clients poll the same task or when a task completes between status checks. Fixes #2535
1 parent 3d7b311 commit 17b3ab2

1 file changed

Lines changed: 29 additions & 7 deletions

File tree

src/mcp/shared/experimental/tasks/in_memory_task_store.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class InMemoryTaskStore(TaskStore):
4646
def __init__(self, page_size: int = 10) -> None:
4747
self._tasks: dict[str, StoredTask] = {}
4848
self._page_size = page_size
49-
self._update_events: dict[str, anyio.Event] = {}
49+
self._update_events: dict[str, list[anyio.Event]] = {}
50+
self._pending_updates: set[str] = set()
5051

5152
def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
5253
"""Calculate expiry time from TTL in milliseconds."""
@@ -194,22 +195,43 @@ async def wait_for_update(self, task_id: str) -> None:
194195
if task_id not in self._tasks:
195196
raise ValueError(f"Task with ID {task_id} not found")
196197

197-
# Create a fresh event for waiting (anyio.Event can't be cleared)
198-
self._update_events[task_id] = anyio.Event()
199-
event = self._update_events[task_id]
200-
await event.wait()
198+
# If an update arrived before we started waiting, consume it and return.
199+
if task_id in self._pending_updates:
200+
self._pending_updates.discard(task_id)
201+
return
202+
203+
# Create a per-waiter event so multiple concurrent waiters each get woken.
204+
event = anyio.Event()
205+
if task_id not in self._update_events:
206+
self._update_events[task_id] = []
207+
self._update_events[task_id].append(event)
208+
try:
209+
await event.wait()
210+
finally:
211+
# Clean up our event from the list (may already be removed by notify).
212+
try:
213+
self._update_events[task_id].remove(event)
214+
except (ValueError, KeyError):
215+
pass
201216

202217
async def notify_update(self, task_id: str) -> None:
203218
"""Signal that a task has been updated."""
204-
if task_id in self._update_events:
205-
self._update_events[task_id].set()
219+
events = self._update_events.pop(task_id, [])
220+
if events:
221+
for event in events:
222+
event.set()
223+
else:
224+
# No waiters yet; mark as pending so the next wait_for_update returns
225+
# immediately instead of blocking.
226+
self._pending_updates.add(task_id)
206227

207228
# --- Testing/debugging helpers ---
208229

209230
def cleanup(self) -> None:
210231
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
211232
self._tasks.clear()
212233
self._update_events.clear()
234+
self._pending_updates.clear()
213235

214236
def get_all_tasks(self) -> list[Task]:
215237
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""

0 commit comments

Comments
 (0)