diff --git a/.changeset/reconnection-scheduler.md b/.changeset/reconnection-scheduler.md new file mode 100644 index 000000000..add9bd6fc --- /dev/null +++ b/.changeset/reconnection-scheduler.md @@ -0,0 +1,5 @@ +--- +'@modelcontextprotocol/client': minor +--- + +Add `reconnectionScheduler` option to `StreamableHTTPClientTransport`. Lets non-persistent environments (serverless, mobile, desktop sleep/wake) override the default `setTimeout`-based SSE reconnection scheduling. The scheduler may return a cancel function that is invoked on `transport.close()`. diff --git a/packages/client/src/client/streamableHttp.examples.ts b/packages/client/src/client/streamableHttp.examples.ts new file mode 100644 index 000000000..74023fa51 --- /dev/null +++ b/packages/client/src/client/streamableHttp.examples.ts @@ -0,0 +1,31 @@ +/** + * Type-checked examples for `streamableHttp.ts`. + * + * These examples are synced into JSDoc comments via the sync-snippets script. + * Each function's region markers define the code snippet that appears in the docs. + * + * @module + */ + +/* eslint-disable unicorn/consistent-function-scoping -- examples must live inside region blocks */ + +import type { ReconnectionScheduler } from './streamableHttp.js'; + +// Stub for a hypothetical platform-specific background scheduling API +declare const platformBackgroundTask: { + schedule(callback: () => void, delay: number): number; + cancel(id: number): void; +}; + +/** + * Example: Using a platform background-task API to schedule reconnections. + */ +function ReconnectionScheduler_basicUsage() { + //#region ReconnectionScheduler_basicUsage + const scheduler: ReconnectionScheduler = (reconnect, delay) => { + const id = platformBackgroundTask.schedule(reconnect, delay); + return () => platformBackgroundTask.cancel(id); + }; + //#endregion ReconnectionScheduler_basicUsage + return scheduler; +} diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index 3d45b60e9..475f7ce86 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -78,6 +78,31 @@ export interface StreamableHTTPReconnectionOptions { maxRetries: number; } +/** + * Custom scheduler for SSE stream reconnection attempts. + * + * Called instead of `setTimeout` when the transport needs to schedule a reconnection. + * Useful in environments where `setTimeout` is unsuitable (serverless functions that + * terminate before the timer fires, mobile apps that need platform background scheduling, + * desktop apps handling sleep/wake). + * + * @param reconnect - Call this to perform the reconnection attempt. + * @param delay - Suggested delay in milliseconds (from backoff calculation). + * @param attemptCount - Zero-indexed retry attempt number. + * @returns An optional cancel function. If returned, it will be called on + * {@linkcode StreamableHTTPClientTransport.close | transport.close()} to abort the + * pending reconnection. + * + * @example + * ```ts source="./streamableHttp.examples.ts#ReconnectionScheduler_basicUsage" + * const scheduler: ReconnectionScheduler = (reconnect, delay) => { + * const id = platformBackgroundTask.schedule(reconnect, delay); + * return () => platformBackgroundTask.cancel(id); + * }; + * ``` + */ +export type ReconnectionScheduler = (reconnect: () => void, delay: number, attemptCount: number) => (() => void) | void; + /** * Configuration options for the {@linkcode StreamableHTTPClientTransport}. */ @@ -116,6 +141,12 @@ export type StreamableHTTPClientTransportOptions = { */ reconnectionOptions?: StreamableHTTPReconnectionOptions; + /** + * Custom scheduler for reconnection attempts. If not provided, `setTimeout` is used. + * See {@linkcode ReconnectionScheduler}. + */ + reconnectionScheduler?: ReconnectionScheduler; + /** * Session ID for the connection. This is used to identify the session on the server. * When not provided and connecting to a server that supports session IDs, the server will generate a new session ID. @@ -150,7 +181,8 @@ export class StreamableHTTPClientTransport implements Transport { private _protocolVersion?: string; private _lastUpscopingHeader?: string; // Track last upscoping header to prevent infinite upscoping. private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field - private _reconnectionTimeout?: ReturnType; + private readonly _reconnectionScheduler?: ReconnectionScheduler; + private _cancelReconnection?: () => void; onclose?: () => void; onerror?: (error: Error) => void; @@ -172,6 +204,7 @@ export class StreamableHTTPClientTransport implements Transport { this._sessionId = opts?.sessionId; this._protocolVersion = opts?.protocolVersion; this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; + this._reconnectionScheduler = opts?.reconnectionScheduler; } private async _commonHeaders(): Promise { @@ -305,15 +338,28 @@ export class StreamableHTTPClientTransport implements Transport { // Calculate next delay based on current attempt count const delay = this._getNextReconnectionDelay(attemptCount); - // Schedule the reconnection - this._reconnectionTimeout = setTimeout(() => { - // Use the last event ID to resume where we left off + const reconnect = (): void => { + this._cancelReconnection = undefined; + if (this._abortController?.signal.aborted) return; this._startOrAuthSse(options).catch(error => { this.onerror?.(new Error(`Failed to reconnect SSE stream: ${error instanceof Error ? error.message : String(error)}`)); - // Schedule another attempt if this one failed, incrementing the attempt counter - this._scheduleReconnection(options, attemptCount + 1); + try { + this._scheduleReconnection(options, attemptCount + 1); + } catch (scheduleError) { + this.onerror?.(scheduleError instanceof Error ? scheduleError : new Error(String(scheduleError))); + } }); - }, delay); + }; + + if (this._reconnectionScheduler) { + const cancel = this._reconnectionScheduler(reconnect, delay, attemptCount); + if (typeof cancel === 'function') { + this._cancelReconnection = cancel; + } + } else { + const handle = setTimeout(reconnect, delay); + this._cancelReconnection = () => clearTimeout(handle); + } } private _handleSseStream(stream: ReadableStream | null, options: StartSSEOptions, isReconnectable: boolean): void { @@ -458,12 +504,13 @@ export class StreamableHTTPClientTransport implements Transport { } async close(): Promise { - if (this._reconnectionTimeout) { - clearTimeout(this._reconnectionTimeout); - this._reconnectionTimeout = undefined; + try { + this._cancelReconnection?.(); + } finally { + this._cancelReconnection = undefined; + this._abortController?.abort(); + this.onclose?.(); } - this._abortController?.abort(); - this.onclose?.(); } async send( diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index 3abbd628c..d1af95103 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -62,7 +62,12 @@ export type { SSEClientTransportOptions } from './client/sse.js'; export { SSEClientTransport, SseError } from './client/sse.js'; export type { StdioServerParameters } from './client/stdio.js'; export { DEFAULT_INHERITED_ENV_VARS, getDefaultEnvironment, StdioClientTransport } from './client/stdio.js'; -export type { StartSSEOptions, StreamableHTTPClientTransportOptions, StreamableHTTPReconnectionOptions } from './client/streamableHttp.js'; +export type { + ReconnectionScheduler, + StartSSEOptions, + StreamableHTTPClientTransportOptions, + StreamableHTTPReconnectionOptions +} from './client/streamableHttp.js'; export { StreamableHTTPClientTransport } from './client/streamableHttp.js'; export { WebSocketClientTransport } from './client/websocket.js'; diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index 55bf79a50..ad376c2e0 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -4,7 +4,7 @@ import type { Mock, Mocked } from 'vitest'; import type { OAuthClientProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; -import type { StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js'; +import type { ReconnectionScheduler, StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js'; import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; describe('StreamableHTTPClientTransport', () => { @@ -1617,8 +1617,8 @@ describe('StreamableHTTPClientTransport', () => { }) ); - // Verify no timeout was scheduled (no reconnection attempt) - expect(transport['_reconnectionTimeout']).toBeUndefined(); + // Verify no reconnection was scheduled + expect(transport['_cancelReconnection']).toBeUndefined(); }); it('should schedule reconnection when maxRetries is greater than 0', async () => { @@ -1640,10 +1640,10 @@ describe('StreamableHTTPClientTransport', () => { // ASSERT - should schedule a reconnection, not report error yet expect(errorSpy).not.toHaveBeenCalled(); - expect(transport['_reconnectionTimeout']).toBeDefined(); + expect(transport['_cancelReconnection']).toBeDefined(); - // Clean up the timeout to avoid test pollution - clearTimeout(transport['_reconnectionTimeout']); + // Clean up the pending reconnection to avoid test pollution + transport['_cancelReconnection']?.(); }); }); @@ -1716,4 +1716,140 @@ describe('StreamableHTTPClientTransport', () => { }); }); }); + + describe('reconnectionScheduler', () => { + const reconnectionOptions: StreamableHTTPReconnectionOptions = { + initialReconnectionDelay: 1000, + maxReconnectionDelay: 5000, + reconnectionDelayGrowFactor: 2, + maxRetries: 3 + }; + + function triggerReconnection(t: StreamableHTTPClientTransport): void { + (t as unknown as { _scheduleReconnection(opts: StartSSEOptions, attempt?: number): void })._scheduleReconnection({}, 0); + } + + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it('invokes the custom scheduler with reconnect, delay, and attemptCount', () => { + const scheduler = vi.fn(); + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions, + reconnectionScheduler: scheduler + }); + + triggerReconnection(transport); + + expect(scheduler).toHaveBeenCalledTimes(1); + expect(scheduler).toHaveBeenCalledWith(expect.any(Function), 1000, 0); + }); + + it('falls back to setTimeout when no scheduler is provided', () => { + const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions + }); + + triggerReconnection(transport); + + expect(setTimeoutSpy).toHaveBeenCalledWith(expect.any(Function), 1000); + }); + + it('does not use setTimeout when a custom scheduler is provided', () => { + const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions, + reconnectionScheduler: vi.fn() + }); + + triggerReconnection(transport); + + expect(setTimeoutSpy).not.toHaveBeenCalled(); + }); + + it('calls the returned cancel function on close()', async () => { + const cancel = vi.fn(); + const scheduler: ReconnectionScheduler = vi.fn(() => cancel); + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions, + reconnectionScheduler: scheduler + }); + + triggerReconnection(transport); + expect(cancel).not.toHaveBeenCalled(); + + await transport.close(); + expect(cancel).toHaveBeenCalledTimes(1); + }); + + it('tolerates schedulers that return void (no cancel function)', async () => { + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions, + reconnectionScheduler: () => { + /* no return */ + } + }); + + triggerReconnection(transport); + await expect(transport.close()).resolves.toBeUndefined(); + }); + + it('clears the default setTimeout on close() when no scheduler is provided', async () => { + const clearTimeoutSpy = vi.spyOn(global, 'clearTimeout'); + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions + }); + + triggerReconnection(transport); + await transport.close(); + + expect(clearTimeoutSpy).toHaveBeenCalledTimes(1); + }); + + it('ignores a late-firing reconnect after close()', async () => { + let capturedReconnect: (() => void) | undefined; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions, + reconnectionScheduler: reconnect => { + capturedReconnect = reconnect; + } + }); + const onerror = vi.fn(); + transport.onerror = onerror; + + await transport.start(); + triggerReconnection(transport); + await transport.close(); + + capturedReconnect?.(); + await vi.runAllTimersAsync(); + + expect(onerror).not.toHaveBeenCalled(); + }); + + it('still aborts and fires onclose if the cancel function throws', async () => { + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions, + reconnectionScheduler: () => () => { + throw new Error('cancel failed'); + } + }); + const onclose = vi.fn(); + transport.onclose = onclose; + + await transport.start(); + triggerReconnection(transport); + const abortController = transport['_abortController']; + + await expect(transport.close()).rejects.toThrow('cancel failed'); + expect(abortController?.signal.aborted).toBe(true); + expect(onclose).toHaveBeenCalledTimes(1); + }); + }); });