Skip to content

Commit 065264d

Browse files
fix: resolve lost-wakeup races in InMemoryTaskStore.wait_for_update
Two races in wait_for_update: 1. Concurrent waiters: second caller overwrites the first's event in _update_events[task_id], so the first waiter hangs forever. Fix: use a list of events per task_id so each waiter gets its own. 2. Notify before wait: if update_task completes before wait_for_update is called, the signal is lost because no event exists yet. Fix: track pending updates in a set; wait_for_update checks and consumes pending flags before creating an event. Both races are reachable via task_result_handler.py:126 when multiple clients poll the same task or when a task completes between status checks. Adds two tests: concurrent waiters and notify-before-wait. Fixes #2535
1 parent 3d7b311 commit 065264d

2 files changed

Lines changed: 69 additions & 7 deletions

File tree

src/mcp/shared/experimental/tasks/in_memory_task_store.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class InMemoryTaskStore(TaskStore):
4646
def __init__(self, page_size: int = 10) -> None:
4747
self._tasks: dict[str, StoredTask] = {}
4848
self._page_size = page_size
49-
self._update_events: dict[str, anyio.Event] = {}
49+
self._update_events: dict[str, list[anyio.Event]] = {}
50+
self._pending_updates: set[str] = set()
5051

5152
def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
5253
"""Calculate expiry time from TTL in milliseconds."""
@@ -194,22 +195,43 @@ async def wait_for_update(self, task_id: str) -> None:
194195
if task_id not in self._tasks:
195196
raise ValueError(f"Task with ID {task_id} not found")
196197

197-
# Create a fresh event for waiting (anyio.Event can't be cleared)
198-
self._update_events[task_id] = anyio.Event()
199-
event = self._update_events[task_id]
200-
await event.wait()
198+
# If an update arrived before we started waiting, consume it and return.
199+
if task_id in self._pending_updates:
200+
self._pending_updates.discard(task_id)
201+
return
202+
203+
# Create a per-waiter event so multiple concurrent waiters each get woken.
204+
event = anyio.Event()
205+
if task_id not in self._update_events:
206+
self._update_events[task_id] = []
207+
self._update_events[task_id].append(event)
208+
try:
209+
await event.wait()
210+
finally:
211+
# Clean up our event from the list (may already be removed by notify).
212+
try:
213+
self._update_events[task_id].remove(event)
214+
except (ValueError, KeyError):
215+
pass
201216

202217
async def notify_update(self, task_id: str) -> None:
203218
"""Signal that a task has been updated."""
204-
if task_id in self._update_events:
205-
self._update_events[task_id].set()
219+
events = self._update_events.pop(task_id, [])
220+
if events:
221+
for event in events:
222+
event.set()
223+
else:
224+
# No waiters yet; mark as pending so the next wait_for_update returns
225+
# immediately instead of blocking.
226+
self._pending_updates.add(task_id)
206227

207228
# --- Testing/debugging helpers ---
208229

209230
def cleanup(self) -> None:
210231
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
211232
self._tasks.clear()
212233
self._update_events.clear()
234+
self._pending_updates.clear()
213235

214236
def get_all_tasks(self) -> list[Task]:
215237
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""

tests/experimental/tasks/server/test_store.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import AsyncIterator
44
from datetime import datetime, timedelta, timezone
55

6+
import anyio
67
import pytest
78

89
from mcp.shared.exceptions import MCPError
@@ -328,6 +329,45 @@ async def test_wait_for_update_nonexistent_raises(store: InMemoryTaskStore) -> N
328329
await store.wait_for_update("nonexistent-task-id")
329330

330331

332+
@pytest.mark.anyio
333+
async def test_wait_for_update_concurrent_waiters(store: InMemoryTaskStore) -> None:
334+
"""Two concurrent waiters for the same task must both wake up."""
335+
task = await store.create_task(metadata=TaskMetadata(ttl=60000))
336+
337+
woke: dict[str, bool] = {"a": False, "b": False}
338+
339+
async def waiter(name: str) -> None:
340+
await store.wait_for_update(task.task_id)
341+
woke[name] = True
342+
343+
async def updater() -> None:
344+
await anyio.sleep(0.05)
345+
await store.update_task(task.task_id, status="completed")
346+
347+
with anyio.fail_after(2):
348+
async with anyio.create_task_group() as tg:
349+
tg.start_soon(waiter, "a")
350+
await anyio.sleep(0.01) # ensure a registers first
351+
tg.start_soon(waiter, "b")
352+
tg.start_soon(updater)
353+
354+
assert woke["a"], "waiter a should have been woken"
355+
assert woke["b"], "waiter b should have been woken"
356+
357+
358+
@pytest.mark.anyio
359+
async def test_wait_for_update_notify_before_wait(store: InMemoryTaskStore) -> None:
360+
"""If notify fires before wait, the signal must not be lost."""
361+
task = await store.create_task(metadata=TaskMetadata(ttl=60000))
362+
363+
# Task completes before anyone waits
364+
await store.update_task(task.task_id, status="completed")
365+
366+
# wait_for_update should return immediately (pending update consumed)
367+
with anyio.fail_after(1):
368+
await store.wait_for_update(task.task_id)
369+
370+
331371
@pytest.mark.anyio
332372
async def test_cancel_task_succeeds_for_working_task(store: InMemoryTaskStore) -> None:
333373
"""Test cancel_task helper succeeds for a working task."""

0 commit comments

Comments
 (0)