diff --git a/cmd/wsh/cmd/setmeta_test.go b/cmd/wsh/cmd/setmeta_test.go new file mode 100644 index 0000000000..6f0e7be6b4 --- /dev/null +++ b/cmd/wsh/cmd/setmeta_test.go @@ -0,0 +1,170 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "reflect" + "testing" +) + +func TestParseMetaSets(t *testing.T) { + tests := []struct { + name string + input []string + want map[string]any + wantErr bool + }{ + { + name: "basic types", + input: []string{"str=hello", "num=42", "float=3.14", "bool=true", "null=null"}, + want: map[string]any{ + "str": "hello", + "num": int64(42), + "float": float64(3.14), + "bool": true, + "null": nil, + }, + }, + { + name: "json values", + input: []string{ + `arr=[1,2,3]`, + `obj={"foo":"bar"}`, + `str="quoted"`, + }, + want: map[string]any{ + "arr": []any{float64(1), float64(2), float64(3)}, + "obj": map[string]any{"foo": "bar"}, + "str": "quoted", + }, + }, + { + name: "nested paths", + input: []string{ + "a/b=55", + "a/c=2", + }, + want: map[string]any{ + "a": map[string]any{ + "b": int64(55), + "c": int64(2), + }, + }, + }, + { + name: "deep nesting", + input: []string{ + "a/b/c/d=hello", + }, + want: map[string]any{ + "a": map[string]any{ + "b": map[string]any{ + "c": map[string]any{ + "d": "hello", + }, + }, + }, + }, + }, + { + name: "override nested value", + input: []string{ + "a/b/c=1", + "a/b=2", + }, + want: map[string]any{ + "a": map[string]any{ + "b": int64(2), + }, + }, + }, + { + name: "override with null", + input: []string{ + "a/b=1", + "a/c=2", + "a=null", + }, + want: map[string]any{ + "a": nil, + }, + }, + { + name: "mixed types in path", + input: []string{ + "a/b=1", + "a/c=[1,2,3]", + "a/d/e=true", + }, + want: map[string]any{ + "a": map[string]any{ + "b": int64(1), + "c": []any{float64(1), float64(2), float64(3)}, + "d": map[string]any{ + "e": true, + }, + }, + }, + }, + { + name: "invalid format", + input: []string{"invalid"}, + wantErr: true, + }, + { + name: "invalid json", + input: []string{`a={"invalid`}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseMetaSets(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parseMetaSets() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseMetaSets() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParseMetaValue(t *testing.T) { + tests := []struct { + name string + input string + want any + wantErr bool + }{ + {"empty string", "", nil, false}, + {"null", "null", nil, false}, + {"true", "true", true, false}, + {"false", "false", false, false}, + {"integer", "42", int64(42), false}, + {"negative integer", "-42", int64(-42), false}, + {"hex integer", "0xff", int64(255), false}, + {"float", "3.14", float64(3.14), false}, + {"string", "hello", "hello", false}, + {"json array", "[1,2,3]", []any{float64(1), float64(2), float64(3)}, false}, + {"json object", `{"foo":"bar"}`, map[string]any{"foo": "bar"}, false}, + {"quoted string", `"quoted"`, "quoted", false}, + {"invalid json", `{"invalid`, nil, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseMetaValue(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parseMetaValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseMetaValue() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/wsh/cmd/wshcmd-setmeta.go b/cmd/wsh/cmd/wshcmd-setmeta.go index 32af2c5529..13e3a352b7 100644 --- a/cmd/wsh/cmd/wshcmd-setmeta.go +++ b/cmd/wsh/cmd/wshcmd-setmeta.go @@ -58,40 +58,88 @@ func loadJSONFile(filepath string) (map[string]interface{}, error) { return result, nil } -func parseMetaSets(metaSets []string) (map[string]interface{}, error) { - meta := make(map[string]interface{}) +func parseMetaValue(setVal string) (any, error) { + if setVal == "" || setVal == "null" { + return nil, nil + } + if setVal == "true" { + return true, nil + } + if setVal == "false" { + return false, nil + } + if setVal[0] == '[' || setVal[0] == '{' || setVal[0] == '"' { + var val any + err := json.Unmarshal([]byte(setVal), &val) + if err != nil { + return nil, fmt.Errorf("invalid json value: %v", err) + } + return val, nil + } + + // Try parsing as integer + ival, err := strconv.ParseInt(setVal, 0, 64) + if err == nil { + return ival, nil + } + + // Try parsing as float + fval, err := strconv.ParseFloat(setVal, 64) + if err == nil { + return fval, nil + } + + // Fallback to string + return setVal, nil +} + +func setNestedValue(meta map[string]any, path []string, value any) { + // For single key, just set directly + if len(path) == 1 { + meta[path[0]] = value + return + } + + // For nested path, traverse or create maps as needed + current := meta + for i := 0; i < len(path)-1; i++ { + key := path[i] + // If next level doesn't exist or isn't a map, create new map + next, exists := current[key] + if !exists { + nextMap := make(map[string]any) + current[key] = nextMap + current = nextMap + } else if nextMap, ok := next.(map[string]any); ok { + current = nextMap + } else { + // If existing value isn't a map, replace with new map + nextMap = make(map[string]any) + current[key] = nextMap + current = nextMap + } + } + + // Set the final value + current[path[len(path)-1]] = value +} + +func parseMetaSets(metaSets []string) (map[string]any, error) { + meta := make(map[string]any) for _, metaSet := range metaSets { fields := strings.SplitN(metaSet, "=", 2) if len(fields) != 2 { return nil, fmt.Errorf("invalid meta set: %q", metaSet) } - setVal := fields[1] - if setVal == "" || setVal == "null" { - meta[fields[0]] = nil - } else if setVal == "true" { - meta[fields[0]] = true - } else if setVal == "false" { - meta[fields[0]] = false - } else if setVal[0] == '[' || setVal[0] == '{' || setVal[0] == '"' { - var val interface{} - err := json.Unmarshal([]byte(setVal), &val) - if err != nil { - return nil, fmt.Errorf("invalid json value: %v", err) - } - meta[fields[0]] = val - } else { - ival, err := strconv.ParseInt(setVal, 0, 64) - if err == nil { - meta[fields[0]] = ival - } else { - fval, err := strconv.ParseFloat(setVal, 64) - if err == nil { - meta[fields[0]] = fval - } else { - meta[fields[0]] = setVal - } - } + + val, err := parseMetaValue(fields[1]) + if err != nil { + return nil, err } + + // Split the key path and set nested value + path := strings.Split(fields[0], "/") + setNestedValue(meta, path, val) } return meta, nil } diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index e90db47c2f..4b71dfcd59 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -306,6 +306,13 @@ declare global { "term:fontsize"?: number; "term:fontfamily"?: string; "term:theme"?: string; + "cmd:env"?: {[key: string]: string}; + "cmd:initscript"?: string; + "cmd:initscript.sh"?: string; + "cmd:initscript.bash"?: string; + "cmd:initscript.zsh"?: string; + "cmd:initscript.pwsh"?: string; + "cmd:initscript.fish"?: string; "ssh:user"?: string; "ssh:hostname"?: string; "ssh:port"?: string; diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 6787078a78..d7363af451 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -11,6 +11,7 @@ import ( "io" "io/fs" "log" + "os" "strings" "sync" "sync/atomic" @@ -24,6 +25,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" "github.com/wavetermdev/waveterm/pkg/shellexec" "github.com/wavetermdev/waveterm/pkg/util/envutil" + "github.com/wavetermdev/waveterm/pkg/util/fileutil" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" @@ -57,6 +59,7 @@ const ( const ( DefaultTermMaxFileSize = 256 * 1024 DefaultHtmlMaxFileSize = 256 * 1024 + MaxInitScriptSize = 50 * 1024 ) const DefaultTimeout = 2 * time.Second @@ -232,25 +235,79 @@ func getCustomInitScriptKeyCascade(shellType string) []string { return []string{waveobj.MetaKey_CmdInitScript} } -func getCustomInitScript(meta waveobj.MetaMapType, connName string, shellType string) string { +func getCustomInitScript(logCtx context.Context, meta waveobj.MetaMapType, connName string, shellType string) string { + initScriptVal, metaKeyName := getCustomInitScriptValue(meta, connName, shellType) + if initScriptVal == "" { + return "" + } + if !fileutil.IsInitScriptPath(initScriptVal) { + blocklogger.Infof(logCtx, "[conndebug] inline initScript (size=%d) found in meta key: %s\n", len(initScriptVal), metaKeyName) + return initScriptVal + } + blocklogger.Infof(logCtx, "[conndebug] initScript detected as a file %q from meta key: %s\n", initScriptVal, metaKeyName) + initScriptVal, err := wavebase.ExpandHomeDir(initScriptVal) + if err != nil { + blocklogger.Infof(logCtx, "[conndebug] cannot expand home dir in Wave initscript file: %v\n", err) + return fmt.Sprintf("echo \"cannot expand home dir in Wave initscript file, from key %s\";\n", metaKeyName) + } + fileData, err := os.ReadFile(initScriptVal) + if err != nil { + blocklogger.Infof(logCtx, "[conndebug] cannot open Wave initscript file: %v\n", err) + return fmt.Sprintf("echo \"cannot open Wave initscript file, from key %s\";\n", metaKeyName) + } + if len(fileData) > MaxInitScriptSize { + blocklogger.Infof(logCtx, "[conndebug] initscript file too large, size=%d, max=%d\n", len(fileData), MaxInitScriptSize) + return fmt.Sprintf("echo \"initscript file too large, from key %s\";\n", metaKeyName) + } + if utilfn.HasBinaryData(fileData) { + blocklogger.Infof(logCtx, "[conndebug] initscript file contains binary data\n") + return fmt.Sprintf("echo \"initscript file contains binary data, from key %s\";\n", metaKeyName) + } + blocklogger.Infof(logCtx, "[conndebug] initscript file read successfully, size=%d\n", len(fileData)) + return string(fileData) +} + +// returns (value, metakey) +func getCustomInitScriptValue(meta waveobj.MetaMapType, connName string, shellType string) (string, string) { keys := getCustomInitScriptKeyCascade(shellType) connMeta := meta.GetConnectionOverride(connName) if connMeta != nil { for _, key := range keys { if connMeta.HasKey(key) { - return connMeta.GetString(key, "") + return connMeta.GetString(key, ""), "blockmeta/[" + connName + "]/" + key } } } for _, key := range keys { if meta.HasKey(key) { - return meta.GetString(key, "") + return meta.GetString(key, ""), "blockmeta/" + key + } + } + fullConfig := wconfig.GetWatcher().GetFullConfig() + connKeywords := fullConfig.Connections[connName] + connKeywordsMap := make(map[string]any) + err := utilfn.ReUnmarshal(&connKeywordsMap, connKeywords) + if err != nil { + log.Printf("error re-unmarshalling connKeywords: %v\n", err) + return "", "" + } + ckMeta := waveobj.MetaMapType(connKeywordsMap) + for _, key := range keys { + if ckMeta.HasKey(key) { + return ckMeta.GetString(key, ""), "connections.json/" + connName + "/" + key } } - return "" + return "", "" } func resolveEnvMap(blockId string, blockMeta waveobj.MetaMapType, connName string) (map[string]string, error) { + rtn := make(map[string]string) + config := wconfig.GetWatcher().GetFullConfig() + connKeywords := config.Connections[connName] + ckEnv := connKeywords.CmdEnv + for k, v := range ckEnv { + rtn[k] = v + } ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) defer cancelFn() _, envFileData, err := filestore.WFS.ReadFile(ctx, blockId, wavebase.BlockFile_Env) @@ -260,25 +317,27 @@ func resolveEnvMap(blockId string, blockMeta waveobj.MetaMapType, connName strin if err != nil { return nil, fmt.Errorf("error reading command env file: %w", err) } - rtn := make(map[string]string) if len(envFileData) > 0 { envMap := envutil.EnvToMap(string(envFileData)) for k, v := range envMap { rtn[k] = v } } - cmdEnv := blockMeta.GetMap(waveobj.MetaKey_CmdEnv) + cmdEnv := blockMeta.GetStringMap(waveobj.MetaKey_CmdEnv, true) for k, v := range cmdEnv { - if v == nil { + if v == waveobj.MetaMap_DeleteSentinel { delete(rtn, k) continue } - if strVal, ok := v.(string); ok { - rtn[k] = strVal - } - if floatVal, ok := v.(float64); ok { - rtn[k] = fmt.Sprintf("%v", floatVal) + rtn[k] = v + } + connEnv := blockMeta.GetConnectionOverride(connName).GetStringMap(waveobj.MetaKey_CmdEnv, true) + for k, v := range connEnv { + if v == waveobj.MetaMap_DeleteSentinel { + delete(rtn, k) + continue } + rtn[k] = v } return rtn, nil } @@ -322,7 +381,7 @@ func (bc *BlockController) DoRunShellCommand(logCtx context.Context, rc *RunShel return bc.manageRunningShellProcess(shellProc, rc, blockMeta) } -func (bc *BlockController) makeSwapToken(ctx context.Context, blockMeta waveobj.MetaMapType, remoteName string, shellType string) *shellutil.TokenSwapEntry { +func (bc *BlockController) makeSwapToken(ctx context.Context, logCtx context.Context, blockMeta waveobj.MetaMapType, remoteName string, shellType string) *shellutil.TokenSwapEntry { token := &shellutil.TokenSwapEntry{ Token: uuid.New().String(), Env: make(map[string]string), @@ -360,7 +419,7 @@ func (bc *BlockController) makeSwapToken(ctx context.Context, blockMeta waveobj. for k, v := range envMap { token.Env[k] = v } - token.ScriptText = getCustomInitScript(blockMeta, remoteName, shellType) + token.ScriptText = getCustomInitScript(logCtx, blockMeta, remoteName, shellType) return token } @@ -509,9 +568,9 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc return nil, fmt.Errorf("unknown controller type %q", bc.ControllerType) } var shellProc *shellexec.ShellProc - swapToken := bc.makeSwapToken(ctx, blockMeta, remoteName, connUnion.ShellType) + swapToken := bc.makeSwapToken(ctx, logCtx, blockMeta, remoteName, connUnion.ShellType) cmdOpts.SwapToken = swapToken - blocklogger.Infof(logCtx, "[conndebug] created swaptoken: %s\n", swapToken.Token) + blocklogger.Debugf(logCtx, "[conndebug] created swaptoken: %s\n", swapToken.Token) if connUnion.ConnType == ConnType_Wsl { wslConn := connUnion.WslConn if !connUnion.WshEnabled { @@ -533,8 +592,8 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc if err != nil { wslConn.SetWshError(err) wslConn.WshEnabled.Store(false) - log.Printf("error starting wsl shell proc with wsh: %v", err) - log.Print("attempting install without wsh") + blocklogger.Infof(logCtx, "[conndebug] error starting wsl shell proc with wsh: %v\n", err) + blocklogger.Infof(logCtx, "[conndebug] attempting install without wsh\n") shellProc, err = shellexec.StartWslShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) if err != nil { return nil, err @@ -562,8 +621,8 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc if err != nil { conn.SetWshError(err) conn.WshEnabled.Store(false) - log.Printf("error starting remote shell proc with wsh: %v", err) - log.Print("attempting install without wsh") + blocklogger.Infof(logCtx, "[conndebug] error starting remote shell proc with wsh: %v\n", err) + blocklogger.Infof(logCtx, "[conndebug] attempting install without wsh\n") shellProc, err = shellexec.StartRemoteShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, conn) if err != nil { return nil, err diff --git a/pkg/util/fileutil/fileutil.go b/pkg/util/fileutil/fileutil.go index afefbe6f5d..84d9e1e88d 100644 --- a/pkg/util/fileutil/fileutil.go +++ b/pkg/util/fileutil/fileutil.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "path/filepath" + "regexp" "strings" "github.com/wavetermdev/waveterm/pkg/wavebase" @@ -116,3 +117,52 @@ func DetectMimeType(path string, fileInfo fs.FileInfo, extended bool) string { } return rtn } + +var ( + systemBinDirs = []string{ + "/bin/", + "/usr/bin/", + "/usr/local/bin/", + "/opt/bin/", + "/sbin/", + "/usr/sbin/", + } + suspiciousPattern = regexp.MustCompile(`[:;#!&$\t%="|>{}]`) + flagPattern = regexp.MustCompile(` --?[a-zA-Z0-9]`) +) + +// IsInitScriptPath tries to determine if the input string is a path to a script +// rather than an inline script content. +func IsInitScriptPath(input string) bool { + if len(input) == 0 || strings.Contains(input, "\n") { + return false + } + + if suspiciousPattern.MatchString(input) { + return false + } + + if flagPattern.MatchString(input) { + return false + } + + // Check for home directory path + if strings.HasPrefix(input, "~/") { + return true + } + + // Path must be absolute (if not home directory) + if !filepath.IsAbs(input) { + return false + } + + // Check if path starts with system binary directories + normalizedPath := filepath.ToSlash(input) + for _, binDir := range systemBinDirs { + if strings.HasPrefix(normalizedPath, binDir) { + return false + } + } + + return true +} diff --git a/pkg/util/utilfn/utilfn.go b/pkg/util/utilfn/utilfn.go index a9588c8ab4..36cd4be169 100644 --- a/pkg/util/utilfn/utilfn.go +++ b/pkg/util/utilfn/utilfn.go @@ -932,3 +932,12 @@ func TimeoutFromContext(ctx context.Context, defaultTimeout time.Duration) time. } return time.Until(deadline) } + +func HasBinaryData(data []byte) bool { + for _, b := range data { + if b < 32 && b != '\n' && b != '\r' && b != '\t' && b != '\f' && b != '\b' { + return true + } + } + return false +} diff --git a/pkg/waveobj/metamap.go b/pkg/waveobj/metamap.go index a74665ec59..f0fc3a5084 100644 --- a/pkg/waveobj/metamap.go +++ b/pkg/waveobj/metamap.go @@ -3,8 +3,12 @@ package waveobj +import "github.com/google/uuid" + type MetaMapType map[string]any +var MetaMap_DeleteSentinel = uuid.NewString() + func (m MetaMapType) GetString(key string, def string) string { if v, ok := m[key]; ok { if s, ok := v.(string); ok { @@ -48,6 +52,26 @@ func (m MetaMapType) GetStringList(key string) []string { return rtn } +func (m MetaMapType) GetStringMap(key string, useDeleteSentinel bool) map[string]string { + mval := m.GetMap(key) + if len(mval) == 0 { + return nil + } + rtn := make(map[string]string, len(mval)) + for k, v := range mval { + if v == nil { + if useDeleteSentinel { + rtn[k] = MetaMap_DeleteSentinel + } + continue + } + if s, ok := v.(string); ok { + rtn[k] = s + } + } + return rtn +} + func (m MetaMapType) GetBool(key string, def bool) bool { if v, ok := m[key]; ok { if b, ok := v.(bool); ok { diff --git a/pkg/wconfig/settingsconfig.go b/pkg/wconfig/settingsconfig.go index 632934f05c..b3866ab0ee 100644 --- a/pkg/wconfig/settingsconfig.go +++ b/pkg/wconfig/settingsconfig.go @@ -151,6 +151,14 @@ type ConnKeywords struct { TermFontFamily string `json:"term:fontfamily,omitempty"` TermTheme string `json:"term:theme,omitempty"` + CmdEnv map[string]string `json:"cmd:env,omitempty"` + CmdInitScript string `json:"cmd:initscript,omitempty"` + CmdInitScriptSh string `json:"cmd:initscript.sh,omitempty"` + CmdInitScriptBash string `json:"cmd:initscript.bash,omitempty"` + CmdInitScriptZsh string `json:"cmd:initscript.zsh,omitempty"` + CmdInitScriptPwsh string `json:"cmd:initscript.pwsh,omitempty"` + CmdInitScriptFish string `json:"cmd:initscript.fish,omitempty"` + SshUser *string `json:"ssh:user,omitempty"` SshHostName *string `json:"ssh:hostname,omitempty"` SshPort *string `json:"ssh:port,omitempty"`