diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 491e5b29dc..57ffb0962b 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -6,6 +6,9 @@ ### CLI +* Improve performance of `databricks fs cp` command by parallelizing file uploads when + copying directories with the `--recursive` flag. + ### Bundles * engine/direct: Fix dependency-ordered deletion by persisting depends_on in state ([#4105](https://github.com/databricks/cli/pull/4105)) * Pass SYSTEM_ACCESSTOKEN from env to the Terraform provider ([#4135](https://github.com/databricks/cli/pull/4135) diff --git a/acceptance/cmd/fs/cp/dir-to-dir/out.test.toml b/acceptance/cmd/fs/cp/dir-to-dir/out.test.toml new file mode 100644 index 0000000000..d560f1de04 --- /dev/null +++ b/acceptance/cmd/fs/cp/dir-to-dir/out.test.toml @@ -0,0 +1,5 @@ +Local = true +Cloud = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/fs/cp/dir-to-dir/output.txt b/acceptance/cmd/fs/cp/dir-to-dir/output.txt new file mode 100644 index 0000000000..9b708fb9f0 --- /dev/null +++ b/acceptance/cmd/fs/cp/dir-to-dir/output.txt @@ -0,0 +1,2 @@ +localdir/file1.txt -> dbfs:/Volumes/main/default/data/uploaded-dir/file1.txt +localdir/file2.txt -> dbfs:/Volumes/main/default/data/uploaded-dir/file2.txt diff --git a/acceptance/cmd/fs/cp/dir-to-dir/script b/acceptance/cmd/fs/cp/dir-to-dir/script new file mode 100644 index 0000000000..6f0b8510b0 --- /dev/null +++ b/acceptance/cmd/fs/cp/dir-to-dir/script @@ -0,0 +1,6 @@ +mkdir -p localdir +echo -n "file1 content" > localdir/file1.txt +echo -n "file2 content" > localdir/file2.txt + +# Recursive directory copy (output sorted for deterministic ordering). +$CLI fs cp -r localdir dbfs:/Volumes/main/default/data/uploaded-dir 2>&1 | sort diff --git a/acceptance/cmd/fs/cp/dir-to-dir/test.toml b/acceptance/cmd/fs/cp/dir-to-dir/test.toml new file mode 100644 index 0000000000..10da309c5b --- /dev/null +++ b/acceptance/cmd/fs/cp/dir-to-dir/test.toml @@ -0,0 +1,20 @@ +Local = true +Cloud = false +Ignore = ["localdir"] + +# Recursive copy: localdir/ -> uploaded-dir/. +[[Server]] +Pattern = "PUT /api/2.0/fs/directories/Volumes/main/default/data/uploaded-dir" +Response.StatusCode = 200 + +[[Server]] +Pattern = "HEAD /api/2.0/fs/directories/Volumes/main/default/data/uploaded-dir" +Response.StatusCode = 200 + +[[Server]] +Pattern = "PUT /api/2.0/fs/files/Volumes/main/default/data/uploaded-dir/file1.txt" +Response.StatusCode = 200 + +[[Server]] +Pattern = "PUT /api/2.0/fs/files/Volumes/main/default/data/uploaded-dir/file2.txt" +Response.StatusCode = 200 diff --git a/acceptance/cmd/fs/cp/file-to-dir/out.test.toml b/acceptance/cmd/fs/cp/file-to-dir/out.test.toml new file mode 100644 index 0000000000..d560f1de04 --- /dev/null +++ b/acceptance/cmd/fs/cp/file-to-dir/out.test.toml @@ -0,0 +1,5 @@ +Local = true +Cloud = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/fs/cp/file-to-dir/output.txt b/acceptance/cmd/fs/cp/file-to-dir/output.txt new file mode 100644 index 0000000000..1bbd3d2cad --- /dev/null +++ b/acceptance/cmd/fs/cp/file-to-dir/output.txt @@ -0,0 +1,3 @@ + +>>> [CLI] fs cp local.txt dbfs:/Volumes/main/default/data/mydir/ +local.txt -> dbfs:/Volumes/main/default/data/mydir/local.txt diff --git a/acceptance/cmd/fs/cp/file-to-dir/script b/acceptance/cmd/fs/cp/file-to-dir/script new file mode 100644 index 0000000000..d21baf28bd --- /dev/null +++ b/acceptance/cmd/fs/cp/file-to-dir/script @@ -0,0 +1,4 @@ +echo -n "hello world!" > local.txt + +# Copy file into a directory (trailing slash indicates directory target). +trace $CLI fs cp local.txt dbfs:/Volumes/main/default/data/mydir/ diff --git a/acceptance/cmd/fs/cp/file-to-dir/test.toml b/acceptance/cmd/fs/cp/file-to-dir/test.toml new file mode 100644 index 0000000000..d8c7892808 --- /dev/null +++ b/acceptance/cmd/fs/cp/file-to-dir/test.toml @@ -0,0 +1,12 @@ +Local = true +Cloud = false +Ignore = ["local.txt"] + +# Copy file into existing directory: local.txt -> mydir/local.txt. +[[Server]] +Pattern = "HEAD /api/2.0/fs/directories/Volumes/main/default/data/mydir" +Response.StatusCode = 200 + +[[Server]] +Pattern = "PUT /api/2.0/fs/files/Volumes/main/default/data/mydir/local.txt" +Response.StatusCode = 200 diff --git a/acceptance/cmd/fs/cp/file-to-file/out.test.toml b/acceptance/cmd/fs/cp/file-to-file/out.test.toml new file mode 100644 index 0000000000..d560f1de04 --- /dev/null +++ b/acceptance/cmd/fs/cp/file-to-file/out.test.toml @@ -0,0 +1,5 @@ +Local = true +Cloud = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/fs/cp/file-to-file/output.txt b/acceptance/cmd/fs/cp/file-to-file/output.txt new file mode 100644 index 0000000000..93b9425293 --- /dev/null +++ b/acceptance/cmd/fs/cp/file-to-file/output.txt @@ -0,0 +1,7 @@ + +>>> [CLI] fs cp local.txt dbfs:/Volumes/main/default/data/uploaded.txt +local.txt -> dbfs:/Volumes/main/default/data/uploaded.txt + +>>> [CLI] fs cp dbfs:/Volumes/main/default/data/remote.txt downloaded.txt +dbfs:/Volumes/main/default/data/remote.txt -> downloaded.txt +content from volume \ No newline at end of file diff --git a/acceptance/cmd/fs/cp/file-to-file/script b/acceptance/cmd/fs/cp/file-to-file/script new file mode 100644 index 0000000000..510ec04618 --- /dev/null +++ b/acceptance/cmd/fs/cp/file-to-file/script @@ -0,0 +1,9 @@ +echo -n "hello world!" > local.txt + +# Upload local file to volume. +trace $CLI fs cp local.txt dbfs:/Volumes/main/default/data/uploaded.txt + +# Download file from volume to local. +trace $CLI fs cp dbfs:/Volumes/main/default/data/remote.txt downloaded.txt + +cat downloaded.txt diff --git a/acceptance/cmd/fs/cp/file-to-file/test.toml b/acceptance/cmd/fs/cp/file-to-file/test.toml new file mode 100644 index 0000000000..e083bbfdbc --- /dev/null +++ b/acceptance/cmd/fs/cp/file-to-file/test.toml @@ -0,0 +1,34 @@ +Local = true +Cloud = false +Ignore = ["local.txt", "downloaded.txt"] + +# Upload: local.txt -> dbfs:/Volumes/.../uploaded.txt. +[[Server]] +Pattern = "HEAD /api/2.0/fs/directories/Volumes/main/default/data/uploaded.txt" +Response.StatusCode = 404 + +[[Server]] +Pattern = "HEAD /api/2.0/fs/files/Volumes/main/default/data/uploaded.txt" +Response.StatusCode = 404 + +[[Server]] +Pattern = "HEAD /api/2.0/fs/directories/Volumes/main/default/data" +Response.StatusCode = 200 + +[[Server]] +Pattern = "PUT /api/2.0/fs/files/Volumes/main/default/data/uploaded.txt" +Response.StatusCode = 200 + +# Download: dbfs:/Volumes/.../remote.txt -> downloaded.txt. +[[Server]] +Pattern = "HEAD /api/2.0/fs/directories/Volumes/main/default/data/remote.txt" +Response.StatusCode = 404 + +[[Server]] +Pattern = "HEAD /api/2.0/fs/files/Volumes/main/default/data/remote.txt" +Response.StatusCode = 200 + +[[Server]] +Pattern = "GET /api/2.0/fs/files/Volumes/main/default/data/remote.txt" +Response.StatusCode = 200 +Response.Body = "content from volume" diff --git a/acceptance/cmd/fs/cp/input-validation/out.test.toml b/acceptance/cmd/fs/cp/input-validation/out.test.toml new file mode 100644 index 0000000000..d560f1de04 --- /dev/null +++ b/acceptance/cmd/fs/cp/input-validation/out.test.toml @@ -0,0 +1,5 @@ +Local = true +Cloud = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/fs/cp/input-validation/output.txt b/acceptance/cmd/fs/cp/input-validation/output.txt new file mode 100644 index 0000000000..febe55b74e --- /dev/null +++ b/acceptance/cmd/fs/cp/input-validation/output.txt @@ -0,0 +1,10 @@ + +>>> errcode [CLI] fs cp src dst --concurrency -1 +Error: --concurrency must be at least 1 + +Exit code: 1 + +>>> errcode [CLI] fs cp src dst --concurrency 0 +Error: --concurrency must be at least 1 + +Exit code: 1 diff --git a/acceptance/cmd/fs/cp/input-validation/script b/acceptance/cmd/fs/cp/input-validation/script new file mode 100644 index 0000000000..a5e8cec862 --- /dev/null +++ b/acceptance/cmd/fs/cp/input-validation/script @@ -0,0 +1,3 @@ +# Invalid concurrency values should fail. +trace errcode $CLI fs cp src dst --concurrency -1 +trace errcode $CLI fs cp src dst --concurrency 0 diff --git a/acceptance/cmd/fs/cp/input-validation/test.toml b/acceptance/cmd/fs/cp/input-validation/test.toml new file mode 100644 index 0000000000..7d36fb9dc1 --- /dev/null +++ b/acceptance/cmd/fs/cp/input-validation/test.toml @@ -0,0 +1,2 @@ +Local = true +Cloud = false diff --git a/cmd/fs/cp.go b/cmd/fs/cp.go index 275620d7c3..d7ae651753 100644 --- a/cmd/fs/cp.go +++ b/cmd/fs/cp.go @@ -9,94 +9,141 @@ import ( "path" "path/filepath" "strings" + "sync" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/filer" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) +// Default number of concurrent file copy operations. This is a conservative +// default that should be sufficient to fully utilize the available bandwidth +// in most cases. +const defaultConcurrency = 8 + +// errInvalidConcurrency is returned when the value of the concurrency +// flag is invalid. +var errInvalidConcurrency = errors.New("--concurrency must be at least 1") + type copy struct { - overwrite bool - recursive bool + overwrite bool + recursive bool + concurrency int - ctx context.Context sourceFiler filer.Filer targetFiler filer.Filer sourceScheme string targetScheme string + + mu sync.Mutex // protect output from concurrent writes } -func (c *copy) cpWriteCallback(sourceDir, targetDir string) fs.WalkDirFunc { - return func(sourcePath string, d fs.DirEntry, err error) error { +// cpDirToDir recursively copies the content of a directory to another +// directory. +// +// There is no guarantee on the order in which the files are copied. +// +// The method does not take care of retrying on error; this is considered to +// be the responsibility of the Filer implementation. If a file copy fails, +// the error is returned and the other copies are cancelled. +func (c *copy) cpDirToDir(ctx context.Context, sourceDir, targetDir string) error { + if !c.recursive { + return fmt.Errorf("source path %s is a directory. Please specify the --recursive flag", sourceDir) + } + + // Create a cancellable context purely for the purpose of having a way to + // cancel the goroutines in case of error walking the directory. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Pool of workers to process copy operations in parallel. The created + // context is the real context for this operation. It is shared by the + // walking function and the goroutines and can be cancelled manually + // by calling the cancel() function of its parent context. + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(c.concurrency) + + // Walk the source directory, queueing file copy operations for processing. + sourceFs := filer.NewFS(ctx, c.sourceFiler) + err := fs.WalkDir(sourceFs, sourceDir, func(sourcePath string, d fs.DirEntry, err error) error { if err != nil { return err } - // Compute path relative to the target directory + // Compute path relative to the source directory. relPath, err := filepath.Rel(sourceDir, sourcePath) if err != nil { return err } relPath = filepath.ToSlash(relPath) - // Compute target path for the file + // Compute target path for the file. targetPath := path.Join(targetDir, relPath) - // create directory and return early + // Create the directory synchronously. This must happen before files + // are copied into it, and WalkDir guarantees directories are visited + // before their contents. if d.IsDir() { - return c.targetFiler.Mkdir(c.ctx, targetPath) + return c.targetFiler.Mkdir(ctx, targetPath) } - return c.cpFileToFile(sourcePath, targetPath) - } -} - -func (c *copy) cpDirToDir(sourceDir, targetDir string) error { - if !c.recursive { - return fmt.Errorf("source path %s is a directory. Please specify the --recursive flag", sourceDir) + g.Go(func() error { + // Goroutines are queued and may start after the context is already + // cancelled (e.g. a prior copy failed). This check aims to avoid + // starting work that will inevitably fail. + if ctx.Err() != nil { + return ctx.Err() + } + return c.cpFileToFile(ctx, sourcePath, targetPath) + }) + return nil + }) + if err != nil { + cancel() // cancel the goroutines + _ = g.Wait() // wait for the goroutines to finish + return err // return the "real" error that led to cancellation } - - sourceFs := filer.NewFS(c.ctx, c.sourceFiler) - return fs.WalkDir(sourceFs, sourceDir, c.cpWriteCallback(sourceDir, targetDir)) + return g.Wait() } -func (c *copy) cpFileToDir(sourcePath, targetDir string) error { +func (c *copy) cpFileToDir(ctx context.Context, sourcePath, targetDir string) error { fileName := filepath.Base(sourcePath) targetPath := path.Join(targetDir, fileName) - return c.cpFileToFile(sourcePath, targetPath) + return c.cpFileToFile(ctx, sourcePath, targetPath) } -func (c *copy) cpFileToFile(sourcePath, targetPath string) error { +func (c *copy) cpFileToFile(ctx context.Context, sourcePath, targetPath string) error { // Get reader for file at source path - r, err := c.sourceFiler.Read(c.ctx, sourcePath) + r, err := c.sourceFiler.Read(ctx, sourcePath) if err != nil { return err } defer r.Close() if c.overwrite { - err = c.targetFiler.Write(c.ctx, targetPath, r, filer.OverwriteIfExists) + err = c.targetFiler.Write(ctx, targetPath, r, filer.OverwriteIfExists) if err != nil { return err } } else { - err = c.targetFiler.Write(c.ctx, targetPath, r) + err = c.targetFiler.Write(ctx, targetPath, r) // skip if file already exists if err != nil && errors.Is(err, fs.ErrExist) { - return c.emitFileSkippedEvent(sourcePath, targetPath) + return c.emitFileSkippedEvent(ctx, sourcePath, targetPath) } if err != nil { return err } } - return c.emitFileCopiedEvent(sourcePath, targetPath) + return c.emitFileCopiedEvent(ctx, sourcePath, targetPath) } // TODO: emit these events on stderr // TODO: add integration tests for these events -func (c *copy) emitFileSkippedEvent(sourcePath, targetPath string) error { +func (c *copy) emitFileSkippedEvent(ctx context.Context, sourcePath, targetPath string) error { fullSourcePath := sourcePath if c.sourceScheme != "" { fullSourcePath = path.Join(c.sourceScheme+":", sourcePath) @@ -109,10 +156,12 @@ func (c *copy) emitFileSkippedEvent(sourcePath, targetPath string) error { event := newFileSkippedEvent(fullSourcePath, fullTargetPath) template := "{{.SourcePath}} -> {{.TargetPath}} (skipped; already exists)\n" - return cmdio.RenderWithTemplate(c.ctx, event, "", template) + c.mu.Lock() + defer c.mu.Unlock() + return cmdio.RenderWithTemplate(ctx, event, "", template) } -func (c *copy) emitFileCopiedEvent(sourcePath, targetPath string) error { +func (c *copy) emitFileCopiedEvent(ctx context.Context, sourcePath, targetPath string) error { fullSourcePath := sourcePath if c.sourceScheme != "" { fullSourcePath = path.Join(c.sourceScheme+":", sourcePath) @@ -125,7 +174,9 @@ func (c *copy) emitFileCopiedEvent(sourcePath, targetPath string) error { event := newFileCopiedEvent(fullSourcePath, fullTargetPath) template := "{{.SourcePath}} -> {{.TargetPath}}\n" - return cmdio.RenderWithTemplate(c.ctx, event, "", template) + c.mu.Lock() + defer c.mu.Unlock() + return cmdio.RenderWithTemplate(ctx, event, "", template) } // hasTrailingDirSeparator checks if a path ends with a directory separator. @@ -153,13 +204,20 @@ func newCpCommand() *cobra.Command { When copying a file, if TARGET_PATH is a directory, the file will be created inside the directory, otherwise the file is created at TARGET_PATH. `, - Args: root.ExactArgs(2), - PreRunE: root.MustWorkspaceClient, + Args: root.ExactArgs(2), } var c copy cmd.Flags().BoolVar(&c.overwrite, "overwrite", false, "overwrite existing files") cmd.Flags().BoolVarP(&c.recursive, "recursive", "r", false, "recursively copy files from directory") + cmd.Flags().IntVar(&c.concurrency, "concurrency", defaultConcurrency, "number of parallel copy operations") + + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { + if c.concurrency <= 0 { + return errInvalidConcurrency + } + return root.MustWorkspaceClient(cmd, args) + } cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() @@ -187,7 +245,6 @@ func newCpCommand() *cobra.Command { c.targetScheme = "dbfs" } - c.ctx = ctx c.sourceFiler = sourceFiler c.targetFiler = targetFiler @@ -199,7 +256,7 @@ func newCpCommand() *cobra.Command { // case 1: source path is a directory, then recursively create files at target path if sourceInfo.IsDir() { - return c.cpDirToDir(sourcePath, targetPath) + return c.cpDirToDir(ctx, sourcePath, targetPath) } // If target path has a trailing separator, trim it and let case 2 handle it @@ -210,11 +267,11 @@ func newCpCommand() *cobra.Command { // case 2: source path is a file, and target path is a directory. In this case // we copy the file to inside the directory if targetInfo, err := targetFiler.Stat(ctx, targetPath); err == nil && targetInfo.IsDir() { - return c.cpFileToDir(sourcePath, targetPath) + return c.cpFileToDir(ctx, sourcePath, targetPath) } // case 3: source path is a file, and target path is a file - return c.cpFileToFile(sourcePath, targetPath) + return c.cpFileToFile(ctx, sourcePath, targetPath) } v := newValidArgs() diff --git a/cmd/fs/cp_test.go b/cmd/fs/cp_test.go new file mode 100644 index 0000000000..b50ee0658f --- /dev/null +++ b/cmd/fs/cp_test.go @@ -0,0 +1,197 @@ +package fs + +import ( + "context" + "errors" + "io" + "io/fs" + "strings" + "testing" + "time" + + "github.com/databricks/cli/libs/filer" +) + +// mockFiler mocks filer.Filer. +type mockFiler struct { + write func(ctx context.Context, path string, r io.Reader, mode ...filer.WriteMode) error + read func(ctx context.Context, path string) (io.ReadCloser, error) + delete func(ctx context.Context, path string, mode ...filer.DeleteMode) error + readDir func(ctx context.Context, path string) ([]fs.DirEntry, error) + mkdir func(ctx context.Context, path string) error + stat func(ctx context.Context, path string) (fs.FileInfo, error) +} + +func (m *mockFiler) Write(ctx context.Context, path string, r io.Reader, mode ...filer.WriteMode) error { + if m.write == nil { + return nil + } + return m.write(ctx, path, r, mode...) +} + +func (m *mockFiler) Read(ctx context.Context, path string) (io.ReadCloser, error) { + if m.read == nil { + return nil, nil + } + return m.read(ctx, path) +} + +func (m *mockFiler) Delete(ctx context.Context, path string, mode ...filer.DeleteMode) error { + if m.delete == nil { + return nil + } + return m.delete(ctx, path, mode...) +} + +func (m *mockFiler) ReadDir(ctx context.Context, path string) ([]fs.DirEntry, error) { + if m.readDir == nil { + return nil, nil + } + return m.readDir(ctx, path) +} + +func (m *mockFiler) Mkdir(ctx context.Context, path string) error { + if m.mkdir == nil { + return nil + } + return m.mkdir(ctx, path) +} + +func (m *mockFiler) Stat(ctx context.Context, path string) (fs.FileInfo, error) { + if m.stat == nil { + return nil, nil + } + return m.stat(ctx, path) +} + +// mockFileInfo mocks fs.FileInfo. +type mockFileInfo struct { + name string + isDir bool +} + +func (m mockFileInfo) Name() string { return m.name } +func (m mockFileInfo) Size() int64 { return 0 } +func (m mockFileInfo) Mode() fs.FileMode { return 0o644 } +func (m mockFileInfo) ModTime() time.Time { return time.Time{} } +func (m mockFileInfo) IsDir() bool { return m.isDir } +func (m mockFileInfo) Sys() any { return nil } + +// mockDirEntry mocks fs.DirEntry. +type mockDirEntry struct { + name string + isDir bool +} + +func (m mockDirEntry) Name() string { return m.name } +func (m mockDirEntry) IsDir() bool { return m.isDir } +func (m mockDirEntry) Type() fs.FileMode { return 0 } +func (m mockDirEntry) Info() (fs.FileInfo, error) { + return mockFileInfo(m), nil +} + +func TestCp_cpDirToDir_contextCancellation(t *testing.T) { + testError := errors.New("test error") + + // Mock the stats and readDir methods for a Filer over a file system that + // has the following directory structure: + // + // src/ + // ├── subdir/ + // ├── file1.txt + // ├── file2.txt + // └── file3.txt + // + mockSourceStat := func(ctx context.Context, path string) (fs.FileInfo, error) { + isDir := path == "src" || path == "src/subdir" + return mockFileInfo{name: path, isDir: isDir}, nil + } + mockSourceReadDir := func(ctx context.Context, path string) ([]fs.DirEntry, error) { + if path == "src" { + return []fs.DirEntry{ + mockDirEntry{name: "subdir", isDir: true}, + mockDirEntry{name: "file1.txt", isDir: false}, + mockDirEntry{name: "file2.txt", isDir: false}, + mockDirEntry{name: "file3.txt", isDir: false}, + }, nil + } + return nil, nil + } + + testCases := []struct { + desc string + c *copy + wantErr error + }{ + { + // The source filer's Read method blocks until context is cancelled, + // simulating a slow file copy operation. The target filer's Mkdir + // method returns an error which should cancel the walk and all file + // copy goroutines. + desc: "cancel go routines on walk error", + c: ©{ + recursive: true, + concurrency: 5, + sourceFiler: &mockFiler{ + stat: mockSourceStat, + readDir: mockSourceReadDir, + read: func(ctx context.Context, path string) (io.ReadCloser, error) { + <-ctx.Done() // block until context is cancelled + return nil, ctx.Err() + }, + }, + targetFiler: &mockFiler{ + mkdir: func(ctx context.Context, path string) error { + return testError + }, + }, + }, + wantErr: testError, + }, + { + // The target filer's Write method returns an error when writing the + // file1.txt file. This error is expected to be returned by the file copy + // goroutine and all other file copy goroutines should be cancelled. + desc: "cancel go routines on file copy error", + c: ©{ + recursive: true, + concurrency: 5, + sourceFiler: &mockFiler{ + stat: mockSourceStat, + readDir: mockSourceReadDir, + read: func(ctx context.Context, path string) (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader("content")), nil + }, + }, + targetFiler: &mockFiler{ + write: func(ctx context.Context, path string, r io.Reader, mode ...filer.WriteMode) error { + if path == "dst/file1.txt" { + return testError + } + <-ctx.Done() // block until context is cancelled + return ctx.Err() + }, + }, + }, + wantErr: testError, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + done := make(chan error, 1) + go func() { + done <- tc.c.cpDirToDir(t.Context(), "src", "dst") + }() + + select { + case gotErr := <-done: + if !errors.Is(gotErr, tc.wantErr) { + t.Errorf("want error %v, got %v", tc.wantErr, gotErr) + } + case <-time.After(3 * time.Second): // do not wait too long in case of test issues + t.Fatal("cpDirToDir blocked instead of returning error immediately") + } + }) + } +}