Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions packages/core/src/shared/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ export abstract class Protocol<ContextT extends BaseContext> {
private _requestHandlers: Map<string, (request: JSONRPCRequest, ctx: ContextT) => Promise<Result>> = new Map();
private _requestHandlerAbortControllers: Map<RequestId, AbortController> = new Map();
private _notificationHandlers: Map<string, (notification: JSONRPCNotification) => Promise<void>> = new Map();
private _responseHandlers: Map<number, (response: JSONRPCResultResponse | Error) => void> = new Map();
private _progressHandlers: Map<number, ProgressCallback> = new Map();
private _timeoutInfo: Map<number, TimeoutInfo> = new Map();
private _responseHandlers: Map<RequestId, (response: JSONRPCResultResponse | Error) => void> = new Map();
private _progressHandlers: Map<RequestId, ProgressCallback> = new Map();
private _timeoutInfo: Map<RequestId, TimeoutInfo> = new Map();
private _pendingDebouncedNotifications = new Set<string>();

private _taskManager: TaskManager;
Expand Down Expand Up @@ -406,7 +406,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
}

private _setupTimeout(
messageId: number,
messageId: RequestId,
timeout: number,
maxTotalTimeout: number | undefined,
onTimeout: () => void,
Expand All @@ -422,7 +422,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
});
}

private _resetTimeout(messageId: number): boolean {
private _resetTimeout(messageId: RequestId): boolean {
const info = this._timeoutInfo.get(messageId);
if (!info) return false;

Expand All @@ -440,7 +440,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
return true;
}

private _cleanupTimeout(messageId: number) {
private _cleanupTimeout(messageId: RequestId) {
const info = this._timeoutInfo.get(messageId);
if (info) {
clearTimeout(info.timeoutId);
Expand Down Expand Up @@ -648,7 +648,7 @@ export abstract class Protocol<ContextT extends BaseContext> {

private _onprogress(notification: ProgressNotification): void {
const { progressToken, ...params } = notification.params;
const messageId = Number(progressToken);
const messageId = progressToken as RequestId;

const handler = this._progressHandlers.get(messageId);
if (!handler) {
Expand Down Expand Up @@ -676,7 +676,12 @@ export abstract class Protocol<ContextT extends BaseContext> {
}

private _onresponse(response: JSONRPCResponse | JSONRPCErrorResponse): void {
const messageId = Number(response.id);
// Handle responses without an ID (shouldn't happen for valid protocol messages)
if (response.id === undefined) {
this._onerror(new Error(`Received a response without an ID: ${JSON.stringify(response)}`));
return;
}
const messageId = response.id;

// Delegate to TaskManager for task-related response handling
const taskResult = this._taskManager.processInboundResponse(response, messageId);
Expand Down
14 changes: 9 additions & 5 deletions packages/core/src/shared/taskManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export interface TaskManagerHost {
request<T extends AnySchema>(request: Request, resultSchema: T, options?: RequestOptions): Promise<SchemaOutput<T>>;
notification(notification: Notification, options?: NotificationOptions): Promise<void>;
reportError(error: Error): void;
removeProgressHandler(token: number): void;
removeProgressHandler(token: RequestId): void;
registerHandler(method: string, handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise<Result>): void;
sendOnResponseStream(message: JSONRPCNotification | JSONRPCRequest, relatedRequestId: RequestId): Promise<void>;
enforceStrictCapabilities: boolean;
Expand Down Expand Up @@ -195,7 +195,7 @@ export function extractTaskManagerOptions(tasksCapability: TaskManagerOptions |
export class TaskManager {
private _taskStore?: TaskStore;
private _taskMessageQueue?: TaskMessageQueue;
private _taskProgressTokens: Map<string, number> = new Map();
private _taskProgressTokens: Map<string, RequestId> = new Map();
private _requestResolvers: Map<RequestId, (response: JSONRPCResultResponse | Error) => void> = new Map();
private _options: TaskManagerOptions;
private _host?: TaskManagerHost;
Expand Down Expand Up @@ -584,7 +584,11 @@ export class TaskManager {
}

private handleResponse(response: JSONRPCResponse | JSONRPCErrorResponse): boolean {
const messageId = Number(response.id);
// Skip responses without an ID
if (response.id === undefined) {
return false;
}
const messageId = response.id;
const resolver = this._requestResolvers.get(messageId);
if (resolver) {
this._requestResolvers.delete(messageId);
Expand All @@ -598,7 +602,7 @@ export class TaskManager {
return false;
}

private shouldPreserveProgressHandler(response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number): boolean {
private shouldPreserveProgressHandler(response: JSONRPCResponse | JSONRPCErrorResponse, messageId: RequestId): boolean {
if (isJSONRPCResultResponse(response) && response.result && typeof response.result === 'object') {
const result = response.result as Record<string, unknown>;
if (result.task && typeof result.task === 'object') {
Expand Down Expand Up @@ -764,7 +768,7 @@ export class TaskManager {

processInboundResponse(
response: JSONRPCResponse | JSONRPCErrorResponse,
messageId: number
messageId: RequestId
): { consumed: boolean; preserveProgress: boolean } {
const consumed = this.handleResponse(response);
if (consumed) {
Expand Down
Loading