diff --git a/.changeset/remove-tooltask-get-handlers.md b/.changeset/remove-tooltask-get-handlers.md new file mode 100644 index 000000000..132799837 --- /dev/null +++ b/.changeset/remove-tooltask-get-handlers.md @@ -0,0 +1,15 @@ +--- +"@modelcontextprotocol/core": minor +"@modelcontextprotocol/server": minor +--- + +Make `ToolTaskHandler.getTask`/`getTaskResult` optional and actually invoke them + +**Bug fix:** `getTask` and `getTaskResult` handlers registered via `registerToolTask` were never invoked — `tasks/get` and `tasks/result` requests always hit `TaskStore` directly. + +**Breaking changes (experimental API):** + +- `getTask` and `getTaskResult` are now **optional** on `ToolTaskHandler`. When omitted, `TaskStore` handles the requests (previous de-facto behavior). +- `TaskRequestHandler` signature changed: handlers receive only `(ctx: TaskServerContext)`, not the tool's input arguments. + +**Migration:** If your handlers just delegated to `ctx.task.store`, delete them. If you're proxying an external job system (Step Functions, CI/CD pipelines), keep them and drop the `args` parameter. diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index 0056795c3..66dfaaec8 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -96,6 +96,8 @@ Notes: | `ErrorCode.RequestTimeout` | `SdkErrorCode.RequestTimeout` | | `ErrorCode.ConnectionClosed` | `SdkErrorCode.ConnectionClosed` | | `StreamableHTTPError` | REMOVED (use `SdkError` with `SdkErrorCode.ClientHttp*`) | +| `ToolTaskHandler.getTask(args, ctx)` | `ToolTaskHandler.getTask?(ctx)` — now optional, no args | +| `ToolTaskHandler.getTaskResult(args, ctx)` | `ToolTaskHandler.getTaskResult?(ctx)` — now optional, no args | All other symbols from `@modelcontextprotocol/sdk/types.js` retain their original names (e.g., `CallToolResultSchema`, `ListToolsResultSchema`, etc.). diff --git a/docs/migration.md b/docs/migration.md index 21f8b67c9..f34693588 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -478,6 +478,37 @@ import { JSONRPCError, ResourceReference, isJSONRPCError } from '@modelcontextpr import { JSONRPCErrorResponse, ResourceTemplateReference, isJSONRPCErrorResponse } from '@modelcontextprotocol/server'; ``` +### `ToolTaskHandler.getTask` and `getTaskResult` are now optional (experimental) + +`getTask` and `getTaskResult` are now optional on `ToolTaskHandler`. When omitted, `tasks/get` and `tasks/result` are served directly from the configured `TaskStore`. Their signature has also changed — they no longer receive the tool's input arguments (which aren't available at `tasks/get`/`tasks/result` time). + +If your handlers just delegated to the store, delete them: + +**Before:** + +```typescript +server.experimental.tasks.registerToolTask('long-task', config, { + createTask: async (args, ctx) => { /* ... */ }, + getTask: async (args, ctx) => ctx.task.store.getTask(ctx.task.id), + getTaskResult: async (args, ctx) => ctx.task.store.getTaskResult(ctx.task.id) +}); +``` + +**After:** + +```typescript +server.experimental.tasks.registerToolTask('long-task', config, { + createTask: async (args, ctx) => { /* ... */ } +}); +``` + +Keep them if you're proxying an external job system (AWS Step Functions, CI/CD pipelines, etc.) — the new signature takes only `ctx`: + +```typescript +getTask: async (ctx) => describeStepFunctionExecution(ctx.task.id), +getTaskResult: async (ctx) => getStepFunctionOutput(ctx.task.id) +``` + ### Request handler context types The `RequestHandlerExtra` type has been replaced with a structured context type hierarchy using nested groups: diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index 1263f4bb5..09fbc905d 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -480,13 +480,6 @@ const getServer = () => { return { task }; - }, - async getTask(_args, ctx) { - return await ctx.task.store.getTask(ctx.task.id); - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; } } ); @@ -588,13 +581,6 @@ const getServer = () => { })(); return { task }; - }, - async getTask(_args, ctx) { - return await ctx.task.store.getTask(ctx.task.id); - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; } } ); diff --git a/packages/core/src/shared/taskManager.ts b/packages/core/src/shared/taskManager.ts index 28460d1d9..8c62eb9b4 100644 --- a/packages/core/src/shared/taskManager.ts +++ b/packages/core/src/shared/taskManager.ts @@ -154,6 +154,16 @@ export type TaskContext = { requestedTtl?: number | null; }; +/** + * Overrides for `tasks/get` and `tasks/result` lookups. Consulted before + * the configured {@linkcode TaskStore}; return `undefined` to fall through. + * @internal + */ +export type TaskLookupOverrides = { + getTask?: (taskId: string, ctx: BaseContext) => Promise; + getTaskResult?: (taskId: string, ctx: BaseContext) => Promise; +}; + export type TaskManagerOptions = { /** * Task storage implementation. Required for handling incoming task requests (server-side). @@ -199,6 +209,7 @@ export class TaskManager { private _requestResolvers: Map void> = new Map(); private _options: TaskManagerOptions; private _host?: TaskManagerHost; + private _overrides?: TaskLookupOverrides; constructor(options: TaskManagerOptions) { this._options = options; @@ -206,13 +217,22 @@ export class TaskManager { this._taskMessageQueue = options.taskMessageQueue; } + /** + * Installs per-task lookup overrides consulted before the {@linkcode TaskStore}. + * Used by McpServer to dispatch to per-tool `getTask`/`getTaskResult` handlers. + * @internal + */ + setTaskOverrides(overrides: TaskLookupOverrides): void { + this._overrides = overrides; + } + bind(host: TaskManagerHost): void { this._host = host; if (this._taskStore) { host.registerHandler('tasks/get', async (request, ctx) => { const params = request.params as { taskId: string }; - const task = await this.handleGetTask(params.taskId, ctx.sessionId); + const task = await this.handleGetTask(params.taskId, ctx); // Per spec: tasks/get responses SHALL NOT include related-task metadata // as the taskId parameter is the source of truth return { @@ -222,7 +242,7 @@ export class TaskManager { host.registerHandler('tasks/result', async (request, ctx) => { const params = request.params as { taskId: string }; - return await this.handleGetTaskPayload(params.taskId, ctx.sessionId, ctx.mcpReq.signal, async message => { + return await this.handleGetTaskPayload(params.taskId, ctx, async message => { // Send the message on the response stream by passing the relatedRequestId // This tells the transport to write the message to the tasks/result response stream await host.sendOnResponseStream(message, ctx.mcpReq.id); @@ -362,8 +382,11 @@ export class TaskManager { // -- Handler bodies (delegated from Protocol's registered handlers) -- - private async handleGetTask(taskId: string, sessionId?: string): Promise { - const task = await this._requireTaskStore.getTask(taskId, sessionId); + private async handleGetTask(taskId: string, ctx: BaseContext): Promise { + const override = await this._overrides?.getTask?.(taskId, ctx); + if (override !== undefined) return override; + + const task = await this._requireTaskStore.getTask(taskId, ctx.sessionId); if (!task) { throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); } @@ -372,10 +395,12 @@ export class TaskManager { private async handleGetTaskPayload( taskId: string, - sessionId: string | undefined, - signal: AbortSignal, + ctx: BaseContext, sendOnResponseStream: (message: JSONRPCNotification | JSONRPCRequest) => Promise ): Promise { + const sessionId = ctx.sessionId; + const signal = ctx.mcpReq.signal; + const handleTaskResult = async (): Promise => { if (this._taskMessageQueue) { let queuedMessage: QueuedMessage | undefined; @@ -404,17 +429,15 @@ export class TaskManager { } } - const task = await this._requireTaskStore.getTask(taskId, sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found: ${taskId}`); - } + const task = await this.handleGetTask(taskId, ctx); if (!isTerminal(task.status)) { await this._waitForTaskUpdate(task.pollInterval, signal); return await handleTaskResult(); } - const result = await this._requireTaskStore.getTaskResult(taskId, sessionId); + const override = await this._overrides?.getTaskResult?.(taskId, ctx); + const result = override ?? (await this._requireTaskStore.getTaskResult(taskId, sessionId)); await this._clearTaskQueue(taskId); return { diff --git a/packages/server/src/experimental/tasks/interfaces.ts b/packages/server/src/experimental/tasks/interfaces.ts index 2aef91a8c..31b33f1b1 100644 --- a/packages/server/src/experimental/tasks/interfaces.ts +++ b/packages/server/src/experimental/tasks/interfaces.ts @@ -13,7 +13,7 @@ import type { TaskServerContext } from '@modelcontextprotocol/core'; -import type { BaseToolCallback } from '../../server/mcp.js'; +import type { AnyToolHandler, BaseToolCallback } from '../../server/mcp.js'; // ============================================================================ // Task Handler Types (for registerToolTask) @@ -30,19 +30,27 @@ export type CreateTaskRequestHandler< /** * Handler for task operations (`get`, `getResult`). + * + * Receives only the context (no tool arguments — they are not available at + * `tasks/get` or `tasks/result` time). Access the task ID via `ctx.task.id`. + * * @experimental */ -export type TaskRequestHandler = BaseToolCallback< - SendResultT, - TaskServerContext, - Args ->; +export type TaskRequestHandler = (ctx: TaskServerContext) => SendResultT | Promise; /** * Interface for task-based tool handlers. * - * Task-based tools split a long-running operation into three phases: - * `createTask`, `getTask`, and `getTaskResult`. + * Task-based tools create a task on `tools/call` and by default let the SDK's + * `TaskStore` handle subsequent `tasks/get` and `tasks/result` requests. + * + * Provide `getTask` and `getTaskResult` to override the default lookups — useful + * when proxying an external job system (e.g., AWS Step Functions, CI/CD pipelines) + * where the external system is the source of truth for task state. + * + * **Note:** the taskId → tool mapping used to dispatch `getTask`/`getTaskResult` + * is held in-memory and does not survive server restarts or span multiple + * instances. In those scenarios, requests fall through to the `TaskStore`. * * @see {@linkcode @modelcontextprotocol/server!experimental/tasks/mcpServer.ExperimentalMcpServerTasks#registerToolTask | registerToolTask} for registration. * @experimental @@ -56,11 +64,23 @@ export interface ToolTaskHandler; /** - * Handler for `tasks/get` requests. + * Optional handler for `tasks/get` requests. When omitted, the configured + * `TaskStore` is consulted directly. */ - getTask: TaskRequestHandler; + getTask?: TaskRequestHandler; /** - * Handler for `tasks/result` requests. + * Optional handler for `tasks/result` requests. When omitted, the configured + * `TaskStore` is consulted directly. */ - getTaskResult: TaskRequestHandler; + getTaskResult?: TaskRequestHandler; +} + +/** + * Type guard for {@linkcode ToolTaskHandler}. + * @experimental + */ +export function isToolTaskHandler( + handler: AnyToolHandler +): handler is ToolTaskHandler { + return 'createTask' in handler; } diff --git a/packages/server/src/experimental/tasks/mcpServer.ts b/packages/server/src/experimental/tasks/mcpServer.ts index b7c28c40d..b54a82376 100644 --- a/packages/server/src/experimental/tasks/mcpServer.ts +++ b/packages/server/src/experimental/tasks/mcpServer.ts @@ -5,13 +5,25 @@ * @experimental */ -import type { StandardSchemaWithJSON, TaskToolExecution, ToolAnnotations, ToolExecution } from '@modelcontextprotocol/core'; +import type { + BaseContext, + CallToolResult, + GetTaskResult, + ServerContext, + StandardSchemaWithJSON, + TaskManager, + TaskServerContext, + TaskToolExecution, + ToolAnnotations, + ToolExecution +} from '@modelcontextprotocol/core'; import type { AnyToolHandler, McpServer, RegisteredTool } from '../../server/mcp.js'; import type { ToolTaskHandler } from './interfaces.js'; +import { isToolTaskHandler } from './interfaces.js'; /** - * Internal interface for accessing {@linkcode McpServer}'s private _createRegisteredTool method. + * Internal interface for accessing {@linkcode McpServer}'s private members. * @internal */ interface McpServerInternal { @@ -26,6 +38,7 @@ interface McpServerInternal { _meta: Record | undefined, handler: AnyToolHandler ): RegisteredTool; + _registeredTools: { [name: string]: RegisteredTool }; } /** @@ -39,14 +52,63 @@ interface McpServerInternal { * @experimental */ export class ExperimentalMcpServerTasks { + /** + * Maps taskId → toolName for tasks whose handlers define custom + * `getTask` or `getTaskResult`. In-memory only; after a server restart + * or on a different instance, lookups fall through to the TaskStore. + */ + private _taskToTool = new Map(); + constructor(private readonly _mcpServer: McpServer) {} + /** @internal */ + _installOverrides(taskManager: TaskManager): void { + taskManager.setTaskOverrides({ + getTask: (taskId, ctx) => this._dispatch(taskId, ctx, 'getTask'), + getTaskResult: (taskId, ctx) => this._dispatch(taskId, ctx, 'getTaskResult') + }); + } + + /** @internal */ + _recordTask(taskId: string, toolName: string): void { + const tool = (this._mcpServer as unknown as McpServerInternal)._registeredTools[toolName]; + if (tool && isToolTaskHandler(tool.handler) && (tool.handler.getTask || tool.handler.getTaskResult)) { + this._taskToTool.set(taskId, toolName); + } + } + + private async _dispatch( + taskId: string, + ctx: BaseContext, + method: M + ): Promise<(M extends 'getTask' ? GetTaskResult : CallToolResult) | undefined> { + const toolName = this._taskToTool.get(taskId); + if (!toolName) return undefined; + + const tool = (this._mcpServer as unknown as McpServerInternal)._registeredTools[toolName]; + if (!tool || !isToolTaskHandler(tool.handler)) return undefined; + + const handler = tool.handler[method]; + if (!handler) return undefined; + + const serverCtx = ctx as ServerContext; + if (!serverCtx.task?.store) return undefined; + + const taskCtx: TaskServerContext = { + ...serverCtx, + task: { ...serverCtx.task, id: taskId, store: serverCtx.task.store } + }; + + return handler(taskCtx) as M extends 'getTask' ? GetTaskResult : CallToolResult; + } + /** * Registers a task-based tool with a config object and handler. * * Task-based tools support long-running operations that can be polled for status - * and results. The handler must implement {@linkcode ToolTaskHandler.createTask | createTask}, {@linkcode ToolTaskHandler.getTask | getTask}, and {@linkcode ToolTaskHandler.getTaskResult | getTaskResult} - * methods. + * and results. The handler implements {@linkcode ToolTaskHandler.createTask | createTask} + * to start the task; subsequent `tasks/get` and `tasks/result` requests are served + * from the configured `TaskStore`. * * @example * ```typescript @@ -59,19 +121,13 @@ export class ExperimentalMcpServerTasks { * const task = await ctx.task.store.createTask({ ttl: 300000 }); * startBackgroundWork(task.taskId, args); * return { task }; - * }, - * getTask: async (args, ctx) => { - * return ctx.task.store.getTask(ctx.task.id); - * }, - * getTaskResult: async (args, ctx) => { - * return ctx.task.store.getTaskResult(ctx.task.id); * } * }); * ``` * * @param name - The tool name * @param config - Tool configuration (description, schemas, etc.) - * @param handler - Task handler with {@linkcode ToolTaskHandler.createTask | createTask}, {@linkcode ToolTaskHandler.getTask | getTask}, {@linkcode ToolTaskHandler.getTaskResult | getTaskResult} methods + * @param handler - Task handler with {@linkcode ToolTaskHandler.createTask | createTask} * @returns {@linkcode server/mcp.RegisteredTool | RegisteredTool} for managing the tool's lifecycle * * @experimental diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index c680dffe7..452e16b00 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -40,6 +40,7 @@ export { WebStandardStreamableHTTPServerTransport } from './server/streamableHtt // experimental exports export type { CreateTaskRequestHandler, TaskRequestHandler, ToolTaskHandler } from './experimental/tasks/interfaces.js'; +export { isToolTaskHandler } from './experimental/tasks/interfaces.js'; export { ExperimentalMcpServerTasks } from './experimental/tasks/mcpServer.js'; export { ExperimentalServerTasks } from './experimental/tasks/server.js'; diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index 4d9f81c50..ec8d6a9f1 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -40,6 +40,7 @@ import { } from '@modelcontextprotocol/core'; import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; +import { isToolTaskHandler } from '../experimental/tasks/interfaces.js'; import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcpServer.js'; import { getCompleter, isCompletable } from './completable.js'; import type { ServerOptions } from './server.js'; @@ -70,10 +71,12 @@ export class McpServer { } = {}; private _registeredTools: { [name: string]: RegisteredTool } = {}; private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; - private _experimental?: { tasks: ExperimentalMcpServerTasks }; + private _experimental: { tasks: ExperimentalMcpServerTasks }; constructor(serverInfo: Implementation, options?: ServerOptions) { this.server = new Server(serverInfo, options); + this._experimental = { tasks: new ExperimentalMcpServerTasks(this) }; + this._experimental.tasks._installOverrides(this.server.taskManager); } /** @@ -84,11 +87,6 @@ export class McpServer { * @experimental */ get experimental(): { tasks: ExperimentalMcpServerTasks } { - if (!this._experimental) { - this._experimental = { - tasks: new ExperimentalMcpServerTasks(this) - }; - } return this._experimental; } @@ -170,7 +168,7 @@ export class McpServer { try { const isTaskRequest = !!request.params.task; const taskSupport = tool.execution?.taskSupport; - const isTaskHandler = 'createTask' in (tool.handler as AnyToolHandler); + const isTaskHandler = isToolTaskHandler(tool.handler); // Validate task hint configuration if ((taskSupport === 'required' || taskSupport === 'optional') && !isTaskHandler) { @@ -199,6 +197,10 @@ export class McpServer { // Return CreateTaskResult immediately for task requests if (isTaskRequest) { + const createTaskResult = result as CreateTaskResult; + if (createTaskResult.task) { + this._experimental.tasks._recordTask(createTaskResult.task.taskId, request.params.name); + } return result; } @@ -318,22 +320,24 @@ export class McpServer { const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); const createTaskResult = (await tool.executor(args, ctx)) as CreateTaskResult; - // Poll until completion const taskId = createTaskResult.task.taskId; + this._experimental.tasks._recordTask(taskId, request.params.name); + + const handler = isToolTaskHandler(tool.handler) ? tool.handler : undefined; + const taskCtx = { ...ctx, task: { ...ctx.task, id: taskId, store: ctx.task.store } }; + + // Poll until completion — use custom handlers when provided, else TaskStore let task = createTaskResult.task; const pollInterval = task.pollInterval ?? 5000; while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') { await new Promise(resolve => setTimeout(resolve, pollInterval)); - const updatedTask = await ctx.task.store.getTask(taskId); - if (!updatedTask) { - throw new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} not found during polling`); - } - task = updatedTask; + task = handler?.getTask ? await handler.getTask(taskCtx) : await ctx.task.store.getTask(taskId); } - // Return the final result - return (await ctx.task.store.getTaskResult(taskId)) as CallToolResult; + return handler?.getTaskResult + ? await handler.getTaskResult(taskCtx) + : ((await ctx.task.store.getTaskResult(taskId)) as CallToolResult); } private _completionHandlerInitialized = false; @@ -1120,9 +1124,7 @@ function createToolExecutor( inputSchema: StandardSchemaWithJSON | undefined, handler: AnyToolHandler ): ToolExecutor { - const isTaskHandler = 'createTask' in handler; - - if (isTaskHandler) { + if (isToolTaskHandler(handler)) { const taskHandler = handler as TaskHandlerInternal; return async (args, ctx) => { if (!ctx.task?.store) { diff --git a/test/integration/test/client/client.test.ts b/test/integration/test/client/client.test.ts index 52d151bdd..d6ed272bb 100644 --- a/test/integration/test/client/client.test.ts +++ b/test/integration/test/client/client.test.ts @@ -2331,17 +2331,6 @@ describe('Task-based execution', () => { await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -2414,17 +2403,6 @@ describe('Task-based execution', () => { await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -2498,17 +2476,6 @@ describe('Task-based execution', () => { await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -2586,17 +2553,6 @@ describe('Task-based execution', () => { await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -3078,17 +3034,6 @@ describe('Task-based execution', () => { await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -3353,17 +3298,6 @@ test('should respect server task capabilities', async () => { await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); diff --git a/test/integration/test/server.test.ts b/test/integration/test/server.test.ts index 825af7ea4..465409610 100644 --- a/test/integration/test/server.test.ts +++ b/test/integration/test/server.test.ts @@ -2406,17 +2406,6 @@ describe('Task-based execution', () => { })(); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -2652,17 +2641,6 @@ describe('Task-based execution', () => { })(); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -3147,17 +3125,6 @@ describe('Task-based execution', () => { })(); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index 967435834..232057d64 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1,5 +1,5 @@ import { Client } from '@modelcontextprotocol/client'; -import type { CallToolResult, Notification, TextContent } from '@modelcontextprotocol/core'; +import type { CallToolResult, CreateTaskResult, Notification, TextContent } from '@modelcontextprotocol/core'; import { getDisplayName, InMemoryTaskStore, @@ -2062,14 +2062,6 @@ describe('Zod v4', () => { createTask: async (_args, ctx) => { const task = await ctx.task.store.createTask({ ttl: 60_000 }); return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) throw new Error('Task not found'); - return task; - }, - getTaskResult: async (_args, ctx) => { - return (await ctx.task.store.getTaskResult(ctx.task.id)) as CallToolResult; } } ); @@ -2132,14 +2124,6 @@ describe('Zod v4', () => { createTask: async (_args, ctx) => { const task = await ctx.task.store.createTask({ ttl: 60_000 }); return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) throw new Error('Task not found'); - return task; - }, - getTaskResult: async (_args, ctx) => { - return (await ctx.task.store.getTaskResult(ctx.task.id)) as CallToolResult; } } ); @@ -6422,17 +6406,6 @@ describe('Zod v4', () => { }, 200); return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async (_input, ctx) => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; } } ); @@ -6526,17 +6499,6 @@ describe('Zod v4', () => { }, 150); return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async (_value, ctx) => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; } } ); @@ -6632,17 +6594,6 @@ describe('Zod v4', () => { }, 200); return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async (_data, ctx) => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; } } ); @@ -6751,17 +6702,6 @@ describe('Zod v4', () => { }, 150); return { task }; - }, - getTask: async ctx => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async ctx => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; } } ); @@ -6852,17 +6792,6 @@ describe('Zod v4', () => { }, 150); return { task }; - }, - getTask: async ctx => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async ctx => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; } } ); @@ -6886,6 +6815,123 @@ describe('Zod v4', () => { taskStore.cleanup(); }); + test('should invoke optional getTask/getTaskResult handlers when provided', async () => { + const taskStore = new InMemoryTaskStore(); + const { releaseLatch, waitForLatch } = createLatch(); + + const mcpServer = new McpServer( + { name: 'test server', version: '1.0' }, + { capabilities: { tools: {}, tasks: { requests: { tools: { call: {} } }, taskStore } } } + ); + const client = new Client( + { name: 'test client', version: '1.0' }, + { capabilities: { tasks: { requests: { tools: { call: {} } } } } } + ); + + const getTask = vi.fn(async ctx => ctx.task.store.getTask(ctx.task.id)); + const getTaskResult = vi.fn(async ctx => (await ctx.task.store.getTaskResult(ctx.task.id)) as CallToolResult); + + mcpServer.experimental.tasks.registerToolTask( + 'proxy-task', + { inputSchema: z.object({ n: z.number() }), execution: { taskSupport: 'required' } }, + { + createTask: async ({ n }, ctx) => { + const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 50 }); + const store = ctx.task.store; + setTimeout(async () => { + await store.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text' as const, text: `n=${n}` }] + }); + releaseLatch(); + }, 100); + return { task }; + }, + getTask, + getTaskResult + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + const created = (await client.callTool( + { name: 'proxy-task', arguments: { n: 7 } }, + { task: { ttl: 60_000 } } + )) as CreateTaskResult; + const taskId = created.task.taskId; + + await waitForLatch(); + + // tasks/get should route to the user's getTask handler + const task = await client.experimental.tasks.getTask(taskId); + expect(task.status).toBe('completed'); + expect(getTask).toHaveBeenCalled(); + + // tasks/result should route to the user's getTaskResult handler + const result = await client.experimental.tasks.getTaskResult(taskId); + expect((result as CallToolResult).content).toEqual([{ type: 'text', text: 'n=7' }]); + expect(getTaskResult).toHaveBeenCalledTimes(1); + + taskStore.cleanup(); + }); + + test('should fall through to TaskStore when getTask/getTaskResult are omitted', async () => { + const taskStore = new InMemoryTaskStore(); + const { releaseLatch, waitForLatch } = createLatch(); + + const mcpServer = new McpServer( + { name: 'test server', version: '1.0' }, + { capabilities: { tools: {}, tasks: { requests: { tools: { call: {} } }, taskStore } } } + ); + const client = new Client( + { name: 'test client', version: '1.0' }, + { capabilities: { tasks: { requests: { tools: { call: {} } } } } } + ); + + const storeGetTask = vi.spyOn(taskStore, 'getTask'); + const storeGetTaskResult = vi.spyOn(taskStore, 'getTaskResult'); + + mcpServer.experimental.tasks.registerToolTask( + 'store-only-task', + { inputSchema: z.object({}), execution: { taskSupport: 'required' } }, + { + createTask: async (_args, ctx) => { + const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 50 }); + const store = ctx.task.store; + setTimeout(async () => { + await store.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text' as const, text: 'done' }] + }); + releaseLatch(); + }, 100); + return { task }; + } + // no getTask, no getTaskResult + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + const created = (await client.callTool( + { name: 'store-only-task', arguments: {} }, + { task: { ttl: 60_000 } } + )) as CreateTaskResult; + const taskId = created.task.taskId; + + await waitForLatch(); + storeGetTask.mockClear(); + storeGetTaskResult.mockClear(); + + await client.experimental.tasks.getTask(taskId); + expect(storeGetTask).toHaveBeenCalled(); + + await client.experimental.tasks.getTaskResult(taskId); + expect(storeGetTaskResult).toHaveBeenCalled(); + + taskStore.cleanup(); + }); + test('should raise error when registerToolTask is called with taskSupport "forbidden"', () => { const taskStore = new InMemoryTaskStore(); @@ -6927,17 +6973,6 @@ describe('Zod v4', () => { createTask: async (_args, ctx) => { const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async (_args, ctx) => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; } } ); diff --git a/test/integration/test/taskLifecycle.test.ts b/test/integration/test/taskLifecycle.test.ts index 6281e833d..bffb267a3 100644 --- a/test/integration/test/taskLifecycle.test.ts +++ b/test/integration/test/taskLifecycle.test.ts @@ -84,17 +84,6 @@ describe('Task Lifecycle Integration Tests', () => { })(); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -165,17 +154,6 @@ describe('Task Lifecycle Integration Tests', () => { })(); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -459,17 +437,6 @@ describe('Task Lifecycle Integration Tests', () => { })(); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -921,17 +888,6 @@ describe('Task Lifecycle Integration Tests', () => { }); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -1118,17 +1074,6 @@ describe('Task Lifecycle Integration Tests', () => { })(); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } ); @@ -1314,17 +1259,6 @@ describe('Task Lifecycle Integration Tests', () => { })(); return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; } } );