diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index 564b6c5cf7..d55af25956 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -11,7 +11,6 @@ import ( "log" "os" "os/exec" - "path/filepath" "runtime" "strings" "sync" @@ -43,6 +42,7 @@ type CommandOptsType struct { ShellPath string `json:"shellPath,omitempty"` ShellOpts []string `json:"shellOpts,omitempty"` SwapToken *shellutil.TokenSwapEntry `json:"swapToken,omitempty"` + Env map[string]string `json:"env,omitempty"` } type ShellProc struct { @@ -116,6 +116,17 @@ func checkCwd(cwd string) error { return nil } +func makeEnvPrefix(env map[string]string) string { + if len(env) == 0 { + return "" + } + var envParts []string + for key, value := range env { + envParts = append(envParts, fmt.Sprintf(`%s=%s`, key, shellutil.HardQuote(value))) + } + return strings.Join(envParts, " ") + " " +} + type PipePty struct { remoteStdinWrite *os.File remoteStdoutRead *os.File @@ -270,6 +281,8 @@ func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr st } log.Printf("full combined command: %s", cmdCombined) ecmd := exec.Command("wsl.exe", "~", "-d", client.Name(), "--", "sh", "-c", cmdCombined) + ecmd.Env = os.Environ() + shellutil.UpdateCmdEnv(ecmd, cmdOpts.Env) if termSize.Rows == 0 || termSize.Cols == 0 { termSize.Rows = shellutil.DefaultTermRows termSize.Cols = shellutil.DefaultTermCols @@ -442,6 +455,9 @@ func StartRemoteShellProc(ctx context.Context, logCtx context.Context, termSize conn.Debugf(logCtx, "packed swaptoken %s\n", packedToken) cmdCombined = fmt.Sprintf(`%s=%s %s`, wavebase.WaveSwapTokenVarName, packedToken, cmdCombined) } + envPrefix := makeEnvPrefix(cmdOpts.Env) + cmdCombined = fmt.Sprintf("%s%s", envPrefix, cmdCombined) + shellutil.AddTokenSwapEntry(cmdOpts.SwapToken) session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil) sessionWrap := MakeSessionWrap(session, cmdCombined, pipePty) @@ -453,24 +469,6 @@ func StartRemoteShellProc(ctx context.Context, logCtx context.Context, termSize return &ShellProc{Cmd: sessionWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil } -func isZshShell(shellPath string) bool { - // get the base path, and then check contains - shellBase := filepath.Base(shellPath) - return strings.Contains(shellBase, "zsh") -} - -func isBashShell(shellPath string) bool { - // get the base path, and then check contains - shellBase := filepath.Base(shellPath) - return strings.Contains(shellBase, "bash") -} - -func isFishShell(shellPath string) bool { - // get the base path, and then check contains - shellBase := filepath.Base(shellPath) - return strings.Contains(shellBase, "fish") -} - func StartLocalShellProc(logCtx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType) (*ShellProc, error) { shellutil.InitCustomShellStartupFiles() var ecmd *exec.Cmd @@ -506,6 +504,7 @@ func StartLocalShellProc(logCtx context.Context, termSize waveobj.TermSize, cmdS blocklogger.Debugf(logCtx, "[conndebug] shell:%s shellOpts:%v\n", shellPath, shellOpts) ecmd = exec.Command(shellPath, shellOpts...) ecmd.Env = os.Environ() + shellutil.UpdateCmdEnv(ecmd, cmdOpts.Env) if shellType == shellutil.ShellType_zsh { shellutil.UpdateCmdEnv(ecmd, map[string]string{"ZDOTDIR": shellutil.GetLocalZshZDotDir()}) } @@ -513,6 +512,7 @@ func StartLocalShellProc(logCtx context.Context, termSize waveobj.TermSize, cmdS shellOpts = append(shellOpts, "-c", cmdStr) ecmd = exec.Command(shellPath, shellOpts...) ecmd.Env = os.Environ() + shellutil.UpdateCmdEnv(ecmd, cmdOpts.Env) } packedToken, err := cmdOpts.SwapToken.PackForClient()