diff --git a/ROADMAP.md b/ROADMAP.md index 4853080c0b..c913dc369e 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -29,7 +29,7 @@ Legend: ✅ Done | 🔧 In Progress | 🔷 Planned | 🤞 Stretch Goal - 🔷 Monaco Theming - 🤞 Blockcontroller fixes for terminal escape sequences - 🤞 Explore VSCode Extension Compatibility with standalone Monaco Editor (language servers) -- 🔷 Various Connection Bugs + Improvements +- 🔧 Various Connection Bugs + Improvements - 🔧 More Connection Config Options ## Future Releases diff --git a/cmd/wsh/cmd/wshcmd-conn.go b/cmd/wsh/cmd/wshcmd-conn.go index 52d76ef677..3e9cacd90e 100644 --- a/cmd/wsh/cmd/wshcmd-conn.go +++ b/cmd/wsh/cmd/wshcmd-conn.go @@ -87,18 +87,26 @@ func validateConnectionName(name string) error { return nil } -func connStatusRun(cmd *cobra.Command, args []string) error { +func getAllConnStatus() ([]wshrpc.ConnStatus, error) { var allResp []wshrpc.ConnStatus sshResp, err := wshclient.ConnStatusCommand(RpcClient, nil) if err != nil { - return fmt.Errorf("getting ssh connection status: %w", err) + return nil, fmt.Errorf("getting ssh connection status: %w", err) } allResp = append(allResp, sshResp...) wslResp, err := wshclient.WslStatusCommand(RpcClient, nil) if err != nil { - return fmt.Errorf("getting wsl connection status: %w", err) + return nil, fmt.Errorf("getting wsl connection status: %w", err) } allResp = append(allResp, wslResp...) + return allResp, nil +} + +func connStatusRun(cmd *cobra.Command, args []string) error { + allResp, err := getAllConnStatus() + if err != nil { + return err + } if len(allResp) == 0 { WriteStdout("no connections\n") return nil @@ -142,21 +150,19 @@ func connDisconnectRun(cmd *cobra.Command, args []string) error { } func connDisconnectAllRun(cmd *cobra.Command, args []string) error { - resp, err := wshclient.ConnStatusCommand(RpcClient, nil) + allConns, err := getAllConnStatus() if err != nil { - return fmt.Errorf("getting connection status: %w", err) - } - if len(resp) == 0 { - return nil + return err } - for _, conn := range resp { - if conn.Status == "connected" { - err := wshclient.ConnDisconnectCommand(RpcClient, conn.Connection, &wshrpc.RpcOpts{Timeout: 10000}) - if err != nil { - WriteStdout("error disconnecting %q: %v\n", conn.Connection, err) - } else { - WriteStdout("disconnected %q\n", conn.Connection) - } + for _, conn := range allConns { + if conn.Status != "connected" { + continue + } + err := wshclient.ConnDisconnectCommand(RpcClient, conn.Connection, &wshrpc.RpcOpts{Timeout: 10000}) + if err != nil { + WriteStdout("error disconnecting %q: %v\n", conn.Connection, err) + } else { + WriteStdout("disconnected %q\n", conn.Connection) } } return nil diff --git a/frontend/app/store/keymodel.ts b/frontend/app/store/keymodel.ts index cadbc8a43a..d6ec2a1c5a 100644 --- a/frontend/app/store/keymodel.ts +++ b/frontend/app/store/keymodel.ts @@ -372,7 +372,6 @@ function registerGlobalKeys() { return false; } globalKeyMap.set("Cmd:f", activateSearch); - globalKeyMap.set("Ctrl:f", activateSearch); globalKeyMap.set("Escape", deactivateSearch); const allKeys = Array.from(globalKeyMap.keys()); // special case keys, handled by web view diff --git a/go.mod b/go.mod index de542ae694..55f7503542 100644 --- a/go.mod +++ b/go.mod @@ -63,6 +63,7 @@ require ( go.opentelemetry.io/otel/metric v1.29.0 // indirect go.opentelemetry.io/otel/trace v1.29.0 // indirect go.uber.org/atomic v1.7.0 // indirect + golang.org/x/mod v0.22.0 // indirect golang.org/x/net v0.33.0 // indirect golang.org/x/oauth2 v0.24.0 // indirect golang.org/x/sync v0.10.0 // indirect diff --git a/go.sum b/go.sum index 7e137603ee..9379a5a573 100644 --- a/go.sum +++ b/go.sum @@ -133,6 +133,8 @@ go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index d35f47ecc3..c5afeb5c8f 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -273,27 +273,34 @@ func createCmdStrAndOpts(blockId string, blockMeta waveobj.MetaMapType) (string, } func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj.MetaMapType) error { + shellProc, err := bc.setupAndStartShellProcess(rc, blockMeta) + if err != nil { + return err + } + return bc.manageRunningShellProcess(shellProc, rc, blockMeta) +} + +func (bc *BlockController) setupAndStartShellProcess(rc *RunShellOpts, blockMeta waveobj.MetaMapType) (*shellexec.ShellProc, error) { // create a circular blockfile for the output ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) defer cancelFn() - err := filestore.WFS.MakeFile(ctx, bc.BlockId, BlockFile_Term, nil, filestore.FileOptsType{MaxSize: DefaultTermMaxFileSize, Circular: true}) - if err != nil && err != fs.ErrExist { - err = fs.ErrExist - return fmt.Errorf("error creating blockfile: %w", err) + fsErr := filestore.WFS.MakeFile(ctx, bc.BlockId, BlockFile_Term, nil, filestore.FileOptsType{MaxSize: DefaultTermMaxFileSize, Circular: true}) + if fsErr != nil && fsErr != fs.ErrExist { + return nil, fmt.Errorf("error creating blockfile: %w", fsErr) } - if err == fs.ErrExist { + if fsErr == fs.ErrExist { // reset the terminal state bc.resetTerminalState() } - err = nil bcInitStatus := bc.GetRuntimeStatus() if bcInitStatus.ShellProcStatus == Status_Running { - return nil + return nil, nil } // TODO better sync here (don't let two starts happen at the same times) remoteName := blockMeta.GetString(waveobj.MetaKey_Connection, "") var cmdStr string var cmdOpts shellexec.CommandOptsType + var err error if bc.ControllerType == BlockController_Shell { cmdOpts.Env = make(map[string]string) cmdOpts.Interactive = true @@ -302,7 +309,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj if cmdOpts.Cwd != "" { cwdPath, err := wavebase.ExpandHomeDir(cmdOpts.Cwd) if err != nil { - return err + return nil, err } cmdOpts.Cwd = cwdPath } @@ -310,11 +317,11 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj var cmdOptsPtr *shellexec.CommandOptsType cmdStr, cmdOptsPtr, err = createCmdStrAndOpts(bc.BlockId, blockMeta) if err != nil { - return err + return nil, err } cmdOpts = *cmdOptsPtr } else { - return fmt.Errorf("unknown controller type %q", bc.ControllerType) + return nil, fmt.Errorf("unknown controller type %q", bc.ControllerType) } var shellProc *shellexec.ShellProc if strings.HasPrefix(remoteName, "wsl://") { @@ -325,20 +332,20 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj wslConn := wsl.GetWslConn(credentialCtx, wslName, false) connStatus := wslConn.DeriveConnStatus() if connStatus.Status != conncontroller.Status_Connected { - return fmt.Errorf("not connected, cannot start shellproc") + return nil, fmt.Errorf("not connected, cannot start shellproc") } // create jwt if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) { jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName()) if err != nil { - return fmt.Errorf("error making jwt token: %w", err) + return nil, fmt.Errorf("error making jwt token: %w", err) } cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr } shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) if err != nil { - return err + return nil, err } } else if remoteName != "" { credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second) @@ -346,24 +353,24 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj opts, err := remote.ParseOpts(remoteName) if err != nil { - return err + return nil, err } conn := conncontroller.GetConn(credentialCtx, opts, false, &wshrpc.ConnKeywords{}) connStatus := conn.DeriveConnStatus() if connStatus.Status != conncontroller.Status_Connected { - return fmt.Errorf("not connected, cannot start shellproc") + return nil, fmt.Errorf("not connected, cannot start shellproc") } if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) { jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: conn.Opts.String()}, conn.GetDomainSocketName()) if err != nil { - return fmt.Errorf("error making jwt token: %w", err) + return nil, fmt.Errorf("error making jwt token: %w", err) } cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr } if !conn.WshEnabled.Load() { shellProc, err = shellexec.StartRemoteShellProcNoWsh(rc.TermSize, cmdStr, cmdOpts, conn) if err != nil { - return err + return nil, err } } else { shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn) @@ -376,19 +383,16 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj log.Print("attempting install without wsh") shellProc, err = shellexec.StartRemoteShellProcNoWsh(rc.TermSize, cmdStr, cmdOpts, conn) if err != nil { - return err + return nil, err } } } - if err != nil { - return err - } } else { // local terminal if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) { jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, wavebase.GetDomainSocketName()) if err != nil { - return fmt.Errorf("error making jwt token: %w", err) + return nil, fmt.Errorf("error making jwt token: %w", err) } cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr } @@ -407,7 +411,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj } shellProc, err = shellexec.StartShellProc(rc.TermSize, cmdStr, cmdOpts) if err != nil { - return err + return nil, err } } bc.UpdateControllerAndSendUpdate(func() bool { @@ -415,6 +419,10 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj bc.ShellProcStatus = Status_Running return true }) + return shellProc, nil +} + +func (bc *BlockController) manageRunningShellProcess(shellProc *shellexec.ShellProc, rc *RunShellOpts, blockMeta waveobj.MetaMapType) error { shellInputCh := make(chan *BlockInputUnion, 32) bc.ShellInputCh = shellInputCh @@ -473,14 +481,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj shellProc.Cmd.Write(ic.InputData) } if ic.TermSize != nil { - err = setTermSize(ctx, bc.BlockId, *ic.TermSize) - if err != nil { - log.Printf("error setting pty size: %v\n", err) - } - err = shellProc.Cmd.SetSize(ic.TermSize.Rows, ic.TermSize.Cols) - if err != nil { - log.Printf("error setting pty size: %v\n", err) - } + updateTermSize(shellProc, bc.BlockId, *ic.TermSize) } } }() @@ -522,6 +523,17 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj return nil } +func updateTermSize(shellProc *shellexec.ShellProc, blockId string, termSize waveobj.TermSize) { + err := setTermSizeInDB(blockId, termSize) + if err != nil { + log.Printf("error setting pty size: %v\n", err) + } + err = shellProc.Cmd.SetSize(termSize.Rows, termSize.Cols) + if err != nil { + log.Printf("error setting pty size: %v\n", err) + } +} + func checkCloseOnExit(blockId string, exitCode int) { ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout) defer cancelFn() @@ -569,16 +581,22 @@ func getTermSize(bdata *waveobj.Block) waveobj.TermSize { } } -func setTermSize(ctx context.Context, blockId string, termSize waveobj.TermSize) error { +func setTermSizeInDB(blockId string, termSize waveobj.TermSize) error { + ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelFn() ctx = waveobj.ContextWithUpdates(ctx) - bdata, err := wstore.DBMustGet[*waveobj.Block](context.Background(), blockId) + bdata, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId) if err != nil { return fmt.Errorf("error getting block data: %v", err) } if bdata.RuntimeOpts == nil { - return fmt.Errorf("error from nil RuntimeOpts: %v", err) + bdata.RuntimeOpts = &waveobj.RuntimeOpts{} } bdata.RuntimeOpts.TermSize = termSize + err = wstore.DBUpdate(ctx, bdata) + if err != nil { + return fmt.Errorf("error updating block data: %v", err) + } updates := waveobj.ContextGetUpdatesRtn(ctx) wps.Broker.SendUpdateEvents(updates) return nil diff --git a/pkg/genconn/genconn.go b/pkg/genconn/genconn.go new file mode 100644 index 0000000000..8c503d20b5 --- /dev/null +++ b/pkg/genconn/genconn.go @@ -0,0 +1,133 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +// generic connection code (WSL + SSH) +package genconn + +import ( + "context" + "fmt" + "io" + "regexp" + "strings" + "sync" + + "github.com/wavetermdev/waveterm/pkg/util/syncbuf" +) + +type CommandSpec struct { + Cmd string + Env map[string]string + Cwd string +} + +type ShellClient interface { + MakeProcessController(cmd CommandSpec) (ShellProcessController, error) +} + +type ShellProcessController interface { + Start() error + Wait() error + Kill() + + // these are not required to be called, if they are not called, the impl will set to discard output + StdinPipe() (io.WriteCloser, error) + StdoutPipe() (io.Reader, error) + StderrPipe() (io.Reader, error) +} + +func RunSimpleCommand(ctx context.Context, client ShellClient, spec CommandSpec) (string, string, error) { + proc, err := client.MakeProcessController(spec) + if err != nil { + return "", "", fmt.Errorf("failed to create process controller: %w", err) + } + + stdout, err := proc.StdoutPipe() + if err != nil { + return "", "", fmt.Errorf("failed to get stdout pipe: %w", err) + } + stderr, err := proc.StderrPipe() + if err != nil { + return "", "", fmt.Errorf("failed to get stderr pipe: %w", err) + } + + if err := proc.Start(); err != nil { + return "", "", fmt.Errorf("failed to start process: %w", err) + } + + stdoutBuf := syncbuf.MakeSyncBuffer() + stderrBuf := syncbuf.MakeSyncBuffer() + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + io.Copy(stdoutBuf, stdout) + }() + + go func() { + defer wg.Done() + io.Copy(stderrBuf, stderr) + }() + + runErr := ProcessContextWait(ctx, proc) + wg.Wait() + + return stdoutBuf.String(), stderrBuf.String(), runErr +} + +func ProcessContextWait(ctx context.Context, proc ShellProcessController) error { + done := make(chan error, 1) + go func() { + done <- proc.Wait() + }() + + select { + case <-ctx.Done(): + proc.Kill() + return ctx.Err() + case err := <-done: + return err + } +} + +func MakeStdoutSyncBuffer(proc ShellProcessController) (*syncbuf.SyncBuffer, error) { + stdout, err := proc.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to get stdout pipe: %w", err) + } + return syncbuf.MakeSyncBufferFromReader(stdout), nil +} + +func MakeStderrSyncBuffer(proc ShellProcessController) (*syncbuf.SyncBuffer, error) { + stderr, err := proc.StderrPipe() + if err != nil { + return nil, fmt.Errorf("failed to get stderr pipe: %w", err) + } + return syncbuf.MakeSyncBufferFromReader(stderr), nil +} + +func BuildShellCommand(opts CommandSpec) (string, error) { + // Build environment variables + var envVars strings.Builder + for key, value := range opts.Env { + if !isValidEnvVarName(key) { + return "", fmt.Errorf("invalid environment variable name: %q", key) + } + envVars.WriteString(fmt.Sprintf("%s=%s ", key, HardQuote(value))) + } + + // Build the command + shellCmd := opts.Cmd + if opts.Cwd != "" { + shellCmd = fmt.Sprintf("cd %s && %s", HardQuote(opts.Cwd), shellCmd) + } + + // Quote the command for `sh -c` + return fmt.Sprintf("sh -c %s", HardQuote(envVars.String()+shellCmd)), nil +} + +func isValidEnvVarName(name string) bool { + validEnvVarName := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + return validEnvVarName.MatchString(name) +} diff --git a/pkg/genconn/quote.go b/pkg/genconn/quote.go new file mode 100644 index 0000000000..8bb59732f7 --- /dev/null +++ b/pkg/genconn/quote.go @@ -0,0 +1,78 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package genconn + +import "regexp" + +var ( + safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`) + + needsEscape = map[byte]bool{ + '"': true, + '\\': true, + '$': true, + '`': true, + } +) + +func HardQuote(s string) string { + if s == "" { + return "\"\"" + } + + if safePattern.MatchString(s) { + return s + } + + buf := make([]byte, 0, len(s)+5) + buf = append(buf, '"') + + for i := 0; i < len(s); i++ { + if needsEscape[s[i]] { + buf = append(buf, '\\') + } + buf = append(buf, s[i]) + } + + buf = append(buf, '"') + return string(buf) +} + +func SoftQuote(s string) string { + if s == "" { + return "\"\"" + } + + // Handle special case of ~ paths + if len(s) > 0 && s[0] == '~' { + // If it's just ~ or ~/something with no special chars, leave it as is + if len(s) == 1 || (len(s) > 1 && s[1] == '/' && safePattern.MatchString(s[2:])) { + return s + } + + // Otherwise quote everything after the ~ (including the /) + if len(s) > 1 && s[1] == '/' { + return "~" + SoftQuote(s[1:]) + } + } + + if safePattern.MatchString(s) { + return s + } + + buf := make([]byte, 0, len(s)+5) + buf = append(buf, '"') + + for i := 0; i < len(s); i++ { + c := s[i] + // In soft quote, we don't escape $ to allow expansion + if c == '"' || c == '\\' || c == '`' { + buf = append(buf, '\\') + } + buf = append(buf, c) + } + + buf = append(buf, '"') + return string(buf) +} diff --git a/pkg/genconn/quote_test.go b/pkg/genconn/quote_test.go new file mode 100644 index 0000000000..ea25a101dd --- /dev/null +++ b/pkg/genconn/quote_test.go @@ -0,0 +1,110 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 +package genconn + +import "testing" + +func TestQuote(t *testing.T) { + tests := []struct { + name string + input string + wantHard string + wantSoft string + }{ + { + name: "simple strings", + input: "simple", + wantHard: "simple", + wantSoft: "simple", + }, + { + name: "safe path", + input: "path/to/file.txt", + wantHard: "path/to/file.txt", + wantSoft: "path/to/file.txt", + }, + { + name: "empty string", + input: "", + wantHard: `""`, + wantSoft: `""`, + }, + { + name: "tilde alone", + input: "~", + wantHard: `"~"`, + wantSoft: "~", + }, + { + name: "tilde with safe path", + input: "~/foo", + wantHard: `"~/foo"`, + wantSoft: "~/foo", + }, + { + name: "tilde with spaces", + input: "~/foo bar", + wantHard: `"~/foo bar"`, + wantSoft: `~"/foo bar"`, + }, + { + name: "tilde with variable", + input: "~/foo$bar", + wantHard: `"~/foo\$bar"`, + wantSoft: `~"/foo$bar"`, + }, + { + name: "invalid tilde path", + input: "~foo", + wantHard: `"~foo"`, + wantSoft: `"~foo"`, + }, + { + name: "variable at start", + input: "$HOME/.config", + wantHard: `"\$HOME/.config"`, + wantSoft: `"$HOME/.config"`, + }, + { + name: "variable in middle", + input: "prefix$HOME", + wantHard: `"prefix\$HOME"`, + wantSoft: `"prefix$HOME"`, + }, + { + name: "double quotes", + input: `has "quotes"`, + wantHard: `"has \"quotes\""`, + wantSoft: `"has \"quotes\""`, + }, + { + name: "backslash", + input: `back\slash`, + wantHard: `"back\\slash"`, + wantSoft: `"back\\slash"`, + }, + { + name: "backtick", + input: "`cmd`", + wantHard: "\"\\`cmd\\`\"", + wantSoft: "\"\\`cmd\\`\"", + }, + { + name: "spaces", + input: "spaces here", + wantHard: `"spaces here"`, + wantSoft: `"spaces here"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := HardQuote(tt.input); got != tt.wantHard { + t.Errorf("HardQuote(%q) = %q, want %q", tt.input, got, tt.wantHard) + } + if got := SoftQuote(tt.input); got != tt.wantSoft { + t.Errorf("SoftQuote(%q) = %q, want %q", tt.input, got, tt.wantSoft) + } + }) + } +} diff --git a/pkg/genconn/ssh-impl.go b/pkg/genconn/ssh-impl.go new file mode 100644 index 0000000000..6968e63334 --- /dev/null +++ b/pkg/genconn/ssh-impl.go @@ -0,0 +1,145 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package genconn + +import ( + "fmt" + "io" + "sync" + + "golang.org/x/crypto/ssh" +) + +var _ ShellClient = (*SSHShellClient)(nil) + +type SSHShellClient struct { + client *ssh.Client +} + +func MakeSSHShellClient(client *ssh.Client) *SSHShellClient { + return &SSHShellClient{client: client} +} + +func (c *SSHShellClient) MakeProcessController(cmdSpec CommandSpec) (ShellProcessController, error) { + return MakeSSHCmdClient(c.client, cmdSpec) +} + +// SSHProcessController implements ShellCmd for SSH connections +type SSHProcessController struct { + client *ssh.Client + session *ssh.Session + lock *sync.Mutex + once *sync.Once + stdinPiped bool + stdoutPiped bool + stderrPiped bool + waitErr error + started bool + cmdSpec CommandSpec +} + +// MakeSSHCmdClient creates a new instance of SSHCmdClient +func MakeSSHCmdClient(client *ssh.Client, cmdSpec CommandSpec) (*SSHProcessController, error) { + session, err := client.NewSession() + if err != nil { + return nil, fmt.Errorf("failed to create SSH session: %w", err) + } + return &SSHProcessController{ + client: client, + lock: &sync.Mutex{}, + once: &sync.Once{}, + cmdSpec: cmdSpec, + session: session, + }, nil +} + +// Start begins execution of the command +func (s *SSHProcessController) Start() error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.started { + return fmt.Errorf("command already started") + } + + fullCmd, err := BuildShellCommand(s.cmdSpec) + if err != nil { + return fmt.Errorf("failed to build shell command: %w", err) + } + // if stdout/stderr weren't piped, then session.stdout/stderr will be nil + // and the library guarantees that the outputs will be attached to io.Discard + // if stdin hasn't been piped, then session.stdin will be nil + // and the libary guarantees that it will be attached to an empty bytes.Buffer, which will produce an immediate EOF + // tl;dr we don't need to worry about hanging beause of long input or explicitly closing stdin + if err := s.session.Start(fullCmd); err != nil { + return fmt.Errorf("failed to start command: %w", err) + } + s.started = true + return nil +} + +// Wait waits for the command to complete +func (s *SSHProcessController) Wait() error { + s.once.Do(func() { + s.waitErr = s.session.Wait() + }) + return s.waitErr +} + +// Kill terminates the command +func (s *SSHProcessController) Kill() { + s.lock.Lock() + defer s.lock.Unlock() + + if s.session != nil { + s.session.Close() + } +} + +func (s *SSHProcessController) StdinPipe() (io.WriteCloser, error) { + s.lock.Lock() + defer s.lock.Unlock() + if s.started { + return nil, fmt.Errorf("command already started") + } + if s.stdinPiped { + return nil, fmt.Errorf("stdin already piped") + } + s.stdinPiped = true + return s.session.StdinPipe() +} + +func (s *SSHProcessController) StdoutPipe() (io.Reader, error) { + s.lock.Lock() + defer s.lock.Unlock() + if s.started { + return nil, fmt.Errorf("command already started") + } + if s.stdoutPiped { + return nil, fmt.Errorf("stdout already piped") + } + s.stdoutPiped = true + stdout, err := s.session.StdoutPipe() + if err != nil { + return nil, err + } + return stdout, nil +} + +func (s *SSHProcessController) StderrPipe() (io.Reader, error) { + s.lock.Lock() + defer s.lock.Unlock() + if s.started { + return nil, fmt.Errorf("command already started") + } + if s.stderrPiped { + return nil, fmt.Errorf("stderr already piped") + } + s.stderrPiped = true + stderr, err := s.session.StderrPipe() + if err != nil { + return nil, err + } + return stderr, nil +} diff --git a/pkg/genconn/wsl-impl.go b/pkg/genconn/wsl-impl.go new file mode 100644 index 0000000000..a0452fec76 --- /dev/null +++ b/pkg/genconn/wsl-impl.go @@ -0,0 +1,146 @@ +//go:build windows + +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package genconn + +import ( + "fmt" + "io" + "sync" + + "github.com/ubuntu/gowsl" +) + +var _ ShellClient = (*WSLShellClient)(nil) + +type WSLShellClient struct { + distro *gowsl.Distro +} + +func MakeWSLShellClient(distro *gowsl.Distro) *WSLShellClient { + return &WSLShellClient{distro: distro} +} + +func (c *WSLShellClient) MakeProcessController(cmdSpec CommandSpec) (ShellProcessController, error) { + return MakeWSLProcessController(c.distro, cmdSpec) +} + +type WSLProcessController struct { + distro *gowsl.Distro + cmd *gowsl.Cmd + lock *sync.Mutex + once *sync.Once + stdinPiped bool + stdoutPiped bool + stderrPiped bool + waitErr error + started bool + cmdSpec CommandSpec +} + +func MakeWSLProcessController(distro *gowsl.Distro, cmdSpec CommandSpec) (*WSLProcessController, error) { + fullCmd, err := BuildShellCommand(cmdSpec) + if err != nil { + return nil, fmt.Errorf("failed to build shell command: %w", err) + } + + cmd := distro.Command(nil, fullCmd) + if cmd == nil { + return nil, fmt.Errorf("failed to create WSL command") + } + + return &WSLProcessController{ + distro: distro, + cmd: cmd, + lock: &sync.Mutex{}, + once: &sync.Once{}, + cmdSpec: cmdSpec, + }, nil +} + +func (w *WSLProcessController) Start() error { + w.lock.Lock() + defer w.lock.Unlock() + + if w.started { + return fmt.Errorf("command already started") + } + + if err := w.cmd.Start(); err != nil { + return fmt.Errorf("failed to start command: %w", err) + } + + w.started = true + return nil +} + +func (w *WSLProcessController) Wait() error { + w.once.Do(func() { + w.waitErr = w.cmd.Wait() + }) + return w.waitErr +} + +func (w *WSLProcessController) Kill() { + w.lock.Lock() + defer w.lock.Unlock() + + if w.cmd != nil && w.cmd.Process != nil { + w.cmd.Process.Kill() + } +} + +func (w *WSLProcessController) StdinPipe() (io.WriteCloser, error) { + w.lock.Lock() + defer w.lock.Unlock() + + if w.started { + return nil, fmt.Errorf("command already started") + } + if w.stdinPiped { + return nil, fmt.Errorf("stdin already piped") + } + + w.stdinPiped = true + return w.cmd.StdinPipe() +} + +func (w *WSLProcessController) StdoutPipe() (io.Reader, error) { + w.lock.Lock() + defer w.lock.Unlock() + + if w.started { + return nil, fmt.Errorf("command already started") + } + if w.stdoutPiped { + return nil, fmt.Errorf("stdout already piped") + } + + w.stdoutPiped = true + stdout, err := w.cmd.StdoutPipe() + if err != nil { + return nil, err + } + return stdout, nil +} + +func (w *WSLProcessController) StderrPipe() (io.Reader, error) { + w.lock.Lock() + defer w.lock.Unlock() + + if w.started { + return nil, fmt.Errorf("command already started") + } + if w.stderrPiped { + return nil, fmt.Errorf("stderr already piped") + } + + w.stderrPiped = true + stderr, err := w.cmd.StderrPipe() + if err != nil { + return nil, err + } + return stderr, nil +} diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index c769f7f7ee..0428ff1d4b 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -20,11 +20,11 @@ import ( "github.com/kevinburke/ssh_config" "github.com/skeema/knownhosts" + "github.com/wavetermdev/waveterm/pkg/genconn" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/userinput" - "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" @@ -33,6 +33,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" "golang.org/x/crypto/ssh" + "golang.org/x/mod/semver" ) const ( @@ -315,9 +316,9 @@ func (conn *SSHConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName s return fmt.Errorf("client is nil") } // check that correct wsh extensions are installed - expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion) + expectedVersion := fmt.Sprintf("v%s", wavebase.WaveVersion) clientVersion, err := remote.GetWshVersion(client) - if err == nil && clientVersion == expectedVersion && !opts.Force { + if err == nil && !opts.Force && semver.Compare(clientVersion, expectedVersion) >= 0 { return nil } var queryText string @@ -369,19 +370,13 @@ func (conn *SSHConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName s } } log.Printf("attempting to install wsh to `%s`", clientDisplayName) - clientOs, err := remote.GetClientOs(client) + clientOs, clientArch, err := remote.GetClientPlatform(ctx, genconn.MakeSSHShellClient(client)) if err != nil { return err } - clientArch, err := remote.GetClientArch(client) + err = remote.CpWshToRemote(ctx, client, clientOs, clientArch) if err != nil { - return err - } - // attempt to install extension - wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) - err = remote.CpHostToRemote(client, wshLocalPath, wavebase.RemoteFullWshBinPath) - if err != nil { - return err + return fmt.Errorf("error installing wsh to remote: %w", err) } log.Printf("successfully installed wsh on %s\n", conn.GetName()) return nil diff --git a/pkg/remote/connutil.go b/pkg/remote/connutil.go index 90707878db..c72f0deb01 100644 --- a/pkg/remote/connutil.go +++ b/pkg/remote/connutil.go @@ -5,8 +5,8 @@ package remote import ( "bytes" + "context" "fmt" - "html/template" "io" "log" "os" @@ -14,11 +14,13 @@ import ( "path/filepath" "regexp" "strings" + "text/template" - "github.com/wavetermdev/waveterm/pkg/panichandler" - "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/genconn" + "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/wavebase" "golang.org/x/crypto/ssh" + "golang.org/x/mod/semver" ) var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-zA-Z0-9][a-zA-Z0-9.-]*)(?::([0-9]+))?$`) @@ -53,6 +55,7 @@ func DetectShell(client *ssh.Client) (string, error) { return fmt.Sprintf(`"%s"`, strings.TrimSpace(string(out))), nil } +// returns a valid semver version string func GetWshVersion(client *ssh.Client) (string, error) { wshPath := GetWshPath(client) @@ -65,8 +68,17 @@ func GetWshVersion(client *ssh.Client) (string, error) { if err != nil { return "", err } - - return strings.TrimSpace(string(out)), nil + // output is expected to be in the form of "wsh v0.10.4" + // should strip off the "wsh" prefix, and return a semver object + fields := strings.Fields(strings.TrimSpace(string(out))) + if len(fields) != 2 { + return "", fmt.Errorf("unexpected output from wsh version: %s", out) + } + wshVersion := strings.TrimSpace(fields[1]) + if !semver.IsValid(wshVersion) { + return "", fmt.Errorf("invalid semver version: %s", wshVersion) + } + return wshVersion, nil } func GetWshPath(client *ssh.Client) string { @@ -140,148 +152,107 @@ func hasBashInstalled(client *ssh.Client) (bool, error) { return false, nil } -func GetClientOs(client *ssh.Client) (string, error) { - session, err := client.NewSession() - if err != nil { - return "", err - } - - out, unixErr := session.CombinedOutput("uname -s") - if unixErr == nil { - formatted := strings.ToLower(string(out)) - formatted = strings.TrimSpace(formatted) - return formatted, nil - } - - session, err = client.NewSession() - if err != nil { - return "", err - } - - out, cmdErr := session.Output("echo %OS%") - if cmdErr == nil { - formatted := strings.ToLower(string(out)) - formatted = strings.TrimSpace(formatted) - return strings.Split(formatted, "_")[0], nil - } - - session, err = client.NewSession() - if err != nil { - return "", err - } +func normalizeOs(os string) string { + os = strings.ToLower(strings.TrimSpace(os)) + return os +} - out, psErr := session.Output("echo $env:OS") - if psErr == nil { - formatted := strings.ToLower(string(out)) - formatted = strings.TrimSpace(formatted) - return strings.Split(formatted, "_")[0], nil +func normalizeArch(arch string) string { + arch = strings.ToLower(strings.TrimSpace(arch)) + switch arch { + case "x86_64", "amd64": + arch = "x64" + case "arm64", "aarch64": + arch = "arm64" } - return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr) + return arch } -func GetClientArch(client *ssh.Client) (string, error) { - session, err := client.NewSession() +// returns (os, arch, error) +// guaranteed to return a supported platform +func GetClientPlatform(ctx context.Context, shell genconn.ShellClient) (string, string, error) { + stdout, stderr, err := genconn.RunSimpleCommand(ctx, shell, genconn.CommandSpec{ + Cmd: "uname -sm", + }) if err != nil { - return "", err + return "", "", fmt.Errorf("error running uname -sm: %w, stderr: %s", err, stderr) } - - out, unixErr := session.CombinedOutput("uname -m") - if unixErr == nil { - return utilfn.FilterValidArch(string(out)) + // Parse and normalize output + parts := strings.Fields(strings.ToLower(strings.TrimSpace(stdout))) + if len(parts) != 2 { + return "", "", fmt.Errorf("unexpected output from uname: %s", stdout) } - - session, err = client.NewSession() - if err != nil { - return "", err + os, arch := normalizeOs(parts[0]), normalizeArch(parts[1]) + if err := wavebase.ValidateWshSupportedArch(os, arch); err != nil { + return "", "", err } + return os, arch, nil +} - out, cmdErr := session.CombinedOutput("echo %PROCESSOR_ARCHITECTURE%") - if cmdErr == nil && strings.TrimSpace(string(out)) != "%PROCESSOR_ARCHITECTURE%" { - return utilfn.FilterValidArch(string(out)) - } +var installTemplateRawDefault = strings.TrimSpace(` +mkdir -p {{.installDir}} || exit 1 +cat > {{.tempPath}} || exit 1 +mv {{.tempPath}} {{.installPath}} || exit 1 +chmod a+x {{.installPath}} || exit 1 +`) +var installTemplate = template.Must(template.New("wsh-install-template").Parse(installTemplateRawDefault)) - session, err = client.NewSession() +func CpWshToRemote(ctx context.Context, client *ssh.Client, clientOs string, clientArch string) error { + wshLocalPath, err := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) if err != nil { - return "", err - } - - out, psErr := session.CombinedOutput("echo $env:PROCESSOR_ARCHITECTURE") - if psErr == nil && strings.TrimSpace(string(out)) != "$env:PROCESSOR_ARCHITECTURE" { - return utilfn.FilterValidArch(string(out)) + return err } - return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr) -} - -var installTemplateRawBash = `bash -c ' \ -mkdir -p {{.installDir}}; \ -cat > {{.tempPath}}; \ -mv {{.tempPath}} {{.installPath}}; \ -chmod a+x {{.installPath}};' \ -` - -var installTemplateRawDefault = ` \ -mkdir -p {{.installDir}}; \ -cat > {{.tempPath}}; \ -mv {{.tempPath}} {{.installPath}}; \ -chmod a+x {{.installPath}}; \ -` - -func CpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error { - // warning: does not work on windows remote yet - bashInstalled, err := hasBashInstalled(client) + input, err := os.Open(wshLocalPath) if err != nil { - return err + return fmt.Errorf("cannot open local file %s: %w", wshLocalPath, err) } - - var selectedTemplateRaw string - if bashInstalled { - selectedTemplateRaw = installTemplateRawBash - } else { - log.Printf("bash is not installed on remote. attempting with default shell") - selectedTemplateRaw = installTemplateRawDefault + defer input.Close() + installWords := map[string]string{ + "installDir": filepath.ToSlash(filepath.Dir(wavebase.RemoteFullWshBinPath)), + "tempPath": filepath.ToSlash(wavebase.RemoteFullWshBinPath + ".temp"), + "installPath": filepath.ToSlash(wavebase.RemoteFullWshBinPath), } - - // I need to use toSlash here to force unix keybindings - // this means we can't guarantee it will work on a remote windows machine - var installWords = map[string]string{ - "installDir": filepath.ToSlash(filepath.Dir(destPath)), - "tempPath": destPath + ".temp", - "installPath": destPath, + var installCmd bytes.Buffer + if err := installTemplate.Execute(&installCmd, installWords); err != nil { + return fmt.Errorf("failed to prepare install command: %w", err) } - - installCmd := &bytes.Buffer{} - installTemplate := template.Must(template.New("").Parse(selectedTemplateRaw)) - installTemplate.Execute(installCmd, installWords) - - session, err := client.NewSession() + genCmd, err := genconn.MakeSSHCmdClient(client, genconn.CommandSpec{ + Cmd: installCmd.String(), + }) if err != nil { - return err + return fmt.Errorf("failed to create remote command: %w", err) } - - installStdin, err := session.StdinPipe() + stdin, err := genCmd.StdinPipe() if err != nil { - return err + return fmt.Errorf("failed to get stdin pipe: %w", err) } - - err = session.Start(installCmd.String()) + defer stdin.Close() + stderrBuf, err := genconn.MakeStderrSyncBuffer(genCmd) if err != nil { - return err + return fmt.Errorf("failed to get stderr pipe: %w", err) } - - input, err := os.Open(sourcePath) - if err != nil { - return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err) + if err := genCmd.Start(); err != nil { + return fmt.Errorf("failed to start remote command: %w", err) } - + copyDone := make(chan error, 1) go func() { - defer func() { - panichandler.PanicHandler("connutil:CpHostToRemote", recover()) - }() - io.Copy(installStdin, input) - session.Close() // this allows the command to complete for reasons i don't fully understand + defer close(copyDone) + defer stdin.Close() + if _, err := io.Copy(stdin, input); err != nil && err != io.EOF { + copyDone <- fmt.Errorf("failed to copy data: %w", err) + } else { + copyDone <- nil + } }() - - return session.Wait() + procErr := genconn.ProcessContextWait(ctx, genCmd) + if procErr != nil { + return fmt.Errorf("remote command failed: %w (stderr: %s)", procErr, stderrBuf.String()) + } + copyErr := <-copyDone + if copyErr != nil { + return fmt.Errorf("failed to copy data: %w (stderr: %s)", copyErr, stderrBuf.String()) + } + return nil } func InstallClientRcFiles(client *ssh.Client) error { diff --git a/pkg/util/shellutil/shellutil.go b/pkg/util/shellutil/shellutil.go index 91ac6d76b8..bff73d6e03 100644 --- a/pkg/util/shellutil/shellutil.go +++ b/pkg/util/shellutil/shellutil.go @@ -93,6 +93,12 @@ $env:PATH = "{{.WSHBINDIR}}" + "{{.PATHSEP}}" + $env:PATH func DetectLocalShellPath() string { if runtime.GOOS == "windows" { + if pwshPath, lpErr := exec.LookPath("pwsh"); lpErr == nil { + return pwshPath + } + if powershellPath, lpErr := exec.LookPath("powershell"); lpErr == nil { + return powershellPath + } return "powershell.exe" } shellPath := GetMacUserShell() @@ -213,7 +219,7 @@ func GetZshZDotDir() string { return filepath.Join(wavebase.GetWaveDataDir(), ZshIntegrationDir) } -func GetWshBaseName(version string, goos string, goarch string) string { +func GetWshBinaryPath(version string, goos string, goarch string) (string, error) { ext := "" if goarch == "amd64" { goarch = "x64" @@ -224,11 +230,11 @@ func GetWshBaseName(version string, goos string, goarch string) string { if goos == "windows" { ext = ".exe" } - return fmt.Sprintf("wsh-%s-%s.%s%s", version, goos, goarch, ext) -} - -func GetWshBinaryPath(version string, goos string, goarch string) string { - return filepath.Join(wavebase.GetWaveAppBinPath(), GetWshBaseName(version, goos, goarch)) + if !wavebase.SupportedWshBinaries[fmt.Sprintf("%s-%s", goos, goarch)] { + return "", fmt.Errorf("unsupported wsh platform: %s-%s", goos, goarch) + } + baseName := fmt.Sprintf("wsh-%s-%s.%s%s", version, goos, goarch, ext) + return filepath.Join(wavebase.GetWaveAppBinPath(), baseName), nil } func InitRcFiles(waveHome string, wshBinDir string) error { @@ -302,8 +308,10 @@ func initCustomShellStartupFilesInternal() error { } // copy the correct binary to bin - wshBaseName := GetWshBaseName(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH) - wshFullPath := GetWshBinaryPath(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH) + wshFullPath, err := GetWshBinaryPath(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH) + if err != nil { + log.Printf("error (non-fatal), could not resolve wsh binary path: %v\n", err) + } if _, err := os.Stat(wshFullPath); err != nil { log.Printf("error (non-fatal), could not resolve wsh binary %q: %v\n", wshFullPath, err) return nil @@ -316,6 +324,7 @@ func initCustomShellStartupFilesInternal() error { if err != nil { return fmt.Errorf("error copying wsh binary to bin: %v", err) } + wshBaseName := filepath.Base(wshFullPath) log.Printf("wsh binary successfully copied from %q to %q\n", wshBaseName, wshDstPath) return nil } diff --git a/pkg/util/syncbuf/syncbuf.go b/pkg/util/syncbuf/syncbuf.go new file mode 100644 index 0000000000..bf210acdf6 --- /dev/null +++ b/pkg/util/syncbuf/syncbuf.go @@ -0,0 +1,41 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package syncbuf + +import ( + "bytes" + "io" + "sync" +) + +type SyncBuffer struct { + lock sync.Mutex + buf *bytes.Buffer +} + +func MakeSyncBuffer() *SyncBuffer { + return &SyncBuffer{ + lock: sync.Mutex{}, + buf: new(bytes.Buffer), + } +} + +// spawns a goroutine to copy the reader to the buffer +func MakeSyncBufferFromReader(r io.Reader) *SyncBuffer { + rtn := MakeSyncBuffer() + go io.Copy(rtn, r) + return rtn +} + +func (s *SyncBuffer) Write(p []byte) (n int, err error) { + s.lock.Lock() + defer s.lock.Unlock() + return s.buf.Write(p) +} + +func (s *SyncBuffer) String() string { + s.lock.Lock() + defer s.lock.Unlock() + return s.buf.String() +} diff --git a/pkg/wavebase/wavebase.go b/pkg/wavebase/wavebase.go index ee4a94916c..d9c3e9c2af 100644 --- a/pkg/wavebase/wavebase.go +++ b/pkg/wavebase/wavebase.go @@ -52,6 +52,15 @@ const AppPathBinDir = "bin" var baseLock = &sync.Mutex{} var ensureDirCache = map[string]bool{} +var SupportedWshBinaries = map[string]bool{ + "darwin-x64": true, + "darwin-arm64": true, + "linux-x64": true, + "linux-arm64": true, + "windows-x64": true, + "windows-arm64": true, +} + type FDLock interface { Close() error } @@ -265,3 +274,10 @@ func UnameKernelRelease() string { }) return osRelease } + +func ValidateWshSupportedArch(os string, arch string) error { + if SupportedWshBinaries[fmt.Sprintf("%s-%s", os, arch)] { + return nil + } + return fmt.Errorf("unsupported wsh platform: %s-%s", os, arch) +} diff --git a/pkg/wsl/wsl-util.go b/pkg/wsl/wsl-util.go index b837d0e3c7..73decd9577 100644 --- a/pkg/wsl/wsl-util.go +++ b/pkg/wsl/wsl-util.go @@ -8,12 +8,12 @@ import ( "context" "errors" "fmt" - "html/template" "io" "log" "os" "path/filepath" "strings" + "text/template" "time" "github.com/wavetermdev/waveterm/pkg/panichandler" diff --git a/pkg/wsl/wsl.go b/pkg/wsl/wsl.go index 673bc1a73f..80b3e59d77 100644 --- a/pkg/wsl/wsl.go +++ b/pkg/wsl/wsl.go @@ -329,7 +329,10 @@ func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName s return err } // attempt to install extension - wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) + wshLocalPath, err := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) + if err != nil { + return err + } err = CpHostToRemote(ctx, client, wshLocalPath, wavebase.RemoteFullWshBinPath) if err != nil { return err