diff --git a/cmd/server/main-server.go b/cmd/server/main-server.go index 4cfe5c76b4..dfb9695134 100644 --- a/cmd/server/main-server.go +++ b/cmd/server/main-server.go @@ -196,7 +196,7 @@ func createMainWshClient() { wshfs.RpcClient = rpc wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc, true) wps.Broker.SetClient(wshutil.DefaultRouter) - localConnWsh := wshutil.MakeWshRpc(nil, nil, wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{}) + localConnWsh := wshutil.MakeWshRpc(nil, nil, wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{}, "conn:local") go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName) wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh, true) } diff --git a/cmd/wsh/cmd/wshcmd-connserver.go b/cmd/wsh/cmd/wshcmd-connserver.go index 995eb0bb5a..3a3e20d643 100644 --- a/cmd/wsh/cmd/wshcmd-connserver.go +++ b/cmd/wsh/cmd/wshcmd-connserver.go @@ -137,7 +137,7 @@ func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter, jwtToken stri } inputCh := make(chan []byte, wshutil.DefaultInputChSize) outputCh := make(chan []byte, wshutil.DefaultOutputChSize) - connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout}) + connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout}, authRtn.RouteId) connServerClient.SetAuthToken(authRtn.AuthToken) router.RegisterRoute(authRtn.RouteId, connServerClient, false) wshclient.RouteAnnounceCommand(connServerClient, nil) diff --git a/cmd/wsh/cmd/wshcmd-root.go b/cmd/wsh/cmd/wshcmd-root.go index 3502915c4e..18ef5c8cff 100644 --- a/cmd/wsh/cmd/wshcmd-root.go +++ b/cmd/wsh/cmd/wshcmd-root.go @@ -86,7 +86,7 @@ func preRunSetupRpcClient(cmd *cobra.Command, args []string) error { if jwtToken == "" { wshutil.SetTermRawModeAndInstallShutdownHandlers(true) UsingTermWshMode = true - RpcClient, WrappedStdin = wshutil.SetupTerminalRpcClient(nil) + RpcClient, WrappedStdin = wshutil.SetupTerminalRpcClient(nil, "wshcmd-termclient") return nil } err := setupRpcClient(nil, jwtToken) @@ -148,7 +148,7 @@ func setupRpcClientWithToken(swapTokenStr string) (wshrpc.CommandAuthenticateRtn return rtn, fmt.Errorf("no rpccontext in token") } RpcContext = *token.RpcContext - RpcClient, err = wshutil.SetupDomainSocketRpcClient(token.SockName, nil) + RpcClient, err = wshutil.SetupDomainSocketRpcClient(token.SockName, nil, "wshcmd") if err != nil { return rtn, fmt.Errorf("error setting up domain socket rpc client: %w", err) } @@ -166,7 +166,7 @@ func setupRpcClient(serverImpl wshutil.ServerImpl, jwtToken string) error { if err != nil { return fmt.Errorf("error extracting socket name from %s: %v", wshutil.WaveJwtTokenVarName, err) } - RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, serverImpl) + RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, serverImpl, "wshcmd") if err != nil { return fmt.Errorf("error setting up domain socket rpc client: %v", err) } diff --git a/cmd/wsh/cmd/wshcmd-test.go b/cmd/wsh/cmd/wshcmd-test.go new file mode 100644 index 0000000000..20ec59e868 --- /dev/null +++ b/cmd/wsh/cmd/wshcmd-test.go @@ -0,0 +1,24 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "github.com/spf13/cobra" +) + +var testCmd = &cobra.Command{ + Use: "test", + Hidden: true, + Short: "test command", + PreRunE: preRunSetupRpcClient, + RunE: runTestCmd, +} + +func init() { + rootCmd.AddCommand(testCmd) +} + +func runTestCmd(cmd *cobra.Command, args []string) error { + return nil +} diff --git a/go.mod b/go.mod index 1df0a0c274..b32bef0ba7 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/s3 v1.74.1 github.com/aws/smithy-go v1.22.2 github.com/creack/pty v1.1.21 + github.com/emirpasic/gods v1.18.1 github.com/fsnotify/fsnotify v1.8.0 github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang-migrate/migrate/v4 v4.18.1 diff --git a/go.sum b/go.sum index c43eb254ac..64e8d740f6 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/ebitengine/purego v0.8.1 h1:sdRKd6plj7KYW33EH5As6YKfe8m9zbN9JMrOjNVF/BE= github.com/ebitengine/purego v0.8.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= diff --git a/pkg/util/ds/expmap.go b/pkg/util/ds/expmap.go new file mode 100644 index 0000000000..cc5aa0a734 --- /dev/null +++ b/pkg/util/ds/expmap.go @@ -0,0 +1,87 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package ds + +import ( + "sync" + "time" + + "github.com/emirpasic/gods/trees/binaryheap" +) + +// an ExpMap has "expiring" keys, which are automatically deleted after a certain time + +type ExpMap[T any] struct { + lock *sync.Mutex + expHeap *binaryheap.Heap // heap of expEntries (sorted by time) + m map[string]expMapEntry[T] +} + +type expMapEntry[T any] struct { + Val T + Exp time.Time +} + +type expEntry struct { + Key string + Exp time.Time +} + +func heapComparator(aArg, bArg any) int { + a := aArg.(expEntry) + b := bArg.(expEntry) + if a.Exp.Before(b.Exp) { + return -1 + } else if a.Exp.After(b.Exp) { + return 1 + } + return 0 +} + +func MakeExpMap[T any]() *ExpMap[T] { + return &ExpMap[T]{ + lock: &sync.Mutex{}, + expHeap: binaryheap.NewWith(heapComparator), + m: make(map[string]expMapEntry[T]), + } +} + +func (em *ExpMap[T]) Set(key string, value T, exp time.Time) { + em.lock.Lock() + defer em.lock.Unlock() + oldEntry, ok := em.m[key] + em.m[key] = expMapEntry[T]{Val: value, Exp: exp} + if !ok || oldEntry.Exp != exp { + em.expHeap.Push(expEntry{Key: key, Exp: exp}) // this might create duplicates. that's ok. + } +} + +func (em *ExpMap[T]) expireItems_nolock() { + // should already hold the lock + now := time.Now() + for { + if em.expHeap.Empty() { + break + } + // we know it isn't empty, so we ignore "ok" + topI, _ := em.expHeap.Peek() + top := topI.(expEntry) + if top.Exp.After(now) { + break + } + em.expHeap.Pop() + entry, ok := em.m[top.Key] + if ok && (entry.Exp.Before(now) || entry.Exp.Equal(now)) { + delete(em.m, top.Key) + } + } +} + +func (em *ExpMap[T]) Get(key string) (T, bool) { + em.lock.Lock() + defer em.lock.Unlock() + em.expireItems_nolock() + v, ok := em.m[key] + return v.Val, ok +} diff --git a/pkg/waveapp/waveapp.go b/pkg/waveapp/waveapp.go index b05f852f10..4b7ff00e28 100644 --- a/pkg/waveapp/waveapp.go +++ b/pkg/waveapp/waveapp.go @@ -176,7 +176,7 @@ func (client *Client) Connect() error { if err != nil { return fmt.Errorf("error extracting socket name from %s: %v", wshutil.WaveJwtTokenVarName, err) } - rpcClient, err := wshutil.SetupDomainSocketRpcClient(sockName, client.ServerImpl) + rpcClient, err := wshutil.SetupDomainSocketRpcClient(sockName, client.ServerImpl, "vdomclient") if err != nil { return fmt.Errorf("error setting up domain socket rpc client: %v", err) } diff --git a/pkg/wshrpc/wshclient/barerpcclient.go b/pkg/wshrpc/wshclient/barerpcclient.go index 686db3afd6..f05b75095c 100644 --- a/pkg/wshrpc/wshclient/barerpcclient.go +++ b/pkg/wshrpc/wshclient/barerpcclient.go @@ -31,7 +31,7 @@ func GetBareRpcClient() *wshutil.WshRpc { waveSrvClient_Once.Do(func() { inputCh := make(chan []byte, DefaultInputChSize) outputCh := make(chan []byte, DefaultOutputChSize) - waveSrvClient_Singleton = wshutil.MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, &WshServerImpl) + waveSrvClient_Singleton = wshutil.MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, &WshServerImpl, "bare-client") wshutil.DefaultRouter.RegisterRoute(BareClientRoute, waveSrvClient_Singleton, true) wps.Broker.SetClient(wshutil.DefaultRouter) }) diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index 6e17ae8063..df241dd26c 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -45,6 +45,22 @@ func (impl *ServerImpl) MessageCommand(ctx context.Context, data wshrpc.CommandM return nil } +func (impl *ServerImpl) StreamTestCommand(ctx context.Context) chan wshrpc.RespOrErrorUnion[int] { + ch := make(chan wshrpc.RespOrErrorUnion[int], 16) + go func() { + defer close(ch) + idx := 0 + for { + ch <- wshrpc.RespOrErrorUnion[int]{Response: idx} + idx++ + if idx == 1000 { + break + } + } + }() + return ch +} + type ByteRangeType struct { All bool Start int64 diff --git a/pkg/wshrpc/wshserver/wshserverutil.go b/pkg/wshrpc/wshserver/wshserverutil.go index 2f912db303..252d358914 100644 --- a/pkg/wshrpc/wshserver/wshserverutil.go +++ b/pkg/wshrpc/wshserver/wshserverutil.go @@ -23,7 +23,7 @@ func GetMainRpcClient() *wshutil.WshRpc { waveSrvClient_Once.Do(func() { inputCh := make(chan []byte, DefaultInputChSize) outputCh := make(chan []byte, DefaultOutputChSize) - waveSrvClient_Singleton = wshutil.MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, &WshServerImpl) + waveSrvClient_Singleton = wshutil.MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, &WshServerImpl, "main-client") }) return waveSrvClient_Singleton } diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go index 0514acb611..0e232a918f 100644 --- a/pkg/wshutil/wshrouter.go +++ b/pkg/wshutil/wshrouter.go @@ -196,26 +196,21 @@ func (router *WshRouter) getAnnouncedRoute(routeId string) string { func (router *WshRouter) sendRoutedMessage(msgBytes []byte, routeId string) bool { rpc := router.GetRpc(routeId) if rpc != nil { - // log.Printf("[router] sending message to %q via rpc\n", routeId) rpc.SendRpcMessage(msgBytes) return true } upstream := router.GetUpstreamClient() if upstream != nil { - log.Printf("[router] sending message to %q via upstream\n", routeId) upstream.SendRpcMessage(msgBytes) return true } else { - log.Printf("[router] sending message to %q via announced route\n", routeId) // we are the upstream, so consult our announced routes map localRouteId := router.getAnnouncedRoute(routeId) - log.Printf("[router] local route id: %q\n", localRouteId) rpc := router.GetRpc(localRouteId) if rpc == nil { log.Printf("[router] no rpc for local route id %q\n", localRouteId) return false } - log.Printf("[router] sending message to %q via local route\n", localRouteId) rpc.SendRpcMessage(msgBytes) return true } diff --git a/pkg/wshutil/wshrpc.go b/pkg/wshutil/wshrpc.go index 84f9628b46..a1ea1f71ae 100644 --- a/pkg/wshutil/wshrpc.go +++ b/pkg/wshutil/wshrpc.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" "github.com/wavetermdev/waveterm/pkg/panichandler" + "github.com/wavetermdev/waveterm/pkg/util/ds" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wshrpc" @@ -26,6 +27,8 @@ const RespChSize = 32 const DefaultMessageChSize = 32 const CtxDoneChSize = 10 +var blockingExpMap = ds.MakeExpMap[bool]() + type ResponseFnType = func(any) error // returns true if handler is complete, false for an async handler @@ -42,7 +45,6 @@ type AbstractRpcClient interface { type WshRpc struct { Lock *sync.Mutex - clientId string InputCh chan []byte OutputCh chan []byte CtxDoneCh chan string // for context cancellation, value is ResId @@ -181,8 +183,10 @@ func (r *RpcMessage) Validate() error { } type rpcData struct { - ResCh chan *RpcMessage - Ctx context.Context + Command string + Route string + ResCh chan *RpcMessage + Ctx context.Context } func validateServerImpl(serverImpl ServerImpl) { @@ -196,7 +200,7 @@ func validateServerImpl(serverImpl ServerImpl) { } // closes outputCh when inputCh is closed/done -func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx wshrpc.RpcContext, serverImpl ServerImpl) *WshRpc { +func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx wshrpc.RpcContext, serverImpl ServerImpl, debugName string) *WshRpc { if inputCh == nil { inputCh = make(chan []byte, DefaultInputChSize) } @@ -206,7 +210,7 @@ func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx wshrpc.RpcCont validateServerImpl(serverImpl) rtn := &WshRpc{ Lock: &sync.Mutex{}, - clientId: uuid.New().String(), + DebugName: debugName, InputCh: inputCh, OutputCh: outputCh, CtxDoneCh: make(chan string, CtxDoneChSize), @@ -221,10 +225,6 @@ func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx wshrpc.RpcCont return rtn } -func (w *WshRpc) ClientId() string { - return w.clientId -} - func (w *WshRpc) GetRpcContext() wshrpc.RpcContext { rtnPtr := w.RpcContext.Load() return *rtnPtr @@ -377,11 +377,7 @@ outer: w.handleRequest(&msg) }() } else { - respCh := w.getResponseCh(msg.ResId) - if respCh == nil { - continue - } - respCh <- &msg + w.sendRespWithBlockMessage(msg) if !msg.Cont { w.unregisterRpc(msg.ResId, nil) } @@ -389,17 +385,17 @@ outer: } } -func (w *WshRpc) getResponseCh(resId string) chan *RpcMessage { +func (w *WshRpc) getResponseCh(resId string) (chan *RpcMessage, *rpcData) { if resId == "" { - return nil + return nil, nil } w.Lock.Lock() defer w.Lock.Unlock() rd := w.RpcMap[resId] if rd == nil { - return nil + return nil, nil } - return rd.ResCh + return rd.ResCh, rd } func (w *WshRpc) SetServerImpl(serverImpl ServerImpl) { @@ -409,13 +405,15 @@ func (w *WshRpc) SetServerImpl(serverImpl ServerImpl) { w.ServerImpl = serverImpl } -func (w *WshRpc) registerRpc(ctx context.Context, reqId string) chan *RpcMessage { +func (w *WshRpc) registerRpc(ctx context.Context, command string, route string, reqId string) chan *RpcMessage { w.Lock.Lock() defer w.Lock.Unlock() rpcCh := make(chan *RpcMessage, RespChSize) w.RpcMap[reqId] = &rpcData{ - ResCh: rpcCh, - Ctx: ctx, + Command: command, + Route: route, + ResCh: rpcCh, + Ctx: ctx, } go func() { defer func() { @@ -712,7 +710,7 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp if err != nil { return nil, err } - handler.respCh = w.registerRpc(handler.ctx, handler.reqId) + handler.respCh = w.registerRpc(handler.ctx, command, opts.Route, handler.reqId) w.OutputCh <- barr return handler, nil } @@ -753,5 +751,34 @@ func (w *WshRpc) retrySendTimeout(resId string) { } time.Sleep(100 * time.Millisecond) } +} +func (w *WshRpc) sendRespWithBlockMessage(msg RpcMessage) { + respCh, rd := w.getResponseCh(msg.ResId) + if respCh == nil { + return + } + select { + case respCh <- &msg: + // normal case, message got sent, just return! + return + default: + // channel is full, we would block... + } + // log the fact that we're blocking + _, noLog := blockingExpMap.Get(msg.ResId) + if !noLog { + log.Printf("[rpc:%s] blocking on response command:%s route:%s resid:%s\n", w.DebugName, rd.Command, rd.Route, msg.ResId) + blockingExpMap.Set(msg.ResId, true, time.Now().Add(time.Second)) + } + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + select { + case respCh <- &msg: + // message got sent, just return! + return + case <-ctx.Done(): + } + log.Printf("[rpc:%s] failed to clear response channel (waited 1s), will fail RPC command:%s route:%s resid:%s\n", w.DebugName, rd.Command, rd.Route, msg.ResId) + w.unregisterRpc(msg.ResId, nil) // we don't pass an error because the channel is full, it won't work anyway... } diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index 424d64db19..871fd72d1f 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -199,11 +199,11 @@ func RestoreTermState() { } // returns (wshRpc, wrappedStdin) -func SetupTerminalRpcClient(serverImpl ServerImpl) (*WshRpc, io.Reader) { +func SetupTerminalRpcClient(serverImpl ServerImpl, debugStr string) (*WshRpc, io.Reader) { messageCh := make(chan []byte, DefaultInputChSize) outputCh := make(chan []byte, DefaultOutputChSize) ptyBuf := MakePtyBuffer(WaveServerOSCPrefix, os.Stdin, messageCh) - rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl) + rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl, debugStr) go func() { defer func() { panichandler.PanicHandler("SetupTerminalRpcClient", recover()) @@ -221,11 +221,11 @@ func SetupTerminalRpcClient(serverImpl ServerImpl) (*WshRpc, io.Reader) { return rpcClient, ptyBuf } -func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl) (*WshRpc, chan []byte) { +func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl, debugStr string) (*WshRpc, chan []byte) { messageCh := make(chan []byte, DefaultInputChSize) outputCh := make(chan []byte, DefaultOutputChSize) rawCh := make(chan []byte, DefaultOutputChSize) - rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl) + rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl, debugStr) go packetparser.Parse(input, messageCh, rawCh) go func() { defer func() { @@ -238,7 +238,7 @@ func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerIm return rpcClient, rawCh } -func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) { +func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl, debugStr string) (*WshRpc, chan error, error) { inputCh := make(chan []byte, DefaultInputChSize) outputCh := make(chan []byte, DefaultOutputChSize) writeErrCh := make(chan error, 1) @@ -260,7 +260,7 @@ func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan err defer conn.Close() AdaptStreamToMsgCh(conn, inputCh) }() - rtn := MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, serverImpl) + rtn := MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, serverImpl, debugStr) return rtn, writeErrCh, nil } @@ -272,7 +272,7 @@ func tryTcpSocket(sockName string) (net.Conn, error) { return net.DialTCP("tcp", nil, addr) } -func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) { +func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl, debugName string) (*WshRpc, error) { conn, tcpErr := tryTcpSocket(sockName) var unixErr error if tcpErr != nil { @@ -281,7 +281,7 @@ func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc if tcpErr != nil && unixErr != nil { return nil, fmt.Errorf("failed to connect to tcp or unix domain socket: tcp err:%w: unix socket err: %w", tcpErr, unixErr) } - rtn, errCh, err := SetupConnRpcClient(conn, serverImpl) + rtn, errCh, err := SetupConnRpcClient(conn, serverImpl, debugName) go func() { defer func() { panichandler.PanicHandler("SetupDomainSocketRpcClient:closeConn", recover())