diff --git a/cmd/wsh/cmd/wshcmd-run.go b/cmd/wsh/cmd/wshcmd-run.go index ccf1a90a64..ae7099ad82 100644 --- a/cmd/wsh/cmd/wshcmd-run.go +++ b/cmd/wsh/cmd/wshcmd-run.go @@ -11,6 +11,7 @@ import ( "github.com/spf13/cobra" "github.com/wavetermdev/waveterm/pkg/util/envutil" + "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" @@ -136,7 +137,7 @@ func runRun(cmd *cobra.Command, args []string) (rtnErr error) { BlockDef: &waveobj.BlockDef{ Meta: createMeta, Files: map[string]*waveobj.FileDef{ - "env": { + wavebase.BlockFile_Env: { Content: envContent, }, }, diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 1f0ca441ea..e90db47c2f 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -499,11 +499,18 @@ declare global { "cmd:closeonexit"?: boolean; "cmd:closeonexitforce"?: boolean; "cmd:closeonexitdelay"?: number; - "cmd:env"?: {[key: string]: string}; - "cmd:cwd"?: string; "cmd:nowsh"?: boolean; "cmd:args"?: string[]; "cmd:shell"?: boolean; + "cmd:allowconnchange"?: boolean; + "cmd:env"?: {[key: string]: string}; + "cmd:cwd"?: string; + "cmd:initscript"?: string; + "cmd:initscript.sh"?: string; + "cmd:initscript.bash"?: string; + "cmd:initscript.zsh"?: string; + "cmd:initscript.pwsh"?: string; + "cmd:initscript.fish"?: string; "ai:*"?: boolean; "ai:preset"?: string; "ai:apitype"?: string; diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 8e42e63bb0..6787078a78 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -43,9 +43,9 @@ const ( ) const ( - BlockFile_Term = "term" // used for main pty output - BlockFile_Cache = "cache:term:full" // for cached block - BlockFile_VDom = "vdom" // used for alt html layout + ConnType_Local = "local" + ConnType_Wsl = "wsl" + ConnType_Ssh = "ssh" ) const ( @@ -146,14 +146,14 @@ func (bc *BlockController) UpdateControllerAndSendUpdate(updateFn func() bool) { func HandleTruncateBlockFile(blockId string) error { ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout) defer cancelFn() - err := filestore.WFS.WriteFile(ctx, blockId, BlockFile_Term, nil) + err := filestore.WFS.WriteFile(ctx, blockId, wavebase.BlockFile_Term, nil) if err == fs.ErrNotExist { return nil } if err != nil { return fmt.Errorf("error truncating blockfile: %w", err) } - err = filestore.WFS.DeleteFile(ctx, blockId, BlockFile_Cache) + err = filestore.WFS.DeleteFile(ctx, blockId, wavebase.BlockFile_Cache) if err == fs.ErrNotExist { err = nil } @@ -165,7 +165,7 @@ func HandleTruncateBlockFile(blockId string) error { Scopes: []string{waveobj.MakeORef(waveobj.OType_Block, blockId).String()}, Data: &wps.WSFileEventData{ ZoneId: blockId, - FileName: BlockFile_Term, + FileName: wavebase.BlockFile_Term, FileOp: wps.FileOp_Truncate, }, }) @@ -195,87 +195,126 @@ func HandleAppendBlockFile(blockId string, blockFile string, data []byte) error return nil } -func (bc *BlockController) resetTerminalState() { +func (bc *BlockController) resetTerminalState(logCtx context.Context) { ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout) defer cancelFn() - wfile, statErr := filestore.WFS.Stat(ctx, bc.BlockId, BlockFile_Term) + wfile, statErr := filestore.WFS.Stat(ctx, bc.BlockId, wavebase.BlockFile_Term) if statErr == fs.ErrNotExist || wfile.Size == 0 { return } + blocklogger.Debugf(logCtx, "[conndebug] resetTerminalState: resetting terminal state\n") // controller type = "shell" var buf bytes.Buffer // buf.WriteString("\x1b[?1049l") // disable alternative buffer buf.WriteString("\x1b[0m") // reset attributes buf.WriteString("\x1b[?25h") // show cursor buf.WriteString("\x1b[?1000l") // disable mouse tracking - buf.WriteString("\r\n\r\n(restored terminal state)\r\n\r\n") - err := filestore.WFS.AppendData(ctx, bc.BlockId, BlockFile_Term, buf.Bytes()) + buf.WriteString("\r\n\r\n") + err := HandleAppendBlockFile(bc.BlockId, wavebase.BlockFile_Term, buf.Bytes()) if err != nil { log.Printf("error appending to blockfile (terminal reset): %v\n", err) } } -// for "cmd" type blocks -func createCmdStrAndOpts(blockId string, blockMeta waveobj.MetaMapType) (string, *shellexec.CommandOptsType, error) { - var cmdStr string - var cmdOpts shellexec.CommandOptsType - cmdOpts.Env = make(map[string]string) - cmdStr = blockMeta.GetString(waveobj.MetaKey_Cmd, "") - if cmdStr == "" { - return "", nil, fmt.Errorf("missing cmd in block meta") +func getCustomInitScriptKeyCascade(shellType string) []string { + if shellType == "bash" { + return []string{waveobj.MetaKey_CmdInitScriptBash, waveobj.MetaKey_CmdInitScriptSh, waveobj.MetaKey_CmdInitScript} } - cmdOpts.Cwd = blockMeta.GetString(waveobj.MetaKey_CmdCwd, "") - if cmdOpts.Cwd != "" { - cwdPath, err := wavebase.ExpandHomeDir(cmdOpts.Cwd) - if err != nil { - return "", nil, err - } - cmdOpts.Cwd = cwdPath + if shellType == "zsh" { + return []string{waveobj.MetaKey_CmdInitScriptZsh, waveobj.MetaKey_CmdInitScriptSh, waveobj.MetaKey_CmdInitScript} } - useShell := blockMeta.GetBool(waveobj.MetaKey_CmdShell, true) - if !useShell { - if strings.Contains(cmdStr, " ") { - return "", nil, fmt.Errorf("cmd should not have spaces if cmd:shell is false (use cmd:args)") + if shellType == "pwsh" { + return []string{waveobj.MetaKey_CmdInitScriptPwsh, waveobj.MetaKey_CmdInitScript} + } + if shellType == "fish" { + return []string{waveobj.MetaKey_CmdInitScriptFish, waveobj.MetaKey_CmdInitScript} + } + return []string{waveobj.MetaKey_CmdInitScript} +} + +func getCustomInitScript(meta waveobj.MetaMapType, connName string, shellType string) string { + keys := getCustomInitScriptKeyCascade(shellType) + connMeta := meta.GetConnectionOverride(connName) + if connMeta != nil { + for _, key := range keys { + if connMeta.HasKey(key) { + return connMeta.GetString(key, "") + } } - cmdArgs := blockMeta.GetStringList(waveobj.MetaKey_CmdArgs) - // shell escape the args - for _, arg := range cmdArgs { - cmdStr = cmdStr + " " + utilfn.ShellQuote(arg, false, -1) + } + for _, key := range keys { + if meta.HasKey(key) { + return meta.GetString(key, "") } } + return "" +} - // get the "env" file +func resolveEnvMap(blockId string, blockMeta waveobj.MetaMapType, connName string) (map[string]string, error) { ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) defer cancelFn() - _, envFileData, err := filestore.WFS.ReadFile(ctx, blockId, "env") + _, envFileData, err := filestore.WFS.ReadFile(ctx, blockId, wavebase.BlockFile_Env) if err == fs.ErrNotExist { err = nil } if err != nil { - return "", nil, fmt.Errorf("error reading command env file: %w", err) + return nil, fmt.Errorf("error reading command env file: %w", err) } + rtn := make(map[string]string) if len(envFileData) > 0 { envMap := envutil.EnvToMap(string(envFileData)) for k, v := range envMap { - cmdOpts.Env[k] = v + rtn[k] = v } } cmdEnv := blockMeta.GetMap(waveobj.MetaKey_CmdEnv) for k, v := range cmdEnv { if v == nil { + delete(rtn, k) continue } - if _, ok := v.(string); ok { - cmdOpts.Env[k] = v.(string) + if strVal, ok := v.(string); ok { + rtn[k] = strVal } - if _, ok := v.(float64); ok { - cmdOpts.Env[k] = fmt.Sprintf("%v", v) + if floatVal, ok := v.(float64); ok { + rtn[k] = fmt.Sprintf("%v", floatVal) + } + } + return rtn, nil +} + +// for "cmd" type blocks +func createCmdStrAndOpts(blockId string, blockMeta waveobj.MetaMapType, connName string) (string, *shellexec.CommandOptsType, error) { + var cmdStr string + var cmdOpts shellexec.CommandOptsType + cmdStr = blockMeta.GetString(waveobj.MetaKey_Cmd, "") + if cmdStr == "" { + return "", nil, fmt.Errorf("missing cmd in block meta") + } + cmdOpts.Cwd = blockMeta.GetString(waveobj.MetaKey_CmdCwd, "") + if cmdOpts.Cwd != "" { + cwdPath, err := wavebase.ExpandHomeDir(cmdOpts.Cwd) + if err != nil { + return "", nil, err + } + cmdOpts.Cwd = cwdPath + } + useShell := blockMeta.GetBool(waveobj.MetaKey_CmdShell, true) + if !useShell { + if strings.Contains(cmdStr, " ") { + return "", nil, fmt.Errorf("cmd should not have spaces if cmd:shell is false (use cmd:args)") + } + cmdArgs := blockMeta.GetStringList(waveobj.MetaKey_CmdArgs) + // shell escape the args + for _, arg := range cmdArgs { + cmdStr = cmdStr + " " + utilfn.ShellQuote(arg, false, -1) } } return cmdStr, &cmdOpts, nil } func (bc *BlockController) DoRunShellCommand(logCtx context.Context, rc *RunShellOpts, blockMeta waveobj.MetaMapType) error { + blocklogger.Debugf(logCtx, "[conndebug] DoRunShellCommand\n") shellProc, err := bc.setupAndStartShellProcess(logCtx, rc, blockMeta) if err != nil { return err @@ -283,7 +322,7 @@ func (bc *BlockController) DoRunShellCommand(logCtx context.Context, rc *RunShel return bc.manageRunningShellProcess(shellProc, rc, blockMeta) } -func (bc *BlockController) makeSwapToken(ctx context.Context, remoteName string) *shellutil.TokenSwapEntry { +func (bc *BlockController) makeSwapToken(ctx context.Context, blockMeta waveobj.MetaMapType, remoteName string, shellType string) *shellutil.TokenSwapEntry { token := &shellutil.TokenSwapEntry{ Token: uuid.New().String(), Env: make(map[string]string), @@ -314,20 +353,126 @@ func (bc *BlockController) makeSwapToken(ctx context.Context, remoteName string) token.Env["WAVETERM_CLIENTID"] = clientData.OID } token.Env["WAVETERM_CONN"] = remoteName + envMap, err := resolveEnvMap(bc.BlockId, blockMeta, remoteName) + if err != nil { + log.Printf("error resolving env map: %v\n", err) + } + for k, v := range envMap { + token.Env[k] = v + } + token.ScriptText = getCustomInitScript(blockMeta, remoteName, shellType) return token } +type ConnUnion struct { + ConnName string + ConnType string + SshConn *conncontroller.SSHConn + WslConn *wslconn.WslConn + WshEnabled bool + ShellPath string + ShellOpts []string + ShellType string +} + +func getLocalShellPath(blockMeta waveobj.MetaMapType) string { + shellPath := blockMeta.GetString(waveobj.MetaKey_TermLocalShellPath, "") + if shellPath != "" { + return shellPath + } + settings := wconfig.GetWatcher().GetFullConfig().Settings + if settings.TermLocalShellPath != "" { + return settings.TermLocalShellPath + } + return shellutil.DetectLocalShellPath() +} + +func getLocalShellOpts(blockMeta waveobj.MetaMapType) []string { + if blockMeta.HasKey(waveobj.MetaKey_TermLocalShellOpts) { + opts := blockMeta.GetStringList(waveobj.MetaKey_TermLocalShellOpts) + return append([]string{}, opts...) + } + settings := wconfig.GetWatcher().GetFullConfig().Settings + if len(settings.TermLocalShellOpts) > 0 { + return append([]string{}, settings.TermLocalShellOpts...) + } + return nil +} + +func (union *ConnUnion) getRemoteInfoAndShellType(blockMeta waveobj.MetaMapType) error { + if !union.WshEnabled { + return nil + } + if union.ConnType == ConnType_Ssh || union.ConnType == ConnType_Wsl { + connRoute := wshutil.MakeConnectionRouteId(union.ConnName) + remoteInfo, err := wshclient.RemoteGetInfoCommand(wshclient.GetBareRpcClient(), &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000}) + if err != nil { + // weird error, could flip the wshEnabled flag and allow it to go forward, but the connection should have already been vetted + return fmt.Errorf("unable to obtain remote info from connserver: %w", err) + } + // TODO allow overriding remote shell path + union.ShellPath = remoteInfo.Shell + } else { + union.ShellPath = getLocalShellPath(blockMeta) + } + union.ShellType = shellutil.GetShellTypeFromShellPath(union.ShellPath) + return nil +} + +func (bc *BlockController) getConnUnion(logCtx context.Context, remoteName string, blockMeta waveobj.MetaMapType) (ConnUnion, error) { + rtn := ConnUnion{ConnName: remoteName} + wshEnabled := !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) + if strings.HasPrefix(remoteName, "wsl://") { + wslName := strings.TrimPrefix(remoteName, "wsl://") + wslConn := wslconn.GetWslConn(wslName) + if wslConn == nil { + return ConnUnion{}, fmt.Errorf("wsl connection not found: %s", remoteName) + } + connStatus := wslConn.DeriveConnStatus() + if connStatus.Status != conncontroller.Status_Connected { + return ConnUnion{}, fmt.Errorf("wsl connection %s not connected, cannot start shellproc", remoteName) + } + rtn.ConnType = ConnType_Wsl + rtn.WslConn = wslConn + rtn.WshEnabled = wshEnabled && wslConn.WshEnabled.Load() + } else if remoteName != "" { + opts, err := remote.ParseOpts(remoteName) + if err != nil { + return ConnUnion{}, fmt.Errorf("invalid ssh remote name (%s): %w", remoteName, err) + } + conn := conncontroller.GetConn(opts) + if conn == nil { + return ConnUnion{}, fmt.Errorf("ssh connection not found: %s", remoteName) + } + connStatus := conn.DeriveConnStatus() + if connStatus.Status != conncontroller.Status_Connected { + return ConnUnion{}, fmt.Errorf("ssh connection %s not connected, cannot start shellproc", remoteName) + } + rtn.ConnType = ConnType_Ssh + rtn.SshConn = conn + rtn.WshEnabled = wshEnabled && conn.WshEnabled.Load() + } else { + rtn.ConnType = ConnType_Local + rtn.WshEnabled = wshEnabled + } + err := rtn.getRemoteInfoAndShellType(blockMeta) + if err != nil { + return ConnUnion{}, err + } + return rtn, nil +} + func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc *RunShellOpts, blockMeta waveobj.MetaMapType) (*shellexec.ShellProc, error) { // create a circular blockfile for the output ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) defer cancelFn() - fsErr := filestore.WFS.MakeFile(ctx, bc.BlockId, BlockFile_Term, nil, wshrpc.FileOpts{MaxSize: DefaultTermMaxFileSize, Circular: true}) + fsErr := filestore.WFS.MakeFile(ctx, bc.BlockId, wavebase.BlockFile_Term, nil, wshrpc.FileOpts{MaxSize: DefaultTermMaxFileSize, Circular: true}) if fsErr != nil && fsErr != fs.ErrExist { return nil, fmt.Errorf("error creating blockfile: %w", fsErr) } if fsErr == fs.ErrExist { // reset the terminal state - bc.resetTerminalState() + bc.resetTerminalState(logCtx) } bcInitStatus := bc.GetRuntimeStatus() if bcInitStatus.ShellProcStatus == Status_Running { @@ -335,11 +480,14 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc } // TODO better sync here (don't let two starts happen at the same times) remoteName := blockMeta.GetString(waveobj.MetaKey_Connection, "") + connUnion, err := bc.getConnUnion(logCtx, remoteName, blockMeta) + if err != nil { + return nil, err + } + blocklogger.Infof(logCtx, "[conndebug] remoteName: %q, connType: %s, wshEnabled: %v, shell: %q, shellType: %s\n", remoteName, connUnion.ConnType, connUnion.WshEnabled, connUnion.ShellPath, connUnion.ShellType) var cmdStr string var cmdOpts shellexec.CommandOptsType - var err error if bc.ControllerType == BlockController_Shell { - cmdOpts.Env = make(map[string]string) cmdOpts.Interactive = true cmdOpts.Login = true cmdOpts.Cwd = blockMeta.GetString(waveobj.MetaKey_CmdCwd, "") @@ -352,7 +500,7 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc } } else if bc.ControllerType == BlockController_Cmd { var cmdOptsPtr *shellexec.CommandOptsType - cmdStr, cmdOptsPtr, err = createCmdStrAndOpts(bc.BlockId, blockMeta) + cmdStr, cmdOptsPtr, err = createCmdStrAndOpts(bc.BlockId, blockMeta, remoteName) if err != nil { return nil, err } @@ -361,22 +509,17 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc return nil, fmt.Errorf("unknown controller type %q", bc.ControllerType) } var shellProc *shellexec.ShellProc - swapToken := bc.makeSwapToken(ctx, remoteName) + swapToken := bc.makeSwapToken(ctx, blockMeta, remoteName, connUnion.ShellType) cmdOpts.SwapToken = swapToken blocklogger.Infof(logCtx, "[conndebug] created swaptoken: %s\n", swapToken.Token) - if strings.HasPrefix(remoteName, "wsl://") { - wslName := strings.TrimPrefix(remoteName, "wsl://") - credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second) - defer cancelFunc() - - wslConn := wslconn.GetWslConn(credentialCtx, wslName, false) - connStatus := wslConn.DeriveConnStatus() - if connStatus.Status != conncontroller.Status_Connected { - return nil, fmt.Errorf("not connected, cannot start shellproc") - } - - // create jwt - if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) { + if connUnion.ConnType == ConnType_Wsl { + wslConn := connUnion.WslConn + if !connUnion.WshEnabled { + shellProc, err = shellexec.StartWslShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) + if err != nil { + return nil, err + } + } else { sockName := wslConn.GetDomainSocketName() rpcContext := wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()} jwtStr, err := wshutil.MakeClientJWTToken(rpcContext, sockName) @@ -386,14 +529,6 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc swapToken.SockName = sockName swapToken.RpcContext = &rpcContext swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr - cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr - } - if !wslConn.WshEnabled.Load() { - shellProc, err = shellexec.StartWslShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) - if err != nil { - return nil, err - } - } else { shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) if err != nil { wslConn.SetWshError(err) @@ -406,20 +541,14 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc } } } - } else if remoteName != "" { - credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second) - defer cancelFunc() - - opts, err := remote.ParseOpts(remoteName) - if err != nil { - return nil, err - } - conn := conncontroller.GetConn(credentialCtx, opts, &wconfig.ConnKeywords{}) - connStatus := conn.DeriveConnStatus() - if connStatus.Status != conncontroller.Status_Connected { - return nil, fmt.Errorf("not connected, cannot start shellproc") - } - if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) { + } else if connUnion.ConnType == ConnType_Ssh { + conn := connUnion.SshConn + if !connUnion.WshEnabled { + shellProc, err = shellexec.StartRemoteShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, conn) + if err != nil { + return nil, err + } + } else { sockName := conn.GetDomainSocketName() rpcContext := wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: conn.Opts.String()} jwtStr, err := wshutil.MakeClientJWTToken(rpcContext, sockName) @@ -429,14 +558,6 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc swapToken.SockName = sockName swapToken.RpcContext = &rpcContext swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr - cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr - } - if !conn.WshEnabled.Load() { - shellProc, err = shellexec.StartRemoteShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, conn) - if err != nil { - return nil, err - } - } else { shellProc, err = shellexec.StartRemoteShellProc(ctx, logCtx, rc.TermSize, cmdStr, cmdOpts, conn) if err != nil { conn.SetWshError(err) @@ -449,9 +570,8 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc } } } - } else { - // local terminal - if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) { + } else if connUnion.ConnType == ConnType_Local { + if connUnion.WshEnabled { sockName := wavebase.GetDomainSocketName() rpcContext := wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId} jwtStr, err := wshutil.MakeClientJWTToken(rpcContext, sockName) @@ -461,25 +581,15 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc swapToken.SockName = sockName swapToken.RpcContext = &rpcContext swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr - cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr - } - settings := wconfig.GetWatcher().GetFullConfig().Settings - if settings.TermLocalShellPath != "" { - cmdOpts.ShellPath = settings.TermLocalShellPath - } - if blockMeta.GetString(waveobj.MetaKey_TermLocalShellPath, "") != "" { - cmdOpts.ShellPath = blockMeta.GetString(waveobj.MetaKey_TermLocalShellPath, "") - } - if len(settings.TermLocalShellOpts) > 0 { - cmdOpts.ShellOpts = append([]string{}, settings.TermLocalShellOpts...) - } - if len(blockMeta.GetStringList(waveobj.MetaKey_TermLocalShellOpts)) > 0 { - cmdOpts.ShellOpts = append([]string{}, blockMeta.GetStringList(waveobj.MetaKey_TermLocalShellOpts)...) } + cmdOpts.ShellPath = connUnion.ShellPath + cmdOpts.ShellOpts = getLocalShellOpts(blockMeta) shellProc, err = shellexec.StartLocalShellProc(logCtx, rc.TermSize, cmdStr, cmdOpts) if err != nil { return nil, err } + } else { + return nil, fmt.Errorf("unknown connection type for conn %q: %s", remoteName, connUnion.ConnType) } bc.UpdateControllerAndSendUpdate(func() bool { bc.ShellProc = shellProc @@ -489,6 +599,17 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc return shellProc, nil } +func (bc *BlockController) getBlockData_noErr() *waveobj.Block { + ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout) + defer cancelFn() + blockData, err := wstore.DBGet[*waveobj.Block](ctx, bc.BlockId) + if err != nil { + log.Printf("error getting block data (getBlockData_noErr): %v\n", err) + return nil + } + return blockData +} + func (bc *BlockController) manageRunningShellProcess(shellProc *shellexec.ShellProc, rc *RunShellOpts, blockMeta waveobj.MetaMapType) error { shellInputCh := make(chan *BlockInputUnion, 32) bc.ShellInputCh = shellInputCh @@ -513,8 +634,11 @@ func (bc *BlockController) manageRunningShellProcess(shellProc *shellexec.ShellP }) shellProc.Cmd.Wait() exitCode := shellProc.Cmd.ExitCode() - termMsg := fmt.Sprintf("\r\nprocess finished with exit code = %d\r\n\r\n", exitCode) - HandleAppendBlockFile(bc.BlockId, BlockFile_Term, []byte(termMsg)) + blockData := bc.getBlockData_noErr() + if blockData != nil && blockData.Meta.GetString(waveobj.MetaKey_Controller, "") == BlockController_Cmd { + termMsg := fmt.Sprintf("\r\nprocess finished with exit code = %d\r\n\r\n", exitCode) + HandleAppendBlockFile(bc.BlockId, wavebase.BlockFile_Term, []byte(termMsg)) + } // to stop the inputCh loop time.Sleep(100 * time.Millisecond) close(shellInputCh) // don't use bc.ShellInputCh (it's nil) @@ -523,7 +647,7 @@ func (bc *BlockController) manageRunningShellProcess(shellProc *shellexec.ShellP for { nr, err := ptyBuffer.Read(buf) if nr > 0 { - err := HandleAppendBlockFile(bc.BlockId, BlockFile_Term, buf[:nr]) + err := HandleAppendBlockFile(bc.BlockId, wavebase.BlockFile_Term, buf[:nr]) if err != nil { log.Printf("error appending to blockfile: %v\n", err) } @@ -683,6 +807,7 @@ func (bc *BlockController) UnlockRunLock() { } func (bc *BlockController) run(logCtx context.Context, bdata *waveobj.Block, blockMeta map[string]any, rtOpts *waveobj.RuntimeOpts, force bool) { + blocklogger.Debugf(logCtx, "[conndebug] BlockController.run() %q\n", bc.BlockId) runningShellCommand := false ok := bc.LockRunLock() if !ok { @@ -765,7 +890,7 @@ func CheckConnStatus(blockId string) error { } if strings.HasPrefix(connName, "wsl://") { distroName := strings.TrimPrefix(connName, "wsl://") - conn := wslconn.GetWslConn(context.Background(), distroName, false) + conn := wslconn.GetWslConn(distroName) connStatus := conn.DeriveConnStatus() if connStatus.Status != conncontroller.Status_Connected { return fmt.Errorf("not connected: %s", connStatus.Status) @@ -776,7 +901,7 @@ func CheckConnStatus(blockId string) error { if err != nil { return fmt.Errorf("error parsing connection name: %w", err) } - conn := conncontroller.GetConn(context.Background(), opts, &wconfig.ConnKeywords{}) + conn := conncontroller.GetConn(opts) connStatus := conn.DeriveConnStatus() if connStatus.Status != conncontroller.Status_Connected { return fmt.Errorf("not connected: %s", connStatus.Status) diff --git a/pkg/filestore/blockstore_test.go b/pkg/filestore/blockstore_test.go index 235242dc0d..c1fd524698 100644 --- a/pkg/filestore/blockstore_test.go +++ b/pkg/filestore/blockstore_test.go @@ -316,6 +316,31 @@ func checkFileDataAt(t *testing.T, ctx context.Context, zoneId string, name stri } } +func TestWriteAt(t *testing.T) { + initDb(t) + defer cleanupDb(t) + + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + fileName := "t3" + zoneId := uuid.NewString() + err := WFS.MakeFile(ctx, zoneId, fileName, nil, FileOptsType{}) + if err != nil { + t.Fatalf("error creating file: %v", err) + } + err = WFS.WriteFile(ctx, zoneId, fileName, []byte("hello world!")) + if err != nil { + t.Fatalf("error writing data: %v", err) + } + checkFileData(t, ctx, zoneId, fileName, "hello world!") + err = WFS.WriteAt(ctx, zoneId, fileName, 0, []byte("foo")) + if err != nil { + t.Fatalf("error writing data: %v", err) + } + checkFileSize(t, ctx, zoneId, fileName, 12) + checkFileData(t, ctx, zoneId, fileName, "foolo world!") +} + func TestAppend(t *testing.T) { initDb(t) defer cleanupDb(t) diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index 8fe1e1a470..31fbe84586 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -812,7 +812,7 @@ func getConnInternal(opts *remote.SSHOpts) *SSHConn { } // does NOT connect, can return nil if connection does not exist -func GetConn(ctx context.Context, opts *remote.SSHOpts, connFlags *wconfig.ConnKeywords) *SSHConn { +func GetConn(opts *remote.SSHOpts) *SSHConn { conn := getConnInternal(opts) return conn } @@ -826,7 +826,7 @@ func EnsureConnection(ctx context.Context, connName string) error { if err != nil { return fmt.Errorf("error parsing connection name: %w", err) } - conn := GetConn(ctx, connOpts, &wconfig.ConnKeywords{}) + conn := GetConn(connOpts) if conn == nil { return fmt.Errorf("connection not found: %s", connName) } diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index 394f2fbe7b..1f9324116a 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -38,7 +38,6 @@ type CommandOptsType struct { Interactive bool `json:"interactive,omitempty"` Login bool `json:"login,omitempty"` Cwd string `json:"cwd,omitempty"` - Env map[string]string `json:"env,omitempty"` ShellPath string `json:"shellPath,omitempty"` ShellOpts []string `json:"shellOpts,omitempty"` SwapToken *shellutil.TokenSwapEntry `json:"swapToken,omitempty"` @@ -256,13 +255,13 @@ func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr st conn.Infof(ctx, "setting ZDOTDIR to %s\n", zshDir) cmdCombined = fmt.Sprintf(`ZDOTDIR=%s %s`, zshDir, cmdCombined) } - - jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName] - if !ok { - return nil, fmt.Errorf("no jwt token provided to connection") + packedToken, err := cmdOpts.SwapToken.PackForClient() + if err != nil { + conn.Infof(ctx, "error packing swap token: %v", err) + } else { + conn.Debugf(ctx, "packed swaptoken %s\n", packedToken) + cmdCombined = fmt.Sprintf(`%s=%s %s`, wavebase.WaveSwapTokenVarName, packedToken, cmdCombined) } - cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined) - log.Printf("full combined command: %s", cmdCombined) ecmd := exec.Command("wsl.exe", "~", "-d", client.Name(), "--", "sh", "-c", cmdCombined) if termSize.Rows == 0 || termSize.Cols == 0 { @@ -425,12 +424,6 @@ func StartRemoteShellProc(ctx context.Context, logCtx context.Context, termSize session.Stdin = remoteStdinRead session.Stdout = remoteStdoutWrite session.Stderr = remoteStdoutWrite - - for envKey, envVal := range cmdOpts.Env { - // note these might fail depending on server settings, but we still try - session.Setenv(envKey, envVal) - } - if shellType == shellutil.ShellType_zsh { zshDir := fmt.Sprintf("~/.waveterm/%s", shellutil.ZshIntegrationDir) conn.Infof(logCtx, "setting ZDOTDIR to %s\n", zshDir) @@ -557,7 +550,6 @@ func StartLocalShellProc(logCtx context.Context, termSize waveobj.TermSize, cmdS envToAdd["LANG"] = wavebase.DetermineLang() } shellutil.UpdateCmdEnv(ecmd, envToAdd) - shellutil.UpdateCmdEnv(ecmd, cmdOpts.Env) if termSize.Rows == 0 || termSize.Cols == 0 { termSize.Rows = shellutil.DefaultTermRows termSize.Cols = shellutil.DefaultTermCols diff --git a/pkg/util/shellutil/shellutil.go b/pkg/util/shellutil/shellutil.go index 9eb1ce1864..e19e8298fe 100644 --- a/pkg/util/shellutil/shellutil.go +++ b/pkg/util/shellutil/shellutil.go @@ -226,6 +226,7 @@ func WaveshellLocalEnvVars(termType string) map[string]string { if termType != "" { rtn["TERM"] = termType } + // these are not necessary since they should be set with the swap token, but no harm in setting them here rtn["TERM_PROGRAM"] = "waveterm" rtn["WAVETERM"], _ = os.Executable() rtn["WAVETERM_VERSION"] = wavebase.WaveVersion diff --git a/pkg/wavebase/wavebase.go b/pkg/wavebase/wavebase.go index 019bfd9c11..2d2c30064b 100644 --- a/pkg/wavebase/wavebase.go +++ b/pkg/wavebase/wavebase.go @@ -37,6 +37,13 @@ const ( WaveSwapTokenVarName = "WAVETERM_SWAPTOKEN" ) +const ( + BlockFile_Term = "term" // used for main pty output + BlockFile_Cache = "cache:term:full" // for cached block + BlockFile_VDom = "vdom" // used for alt html layout + BlockFile_Env = "env" +) + const NeedJwtConst = "NEED-JWT" var ConfigHome_VarCache string // caches WAVETERM_CONFIG_HOME diff --git a/pkg/waveobj/metaconsts.go b/pkg/waveobj/metaconsts.go index d576e6bf19..c6fec36d57 100644 --- a/pkg/waveobj/metaconsts.go +++ b/pkg/waveobj/metaconsts.go @@ -47,11 +47,18 @@ const ( MetaKey_CmdCloseOnExit = "cmd:closeonexit" MetaKey_CmdCloseOnExitForce = "cmd:closeonexitforce" MetaKey_CmdCloseOnExitDelay = "cmd:closeonexitdelay" - MetaKey_CmdEnv = "cmd:env" - MetaKey_CmdCwd = "cmd:cwd" MetaKey_CmdNoWsh = "cmd:nowsh" MetaKey_CmdArgs = "cmd:args" MetaKey_CmdShell = "cmd:shell" + MetaKey_CmdAllowConnChange = "cmd:allowconnchange" + MetaKey_CmdEnv = "cmd:env" + MetaKey_CmdCwd = "cmd:cwd" + MetaKey_CmdInitScript = "cmd:initscript" + MetaKey_CmdInitScriptSh = "cmd:initscript.sh" + MetaKey_CmdInitScriptBash = "cmd:initscript.bash" + MetaKey_CmdInitScriptZsh = "cmd:initscript.zsh" + MetaKey_CmdInitScriptPwsh = "cmd:initscript.pwsh" + MetaKey_CmdInitScriptFish = "cmd:initscript.fish" MetaKey_AiClear = "ai:*" MetaKey_AiPresetKey = "ai:preset" diff --git a/pkg/waveobj/metamap.go b/pkg/waveobj/metamap.go index 1c20881a9e..a74665ec59 100644 --- a/pkg/waveobj/metamap.go +++ b/pkg/waveobj/metamap.go @@ -14,6 +14,22 @@ func (m MetaMapType) GetString(key string, def string) string { return def } +func (m MetaMapType) HasKey(key string) bool { + _, ok := m[key] + return ok +} + +func (m MetaMapType) GetConnectionOverride(connName string) MetaMapType { + v, ok := m["["+connName+"]"] + if !ok { + return nil + } + if mval, ok := v.(map[string]any); ok { + return MetaMapType(mval) + } + return nil +} + func (m MetaMapType) GetStringList(key string) []string { v, ok := m[key] if !ok { diff --git a/pkg/waveobj/wtypemeta.go b/pkg/waveobj/wtypemeta.go index 01a1327925..6705aa4fd8 100644 --- a/pkg/waveobj/wtypemeta.go +++ b/pkg/waveobj/wtypemeta.go @@ -36,21 +36,30 @@ type MetaTSType struct { FrameIcon string `json:"frame:icon,omitempty"` FrameText string `json:"frame:text,omitempty"` - CmdClear bool `json:"cmd:*,omitempty"` - Cmd string `json:"cmd,omitempty"` - CmdInteractive bool `json:"cmd:interactive,omitempty"` - CmdLogin bool `json:"cmd:login,omitempty"` - CmdRunOnStart bool `json:"cmd:runonstart,omitempty"` - CmdClearOnStart bool `json:"cmd:clearonstart,omitempty"` - CmdRunOnce bool `json:"cmd:runonce,omitempty"` - CmdCloseOnExit bool `json:"cmd:closeonexit,omitempty"` - CmdCloseOnExitForce bool `json:"cmd:closeonexitforce,omitempty"` - CmdCloseOnExitDelay float64 `json:"cmd:closeonexitdelay,omitempty"` - CmdEnv map[string]string `json:"cmd:env,omitempty"` - CmdCwd string `json:"cmd:cwd,omitempty"` - CmdNoWsh bool `json:"cmd:nowsh,omitempty"` - CmdArgs []string `json:"cmd:args,omitempty"` // args for cmd (only if cmd:shell is false) - CmdShell bool `json:"cmd:shell,omitempty"` // shell expansion for cmd+args (defaults to true) + CmdClear bool `json:"cmd:*,omitempty"` + Cmd string `json:"cmd,omitempty"` + CmdInteractive bool `json:"cmd:interactive,omitempty"` + CmdLogin bool `json:"cmd:login,omitempty"` + CmdRunOnStart bool `json:"cmd:runonstart,omitempty"` + CmdClearOnStart bool `json:"cmd:clearonstart,omitempty"` + CmdRunOnce bool `json:"cmd:runonce,omitempty"` + CmdCloseOnExit bool `json:"cmd:closeonexit,omitempty"` + CmdCloseOnExitForce bool `json:"cmd:closeonexitforce,omitempty"` + CmdCloseOnExitDelay float64 `json:"cmd:closeonexitdelay,omitempty"` + CmdNoWsh bool `json:"cmd:nowsh,omitempty"` + CmdArgs []string `json:"cmd:args,omitempty"` // args for cmd (only if cmd:shell is false) + CmdShell bool `json:"cmd:shell,omitempty"` // shell expansion for cmd+args (defaults to true) + CmdAllowConnChange bool `json:"cmd:allowconnchange,omitempty"` + + // these can be nested under "[conn]" + CmdEnv map[string]string `json:"cmd:env,omitempty"` + CmdCwd string `json:"cmd:cwd,omitempty"` + CmdInitScript string `json:"cmd:initscript,omitempty"` + CmdInitScriptSh string `json:"cmd:initscript.sh,omitempty"` + CmdInitScriptBash string `json:"cmd:initscript.bash,omitempty"` + CmdInitScriptZsh string `json:"cmd:initscript.zsh,omitempty"` + CmdInitScriptPwsh string `json:"cmd:initscript.pwsh,omitempty"` + CmdInitScriptFish string `json:"cmd:initscript.fish,omitempty"` // AI options match settings AiClear bool `json:"ai:*,omitempty"` diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 9a655e0f8b..edc7f76294 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -273,7 +273,7 @@ func (ws *WshServer) ControllerAppendOutputCommand(ctx context.Context, data wsh if err != nil { return fmt.Errorf("error decoding output data: %w", err) } - err = blockcontroller.HandleAppendBlockFile(data.BlockId, blockcontroller.BlockFile_Term, outputBuf[:nw]) + err = blockcontroller.HandleAppendBlockFile(data.BlockId, wavebase.BlockFile_Term, outputBuf[:nw]) if err != nil { return fmt.Errorf("error appending to block file: %w", err) } @@ -335,7 +335,7 @@ func (ws *WshServer) FileAppendCommand(ctx context.Context, data wshrpc.FileData func (ws *WshServer) FileAppendIJsonCommand(ctx context.Context, data wshrpc.CommandAppendIJsonData) error { tryCreate := true - if data.FileName == blockcontroller.BlockFile_VDom && tryCreate { + if data.FileName == wavebase.BlockFile_VDom && tryCreate { err := filestore.WFS.MakeFile(ctx, data.ZoneId, data.FileName, nil, wshrpc.FileOpts{MaxSize: blockcontroller.DefaultHtmlMaxFileSize, IJson: true}) if err != nil && err != fs.ErrExist { return fmt.Errorf("error creating blockfile[vdom]: %w", err) @@ -496,7 +496,7 @@ func (ws *WshServer) ConnEnsureCommand(ctx context.Context, data wshrpc.ConnExtD func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error { if strings.HasPrefix(connName, "wsl://") { distroName := strings.TrimPrefix(connName, "wsl://") - conn := wslconn.GetWslConn(ctx, distroName, false) + conn := wslconn.GetWslConn(distroName) if conn == nil { return fmt.Errorf("distro not found: %s", connName) } @@ -506,7 +506,7 @@ func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) if err != nil { return fmt.Errorf("error parsing connection name: %w", err) } - conn := conncontroller.GetConn(ctx, connOpts, &wconfig.ConnKeywords{}) + conn := conncontroller.GetConn(connOpts) if conn == nil { return fmt.Errorf("connection not found: %s", connName) } @@ -519,7 +519,7 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connRequest wshrpc. connName := connRequest.Host if strings.HasPrefix(connName, "wsl://") { distroName := strings.TrimPrefix(connName, "wsl://") - conn := wslconn.GetWslConn(ctx, distroName, false) + conn := wslconn.GetWslConn(distroName) if conn == nil { return fmt.Errorf("connection not found: %s", connName) } @@ -529,7 +529,7 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connRequest wshrpc. if err != nil { return fmt.Errorf("error parsing connection name: %w", err) } - conn := conncontroller.GetConn(ctx, connOpts, &connRequest.Keywords) + conn := conncontroller.GetConn(connOpts) if conn == nil { return fmt.Errorf("connection not found: %s", connName) } @@ -542,7 +542,7 @@ func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, data wshrpc.Co connName := data.ConnName if strings.HasPrefix(connName, "wsl://") { distroName := strings.TrimPrefix(connName, "wsl://") - conn := wslconn.GetWslConn(ctx, distroName, false) + conn := wslconn.GetWslConn(distroName) if conn == nil { return fmt.Errorf("connection not found: %s", connName) } @@ -552,7 +552,7 @@ func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, data wshrpc.Co if err != nil { return fmt.Errorf("error parsing connection name: %w", err) } - conn := conncontroller.GetConn(ctx, connOpts, &wconfig.ConnKeywords{}) + conn := conncontroller.GetConn(connOpts) if conn == nil { return fmt.Errorf("connection not found: %s", connName) } @@ -589,7 +589,7 @@ func (ws *WshServer) ConnUpdateWshCommand(ctx context.Context, remoteInfo wshrpc if err != nil { return false, fmt.Errorf("error parsing connection name: %w", err) } - conn := conncontroller.GetConn(ctx, connOpts, &wconfig.ConnKeywords{}) + conn := conncontroller.GetConn(connOpts) if conn == nil { return false, fmt.Errorf("connection not found: %s", connName) } @@ -639,7 +639,7 @@ func (ws *WshServer) WslDefaultDistroCommand(ctx context.Context) (string, error func (ws *WshServer) DismissWshFailCommand(ctx context.Context, connName string) error { if strings.HasPrefix(connName, "wsl://") { distroName := strings.TrimPrefix(connName, "wsl://") - conn := wslconn.GetWslConn(ctx, distroName, false) + conn := wslconn.GetWslConn(distroName) if conn == nil { return fmt.Errorf("connection not found: %s", connName) } @@ -651,7 +651,7 @@ func (ws *WshServer) DismissWshFailCommand(ctx context.Context, connName string) if err != nil { return err } - conn := conncontroller.GetConn(ctx, opts, nil) + conn := conncontroller.GetConn(opts) if conn == nil { return fmt.Errorf("connection %s not found", connName) } diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index ff72ac7a0a..424d64db19 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -549,7 +549,6 @@ func getShell() string { if runtime.GOOS == "darwin" { return shellutil.GetMacUserShell() } - shell := os.Getenv("SHELL") if shell == "" { return "/bin/bash" diff --git a/pkg/wslconn/wslconn.go b/pkg/wslconn/wslconn.go index 4faeedaf22..63d99f1979 100644 --- a/pkg/wslconn/wslconn.go +++ b/pkg/wslconn/wslconn.go @@ -117,6 +117,10 @@ func (conn *WslConn) Infof(ctx context.Context, format string, args ...any) { blocklogger.Infof(ctx, "[conndebug] "+format, args...) } +func (conn *WslConn) Debugf(ctx context.Context, format string, args ...any) { + blocklogger.Infof(ctx, "[conndebug] "+format, args...) +} + func (conn *WslConn) FireConnChangeEvent() { status := conn.DeriveConnStatus() event := wps.WaveEvent{ @@ -751,11 +755,8 @@ func getConnInternal(name string) *WslConn { return rtn } -func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn { +func GetWslConn(name string) *WslConn { conn := getConnInternal(name) - if conn.Client == nil && shouldConnect { - conn.Connect(ctx) - } return conn } @@ -764,7 +765,7 @@ func EnsureConnection(ctx context.Context, connName string) error { if connName == "" { return nil } - conn := GetWslConn(ctx, connName, false) + conn := GetWslConn(connName) if conn == nil { return fmt.Errorf("connection not found: %s", connName) }