Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/abort-handlers-on-close.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@modelcontextprotocol/core': patch
---

Abort in-flight request handlers when the connection closes. Previously, request handlers would continue running after the transport disconnected, wasting resources and preventing proper cleanup. Also fixes `InMemoryTransport.close()` firing `onclose` twice on the initiating side.
32 changes: 26 additions & 6 deletions packages/core/src/shared/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,11 @@ export abstract class Protocol<ContextT extends BaseContext> {
this._transport = transport;
const _onclose = this.transport?.onclose;
this._transport.onclose = () => {
_onclose?.();
this._onclose();
try {
_onclose?.();
} finally {
this._onclose();
}
};

const _onerror = this.transport?.onerror;
Expand Down Expand Up @@ -494,13 +497,28 @@ export abstract class Protocol<ContextT extends BaseContext> {
this._taskManager.onClose();
this._pendingDebouncedNotifications.clear();

for (const info of this._timeoutInfo.values()) {
clearTimeout(info.timeoutId);
}
this._timeoutInfo.clear();

const requestHandlerAbortControllers = this._requestHandlerAbortControllers;
this._requestHandlerAbortControllers = new Map();

const error = new SdkError(SdkErrorCode.ConnectionClosed, 'Connection closed');

this._transport = undefined;
this.onclose?.();

for (const handler of responseHandlers.values()) {
handler(error);
try {
this.onclose?.();
} finally {
for (const handler of responseHandlers.values()) {
handler(error);
}

for (const controller of requestHandlerAbortControllers.values()) {
controller.abort(error);
}
}
}

Expand Down Expand Up @@ -642,7 +660,9 @@ export abstract class Protocol<ContextT extends BaseContext> {
)
.catch(error => this._onerror(new Error(`Failed to send response: ${error}`)))
.finally(() => {
this._requestHandlerAbortControllers.delete(request.id);
if (this._requestHandlerAbortControllers.get(request.id) === abortController) {
this._requestHandlerAbortControllers.delete(request.id);
}
});
}

Expand Down
5 changes: 1 addition & 4 deletions packages/core/src/shared/taskManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,7 @@ export class TaskManager {

onClose(): void {
this._taskProgressTokens.clear();
this._requestResolvers.clear();
}

// -- Private helpers --
Expand Down Expand Up @@ -893,8 +894,4 @@ export class NullTaskManager extends TaskManager {
): Promise<{ queued: boolean; jsonrpcNotification?: JSONRPCNotification }> {
return { queued: false, jsonrpcNotification: { ...notification, jsonrpc: '2.0' } };
}

override onClose(): void {
// No-op
}
}
11 changes: 9 additions & 2 deletions packages/core/src/util/inMemory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ interface QueuedMessage {
export class InMemoryTransport implements Transport {
private _otherTransport?: InMemoryTransport;
private _messageQueue: QueuedMessage[] = [];
private _closed = false;

onclose?: () => void;
onerror?: (error: Error) => void;
Expand All @@ -39,10 +40,16 @@ export class InMemoryTransport implements Transport {
}

async close(): Promise<void> {
if (this._closed) return;
this._closed = true;

const other = this._otherTransport;
this._otherTransport = undefined;
await other?.close();
this.onclose?.();
try {
await other?.close();
} finally {
this.onclose?.();
}
}

/**
Expand Down
47 changes: 47 additions & 0 deletions packages/core/test/inMemory.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,53 @@ describe('InMemoryTransport', () => {
await expect(clientTransport.send({ jsonrpc: '2.0', method: 'test', id: 1 })).rejects.toThrow('Not connected');
});

test('should fire onclose exactly once per transport', async () => {
let clientCloseCount = 0;
let serverCloseCount = 0;

clientTransport.onclose = () => clientCloseCount++;
serverTransport.onclose = () => serverCloseCount++;

await clientTransport.close();

expect(clientCloseCount).toBe(1);
expect(serverCloseCount).toBe(1);
});

test('should handle double close idempotently', async () => {
let clientCloseCount = 0;
clientTransport.onclose = () => clientCloseCount++;

await clientTransport.close();
await clientTransport.close();

expect(clientCloseCount).toBe(1);
});

test('should handle concurrent close from both sides', async () => {
let clientCloseCount = 0;
let serverCloseCount = 0;

clientTransport.onclose = () => clientCloseCount++;
serverTransport.onclose = () => serverCloseCount++;

await Promise.all([clientTransport.close(), serverTransport.close()]);

expect(clientCloseCount).toBe(1);
expect(serverCloseCount).toBe(1);
});

test('should fire onclose even if peer onclose throws', async () => {
let clientCloseCount = 0;
clientTransport.onclose = () => clientCloseCount++;
serverTransport.onclose = () => {
throw new Error('boom');
};

await expect(clientTransport.close()).rejects.toThrow('boom');
expect(clientCloseCount).toBe(1);
});

test('should queue messages sent before start', async () => {
const message: JSONRPCMessage = {
jsonrpc: '2.0',
Expand Down
30 changes: 30 additions & 0 deletions packages/core/test/shared/protocol.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,36 @@ describe('protocol tests', () => {
expect(oncloseMock).toHaveBeenCalled();
});

test('should abort in-flight request handlers when the connection is closed', async () => {
await protocol.connect(transport);

let abortReason: unknown;
let handlerStarted = false;
const handlerDone = new Promise<void>(resolve => {
protocol.setRequestHandler('ping', async (_request, ctx) => {
handlerStarted = true;
await new Promise<void>(resolveInner => {
ctx.mcpReq.signal.addEventListener('abort', () => {
abortReason = ctx.mcpReq.signal.reason;
resolveInner();
});
});
resolve();
return {};
});
});

transport.onmessage?.({ jsonrpc: '2.0', id: 1, method: 'ping', params: {} });

await vi.waitFor(() => expect(handlerStarted).toBe(true));

await transport.close();
await handlerDone;

expect(abortReason).toBeInstanceOf(SdkError);
expect((abortReason as SdkError).code).toBe(SdkErrorCode.ConnectionClosed);
});

test('should not overwrite existing hooks when connecting transports', async () => {
const oncloseMock = vi.fn();
const onerrorMock = vi.fn();
Expand Down
Loading