diff --git a/cmd/wsh/cmd/wshcmd-file.go b/cmd/wsh/cmd/wshcmd-file.go index 7c1c60ded0..0011c47527 100644 --- a/cmd/wsh/cmd/wshcmd-file.go +++ b/cmd/wsh/cmd/wshcmd-file.go @@ -103,7 +103,6 @@ func init() { fileCmd.AddCommand(fileInfoCmd) fileCmd.AddCommand(fileAppendCmd) fileCpCmd.Flags().BoolP("merge", "m", false, "merge directories") - fileCpCmd.Flags().BoolP("recursive", "r", false, "copy directories recursively") fileCpCmd.Flags().BoolP("force", "f", false, "force overwrite of existing files") fileCmd.AddCommand(fileCpCmd) fileMvCmd.Flags().BoolP("recursive", "r", false, "move directories recursively") @@ -174,7 +173,7 @@ var fileAppendCmd = &cobra.Command{ var fileCpCmd = &cobra.Command{ Use: "cp [source-uri] [destination-uri]" + UriHelpText, Aliases: []string{"copy"}, - Short: "copy files between storage systems", + Short: "copy files between storage systems, recursively if needed", Long: "Copy files between different storage systems." + UriHelpText, Example: " wsh file cp wavefile://block/config.txt ./local-config.txt\n wsh file cp ./local-config.txt wavefile://block/config.txt\n wsh file cp wsh://user@ec2/home/user/config.txt wavefile://client/config.txt", Args: cobra.ExactArgs(2), @@ -398,10 +397,6 @@ func getTargetPath(src, dst string) (string, error) { func fileCpRun(cmd *cobra.Command, args []string) error { src, dst := args[0], args[1] - recursive, err := cmd.Flags().GetBool("recursive") - if err != nil { - return err - } merge, err := cmd.Flags().GetBool("merge") if err != nil { return err @@ -419,9 +414,9 @@ func fileCpRun(cmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("unable to parse dest path: %w", err) } - log.Printf("Copying %s to %s; recursive: %v, merge: %v, force: %v", srcPath, destPath, recursive, merge, force) + log.Printf("Copying %s to %s; merge: %v, force: %v", srcPath, destPath, merge, force) rpcOpts := &wshrpc.RpcOpts{Timeout: TimeoutYear} - err = wshclient.FileCopyCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcPath, DestUri: destPath, Opts: &wshrpc.FileCopyOpts{Recursive: recursive, Merge: merge, Overwrite: force, Timeout: TimeoutYear}}, rpcOpts) + err = wshclient.FileCopyCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcPath, DestUri: destPath, Opts: &wshrpc.FileCopyOpts{Merge: merge, Overwrite: force, Timeout: TimeoutYear}}, rpcOpts) if err != nil { return fmt.Errorf("copying file: %w", err) } @@ -449,7 +444,7 @@ func fileMvRun(cmd *cobra.Command, args []string) error { } log.Printf("Moving %s to %s; recursive: %v, force: %v", srcPath, destPath, recursive, force) rpcOpts := &wshrpc.RpcOpts{Timeout: TimeoutYear} - err = wshclient.FileMoveCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcPath, DestUri: destPath, Opts: &wshrpc.FileCopyOpts{Recursive: recursive, Overwrite: force, Timeout: TimeoutYear}}, rpcOpts) + err = wshclient.FileMoveCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcPath, DestUri: destPath, Opts: &wshrpc.FileCopyOpts{Overwrite: force, Timeout: TimeoutYear, Recursive: recursive}}, rpcOpts) if err != nil { return fmt.Errorf("moving file: %w", err) } diff --git a/cmd/wsh/cmd/wshcmd-view.go b/cmd/wsh/cmd/wshcmd-view.go index 97ee8ffdb6..a2f8f86394 100644 --- a/cmd/wsh/cmd/wshcmd-view.go +++ b/cmd/wsh/cmd/wshcmd-view.go @@ -19,6 +19,7 @@ var viewMagnified bool var viewCmd = &cobra.Command{ Use: "view {file|directory|URL}", + Aliases: []string{"preview", "open"}, Short: "preview/edit a file or directory", RunE: viewRun, PreRunE: preRunSetupRpcClient, diff --git a/frontend/app/store/wshclientapi.ts b/frontend/app/store/wshclientapi.ts index 1c6903104e..3c83bc98d1 100644 --- a/frontend/app/store/wshclientapi.ts +++ b/frontend/app/store/wshclientapi.ts @@ -263,7 +263,7 @@ class RpcApiType { } // command "remotefilecopy" [call] - RemoteFileCopyCommand(client: WshClient, data: CommandRemoteFileCopyData, opts?: RpcOpts): Promise { + RemoteFileCopyCommand(client: WshClient, data: CommandFileCopyData, opts?: RpcOpts): Promise { return client.wshRpcCall("remotefilecopy", data, opts); } @@ -283,7 +283,7 @@ class RpcApiType { } // command "remotefilemove" [call] - RemoteFileMoveCommand(client: WshClient, data: CommandRemoteFileCopyData, opts?: RpcOpts): Promise { + RemoteFileMoveCommand(client: WshClient, data: CommandFileCopyData, opts?: RpcOpts): Promise { return client.wshRpcCall("remotefilemove", data, opts); } diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index a690a60cec..180fb42cec 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -202,13 +202,6 @@ declare global { message: string; }; - // wshrpc.CommandRemoteFileCopyData - type CommandRemoteFileCopyData = { - srcuri: string; - desturi: string; - opts?: FileCopyOpts; - }; - // wshrpc.CommandRemoteListEntriesData type CommandRemoteListEntriesData = { path: string; diff --git a/pkg/remote/fileshare/fileshare.go b/pkg/remote/fileshare/fileshare.go index 9473db55a2..10ca41f97e 100644 --- a/pkg/remote/fileshare/fileshare.go +++ b/pkg/remote/fileshare/fileshare.go @@ -118,11 +118,19 @@ func Move(ctx context.Context, data wshrpc.CommandFileCopyData) error { return fmt.Errorf("error creating fileshare client, could not parse destination connection %s", data.DestUri) } if srcConn.Host != destConn.Host { - err := destClient.CopyRemote(ctx, srcConn, destConn, srcClient, data.Opts) + finfo, err := srcClient.Stat(ctx, srcConn) + if err != nil { + return fmt.Errorf("cannot stat %q: %w", data.SrcUri, err) + } + recursive := data.Opts != nil && data.Opts.Recursive + if finfo.IsDir && data.Opts != nil && !recursive { + return fmt.Errorf("cannot move directory %q to %q without recursive flag", data.SrcUri, data.DestUri) + } + err = destClient.CopyRemote(ctx, srcConn, destConn, srcClient, data.Opts) if err != nil { return fmt.Errorf("cannot copy %q to %q: %w", data.SrcUri, data.DestUri, err) } - return srcClient.Delete(ctx, srcConn, data.Opts.Recursive) + return srcClient.Delete(ctx, srcConn, recursive) } else { return srcClient.MoveInternal(ctx, srcConn, destConn, data.Opts) } diff --git a/pkg/remote/fileshare/fstype/fstype.go b/pkg/remote/fileshare/fstype/fstype.go index 3c3d6fceb3..2e44e6b003 100644 --- a/pkg/remote/fileshare/fstype/fstype.go +++ b/pkg/remote/fileshare/fstype/fstype.go @@ -5,12 +5,17 @@ package fstype import ( "context" + "time" "github.com/wavetermdev/waveterm/pkg/remote/connparse" "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" "github.com/wavetermdev/waveterm/pkg/wshrpc" ) +const ( + DefaultTimeout = 30 * time.Second +) + type FileShareClient interface { // Stat returns the file info at the given parsed connection path Stat(ctx context.Context, conn *connparse.Connection) (*wshrpc.FileInfo, error) diff --git a/pkg/remote/fileshare/wavefs/wavefs.go b/pkg/remote/fileshare/wavefs/wavefs.go index 63cbe36a1d..181e5699c4 100644 --- a/pkg/remote/fileshare/wavefs/wavefs.go +++ b/pkg/remote/fileshare/wavefs/wavefs.go @@ -29,10 +29,6 @@ import ( "github.com/wavetermdev/waveterm/pkg/wshutil" ) -const ( - DefaultTimeout = 30 * time.Second -) - type WaveClient struct{} var _ fstype.FileShareClient = WaveClient{} @@ -54,7 +50,7 @@ func (c WaveClient) ReadStream(ctx context.Context, conn *connparse.Connection, if !rtnData.Info.IsDir { for i := 0; i < dataLen; i += wshrpc.FileChunkSize { if ctx.Err() != nil { - ch <- wshutil.RespErr[wshrpc.FileData](ctx.Err()) + ch <- wshutil.RespErr[wshrpc.FileData](context.Cause(ctx)) return } dataEnd := min(i+wshrpc.FileChunkSize, dataLen) @@ -63,7 +59,7 @@ func (c WaveClient) ReadStream(ctx context.Context, conn *connparse.Connection, } else { for i := 0; i < len(rtnData.Entries); i += wshrpc.DirChunkSize { if ctx.Err() != nil { - ch <- wshutil.RespErr[wshrpc.FileData](ctx.Err()) + ch <- wshutil.RespErr[wshrpc.FileData](context.Cause(ctx)) return } ch <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: wshrpc.FileData{Entries: rtnData.Entries[i:min(i+wshrpc.DirChunkSize, len(rtnData.Entries))], Info: rtnData.Info}} @@ -116,7 +112,7 @@ func (c WaveClient) ReadTarStream(ctx context.Context, conn *connparse.Connectio pathPrefix := getPathPrefix(conn) schemeAndHost := conn.GetSchemeAndHost() + "/" - timeout := DefaultTimeout + timeout := fstype.DefaultTimeout if opts.Timeout > 0 { timeout = time.Duration(opts.Timeout) * time.Millisecond } @@ -130,12 +126,12 @@ func (c WaveClient) ReadTarStream(ctx context.Context, conn *connparse.Connectio }() for _, file := range list { if readerCtx.Err() != nil { - rtn <- wshutil.RespErr[iochantypes.Packet](readerCtx.Err()) + rtn <- wshutil.RespErr[iochantypes.Packet](context.Cause(readerCtx)) return } file.Mode = 0644 - if err = writeHeader(fileutil.ToFsFileInfo(file), file.Path); err != nil { + if err = writeHeader(fileutil.ToFsFileInfo(file), file.Path, file.Path == conn.Path); err != nil { rtn <- wshutil.RespErr[iochantypes.Packet](fmt.Errorf("error writing tar header: %w", err)) return } @@ -447,27 +443,37 @@ func (c WaveClient) CopyRemote(ctx context.Context, srcConn, destConn *connparse if zoneId == "" { return fmt.Errorf("zoneid not found in connection") } + overwrite := opts != nil && opts.Overwrite + merge := opts != nil && opts.Merge + destHasSlash := strings.HasSuffix(destConn.Path, "/") destPrefix := getPathPrefix(destConn) destPrefix = strings.TrimPrefix(destPrefix, destConn.GetSchemeAndHost()+"/") log.Printf("CopyRemote: srcConn: %v, destConn: %v, destPrefix: %s\n", srcConn, destConn, destPrefix) + entries, err := c.ListEntries(ctx, srcConn, nil) + if err != nil { + return fmt.Errorf("error listing blockfiles: %w", err) + } + if len(entries) > 1 && !merge { + return fmt.Errorf("more than one entry at destination prefix, use merge flag to copy") + } readCtx, cancel := context.WithCancelCause(ctx) ioch := srcClient.ReadTarStream(readCtx, srcConn, opts) - err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader) error { + err = tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader, singleFile bool) error { if next.Typeflag == tar.TypeDir { return nil } fileName, err := cleanPath(path.Join(destPrefix, next.Name)) + if singleFile && !destHasSlash { + fileName, err = cleanPath(destConn.Path) + } if err != nil { return fmt.Errorf("error cleaning path: %w", err) } - _, err = filestore.WFS.Stat(ctx, zoneId, fileName) - if err != nil { - if !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("error getting blockfile info: %w", err) - } - err := filestore.WFS.MakeFile(ctx, zoneId, fileName, nil, wshrpc.FileOpts{}) - if err != nil { - return fmt.Errorf("error making blockfile: %w", err) + if !overwrite { + for _, entry := range entries { + if entry.Name == fileName { + return fmt.Errorf("destination already exists: %v", fileName) + } } } log.Printf("CopyRemote: writing file: %s; size: %d\n", fileName, next.Size) diff --git a/pkg/remote/fileshare/wshfs/wshfs.go b/pkg/remote/fileshare/wshfs/wshfs.go index 61816ea576..424c589d20 100644 --- a/pkg/remote/fileshare/wshfs/wshfs.go +++ b/pkg/remote/fileshare/wshfs/wshfs.go @@ -157,7 +157,7 @@ func (c WshClient) MoveInternal(ctx context.Context, srcConn, destConn *connpars if timeout == 0 { timeout = ThirtySeconds } - return wshclient.RemoteFileMoveCommand(RpcClient, wshrpc.CommandRemoteFileCopyData{SrcUri: srcConn.GetFullURI(), DestUri: destConn.GetFullURI(), Opts: opts}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(destConn.Host), Timeout: timeout}) + return wshclient.RemoteFileMoveCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcConn.GetFullURI(), DestUri: destConn.GetFullURI(), Opts: opts}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(destConn.Host), Timeout: timeout}) } func (c WshClient) CopyRemote(ctx context.Context, srcConn, destConn *connparse.Connection, _ fstype.FileShareClient, opts *wshrpc.FileCopyOpts) error { @@ -172,7 +172,7 @@ func (c WshClient) CopyInternal(ctx context.Context, srcConn, destConn *connpars if timeout == 0 { timeout = ThirtySeconds } - return wshclient.RemoteFileCopyCommand(RpcClient, wshrpc.CommandRemoteFileCopyData{SrcUri: srcConn.GetFullURI(), DestUri: destConn.GetFullURI(), Opts: opts}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(destConn.Host), Timeout: timeout}) + return wshclient.RemoteFileCopyCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcConn.GetFullURI(), DestUri: destConn.GetFullURI(), Opts: opts}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(destConn.Host), Timeout: timeout}) } func (c WshClient) Delete(ctx context.Context, conn *connparse.Connection, recursive bool) error { diff --git a/pkg/util/fileutil/fileutil.go b/pkg/util/fileutil/fileutil.go index 4c894f190c..d7c8940db4 100644 --- a/pkg/util/fileutil/fileutil.go +++ b/pkg/util/fileutil/fileutil.go @@ -19,6 +19,7 @@ import ( ) func FixPath(path string) (string, error) { + origPath := path var err error if strings.HasPrefix(path, "~") { path = filepath.Join(wavebase.GetHomeDir(), path[1:]) @@ -28,6 +29,9 @@ func FixPath(path string) (string, error) { return "", err } } + if strings.HasSuffix(origPath, "/") && !strings.HasSuffix(path, "/") { + path += "/" + } return path, nil } diff --git a/pkg/util/tarcopy/tarcopy.go b/pkg/util/tarcopy/tarcopy.go index 06e008811c..aaa480d0b3 100644 --- a/pkg/util/tarcopy/tarcopy.go +++ b/pkg/util/tarcopy/tarcopy.go @@ -29,31 +29,53 @@ const ( pipeReaderName = "pipe reader" pipeWriterName = "pipe writer" tarWriterName = "tar writer" + + // custom flag to indicate that the source is a single file + SingleFile = "singlefile" ) // TarCopySrc creates a tar stream writer and returns a channel to send the tar stream to. -// writeHeader is a function that writes the tar header for the file. +// writeHeader is a function that writes the tar header for the file. If only a single file is being written, the singleFile flag should be set to true. // writer is the tar writer to write the file data to. // close is a function that closes the tar writer and internal pipe writer. -func TarCopySrc(ctx context.Context, pathPrefix string) (outputChan chan wshrpc.RespOrErrorUnion[iochantypes.Packet], writeHeader func(fi fs.FileInfo, file string) error, writer io.Writer, close func()) { +func TarCopySrc(ctx context.Context, pathPrefix string) (outputChan chan wshrpc.RespOrErrorUnion[iochantypes.Packet], writeHeader func(fi fs.FileInfo, file string, singleFile bool) error, writer io.Writer, close func()) { pipeReader, pipeWriter := io.Pipe() tarWriter := tar.NewWriter(pipeWriter) rtnChan := iochan.ReaderChan(ctx, pipeReader, wshrpc.FileChunkSize, func() { gracefulClose(pipeReader, tarCopySrcName, pipeReaderName) }) - return rtnChan, func(fi fs.FileInfo, file string) error { + singleFileFlagSet := false + + return rtnChan, func(fi fs.FileInfo, path string, singleFile bool) error { // generate tar header - header, err := tar.FileInfoHeader(fi, file) + header, err := tar.FileInfoHeader(fi, path) if err != nil { return err } - header.Name = filepath.Clean(strings.TrimPrefix(file, pathPrefix)) - if err := validatePath(header.Name); err != nil { + if singleFile { + if singleFileFlagSet { + return errors.New("attempting to write multiple files to a single file tar stream") + } + + header.PAXRecords = map[string]string{SingleFile: "true"} + singleFileFlagSet = true + } + + path, err = fixPath(path, pathPrefix) + if err != nil { return err } + // skip if path is empty, which means the file is the root directory + if path == "" { + return nil + } + header.Name = path + + log.Printf("TarCopySrc: header name: %v\n", header.Name) + // write header if err := tarWriter.WriteHeader(header); err != nil { return err @@ -65,20 +87,18 @@ func TarCopySrc(ctx context.Context, pathPrefix string) (outputChan chan wshrpc. } } -func validatePath(path string) error { +func fixPath(path string, prefix string) (string, error) { + path = strings.TrimPrefix(strings.TrimPrefix(filepath.Clean(strings.TrimPrefix(path, prefix)), "/"), "\\") if strings.Contains(path, "..") { - return fmt.Errorf("invalid tar path containing directory traversal: %s", path) + return "", fmt.Errorf("invalid tar path containing directory traversal: %s", path) } - if strings.HasPrefix(path, "/") { - return fmt.Errorf("invalid tar path starting with /: %s", path) - } - return nil + return path, nil } // TarCopyDest reads a tar stream from a channel and writes the files to the destination. -// readNext is a function that is called for each file in the tar stream to read the file data. It should return an error if the file cannot be read. +// readNext is a function that is called for each file in the tar stream to read the file data. If only a single file is being written from the tar src, the singleFile flag will be set in this callback. It should return an error if the file cannot be read. // The function returns an error if the tar stream cannot be read. -func TarCopyDest(ctx context.Context, cancel context.CancelCauseFunc, ch <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet], readNext func(next *tar.Header, reader *tar.Reader) error) error { +func TarCopyDest(ctx context.Context, cancel context.CancelCauseFunc, ch <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet], readNext func(next *tar.Header, reader *tar.Reader, singleFile bool) error) error { pipeReader, pipeWriter := io.Pipe() iochan.WriterChan(ctx, pipeWriter, ch, func() { gracefulClose(pipeWriter, tarCopyDestName, pipeWriterName) @@ -110,7 +130,12 @@ func TarCopyDest(ctx context.Context, cancel context.CancelCauseFunc, ch <-chan return err } } - err = readNext(next, tarReader) + + // Check for directory traversal + if strings.Contains(next.Name, "..") { + return nil + } + err = readNext(next, tarReader, next.PAXRecords != nil && next.PAXRecords[SingleFile] == "true") if err != nil { return err } diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index 6fdbaf7473..5d2f140097 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -321,7 +321,7 @@ func RecordTEventCommand(w *wshutil.WshRpc, data telemetrydata.TEvent, opts *wsh } // command "remotefilecopy", wshserver.RemoteFileCopyCommand -func RemoteFileCopyCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteFileCopyData, opts *wshrpc.RpcOpts) error { +func RemoteFileCopyCommand(w *wshutil.WshRpc, data wshrpc.CommandFileCopyData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "remotefilecopy", data, opts) return err } @@ -345,7 +345,7 @@ func RemoteFileJoinCommand(w *wshutil.WshRpc, data []string, opts *wshrpc.RpcOpt } // command "remotefilemove", wshserver.RemoteFileMoveCommand -func RemoteFileMoveCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteFileCopyData, opts *wshrpc.RpcOpts) error { +func RemoteFileMoveCommand(w *wshutil.WshRpc, data wshrpc.CommandFileCopyData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "remotefilemove", data, opts) return err } diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index 711de2e26e..35698ca7e1 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -18,6 +18,7 @@ import ( "time" "github.com/wavetermdev/waveterm/pkg/remote/connparse" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs" "github.com/wavetermdev/waveterm/pkg/suggestion" "github.com/wavetermdev/waveterm/pkg/util/fileutil" @@ -30,10 +31,6 @@ import ( "github.com/wavetermdev/waveterm/pkg/wshutil" ) -const ( - DefaultTimeout = 30 * time.Second -) - type ServerImpl struct { LogWriter io.Writer } @@ -240,8 +237,8 @@ func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc. if opts == nil { opts = &wshrpc.FileCopyOpts{} } - recursive := opts.Recursive - logPrintfDev("RemoteTarStreamCommand: path=%s\n", path) + log.Printf("RemoteTarStreamCommand: path=%s\n", path) + srcHasSlash := strings.HasSuffix(path, "/") path, err := wavebase.ExpandHomeDir(path) if err != nil { return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot expand path %q: %w", path, err)) @@ -253,18 +250,15 @@ func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc. } var pathPrefix string - if finfo.IsDir() && strings.HasSuffix(cleanedPath, "/") { + singleFile := !finfo.IsDir() + if !singleFile && srcHasSlash { pathPrefix = cleanedPath } else { - pathPrefix = filepath.Dir(cleanedPath) + "/" - } - if finfo.IsDir() { - if !recursive { - return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot create tar stream for %q: %w", path, errors.New("directory copy requires recursive option"))) - } + pathPrefix = filepath.Dir(cleanedPath) } + log.Printf("RemoteTarStreamCommand: path=%s, pathPrefix=%s\n", path, pathPrefix) - timeout := DefaultTimeout + timeout := fstype.DefaultTimeout if opts.Timeout > 0 { timeout = time.Duration(opts.Timeout) * time.Millisecond } @@ -283,7 +277,8 @@ func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc. if err != nil { return err } - if err = writeHeader(info, path); err != nil { + log.Printf("RemoteTarStreamCommand: path=%s\n", path) + if err = writeHeader(info, path, singleFile); err != nil { return err } // if not a dir, write file content @@ -300,10 +295,10 @@ func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc. } log.Printf("RemoteTarStreamCommand: starting\n") err = nil - if finfo.IsDir() { - err = filepath.Walk(path, walkFunc) + if singleFile { + err = walkFunc(cleanedPath, finfo, nil) } else { - err = walkFunc(path, finfo, nil) + err = filepath.Walk(cleanedPath, walkFunc) } if err != nil { rtn <- wshutil.RespErr[iochantypes.Packet](err) @@ -314,7 +309,7 @@ func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc. return rtn } -func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.CommandRemoteFileCopyData) error { +func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.CommandFileCopyData) error { log.Printf("RemoteFileCopyCommand: src=%s, dest=%s\n", data.SrcUri, data.DestUri) opts := data.Opts if opts == nil { @@ -331,19 +326,25 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path)) destinfo, err := os.Stat(destPathCleaned) - if err == nil { - if !destinfo.IsDir() { - if !overwrite { - return fmt.Errorf("destination %q already exists, use overwrite option", destPathCleaned) - } else { - err := os.Remove(destPathCleaned) - if err != nil { - return fmt.Errorf("cannot remove file %q: %w", destPathCleaned, err) - } + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("cannot stat destination %q: %w", destPathCleaned, err) + } + } + + destExists := destinfo != nil + destIsDir := destExists && destinfo.IsDir() + destHasSlash := strings.HasSuffix(destUri, "/") + + if destExists && !destIsDir { + if !overwrite { + return fmt.Errorf("file already exists at destination %q, use overwrite option", destPathCleaned) + } else { + err := os.Remove(destPathCleaned) + if err != nil { + return fmt.Errorf("cannot remove file %q: %w", destPathCleaned, err) } } - } else if !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("cannot stat destination %q: %w", destPathCleaned, err) } srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri) if err != nil { @@ -351,14 +352,16 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) { - destinfo, err = os.Stat(path) + nextinfo, err := os.Stat(path) if err != nil && !errors.Is(err, fs.ErrNotExist) { return 0, fmt.Errorf("cannot stat file %q: %w", path, err) } - if destinfo != nil { - if destinfo.IsDir() { + if nextinfo != nil { + if nextinfo.IsDir() { + log.Printf("RemoteFileCopyCommand: nextinfo is dir, path=%s\n", path) if !finfo.IsDir() { + log.Printf("RemoteFileCopyCommand: finfo is file: %s\n", path) // try to create file in directory path = filepath.Join(path, filepath.Base(finfo.Name())) newdestinfo, err := os.Stat(path) @@ -390,13 +393,17 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C return 0, fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", path) } } + } else { + log.Printf("RemoteFileCopyCommand: nextinfo is nil, path=%s\n", path) } if finfo.IsDir() { + log.Printf("RemoteFileCopyCommand: making dirs %s\n", path) err := os.MkdirAll(path, finfo.Mode()) if err != nil { return 0, fmt.Errorf("cannot create directory %q: %w", path, err) } + return 0, nil } else { err := os.MkdirAll(filepath.Dir(path), 0755) if err != nil { @@ -426,12 +433,19 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } if srcFileStat.IsDir() { + log.Print("RemoteFileCopyCommand: copying directory\n") + srcPathPrefix := filepath.Dir(srcPathCleaned) + if strings.HasSuffix(srcUri, "/") { + log.Printf("RemoteFileCopyCommand: src has slash, using %q as src path\n", srcPathCleaned) + srcPathPrefix = srcPathCleaned + } err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error { if err != nil { return err } srcFilePath := path - destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathCleaned)) + destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathPrefix)) + log.Printf("RemoteFileCopyCommand: copying %q to %q\n", srcFilePath, destFilePath) var file *os.File if !info.IsDir() { file, err = os.Open(srcFilePath) @@ -447,18 +461,24 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) } } else { + log.Print("RemoteFileCopyCommand: copying single file\n") file, err := os.Open(srcPathCleaned) if err != nil { return fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err) } defer file.Close() - _, err = copyFileFunc(destPathCleaned, srcFileStat, file) + destFilePath := filepath.Join(destPathCleaned, filepath.Base(srcPathCleaned)) + if destHasSlash { + log.Printf("RemoteFileCopyCommand: dest has slash, using %q as dest path\n", destPathCleaned) + destFilePath = destPathCleaned + } + _, err = copyFileFunc(destFilePath, srcFileStat, file) if err != nil { return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) } } } else { - timeout := DefaultTimeout + timeout := fstype.DefaultTimeout if opts.Timeout > 0 { timeout = time.Duration(opts.Timeout) * time.Millisecond } @@ -470,16 +490,19 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C numFiles := 0 numSkipped := 0 totalBytes := int64(0) - err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader) error { - // Check for directory traversal - if strings.Contains(next.Name, "..") { - log.Printf("skipping file with unsafe path: %q\n", next.Name) - numSkipped++ - return nil - } + + err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader, singleFile bool) error { numFiles++ + nextpath := filepath.Join(destPathCleaned, next.Name) + log.Printf("RemoteFileCopyCommand: copying %q to %q\n", next.Name, nextpath) + if singleFile { + // custom flag to indicate that the source is a single file, not a directory the contents of a directory + if !destHasSlash { + nextpath = destPathCleaned + } + } finfo := next.FileInfo() - n, err := copyFileFunc(filepath.Join(destPathCleaned, next.Name), finfo, reader) + n, err := copyFileFunc(nextpath, finfo, reader) if err != nil { return fmt.Errorf("cannot copy file %q: %w", next.Name, err) } @@ -689,12 +712,13 @@ func (impl *ServerImpl) RemoteFileTouchCommand(ctx context.Context, path string) return nil } -func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.CommandRemoteFileCopyData) error { +func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.CommandFileCopyData) error { logPrintfDev("RemoteFileCopyCommand: src=%s, dest=%s\n", data.SrcUri, data.DestUri) opts := data.Opts destUri := data.DestUri srcUri := data.SrcUri overwrite := opts != nil && opts.Overwrite + recursive := opts != nil && opts.Recursive destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri) if err != nil { @@ -722,7 +746,14 @@ func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.C } if srcConn.Host == destConn.Host { srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path)) - err := os.Rename(srcPathCleaned, destPathCleaned) + finfo, err := os.Stat(srcPathCleaned) + if err != nil { + return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err) + } + if finfo.IsDir() && !recursive { + return fmt.Errorf("cannot move directory %q, recursive option not specified", srcUri) + } + err = os.Rename(srcPathCleaned, destPathCleaned) if err != nil { return fmt.Errorf("cannot move file %q to %q: %w", srcPathCleaned, destPathCleaned, err) } diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 2a74a9c63e..0cf70ae3f0 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -206,11 +206,11 @@ type WshRpcInterface interface { // remotes RemoteStreamFileCommand(ctx context.Context, data CommandRemoteStreamFileData) chan RespOrErrorUnion[FileData] RemoteTarStreamCommand(ctx context.Context, data CommandRemoteStreamTarData) <-chan RespOrErrorUnion[iochantypes.Packet] - RemoteFileCopyCommand(ctx context.Context, data CommandRemoteFileCopyData) error + RemoteFileCopyCommand(ctx context.Context, data CommandFileCopyData) error RemoteListEntriesCommand(ctx context.Context, data CommandRemoteListEntriesData) chan RespOrErrorUnion[CommandRemoteListEntriesRtnData] RemoteFileInfoCommand(ctx context.Context, path string) (*FileInfo, error) RemoteFileTouchCommand(ctx context.Context, path string) error - RemoteFileMoveCommand(ctx context.Context, data CommandRemoteFileCopyData) error + RemoteFileMoveCommand(ctx context.Context, data CommandFileCopyData) error RemoteFileDeleteCommand(ctx context.Context, data CommandDeleteFileData) error RemoteWriteFileCommand(ctx context.Context, data FileData) error RemoteFileJoinCommand(ctx context.Context, paths []string) (*FileInfo, error) @@ -515,12 +515,6 @@ type CommandFileCopyData struct { Opts *FileCopyOpts `json:"opts,omitempty"` } -type CommandRemoteFileCopyData struct { - SrcUri string `json:"srcuri"` - DestUri string `json:"desturi"` - Opts *FileCopyOpts `json:"opts,omitempty"` -} - type CommandRemoteStreamTarData struct { Path string `json:"path"` Opts *FileCopyOpts `json:"opts,omitempty"` @@ -528,8 +522,8 @@ type CommandRemoteStreamTarData struct { type FileCopyOpts struct { Overwrite bool `json:"overwrite,omitempty"` - Recursive bool `json:"recursive,omitempty"` - Merge bool `json:"merge,omitempty"` + Recursive bool `json:"recursive,omitempty"` // only used for move, always true for copy + Merge bool `json:"merge,omitempty"` // only used for copy, always false for move Timeout int64 `json:"timeout,omitempty"` } diff --git a/pkg/wshutil/wshproxy.go b/pkg/wshutil/wshproxy.go index 0bc5ae088d..6ad0a96193 100644 --- a/pkg/wshutil/wshproxy.go +++ b/pkg/wshutil/wshproxy.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wshrpc" @@ -248,6 +249,9 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) { } func (p *WshRpcProxy) SendRpcMessage(msg []byte) { + defer func() { + panichandler.PanicHandler("WshRpcProxy.SendRpcMessage", recover()) + }() p.ToRemoteCh <- msg }