diff --git a/pkg/wshutil/wshrpc.go b/pkg/wshutil/wshrpc.go index a1ea1f71ae..37bb488882 100644 --- a/pkg/wshutil/wshrpc.go +++ b/pkg/wshutil/wshrpc.go @@ -186,7 +186,7 @@ type rpcData struct { Command string Route string ResCh chan *RpcMessage - Ctx context.Context + Handler *RpcRequestHandler } func validateServerImpl(serverImpl ServerImpl) { @@ -405,21 +405,21 @@ func (w *WshRpc) SetServerImpl(serverImpl ServerImpl) { w.ServerImpl = serverImpl } -func (w *WshRpc) registerRpc(ctx context.Context, command string, route string, reqId string) chan *RpcMessage { +func (w *WshRpc) registerRpc(handler *RpcRequestHandler, command string, route string, reqId string) chan *RpcMessage { w.Lock.Lock() defer w.Lock.Unlock() rpcCh := make(chan *RpcMessage, RespChSize) w.RpcMap[reqId] = &rpcData{ + Handler: handler, Command: command, Route: route, ResCh: rpcCh, - Ctx: ctx, } go func() { defer func() { panichandler.PanicHandler("registerRpc:timeout", recover()) }() - <-ctx.Done() + <-handler.ctx.Done() w.retrySendTimeout(reqId) }() return rpcCh @@ -447,6 +447,7 @@ func (w *WshRpc) unregisterRpc(reqId string, err error) { } delete(w.RpcMap, reqId) close(rd.ResCh) + rd.Handler.callContextCancelFn() } // no response @@ -541,16 +542,19 @@ func (handler *RpcRequestHandler) NextResponse() (any, error) { } func (handler *RpcRequestHandler) finalize() { - cancelFnPtr := handler.ctxCancelFn.Load() - if cancelFnPtr != nil && *cancelFnPtr != nil { - (*cancelFnPtr)() - handler.ctxCancelFn.Store(nil) - } + handler.callContextCancelFn() if handler.reqId != "" { handler.w.unregisterRpc(handler.reqId, nil) } } +func (handler *RpcRequestHandler) callContextCancelFn() { + cancelFnPtr := handler.ctxCancelFn.Swap(nil) + if cancelFnPtr != nil && *cancelFnPtr != nil { + (*cancelFnPtr)() + } +} + type RpcResponseHandler struct { w *WshRpc ctx context.Context @@ -710,7 +714,7 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp if err != nil { return nil, err } - handler.respCh = w.registerRpc(handler.ctx, command, opts.Route, handler.reqId) + handler.respCh = w.registerRpc(handler, command, opts.Route, handler.reqId) w.OutputCh <- barr return handler, nil }