diff --git a/cmd/wsh/cmd/wshcmd-ai.go b/cmd/wsh/cmd/wshcmd-ai.go index 0b754eb90f..71f175089a 100644 --- a/cmd/wsh/cmd/wshcmd-ai.go +++ b/cmd/wsh/cmd/wshcmd-ai.go @@ -142,8 +142,8 @@ func aiRun(cmd *cobra.Command, args []string) (rtnErr error) { if message.Len() == 0 { return fmt.Errorf("message is empty") } - if message.Len() > 10*1024 { - return fmt.Errorf("current max message size is 10k") + if message.Len() > 50*1024 { + return fmt.Errorf("current max message size is 50k") } messageData := wshrpc.AiMessageData{ diff --git a/cmd/wsh/cmd/wshcmd-rcfiles.go b/cmd/wsh/cmd/wshcmd-rcfiles.go index 2db2fb76bf..745d325682 100644 --- a/cmd/wsh/cmd/wshcmd-rcfiles.go +++ b/cmd/wsh/cmd/wshcmd-rcfiles.go @@ -19,7 +19,7 @@ var rcfilesCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { err := wshutil.InstallRcFiles() if err != nil { - WriteStderr(err.Error()) + WriteStderr("%s\n", err.Error()) return } }, diff --git a/cmd/wsh/cmd/wshcmd-setbg.go b/cmd/wsh/cmd/wshcmd-setbg.go index 5c05feeafa..ed0b114001 100644 --- a/cmd/wsh/cmd/wshcmd-setbg.go +++ b/cmd/wsh/cmd/wshcmd-setbg.go @@ -162,7 +162,7 @@ func setBgRun(cmd *cobra.Command, args []string) (rtnErr error) { if err != nil { return fmt.Errorf("error formatting metadata: %v", err) } - WriteStdout(string(jsonBytes) + "\n") + WriteStdout("%s\n", string(jsonBytes)) return nil } diff --git a/frontend/app/view/preview/preview.tsx b/frontend/app/view/preview/preview.tsx index ece27f24fc..7add7e3ded 100644 --- a/frontend/app/view/preview/preview.tsx +++ b/frontend/app/view/preview/preview.tsx @@ -38,6 +38,15 @@ import "./preview.scss"; const MaxFileSize = 1024 * 1024 * 10; // 10MB const MaxCSVSize = 1024 * 1024 * 1; // 1MB +// TODO drive this using config +const BOOKMARKS: { label: string; path: string }[] = [ + { label: "Home", path: "~" }, + { label: "Desktop", path: "~/Desktop" }, + { label: "Downloads", path: "~/Downloads" }, + { label: "Documents", path: "~/Documents" }, + { label: "Root", path: "/" }, +]; + type SpecializedViewProps = { model: PreviewModel; parentRef: React.RefObject; @@ -185,27 +194,10 @@ export class PreviewModel implements ViewModel { elemtype: "iconbutton", icon: "folder-open", longClick: (e: React.MouseEvent) => { - const menuItems: ContextMenuItem[] = []; - menuItems.push({ - label: "Go to Home", - click: () => this.goHistory("~"), - }); - menuItems.push({ - label: "Go to Desktop", - click: () => this.goHistory("~/Desktop"), - }); - menuItems.push({ - label: "Go to Downloads", - click: () => this.goHistory("~/Downloads"), - }); - menuItems.push({ - label: "Go to Documents", - click: () => this.goHistory("~/Documents"), - }); - menuItems.push({ - label: "Go to Root", - click: () => this.goHistory("/"), - }); + const menuItems: ContextMenuItem[] = BOOKMARKS.map((bookmark) => ({ + label: `Go to ${bookmark.label} (${bookmark.path})`, + click: () => this.goHistory(bookmark.path), + })); ContextMenuModel.showContextMenu(menuItems, e); }, }; diff --git a/frontend/app/view/waveai/waveai.tsx b/frontend/app/view/waveai/waveai.tsx index bcaa21e7e6..8ba1a08526 100644 --- a/frontend/app/view/waveai/waveai.tsx +++ b/frontend/app/view/waveai/waveai.tsx @@ -523,7 +523,6 @@ const ChatWindow = memo( const handleNewMessage = useCallback( throttle(100, (messagesLen: number) => { if (osRef.current?.osInstance()) { - console.log("handleNewMessage", messagesLen, isUserScrolling.current); const { viewport } = osRef.current.osInstance().elements(); if (prevMessagesLenRef.current !== messagesLen || !isUserScrolling.current) { viewport.scrollTo({ diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 370d63137c..876e63bd5e 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -304,6 +304,7 @@ declare global { "conn:askbeforewshinstall"?: boolean; "conn:overrideconfig"?: boolean; "conn:wshpath"?: string; + "conn:shellpath"?: string; "display:hidden"?: boolean; "display:order"?: number; "term:*"?: boolean; diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index df1ff1be7a..08a75a47af 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -369,18 +369,18 @@ func (bc *BlockController) setupAndStartShellProcess(rc *RunShellOpts, blockMeta cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr } if !conn.WshEnabled.Load() { - shellProc, err = shellexec.StartRemoteShellProcNoWsh(rc.TermSize, cmdStr, cmdOpts, conn) + shellProc, err = shellexec.StartRemoteShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, conn) if err != nil { return nil, err } } else { - shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn) + shellProc, err = shellexec.StartRemoteShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, conn) if err != nil { conn.SetWshError(err) conn.WshEnabled.Store(false) log.Printf("error starting remote shell proc with wsh: %v", err) log.Print("attempting install without wsh") - shellProc, err = shellexec.StartRemoteShellProcNoWsh(rc.TermSize, cmdStr, cmdOpts, conn) + shellProc, err = shellexec.StartRemoteShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, conn) if err != nil { return nil, err } @@ -408,7 +408,7 @@ func (bc *BlockController) setupAndStartShellProcess(rc *RunShellOpts, blockMeta if len(blockMeta.GetStringList(waveobj.MetaKey_TermLocalShellOpts)) > 0 { cmdOpts.ShellOpts = append([]string{}, blockMeta.GetStringList(waveobj.MetaKey_TermLocalShellOpts)...) } - shellProc, err = shellexec.StartShellProc(rc.TermSize, cmdStr, cmdOpts) + shellProc, err = shellexec.StartLocalShellProc(rc.TermSize, cmdStr, cmdOpts) if err != nil { return nil, err } diff --git a/pkg/genconn/quote.go b/pkg/genconn/quote.go index ad21910eab..469359cbfa 100644 --- a/pkg/genconn/quote.go +++ b/pkg/genconn/quote.go @@ -6,16 +6,13 @@ package genconn import "regexp" var ( - safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`) - - needsEscape = map[byte]bool{ - '"': true, - '\\': true, - '$': true, - '`': true, - } + safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`) + psSafePattern = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`) ) +// TODO: fish quoting is slightly different +// specifically \` will cause an inconsistency between fish and bash/zsh :/ +// might need a specific fish quoting function, and an explicit fish shell detection func HardQuote(s string) string { if s == "" { return "\"\"" @@ -29,10 +26,42 @@ func HardQuote(s string) string { buf = append(buf, '"') for i := 0; i < len(s); i++ { - if needsEscape[s[i]] { - buf = append(buf, '\\') + switch s[i] { + case '"', '\\', '$', '`': + buf = append(buf, '\\', s[i]) + case '\n': + buf = append(buf, '\\', '\n') + default: + buf = append(buf, s[i]) } - buf = append(buf, s[i]) + } + + buf = append(buf, '"') + return string(buf) +} + +func HardQuotePowerShell(s string) string { + if s == "" { + return "\"\"" + } + + if psSafePattern.MatchString(s) { + return s + } + + buf := make([]byte, 0, len(s)+5) + buf = append(buf, '"') + + for i := 0; i < len(s); i++ { + c := s[i] + // In PowerShell, backtick (`) is the escape character + switch c { + case '"', '`', '$': + buf = append(buf, '`') + case '\n': + buf = append(buf, '`', 'n') // PowerShell uses `n for newline + } + buf = append(buf, c) } buf = append(buf, '"') diff --git a/pkg/genconn/ssh-impl.go b/pkg/genconn/ssh-impl.go index 945a54aa50..4e49f9e66a 100644 --- a/pkg/genconn/ssh-impl.go +++ b/pkg/genconn/ssh-impl.go @@ -6,6 +6,7 @@ package genconn import ( "fmt" "io" + "log" "sync" "golang.org/x/crypto/ssh" @@ -41,6 +42,7 @@ type SSHProcessController struct { // MakeSSHCmdClient creates a new instance of SSHCmdClient func MakeSSHCmdClient(client *ssh.Client, cmdSpec CommandSpec) (*SSHProcessController, error) { + log.Printf("SSH-NEWSESSION (cmdclient)\n") session, err := client.NewSession() if err != nil { return nil, fmt.Errorf("failed to create SSH session: %w", err) diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index 72b0a7789c..27bd9b7d4b 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -69,11 +69,12 @@ type SSHConn struct { ActiveConnNum int } -var ConnServerCmdTemplate = strings.TrimSpace(` -%s version || echo "not-installed" -read jwt_token -WAVETERM_JWT="$jwt_token" %s connserver -`) +var ConnServerCmdTemplate = strings.TrimSpace( + strings.Join([]string{ + "%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm);", + "read jwt_token;", + "WAVETERM_JWT=\"$jwt_token\" %s connserver", + }, "\n")) func GetAllConnStatus() []wshrpc.ConnStatus { globalLock.Lock() @@ -225,37 +226,55 @@ func (conn *SSHConn) OpenDomainSocketListener(ctx context.Context) error { return nil } -// expects the output of `wsh version` which looks like `wsh v0.10.4` or "not-installed" -// returns (up-to-date, semver, error) +// expects the output of `wsh version` which looks like `wsh v0.10.4` or "not-installed [os] [arch]" +// returns (up-to-date, semver, osArchStr, error) // if not up to date, or error, version might be "" -func IsWshVersionUpToDate(wshVersionLine string) (bool, string, error) { +func IsWshVersionUpToDate(wshVersionLine string) (bool, string, string, error) { wshVersionLine = strings.TrimSpace(wshVersionLine) - if wshVersionLine == "not-installed" { - return false, "", nil + if strings.HasPrefix(wshVersionLine, "not-installed") { + return false, "not-installed", strings.TrimSpace(strings.TrimPrefix(wshVersionLine, "not-installed")), nil } parts := strings.Fields(wshVersionLine) if len(parts) != 2 { - return false, "", fmt.Errorf("unexpected version format: %s", wshVersionLine) + return false, "", "", fmt.Errorf("unexpected version format: %s", wshVersionLine) } clientVersion := parts[1] expectedVersion := fmt.Sprintf("v%s", wavebase.WaveVersion) if semver.Compare(clientVersion, expectedVersion) < 0 { - return false, clientVersion, nil + return false, clientVersion, "", nil } - return true, clientVersion, nil + return true, clientVersion, "", nil } -// returns (needsInstall, clientVersion, error) -func (conn *SSHConn) StartConnServer(ctx context.Context) (bool, string, error) { +func (conn *SSHConn) getWshPath() string { + config, ok := conn.getConnectionConfig() + if ok && config.ConnWshPath != "" { + return config.ConnWshPath + } + return wavebase.RemoteFullWshBinPath +} + +func (conn *SSHConn) GetConfigShellPath() string { + config, ok := conn.getConnectionConfig() + if !ok { + return "" + } + return config.ConnShellPath +} + +// returns (needsInstall, clientVersion, osArchStr, error) +// if wsh is not installed, the clientVersion will be "not-installed", and it will also return an osArchStr +// if clientVersion is set, then no osArchStr will be returned +func (conn *SSHConn) StartConnServer(ctx context.Context) (bool, string, string, error) { conn.Infof(ctx, "running StartConnServer...\n") allowed := WithLockRtn(conn, func() bool { return conn.Status == Status_Connecting }) if !allowed { - return false, "", fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus()) + return false, "", "", fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus()) } client := conn.GetClient() - wshPath := remote.GetWshPath(client) + wshPath := conn.getWshPath() rpcCtx := wshrpc.RpcContext{ ClientType: wshrpc.ClientType_ConnServer, Conn: conn.GetName(), @@ -263,49 +282,51 @@ func (conn *SSHConn) StartConnServer(ctx context.Context) (bool, string, error) sockName := conn.GetDomainSocketName() jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName) if err != nil { - return false, "", fmt.Errorf("unable to create jwt token for conn controller: %w", err) + return false, "", "", fmt.Errorf("unable to create jwt token for conn controller: %w", err) } + conn.Infof(ctx, "SSH-NEWSESSION (StartConnServer)\n") sshSession, err := client.NewSession() if err != nil { - return false, "", fmt.Errorf("unable to create ssh session for conn controller: %w", err) + return false, "", "", fmt.Errorf("unable to create ssh session for conn controller: %w", err) } pipeRead, pipeWrite := io.Pipe() sshSession.Stdout = pipeWrite sshSession.Stderr = pipeWrite stdinPipe, err := sshSession.StdinPipe() if err != nil { - return false, "", fmt.Errorf("unable to get stdin pipe: %w", err) + return false, "", "", fmt.Errorf("unable to get stdin pipe: %w", err) } cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath) - log.Printf("starting conn controller: %s\n", cmdStr) + log.Printf("starting conn controller: %q\n", cmdStr) shWrappedCmdStr := fmt.Sprintf("sh -c %s", genconn.HardQuote(cmdStr)) + blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr) err = sshSession.Start(shWrappedCmdStr) if err != nil { - return false, "", fmt.Errorf("unable to start conn controller command: %w", err) + return false, "", "", fmt.Errorf("unable to start conn controller command: %w", err) } linesChan := wshutil.StreamToLinesChan(pipeRead) versionLine, err := wshutil.ReadLineWithTimeout(linesChan, 2*time.Second) if err != nil { sshSession.Close() - return false, "", fmt.Errorf("error reading wsh version: %w", err) + return false, "", "", fmt.Errorf("error reading wsh version: %w", err) } conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine)) - isUpToDate, clientVersion, err := IsWshVersionUpToDate(versionLine) + isUpToDate, clientVersion, osArchStr, err := IsWshVersionUpToDate(versionLine) if err != nil { sshSession.Close() - return false, "", fmt.Errorf("error checking wsh version: %w", err) + return false, "", "", fmt.Errorf("error checking wsh version: %w", err) } - conn.Infof(ctx, "connserver update to date: %v\n", isUpToDate) + conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate) if !isUpToDate { sshSession.Close() - return true, clientVersion, nil + return true, clientVersion, osArchStr, nil } // write the jwt conn.Infof(ctx, "writing jwt token to connserver\n") _, err = fmt.Fprintf(stdinPipe, "%s\n", jwtToken) if err != nil { sshSession.Close() - return false, clientVersion, fmt.Errorf("failed to write JWT token: %w", err) + return false, clientVersion, "", fmt.Errorf("failed to write JWT token: %w", err) } conn.WithLock(func() { conn.ConnController = sshSession @@ -351,11 +372,11 @@ func (conn *SSHConn) StartConnServer(ctx context.Context) (bool, string, error) defer cancelFn() err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn)) if err != nil { - return false, clientVersion, fmt.Errorf("timeout waiting for connserver to register") + return false, clientVersion, "", fmt.Errorf("timeout waiting for connserver to register") } time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready") conn.Infof(ctx, "connserver is registered and ready\n") - return false, clientVersion, nil + return false, clientVersion, "", nil } type WshInstallOpts struct { @@ -438,17 +459,22 @@ func (conn *SSHConn) getPermissionToInstallWsh(ctx context.Context, clientDispla return true, nil } -func (conn *SSHConn) InstallWsh(ctx context.Context) error { +func (conn *SSHConn) InstallWsh(ctx context.Context, osArchStr string) error { conn.Infof(ctx, "running installWsh...\n") client := conn.GetClient() if client == nil { conn.Infof(ctx, "ERROR ssh client is not connected, cannot install\n") return fmt.Errorf("ssh client is not connected, cannot install") } - clientOs, clientArch, err := remote.GetClientPlatform(ctx, genconn.MakeSSHShellClient(client)) + var clientOs, clientArch string + var err error + if osArchStr != "" { + clientOs, clientArch, err = remote.GetClientPlatformFromOsArchStr(ctx, osArchStr) + } else { + clientOs, clientArch, err = remote.GetClientPlatform(ctx, genconn.MakeSSHShellClient(client)) + } if err != nil { conn.Infof(ctx, "ERROR detecting client platform: %v\n", err) - return err } conn.Infof(ctx, "detected remote platform os:%s arch:%s\n", clientOs, clientArch) err = remote.CpWshToRemote(ctx, client, clientOs, clientArch) @@ -547,8 +573,7 @@ func (conn *SSHConn) Connect(ctx context.Context, connFlags *wshrpc.ConnKeywords // logic for saving connection and potential flags (we only save once a connection has been made successfully) // at the moment, identity files is the only saved flag var identityFiles []string - existingConfig := wconfig.GetWatcher().GetFullConfig() - existingConnection, ok := existingConfig.Connections[conn.GetName()] + existingConnection, ok := conn.getConnectionConfig() if ok { identityFiles = existingConnection.SshIdentityFile } @@ -592,7 +617,7 @@ func (conn *SSHConn) getConnWshSettings() (bool, bool) { config := wconfig.GetWatcher().GetFullConfig() enableWsh := config.Settings.ConnWshEnabled askBeforeInstall := wconfig.DefaultBoolPtr(config.Settings.ConnAskBeforeWshInstall, true) - connSettings, ok := config.Connections[conn.GetName()] + connSettings, ok := conn.getConnectionConfig() if ok { if connSettings.ConnWshEnabled != nil { enableWsh = *connSettings.ConnWshEnabled @@ -639,7 +664,7 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string) err = fmt.Errorf("error opening domain socket listener: %w", err) return WshCheckResult{NoWshReason: "error opening domain socket", WshError: err} } - needsInstall, clientVersion, err := conn.StartConnServer(ctx) + needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx) if err != nil { conn.Infof(ctx, "ERROR starting conn server: %v\n", err) err = fmt.Errorf("error starting conn server: %w", err) @@ -647,13 +672,13 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string) } if needsInstall { conn.Infof(ctx, "connserver needs to be (re)installed\n") - err = conn.InstallWsh(ctx) + err = conn.InstallWsh(ctx, osArchStr) if err != nil { conn.Infof(ctx, "ERROR installing wsh: %v\n", err) err = fmt.Errorf("error installing wsh: %w", err) return WshCheckResult{NoWshReason: "error installing wsh/connserver", WshError: err} } - needsInstall, clientVersion, err = conn.StartConnServer(ctx) + needsInstall, clientVersion, _, err = conn.StartConnServer(ctx) if err != nil { conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err) err = fmt.Errorf("error starting conn server (after install): %w", err) @@ -670,6 +695,15 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string) } } +func (conn *SSHConn) getConnectionConfig() (wshrpc.ConnKeywords, bool) { + config := wconfig.GetWatcher().GetFullConfig() + connSettings, ok := config.Connections[conn.GetName()] + if !ok { + return wshrpc.ConnKeywords{}, false + } + return connSettings, true +} + func (conn *SSHConn) persistWshInstalled(ctx context.Context, result WshCheckResult) { conn.WshEnabled.Store(result.WshEnabled) conn.SetWshError(result.WshError) @@ -677,9 +711,8 @@ func (conn *SSHConn) persistWshInstalled(ctx context.Context, result WshCheckRes conn.NoWshReason = result.NoWshReason conn.WshVersion = result.ClientVersion }) - config := wconfig.GetWatcher().GetFullConfig() - connSettings, ok := config.Connections[conn.GetName()] - if ok && connSettings.ConnWshEnabled != nil { + connConfig, ok := conn.getConnectionConfig() + if ok && connConfig.ConnWshEnabled != nil { return } meta := make(map[string]any) diff --git a/pkg/remote/connutil.go b/pkg/remote/connutil.go index b79b104745..2407cbf5a2 100644 --- a/pkg/remote/connutil.go +++ b/pkg/remote/connutil.go @@ -16,6 +16,7 @@ import ( "strings" "text/template" + "github.com/wavetermdev/waveterm/pkg/blocklogger" "github.com/wavetermdev/waveterm/pkg/genconn" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/wavebase" @@ -35,46 +36,6 @@ func ParseOpts(input string) (*SSHOpts, error) { return &SSHOpts{SSHHost: remoteHost, SSHUser: remoteUser, SSHPort: remotePort}, nil } -func GetWshPath(client *ssh.Client) string { - defaultPath := wavebase.RemoteFullWshBinPath - session, err := client.NewSession() - if err != nil { - log.Printf("unable to detect client's wsh path. using default. error: %v", err) - return defaultPath - } - - out, whichErr := session.Output("which wsh") - if whichErr == nil { - return strings.TrimSpace(string(out)) - } - - session, err = client.NewSession() - if err != nil { - log.Printf("unable to detect client's wsh path. using default. error: %v", err) - return defaultPath - } - - out, whereErr := session.Output("where.exe wsh") - if whereErr == nil { - return strings.TrimSpace(string(out)) - } - - // check cmd on windows since it requires an absolute path with backslashes - session, err = client.NewSession() - if err != nil { - log.Printf("unable to detect client's wsh path. using default. error: %v", err) - return defaultPath - } - - out, cmdErr := session.Output("(dir 2>&1 *``|echo %userprofile%\\.waveterm%\\.waveterm\\bin\\wsh.exe);&<# rem #>echo none") //todo - if cmdErr == nil && strings.TrimSpace(string(out)) != "none" { - return strings.TrimSpace(string(out)) - } - - // no custom install, use default path - return defaultPath -} - func normalizeOs(os string) string { os = strings.ToLower(strings.TrimSpace(os)) return os @@ -94,6 +55,7 @@ func normalizeArch(arch string) string { // returns (os, arch, error) // guaranteed to return a supported platform func GetClientPlatform(ctx context.Context, shell genconn.ShellClient) (string, string, error) { + blocklogger.Infof(ctx, "[conndebug] running `uname -sm` to detect client platform\n") stdout, stderr, err := genconn.RunSimpleCommand(ctx, shell, genconn.CommandSpec{ Cmd: "uname -sm", }) @@ -112,16 +74,28 @@ func GetClientPlatform(ctx context.Context, shell genconn.ShellClient) (string, return os, arch, nil } +func GetClientPlatformFromOsArchStr(ctx context.Context, osArchStr string) (string, string, error) { + parts := strings.Fields(strings.TrimSpace(osArchStr)) + if len(parts) != 2 { + return "", "", fmt.Errorf("unexpected output from uname: %s", osArchStr) + } + os, arch := normalizeOs(parts[0]), normalizeArch(parts[1]) + if err := wavebase.ValidateWshSupportedArch(os, arch); err != nil { + return "", "", err + } + return os, arch, nil +} + var installTemplateRawDefault = strings.TrimSpace(` -mkdir -p {{.installDir}} || exit 1 -cat > {{.tempPath}} || exit 1 -mv {{.tempPath}} {{.installPath}} || exit 1 -chmod a+x {{.installPath}} || exit 1 +mkdir -p {{.installDir}} || exit 1; +cat > {{.tempPath}} || exit 1; +mv {{.tempPath}} {{.installPath}} || exit 1; +chmod a+x {{.installPath}} || exit 1; `) var installTemplate = template.Must(template.New("wsh-install-template").Parse(installTemplateRawDefault)) func CpWshToRemote(ctx context.Context, client *ssh.Client, clientOs string, clientArch string) error { - wshLocalPath, err := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) + wshLocalPath, err := shellutil.GetLocalWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) if err != nil { return err } @@ -132,13 +106,14 @@ func CpWshToRemote(ctx context.Context, client *ssh.Client, clientOs string, cli defer input.Close() installWords := map[string]string{ "installDir": filepath.ToSlash(filepath.Dir(wavebase.RemoteFullWshBinPath)), - "tempPath": filepath.ToSlash(wavebase.RemoteFullWshBinPath + ".temp"), - "installPath": filepath.ToSlash(wavebase.RemoteFullWshBinPath), + "tempPath": wavebase.RemoteFullWshBinPath + ".temp", + "installPath": wavebase.RemoteFullWshBinPath, } var installCmd bytes.Buffer if err := installTemplate.Execute(&installCmd, installWords); err != nil { return fmt.Errorf("failed to prepare install command: %w", err) } + blocklogger.Infof(ctx, "[conndebug] copying %q to remote server %q\n", wshLocalPath, wavebase.RemoteFullWshBinPath) genCmd, err := genconn.MakeSSHCmdClient(client, genconn.CommandSpec{ Cmd: installCmd.String(), }) diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index e5ffbce845..545d5e627d 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -25,7 +25,6 @@ import ( "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" "github.com/wavetermdev/waveterm/pkg/util/pamparse" "github.com/wavetermdev/waveterm/pkg/util/shellutil" - "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wshrpc" @@ -36,6 +35,14 @@ import ( const DefaultGracefulKillWait = 400 * time.Millisecond +const ( + ShellType_bash = "bash" + ShellType_zsh = "zsh" + ShellType_fish = "fish" + ShellType_pwsh = "pwsh" + ShellType_unknown = "unknown" +) + type CommandOptsType struct { Interactive bool `json:"interactive,omitempty"` Login bool `json:"login,omitempty"` @@ -151,6 +158,23 @@ func (pp *PipePty) WriteString(s string) (n int, err error) { return pp.Write([]byte(s)) } +func getShellTypeFromShellPath(shellPath string) string { + shellBase := filepath.Base(shellPath) + if strings.Contains(shellBase, "bash") { + return ShellType_bash + } + if strings.Contains(shellBase, "zsh") { + return ShellType_zsh + } + if strings.Contains(shellBase, "fish") { + return ShellType_fish + } + if strings.Contains(shellBase, "pwsh") || strings.Contains(shellBase, "powershell") { + return ShellType_pwsh + } + return ShellType_unknown +} + func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) { utilCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second) defer cancelFn() @@ -245,8 +269,9 @@ func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr st return &ShellProc{Cmd: cmdWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil } -func StartRemoteShellProcNoWsh(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) { +func StartRemoteShellProcNoWsh(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) { client := conn.GetClient() + conn.Infof(ctx, "SSH-NEWSESSION (StartRemoteShellProcNoWsh)") session, err := client.NewSession() if err != nil { return nil, err @@ -287,7 +312,7 @@ func StartRemoteShellProcNoWsh(termSize waveobj.TermSize, cmdStr string, cmdOpts return &ShellProc{Cmd: sessionWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil } -func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) { +func StartRemoteShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) { client := conn.GetClient() connRoute := wshutil.MakeConnectionRouteId(conn.GetName()) rpcClient := wshclient.GetBareRpcClient() @@ -296,14 +321,27 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm return nil, fmt.Errorf("unable to obtain client info: %w", err) } log.Printf("client info collected: %+#v", remoteInfo) - - shellPath := cmdOpts.ShellPath - if shellPath == "" { + var shellPath string + if cmdOpts.ShellPath != "" { + conn.Infof(ctx, "using shell path from command opts: %s\n", cmdOpts.ShellPath) + shellPath = cmdOpts.ShellPath + } + configShellPath := conn.GetConfigShellPath() + if shellPath == "" && configShellPath != "" { + conn.Infof(ctx, "using shell path from config (conn:shellpath): %s\n", configShellPath) + shellPath = configShellPath + } + if shellPath == "" && remoteInfo.Shell != "" { + conn.Infof(ctx, "using shell path detected on remote machine: %s\n", remoteInfo.Shell) shellPath = remoteInfo.Shell } + if shellPath == "" { + conn.Infof(ctx, "no shell path detected, using default (/bin/bash)\n") + shellPath = "/bin/bash" + } var shellOpts []string var cmdCombined string - log.Printf("using shell: %s", shellPath) + log.Printf("detected shell %q for conn %q\n", shellPath, conn.GetName()) err = wshclient.RemoteInstallRcFilesCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000}) if err != nil { @@ -311,46 +349,51 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm return nil, err } shellOpts = append(shellOpts, cmdOpts.ShellOpts...) + shellType := getShellTypeFromShellPath(shellPath) + conn.Infof(ctx, "detected shell type: %s\n", shellType) if cmdStr == "" { /* transform command in order to inject environment vars */ - if isBashShell(shellPath) { - log.Printf("recognized as bash shell") + if shellType == ShellType_bash { // add --rcfile // cant set -l or -i with --rcfile - bashPath := genconn.SoftQuote(fmt.Sprintf("~/.waveterm/%s/.bashrc", shellutil.BashIntegrationDir)) + bashPath := fmt.Sprintf("~/.waveterm/%s/.bashrc", shellutil.BashIntegrationDir) shellOpts = append(shellOpts, "--rcfile", bashPath) - } else if isFishShell(shellPath) { - fishDir := genconn.SoftQuote(fmt.Sprintf("~/.waveterm/%s", shellutil.WaveHomeBinDir)) - carg := fmt.Sprintf(`"set -x PATH %s $PATH"`, fishDir) + } else if shellType == ShellType_fish { + if cmdOpts.Login { + shellOpts = append(shellOpts, "-l") + } + // source the wave.fish file + waveFishPath := fmt.Sprintf("~/.waveterm/%s/wave.fish", shellutil.FishIntegrationDir) + carg := fmt.Sprintf(`"source %s"`, waveFishPath) shellOpts = append(shellOpts, "-C", carg) - } else if remote.IsPowershell(shellPath) { - pwshPath := genconn.SoftQuote(fmt.Sprintf("~/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir)) + } else if shellType == ShellType_pwsh { + pwshPath := fmt.Sprintf("~/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir) // powershell is weird about quoted path executables and requires an ampersand first shellPath = "& " + shellPath shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", pwshPath) } else { if cmdOpts.Login { shellOpts = append(shellOpts, "-l") - } else if cmdOpts.Interactive { + } + if cmdOpts.Interactive { shellOpts = append(shellOpts, "-i") } // zdotdir setting moved to after session is created } cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " ")) - log.Printf("combined command is: %s", cmdCombined) } else { + // TODO check quoting of cmdStr shellPath = cmdStr shellOpts = append(shellOpts, "-c", cmdStr) cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " ")) - log.Printf("combined command is: %s", cmdCombined) } - + conn.Infof(ctx, "starting shell, using command: %s\n", cmdCombined) + conn.Infof(ctx, "SSH-NEWSESSION (StartRemoteShellProc)\n") session, err := client.NewSession() if err != nil { return nil, err } - remoteStdinRead, remoteStdinWriteOurs, err := os.Pipe() if err != nil { return nil, err @@ -381,8 +424,9 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm session.Setenv(envKey, envVal) } - if isZshShell(shellPath) { - zshDir := genconn.SoftQuote(fmt.Sprintf("~/.waveterm/%s", shellutil.ZshIntegrationDir)) + if shellType == ShellType_zsh { + zshDir := fmt.Sprintf("~/.waveterm/%s", shellutil.ZshIntegrationDir) + conn.Infof(ctx, "setting ZDOTDIR to %s\n", zshDir) cmdCombined = fmt.Sprintf(`ZDOTDIR=%s %s`, zshDir, cmdCombined) } @@ -390,13 +434,7 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm if !ok { return nil, fmt.Errorf("no jwt token provided to connection") } - - if remote.IsPowershell(shellPath) { - cmdCombined = fmt.Sprintf(`$env:%s="%s"; %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined) - } else { - cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined) - } - + cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined) session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil) sessionWrap := MakeSessionWrap(session, cmdCombined, pipePty) err = sessionWrap.Start() @@ -425,7 +463,7 @@ func isFishShell(shellPath string) bool { return strings.Contains(shellBase, "fish") } -func StartShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType) (*ShellProc, error) { +func StartLocalShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType) (*ShellProc, error) { shellutil.InitCustomShellStartupFiles() var ecmd *exec.Cmd var shellOpts []string @@ -433,29 +471,34 @@ func StartShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOpt if shellPath == "" { shellPath = shellutil.DetectLocalShellPath() } + shellType := getShellTypeFromShellPath(shellPath) shellOpts = append(shellOpts, cmdOpts.ShellOpts...) if cmdStr == "" { - if isBashShell(shellPath) { + if shellType == ShellType_bash { // add --rcfile // cant set -l or -i with --rcfile - shellOpts = append(shellOpts, "--rcfile", shellutil.GetBashRcFileOverride()) - } else if isFishShell(shellPath) { - wshBinDir := filepath.Join(wavebase.GetWaveDataDir(), shellutil.WaveHomeBinDir) - quotedWshBinDir := utilfn.ShellQuote(wshBinDir, false, 300) - shellOpts = append(shellOpts, "-C", fmt.Sprintf("set -x PATH %s $PATH", quotedWshBinDir)) - } else if remote.IsPowershell(shellPath) { - shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", shellutil.GetWavePowershellEnv()) + shellOpts = append(shellOpts, "--rcfile", shellutil.GetLocalBashRcFileOverride()) + } else if shellType == ShellType_fish { + if cmdOpts.Login { + shellOpts = append(shellOpts, "-l") + } + waveFishPath := shellutil.GetLocalWaveFishFilePath() + carg := fmt.Sprintf("source %s", genconn.HardQuote(waveFishPath)) + shellOpts = append(shellOpts, "-C", carg) + } else if shellType == ShellType_pwsh { + shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", shellutil.GetLocalWavePowershellEnv()) } else { if cmdOpts.Login { shellOpts = append(shellOpts, "-l") - } else if cmdOpts.Interactive { + } + if cmdOpts.Interactive { shellOpts = append(shellOpts, "-i") } } ecmd = exec.Command(shellPath, shellOpts...) ecmd.Env = os.Environ() - if isZshShell(shellPath) { - shellutil.UpdateCmdEnv(ecmd, map[string]string{"ZDOTDIR": shellutil.GetZshZDotDir()}) + if shellType == ShellType_zsh { + shellutil.UpdateCmdEnv(ecmd, map[string]string{"ZDOTDIR": shellutil.GetLocalZshZDotDir()}) } } else { shellOpts = append(shellOpts, "-c", cmdStr) diff --git a/pkg/util/shellutil/shellutil.go b/pkg/util/shellutil/shellutil.go index 79f15e0ee1..f89568bcfe 100644 --- a/pkg/util/shellutil/shellutil.go +++ b/pkg/util/shellutil/shellutil.go @@ -17,6 +17,7 @@ import ( "sync" "time" + "github.com/wavetermdev/waveterm/pkg/genconn" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" @@ -33,9 +34,11 @@ var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`) const DefaultShellPath = "/bin/bash" const ( + // there must be no spaces in these integration dir paths ZshIntegrationDir = "shell/zsh" BashIntegrationDir = "shell/bash" PwshIntegrationDir = "shell/pwsh" + FishIntegrationDir = "shell/fish" WaveHomeBinDir = "bin" ZshStartup_Zprofile = ` @@ -44,9 +47,12 @@ const ( ` ZshStartup_Zshrc = ` -# Source the original zshrc -[ -f ~/.zshrc ] && source ~/.zshrc +# Source the original zshrc only if ZDOTDIR has not been changed +if [ "$ZDOTDIR" = "$WAVETERM_ZDOTDIR" ]; then + [ -f ~/.zshrc ] && source ~/.zshrc +fi +# Custom additions export PATH={{.WSHBINDIR}}:$PATH if [[ -n ${_comps+x} ]]; then source <(wsh completion zsh) @@ -56,10 +62,26 @@ fi ZshStartup_Zlogin = ` # Source the original zlogin [ -f ~/.zlogin ] && source ~/.zlogin + +# Unset ZDOTDIR only if it hasn't been modified +if [ "$ZDOTDIR" = "$WAVETERM_ZDOTDIR" ]; then + unset ZDOTDIR +fi ` ZshStartup_Zshenv = ` +# Store the initial ZDOTDIR value +WAVETERM_ZDOTDIR="$ZDOTDIR" + +# Source the original zshenv [ -f ~/.zshenv ] && source ~/.zshenv + +# Detect if ZDOTDIR has changed +if [ "$ZDOTDIR" != "$WAVETERM_ZDOTDIR" ]; then + # If changed, manually source your custom zshrc from the original WAVETERM_ZDOTDIR + [ -f "$WAVETERM_ZDOTDIR/.zshrc" ] && source "$WAVETERM_ZDOTDIR/.zshrc" +fi + ` BashStartup_Bashrc = ` @@ -83,11 +105,22 @@ if type _init_completion &>/dev/null; then fi ` + + FishStartup_Wavefish = ` +# this file is sourced with -C +# Add Wave binary directory to PATH +set -x PATH {{.WSHBINDIR}} $PATH + +# Load Wave completions +wsh completion fish | source +` + PwshStartup_wavepwsh = ` -# no need to source regular profiles since we cannot -# overwrite those with powershell. Instead we will source -# this file with -NoExit -$env:PATH = "{{.WSHBINDIR}}" + "{{.PATHSEP}}" + $env:PATH +# We source this file with -NoExit -File +$env:PATH = {{.WSHBINDIR_PWSH}} + "{{.PATHSEP}}" + $env:PATH + +# Load Wave completions +wsh completion powershell | Out-String | Invoke-Expression ` ) @@ -207,19 +240,23 @@ func InitCustomShellStartupFiles() error { return err } -func GetBashRcFileOverride() string { +func GetLocalBashRcFileOverride() string { return filepath.Join(wavebase.GetWaveDataDir(), BashIntegrationDir, ".bashrc") } -func GetWavePowershellEnv() string { +func GetLocalWaveFishFilePath() string { + return filepath.Join(wavebase.GetWaveDataDir(), FishIntegrationDir, "wave.fish") +} + +func GetLocalWavePowershellEnv() string { return filepath.Join(wavebase.GetWaveDataDir(), PwshIntegrationDir, "wavepwsh.ps1") } -func GetZshZDotDir() string { +func GetLocalZshZDotDir() string { return filepath.Join(wavebase.GetWaveDataDir(), ZshIntegrationDir) } -func GetWshBinaryPath(version string, goos string, goarch string) (string, error) { +func GetLocalWshBinaryPath(version string, goos string, goarch string) (string, error) { ext := "" if goarch == "amd64" { goarch = "x64" @@ -237,8 +274,10 @@ func GetWshBinaryPath(version string, goos string, goarch string) (string, error return filepath.Join(wavebase.GetWaveAppBinPath(), baseName), nil } -func InitRcFiles(waveHome string, wshBinDir string) error { - // ensure directiries exist +// absWshBinDir must be an absolute, expanded path (no ~ or $HOME, etc.) +// it will be hard-quoted appropriately for the shell +func InitRcFiles(waveHome string, absWshBinDir string) error { + // ensure directories exist zshDir := filepath.Join(waveHome, ZshIntegrationDir) err := wavebase.CacheEnsureDir(zshDir, ZshIntegrationDir, 0755, ZshIntegrationDir) if err != nil { @@ -249,43 +288,55 @@ func InitRcFiles(waveHome string, wshBinDir string) error { if err != nil { return err } + fishDir := filepath.Join(waveHome, FishIntegrationDir) + err = wavebase.CacheEnsureDir(fishDir, FishIntegrationDir, 0755, FishIntegrationDir) + if err != nil { + return err + } pwshDir := filepath.Join(waveHome, PwshIntegrationDir) err = wavebase.CacheEnsureDir(pwshDir, PwshIntegrationDir, 0755, PwshIntegrationDir) if err != nil { return err } + var pathSep string + if runtime.GOOS == "windows" { + pathSep = ";" + } else { + pathSep = ":" + } + params := map[string]string{ + "WSHBINDIR": genconn.HardQuote(absWshBinDir), + "WSHBINDIR_PWSH": genconn.HardQuotePowerShell(absWshBinDir), + "PATHSEP": pathSep, + } + // write files to directory - zprofilePath := filepath.Join(zshDir, ".zprofile") - err = os.WriteFile(zprofilePath, []byte(ZshStartup_Zprofile), 0644) + err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zprofile"), ZshStartup_Zprofile, params) if err != nil { return fmt.Errorf("error writing zsh-integration .zprofile: %v", err) } - err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zshrc"), ZshStartup_Zshrc, map[string]string{"WSHBINDIR": fmt.Sprintf(`"%s"`, wshBinDir)}) + err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zshrc"), ZshStartup_Zshrc, params) if err != nil { return fmt.Errorf("error writing zsh-integration .zshrc: %v", err) } - zloginPath := filepath.Join(zshDir, ".zlogin") - err = os.WriteFile(zloginPath, []byte(ZshStartup_Zlogin), 0644) + err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zlogin"), ZshStartup_Zlogin, params) if err != nil { return fmt.Errorf("error writing zsh-integration .zlogin: %v", err) } - zshenvPath := filepath.Join(zshDir, ".zshenv") - err = os.WriteFile(zshenvPath, []byte(ZshStartup_Zshenv), 0644) + err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zshenv"), ZshStartup_Zshenv, params) if err != nil { return fmt.Errorf("error writing zsh-integration .zshenv: %v", err) } - err = utilfn.WriteTemplateToFile(filepath.Join(bashDir, ".bashrc"), BashStartup_Bashrc, map[string]string{"WSHBINDIR": fmt.Sprintf(`"%s"`, wshBinDir)}) + err = utilfn.WriteTemplateToFile(filepath.Join(bashDir, ".bashrc"), BashStartup_Bashrc, params) if err != nil { return fmt.Errorf("error writing bash-integration .bashrc: %v", err) } - var pathSep string - if runtime.GOOS == "windows" { - pathSep = ";" - } else { - pathSep = ":" + err = utilfn.WriteTemplateToFile(filepath.Join(fishDir, "wave.fish"), FishStartup_Wavefish, params) + if err != nil { + return fmt.Errorf("error writing fish-integration wave.fish: %v", err) } - err = utilfn.WriteTemplateToFile(filepath.Join(pwshDir, "wavepwsh.ps1"), PwshStartup_wavepwsh, map[string]string{"WSHBINDIR": toPwshEnvVarRef(wshBinDir), "PATHSEP": pathSep}) + err = utilfn.WriteTemplateToFile(filepath.Join(pwshDir, "wavepwsh.ps1"), PwshStartup_wavepwsh, params) if err != nil { return fmt.Errorf("error writing pwsh-integration wavepwsh.ps1: %v", err) } @@ -297,7 +348,7 @@ func initCustomShellStartupFilesInternal() error { log.Printf("initializing wsh and shell startup files\n") waveDataHome := wavebase.GetWaveDataDir() binDir := filepath.Join(waveDataHome, WaveHomeBinDir) - err := InitRcFiles(waveDataHome, `$WAVETERM_WSHBINDIR`) + err := InitRcFiles(waveDataHome, binDir) if err != nil { return err } @@ -308,7 +359,7 @@ func initCustomShellStartupFilesInternal() error { } // copy the correct binary to bin - wshFullPath, err := GetWshBinaryPath(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH) + wshFullPath, err := GetLocalWshBinaryPath(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH) if err != nil { log.Printf("error (non-fatal), could not resolve wsh binary path: %v\n", err) } @@ -328,7 +379,3 @@ func initCustomShellStartupFilesInternal() error { log.Printf("wsh binary successfully copied from %q to %q\n", wshBaseName, wshDstPath) return nil } - -func toPwshEnvVarRef(input string) string { - return strings.Replace(input, "$", "$env:", -1) -} diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index b2f7d02dc3..28c69e20a0 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -475,6 +475,7 @@ type ConnKeywords struct { ConnAskBeforeWshInstall *bool `json:"conn:askbeforewshinstall,omitempty"` ConnOverrideConfig bool `json:"conn:overrideconfig,omitempty"` ConnWshPath string `json:"conn:wshpath,omitempty"` + ConnShellPath string `json:"conn:shellpath,omitempty"` DisplayHidden *bool `json:"display:hidden,omitempty"` DisplayOrder float32 `json:"display:order,omitempty"` diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index cb7d9ff646..7473540fea 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -696,7 +696,7 @@ func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, data wshrpc.Co if conn == nil { return fmt.Errorf("connection not found: %s", connName) } - return conn.InstallWsh(ctx) + return conn.InstallWsh(ctx, "") } func (ws *WshServer) ConnUpdateWshCommand(ctx context.Context, remoteInfo wshrpc.RemoteInfo) (bool, error) { @@ -710,7 +710,7 @@ func (ws *WshServer) ConnUpdateWshCommand(ctx context.Context, remoteInfo wshrpc } log.Printf("checking wsh version for connection %s (current: %s)", connName, remoteInfo.ClientVersion) - upToDate, _, err := conncontroller.IsWshVersionUpToDate(remoteInfo.ClientVersion) + upToDate, _, _, err := conncontroller.IsWshVersionUpToDate(remoteInfo.ClientVersion) if err != nil { return false, fmt.Errorf("unable to compare wsh version: %w", err) } diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index 77342fcd6d..6075a50dd8 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -566,6 +566,6 @@ func GetInfo() wshrpc.RemoteInfo { func InstallRcFiles() error { home := wavebase.GetHomeDir() waveDir := filepath.Join(home, wavebase.RemoteWaveHomeDirName) - winBinDir := filepath.Join(waveDir, wavebase.RemoteWshBinDirName) - return shellutil.InitRcFiles(waveDir, winBinDir) + wshBinDir := filepath.Join(waveDir, wavebase.RemoteWshBinDirName) + return shellutil.InitRcFiles(waveDir, wshBinDir) } diff --git a/pkg/wsl/wsl.go b/pkg/wsl/wsl.go index 47c6b3cf1a..ab40fe54bb 100644 --- a/pkg/wsl/wsl.go +++ b/pkg/wsl/wsl.go @@ -337,7 +337,7 @@ func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName s return err } // attempt to install extension - wshLocalPath, err := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) + wshLocalPath, err := shellutil.GetLocalWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) if err != nil { return err }