Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/examples/server/simpleStreamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,11 @@ const getServer = () => {

return { task };
},
async getTask(_args, { taskId, taskStore: getTaskStore }) {
return await getTaskStore.getTask(taskId);
async getTask({ taskId, taskStore }) {
return await taskStore.getTask(taskId);
},
async getTaskResult(_args, { taskId, taskStore: getResultTaskStore }) {
const result = await getResultTaskStore.getTaskResult(taskId);
async getTaskResult({ taskId, taskStore }) {
const result = await taskStore.getTaskResult(taskId);
return result as CallToolResult;
}
}
Expand Down Expand Up @@ -605,10 +605,10 @@ const getServer = () => {
task
};
},
async getTask(_args, { taskId, taskStore }) {
async getTask({ taskId, taskStore }) {
return await taskStore.getTask(taskId);
},
async getTaskResult(_args, { taskId, taskStore }) {
async getTaskResult({ taskId, taskStore }) {
Comment on lines +608 to +611
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug: collect-user-info-task at lines 382-388 still uses the old two-argument signature (_args, { taskId, taskStore }) for getTask and getTaskResult, while the PR updated TaskRequestHandler to accept a single extra argument. This will cause a TypeError: Cannot destructure property 'taskId' of 'undefined' at runtime when these handlers are invoked, since the entire extra object is passed as the first argument and undefined as the second. The delay tool at lines 608-614 was correctly updated but this tool was missed.

Extended reasoning...

What the bug is

The PR changes TaskRequestHandler from a multi-argument type (BaseToolCallback<SendResultT, TaskRequestHandlerExtra, Args>) to a single-argument function type ((extra: TaskRequestHandlerExtra) => SendResultT | Promise<SendResultT>). This means getTask and getTaskResult handlers now receive a single object containing { taskId, taskStore, ...extra } instead of the old (args, extra) two-argument pattern.

The specific code path

In src/server/mcp.ts, the McpServer constructor sets up taskHandlerHooks that call handler.getTask({ ...extra, taskId, taskStore }) and handler.getTaskResult({ ...extra, taskId, taskStore }) with a single argument. When collect-user-info-task's handler at line 382 declares async getTask(_args, { taskId, taskStore: getTaskStore }), JavaScript passes the entire { ...extra, taskId, taskStore } object as _args and undefined as the second parameter.

Why existing code doesn't prevent it

The delay tool was correctly updated in this PR (lines 608-614 in the diff show the fix from (_args, { taskId, taskStore }) to ({ taskId, taskStore })), but the collect-user-info-task tool at lines 382-388 was overlooked. TypeScript should catch this as a compilation error since a 2-parameter function is not assignable to the new 1-parameter TaskRequestHandler type, but example files may not be compiled as part of CI.

Step-by-step proof

  1. Client sends tools/call with name: "collect-user-info-task" and task: { ttl: 60000 }
  2. Server creates the task via createTask, returns CreateTaskResult
  3. Client later calls tasks/get with the taskId
  4. Protocol dispatches to taskHandlerHooks.getTask(taskId, extra)
  5. McpServer's hook calls handler.getTask({ ...extra, taskId, taskStore }) — one argument
  6. The old handler signature async getTask(_args, { taskId, taskStore: getTaskStore }) receives: _args = { signal, sessionId, taskId, taskStore, ... } and the second parameter = undefined
  7. Destructuring { taskId, taskStore: getTaskStore } from undefined throws TypeError: Cannot destructure property 'taskId' of 'undefined'
  8. The server crashes with an unhandled exception

Impact

This is a runtime crash in the example server. Any client that creates a task using the collect-user-info-task tool and then polls its status will crash the server. The delay tool works correctly since it was updated.

Fix

Update lines 382-388 to use the new single-argument signature, matching the delay tool:

async getTask({ taskId, taskStore: getTaskStore }) {
    return await getTaskStore.getTask(taskId);
},
async getTaskResult({ taskId, taskStore: getResultTaskStore }) {
    const result = await getResultTaskStore.getTaskResult(taskId);
    return result as CallToolResult;
}

const result = await taskStore.getTaskResult(taskId);
return result as CallToolResult;
}
Expand Down
9 changes: 3 additions & 6 deletions src/experimental/tasks/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,16 @@ export type CreateTaskRequestHandler<
* Handler for task operations (get, getResult).
* @experimental
*/
export type TaskRequestHandler<
SendResultT extends Result,
Args extends undefined | ZodRawShapeCompat | AnySchema = undefined
> = BaseToolCallback<SendResultT, TaskRequestHandlerExtra, Args>;
export type TaskRequestHandler<SendResultT extends Result> = (extra: TaskRequestHandlerExtra) => SendResultT | Promise<SendResultT>;

/**
* Interface for task-based tool handlers.
* @experimental
*/
export interface ToolTaskHandler<Args extends undefined | ZodRawShapeCompat | AnySchema = undefined> {
createTask: CreateTaskRequestHandler<CreateTaskResult, Args>;
getTask: TaskRequestHandler<GetTaskResult, Args>;
getTaskResult: TaskRequestHandler<CallToolResult, Args>;
getTask: TaskRequestHandler<GetTaskResult>;
getTaskResult: TaskRequestHandler<CallToolResult>;
}

/**
Expand Down
74 changes: 67 additions & 7 deletions src/server/mcp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,43 @@ export class McpServer {
private _registeredTools: { [name: string]: RegisteredTool } = {};
private _registeredPrompts: { [name: string]: RegisteredPrompt } = {};
private _experimental?: { tasks: ExperimentalMcpServerTasks };
private _taskToolMap: Map<string, string> = new Map();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 _taskToolMap entries (taskId → toolName) are added at line 248 but never removed — there is no .delete() or .clear() call anywhere. For long-running servers processing many tasks, this map grows unboundedly even after tasks reach terminal states or expire via TTL. Consider adding cleanup when a task completes/fails/is cancelled, or lazily when _getTaskHandler finds a task no longer exists in the store.

Extended reasoning...

What the bug is

The _taskToolMap field (declared at line 85 as Map<string, string>) stores a mapping from taskId to the tool name that created it. Entries are added at line 248 via this._taskToolMap.set(taskResult.task.taskId, request.params.name) whenever a task-augmented tool call returns a CreateTaskResult. However, there is no corresponding .delete() call anywhere in the codebase — entries persist for the lifetime of the McpServer instance.

How it manifests

Every task creation adds a string → string entry to the map. When a task completes, fails, is cancelled, or expires via its TTL and gets cleaned up from the TaskStore, the corresponding _taskToolMap entry remains. Over time, for a server that processes many tasks, this map grows monotonically.

Step-by-step proof

  1. A client calls tools/call with task: { ttl: 60000 } for a registered tool task.
  2. The CallToolRequestSchema handler at line 243-249 executes: this._taskToolMap.set(taskResult.task.taskId, request.params.name).
  3. The task completes — TaskStore.storeTaskResult() is called, the task enters a terminal state.
  4. The task's TTL expires and InMemoryTaskStore cleans it up internally.
  5. The _taskToolMap still holds the taskId → toolName entry. There is no code path that removes it.
  6. Repeat steps 1-5 thousands of times — the map now holds thousands of stale entries.

Why existing code doesn't prevent it

Searching for all references to _taskToolMap reveals exactly three: the declaration (line 85), a .get() call (line 111), and the .set() call (line 248). No .delete(), .clear(), or any other cleanup mechanism exists.

Impact

Each entry is two short strings (taskId + toolName), so individual entries are small. For typical short-lived MCP server instances or low task throughput, this is unlikely to cause issues. However, for long-running servers processing a high volume of tasks (e.g., a persistent production server), memory usage will grow linearly and unboundedly over time.

Suggested fix

The simplest approach would be to add a lazy cleanup in _getTaskHandler: if the taskId is found in _taskToolMap but the task no longer exists in the TaskStore, delete the entry. Alternatively, cleanup could be added when a task reaches a terminal state (in the taskHandlerHooks or after storeTaskResult calls).


constructor(serverInfo: Implementation, options?: ServerOptions) {
this.server = new Server(serverInfo, options);
const taskHandlerHooks = {
getTask: async (taskId: string, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
// taskStore is guaranteed to exist here because Protocol only calls hooks when taskStore is configured
const taskStore = extra.taskStore!;
const handler = this._getTaskHandler(taskId);
if (handler) {
return await handler.getTask({ ...extra, taskId, taskStore });
}
return await taskStore.getTask(taskId);
},
getTaskResult: async (taskId: string, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
const taskStore = extra.taskStore!;
const handler = this._getTaskHandler(taskId);
try {
if (handler) {
return await handler.getTaskResult({ ...extra, taskId, taskStore });
}
return await taskStore.getTaskResult(taskId);
} finally {
// Once the result has been retrieved the task is complete;
// drop the taskId → toolName mapping to avoid unbounded growth.
this._taskToolMap.delete(taskId);
}
}
};
this.server = new Server(serverInfo, { ...options, taskHandlerHooks });
}

private _getTaskHandler(taskId: string): ToolTaskHandler<ZodRawShapeCompat | undefined> | null {
const toolName = this._taskToolMap.get(taskId);
if (!toolName) return null;
const tool = this._registeredTools[toolName];
if (!tool || !('createTask' in (tool.handler as AnyToolHandler<ZodRawShapeCompat>))) return null;
return tool.handler as ToolTaskHandler<ZodRawShapeCompat | undefined>;
}

/**
Expand Down Expand Up @@ -215,6 +249,10 @@ export class McpServer {

// Return CreateTaskResult immediately for task requests
if (isTaskRequest) {
const taskResult = result as CreateTaskResult;
if (taskResult.task?.taskId) {
this._taskToolMap.set(taskResult.task.taskId, request.params.name);
}
return result;
}

Expand Down Expand Up @@ -374,27 +412,28 @@ export class McpServer {
const handler = tool.handler as ToolTaskHandler<ZodRawShapeCompat | undefined>;
const taskExtra = { ...extra, taskStore: extra.taskStore };

const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined
? await Promise.resolve((handler as ToolTaskHandler<ZodRawShapeCompat>).createTask(args, taskExtra))
: // eslint-disable-next-line @typescript-eslint/no-explicit-any
await Promise.resolve(((handler as ToolTaskHandler<undefined>).createTask as any)(taskExtra));
const wrappedHandler = toolTaskHandlerByArgs(handler, args);

const createTaskResult = await wrappedHandler.createTask(taskExtra);

// Poll until completion
const taskId = createTaskResult.task.taskId;
const taskExtraComplete = { ...extra, taskId, taskStore: extra.taskStore };
let task = createTaskResult.task;
const pollInterval = task.pollInterval ?? 5000;

while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') {
await new Promise(resolve => setTimeout(resolve, pollInterval));
const updatedTask = await extra.taskStore.getTask(taskId);
const getTaskResult = await wrappedHandler.getTask(taskExtraComplete);
const updatedTask = getTaskResult;
if (!updatedTask) {
throw new McpError(ErrorCode.InternalError, `Task ${taskId} not found during polling`);
}
task = updatedTask;
}

// Return the final result
return (await extra.taskStore.getTaskResult(taskId)) as CallToolResult;
return await wrappedHandler.getTaskResult(taskExtraComplete);
}

private _completionHandlerInitialized = false;
Expand Down Expand Up @@ -1545,3 +1584,24 @@ const EMPTY_COMPLETION_RESULT: CompleteResult = {
hasMore: false
}
};

/**
* Wraps a tool task handler's createTask to handle args uniformly.
* getTask and getTaskResult don't take args, so they're passed through directly.
* @param handler The task handler to wrap.
* @param args The tool arguments.
* @returns A wrapped task handler for a tool, which only exposes a no-args interface for createTask.
*/
function toolTaskHandlerByArgs<Args extends AnySchema | ZodRawShapeCompat | undefined>(
handler: ToolTaskHandler<Args>,
args: unknown
): ToolTaskHandler<undefined> {
return {
createTask: extra =>
args // undefined only if tool.inputSchema is undefined
? Promise.resolve((handler as ToolTaskHandler<ZodRawShapeCompat>).createTask(args, extra))
: Promise.resolve((handler as ToolTaskHandler<undefined>).createTask(extra)),
getTask: handler.getTask,
getTaskResult: handler.getTaskResult
};
}
32 changes: 31 additions & 1 deletion src/shared/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ export type ProtocolOptions = {
* appropriately (e.g., by failing the task, dropping messages, etc.).
*/
maxTaskQueueSize?: number;
/**
* Optional hooks for customizing task request handling.
* If a hook is provided, it fully owns the behavior (no fallback to TaskStore).
*/
taskHandlerHooks?: {
/**
* Called when tasks/get is received. If provided, must return the task.
*/
getTask?: (taskId: string, extra: RequestHandlerExtra<Request, Notification>) => Promise<GetTaskResult>;
/**
* Called when tasks/payload needs to retrieve the final result. If provided, must return the result.
*/
getTaskResult?: (taskId: string, extra: RequestHandlerExtra<Request, Notification>) => Promise<Result>;
};
};

/**
Expand Down Expand Up @@ -383,6 +397,16 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
this._taskMessageQueue = _options?.taskMessageQueue;
if (this._taskStore) {
this.setRequestHandler(GetTaskRequestSchema, async (request, extra) => {
// Use hook if provided, otherwise fall back to TaskStore
if (_options?.taskHandlerHooks?.getTask) {
const hookResult = await _options.taskHandlerHooks.getTask(
request.params.taskId,
extra as unknown as RequestHandlerExtra<Request, Notification>
);
// @ts-expect-error SendResultT cannot contain GetTaskResult
return hookResult as SendResultT;
}

const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId);
if (!task) {
throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found');
Expand Down Expand Up @@ -462,7 +486,13 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e

// If task is terminal, return the result
if (isTerminal(task.status)) {
const result = await this._taskStore!.getTaskResult(taskId, extra.sessionId);
// Use hook if provided, otherwise fall back to TaskStore
const result = this._options?.taskHandlerHooks?.getTaskResult
? await this._options.taskHandlerHooks.getTaskResult(
taskId,
extra as unknown as RequestHandlerExtra<Request, Notification>
)
: await this._taskStore!.getTaskResult(taskId, extra.sessionId);

this._clearTaskQueue(taskId);

Expand Down
24 changes: 12 additions & 12 deletions test/client/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2386,14 +2386,14 @@ describe('Task-based execution', () => {

return { task };
},
async getTask(_args, extra) {
async getTask(extra) {
const task = await extra.taskStore.getTask(extra.taskId);
if (!task) {
throw new Error(`Task ${extra.taskId} not found`);
}
return task;
},
async getTaskResult(_args, extra) {
async getTaskResult(extra) {
const result = await extra.taskStore.getTaskResult(extra.taskId);
return result as { content: Array<{ type: 'text'; text: string }> };
}
Expand Down Expand Up @@ -2462,14 +2462,14 @@ describe('Task-based execution', () => {

return { task };
},
async getTask(_args, extra) {
async getTask(extra) {
const task = await extra.taskStore.getTask(extra.taskId);
if (!task) {
throw new Error(`Task ${extra.taskId} not found`);
}
return task;
},
async getTaskResult(_args, extra) {
async getTaskResult(extra) {
const result = await extra.taskStore.getTaskResult(extra.taskId);
return result as { content: Array<{ type: 'text'; text: string }> };
}
Expand Down Expand Up @@ -2539,14 +2539,14 @@ describe('Task-based execution', () => {

return { task };
},
async getTask(_args, extra) {
async getTask(extra) {
const task = await extra.taskStore.getTask(extra.taskId);
if (!task) {
throw new Error(`Task ${extra.taskId} not found`);
}
return task;
},
async getTaskResult(_args, extra) {
async getTaskResult(extra) {
const result = await extra.taskStore.getTaskResult(extra.taskId);
return result as { content: Array<{ type: 'text'; text: string }> };
}
Expand Down Expand Up @@ -2620,14 +2620,14 @@ describe('Task-based execution', () => {

return { task };
},
async getTask(_args, extra) {
async getTask(extra) {
const task = await extra.taskStore.getTask(extra.taskId);
if (!task) {
throw new Error(`Task ${extra.taskId} not found`);
}
return task;
},
async getTaskResult(_args, extra) {
async getTaskResult(extra) {
const result = await extra.taskStore.getTaskResult(extra.taskId);
return result as { content: Array<{ type: 'text'; text: string }> };
}
Expand Down Expand Up @@ -3105,14 +3105,14 @@ describe('Task-based execution', () => {

return { task };
},
async getTask(_args, extra) {
async getTask(extra) {
const task = await extra.taskStore.getTask(extra.taskId);
if (!task) {
throw new Error(`Task ${extra.taskId} not found`);
}
return task;
},
async getTaskResult(_args, extra) {
async getTaskResult(extra) {
const result = await extra.taskStore.getTaskResult(extra.taskId);
return result as { content: Array<{ type: 'text'; text: string }> };
}
Expand Down Expand Up @@ -3373,14 +3373,14 @@ test('should respect server task capabilities', async () => {

return { task };
},
async getTask(_args, extra) {
async getTask(extra) {
const task = await extra.taskStore.getTask(extra.taskId);
if (!task) {
throw new Error(`Task ${extra.taskId} not found`);
}
return task;
},
async getTaskResult(_args, extra) {
async getTaskResult(extra) {
const result = await extra.taskStore.getTaskResult(extra.taskId);
return result as { content: Array<{ type: 'text'; text: string }> };
}
Expand Down
Loading
Loading