diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index d2cf64d498..6f2145827c 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -148,10 +148,12 @@ func (pp *PipePty) WriteString(s string) (n int, err error) { } func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) { + utilCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second) + defer cancelFn() client := conn.GetClient() shellPath := cmdOpts.ShellPath if shellPath == "" { - remoteShellPath, err := wsl.DetectShell(conn.Context, client) + remoteShellPath, err := wsl.DetectShell(utilCtx, client) if err != nil { return nil, err } @@ -160,13 +162,13 @@ func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr st var shellOpts []string log.Printf("detected shell: %s", shellPath) - err := wsl.InstallClientRcFiles(conn.Context, client) + err := wsl.InstallClientRcFiles(utilCtx, client) if err != nil { log.Printf("error installing rc files: %v", err) return nil, err } - homeDir := wsl.GetHomeDir(conn.Context, client) + homeDir := wsl.GetHomeDir(utilCtx, client) shellOpts = append(shellOpts, "~", "-d", client.Name()) var subShellOpts []string diff --git a/pkg/wsl/wsl.go b/pkg/wsl/wsl.go index 673bc1a73f..4cad32bde9 100644 --- a/pkg/wsl/wsl.go +++ b/pkg/wsl/wsl.go @@ -51,7 +51,6 @@ type WslConn struct { HasWaiter *atomic.Bool LastConnectTime int64 ActiveConnNum int - Context context.Context cancelFn func() } @@ -188,6 +187,8 @@ func (conn *WslConn) OpenDomainSocketListener() error { } func (conn *WslConn) StartConnServer() error { + utilCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() var allowed bool conn.WithLock(func() { if conn.Status != Status_Connecting { @@ -200,7 +201,7 @@ func (conn *WslConn) StartConnServer() error { return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus()) } client := conn.GetClient() - wshPath := GetWshPath(conn.Context, client) + wshPath := GetWshPath(utilCtx, client) rpcCtx := wshrpc.RpcContext{ ClientType: wshrpc.ClientType_ConnServer, Conn: conn.GetName(), @@ -210,7 +211,7 @@ func (conn *WslConn) StartConnServer() error { if err != nil { return fmt.Errorf("unable to create jwt token for conn controller: %w", err) } - shellPath, err := DetectShell(conn.Context, client) + shellPath, err := DetectShell(utilCtx, client) if err != nil { return err } @@ -221,7 +222,14 @@ func (conn *WslConn) StartConnServer() error { cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath) } log.Printf("starting conn controller: %s\n", cmdStr) - cmd := client.WslCommand(conn.Context, cmdStr) + connServerCtx, cancelFn := context.WithCancel(context.Background()) + conn.WithLock(func() { + if conn.cancelFn != nil { + conn.cancelFn() + } + conn.cancelFn = cancelFn + }) + cmd := client.WslCommand(connServerCtx, cmdStr) pipeRead, pipeWrite := io.Pipe() inputPipeRead, inputPipeWrite := io.Pipe() cmd.SetStdout(pipeWrite) @@ -473,8 +481,7 @@ func getConnInternal(name string) *WslConn { connName := WslName{Distro: name} rtn := clientControllerMap[name] if rtn == nil { - ctx, cancelFn := context.WithCancel(context.Background()) - rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, Context: ctx, cancelFn: cancelFn} + rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, cancelFn: nil} clientControllerMap[name] = rtn } return rtn