Skip to content

Commit 113f35c

Browse files
committed
Make session_id optional in TaskStore; reject stateless mode instead
Previously, session_id was required (str) throughout the task store interface, which forced sessionless transports like stdio and memory to fabricate UUIDs before running tasks. This was awkward because those transports are single-client by architecture and have no session concept. The policy is now enforced at the handler layer instead of the store layer: default task handlers reject requests when the server is in stateless mode (where tasks cannot survive across requests), and pass session_id through as-is otherwise. A None session_id simply means no session-scoped isolation, which is correct for single-client transports. Isolation in InMemoryTaskStore uses strict equality: None only matches None, so tasks created by a sessionless transport are not visible to session-scoped transports and vice versa, preventing cross-transport leaks when a process serves multiple transports from one store. Changes: - TaskStore, InMemoryTaskStore, TaskContext, helpers: session_id is now str | None - ServerSession: expose stateless property - Default task handlers: reject stateless mode, pass session_id as-is - run_task() / ServerTaskContext: accept None session_id, reject stateless mode - Memory transport: revert fabricated UUID session_id - New tests for None-session isolation (strict equality, no cross-transport leaks) :house: Remote-Dev: homespace
1 parent a4f5ade commit 113f35c

File tree

14 files changed

+172
-72
lines changed

14 files changed

+172
-72
lines changed

src/mcp/client/_memory.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import uuid
65
from collections.abc import AsyncIterator
76
from contextlib import AbstractAsyncContextManager, asynccontextmanager
87
from types import TracebackType
@@ -51,14 +50,12 @@ async def _connect(self) -> AsyncIterator[TransportStreams]:
5150

5251
async with anyio.create_task_group() as tg:
5352
# Start server in background
54-
memory_session_id = uuid.uuid4().hex
5553
tg.start_soon(
5654
lambda: actual_server.run(
5755
server_read,
5856
server_write,
5957
actual_server.create_initialization_options(),
6058
raise_exceptions=self._raise_exceptions,
61-
session_id=memory_session_id,
6259
)
6360
)
6461

src/mcp/server/experimental/request_context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,11 @@ async def work(task: ServerTaskContext) -> CallToolResult:
189189
# Access task_group via TaskSupport - raises if not in run() context
190190
task_group = support.task_group
191191

192+
if self._session.stateless:
193+
raise RuntimeError(
194+
"run_task() does not support stateless mode. Tasks require a persistent session for result retrieval."
195+
)
192196
session_id = self._session.session_id
193-
if session_id is None:
194-
raise RuntimeError("Session ID is required for task operations but session has no ID.")
195197
task = await support.store.create_task(self.task_metadata, task_id, session_id=session_id)
196198

197199
task_ctx = ServerTaskContext(

src/mcp/server/experimental/task_context.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,8 @@ def __init__(
9090
queue: The message queue for elicitation/sampling
9191
handler: The result handler for response routing (required for elicit/create_message)
9292
"""
93-
session_id = session.session_id
94-
if session_id is None:
95-
raise RuntimeError("Session ID is required for task operations but session has no ID.")
96-
self._session_id = session_id
97-
self._ctx = TaskContext(task=task, store=store, session_id=session_id)
93+
self._session_id = session.session_id
94+
self._ctx = TaskContext(task=task, store=store, session_id=self._session_id)
9895
self._session = session
9996
self._queue = queue
10097
self._handler = handler

src/mcp/server/experimental/task_result_handler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async def handle(
8080
request: GetTaskPayloadRequest,
8181
session: ServerSession,
8282
request_id: RequestId,
83-
session_id: str,
83+
session_id: str | None,
8484
) -> GetTaskPayloadResult:
8585
"""Handle a tasks/result request.
8686
@@ -95,7 +95,8 @@ async def handle(
9595
request: The GetTaskPayloadRequest
9696
session: The server session for sending messages
9797
request_id: The request ID for relatedRequestId routing
98-
session_id: Session identifier for access control.
98+
session_id: Session identifier for access control. Must exactly
99+
match the session_id the task was created with (including None).
99100
100101
Returns:
101102
GetTaskPayloadResult with the task's final payload

src/mcp/server/lowlevel/experimental.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,23 +153,24 @@ def enable_tasks(
153153
if on_cancel_task is not None:
154154
self._add_request_handler("tasks/cancel", on_cancel_task)
155155

156-
def _require_session_id(ctx: ServerRequestContext[LifespanResultT]) -> str:
157-
session_id = ctx.session.session_id
158-
if session_id is None:
156+
def _check_stateless(ctx: ServerRequestContext[LifespanResultT]) -> None:
157+
if ctx.session.stateless:
159158
raise MCPError(
160159
code=INVALID_PARAMS,
161-
message="Session ID is required for task operations.",
160+
message=(
161+
"Default task handlers do not support stateless mode. "
162+
"Provide custom task handlers if you need stateless task support."
163+
),
162164
)
163-
return session_id
164165

165166
# Fill in defaults for any not provided
166167
if not self._has_handler("tasks/get"):
167168

168169
async def _default_get_task(
169170
ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams
170171
) -> GetTaskResult:
171-
session_id = _require_session_id(ctx)
172-
task = await task_support.store.get_task(params.task_id, session_id=session_id)
172+
_check_stateless(ctx)
173+
task = await task_support.store.get_task(params.task_id, session_id=ctx.session.session_id)
173174
if task is None:
174175
raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}")
175176
return GetTaskResult(
@@ -190,9 +191,11 @@ async def _default_get_task_result(
190191
ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams
191192
) -> GetTaskPayloadResult:
192193
assert ctx.request_id is not None
193-
session_id = _require_session_id(ctx)
194+
_check_stateless(ctx)
194195
req = GetTaskPayloadRequest(params=params)
195-
result = await task_support.handler.handle(req, ctx.session, ctx.request_id, session_id=session_id)
196+
result = await task_support.handler.handle(
197+
req, ctx.session, ctx.request_id, session_id=ctx.session.session_id
198+
)
196199
return result
197200

198201
self._add_request_handler("tasks/result", _default_get_task_result)
@@ -202,9 +205,9 @@ async def _default_get_task_result(
202205
async def _default_list_tasks(
203206
ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None
204207
) -> ListTasksResult:
208+
_check_stateless(ctx)
205209
cursor = params.cursor if params else None
206-
session_id = _require_session_id(ctx)
207-
tasks, next_cursor = await task_support.store.list_tasks(cursor, session_id=session_id)
210+
tasks, next_cursor = await task_support.store.list_tasks(cursor, session_id=ctx.session.session_id)
208211
return ListTasksResult(tasks=tasks, next_cursor=next_cursor)
209212

210213
self._add_request_handler("tasks/list", _default_list_tasks)
@@ -214,8 +217,8 @@ async def _default_list_tasks(
214217
async def _default_cancel_task(
215218
ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams
216219
) -> CancelTaskResult:
217-
session_id = _require_session_id(ctx)
218-
result = await cancel_task(task_support.store, params.task_id, session_id=session_id)
220+
_check_stateless(ctx)
221+
result = await cancel_task(task_support.store, params.task_id, session_id=ctx.session.session_id)
219222
return result
220223

221224
self._add_request_handler("tasks/cancel", _default_cancel_task)

src/mcp/server/session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification]
111111
def client_params(self) -> types.InitializeRequestParams | None:
112112
return self._client_params
113113

114+
@property
115+
def stateless(self) -> bool:
116+
"""Whether this session is in stateless mode (no persistent server-side state)."""
117+
return self._stateless
118+
114119
@property
115120
def experimental(self) -> ExperimentalServerSessionFeatures:
116121
"""Experimental APIs for server→client task operations.

src/mcp/shared/experimental/tasks/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class TaskContext:
2121
use ServerTaskContext from mcp.server.experimental.
2222
2323
Example (distributed worker):
24-
async def worker_job(task_id: str, session_id: str):
24+
async def worker_job(task_id: str, session_id: str | None):
2525
store = RedisTaskStore(redis_url)
2626
task = await store.get_task(task_id, session_id=session_id)
2727
ctx = TaskContext(task=task, store=store, session_id=session_id)
@@ -31,7 +31,7 @@ async def worker_job(task_id: str, session_id: str):
3131
await ctx.complete(result)
3232
"""
3333

34-
def __init__(self, task: Task, store: TaskStore, *, session_id: str):
34+
def __init__(self, task: Task, store: TaskStore, *, session_id: str | None):
3535
self._task = task
3636
self._store = store
3737
self._session_id = session_id

src/mcp/shared/experimental/tasks/helpers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def cancel_task(
5151
store: TaskStore,
5252
task_id: str,
5353
*,
54-
session_id: str,
54+
session_id: str | None,
5555
) -> CancelTaskResult:
5656
"""Cancel a task with spec-compliant validation.
5757
@@ -64,7 +64,8 @@ async def cancel_task(
6464
Args:
6565
store: The task store
6666
task_id: The task identifier to cancel
67-
session_id: Session identifier for access control.
67+
session_id: Session identifier for access control. Must exactly match
68+
the session_id the task was created with (including None).
6869
6970
Returns:
7071
CancelTaskResult with the cancelled task state
@@ -128,7 +129,7 @@ async def task_execution(
128129
task_id: str,
129130
store: TaskStore,
130131
*,
131-
session_id: str,
132+
session_id: str | None,
132133
) -> AsyncIterator[TaskContext]:
133134
"""Context manager for safe task execution (pure, no server dependencies).
134135
@@ -141,7 +142,8 @@ async def task_execution(
141142
Args:
142143
task_id: The task identifier to execute
143144
store: The task store (must be accessible by the worker)
144-
session_id: Session identifier for access control.
145+
session_id: Session identifier for access control. Must exactly match
146+
the session_id the task was created with (including None).
145147
146148
Yields:
147149
TaskContext for updating status and completing/failing the task
@@ -150,7 +152,7 @@ async def task_execution(
150152
ValueError: If the task is not found in the store
151153
152154
Example (distributed worker):
153-
async def worker_process(task_id: str, session_id: str):
155+
async def worker_process(task_id: str, session_id: str | None):
154156
store = RedisTaskStore(redis_url)
155157
async with task_execution(task_id, store, session_id=session_id) as ctx:
156158
await ctx.update_status("Working...")

src/mcp/shared/experimental/tasks/in_memory_task_store.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class StoredTask:
2222
"""Internal storage representation of a task."""
2323

2424
task: Task
25-
session_id: str
25+
session_id: str | None
2626
result: Result | None = None
2727
# Time when this task should be removed (None = never)
2828
expires_at: datetime | None = field(default=None)
@@ -68,10 +68,13 @@ def _cleanup_expired(self) -> None:
6868
for task_id in expired_ids:
6969
del self._tasks[task_id]
7070

71-
def _get_stored_task(self, task_id: str, *, session_id: str) -> StoredTask | None:
71+
def _get_stored_task(self, task_id: str, *, session_id: str | None) -> StoredTask | None:
7272
"""Retrieve a stored task, enforcing session ownership.
7373
7474
Returns None if the task does not exist or belongs to a different session.
75+
Isolation uses strict equality: None only matches None, so tasks created
76+
by sessionless transports (stdio) are not visible to session-scoped
77+
transports (HTTP) and vice versa.
7578
"""
7679
stored = self._tasks.get(task_id)
7780
if stored is None:
@@ -85,7 +88,7 @@ async def create_task(
8588
metadata: TaskMetadata,
8689
task_id: str | None = None,
8790
*,
88-
session_id: str,
91+
session_id: str | None,
8992
) -> Task:
9093
"""Create a new task with the given metadata."""
9194
# Cleanup expired tasks on access
@@ -106,7 +109,7 @@ async def create_task(
106109
# Return a copy to prevent external modification
107110
return Task(**task.model_dump())
108111

109-
async def get_task(self, task_id: str, *, session_id: str) -> Task | None:
112+
async def get_task(self, task_id: str, *, session_id: str | None) -> Task | None:
110113
"""Get a task by ID."""
111114
# Cleanup expired tasks on access
112115
self._cleanup_expired()
@@ -124,7 +127,7 @@ async def update_task(
124127
status: TaskStatus | None = None,
125128
status_message: str | None = None,
126129
*,
127-
session_id: str,
130+
session_id: str | None,
128131
) -> Task:
129132
"""Update a task's status and/or message."""
130133
stored = self._get_stored_task(task_id, session_id=session_id)
@@ -156,15 +159,15 @@ async def update_task(
156159

157160
return Task(**stored.task.model_dump())
158161

159-
async def store_result(self, task_id: str, result: Result, *, session_id: str) -> None:
162+
async def store_result(self, task_id: str, result: Result, *, session_id: str | None) -> None:
160163
"""Store the result for a task."""
161164
stored = self._get_stored_task(task_id, session_id=session_id)
162165
if stored is None:
163166
raise ValueError(f"Task with ID {task_id} not found")
164167

165168
stored.result = result
166169

167-
async def get_result(self, task_id: str, *, session_id: str) -> Result | None:
170+
async def get_result(self, task_id: str, *, session_id: str | None) -> Result | None:
168171
"""Get the stored result for a task."""
169172
stored = self._get_stored_task(task_id, session_id=session_id)
170173
if stored is None:
@@ -176,13 +179,14 @@ async def list_tasks(
176179
self,
177180
cursor: str | None = None,
178181
*,
179-
session_id: str,
182+
session_id: str | None,
180183
) -> tuple[list[Task], str | None]:
181184
"""List tasks with pagination."""
182185
# Cleanup expired tasks on access
183186
self._cleanup_expired()
184187

185-
# Filter tasks by session ownership before pagination
188+
# Filter tasks by session ownership before pagination.
189+
# Strict equality: None only matches None (sessionless transports).
186190
filtered_task_ids = [task_id for task_id, stored in self._tasks.items() if stored.session_id == session_id]
187191

188192
start_index = 0
@@ -203,7 +207,7 @@ async def list_tasks(
203207

204208
return tasks, next_cursor
205209

206-
async def delete_task(self, task_id: str, *, session_id: str) -> bool:
210+
async def delete_task(self, task_id: str, *, session_id: str | None) -> bool:
207211
"""Delete a task."""
208212
stored = self._get_stored_task(task_id, session_id=session_id)
209213
if stored is None:

0 commit comments

Comments
 (0)