From 5d197a78e87310f1dfd3f34de8ffae80af76a070 Mon Sep 17 00:00:00 2001 From: sawka Date: Mon, 30 Dec 2024 21:58:03 -0800 Subject: [PATCH 01/11] use pwsh over powershell if present --- pkg/util/shellutil/shellutil.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pkg/util/shellutil/shellutil.go b/pkg/util/shellutil/shellutil.go index 91ac6d76b8..175dd59d3a 100644 --- a/pkg/util/shellutil/shellutil.go +++ b/pkg/util/shellutil/shellutil.go @@ -93,6 +93,12 @@ $env:PATH = "{{.WSHBINDIR}}" + "{{.PATHSEP}}" + $env:PATH func DetectLocalShellPath() string { if runtime.GOOS == "windows" { + if pwshPath, lpErr := exec.LookPath("pwsh"); lpErr == nil { + return pwshPath + } + if powershellPath, lpErr := exec.LookPath("powershell"); lpErr == nil { + return powershellPath + } return "powershell.exe" } shellPath := GetMacUserShell() From e5c3c231baad33d1f9d778feb0157d354d09fecb Mon Sep 17 00:00:00 2001 From: sawka Date: Mon, 30 Dec 2024 22:51:25 -0800 Subject: [PATCH 02/11] initial refactor of DoRunShellCommand into a setup / manage phase. fix bug around persisting termsize to DB --- frontend/app/store/keymodel.ts | 1 - pkg/blockcontroller/blockcontroller.go | 86 ++++++++++++++++---------- 2 files changed, 52 insertions(+), 35 deletions(-) diff --git a/frontend/app/store/keymodel.ts b/frontend/app/store/keymodel.ts index cc4f58c9a6..0550b63d93 100644 --- a/frontend/app/store/keymodel.ts +++ b/frontend/app/store/keymodel.ts @@ -341,7 +341,6 @@ function registerGlobalKeys() { return false; } globalKeyMap.set("Cmd:f", activateSearch); - globalKeyMap.set("Ctrl:f", activateSearch); globalKeyMap.set("Escape", deactivateSearch); const allKeys = Array.from(globalKeyMap.keys()); // special case keys, handled by web view diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 35d32eb680..c800b37447 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -273,27 +273,34 @@ func createCmdStrAndOpts(blockId string, blockMeta waveobj.MetaMapType) (string, } func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj.MetaMapType) error { + shellProc, err := bc.setupAndStartShellProcess(rc, blockMeta) + if err != nil { + return err + } + return bc.manageRunningShellProcess(shellProc, rc, blockMeta) +} + +func (bc *BlockController) setupAndStartShellProcess(rc *RunShellOpts, blockMeta waveobj.MetaMapType) (*shellexec.ShellProc, error) { // create a circular blockfile for the output ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) defer cancelFn() - err := filestore.WFS.MakeFile(ctx, bc.BlockId, BlockFile_Term, nil, filestore.FileOptsType{MaxSize: DefaultTermMaxFileSize, Circular: true}) - if err != nil && err != fs.ErrExist { - err = fs.ErrExist - return fmt.Errorf("error creating blockfile: %w", err) + fsErr := filestore.WFS.MakeFile(ctx, bc.BlockId, BlockFile_Term, nil, filestore.FileOptsType{MaxSize: DefaultTermMaxFileSize, Circular: true}) + if fsErr != nil && fsErr != fs.ErrExist { + return nil, fmt.Errorf("error creating blockfile: %w", fsErr) } - if err == fs.ErrExist { + if fsErr == fs.ErrExist { // reset the terminal state bc.resetTerminalState() } - err = nil bcInitStatus := bc.GetRuntimeStatus() if bcInitStatus.ShellProcStatus == Status_Running { - return nil + return nil, nil } // TODO better sync here (don't let two starts happen at the same times) remoteName := blockMeta.GetString(waveobj.MetaKey_Connection, "") var cmdStr string var cmdOpts shellexec.CommandOptsType + var err error if bc.ControllerType == BlockController_Shell { cmdOpts.Env = make(map[string]string) cmdOpts.Interactive = true @@ -302,7 +309,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj if cmdOpts.Cwd != "" { cwdPath, err := wavebase.ExpandHomeDir(cmdOpts.Cwd) if err != nil { - return err + return nil, err } cmdOpts.Cwd = cwdPath } @@ -310,11 +317,11 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj var cmdOptsPtr *shellexec.CommandOptsType cmdStr, cmdOptsPtr, err = createCmdStrAndOpts(bc.BlockId, blockMeta) if err != nil { - return err + return nil, err } cmdOpts = *cmdOptsPtr } else { - return fmt.Errorf("unknown controller type %q", bc.ControllerType) + return nil, fmt.Errorf("unknown controller type %q", bc.ControllerType) } var shellProc *shellexec.ShellProc if strings.HasPrefix(remoteName, "wsl://") { @@ -325,20 +332,20 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj wslConn := wsl.GetWslConn(credentialCtx, wslName, false) connStatus := wslConn.DeriveConnStatus() if connStatus.Status != conncontroller.Status_Connected { - return fmt.Errorf("not connected, cannot start shellproc") + return nil, fmt.Errorf("not connected, cannot start shellproc") } // create jwt if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) { jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName()) if err != nil { - return fmt.Errorf("error making jwt token: %w", err) + return nil, fmt.Errorf("error making jwt token: %w", err) } cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr } shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) if err != nil { - return err + return nil, err } } else if remoteName != "" { credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second) @@ -346,24 +353,24 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj opts, err := remote.ParseOpts(remoteName) if err != nil { - return err + return nil, err } conn := conncontroller.GetConn(credentialCtx, opts, false, &wshrpc.ConnKeywords{}) connStatus := conn.DeriveConnStatus() if connStatus.Status != conncontroller.Status_Connected { - return fmt.Errorf("not connected, cannot start shellproc") + return nil, fmt.Errorf("not connected, cannot start shellproc") } if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) { jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: conn.Opts.String()}, conn.GetDomainSocketName()) if err != nil { - return fmt.Errorf("error making jwt token: %w", err) + return nil, fmt.Errorf("error making jwt token: %w", err) } cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr } if !conn.WshEnabled.Load() { shellProc, err = shellexec.StartRemoteShellProcNoWsh(rc.TermSize, cmdStr, cmdOpts, conn) if err != nil { - return err + return nil, err } } else { shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn) @@ -376,19 +383,16 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj log.Print("attempting install without wsh") shellProc, err = shellexec.StartRemoteShellProcNoWsh(rc.TermSize, cmdStr, cmdOpts, conn) if err != nil { - return err + return nil, err } } } - if err != nil { - return err - } } else { // local terminal if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) { jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, wavebase.GetDomainSocketName()) if err != nil { - return fmt.Errorf("error making jwt token: %w", err) + return nil, fmt.Errorf("error making jwt token: %w", err) } cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr } @@ -407,7 +411,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj } shellProc, err = shellexec.StartShellProc(rc.TermSize, cmdStr, cmdOpts) if err != nil { - return err + return nil, err } } bc.UpdateControllerAndSendUpdate(func() bool { @@ -415,6 +419,10 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj bc.ShellProcStatus = Status_Running return true }) + return shellProc, nil +} + +func (bc *BlockController) manageRunningShellProcess(shellProc *shellexec.ShellProc, rc *RunShellOpts, blockMeta waveobj.MetaMapType) error { shellInputCh := make(chan *BlockInputUnion, 32) bc.ShellInputCh = shellInputCh @@ -469,14 +477,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj shellProc.Cmd.Write(ic.InputData) } if ic.TermSize != nil { - err = setTermSize(ctx, bc.BlockId, *ic.TermSize) - if err != nil { - log.Printf("error setting pty size: %v\n", err) - } - err = shellProc.Cmd.SetSize(ic.TermSize.Rows, ic.TermSize.Cols) - if err != nil { - log.Printf("error setting pty size: %v\n", err) - } + updateTermSize(shellProc, bc.BlockId, *ic.TermSize) } } }() @@ -514,6 +515,17 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj return nil } +func updateTermSize(shellProc *shellexec.ShellProc, blockId string, termSize waveobj.TermSize) { + err := setTermSizeInDB(blockId, termSize) + if err != nil { + log.Printf("error setting pty size: %v\n", err) + } + err = shellProc.Cmd.SetSize(termSize.Rows, termSize.Cols) + if err != nil { + log.Printf("error setting pty size: %v\n", err) + } +} + func checkCloseOnExit(blockId string, exitCode int) { ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout) defer cancelFn() @@ -561,16 +573,22 @@ func getTermSize(bdata *waveobj.Block) waveobj.TermSize { } } -func setTermSize(ctx context.Context, blockId string, termSize waveobj.TermSize) error { +func setTermSizeInDB(blockId string, termSize waveobj.TermSize) error { + ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelFn() ctx = waveobj.ContextWithUpdates(ctx) - bdata, err := wstore.DBMustGet[*waveobj.Block](context.Background(), blockId) + bdata, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId) if err != nil { return fmt.Errorf("error getting block data: %v", err) } if bdata.RuntimeOpts == nil { - return fmt.Errorf("error from nil RuntimeOpts: %v", err) + bdata.RuntimeOpts = &waveobj.RuntimeOpts{} } bdata.RuntimeOpts.TermSize = termSize + err = wstore.DBUpdate(ctx, bdata) + if err != nil { + return fmt.Errorf("error updating block data: %v", err) + } updates := waveobj.ContextGetUpdatesRtn(ctx) wps.Broker.SendUpdateEvents(updates) return nil From 6605db22383003cbc188a42f60593e9ba0146a6a Mon Sep 17 00:00:00 2001 From: sawka Date: Tue, 31 Dec 2024 10:18:28 -0800 Subject: [PATCH 03/11] fix wshcmd-conn disconnectall to work with wsl --- cmd/wsh/cmd/wshcmd-conn.go | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/cmd/wsh/cmd/wshcmd-conn.go b/cmd/wsh/cmd/wshcmd-conn.go index 52d76ef677..3e9cacd90e 100644 --- a/cmd/wsh/cmd/wshcmd-conn.go +++ b/cmd/wsh/cmd/wshcmd-conn.go @@ -87,18 +87,26 @@ func validateConnectionName(name string) error { return nil } -func connStatusRun(cmd *cobra.Command, args []string) error { +func getAllConnStatus() ([]wshrpc.ConnStatus, error) { var allResp []wshrpc.ConnStatus sshResp, err := wshclient.ConnStatusCommand(RpcClient, nil) if err != nil { - return fmt.Errorf("getting ssh connection status: %w", err) + return nil, fmt.Errorf("getting ssh connection status: %w", err) } allResp = append(allResp, sshResp...) wslResp, err := wshclient.WslStatusCommand(RpcClient, nil) if err != nil { - return fmt.Errorf("getting wsl connection status: %w", err) + return nil, fmt.Errorf("getting wsl connection status: %w", err) } allResp = append(allResp, wslResp...) + return allResp, nil +} + +func connStatusRun(cmd *cobra.Command, args []string) error { + allResp, err := getAllConnStatus() + if err != nil { + return err + } if len(allResp) == 0 { WriteStdout("no connections\n") return nil @@ -142,21 +150,19 @@ func connDisconnectRun(cmd *cobra.Command, args []string) error { } func connDisconnectAllRun(cmd *cobra.Command, args []string) error { - resp, err := wshclient.ConnStatusCommand(RpcClient, nil) + allConns, err := getAllConnStatus() if err != nil { - return fmt.Errorf("getting connection status: %w", err) - } - if len(resp) == 0 { - return nil + return err } - for _, conn := range resp { - if conn.Status == "connected" { - err := wshclient.ConnDisconnectCommand(RpcClient, conn.Connection, &wshrpc.RpcOpts{Timeout: 10000}) - if err != nil { - WriteStdout("error disconnecting %q: %v\n", conn.Connection, err) - } else { - WriteStdout("disconnected %q\n", conn.Connection) - } + for _, conn := range allConns { + if conn.Status != "connected" { + continue + } + err := wshclient.ConnDisconnectCommand(RpcClient, conn.Connection, &wshrpc.RpcOpts{Timeout: 10000}) + if err != nil { + WriteStdout("error disconnecting %q: %v\n", conn.Connection, err) + } else { + WriteStdout("disconnected %q\n", conn.Connection) } } return nil From 0f58a0efd70284cb9cf7cee9735f48579bede9a2 Mon Sep 17 00:00:00 2001 From: sawka Date: Tue, 31 Dec 2024 12:43:04 -0800 Subject: [PATCH 04/11] more robust cp --- pkg/remote/conncontroller/conncontroller.go | 2 +- pkg/remote/connutil.go | 107 +++++++++++--------- 2 files changed, 58 insertions(+), 51 deletions(-) diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index c769f7f7ee..ece8b8dbb5 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -379,7 +379,7 @@ func (conn *SSHConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName s } // attempt to install extension wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) - err = remote.CpHostToRemote(client, wshLocalPath, wavebase.RemoteFullWshBinPath) + err = remote.CpHostToRemote(ctx, client, wshLocalPath, wavebase.RemoteFullWshBinPath) if err != nil { return err } diff --git a/pkg/remote/connutil.go b/pkg/remote/connutil.go index 90707878db..8e938db7ef 100644 --- a/pkg/remote/connutil.go +++ b/pkg/remote/connutil.go @@ -5,8 +5,8 @@ package remote import ( "bytes" + "context" "fmt" - "html/template" "io" "log" "os" @@ -14,8 +14,8 @@ import ( "path/filepath" "regexp" "strings" + "text/template" - "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "golang.org/x/crypto/ssh" @@ -212,76 +212,83 @@ func GetClientArch(client *ssh.Client) (string, error) { return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr) } -var installTemplateRawBash = `bash -c ' \ -mkdir -p {{.installDir}}; \ -cat > {{.tempPath}}; \ -mv {{.tempPath}} {{.installPath}}; \ -chmod a+x {{.installPath}};' \ -` - -var installTemplateRawDefault = ` \ -mkdir -p {{.installDir}}; \ -cat > {{.tempPath}}; \ -mv {{.tempPath}} {{.installPath}}; \ -chmod a+x {{.installPath}}; \ -` - -func CpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error { - // warning: does not work on windows remote yet - bashInstalled, err := hasBashInstalled(client) - if err != nil { - return err - } +var installTemplateRawDefault = strings.TrimSpace(` +mkdir -p {{.installDir}} || exit 1 +cat > {{.tempPath}} || exit 1 +mv {{.tempPath}} {{.installPath}} || exit 1 +chmod a+x {{.installPath}} || exit 1 +`) +var installTemplate = template.Must(template.New("wsh-install-template").Parse(installTemplateRawDefault)) - var selectedTemplateRaw string - if bashInstalled { - selectedTemplateRaw = installTemplateRawBash - } else { - log.Printf("bash is not installed on remote. attempting with default shell") - selectedTemplateRaw = installTemplateRawDefault +func CpHostToRemote(ctx context.Context, client *ssh.Client, sourcePath string, destPath string) error { + installWords := map[string]string{ + "installDir": filepath.ToSlash(filepath.Dir(destPath)), + "tempPath": filepath.ToSlash(destPath + ".temp"), + "installPath": filepath.ToSlash(destPath), } - // I need to use toSlash here to force unix keybindings - // this means we can't guarantee it will work on a remote windows machine - var installWords = map[string]string{ - "installDir": filepath.ToSlash(filepath.Dir(destPath)), - "tempPath": destPath + ".temp", - "installPath": destPath, + var installCmd bytes.Buffer + if err := installTemplate.Execute(&installCmd, installWords); err != nil { + return fmt.Errorf("failed to prepare install command: %w", err) } - installCmd := &bytes.Buffer{} - installTemplate := template.Must(template.New("").Parse(selectedTemplateRaw)) - installTemplate.Execute(installCmd, installWords) + // Add debug log of the command + log.Printf("Running remote command: %s", installCmd.String()) session, err := client.NewSession() if err != nil { - return err + return fmt.Errorf("failed to create SSH session: %w", err) } + defer session.Close() + + // Add stderr capture + var stderr bytes.Buffer + session.Stderr = &stderr - installStdin, err := session.StdinPipe() + stdin, err := session.StdinPipe() if err != nil { - return err + return fmt.Errorf("failed to get stdin pipe: %w", err) } - err = session.Start(installCmd.String()) - if err != nil { - return err + if err := session.Start(installCmd.String()); err != nil { + return fmt.Errorf("failed to start remote command: %w", err) } input, err := os.Open(sourcePath) if err != nil { - return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err) + return fmt.Errorf("cannot open local file %s: %w", sourcePath, err) } + defer input.Close() + + copyDone := make(chan error, 1) go func() { - defer func() { - panichandler.PanicHandler("connutil:CpHostToRemote", recover()) - }() - io.Copy(installStdin, input) - session.Close() // this allows the command to complete for reasons i don't fully understand + defer close(copyDone) + defer stdin.Close() + + _, err := io.Copy(stdin, input) + if err != nil && err != io.EOF { + copyDone <- err + return + } + copyDone <- nil }() - return session.Wait() + select { + case <-ctx.Done(): + session.Close() + return ctx.Err() + case err := <-copyDone: + if err != nil { + return fmt.Errorf("failed to copy data: %w", err) + } + } + + if err := session.Wait(); err != nil { + return fmt.Errorf("remote command failed: %w (stderr: %s)", err, stderr.String()) + } + + return nil } func InstallClientRcFiles(client *ssh.Client) error { From 716a185c09aaa3096186213d4e1ae1010deeecd6 Mon Sep 17 00:00:00 2001 From: sawka Date: Tue, 31 Dec 2024 12:43:33 -0800 Subject: [PATCH 05/11] remove debug log --- pkg/remote/connutil.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/pkg/remote/connutil.go b/pkg/remote/connutil.go index 8e938db7ef..e001652539 100644 --- a/pkg/remote/connutil.go +++ b/pkg/remote/connutil.go @@ -232,9 +232,6 @@ func CpHostToRemote(ctx context.Context, client *ssh.Client, sourcePath string, return fmt.Errorf("failed to prepare install command: %w", err) } - // Add debug log of the command - log.Printf("Running remote command: %s", installCmd.String()) - session, err := client.NewSession() if err != nil { return fmt.Errorf("failed to create SSH session: %w", err) From ff2793fba9ab35650cc6ec1ff13ea527622ff075 Mon Sep 17 00:00:00 2001 From: sawka Date: Wed, 1 Jan 2025 14:54:15 -0800 Subject: [PATCH 06/11] clean up the genconn interface... get the ssh impl done. write a generic RunSimple. create a new SyncBuffer struct. --- pkg/genconn/genconn.go | 132 +++++++++++++ pkg/genconn/quote.go | 78 ++++++++ pkg/genconn/quote_test.go | 110 +++++++++++ pkg/genconn/ssh-impl.go | 143 ++++++++++++++ pkg/genconn/wsl-impl.go | 200 ++++++++++++++++++++ pkg/remote/conncontroller/conncontroller.go | 8 +- pkg/remote/connutil.go | 150 +++++++++------ pkg/util/shellutil/shellutil.go | 19 +- pkg/util/syncbuf/syncbuf.go | 41 ++++ pkg/wavebase/wavebase.go | 16 ++ pkg/wsl/wsl.go | 5 +- 11 files changed, 828 insertions(+), 74 deletions(-) create mode 100644 pkg/genconn/genconn.go create mode 100644 pkg/genconn/quote.go create mode 100644 pkg/genconn/quote_test.go create mode 100644 pkg/genconn/ssh-impl.go create mode 100644 pkg/genconn/wsl-impl.go create mode 100644 pkg/util/syncbuf/syncbuf.go diff --git a/pkg/genconn/genconn.go b/pkg/genconn/genconn.go new file mode 100644 index 0000000000..da448510ef --- /dev/null +++ b/pkg/genconn/genconn.go @@ -0,0 +1,132 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +// generic connection code (WSL + SSH) +package genconn + +import ( + "context" + "fmt" + "io" + "regexp" + "strings" + "sync" + + "github.com/wavetermdev/waveterm/pkg/util/syncbuf" +) + +type CommandSpec struct { + Cmd string + Env map[string]string + Cwd string +} + +type ShellClient interface { + MakeProcessController(cmd CommandSpec) (ShellProcessController, error) +} + +type ShellProcessController interface { + Start() error + Wait() error + Kill() + + // these are not required to be called, if they are not called, the impl will set to discard output + StdinPipe() (io.WriteCloser, error) + StdoutPipe() (io.Reader, error) + StderrPipe() (io.Reader, error) +} + +func RunSimpleCommand(ctx context.Context, client ShellClient, spec CommandSpec) (string, string, error) { + proc, err := client.MakeProcessController(spec) + if err != nil { + return "", "", fmt.Errorf("failed to create process controller: %w", err) + } + + stdout, err := proc.StdoutPipe() + if err != nil { + return "", "", fmt.Errorf("failed to get stdout pipe: %w", err) + } + stderr, err := proc.StderrPipe() + if err != nil { + return "", "", fmt.Errorf("failed to get stderr pipe: %w", err) + } + + if err := proc.Start(); err != nil { + return "", "", fmt.Errorf("failed to start process: %w", err) + } + + stdoutBuf := syncbuf.MakeSyncBuffer() + stderrBuf := syncbuf.MakeSyncBuffer() + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + io.Copy(stdoutBuf, stdout) + }() + + go func() { + defer wg.Done() + io.Copy(stderrBuf, stderr) + }() + + done := make(chan error, 1) + go func() { + done <- proc.Wait() + }() + + var runErr error + select { + case <-ctx.Done(): + proc.Kill() + runErr = ctx.Err() + case err := <-done: + if err != nil { + runErr = fmt.Errorf("process failed: %w", err) + } + } + + wg.Wait() + return stdoutBuf.String(), stderrBuf.String(), runErr +} + +func MakeStdoutSyncBuffer(proc ShellProcessController) (*syncbuf.SyncBuffer, error) { + stdout, err := proc.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to get stdout pipe: %w", err) + } + return syncbuf.MakeSyncBufferFromReader(stdout), nil +} + +func MakeStderrSyncBuffer(proc ShellProcessController) (*syncbuf.SyncBuffer, error) { + stderr, err := proc.StderrPipe() + if err != nil { + return nil, fmt.Errorf("failed to get stderr pipe: %w", err) + } + return syncbuf.MakeSyncBufferFromReader(stderr), nil +} + +func BuildShellCommand(opts CommandSpec) (string, error) { + // Build environment variables + var envVars strings.Builder + for key, value := range opts.Env { + if !isValidEnvVarName(key) { + return "", fmt.Errorf("invalid environment variable name: %q", key) + } + envVars.WriteString(fmt.Sprintf("%s=%s ", key, HardQuote(value))) + } + + // Build the command + shellCmd := opts.Cmd + if opts.Cwd != "" { + shellCmd = fmt.Sprintf("cd %s && %s", HardQuote(opts.Cwd), shellCmd) + } + + // Quote the command for `sh -c` + return fmt.Sprintf("sh -c %s", HardQuote(envVars.String()+shellCmd)), nil +} + +func isValidEnvVarName(name string) bool { + validEnvVarName := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + return validEnvVarName.MatchString(name) +} diff --git a/pkg/genconn/quote.go b/pkg/genconn/quote.go new file mode 100644 index 0000000000..8bb59732f7 --- /dev/null +++ b/pkg/genconn/quote.go @@ -0,0 +1,78 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package genconn + +import "regexp" + +var ( + safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`) + + needsEscape = map[byte]bool{ + '"': true, + '\\': true, + '$': true, + '`': true, + } +) + +func HardQuote(s string) string { + if s == "" { + return "\"\"" + } + + if safePattern.MatchString(s) { + return s + } + + buf := make([]byte, 0, len(s)+5) + buf = append(buf, '"') + + for i := 0; i < len(s); i++ { + if needsEscape[s[i]] { + buf = append(buf, '\\') + } + buf = append(buf, s[i]) + } + + buf = append(buf, '"') + return string(buf) +} + +func SoftQuote(s string) string { + if s == "" { + return "\"\"" + } + + // Handle special case of ~ paths + if len(s) > 0 && s[0] == '~' { + // If it's just ~ or ~/something with no special chars, leave it as is + if len(s) == 1 || (len(s) > 1 && s[1] == '/' && safePattern.MatchString(s[2:])) { + return s + } + + // Otherwise quote everything after the ~ (including the /) + if len(s) > 1 && s[1] == '/' { + return "~" + SoftQuote(s[1:]) + } + } + + if safePattern.MatchString(s) { + return s + } + + buf := make([]byte, 0, len(s)+5) + buf = append(buf, '"') + + for i := 0; i < len(s); i++ { + c := s[i] + // In soft quote, we don't escape $ to allow expansion + if c == '"' || c == '\\' || c == '`' { + buf = append(buf, '\\') + } + buf = append(buf, c) + } + + buf = append(buf, '"') + return string(buf) +} diff --git a/pkg/genconn/quote_test.go b/pkg/genconn/quote_test.go new file mode 100644 index 0000000000..ea25a101dd --- /dev/null +++ b/pkg/genconn/quote_test.go @@ -0,0 +1,110 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 +package genconn + +import "testing" + +func TestQuote(t *testing.T) { + tests := []struct { + name string + input string + wantHard string + wantSoft string + }{ + { + name: "simple strings", + input: "simple", + wantHard: "simple", + wantSoft: "simple", + }, + { + name: "safe path", + input: "path/to/file.txt", + wantHard: "path/to/file.txt", + wantSoft: "path/to/file.txt", + }, + { + name: "empty string", + input: "", + wantHard: `""`, + wantSoft: `""`, + }, + { + name: "tilde alone", + input: "~", + wantHard: `"~"`, + wantSoft: "~", + }, + { + name: "tilde with safe path", + input: "~/foo", + wantHard: `"~/foo"`, + wantSoft: "~/foo", + }, + { + name: "tilde with spaces", + input: "~/foo bar", + wantHard: `"~/foo bar"`, + wantSoft: `~"/foo bar"`, + }, + { + name: "tilde with variable", + input: "~/foo$bar", + wantHard: `"~/foo\$bar"`, + wantSoft: `~"/foo$bar"`, + }, + { + name: "invalid tilde path", + input: "~foo", + wantHard: `"~foo"`, + wantSoft: `"~foo"`, + }, + { + name: "variable at start", + input: "$HOME/.config", + wantHard: `"\$HOME/.config"`, + wantSoft: `"$HOME/.config"`, + }, + { + name: "variable in middle", + input: "prefix$HOME", + wantHard: `"prefix\$HOME"`, + wantSoft: `"prefix$HOME"`, + }, + { + name: "double quotes", + input: `has "quotes"`, + wantHard: `"has \"quotes\""`, + wantSoft: `"has \"quotes\""`, + }, + { + name: "backslash", + input: `back\slash`, + wantHard: `"back\\slash"`, + wantSoft: `"back\\slash"`, + }, + { + name: "backtick", + input: "`cmd`", + wantHard: "\"\\`cmd\\`\"", + wantSoft: "\"\\`cmd\\`\"", + }, + { + name: "spaces", + input: "spaces here", + wantHard: `"spaces here"`, + wantSoft: `"spaces here"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := HardQuote(tt.input); got != tt.wantHard { + t.Errorf("HardQuote(%q) = %q, want %q", tt.input, got, tt.wantHard) + } + if got := SoftQuote(tt.input); got != tt.wantSoft { + t.Errorf("SoftQuote(%q) = %q, want %q", tt.input, got, tt.wantSoft) + } + }) + } +} diff --git a/pkg/genconn/ssh-impl.go b/pkg/genconn/ssh-impl.go new file mode 100644 index 0000000000..83aae82f2a --- /dev/null +++ b/pkg/genconn/ssh-impl.go @@ -0,0 +1,143 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package genconn + +import ( + "fmt" + "io" + "sync" + + "golang.org/x/crypto/ssh" +) + +type SSHShellClient struct { + client *ssh.Client +} + +func MakeSSHShellClient(client *ssh.Client) *SSHShellClient { + return &SSHShellClient{client: client} +} + +func (c *SSHShellClient) MakeProcessController(cmdSpec CommandSpec) (ShellProcessController, error) { + return MakeSSHCmdClient(c.client, cmdSpec) +} + +// SSHProcessController implements ShellCmd for SSH connections +type SSHProcessController struct { + client *ssh.Client + session *ssh.Session + lock *sync.Mutex + once *sync.Once + stdinPiped bool + stdoutPiped bool + stderrPiped bool + waitErr error + started bool + cmdSpec CommandSpec +} + +// MakeSSHCmdClient creates a new instance of SSHCmdClient +func MakeSSHCmdClient(client *ssh.Client, cmdSpec CommandSpec) (*SSHProcessController, error) { + session, err := client.NewSession() + if err != nil { + return nil, fmt.Errorf("failed to create SSH session: %w", err) + } + return &SSHProcessController{ + client: client, + lock: &sync.Mutex{}, + once: &sync.Once{}, + cmdSpec: cmdSpec, + session: session, + }, nil +} + +// Start begins execution of the command +func (s *SSHProcessController) Start() error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.started { + return fmt.Errorf("command already started") + } + + fullCmd, err := BuildShellCommand(s.cmdSpec) + if err != nil { + return fmt.Errorf("failed to build shell command: %w", err) + } + // if stdout/stderr weren't piped, then session.stdout/stderr will be nil + // and the library guarantees that the outputs will be attached to io.Discard + // if stdin hasn't been piped, then session.stdin will be nil + // and the libary guarantees that it will be attached to an empty bytes.Buffer, which will produce an immediate EOF + // tl;dr we don't need to worry about hanging beause of long input or explicitly closing stdin + if err := s.session.Start(fullCmd); err != nil { + return fmt.Errorf("failed to start command: %w", err) + } + s.started = true + return nil +} + +// Wait waits for the command to complete +func (s *SSHProcessController) Wait() error { + s.once.Do(func() { + s.waitErr = s.session.Wait() + }) + return s.waitErr +} + +// Kill terminates the command +func (s *SSHProcessController) Kill() { + s.lock.Lock() + defer s.lock.Unlock() + + if s.session != nil { + s.session.Close() + } +} + +func (s *SSHProcessController) StdinPipe() (io.WriteCloser, error) { + s.lock.Lock() + defer s.lock.Unlock() + if s.started { + return nil, fmt.Errorf("command already started") + } + if s.stdinPiped { + return nil, fmt.Errorf("stdin already piped") + } + s.stdinPiped = true + return s.session.StdinPipe() +} + +func (s *SSHProcessController) StdoutPipe() (io.Reader, error) { + s.lock.Lock() + defer s.lock.Unlock() + if s.started { + return nil, fmt.Errorf("command already started") + } + if s.stdoutPiped { + return nil, fmt.Errorf("stdout already piped") + } + s.stdoutPiped = true + stdout, err := s.session.StdoutPipe() + if err != nil { + return nil, err + } + return stdout, nil +} + +func (s *SSHProcessController) StderrPipe() (io.Reader, error) { + s.lock.Lock() + defer s.lock.Unlock() + if s.started { + return nil, fmt.Errorf("command already started") + } + if s.stderrPiped { + return nil, fmt.Errorf("stderr already piped") + } + s.stderrPiped = true + stderr, err := s.session.StderrPipe() + if err != nil { + return nil, err + } + return stderr, nil +} diff --git a/pkg/genconn/wsl-impl.go b/pkg/genconn/wsl-impl.go new file mode 100644 index 0000000000..7508bfd24b --- /dev/null +++ b/pkg/genconn/wsl-impl.go @@ -0,0 +1,200 @@ +//go:build windows + +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package genconn + +import ( + "bytes" + "context" + "fmt" + "io" + "sync" + + "github.com/ubuntu/gowsl" +) + +// WSLSimpleCmdClient implements SimpleShellCmd for Ubuntu/WSL connections +type WSLSimpleCmdClient struct { + client *gowsl.Distro +} + +// NewWSLSimpleCmdClient creates a new instance of WSLSimpleCmdClient +func NewWSLSimpleCmdClient(client *gowsl.Distro) *WSLSimpleCmdClient { + return &WSLSimpleCmdClient{client: client} +} + +// Run executes the given shell command with options in the WSL environment +func (w *WSLSimpleCmdClient) Run(ctx context.Context, cmdSpec CommandSpec) (string, string, error) { + if ctx == nil { + return "", "", fmt.Errorf("nil Context") + } + + // Build the shell command using the shared helper + finalCmd, err := BuildShellCommand(cmdSpec) + if err != nil { + return "", "", fmt.Errorf("failed to build shell command: %w", err) + } + + // Create the command with context + cmd := w.client.Command(ctx, finalCmd) + if cmd == nil { + return "", "", fmt.Errorf("failed to create WSL command") + } + + // Create buffers for stdout and stderr + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + // Run the command + if err := cmd.Run(); err != nil { + return stdout.String(), stderr.String(), fmt.Errorf("command failed: %w", err) + } + + return stdout.String(), stderr.String(), nil +} + +// WSLCmdClient implements the ShellCmd interface for WSL +type WSLCmdClient struct { + client *gowsl.Distro + cmd *gowsl.Cmd + lock *sync.Mutex + once *sync.Once + waitErr error + initialized bool + started bool + commandSpec CommandSpec +} + +// NewWSLCmdClient creates a new instance of WSLCmdClient +func NewWSLCmdClient(client *gowsl.Distro) *WSLCmdClient { + return &WSLCmdClient{ + client: client, + lock: &sync.Mutex{}, + once: &sync.Once{}, + } +} + +// Init prepares the command but doesn't start it +func (w *WSLCmdClient) Init(cmd CommandSpec) error { + w.lock.Lock() + defer w.lock.Unlock() + + if w.initialized { + return fmt.Errorf("command already initialized") + } + + finalCmd, err := BuildShellCommand(cmd) + if err != nil { + return fmt.Errorf("failed to build shell command: %w", err) + } + + // Create command without context since we'll manage lifecycle manually + w.cmd = w.client.Command(nil, finalCmd) + if w.cmd == nil { + return fmt.Errorf("failed to create WSL command") + } + + w.commandSpec = cmd + w.initialized = true + return nil +} + +// Start begins execution of the command +func (w *WSLCmdClient) Start() error { + w.lock.Lock() + defer w.lock.Unlock() + + if !w.initialized { + return fmt.Errorf("command not initialized") + } + if w.started { + return fmt.Errorf("command already started") + } + + if err := w.cmd.Start(); err != nil { + return fmt.Errorf("failed to start command: %w", err) + } + + w.started = true + return nil +} + +// Wait waits for the command to complete +func (w *WSLCmdClient) Wait() error { + if !w.initialized { + panic("command not initialized") + } + w.once.Do(func() { + w.waitErr = w.cmd.Wait() + }) + return w.waitErr +} + +// Kill terminates the command +func (w *WSLCmdClient) Kill() { + w.lock.Lock() + defer w.lock.Unlock() + + if w.cmd != nil && w.cmd.Process != nil { + w.cmd.Process.Kill() + } +} + +// ExitCode returns the exit code of the command +func (w *WSLCmdClient) ExitCode() int { + w.lock.Lock() + defer w.lock.Unlock() + + if w.cmd == nil || w.cmd.ProcessState == nil { + return -1 + } + return w.cmd.ProcessState.ExitCode() +} + +// StdinPipe returns a pipe that will be connected to the command's standard input +func (w *WSLCmdClient) StdinPipe() (io.WriteCloser, error) { + w.lock.Lock() + defer w.lock.Unlock() + + if !w.initialized { + return nil, fmt.Errorf("command not initialized") + } + if w.started { + return nil, fmt.Errorf("command already started") + } + + return w.cmd.StdinPipe() +} + +// StdoutPipe returns a pipe that will be connected to the command's standard output +func (w *WSLCmdClient) StdoutPipe() (io.ReadCloser, error) { + w.lock.Lock() + defer w.lock.Unlock() + + if !w.initialized { + return nil, fmt.Errorf("command not initialized") + } + if w.started { + return nil, fmt.Errorf("command already started") + } + + return w.cmd.StdoutPipe() +} + +// StderrPipe returns a pipe that will be connected to the command's standard error +func (w *WSLCmdClient) StderrPipe() (io.ReadCloser, error) { + w.lock.Lock() + defer w.lock.Unlock() + + if !w.initialized { + return nil, fmt.Errorf("command not initialized") + } + if w.started { + return nil, fmt.Errorf("command already started") + } + + return w.cmd.StderrPipe() +} diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index ece8b8dbb5..e025109541 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -20,6 +20,7 @@ import ( "github.com/kevinburke/ssh_config" "github.com/skeema/knownhosts" + "github.com/wavetermdev/waveterm/pkg/genconn" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/telemetry" @@ -369,16 +370,15 @@ func (conn *SSHConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName s } } log.Printf("attempting to install wsh to `%s`", clientDisplayName) - clientOs, err := remote.GetClientOs(client) + clientOs, clientArch, err := remote.GetClientPlatform(ctx, genconn.MakeSSHShellClient(client)) if err != nil { return err } - clientArch, err := remote.GetClientArch(client) + // attempt to install extension + wshLocalPath, err := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) if err != nil { return err } - // attempt to install extension - wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) err = remote.CpHostToRemote(ctx, client, wshLocalPath, wavebase.RemoteFullWshBinPath) if err != nil { return err diff --git a/pkg/remote/connutil.go b/pkg/remote/connutil.go index e001652539..9d758fe5ff 100644 --- a/pkg/remote/connutil.go +++ b/pkg/remote/connutil.go @@ -16,7 +16,8 @@ import ( "strings" "text/template" - "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/genconn" + "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/wavebase" "golang.org/x/crypto/ssh" ) @@ -140,86 +141,113 @@ func hasBashInstalled(client *ssh.Client) (bool, error) { return false, nil } -func GetClientOs(client *ssh.Client) (string, error) { - session, err := client.NewSession() - if err != nil { - return "", err - } +func normalizeOs(os string) string { + os = strings.ToLower(strings.TrimSpace(os)) + return os +} - out, unixErr := session.CombinedOutput("uname -s") - if unixErr == nil { - formatted := strings.ToLower(string(out)) - formatted = strings.TrimSpace(formatted) - return formatted, nil +func normalizeArch(arch string) string { + arch = strings.ToLower(strings.TrimSpace(arch)) + switch arch { + case "x86_64", "amd64": + arch = "x64" + case "arm64", "aarch64": + arch = "arm64" } + return arch +} - session, err = client.NewSession() +// returns (os, arch, error) +// guaranteed to return a supported platform +func GetClientPlatform(ctx context.Context, shell genconn.ShellClient) (string, string, error) { + stdout, stderr, err := genconn.RunSimpleCommand(ctx, shell, genconn.CommandSpec{ + Cmd: "uname -sm", + }) if err != nil { - return "", err - } - - out, cmdErr := session.Output("echo %OS%") - if cmdErr == nil { - formatted := strings.ToLower(string(out)) - formatted = strings.TrimSpace(formatted) - return strings.Split(formatted, "_")[0], nil + return "", "", fmt.Errorf("error running uname -sm: %w, stderr: %s", err, stderr) } - - session, err = client.NewSession() - if err != nil { - return "", err + // Parse and normalize output + parts := strings.Fields(strings.ToLower(strings.TrimSpace(stdout))) + if len(parts) != 2 { + return "", "", fmt.Errorf("unexpected output from uname: %s", stdout) } - - out, psErr := session.Output("echo $env:OS") - if psErr == nil { - formatted := strings.ToLower(string(out)) - formatted = strings.TrimSpace(formatted) - return strings.Split(formatted, "_")[0], nil + os, arch := normalizeOs(parts[0]), normalizeArch(parts[1]) + if err := wavebase.ValidateWshSupportedArch(os, arch); err != nil { + return "", "", err } - return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr) + return os, arch, nil } -func GetClientArch(client *ssh.Client) (string, error) { - session, err := client.NewSession() +var installTemplateRawDefault = strings.TrimSpace(` +mkdir -p {{.installDir}} || exit 1 +cat > {{.tempPath}} || exit 1 +mv {{.tempPath}} {{.installPath}} || exit 1 +chmod a+x {{.installPath}} || exit 1 +`) +var installTemplate = template.Must(template.New("wsh-install-template").Parse(installTemplateRawDefault)) + +func CpWshToRemote(ctx context.Context, client *ssh.Client, clientOs string, clientArch string) error { + wshLocalPath, err := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) if err != nil { - return "", err + return err } - - out, unixErr := session.CombinedOutput("uname -m") - if unixErr == nil { - return utilfn.FilterValidArch(string(out)) + input, err := os.Open(wshLocalPath) + if err != nil { + return fmt.Errorf("cannot open local file %s: %w", wshLocalPath, err) } - - session, err = client.NewSession() + defer input.Close() + installWords := map[string]string{ + "installDir": filepath.ToSlash(filepath.Dir(wavebase.RemoteFullWshBinPath)), + "tempPath": filepath.ToSlash(wavebase.RemoteFullWshBinPath + ".temp"), + "installPath": filepath.ToSlash(wavebase.RemoteFullWshBinPath), + } + var installCmd bytes.Buffer + if err := installTemplate.Execute(&installCmd, installWords); err != nil { + return fmt.Errorf("failed to prepare install command: %w", err) + } + genCmd, err := genconn.MakeSSHCmdClient(client, genconn.CommandSpec{ + Cmd: installCmd.String(), + }) if err != nil { - return "", err + return fmt.Errorf("failed to create remote command: %w", err) } - - out, cmdErr := session.CombinedOutput("echo %PROCESSOR_ARCHITECTURE%") - if cmdErr == nil && strings.TrimSpace(string(out)) != "%PROCESSOR_ARCHITECTURE%" { - return utilfn.FilterValidArch(string(out)) + stdin, err := genCmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to get stdin pipe: %w", err) } - - session, err = client.NewSession() + defer stdin.Close() + stderrBuf, err := genconn.MakeStderrSyncBuffer(genCmd) if err != nil { - return "", err + return fmt.Errorf("failed to get stderr pipe: %w", err) } - - out, psErr := session.CombinedOutput("echo $env:PROCESSOR_ARCHITECTURE") - if psErr == nil && strings.TrimSpace(string(out)) != "$env:PROCESSOR_ARCHITECTURE" { - return utilfn.FilterValidArch(string(out)) + if err := genCmd.Start(); err != nil { + return fmt.Errorf("failed to start remote command: %w", err) } - return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr) + copyDone := make(chan error, 1) + go func() { + defer close(copyDone) + defer stdin.Close() + if _, err := io.Copy(stdin, input); err != nil && err != io.EOF { + copyDone <- fmt.Errorf("failed to copy data: %w", err) + } else { + copyDone <- nil + } + }() + select { + case <-ctx.Done(): + genCmd.Kill() + return ctx.Err() + case err := <-copyDone: + if err != nil { + return fmt.Errorf("failed to copy data: %w", err) + } + } + if err := genCmd.Wait(); err != nil { + return fmt.Errorf("remote command failed: %w (stderr: %s)", err, stderrBuf.String()) + } + return nil } -var installTemplateRawDefault = strings.TrimSpace(` -mkdir -p {{.installDir}} || exit 1 -cat > {{.tempPath}} || exit 1 -mv {{.tempPath}} {{.installPath}} || exit 1 -chmod a+x {{.installPath}} || exit 1 -`) -var installTemplate = template.Must(template.New("wsh-install-template").Parse(installTemplateRawDefault)) - func CpHostToRemote(ctx context.Context, client *ssh.Client, sourcePath string, destPath string) error { installWords := map[string]string{ "installDir": filepath.ToSlash(filepath.Dir(destPath)), diff --git a/pkg/util/shellutil/shellutil.go b/pkg/util/shellutil/shellutil.go index 175dd59d3a..bff73d6e03 100644 --- a/pkg/util/shellutil/shellutil.go +++ b/pkg/util/shellutil/shellutil.go @@ -219,7 +219,7 @@ func GetZshZDotDir() string { return filepath.Join(wavebase.GetWaveDataDir(), ZshIntegrationDir) } -func GetWshBaseName(version string, goos string, goarch string) string { +func GetWshBinaryPath(version string, goos string, goarch string) (string, error) { ext := "" if goarch == "amd64" { goarch = "x64" @@ -230,11 +230,11 @@ func GetWshBaseName(version string, goos string, goarch string) string { if goos == "windows" { ext = ".exe" } - return fmt.Sprintf("wsh-%s-%s.%s%s", version, goos, goarch, ext) -} - -func GetWshBinaryPath(version string, goos string, goarch string) string { - return filepath.Join(wavebase.GetWaveAppBinPath(), GetWshBaseName(version, goos, goarch)) + if !wavebase.SupportedWshBinaries[fmt.Sprintf("%s-%s", goos, goarch)] { + return "", fmt.Errorf("unsupported wsh platform: %s-%s", goos, goarch) + } + baseName := fmt.Sprintf("wsh-%s-%s.%s%s", version, goos, goarch, ext) + return filepath.Join(wavebase.GetWaveAppBinPath(), baseName), nil } func InitRcFiles(waveHome string, wshBinDir string) error { @@ -308,8 +308,10 @@ func initCustomShellStartupFilesInternal() error { } // copy the correct binary to bin - wshBaseName := GetWshBaseName(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH) - wshFullPath := GetWshBinaryPath(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH) + wshFullPath, err := GetWshBinaryPath(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH) + if err != nil { + log.Printf("error (non-fatal), could not resolve wsh binary path: %v\n", err) + } if _, err := os.Stat(wshFullPath); err != nil { log.Printf("error (non-fatal), could not resolve wsh binary %q: %v\n", wshFullPath, err) return nil @@ -322,6 +324,7 @@ func initCustomShellStartupFilesInternal() error { if err != nil { return fmt.Errorf("error copying wsh binary to bin: %v", err) } + wshBaseName := filepath.Base(wshFullPath) log.Printf("wsh binary successfully copied from %q to %q\n", wshBaseName, wshDstPath) return nil } diff --git a/pkg/util/syncbuf/syncbuf.go b/pkg/util/syncbuf/syncbuf.go new file mode 100644 index 0000000000..bf210acdf6 --- /dev/null +++ b/pkg/util/syncbuf/syncbuf.go @@ -0,0 +1,41 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package syncbuf + +import ( + "bytes" + "io" + "sync" +) + +type SyncBuffer struct { + lock sync.Mutex + buf *bytes.Buffer +} + +func MakeSyncBuffer() *SyncBuffer { + return &SyncBuffer{ + lock: sync.Mutex{}, + buf: new(bytes.Buffer), + } +} + +// spawns a goroutine to copy the reader to the buffer +func MakeSyncBufferFromReader(r io.Reader) *SyncBuffer { + rtn := MakeSyncBuffer() + go io.Copy(rtn, r) + return rtn +} + +func (s *SyncBuffer) Write(p []byte) (n int, err error) { + s.lock.Lock() + defer s.lock.Unlock() + return s.buf.Write(p) +} + +func (s *SyncBuffer) String() string { + s.lock.Lock() + defer s.lock.Unlock() + return s.buf.String() +} diff --git a/pkg/wavebase/wavebase.go b/pkg/wavebase/wavebase.go index ee4a94916c..d9c3e9c2af 100644 --- a/pkg/wavebase/wavebase.go +++ b/pkg/wavebase/wavebase.go @@ -52,6 +52,15 @@ const AppPathBinDir = "bin" var baseLock = &sync.Mutex{} var ensureDirCache = map[string]bool{} +var SupportedWshBinaries = map[string]bool{ + "darwin-x64": true, + "darwin-arm64": true, + "linux-x64": true, + "linux-arm64": true, + "windows-x64": true, + "windows-arm64": true, +} + type FDLock interface { Close() error } @@ -265,3 +274,10 @@ func UnameKernelRelease() string { }) return osRelease } + +func ValidateWshSupportedArch(os string, arch string) error { + if SupportedWshBinaries[fmt.Sprintf("%s-%s", os, arch)] { + return nil + } + return fmt.Errorf("unsupported wsh platform: %s-%s", os, arch) +} diff --git a/pkg/wsl/wsl.go b/pkg/wsl/wsl.go index 673bc1a73f..80b3e59d77 100644 --- a/pkg/wsl/wsl.go +++ b/pkg/wsl/wsl.go @@ -329,7 +329,10 @@ func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName s return err } // attempt to install extension - wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) + wshLocalPath, err := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) + if err != nil { + return err + } err = CpHostToRemote(ctx, client, wshLocalPath, wavebase.RemoteFullWshBinPath) if err != nil { return err From 198886e6f85b5bc6a75c0d0d03cd6cbeb8406d6c Mon Sep 17 00:00:00 2001 From: sawka Date: Wed, 1 Jan 2025 21:52:06 -0800 Subject: [PATCH 07/11] fix wsl impl for new interface --- pkg/genconn/ssh-impl.go | 2 + pkg/genconn/wsl-impl.go | 164 ++++++++++++++-------------------------- 2 files changed, 57 insertions(+), 109 deletions(-) diff --git a/pkg/genconn/ssh-impl.go b/pkg/genconn/ssh-impl.go index 83aae82f2a..6968e63334 100644 --- a/pkg/genconn/ssh-impl.go +++ b/pkg/genconn/ssh-impl.go @@ -11,6 +11,8 @@ import ( "golang.org/x/crypto/ssh" ) +var _ ShellClient = (*SSHShellClient)(nil) + type SSHShellClient struct { client *ssh.Client } diff --git a/pkg/genconn/wsl-impl.go b/pkg/genconn/wsl-impl.go index 7508bfd24b..a0452fec76 100644 --- a/pkg/genconn/wsl-impl.go +++ b/pkg/genconn/wsl-impl.go @@ -6,8 +6,6 @@ package genconn import ( - "bytes" - "context" "fmt" "io" "sync" @@ -15,101 +13,57 @@ import ( "github.com/ubuntu/gowsl" ) -// WSLSimpleCmdClient implements SimpleShellCmd for Ubuntu/WSL connections -type WSLSimpleCmdClient struct { - client *gowsl.Distro -} +var _ ShellClient = (*WSLShellClient)(nil) -// NewWSLSimpleCmdClient creates a new instance of WSLSimpleCmdClient -func NewWSLSimpleCmdClient(client *gowsl.Distro) *WSLSimpleCmdClient { - return &WSLSimpleCmdClient{client: client} +type WSLShellClient struct { + distro *gowsl.Distro } -// Run executes the given shell command with options in the WSL environment -func (w *WSLSimpleCmdClient) Run(ctx context.Context, cmdSpec CommandSpec) (string, string, error) { - if ctx == nil { - return "", "", fmt.Errorf("nil Context") - } - - // Build the shell command using the shared helper - finalCmd, err := BuildShellCommand(cmdSpec) - if err != nil { - return "", "", fmt.Errorf("failed to build shell command: %w", err) - } - - // Create the command with context - cmd := w.client.Command(ctx, finalCmd) - if cmd == nil { - return "", "", fmt.Errorf("failed to create WSL command") - } - - // Create buffers for stdout and stderr - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - // Run the command - if err := cmd.Run(); err != nil { - return stdout.String(), stderr.String(), fmt.Errorf("command failed: %w", err) - } +func MakeWSLShellClient(distro *gowsl.Distro) *WSLShellClient { + return &WSLShellClient{distro: distro} +} - return stdout.String(), stderr.String(), nil +func (c *WSLShellClient) MakeProcessController(cmdSpec CommandSpec) (ShellProcessController, error) { + return MakeWSLProcessController(c.distro, cmdSpec) } -// WSLCmdClient implements the ShellCmd interface for WSL -type WSLCmdClient struct { - client *gowsl.Distro +type WSLProcessController struct { + distro *gowsl.Distro cmd *gowsl.Cmd lock *sync.Mutex once *sync.Once + stdinPiped bool + stdoutPiped bool + stderrPiped bool waitErr error - initialized bool started bool - commandSpec CommandSpec -} - -// NewWSLCmdClient creates a new instance of WSLCmdClient -func NewWSLCmdClient(client *gowsl.Distro) *WSLCmdClient { - return &WSLCmdClient{ - client: client, - lock: &sync.Mutex{}, - once: &sync.Once{}, - } + cmdSpec CommandSpec } -// Init prepares the command but doesn't start it -func (w *WSLCmdClient) Init(cmd CommandSpec) error { - w.lock.Lock() - defer w.lock.Unlock() - - if w.initialized { - return fmt.Errorf("command already initialized") - } - - finalCmd, err := BuildShellCommand(cmd) +func MakeWSLProcessController(distro *gowsl.Distro, cmdSpec CommandSpec) (*WSLProcessController, error) { + fullCmd, err := BuildShellCommand(cmdSpec) if err != nil { - return fmt.Errorf("failed to build shell command: %w", err) + return nil, fmt.Errorf("failed to build shell command: %w", err) } - // Create command without context since we'll manage lifecycle manually - w.cmd = w.client.Command(nil, finalCmd) - if w.cmd == nil { - return fmt.Errorf("failed to create WSL command") + cmd := distro.Command(nil, fullCmd) + if cmd == nil { + return nil, fmt.Errorf("failed to create WSL command") } - w.commandSpec = cmd - w.initialized = true - return nil + return &WSLProcessController{ + distro: distro, + cmd: cmd, + lock: &sync.Mutex{}, + once: &sync.Once{}, + cmdSpec: cmdSpec, + }, nil } -// Start begins execution of the command -func (w *WSLCmdClient) Start() error { +func (w *WSLProcessController) Start() error { w.lock.Lock() defer w.lock.Unlock() - if !w.initialized { - return fmt.Errorf("command not initialized") - } if w.started { return fmt.Errorf("command already started") } @@ -122,19 +76,14 @@ func (w *WSLCmdClient) Start() error { return nil } -// Wait waits for the command to complete -func (w *WSLCmdClient) Wait() error { - if !w.initialized { - panic("command not initialized") - } +func (w *WSLProcessController) Wait() error { w.once.Do(func() { w.waitErr = w.cmd.Wait() }) return w.waitErr } -// Kill terminates the command -func (w *WSLCmdClient) Kill() { +func (w *WSLProcessController) Kill() { w.lock.Lock() defer w.lock.Unlock() @@ -143,58 +92,55 @@ func (w *WSLCmdClient) Kill() { } } -// ExitCode returns the exit code of the command -func (w *WSLCmdClient) ExitCode() int { - w.lock.Lock() - defer w.lock.Unlock() - - if w.cmd == nil || w.cmd.ProcessState == nil { - return -1 - } - return w.cmd.ProcessState.ExitCode() -} - -// StdinPipe returns a pipe that will be connected to the command's standard input -func (w *WSLCmdClient) StdinPipe() (io.WriteCloser, error) { +func (w *WSLProcessController) StdinPipe() (io.WriteCloser, error) { w.lock.Lock() defer w.lock.Unlock() - if !w.initialized { - return nil, fmt.Errorf("command not initialized") - } if w.started { return nil, fmt.Errorf("command already started") } + if w.stdinPiped { + return nil, fmt.Errorf("stdin already piped") + } + w.stdinPiped = true return w.cmd.StdinPipe() } -// StdoutPipe returns a pipe that will be connected to the command's standard output -func (w *WSLCmdClient) StdoutPipe() (io.ReadCloser, error) { +func (w *WSLProcessController) StdoutPipe() (io.Reader, error) { w.lock.Lock() defer w.lock.Unlock() - if !w.initialized { - return nil, fmt.Errorf("command not initialized") - } if w.started { return nil, fmt.Errorf("command already started") } + if w.stdoutPiped { + return nil, fmt.Errorf("stdout already piped") + } - return w.cmd.StdoutPipe() + w.stdoutPiped = true + stdout, err := w.cmd.StdoutPipe() + if err != nil { + return nil, err + } + return stdout, nil } -// StderrPipe returns a pipe that will be connected to the command's standard error -func (w *WSLCmdClient) StderrPipe() (io.ReadCloser, error) { +func (w *WSLProcessController) StderrPipe() (io.Reader, error) { w.lock.Lock() defer w.lock.Unlock() - if !w.initialized { - return nil, fmt.Errorf("command not initialized") - } if w.started { return nil, fmt.Errorf("command already started") } + if w.stderrPiped { + return nil, fmt.Errorf("stderr already piped") + } - return w.cmd.StderrPipe() + w.stderrPiped = true + stderr, err := w.cmd.StderrPipe() + if err != nil { + return nil, err + } + return stderr, nil } From 23d21caf5f8248028a17f1de96c700a24a5741d2 Mon Sep 17 00:00:00 2001 From: sawka Date: Thu, 2 Jan 2025 10:22:27 -0800 Subject: [PATCH 08/11] mark conn updates as in-progress on ROADMAP --- ROADMAP.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ROADMAP.md b/ROADMAP.md index 4853080c0b..c913dc369e 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -29,7 +29,7 @@ Legend: ✅ Done | 🔧 In Progress | 🔷 Planned | 🤞 Stretch Goal - 🔷 Monaco Theming - 🤞 Blockcontroller fixes for terminal escape sequences - 🤞 Explore VSCode Extension Compatibility with standalone Monaco Editor (language servers) -- 🔷 Various Connection Bugs + Improvements +- 🔧 Various Connection Bugs + Improvements - 🔧 More Connection Config Options ## Future Releases From e9ca08a6e4c6aa790913a70c30b5e791050d3433 Mon Sep 17 00:00:00 2001 From: sawka Date: Thu, 2 Jan 2025 12:05:38 -0800 Subject: [PATCH 09/11] new utility function to wait on process/context. use in RunSimple and in the CpWsh func --- pkg/genconn/genconn.go | 17 +++-- pkg/remote/conncontroller/conncontroller.go | 10 +-- pkg/remote/connutil.go | 84 ++------------------- 3 files changed, 17 insertions(+), 94 deletions(-) diff --git a/pkg/genconn/genconn.go b/pkg/genconn/genconn.go index da448510ef..8c503d20b5 100644 --- a/pkg/genconn/genconn.go +++ b/pkg/genconn/genconn.go @@ -70,24 +70,25 @@ func RunSimpleCommand(ctx context.Context, client ShellClient, spec CommandSpec) io.Copy(stderrBuf, stderr) }() + runErr := ProcessContextWait(ctx, proc) + wg.Wait() + + return stdoutBuf.String(), stderrBuf.String(), runErr +} + +func ProcessContextWait(ctx context.Context, proc ShellProcessController) error { done := make(chan error, 1) go func() { done <- proc.Wait() }() - var runErr error select { case <-ctx.Done(): proc.Kill() - runErr = ctx.Err() + return ctx.Err() case err := <-done: - if err != nil { - runErr = fmt.Errorf("process failed: %w", err) - } + return err } - - wg.Wait() - return stdoutBuf.String(), stderrBuf.String(), runErr } func MakeStdoutSyncBuffer(proc ShellProcessController) (*syncbuf.SyncBuffer, error) { diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index e025109541..df99bd4f02 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -25,7 +25,6 @@ import ( "github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/userinput" - "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" @@ -374,14 +373,9 @@ func (conn *SSHConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName s if err != nil { return err } - // attempt to install extension - wshLocalPath, err := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) + err = remote.CpWshToRemote(ctx, client, clientOs, clientArch) if err != nil { - return err - } - err = remote.CpHostToRemote(ctx, client, wshLocalPath, wavebase.RemoteFullWshBinPath) - if err != nil { - return err + return fmt.Errorf("error installing wsh to remote: %w", err) } log.Printf("successfully installed wsh on %s\n", conn.GetName()) return nil diff --git a/pkg/remote/connutil.go b/pkg/remote/connutil.go index 9d758fe5ff..3e76c102ec 100644 --- a/pkg/remote/connutil.go +++ b/pkg/remote/connutil.go @@ -233,86 +233,14 @@ func CpWshToRemote(ctx context.Context, client *ssh.Client, clientOs string, cli copyDone <- nil } }() - select { - case <-ctx.Done(): - genCmd.Kill() - return ctx.Err() - case err := <-copyDone: - if err != nil { - return fmt.Errorf("failed to copy data: %w", err) - } - } - if err := genCmd.Wait(); err != nil { - return fmt.Errorf("remote command failed: %w (stderr: %s)", err, stderrBuf.String()) - } - return nil -} - -func CpHostToRemote(ctx context.Context, client *ssh.Client, sourcePath string, destPath string) error { - installWords := map[string]string{ - "installDir": filepath.ToSlash(filepath.Dir(destPath)), - "tempPath": filepath.ToSlash(destPath + ".temp"), - "installPath": filepath.ToSlash(destPath), - } - - var installCmd bytes.Buffer - if err := installTemplate.Execute(&installCmd, installWords); err != nil { - return fmt.Errorf("failed to prepare install command: %w", err) - } - - session, err := client.NewSession() - if err != nil { - return fmt.Errorf("failed to create SSH session: %w", err) - } - defer session.Close() - - // Add stderr capture - var stderr bytes.Buffer - session.Stderr = &stderr - - stdin, err := session.StdinPipe() - if err != nil { - return fmt.Errorf("failed to get stdin pipe: %w", err) + procErr := genconn.ProcessContextWait(ctx, genCmd) + if procErr != nil { + return fmt.Errorf("remote command failed: %w (stderr: %s)", procErr, stderrBuf.String()) } - - if err := session.Start(installCmd.String()); err != nil { - return fmt.Errorf("failed to start remote command: %w", err) - } - - input, err := os.Open(sourcePath) - if err != nil { - return fmt.Errorf("cannot open local file %s: %w", sourcePath, err) + copyErr := <-copyDone + if copyErr != nil { + return fmt.Errorf("failed to copy data: %w (stderr: %s)", copyErr, stderrBuf.String()) } - defer input.Close() - - copyDone := make(chan error, 1) - - go func() { - defer close(copyDone) - defer stdin.Close() - - _, err := io.Copy(stdin, input) - if err != nil && err != io.EOF { - copyDone <- err - return - } - copyDone <- nil - }() - - select { - case <-ctx.Done(): - session.Close() - return ctx.Err() - case err := <-copyDone: - if err != nil { - return fmt.Errorf("failed to copy data: %w", err) - } - } - - if err := session.Wait(); err != nil { - return fmt.Errorf("remote command failed: %w (stderr: %s)", err, stderr.String()) - } - return nil } From 002e96472b4a918e9df5c2fd6009e31ba5c364db Mon Sep 17 00:00:00 2001 From: sawka Date: Thu, 2 Jan 2025 12:09:19 -0800 Subject: [PATCH 10/11] switch to text/template --- pkg/wsl/wsl-util.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/wsl/wsl-util.go b/pkg/wsl/wsl-util.go index b837d0e3c7..73decd9577 100644 --- a/pkg/wsl/wsl-util.go +++ b/pkg/wsl/wsl-util.go @@ -8,12 +8,12 @@ import ( "context" "errors" "fmt" - "html/template" "io" "log" "os" "path/filepath" "strings" + "text/template" "time" "github.com/wavetermdev/waveterm/pkg/panichandler" From 9b7d4de518fbc8522eacfe4cd6e97f4df490162e Mon Sep 17 00:00:00 2001 From: sawka Date: Thu, 2 Jan 2025 12:19:39 -0800 Subject: [PATCH 11/11] check semver compatibility for installed wsh --- go.mod | 1 + go.sum | 2 ++ pkg/remote/conncontroller/conncontroller.go | 5 +++-- pkg/remote/connutil.go | 15 +++++++++++++-- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index de542ae694..55f7503542 100644 --- a/go.mod +++ b/go.mod @@ -63,6 +63,7 @@ require ( go.opentelemetry.io/otel/metric v1.29.0 // indirect go.opentelemetry.io/otel/trace v1.29.0 // indirect go.uber.org/atomic v1.7.0 // indirect + golang.org/x/mod v0.22.0 // indirect golang.org/x/net v0.33.0 // indirect golang.org/x/oauth2 v0.24.0 // indirect golang.org/x/sync v0.10.0 // indirect diff --git a/go.sum b/go.sum index 7e137603ee..9379a5a573 100644 --- a/go.sum +++ b/go.sum @@ -133,6 +133,8 @@ go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index df99bd4f02..0428ff1d4b 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -33,6 +33,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" "golang.org/x/crypto/ssh" + "golang.org/x/mod/semver" ) const ( @@ -315,9 +316,9 @@ func (conn *SSHConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName s return fmt.Errorf("client is nil") } // check that correct wsh extensions are installed - expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion) + expectedVersion := fmt.Sprintf("v%s", wavebase.WaveVersion) clientVersion, err := remote.GetWshVersion(client) - if err == nil && clientVersion == expectedVersion && !opts.Force { + if err == nil && !opts.Force && semver.Compare(clientVersion, expectedVersion) >= 0 { return nil } var queryText string diff --git a/pkg/remote/connutil.go b/pkg/remote/connutil.go index 3e76c102ec..c72f0deb01 100644 --- a/pkg/remote/connutil.go +++ b/pkg/remote/connutil.go @@ -20,6 +20,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/wavebase" "golang.org/x/crypto/ssh" + "golang.org/x/mod/semver" ) var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-zA-Z0-9][a-zA-Z0-9.-]*)(?::([0-9]+))?$`) @@ -54,6 +55,7 @@ func DetectShell(client *ssh.Client) (string, error) { return fmt.Sprintf(`"%s"`, strings.TrimSpace(string(out))), nil } +// returns a valid semver version string func GetWshVersion(client *ssh.Client) (string, error) { wshPath := GetWshPath(client) @@ -66,8 +68,17 @@ func GetWshVersion(client *ssh.Client) (string, error) { if err != nil { return "", err } - - return strings.TrimSpace(string(out)), nil + // output is expected to be in the form of "wsh v0.10.4" + // should strip off the "wsh" prefix, and return a semver object + fields := strings.Fields(strings.TrimSpace(string(out))) + if len(fields) != 2 { + return "", fmt.Errorf("unexpected output from wsh version: %s", out) + } + wshVersion := strings.TrimSpace(fields[1]) + if !semver.IsValid(wshVersion) { + return "", fmt.Errorf("invalid semver version: %s", wshVersion) + } + return wshVersion, nil } func GetWshPath(client *ssh.Client) string {