diff --git a/.changeset/fix-abort-listener-leak.md b/.changeset/fix-abort-listener-leak.md new file mode 100644 index 000000000..f1dd3163b --- /dev/null +++ b/.changeset/fix-abort-listener-leak.md @@ -0,0 +1,5 @@ +--- +'@modelcontextprotocol/core': patch +--- + +Consolidate per-request cleanup in `_requestWithSchema` into a single `.finally()` block. This fixes an abort signal listener leak (listeners accumulated when a caller reused one `AbortSignal` across requests) and two cases where `_responseHandlers` entries leaked on send-failure paths. diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index d6daf0172..ffa642998 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -800,6 +800,9 @@ export abstract class Protocol { ): Promise> { const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; + let onAbort: (() => void) | undefined; + let cleanupMessageId: number | undefined; + // Send the request return new Promise>((resolve, reject) => { const earlyReject = (error: unknown) => { @@ -823,6 +826,7 @@ export abstract class Protocol { options?.signal?.throwIfAborted(); const messageId = this._requestMessageId++; + cleanupMessageId = messageId; const jsonrpcRequest: JSONRPCRequest = { ...request, jsonrpc: '2.0', @@ -841,9 +845,7 @@ export abstract class Protocol { } const cancel = (reason: unknown) => { - this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); - this._cleanupTimeout(messageId); this._transport ?.send( @@ -885,9 +887,8 @@ export abstract class Protocol { } }); - options?.signal?.addEventListener('abort', () => { - cancel(options?.signal?.reason); - }); + onAbort = () => cancel(options?.signal?.reason); + options?.signal?.addEventListener('abort', onAbort, { once: true }); const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; const timeoutHandler = () => cancel(new SdkError(SdkErrorCode.RequestTimeout, 'Request timed out', { timeout })); @@ -907,16 +908,14 @@ export abstract class Protocol { let outboundQueued = false; try { const taskResult = this._taskManager.processOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, error => { - this._cleanupTimeout(messageId); + this._progressHandlers.delete(messageId); reject(error); }); if (taskResult.queued) { outboundQueued = true; } } catch (error) { - this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); - this._cleanupTimeout(messageId); reject(error); return; } @@ -924,10 +923,23 @@ export abstract class Protocol { if (!outboundQueued) { // No related task or no module - send through transport normally this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { - this._cleanupTimeout(messageId); + this._progressHandlers.delete(messageId); reject(error); }); } + }).finally(() => { + // Per-request cleanup that must run on every exit path. Consolidated + // here so new exit paths added to the promise body can't forget it. + // _progressHandlers is NOT cleaned up here: _onresponse deletes it + // conditionally (preserveProgress for task flows), and error paths + // above delete it inline since no task exists in those cases. + if (onAbort) { + options?.signal?.removeEventListener('abort', onAbort); + } + if (cleanupMessageId !== undefined) { + this._responseHandlers.delete(cleanupMessageId); + this._cleanupTimeout(cleanupMessageId); + } }); } diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index 69735bc3a..619e09376 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -247,6 +247,64 @@ describe('protocol tests', () => { expect((abortReason as SdkError).code).toBe(SdkErrorCode.ConnectionClosed); }); + test('should remove abort listener from caller signal when request settles', async () => { + await protocol.connect(transport); + + const controller = new AbortController(); + const addSpy = vi.spyOn(controller.signal, 'addEventListener'); + const removeSpy = vi.spyOn(controller.signal, 'removeEventListener'); + + const mockSchema = z.object({ result: z.string() }); + const reqPromise = testRequest(protocol, { method: 'example', params: {} }, mockSchema, { + signal: controller.signal + }); + + expect(addSpy).toHaveBeenCalledTimes(1); + const listener = addSpy.mock.calls[0]![1]; + + transport.onmessage?.({ jsonrpc: '2.0', id: 0, result: { result: 'ok' } }); + await reqPromise; + + expect(removeSpy).toHaveBeenCalledWith('abort', listener); + }); + + test('should not accumulate abort listeners when reusing a signal across requests', async () => { + await protocol.connect(transport); + + const controller = new AbortController(); + const addSpy = vi.spyOn(controller.signal, 'addEventListener'); + const removeSpy = vi.spyOn(controller.signal, 'removeEventListener'); + + const mockSchema = z.object({ result: z.string() }); + for (let i = 0; i < 5; i++) { + const reqPromise = testRequest(protocol, { method: 'example', params: {} }, mockSchema, { + signal: controller.signal + }); + transport.onmessage?.({ jsonrpc: '2.0', id: i, result: { result: 'ok' } }); + await reqPromise; + } + + expect(addSpy).toHaveBeenCalledTimes(5); + expect(removeSpy).toHaveBeenCalledTimes(5); + }); + + test('should remove abort listener when request rejects', async () => { + await protocol.connect(transport); + + const controller = new AbortController(); + const removeSpy = vi.spyOn(controller.signal, 'removeEventListener'); + + const mockSchema = z.object({ result: z.string() }); + await expect( + testRequest(protocol, { method: 'example', params: {} }, mockSchema, { + signal: controller.signal, + timeout: 0 + }) + ).rejects.toThrow(); + + expect(removeSpy).toHaveBeenCalledWith('abort', expect.any(Function)); + }); + test('should not overwrite existing hooks when connecting transports', async () => { const oncloseMock = vi.fn(); const onerrorMock = vi.fn();