From 85e3d32dad0b5b72ece11bc3a613e78cc5c3610a Mon Sep 17 00:00:00 2001 From: Evan Simkowitz Date: Wed, 5 Feb 2025 17:59:05 -0800 Subject: [PATCH 1/8] Fix WSH copy internal --- pkg/wshrpc/wshremote/wshremote.go | 165 ++++++++++++++++++------------ 1 file changed, 100 insertions(+), 65 deletions(-) diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index f324f52fec..d1c3777b2f 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -348,11 +348,105 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C if err != nil { return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err) } + + copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) { + destinfo, err = os.Stat(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return 0, fmt.Errorf("cannot stat file %q: %w", path, err) + } + + if destinfo != nil { + if destinfo.IsDir() { + if !finfo.IsDir() { + if !overwrite { + return 0, fmt.Errorf("cannot create directory %q, file exists at path, overwrite not specified", path) + } else { + err := os.Remove(path) + if err != nil { + return 0, fmt.Errorf("cannot remove file %q: %w", path, err) + } + } + } else if !merge && !overwrite { + return 0, fmt.Errorf("cannot create directory %q, directory exists at path, neither overwrite nor merge specified", path) + } else if overwrite { + err := os.RemoveAll(path) + if err != nil { + return 0, fmt.Errorf("cannot remove directory %q: %w", path, err) + } + } + } else { + if finfo.IsDir() { + if !overwrite { + return 0, fmt.Errorf("cannot create file %q, directory exists at path, overwrite not specified", path) + } else { + err := os.RemoveAll(path) + if err != nil { + return 0, fmt.Errorf("cannot remove directory %q: %w", path, err) + } + } + } else if !overwrite { + return 0, fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", path) + } + } + } else { + if finfo.IsDir() { + err := os.MkdirAll(path, finfo.Mode()) + if err != nil { + return 0, fmt.Errorf("cannot create directory %q: %w", path, err) + } + } else { + err := os.MkdirAll(filepath.Dir(path), 0755) + if err != nil { + return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err) + } + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode()) + if err != nil { + return 0, fmt.Errorf("cannot create new file %q: %w", path, err) + } + _, err = io.Copy(file, srcFile) + if err != nil { + return 0, fmt.Errorf("cannot write file %q: %w", path, err) + } + file.Close() + } + } + return finfo.Size(), nil + } + if srcConn.Host == destConn.Host { srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path)) - err := os.Rename(srcPathCleaned, destPathCleaned) + + srcFileStat, err := os.Stat(srcPathCleaned) if err != nil { - return fmt.Errorf("cannot copy file %q to %q: %w", srcPathCleaned, destPathCleaned, err) + return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err) + } + + if srcFileStat.IsDir() { + err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + path = filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathCleaned)) + + var file *os.File + if !info.IsDir() { + file, err = os.Open(path) + if err != nil { + return fmt.Errorf("cannot open file %q: %w", path, err) + } + } + _, err = copyFileFunc(path, info, file) + return err + }) + } else { + file, err := os.Open(srcPathCleaned) + if err != nil { + return fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err) + } + _, err = copyFileFunc(srcPathCleaned, srcFileStat, file) + } + if err != nil { + return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) } } else { timeout := DefaultTimeout @@ -376,70 +470,11 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } numFiles++ finfo := next.FileInfo() - nextPath := filepath.Join(destPathCleaned, next.Name) - destinfo, err = os.Stat(nextPath) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("cannot stat file %q: %w", nextPath, err) - } - if !finfo.IsDir() { - totalBytes += finfo.Size() - } - - if destinfo != nil { - if destinfo.IsDir() { - if !finfo.IsDir() { - if !overwrite { - return fmt.Errorf("cannot create directory %q, file exists at path, overwrite not specified", nextPath) - } else { - err := os.Remove(nextPath) - if err != nil { - return fmt.Errorf("cannot remove file %q: %w", nextPath, err) - } - } - } else if !merge && !overwrite { - return fmt.Errorf("cannot create directory %q, directory exists at path, neither overwrite nor merge specified", nextPath) - } else if overwrite { - err := os.RemoveAll(nextPath) - if err != nil { - return fmt.Errorf("cannot remove directory %q: %w", nextPath, err) - } - } - } else { - if finfo.IsDir() { - if !overwrite { - return fmt.Errorf("cannot create file %q, directory exists at path, overwrite not specified", nextPath) - } else { - err := os.RemoveAll(nextPath) - if err != nil { - return fmt.Errorf("cannot remove directory %q: %w", nextPath, err) - } - } - } else if !overwrite { - return fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", nextPath) - } - } - } else { - if finfo.IsDir() { - err := os.MkdirAll(nextPath, finfo.Mode()) - if err != nil { - return fmt.Errorf("cannot create directory %q: %w", nextPath, err) - } - } else { - err := os.MkdirAll(filepath.Dir(nextPath), 0755) - if err != nil { - return fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(nextPath), err) - } - file, err := os.OpenFile(nextPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode()) - if err != nil { - return fmt.Errorf("cannot create new file %q: %w", nextPath, err) - } - _, err = io.Copy(file, reader) - if err != nil { - return fmt.Errorf("cannot write file %q: %w", nextPath, err) - } - file.Close() - } + n, err := copyFileFunc(filepath.Join(destPathCleaned, next.Name), finfo, reader) + if err != nil { + return fmt.Errorf("cannot copy file %q: %w", next.Name, err) } + totalBytes += n return nil }) if err != nil { From f277ae7a39d8278a8a7b13fa3b63999ad247c540 Mon Sep 17 00:00:00 2001 From: Evan Simkowitz Date: Wed, 5 Feb 2025 18:45:17 -0800 Subject: [PATCH 2/8] better handlng if file exists --- pkg/wshrpc/wshremote/wshremote.go | 68 +++++++++++++++++-------------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index d1c3777b2f..b2ab0d537f 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -358,13 +358,14 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C if destinfo != nil { if destinfo.IsDir() { if !finfo.IsDir() { - if !overwrite { - return 0, fmt.Errorf("cannot create directory %q, file exists at path, overwrite not specified", path) - } else { - err := os.Remove(path) - if err != nil { - return 0, fmt.Errorf("cannot remove file %q: %w", path, err) - } + // try to create file in directory + path = filepath.Join(path, filepath.Base(finfo.Name())) + newdestinfo, err := os.Stat(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return 0, fmt.Errorf("cannot stat file %q: %w", path, err) + } + if newdestinfo != nil && !overwrite { + return 0, fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", path) } } else if !merge && !overwrite { return 0, fmt.Errorf("cannot create directory %q, directory exists at path, neither overwrite nor merge specified", path) @@ -388,28 +389,30 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C return 0, fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", path) } } + } + + if finfo.IsDir() { + err := os.MkdirAll(path, finfo.Mode()) + if err != nil { + return 0, fmt.Errorf("cannot create directory %q: %w", path, err) + } } else { - if finfo.IsDir() { - err := os.MkdirAll(path, finfo.Mode()) - if err != nil { - return 0, fmt.Errorf("cannot create directory %q: %w", path, err) - } - } else { - err := os.MkdirAll(filepath.Dir(path), 0755) - if err != nil { - return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err) - } - file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode()) - if err != nil { - return 0, fmt.Errorf("cannot create new file %q: %w", path, err) - } - _, err = io.Copy(file, srcFile) - if err != nil { - return 0, fmt.Errorf("cannot write file %q: %w", path, err) - } - file.Close() + err := os.MkdirAll(filepath.Dir(path), 0755) + if err != nil { + return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err) } } + + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode()) + if err != nil { + return 0, fmt.Errorf("cannot create new file %q: %w", path, err) + } + _, err = io.Copy(file, srcFile) + if err != nil { + return 0, fmt.Errorf("cannot write file %q: %w", path, err) + } + file.Close() + return finfo.Size(), nil } @@ -434,19 +437,24 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C if err != nil { return fmt.Errorf("cannot open file %q: %w", path, err) } + defer file.Close() } _, err = copyFileFunc(path, info, file) return err }) + if err != nil { + return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) + } } else { file, err := os.Open(srcPathCleaned) + defer file.Close() if err != nil { return fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err) } - _, err = copyFileFunc(srcPathCleaned, srcFileStat, file) - } - if err != nil { - return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) + _, err = copyFileFunc(destPathCleaned, srcFileStat, file) + if err != nil { + return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) + } } } else { timeout := DefaultTimeout From 6f8f4818877fb5eb3ff8afad49b1295f34cdd2e0 Mon Sep 17 00:00:00 2001 From: Evan Simkowitz Date: Wed, 5 Feb 2025 18:52:26 -0800 Subject: [PATCH 3/8] Update pkg/wshrpc/wshremote/wshremote.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- pkg/wshrpc/wshremote/wshremote.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index b2ab0d537f..b48ea9e66a 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -407,11 +407,11 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C if err != nil { return 0, fmt.Errorf("cannot create new file %q: %w", path, err) } + defer file.Close() _, err = io.Copy(file, srcFile) if err != nil { return 0, fmt.Errorf("cannot write file %q: %w", path, err) } - file.Close() return finfo.Size(), nil } From e84d481c9327d782a826ddaa83968c6fe88e19a9 Mon Sep 17 00:00:00 2001 From: Evan Simkowitz Date: Wed, 5 Feb 2025 18:52:42 -0800 Subject: [PATCH 4/8] Update pkg/wshrpc/wshremote/wshremote.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- pkg/wshrpc/wshremote/wshremote.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index b48ea9e66a..378f00f50a 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -447,10 +447,12 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } } else { file, err := os.Open(srcPathCleaned) - defer file.Close() if err != nil { return fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err) } + defer file.Close() + return fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err) + } _, err = copyFileFunc(destPathCleaned, srcFileStat, file) if err != nil { return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) From 321bc9de8697f201223497ea69d4fd716d67f4dc Mon Sep 17 00:00:00 2001 From: Evan Simkowitz Date: Wed, 5 Feb 2025 18:58:17 -0800 Subject: [PATCH 5/8] fix bad merge, coderabbit suggestion --- pkg/wshrpc/wshremote/wshremote.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index 378f00f50a..f66c8cdaa9 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -429,17 +429,17 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C if err != nil { return err } - path = filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathCleaned)) - + srcFilePath := path + destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathCleaned)) var file *os.File if !info.IsDir() { - file, err = os.Open(path) + file, err = os.Open(srcFilePath) if err != nil { - return fmt.Errorf("cannot open file %q: %w", path, err) + return fmt.Errorf("cannot open file %q: %w", srcFilePath, err) } defer file.Close() } - _, err = copyFileFunc(path, info, file) + _, err = copyFileFunc(destFilePath, info, file) return err }) if err != nil { @@ -451,8 +451,6 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C return fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err) } defer file.Close() - return fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err) - } _, err = copyFileFunc(destPathCleaned, srcFileStat, file) if err != nil { return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) From 9457583174b2322bae1272796edecb418b23fb4e Mon Sep 17 00:00:00 2001 From: Evan Simkowitz Date: Wed, 5 Feb 2025 19:03:26 -0800 Subject: [PATCH 6/8] fix relative paths --- pkg/remote/connparse/connparse.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/remote/connparse/connparse.go b/pkg/remote/connparse/connparse.go index 9951065617..b099d1c0a9 100644 --- a/pkg/remote/connparse/connparse.go +++ b/pkg/remote/connparse/connparse.go @@ -128,8 +128,11 @@ func ParseURI(uri string) (*Connection, error) { } } + addPrecedingSlash := true + if scheme == "" { scheme = ConnectionTypeWsh + addPrecedingSlash = false if len(rest) != len(uri) { // This accounts for when the uri starts with "//", which would get trimmed in the first split. parseWshPath() @@ -152,7 +155,7 @@ func ParseURI(uri string) (*Connection, error) { } if strings.HasPrefix(remotePath, "/~") { remotePath = strings.TrimPrefix(remotePath, "/") - } else if len(remotePath) > 1 && !windowsDriveRegex.MatchString(remotePath) && !strings.HasPrefix(remotePath, "/") && !strings.HasPrefix(remotePath, "~") && !strings.HasPrefix(remotePath, "./") && !strings.HasPrefix(remotePath, "../") && !strings.HasPrefix(remotePath, ".\\") && !strings.HasPrefix(remotePath, "..\\") && remotePath != ".." { + } else if addPrecedingSlash && (len(remotePath) > 1 && !windowsDriveRegex.MatchString(remotePath) && !strings.HasPrefix(remotePath, "/") && !strings.HasPrefix(remotePath, "~") && !strings.HasPrefix(remotePath, "./") && !strings.HasPrefix(remotePath, "../") && !strings.HasPrefix(remotePath, ".\\") && !strings.HasPrefix(remotePath, "..\\") && remotePath != "..") { remotePath = "/" + remotePath } } From 523a084fb717cbb80023e1ffe48c760463d09c3c Mon Sep 17 00:00:00 2001 From: Evan Simkowitz Date: Wed, 5 Feb 2025 19:06:51 -0800 Subject: [PATCH 7/8] add current path test --- pkg/remote/connparse/connparse_test.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pkg/remote/connparse/connparse_test.go b/pkg/remote/connparse/connparse_test.go index c530c8e768..97b9aaca9a 100644 --- a/pkg/remote/connparse/connparse_test.go +++ b/pkg/remote/connparse/connparse_test.go @@ -212,6 +212,28 @@ func TestParseURI_WSHCurrentPath(t *testing.T) { if c.GetFullURI() != expected { t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) } + + cstr = "path/to/file" + c, err = connparse.ParseURI(cstr) + if err != nil { + t.Fatalf("failed to parse URI: %v", err) + } + expected = "path/to/file" + if c.Path != expected { + t.Fatalf("expected path to be %q, got %q", expected, c.Path) + } + expected = "current" + if c.Host != expected { + t.Fatalf("expected host to be %q, got %q", expected, c.Host) + } + expected = "wsh" + if c.Scheme != expected { + t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + } + expected = "wsh://current/path/to/file" + if c.GetFullURI() != expected { + t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + } } func TestParseURI_WSHCurrentPathWindows(t *testing.T) { From 1df574de4368e07a226c82670f23e688014e1104 Mon Sep 17 00:00:00 2001 From: Evan Simkowitz Date: Wed, 5 Feb 2025 19:08:38 -0800 Subject: [PATCH 8/8] more tests --- pkg/remote/connparse/connparse_test.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pkg/remote/connparse/connparse_test.go b/pkg/remote/connparse/connparse_test.go index 97b9aaca9a..82ccc83625 100644 --- a/pkg/remote/connparse/connparse_test.go +++ b/pkg/remote/connparse/connparse_test.go @@ -234,6 +234,28 @@ func TestParseURI_WSHCurrentPath(t *testing.T) { if c.GetFullURI() != expected { t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) } + + cstr = "/etc/path/to/file" + c, err = connparse.ParseURI(cstr) + if err != nil { + t.Fatalf("failed to parse URI: %v", err) + } + expected = "/etc/path/to/file" + if c.Path != expected { + t.Fatalf("expected path to be %q, got %q", expected, c.Path) + } + expected = "current" + if c.Host != expected { + t.Fatalf("expected host to be %q, got %q", expected, c.Host) + } + expected = "wsh" + if c.Scheme != expected { + t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + } + expected = "wsh://current/etc/path/to/file" + if c.GetFullURI() != expected { + t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + } } func TestParseURI_WSHCurrentPathWindows(t *testing.T) {