diff --git a/cmd/auth/login.go b/cmd/auth/login.go index cd4d81ad25..c4d0851011 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -146,7 +146,11 @@ a new profile is created. ctx := cmd.Context() profileName := cmd.Flag("profile").Value.String() - tokenCache, mode, err := storage.ResolveCache(ctx, "") + // Resolve the cache before the browser step so a missing/locked keyring + // surfaces here rather than after the user completes OAuth. When secure + // is selected but the keyring is unreachable, this silently falls back + // to plaintext and persists auth_storage = plaintext for next time. + tokenCache, mode, err := storage.ResolveCacheForLogin(ctx, "") if err != nil { return err } diff --git a/libs/auth/storage/cache.go b/libs/auth/storage/cache.go index 151646081d..2801c8a6b8 100644 --- a/libs/auth/storage/cache.go +++ b/libs/auth/storage/cache.go @@ -4,6 +4,9 @@ import ( "context" "fmt" + "github.com/databricks/cli/libs/databrickscfg" + "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" ) @@ -12,15 +15,17 @@ import ( // so unit tests can inject stubs without hitting the real OS keyring or // filesystem. Production code uses defaultCacheFactories(). type cacheFactories struct { - newFile func(context.Context) (cache.TokenCache, error) - newKeyring func() cache.TokenCache + newFile func(context.Context) (cache.TokenCache, error) + newKeyring func() cache.TokenCache + probeKeyring func() error } // defaultCacheFactories returns the production factory set. func defaultCacheFactories() cacheFactories { return cacheFactories{ - newFile: func(ctx context.Context) (cache.TokenCache, error) { return NewFileTokenCache(ctx) }, - newKeyring: NewKeyringCache, + newFile: func(ctx context.Context) (cache.TokenCache, error) { return NewFileTokenCache(ctx) }, + newKeyring: NewKeyringCache, + probeKeyring: ProbeKeyring, } } @@ -38,6 +43,24 @@ func ResolveCache(ctx context.Context, override StorageMode) (cache.TokenCache, return resolveCacheWith(ctx, override, defaultCacheFactories()) } +// ResolveCacheForLogin resolves the cache like ResolveCache with extra rules +// for the auth login path: +// +// 1. When the resolved mode is secure and the user did not explicitly ask for +// it (no override flag, no env var, no config), and the OS keyring is +// unreachable, fall back silently to plaintext and persist +// auth_storage = plaintext so subsequent commands skip the probe. +// 2. When the user explicitly asked for secure (override, env var, or config) +// but the keyring is unreachable, return an error. An explicit "I want +// secure" is honored strictly: never silently downgrade. +// +// Login-specific. Read paths (auth token, bundle commands) keep the original +// keyring error so they don't silently mint plaintext copies of tokens that +// were stored in the keyring on another machine. +func ResolveCacheForLogin(ctx context.Context, override StorageMode) (cache.TokenCache, StorageMode, error) { + return resolveCacheForLoginWith(ctx, override, defaultCacheFactories()) +} + // WrapForOAuthArgument wraps tokenCache so SDK-side writes (Challenge, refresh) // dual-write to the legacy host-based cache key when mode is plaintext. Other // modes return tokenCache unchanged: secure mode never writes a host-key entry, @@ -73,3 +96,69 @@ func resolveCacheWith(ctx context.Context, override StorageMode, f cacheFactorie return nil, "", fmt.Errorf("unsupported storage mode %q", string(mode)) } } + +// resolveCacheForLoginWith is the pure form of ResolveCacheForLogin. It takes +// the factory set as a parameter so tests can inject stubs. +func resolveCacheForLoginWith(ctx context.Context, override StorageMode, f cacheFactories) (cache.TokenCache, StorageMode, error) { + mode, explicit, err := ResolveStorageModeWithSource(ctx, override) + if err != nil { + return nil, "", err + } + return applyLoginFallback(ctx, mode, explicit, f) +} + +// applyLoginFallback realizes the login-time fallback rules given an already- +// resolved mode and whether the user explicitly asked for it. Split out so +// tests can drive the (mode, explicit) input space directly without depending +// on whatever the resolver's default mode happens to be at any point in time. +func applyLoginFallback(ctx context.Context, mode StorageMode, explicit bool, f cacheFactories) (cache.TokenCache, StorageMode, error) { + if mode != StorageModeSecure { + switch mode { + case StorageModePlaintext: + c, err := f.newFile(ctx) + if err != nil { + return nil, "", fmt.Errorf("open file token cache: %w", err) + } + return c, mode, nil + default: + return nil, "", fmt.Errorf("unsupported storage mode %q", string(mode)) + } + } + if probeErr := f.probeKeyring(); probeErr != nil { + if explicit { + return nil, "", fmt.Errorf("secure storage was requested but the OS keyring is not reachable: %w", probeErr) + } + log.Debugf(ctx, "secure storage unavailable (%v), falling back to plaintext", probeErr) + fileCache, fileErr := f.newFile(ctx) + if fileErr != nil { + return nil, "", fmt.Errorf("open file token cache: %w", fileErr) + } + if err := persistPlaintextFallback(ctx); err != nil { + log.Debugf(ctx, "persisting auth_storage fallback failed: %v", err) + } + return fileCache, StorageModePlaintext, nil + } + return f.newKeyring(), StorageModeSecure, nil +} + +// persistPlaintextFallback writes auth_storage = plaintext to [__settings__] +// in .databrickscfg so subsequent commands skip the (slow/blocking) keyring +// probe and route straight to the file cache. +// +// We deliberately persist only on the default-mode + probe-fail path, never +// on the success paths: +// - default + probe ok: writing the runtime mode would lock the current +// default into the user's config and prevent a future change to the +// default from reaching them. +// - explicit secure (override, env, config): the value is already set +// somewhere by definition, so a write would be redundant. +// +// The fallback is the only path where persisting changes future behavior. +// It also pins these users to plaintext explicitly, so any future changes to +// this logic don't accidentally disrupt them: they're already using plaintext +// implicitly (the keyring is unreachable), and the persisted setting makes +// that choice stable across CLI versions. +func persistPlaintextFallback(ctx context.Context) error { + configPath := env.Get(ctx, "DATABRICKS_CONFIG_FILE") + return databrickscfg.SetConfiguredAuthStorage(ctx, string(StorageModePlaintext), configPath) +} diff --git a/libs/auth/storage/cache_test.go b/libs/auth/storage/cache_test.go index b84c1ef3ba..34d3f38c13 100644 --- a/libs/auth/storage/cache_test.go +++ b/libs/auth/storage/cache_test.go @@ -3,9 +3,11 @@ package storage import ( "context" "errors" + "os" "path/filepath" "testing" + "github.com/databricks/cli/libs/databrickscfg" "github.com/databricks/cli/libs/env" "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" @@ -24,8 +26,9 @@ func (stubCache) Lookup(string) (*oauth2.Token, error) { return nil, cache.ErrNo func fakeFactories(t *testing.T) cacheFactories { t.Helper() return cacheFactories{ - newFile: func(context.Context) (cache.TokenCache, error) { return stubCache{source: "file"}, nil }, - newKeyring: func() cache.TokenCache { return stubCache{source: "keyring"} }, + newFile: func(context.Context) (cache.TokenCache, error) { return stubCache{source: "file"}, nil }, + newKeyring: func() cache.TokenCache { return stubCache{source: "keyring"} }, + probeKeyring: func() error { return nil }, } } @@ -106,8 +109,9 @@ func TestResolveCache_FileFactoryErrorPropagates(t *testing.T) { ctx := t.Context() boom := errors.New("disk full") factories := cacheFactories{ - newFile: func(context.Context) (cache.TokenCache, error) { return nil, boom }, - newKeyring: func() cache.TokenCache { return stubCache{source: "keyring"} }, + newFile: func(context.Context) (cache.TokenCache, error) { return nil, boom }, + newKeyring: func() cache.TokenCache { return stubCache{source: "keyring"} }, + probeKeyring: func() error { return nil }, } _, _, err := resolveCacheWith(ctx, StorageModePlaintext, factories) @@ -116,6 +120,118 @@ func TestResolveCache_FileFactoryErrorPropagates(t *testing.T) { assert.ErrorIs(t, err, boom) } +func TestResolveCacheForLogin_PlaintextSkipsProbe(t *testing.T) { + hermetic(t) + ctx := t.Context() + probed := false + f := fakeFactories(t) + f.probeKeyring = func() error { + probed = true + return nil + } + + got, mode, err := resolveCacheForLoginWith(ctx, StorageModePlaintext, f) + + require.NoError(t, err) + assert.Equal(t, StorageModePlaintext, mode) + assert.Equal(t, "file", got.(stubCache).source) + assert.False(t, probed, "probe must not run when mode is already plaintext") +} + +func TestResolveCacheForLogin_SecureProbeOK(t *testing.T) { + hermetic(t) + ctx := env.Set(t.Context(), EnvVar, "secure") + + got, mode, err := resolveCacheForLoginWith(ctx, "", fakeFactories(t)) + + require.NoError(t, err) + assert.Equal(t, StorageModeSecure, mode) + assert.Equal(t, "keyring", got.(stubCache).source) +} + +func TestResolveCacheForLogin_ExplicitEnvSecure_ProbeFail_Errors(t *testing.T) { + hermetic(t) + ctx := env.Set(t.Context(), EnvVar, "secure") + configPath := env.Get(ctx, "DATABRICKS_CONFIG_FILE") + + f := fakeFactories(t) + f.probeKeyring = func() error { return errors.New("no keyring") } + + _, _, err := resolveCacheForLoginWith(ctx, "", f) + require.Error(t, err) + assert.ErrorContains(t, err, "secure storage was requested") + + persisted, gerr := databrickscfg.GetConfiguredAuthStorage(ctx, configPath) + require.NoError(t, gerr) + assert.Equal(t, "", persisted, "env-set secure must not be persisted as plaintext") +} + +func TestResolveCacheForLogin_ExplicitConfigSecure_ProbeFail_Errors(t *testing.T) { + hermetic(t) + ctx := t.Context() + configPath := env.Get(ctx, "DATABRICKS_CONFIG_FILE") + require.NoError(t, os.WriteFile(configPath, []byte("[__settings__]\nauth_storage = secure\n"), 0o600)) + + f := fakeFactories(t) + f.probeKeyring = func() error { return errors.New("no keyring") } + + _, _, err := resolveCacheForLoginWith(ctx, "", f) + require.Error(t, err) + assert.ErrorContains(t, err, "secure storage was requested") + + persisted, gerr := databrickscfg.GetConfiguredAuthStorage(ctx, configPath) + require.NoError(t, gerr) + assert.Equal(t, "secure", persisted, "config-set secure must not be silently rewritten") +} + +func TestResolveCacheForLogin_ExplicitOverrideSecure_ProbeFail_Errors(t *testing.T) { + hermetic(t) + ctx := t.Context() + + f := fakeFactories(t) + f.probeKeyring = func() error { return errors.New("no keyring") } + + _, _, err := resolveCacheForLoginWith(ctx, StorageModeSecure, f) + require.Error(t, err) + assert.ErrorContains(t, err, "secure storage was requested") +} + +func TestApplyLoginFallback_DefaultSecure_ProbeFail_FallsBackAndPersists(t *testing.T) { + hermetic(t) + ctx := t.Context() + configPath := env.Get(ctx, "DATABRICKS_CONFIG_FILE") + + f := fakeFactories(t) + f.probeKeyring = func() error { return errors.New("no keyring") } + + got, mode, err := applyLoginFallback(ctx, StorageModeSecure, false, f) + + require.NoError(t, err) + assert.Equal(t, StorageModePlaintext, mode) + assert.Equal(t, "file", got.(stubCache).source) + + persisted, err := databrickscfg.GetConfiguredAuthStorage(ctx, configPath) + require.NoError(t, err) + assert.Equal(t, "plaintext", persisted, "default-mode fallback must persist auth_storage = plaintext") +} + +func TestApplyLoginFallback_ExplicitSecure_ProbeFail_Errors(t *testing.T) { + hermetic(t) + ctx := t.Context() + configPath := env.Get(ctx, "DATABRICKS_CONFIG_FILE") + + f := fakeFactories(t) + f.probeKeyring = func() error { return errors.New("no keyring") } + + _, _, err := applyLoginFallback(ctx, StorageModeSecure, true, f) + require.Error(t, err) + assert.ErrorContains(t, err, "secure storage was requested") + + persisted, gerr := databrickscfg.GetConfiguredAuthStorage(ctx, configPath) + require.NoError(t, gerr) + assert.Equal(t, "", persisted, "explicit-secure error must not write config") +} + func TestWrapForOAuthArgument(t *testing.T) { const ( host = "https://example.com" diff --git a/libs/auth/storage/keyring.go b/libs/auth/storage/keyring.go index e9fc7d13df..d16428a57b 100644 --- a/libs/auth/storage/keyring.go +++ b/libs/auth/storage/keyring.go @@ -17,6 +17,11 @@ import ( // cache key the SDK passes through TokenCache.Store / Lookup. const keyringServiceName = "databricks-cli" +// keyringProbeAccount is the account name ProbeKeyring writes and deletes +// to verify the keyring is reachable. Distinct from any real cache key so a +// concurrent probe cannot collide with an actual OAuth token entry. +const keyringProbeAccount = "__probe__" + // defaultKeyringTimeout is how long a single keyring operation is allowed // to run before the wrapper returns a TimeoutError. Matches the value used // by GitHub CLI. @@ -79,6 +84,34 @@ func NewKeyringCache() cache.TokenCache { } } +// ProbeKeyring returns nil if the OS keyring is reachable and accepts a +// write+delete cycle within the standard timeout. A non-nil error means the +// keyring cannot be used in this environment (no backend, headless Linux +// session waiting on a UI prompt, locked keychain refusing access, etc.). +// +// Used by databricks auth login to decide whether to silently fall back to +// plaintext storage before opening the browser, so the user does not +// complete an OAuth flow only to fail at the final Store call. +func ProbeKeyring() error { + return probeWithBackend(zalandoBackend{}, defaultKeyringTimeout) +} + +func probeWithBackend(backend keyringBackend, timeout time.Duration) error { + c := &keyringCache{ + backend: backend, + timeout: timeout, + keyringSvcName: keyringServiceName, + } + tok := &oauth2.Token{AccessToken: "probe"} + if err := c.Store(keyringProbeAccount, tok); err != nil { + return fmt.Errorf("write: %w", err) + } + if err := c.Store(keyringProbeAccount, nil); err != nil { + return fmt.Errorf("delete: %w", err) + } + return nil +} + // Store stores t under key. Nil t deletes the entry; deleting a missing // entry is not an error. func (k *keyringCache) Store(key string, t *oauth2.Token) error { diff --git a/libs/auth/storage/keyring_test.go b/libs/auth/storage/keyring_test.go index 74ea3c0c63..1057c75d20 100644 --- a/libs/auth/storage/keyring_test.go +++ b/libs/auth/storage/keyring_test.go @@ -217,3 +217,64 @@ func TestKeyringCache_StoreNil_TimesOut(t *testing.T) { var timeoutErr *TimeoutError assert.ErrorAs(t, err, &timeoutErr, "expected TimeoutError, got %T: %v", err, err) } + +func TestProbeKeyring(t *testing.T) { + boom := errors.New("backend boom") + cases := []struct { + name string + setErr error + deleteErr error + setBlock bool + timeout time.Duration + wantErr error + wantTimeout bool + }{ + { + name: "success leaves no entry", + timeout: 100 * time.Millisecond, + }, + { + name: "set error propagates", + setErr: boom, + timeout: 100 * time.Millisecond, + wantErr: boom, + }, + { + name: "set times out", + setBlock: true, + timeout: 50 * time.Millisecond, + wantTimeout: true, + }, + { + name: "delete error propagates", + deleteErr: boom, + timeout: 100 * time.Millisecond, + wantErr: boom, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + backend := newFakeBackend() + backend.setErr = tc.setErr + backend.deleteErr = tc.deleteErr + backend.setBlock = tc.setBlock + + err := probeWithBackend(backend, tc.timeout) + + switch { + case tc.wantErr != nil: + require.Error(t, err) + assert.ErrorIs(t, err, tc.wantErr) + case tc.wantTimeout: + require.Error(t, err) + var timeoutErr *TimeoutError + assert.ErrorAs(t, err, &timeoutErr) + default: + require.NoError(t, err) + _, ok := backend.items[itemKey(keyringServiceName, keyringProbeAccount)] + assert.False(t, ok, "probe must clean up after itself") + } + }) + } +} diff --git a/libs/auth/storage/mode.go b/libs/auth/storage/mode.go index b3dc846536..2caace171f 100644 --- a/libs/auth/storage/mode.go +++ b/libs/auth/storage/mode.go @@ -65,24 +65,37 @@ func ParseMode(raw string) StorageMode { // unrecognized env or config value is reported as an error wrapped with // the source name. func ResolveStorageMode(ctx context.Context, override StorageMode) (StorageMode, error) { + mode, _, err := ResolveStorageModeWithSource(ctx, override) + return mode, err +} + +// ResolveStorageModeWithSource is like ResolveStorageMode but also reports +// whether the resolved mode came from an explicit user choice (override flag, +// env var, or config) versus the built-in default. Callers use this to honor +// "I want secure" strictly: when the user explicitly asked for secure storage +// but it cannot be provided, the right move is to error out, not to silently +// downgrade. +func ResolveStorageModeWithSource(ctx context.Context, override StorageMode) (StorageMode, bool, error) { if override != StorageModeUnknown { - return override, nil + return override, true, nil } if raw := env.Get(ctx, EnvVar); raw != "" { - return parseFromSource(raw, EnvVar) + mode, err := parseFromSource(raw, EnvVar) + return mode, true, err } configPath := env.Get(ctx, "DATABRICKS_CONFIG_FILE") raw, err := databrickscfg.GetConfiguredAuthStorage(ctx, configPath) if err != nil { - return "", fmt.Errorf("read auth_storage setting: %w", err) + return "", false, fmt.Errorf("read auth_storage setting: %w", err) } if raw != "" { - return parseFromSource(raw, "auth_storage") + mode, err := parseFromSource(raw, "auth_storage") + return mode, true, err } - return StorageModePlaintext, nil + return StorageModePlaintext, false, nil } func parseFromSource(raw, source string) (StorageMode, error) { diff --git a/libs/auth/storage/mode_test.go b/libs/auth/storage/mode_test.go index d932d2253a..bec3e571eb 100644 --- a/libs/auth/storage/mode_test.go +++ b/libs/auth/storage/mode_test.go @@ -128,3 +128,70 @@ func TestResolveStorageMode_SkipsConfigReadWhenOverrideOrEnvSet(t *testing.T) { assert.Equal(t, StorageModeSecure, got) }) } + +func TestResolveStorageModeWithSource(t *testing.T) { + cases := []struct { + name string + override StorageMode + envValue string + configBody string + wantMode StorageMode + wantExplicit bool + wantErrSub string + }{ + { + name: "default is not explicit", + wantMode: StorageModePlaintext, + wantExplicit: false, + }, + { + name: "override is explicit", + override: StorageModeSecure, + wantMode: StorageModeSecure, + wantExplicit: true, + }, + { + name: "env is explicit", + envValue: "secure", + wantMode: StorageModeSecure, + wantExplicit: true, + }, + { + name: "config is explicit", + configBody: "[__settings__]\nauth_storage = secure\n", + wantMode: StorageModeSecure, + wantExplicit: true, + }, + { + name: "invalid env is rejected", + envValue: "bogus", + wantErrSub: "DATABRICKS_AUTH_STORAGE", + }, + { + name: "invalid config value is rejected", + configBody: "[__settings__]\nauth_storage = bogus\n", + wantErrSub: "auth_storage", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfgPath := filepath.Join(t.TempDir(), ".databrickscfg") + if tc.configBody != "" { + require.NoError(t, os.WriteFile(cfgPath, []byte(tc.configBody), 0o600)) + } + t.Setenv("DATABRICKS_CONFIG_FILE", cfgPath) + t.Setenv(EnvVar, tc.envValue) + + mode, explicit, err := ResolveStorageModeWithSource(t.Context(), tc.override) + if tc.wantErrSub != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrSub) + return + } + require.NoError(t, err) + assert.Equal(t, tc.wantMode, mode) + assert.Equal(t, tc.wantExplicit, explicit) + }) + } +} diff --git a/libs/databrickscfg/ops.go b/libs/databrickscfg/ops.go index c4d0f1cc79..43b3fc7000 100644 --- a/libs/databrickscfg/ops.go +++ b/libs/databrickscfg/ops.go @@ -196,6 +196,29 @@ func SetDefaultProfile(ctx context.Context, profileName, configFilePath string) return writeConfigFile(ctx, configFile) } +// SetConfiguredAuthStorage writes the auth_storage key to the [__settings__] +// section. Used by auth login to persist a plaintext fallback when the OS +// keyring is unreachable, so subsequent commands skip the keyring probe and +// route directly to the file cache. +func SetConfiguredAuthStorage(ctx context.Context, value, configFilePath string) error { + configFile, err := loadOrCreateConfigFile(ctx, configFilePath) + if err != nil { + return err + } + + section, err := configFile.GetSection(databricksSettingsSection) + if err != nil { + section, err = configFile.NewSection(databricksSettingsSection) + if err != nil { + return fmt.Errorf("cannot create %s section: %w", databricksSettingsSection, err) + } + } + + section.Key(authStorageKey).SetValue(value) + + return writeConfigFile(ctx, configFile) +} + // ClearDefaultProfile removes the default_profile key from the [__settings__] // section if the current default matches the given profile name. func ClearDefaultProfile(ctx context.Context, profileName, configFilePath string) error { diff --git a/libs/databrickscfg/ops_test.go b/libs/databrickscfg/ops_test.go index 0555a8171f..a8ef811e75 100644 --- a/libs/databrickscfg/ops_test.go +++ b/libs/databrickscfg/ops_test.go @@ -709,3 +709,57 @@ func TestGetConfiguredAuthStorage_MissingFile(t *testing.T) { require.NoError(t, err) assert.Equal(t, "", got) } + +func TestSetConfiguredAuthStorage(t *testing.T) { + cases := []struct { + name string + contents string + }{ + { + name: "missing file is created", + contents: "", + }, + { + name: "missing settings section is created", + contents: "[my-ws]\nhost = https://example.cloud.databricks.com\n", + }, + { + name: "settings section without auth_storage gets the key added", + contents: "[__settings__]\ndefault_profile = my-ws\n", + }, + { + name: "existing auth_storage value is overwritten", + contents: "[__settings__]\nauth_storage = secure\n", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + path := filepath.Join(t.TempDir(), ".databrickscfg") + if tc.contents != "" { + require.NoError(t, os.WriteFile(path, []byte(tc.contents), 0o600)) + } + + require.NoError(t, SetConfiguredAuthStorage(t.Context(), "plaintext", path)) + + got, err := GetConfiguredAuthStorage(t.Context(), path) + require.NoError(t, err) + assert.Equal(t, "plaintext", got) + }) + } +} + +func TestSetConfiguredAuthStorage_PreservesOtherSettings(t *testing.T) { + path := filepath.Join(t.TempDir(), ".databrickscfg") + require.NoError(t, os.WriteFile(path, []byte("[__settings__]\ndefault_profile = dev\n\n[dev]\nhost = https://example.cloud.databricks.com\n"), 0o600)) + + require.NoError(t, SetConfiguredAuthStorage(t.Context(), "plaintext", path)) + + defaultProfile, err := GetConfiguredDefaultProfile(t.Context(), path) + require.NoError(t, err) + assert.Equal(t, "dev", defaultProfile) + + authStorage, err := GetConfiguredAuthStorage(t.Context(), path) + require.NoError(t, err) + assert.Equal(t, "plaintext", authStorage) +}