diff --git a/bundle/generate/downloader.go b/bundle/generate/downloader.go index 1d6ed73ce9..4d1c883b25 100644 --- a/bundle/generate/downloader.go +++ b/bundle/generate/downloader.go @@ -3,22 +3,20 @@ package generate import ( "context" "fmt" - "io" "net/http" - "os" "path" "path/filepath" "strings" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/notebook" "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/client" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/pipelines" "github.com/databricks/databricks-sdk-go/service/workspace" "golang.org/x/sync/errgroup" - - "github.com/databricks/databricks-sdk-go/client" ) type exportFile struct { @@ -27,11 +25,12 @@ type exportFile struct { } type Downloader struct { - files map[string]exportFile - w *databricks.WorkspaceClient - sourceDir string - configDir string - basePath string + files map[string]exportFile + w *databricks.WorkspaceClient + sourceDir string + configDir string + basePath string + outputFiler filer.Filer } func (n *Downloader) MarkTaskForDownload(ctx context.Context, task *jobs.Task) error { @@ -194,7 +193,7 @@ func (n *Downloader) relativePath(fullPath string) string { func (n *Downloader) FlushToDisk(ctx context.Context, force bool) error { // First check that all files can be written for targetPath := range n.files { - info, err := os.Stat(targetPath) + info, err := n.outputFiler.Stat(ctx, targetPath) if err == nil { if info.IsDir() { return fmt.Errorf("%s is a directory", targetPath) @@ -207,42 +206,36 @@ func (n *Downloader) FlushToDisk(ctx context.Context, force bool) error { errs, errCtx := errgroup.WithContext(ctx) for targetPath, exportFile := range n.files { - // Create parent directories if they don't exist - dir := filepath.Dir(targetPath) - err := os.MkdirAll(dir, 0o755) - if err != nil { - return err - } errs.Go(func() error { reader, err := n.w.Workspace.Download(errCtx, exportFile.path, workspace.DownloadFormat(exportFile.format)) if err != nil { return err } + defer reader.Close() - file, err := os.Create(targetPath) - if err != nil { - return err + mode := []filer.WriteMode{filer.CreateParentDirectories} + if force { + mode = append(mode, filer.OverwriteIfExists) } - defer file.Close() - - _, err = io.Copy(file, reader) + err = n.outputFiler.Write(errCtx, targetPath, reader, mode...) if err != nil { return err } cmdio.LogString(errCtx, "File successfully saved to "+targetPath) - return reader.Close() + return nil }) } return errs.Wait() } -func NewDownloader(w *databricks.WorkspaceClient, sourceDir, configDir string) *Downloader { +func NewDownloader(w *databricks.WorkspaceClient, sourceDir, configDir string, outputFiler filer.Filer) *Downloader { return &Downloader{ - files: make(map[string]exportFile), - w: w, - sourceDir: sourceDir, - configDir: configDir, + files: make(map[string]exportFile), + w: w, + sourceDir: sourceDir, + configDir: configDir, + outputFiler: outputFiler, } } diff --git a/bundle/generate/downloader_test.go b/bundle/generate/downloader_test.go index d87b373262..d453e210f5 100644 --- a/bundle/generate/downloader_test.go +++ b/bundle/generate/downloader_test.go @@ -5,6 +5,8 @@ import ( "path/filepath" "testing" + "github.com/databricks/cli/libs/fakefs" + "github.com/databricks/cli/libs/filer" "github.com/databricks/databricks-sdk-go/experimental/mocks" "github.com/databricks/databricks-sdk-go/service/workspace" "github.com/stretchr/testify/assert" @@ -18,7 +20,8 @@ func TestDownloader_MarkFileReturnsRelativePath(t *testing.T) { dir := "base/dir/doesnt/matter" sourceDir := filepath.Join(dir, "source") configDir := filepath.Join(dir, "config") - downloader := NewDownloader(m.WorkspaceClient, sourceDir, configDir) + fakeFiler := filer.NewFakeFiler(map[string]fakefs.FileInfo{}) + downloader := NewDownloader(m.WorkspaceClient, sourceDir, configDir, fakeFiler) var err error diff --git a/cmd/bundle/generate/app.go b/cmd/bundle/generate/app.go index 323d92835c..122aaccca3 100644 --- a/cmd/bundle/generate/app.go +++ b/cmd/bundle/generate/app.go @@ -10,6 +10,7 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/dyn" "github.com/databricks/cli/libs/dyn/yamlsaver" + "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/logdiag" "github.com/databricks/cli/libs/textutil" "github.com/databricks/databricks-sdk-go/service/apps" @@ -77,7 +78,22 @@ per target environment.`, return err } - downloader := generate.NewDownloader(w, sourceDir, configDir) + outputFiler, err := filer.NewOutputFiler(ctx, b.BundleRootPath) + if err != nil { + return err + } + + // Make sourceDir and configDir relative to the bundle root + sourceDir, err = makeRelativeToRoot(b.BundleRootPath, sourceDir) + if err != nil { + return err + } + configDir, err = makeRelativeToRoot(b.BundleRootPath, configDir) + if err != nil { + return err + } + + downloader := generate.NewDownloader(w, sourceDir, configDir, outputFiler) sourceCodePath := app.DefaultSourceCodePath // If the source code path is not set, we don't need to download anything. @@ -121,7 +137,7 @@ per target environment.`, filename := filepath.Join(configDir, appKey+".app.yml") saver := yamlsaver.NewSaver() - err = saver.SaveAsYAML(result, filename, force) + err = saver.SaveAsYAMLToFiler(ctx, outputFiler, result, filename, force) if err != nil { return err } diff --git a/cmd/bundle/generate/dashboard.go b/cmd/bundle/generate/dashboard.go index b3caa4b679..b674908e35 100644 --- a/cmd/bundle/generate/dashboard.go +++ b/cmd/bundle/generate/dashboard.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "os" "path" "path/filepath" "strings" @@ -25,6 +24,7 @@ import ( "github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/dyn" "github.com/databricks/cli/libs/dyn/yamlsaver" + "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/logdiag" "github.com/databricks/cli/libs/textutil" "github.com/databricks/databricks-sdk-go" @@ -66,6 +66,9 @@ type dashboard struct { // Output and error streams. out io.Writer err io.Writer + + // Output filer for writing files. + outputFiler filer.Filer } func (d *dashboard) resolveID(ctx context.Context, b *bundle.Bundle) string { @@ -165,7 +168,7 @@ func remarshalJSON(data []byte) ([]byte, error) { return buf.Bytes(), nil } -func (d *dashboard) saveSerializedDashboard(_ context.Context, b *bundle.Bundle, dashboard *dashboards.Dashboard, filename string) error { +func (d *dashboard) saveSerializedDashboard(ctx context.Context, dashboard *dashboards.Dashboard, filename string) error { // Unmarshal and remarshal the serialized dashboard to ensure it is formatted correctly. // The result will have alphabetically sorted keys and be indented. data, err := remarshalJSON([]byte(dashboard.SerializedDashboard)) @@ -173,40 +176,34 @@ func (d *dashboard) saveSerializedDashboard(_ context.Context, b *bundle.Bundle, return err } - // Make sure the output directory exists. - if err := os.MkdirAll(filepath.Dir(filename), 0o755); err != nil { - return err - } - // Clean the filename to ensure it is a valid path (and can be used on this OS). filename = filepath.Clean(filename) - // Attempt to make the path relative to the bundle root. - rel, err := filepath.Rel(b.BundleRootPath, filename) - if err != nil { - rel = filename - } - // Verify that the file does not already exist. - info, err := os.Stat(filename) + info, err := d.outputFiler.Stat(ctx, filename) if err == nil { if info.IsDir() { - return fmt.Errorf("%s is a directory", rel) + return fmt.Errorf("%s is a directory", filename) } if !d.force { - return fmt.Errorf("%s already exists. Use --force to overwrite", rel) + return fmt.Errorf("%s already exists. Use --force to overwrite", filename) } } - fmt.Fprintf(d.out, "Writing dashboard to %q\n", rel) - return os.WriteFile(filename, data, 0o644) + fmt.Fprintf(d.out, "Writing dashboard to %q\n", filename) + + mode := []filer.WriteMode{filer.CreateParentDirectories} + if d.force { + mode = append(mode, filer.OverwriteIfExists) + } + return d.outputFiler.Write(ctx, filename, bytes.NewReader(data), mode...) } -func (d *dashboard) saveConfiguration(ctx context.Context, b *bundle.Bundle, dashboard *dashboards.Dashboard, key string) error { +func (d *dashboard) saveConfiguration(ctx context.Context, dashboard *dashboards.Dashboard, key string) error { // Save serialized dashboard definition to the dashboard directory. dashboardBasename := key + ".lvdash.json" dashboardPath := filepath.Join(d.dashboardDir, dashboardBasename) - err := d.saveSerializedDashboard(ctx, b, dashboard, dashboardPath) + err := d.saveSerializedDashboard(ctx, dashboard, dashboardPath) if err != nil { return err } @@ -225,25 +222,14 @@ func (d *dashboard) saveConfiguration(ctx context.Context, b *bundle.Bundle, das }), } - // Make sure the output directory exists. - if err := os.MkdirAll(d.resourceDir, 0o755); err != nil { - return err - } - // Save the configuration to the resource directory. resourcePath := filepath.Join(d.resourceDir, key+".dashboard.yml") saver := yamlsaver.NewSaverWithStyle(map[string]yaml.Style{ "display_name": yaml.DoubleQuotedStyle, }) - // Attempt to make the path relative to the bundle root. - rel, err := filepath.Rel(b.BundleRootPath, resourcePath) - if err != nil { - rel = resourcePath - } - - fmt.Fprintf(d.out, "Writing configuration to %q\n", rel) - err = saver.SaveAsYAML(result, resourcePath, d.force) + fmt.Fprintf(d.out, "Writing configuration to %q\n", resourcePath) + err = saver.SaveAsYAMLToFiler(ctx, d.outputFiler, result, resourcePath, d.force) if err != nil { return err } @@ -306,7 +292,7 @@ func (d *dashboard) updateDashboardForResource(ctx context.Context, b *bundle.Bu } if etag != dashboard.Etag { - err = d.saveSerializedDashboard(ctx, b, dashboard, dashboardPath) + err = d.saveSerializedDashboard(ctx, dashboard, dashboardPath) if err != nil { logdiag.LogError(ctx, err) return @@ -338,7 +324,7 @@ func (d *dashboard) generateForExisting(ctx context.Context, b *bundle.Bundle, d } key := textutil.NormalizeString(dashboard.DisplayName) - err = d.saveConfiguration(ctx, b, dashboard, key) + err = d.saveConfiguration(ctx, dashboard, key) if err != nil { logdiag.LogError(ctx, err) } @@ -354,12 +340,18 @@ func (d *dashboard) generateForExisting(ctx context.Context, b *bundle.Bundle, d } func (d *dashboard) initialize(ctx context.Context, b *bundle.Bundle) { - // Make the paths absolute if they aren't already. - if !filepath.IsAbs(d.resourceDir) { - d.resourceDir = filepath.Join(b.BundleRootPath, d.resourceDir) + var err error + + // Make paths relative to the bundle root (required for the filer which is rooted there). + d.resourceDir, err = makeRelativeToRoot(b.BundleRootPath, d.resourceDir) + if err != nil { + logdiag.LogError(ctx, err) + return } - if !filepath.IsAbs(d.dashboardDir) { - d.dashboardDir = filepath.Join(b.BundleRootPath, d.dashboardDir) + d.dashboardDir, err = makeRelativeToRoot(b.BundleRootPath, d.dashboardDir) + if err != nil { + logdiag.LogError(ctx, err) + return } // Make sure we know how the dashboard path is relative to the resource path. @@ -370,6 +362,14 @@ func (d *dashboard) initialize(ctx context.Context, b *bundle.Bundle) { } d.relativeDashboardDir = filepath.ToSlash(rel) + + // Construct output filer for writing files. + outputFiler, err := filer.NewOutputFiler(ctx, b.BundleRootPath) + if err != nil { + logdiag.LogError(ctx, err) + return + } + d.outputFiler = outputFiler } func (d *dashboard) runForResource(ctx context.Context, b *bundle.Bundle) { diff --git a/cmd/bundle/generate/job.go b/cmd/bundle/generate/job.go index fd5436fced..9c43664ff0 100644 --- a/cmd/bundle/generate/job.go +++ b/cmd/bundle/generate/job.go @@ -1,10 +1,12 @@ package generate import ( + "bytes" + "context" "errors" "fmt" + "io" "io/fs" - "os" "path/filepath" "strconv" @@ -14,6 +16,7 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/dyn" "github.com/databricks/cli/libs/dyn/yamlsaver" + "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/logdiag" "github.com/databricks/cli/libs/textutil" "github.com/databricks/databricks-sdk-go/service/jobs" @@ -21,6 +24,31 @@ import ( "gopkg.in/yaml.v3" ) +// filerRename renames a file using filer operations (read, write, delete). +// This is needed because the filer interface doesn't have a native rename method. +func filerRename(ctx context.Context, f filer.Filer, oldPath, newPath string) error { + // Read the old file + r, err := f.Read(ctx, oldPath) + if err != nil { + return err + } + defer r.Close() + + content, err := io.ReadAll(r) + if err != nil { + return err + } + + // Write to new path + err = f.Write(ctx, newPath, bytes.NewReader(content), filer.CreateParentDirectories, filer.OverwriteIfExists) + if err != nil { + return err + } + + // Delete the old file + return f.Delete(ctx, oldPath) +} + func NewGenerateJobCommand() *cobra.Command { var configDir string var sourceDir string @@ -80,7 +108,22 @@ After generation, you can deploy this job to other targets using: return err } - downloader := generate.NewDownloader(w, sourceDir, configDir) + outputFiler, err := filer.NewOutputFiler(ctx, b.BundleRootPath) + if err != nil { + return err + } + + // Make sourceDir and configDir relative to the bundle root + sourceDir, err = makeRelativeToRoot(b.BundleRootPath, sourceDir) + if err != nil { + return err + } + configDir, err = makeRelativeToRoot(b.BundleRootPath, configDir) + if err != nil { + return err + } + + downloader := generate.NewDownloader(w, sourceDir, configDir, outputFiler) // Don't download files if the job is using Git source // When Git source is used, the job will be using the files from the Git repository @@ -129,7 +172,7 @@ After generation, you can deploy this job to other targets using: // User might continuously run generate command to update their bundle jobs with any changes made in Databricks UI. // Due to changing in the generated file names, we need to first rename existing resource file to the new name. // Otherwise users can end up with duplicated resources. - err = os.Rename(oldFilename, filename) + err = filerRename(ctx, outputFiler, oldFilename, filename) if err != nil && !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("failed to rename file %s. DABs uses the resource type as a sub-extension for generated content, please rename it to %s, err: %w", oldFilename, filename, err) } @@ -140,7 +183,7 @@ After generation, you can deploy this job to other targets using: "custom_tags": yaml.DoubleQuotedStyle, "tags": yaml.DoubleQuotedStyle, }) - err = saver.SaveAsYAML(result, filename, force) + err = saver.SaveAsYAMLToFiler(ctx, outputFiler, result, filename, force) if err != nil { return err } diff --git a/cmd/bundle/generate/pipeline.go b/cmd/bundle/generate/pipeline.go index d725d746f8..f3f96da954 100644 --- a/cmd/bundle/generate/pipeline.go +++ b/cmd/bundle/generate/pipeline.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "io/fs" - "os" "path/filepath" "github.com/databricks/cli/bundle/generate" @@ -13,6 +12,7 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/dyn" "github.com/databricks/cli/libs/dyn/yamlsaver" + "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/logdiag" "github.com/databricks/cli/libs/textutil" "github.com/databricks/databricks-sdk-go/service/pipelines" @@ -79,7 +79,22 @@ like catalogs, schemas, and compute configurations per target.`, return err } - downloader := generate.NewDownloader(w, sourceDir, configDir) + outputFiler, err := filer.NewOutputFiler(ctx, b.BundleRootPath) + if err != nil { + return err + } + + // Make sourceDir and configDir relative to the bundle root + sourceDir, err = makeRelativeToRoot(b.BundleRootPath, sourceDir) + if err != nil { + return err + } + configDir, err = makeRelativeToRoot(b.BundleRootPath, configDir) + if err != nil { + return err + } + + downloader := generate.NewDownloader(w, sourceDir, configDir, outputFiler) for _, lib := range pipeline.Spec.Libraries { err := downloader.MarkPipelineLibraryForDownload(ctx, &lib) if err != nil { @@ -131,7 +146,7 @@ like catalogs, schemas, and compute configurations per target.`, // User might continuously run generate command to update their bundle jobs with any changes made in Databricks UI. // Due to changing in the generated file names, we need to first rename existing resource file to the new name. // Otherwise users can end up with duplicated resources. - err = os.Rename(oldFilename, filename) + err = filerRename(ctx, outputFiler, oldFilename, filename) if err != nil && !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("failed to rename file %s. DABs uses the resource type as a sub-extension for generated content, please rename it to %s, err: %w", oldFilename, filename, err) } @@ -144,7 +159,7 @@ like catalogs, schemas, and compute configurations per target.`, "configuration": yaml.DoubleQuotedStyle, }, ) - err = saver.SaveAsYAML(result, filename, force) + err = saver.SaveAsYAMLToFiler(ctx, outputFiler, result, filename, force) if err != nil { return err } diff --git a/cmd/bundle/generate/utils.go b/cmd/bundle/generate/utils.go new file mode 100644 index 0000000000..7ed21d2ad2 --- /dev/null +++ b/cmd/bundle/generate/utils.go @@ -0,0 +1,18 @@ +package generate + +import ( + "path/filepath" +) + +// makeRelativeToRoot converts a path to be relative to the bundle root. +// If the path is already relative, it is returned as-is. +// If the path is absolute and under the root, it is made relative. +// This is needed because the output filer is rooted at the bundle root, +// and paths must be relative to that root for the filer to write correctly. +func makeRelativeToRoot(root, path string) (string, error) { + if !filepath.IsAbs(path) { + return path, nil + } + + return filepath.Rel(root, path) +} diff --git a/libs/dyn/yamlsaver/saver.go b/libs/dyn/yamlsaver/saver.go index 8aaa260377..d758fbe7e6 100644 --- a/libs/dyn/yamlsaver/saver.go +++ b/libs/dyn/yamlsaver/saver.go @@ -1,6 +1,8 @@ package yamlsaver import ( + "bytes" + "context" "fmt" "io" "os" @@ -9,6 +11,7 @@ import ( "strconv" "github.com/databricks/cli/libs/dyn" + "github.com/databricks/cli/libs/filer" "gopkg.in/yaml.v3" ) @@ -56,6 +59,32 @@ func (s *saver) SaveAsYAML(data any, filename string, force bool) error { return nil } +// SaveAsYAMLToFiler saves the data as YAML to the given filename using the provided filer. +func (s *saver) SaveAsYAMLToFiler(ctx context.Context, f filer.Filer, data any, filename string, force bool) error { + // check that file exists + info, err := f.Stat(ctx, filename) + if err == nil { + if info.IsDir() { + return fmt.Errorf("%s is a directory", filename) + } + if !force { + return fmt.Errorf("%s already exists. Use --force to overwrite", filename) + } + } + + var buf bytes.Buffer + err = s.encode(data, &buf) + if err != nil { + return err + } + + mode := []filer.WriteMode{filer.CreateParentDirectories} + if force { + mode = append(mode, filer.OverwriteIfExists) + } + return f.Write(ctx, filename, &buf, mode...) +} + func (s *saver) encode(data any, w io.Writer) error { yamlNode, err := s.toYamlNode(dyn.V(data)) if err != nil { diff --git a/libs/filer/output_filer.go b/libs/filer/output_filer.go new file mode 100644 index 0000000000..a7b1681596 --- /dev/null +++ b/libs/filer/output_filer.go @@ -0,0 +1,33 @@ +package filer + +import ( + "context" + "path/filepath" + "strings" + + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/dbr" +) + +// NewOutputFiler creates a filer for writing output files. +// When running on DBR and writing to the workspace filesystem, it uses the +// workspace files extensions client (import/export API) to support writing notebooks. +// Otherwise, it uses the local filesystem client. +// +// It is not possible to write notebooks through the workspace filesystem's FUSE mount for DBR versions less than 16.4. +// This function ensures the correct filer is used based on the runtime environment. +func NewOutputFiler(ctx context.Context, outputDir string) (Filer, error) { + outputDir, err := filepath.Abs(outputDir) + if err != nil { + return nil, err + } + + // If the CLI is running on DBR and we're writing to the workspace file system, + // use the extension-aware workspace filesystem filer. + if strings.HasPrefix(outputDir, "/Workspace/") && dbr.RunsOnRuntime(ctx) { + w := cmdctx.WorkspaceClient(ctx) + return NewWorkspaceFilesExtensionsClient(w, outputDir) + } + + return NewLocalClient(outputDir) +} diff --git a/libs/filer/output_filer_test.go b/libs/filer/output_filer_test.go new file mode 100644 index 0000000000..ed3f4ae9e2 --- /dev/null +++ b/libs/filer/output_filer_test.go @@ -0,0 +1,58 @@ +package filer + +import ( + "context" + "runtime" + "testing" + + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/dbr" + "github.com/databricks/databricks-sdk-go" + workspaceConfig "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewOutputFilerLocal(t *testing.T) { + ctx := dbr.MockRuntime(context.Background(), dbr.Environment{IsDbr: false}) + + tmpDir := t.TempDir() + f, err := NewOutputFiler(ctx, tmpDir) + require.NoError(t, err) + + assert.IsType(t, &LocalClient{}, f) +} + +func TestNewOutputFilerLocalForNonWorkspacePath(t *testing.T) { + // This test is not valid on windows because a DBR image is always based on Linux. + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows") + } + + // Even on DBR, if path doesn't start with /Workspace/, use local client + ctx := dbr.MockRuntime(context.Background(), dbr.Environment{IsDbr: true, Version: "15.4"}) + + tmpDir := t.TempDir() + f, err := NewOutputFiler(ctx, tmpDir) + require.NoError(t, err) + + assert.IsType(t, &LocalClient{}, f) +} + +func TestNewOutputFilerDBR(t *testing.T) { + // This test is not valid on windows because a DBR image is always based on Linux. + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows") + } + + ctx := dbr.MockRuntime(context.Background(), dbr.Environment{IsDbr: true, Version: "15.4"}) + ctx = cmdctx.SetWorkspaceClient(ctx, &databricks.WorkspaceClient{ + Config: &workspaceConfig.Config{Host: "https://myhost.com"}, + }) + + // On DBR with /Workspace/ path, should use workspace files extensions client + f, err := NewOutputFiler(ctx, "/Workspace/Users/test@example.com/my-bundle") + require.NoError(t, err) + + assert.IsType(t, &WorkspaceFilesExtensionsClient{}, f) +} diff --git a/libs/template/resolver_test.go b/libs/template/resolver_test.go index 1dee1c45fe..abb4d30531 100644 --- a/libs/template/resolver_test.go +++ b/libs/template/resolver_test.go @@ -5,23 +5,28 @@ import ( "testing" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/dbr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func testContext(t *testing.T) context.Context { + return dbr.DetectRuntime(context.Background()) +} + func TestTemplateResolverBothTagAndBranch(t *testing.T) { r := Resolver{ Tag: "tag", Branch: "branch", } - _, err := r.Resolve(context.Background()) + _, err := r.Resolve(testContext(t)) assert.EqualError(t, err, "only one of tag or branch can be specified") } func TestTemplateResolverErrorsWhenPromptingIsNotSupported(t *testing.T) { r := Resolver{} - ctx := cmdio.MockDiscard(context.Background()) + ctx := cmdio.MockDiscard(testContext(t)) _, err := r.Resolve(ctx) assert.EqualError(t, err, "prompting is not supported. Please specify the path, name or URL of the template to use") @@ -38,7 +43,7 @@ func TestTemplateResolverForDefaultTemplates(t *testing.T) { TemplatePathOrUrl: name, } - tmpl, err := r.Resolve(context.Background()) + tmpl, err := r.Resolve(testContext(t)) require.NoError(t, err) assert.Equal(t, &builtinReader{name: name}, tmpl.Reader) @@ -52,7 +57,7 @@ func TestTemplateResolverForDefaultTemplates(t *testing.T) { ConfigFile: "/config/file", } - tmpl, err := r.Resolve(context.Background()) + tmpl, err := r.Resolve(testContext(t)) require.NoError(t, err) // Assert reader and writer configuration @@ -69,7 +74,7 @@ func TestTemplateResolverForCustomUrl(t *testing.T) { ConfigFile: "/config/file", } - tmpl, err := r.Resolve(context.Background()) + tmpl, err := r.Resolve(testContext(t)) require.NoError(t, err) assert.Equal(t, Custom, tmpl.name) @@ -89,7 +94,7 @@ func TestTemplateResolverForCustomPath(t *testing.T) { ConfigFile: "/config/file", } - tmpl, err := r.Resolve(context.Background()) + tmpl, err := r.Resolve(testContext(t)) require.NoError(t, err) assert.Equal(t, Custom, tmpl.name) diff --git a/libs/template/writer.go b/libs/template/writer.go index 37e3fec0e7..c30bbca24d 100644 --- a/libs/template/writer.go +++ b/libs/template/writer.go @@ -2,14 +2,10 @@ package template import ( "context" - "path/filepath" "sort" "strconv" - "strings" - "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/dbr" "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/jsonschema" "github.com/databricks/cli/libs/telemetry" @@ -46,30 +42,10 @@ type defaultWriter struct { renderer *renderer } -func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) { - outputDir, err := filepath.Abs(outputDir) - if err != nil { - return nil, err - } - - // If the CLI is running on DBR and we're writing to the workspace file system, - // use the extension-aware workspace filesystem filer to instantiate the template. - // - // It is not possible to write notebooks through the workspace filesystem's FUSE mount. - // Therefore this is the only way we can initialize templates that contain notebooks - // when running the CLI on DBR and initializing a template to the workspace. - // - if strings.HasPrefix(outputDir, "/Workspace/") && dbr.RunsOnRuntime(ctx) { - return filer.NewWorkspaceFilesExtensionsClient(cmdctx.WorkspaceClient(ctx), outputDir) - } - - return filer.NewLocalClient(outputDir) -} - func (tmpl *defaultWriter) Configure(ctx context.Context, configPath, outputDir string) error { tmpl.configPath = configPath - outputFiler, err := constructOutputFiler(ctx, outputDir) + outputFiler, err := filer.NewOutputFiler(ctx, outputDir) if err != nil { return err } diff --git a/libs/template/writer_test.go b/libs/template/writer_test.go index 8b440f34f8..53996acd97 100644 --- a/libs/template/writer_test.go +++ b/libs/template/writer_test.go @@ -15,9 +15,11 @@ import ( ) func TestDefaultWriterConfigure(t *testing.T) { + ctx := dbr.DetectRuntime(context.Background()) + // Test on local file system. w := &defaultWriter{} - err := w.Configure(context.Background(), "/foo/bar", "/out/abc") + err := w.Configure(ctx, "/foo/bar", "/out/abc") assert.NoError(t, err) assert.Equal(t, "/foo/bar", w.configPath) @@ -46,7 +48,7 @@ func TestDefaultWriterConfigureOnDBR(t *testing.T) { func TestMaterializeForNonTemplateDirectory(t *testing.T) { tmpDir1 := t.TempDir() tmpDir2 := t.TempDir() - ctx := context.Background() + ctx := dbr.DetectRuntime(context.Background()) w := &defaultWriter{} err := w.Configure(ctx, "/foo/bar", tmpDir1)