From 0cfa9c81e40d8dd4e051ae930e84ded1b4999b5f Mon Sep 17 00:00:00 2001 From: Rohit Sharma Date: Tue, 17 Feb 2026 22:21:48 +0000 Subject: [PATCH] feat(spark): Add multi-tenant Databricks token support via cross-namespace K8s secrets Enable per-project Databricks authentication by reading tokens from Kubernetes secrets in workflow namespaces, with backward-compatible fallback to the FLYTE_DATABRICKS_ACCESS_TOKEN environment variable. Changes: - Add get_secret_from_k8s() for cross-namespace K8s secret reading - Add get_databricks_token() with tiered resolution (K8s -> env var) - Update DatabricksJobMetadata to persist auth_token across lifecycle - Update DatabricksConnector.create/get/delete to use per-project tokens - Add DatabricksV2.databricks_token_secret for custom secret names - Add 31 comprehensive tests covering all token resolution paths Tracking: https://github.com/flyteorg/flyte/issues/6911 Signed-off-by: Rohit Sharma --- .../flytekitplugins/spark/connector.py | 178 ++++- .../flytekitplugins/spark/task.py | 35 +- .../flytekit-spark/tests/test_connector.py | 8 +- .../tests/test_databricks_token.py | 654 ++++++++++++++++++ pydoclint-errors-baseline.txt | 4 - 5 files changed, 853 insertions(+), 26 deletions(-) create mode 100644 plugins/flytekit-spark/tests/test_databricks_token.py diff --git a/plugins/flytekit-spark/flytekitplugins/spark/connector.py b/plugins/flytekit-spark/flytekitplugins/spark/connector.py index 895c7d153d..6c901f7573 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/connector.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/connector.py @@ -1,5 +1,6 @@ import http import json +import logging import os import typing from dataclasses import dataclass @@ -13,10 +14,12 @@ from flytekit.extend.backend.utils import convert_to_flyte_phase, get_connector_secret from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap -from flytekit.models.task import TaskTemplate +from flytekit.models.task import TaskExecutionMetadata, TaskTemplate aiohttp = lazy_module("aiohttp") +logger = logging.getLogger(__name__) + DATABRICKS_API_ENDPOINT = "/api/2.1/jobs" DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE" @@ -25,6 +28,7 @@ class DatabricksJobMetadata(ResourceMeta): databricks_instance: str run_id: str + auth_token: Optional[str] = None # Store auth token for get/delete operations def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: @@ -68,7 +72,11 @@ def __init__(self): super().__init__(task_type_name="spark", metadata_type=DatabricksJobMetadata) async def create( - self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + task_execution_metadata: Optional[TaskExecutionMetadata] = None, + **kwargs, ) -> DatabricksJobMetadata: data = json.dumps(_get_databricks_job_spec(task_template)) databricks_instance = task_template.custom.get( @@ -80,15 +88,31 @@ async def create( f"Missing databricks instance. Please set the value through the task config or set the {DEFAULT_DATABRICKS_INSTANCE_ENV_KEY} environment variable in the connector." ) + # Get workflow-specific token or fall back to default + namespace = task_execution_metadata.namespace if task_execution_metadata else None + + # Extract custom secret name from task template (if provided) + custom_secret_name = task_template.custom.get("databricksTokenSecret") + + logger.info(f"Creating Databricks job for namespace: {namespace or 'unknown'}") + if custom_secret_name: + logger.info(f"Using custom secret name: {custom_secret_name}") + + auth_token = get_databricks_token( + namespace=namespace, task_template=task_template, secret_name=custom_secret_name + ) databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/submit" async with aiohttp.ClientSession() as session: - async with session.post(databricks_url, headers=get_header(), data=data) as resp: + async with session.post(databricks_url, headers=get_header(auth_token=auth_token), data=data) as resp: response = await resp.json() if resp.status != http.HTTPStatus.OK: raise RuntimeError(f"Failed to create databricks job with error: {response}") - return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"])) + logger.info(f"Successfully created Databricks job with run_id: {response['run_id']}") + return DatabricksJobMetadata( + databricks_instance=databricks_instance, run_id=str(response["run_id"]), auth_token=auth_token + ) async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: databricks_instance = resource_meta.databricks_instance @@ -96,8 +120,11 @@ async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}" ) + # Use the stored auth token if available, otherwise fall back to default + headers = get_header(auth_token=resource_meta.auth_token) + async with aiohttp.ClientSession() as session: - async with session.get(databricks_url, headers=get_header()) as resp: + async with session.get(databricks_url, headers=headers) as resp: if resp.status != http.HTTPStatus.OK: raise RuntimeError(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}") response = await resp.json() @@ -128,8 +155,11 @@ async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs): databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel" data = json.dumps({"run_id": resource_meta.run_id}) + # Use the stored auth token if available, otherwise fall back to default + headers = get_header(auth_token=resource_meta.auth_token) + async with aiohttp.ClientSession() as session: - async with session.post(databricks_url, headers=get_header(), data=data) as resp: + async with session.post(databricks_url, headers=headers, data=data) as resp: if resp.status != http.HTTPStatus.OK: raise RuntimeError( f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}" @@ -150,9 +180,139 @@ def __init__(self): super(DatabricksConnector, self).__init__(task_type_name="databricks", metadata_type=DatabricksJobMetadata) -def get_header() -> typing.Dict[str, str]: - token = get_connector_secret("FLYTE_DATABRICKS_ACCESS_TOKEN") - return {"Authorization": f"Bearer {token}", "content-type": "application/json"} +def get_secret_from_k8s(secret_name: str, secret_key: str, namespace: str) -> Optional[str]: + """Read a secret from Kubernetes using the Kubernetes Python client. + + Args: + secret_name (str): Name of the Kubernetes secret (e.g., "databricks-token"). + secret_key (str): Key within the secret (e.g., "token"). + namespace (str): Kubernetes namespace where the secret is stored. + + Returns: + Optional[str]: The secret value as a string, or None if not found. + """ + try: + import base64 + + from kubernetes import client, config + + # Try to load in-cluster config first (when running in K8s) + try: + config.load_incluster_config() + except config.ConfigException: + # Fall back to kubeconfig (for local testing) + try: + config.load_kube_config() + except Exception as e: + logger.warning(f"Failed to load Kubernetes config: {e}") + return None + + v1 = client.CoreV1Api() + + try: + secret = v1.read_namespaced_secret(name=secret_name, namespace=namespace) + if secret.data and secret_key in secret.data: + # Kubernetes secrets are base64 encoded + secret_value = base64.b64decode(secret.data[secret_key]).decode("utf-8") + return secret_value + else: + logger.debug( + f"Secret '{secret_name}' exists but key '{secret_key}' not found in namespace '{namespace}'" + ) + return None + except client.exceptions.ApiException as e: + if e.status == 404: + logger.debug(f"Secret '{secret_name}' not found in namespace '{namespace}'") + else: + logger.warning(f"Error reading secret '{secret_name}' from namespace '{namespace}': {e}") + return None + + except ImportError: + logger.warning("kubernetes Python package not installed - cannot read namespace secrets") + return None + except Exception as e: + logger.warning(f"Unexpected error reading K8s secret: {e}") + return None + + +def get_databricks_token( + namespace: Optional[str] = None, task_template: Optional[TaskTemplate] = None, secret_name: Optional[str] = None +) -> str: + """Get the Databricks access token with multi-tenant support. + + Token resolution: namespace K8s secret -> FLYTE_DATABRICKS_ACCESS_TOKEN env var. + + Args: + namespace (Optional[str]): Kubernetes namespace for workflow-specific token lookup. + task_template (Optional[TaskTemplate]): Optional TaskTemplate (kept for API compatibility). + secret_name (Optional[str]): Custom secret name. Defaults to 'databricks-token'. + + Returns: + str: The Databricks access token. + + Raises: + ValueError: If no token is found from any source. + """ + token = None + token_source = "unknown" + + # Use custom secret name or default to 'databricks-token' + k8s_secret_name = secret_name or "databricks-token" + + # Step 1: Try namespace-specific K8s secret (cross-namespace lookup) + if namespace: + logger.info(f"Looking for Databricks token in workflow namespace: {namespace} (secret: {k8s_secret_name})") + token = get_secret_from_k8s(secret_name=k8s_secret_name, secret_key="token", namespace=namespace) + + if token: + logger.info(f"Found Databricks token in namespace '{namespace}' from secret '{k8s_secret_name}'") + token_source = f"k8s_namespace:{namespace}/secret:{k8s_secret_name}" + else: + logger.info( + f"Databricks token not found in secret '{k8s_secret_name}' in namespace '{namespace}' - trying fallback" + ) + else: + logger.info("No namespace provided for cross-namespace lookup") + + # Step 2: Fall back to environment variable (backward compatibility) + if token is None: + logger.info("Falling back to default Databricks token (FLYTE_DATABRICKS_ACCESS_TOKEN)") + try: + token = get_connector_secret("FLYTE_DATABRICKS_ACCESS_TOKEN") + token_source = "env_variable" + except Exception as e: + logger.error(f"Failed to get default Databricks token: {e}") + raise ValueError( + "No Databricks token found from any source:\n" + f"1. Namespace-specific K8s secret '{k8s_secret_name}'\n" + "2. FLYTE_DATABRICKS_ACCESS_TOKEN environment variable\n" + f"Workflow namespace: {namespace or 'N/A'}" + ) + + if not token: + raise ValueError("Databricks token is empty") + + # Log token info without exposing the actual token value + token_preview = f"{token[:8]}..." if len(token) > 8 else "***" + logger.info(f"Using Databricks token from: {token_source} (preview: {token_preview})") + + return token + + +def get_header(task_template: Optional[TaskTemplate] = None, auth_token: Optional[str] = None) -> typing.Dict[str, str]: + """Get the authorization header for Databricks API calls. + + Args: + task_template (Optional[TaskTemplate]): TaskTemplate with workflow-specific secret requests. + auth_token (Optional[str]): Pre-fetched auth token to use directly. + + Returns: + typing.Dict[str, str]: Authorization and content-type headers. + """ + if auth_token is None: + auth_token = get_databricks_token(task_template) + + return {"Authorization": f"Bearer {auth_token}", "content-type": "application/json"} def result_state_is_available(life_cycle_state: str) -> bool: diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 5801d24fde..be6be39f53 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -71,19 +71,25 @@ def __post_init__(self): @dataclass class DatabricksV2(Spark): - """ - Use this to configure a Databricks task. Task's marked with this will automatically execute - natively onto databricks platform as a distributed execution of spark - - Args: - databricks_conf: Databricks job configuration compliant with API version 2.1, supporting 2.0 use cases. - For the configuration structure, visit here.https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure - For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html - databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. + """Use this to configure a Databricks task. + + Tasks marked with this will automatically execute natively onto the Databricks + platform as a distributed execution of Spark. + + Attributes: + databricks_conf (Optional[Dict[str, Union[str, dict]]]): Databricks job configuration + compliant with API version 2.1, supporting 2.0 use cases. + See https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + and https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html + databricks_instance (Optional[str]): Domain name of your deployment. + Use the form .cloud.databricks.com. + databricks_token_secret (Optional[str]): Custom name for the K8s secret containing + the Databricks token. Defaults to 'databricks-token' if not specified. """ databricks_conf: Optional[Dict[str, Union[str, dict]]] = None databricks_instance: Optional[str] = None + databricks_token_secret: Optional[str] = None # This method does not reset the SparkSession since it's a bit hard to handle multiple @@ -187,7 +193,16 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: job._databricks_conf = cfg.databricks_conf job._databricks_instance = cfg.databricks_instance - return MessageToDict(job.to_flyte_idl()) + # Serialize to dict + custom_dict = MessageToDict(job.to_flyte_idl()) + + # Add custom secret name if specified (not part of protobuf, so add separately) + if isinstance(self.task_config, (Databricks, DatabricksV2)): + cfg = cast(DatabricksV2, self.task_config) + if hasattr(cfg, "databricks_token_secret") and cfg.databricks_token_secret: + custom_dict["databricksTokenSecret"] = cfg.databricks_token_secret + + return custom_dict def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8sPod]: """ diff --git a/plugins/flytekit-spark/tests/test_connector.py b/plugins/flytekit-spark/tests/test_connector.py index 5136d39ce8..c337d27155 100644 --- a/plugins/flytekit-spark/tests/test_connector.py +++ b/plugins/flytekit-spark/tests/test_connector.py @@ -118,6 +118,7 @@ async def test_databricks_agent(task_template: TaskTemplate): databricks_metadata = DatabricksJobMetadata( databricks_instance="test-account.cloud.databricks.com", run_id="123", + auth_token=mocked_token, ) mock_create_response = {"run_id": "123"} @@ -135,7 +136,7 @@ async def test_databricks_agent(task_template: TaskTemplate): res = await agent.create(task_template, None) spec = _get_databricks_job_spec(task_template) data = json.dumps(spec) - mocked.assert_called_with(create_url, method="POST", data=data, headers=get_header()) + mocked.assert_called_with(create_url, method="POST", data=data, headers=get_header(auth_token=mocked_token)) spark_envs = spec["new_cluster"]["spark_env_vars"] assert spark_envs["foo"] == "bar" assert spark_envs[FLYTE_FAIL_ON_ERROR] == "true" @@ -152,7 +153,7 @@ async def test_databricks_agent(task_template: TaskTemplate): mocked.post(delete_url, status=http.HTTPStatus.OK, payload=mock_delete_response) await agent.delete(databricks_metadata) - assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"} + assert get_header(auth_token=mocked_token) == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"} mock.patch.stopall() @@ -176,6 +177,7 @@ async def test_agent_create_with_default_instance(task_template: TaskTemplate): databricks_metadata = DatabricksJobMetadata( databricks_instance="test-account.cloud.databricks.com", run_id="123", + auth_token=mocked_token, ) mock_create_response = {"run_id": "123"} @@ -188,7 +190,7 @@ async def test_agent_create_with_default_instance(task_template: TaskTemplate): res = await agent.create(task_template, None) spec = _get_databricks_job_spec(task_template) data = json.dumps(spec) - mocked.assert_called_with(create_url, method="POST", data=data, headers=get_header()) + mocked.assert_called_with(create_url, method="POST", data=data, headers=get_header(auth_token=mocked_token)) assert res == databricks_metadata mock.patch.stopall() diff --git a/plugins/flytekit-spark/tests/test_databricks_token.py b/plugins/flytekit-spark/tests/test_databricks_token.py new file mode 100644 index 0000000000..9bed13bb56 --- /dev/null +++ b/plugins/flytekit-spark/tests/test_databricks_token.py @@ -0,0 +1,654 @@ +""" +Comprehensive tests for the Databricks per-project token support feature. + +Tests cover: +- get_secret_from_k8s: Cross-namespace K8s secret reading +- get_databricks_token: Token resolution strategy (K8s -> env var fallback) +- get_header: Authorization header generation with token support +- DatabricksConnector.create/get/delete: Token persistence in job metadata +- DatabricksV2 task config: databricks_token_secret serialization +""" + +import base64 +import http +from datetime import timedelta +from unittest.mock import MagicMock, patch + +import pytest +from aioresponses import aioresponses +from flytekitplugins.spark.connector import ( + DATABRICKS_API_ENDPOINT, + DatabricksJobMetadata, + get_databricks_token, + get_header, + get_secret_from_k8s, +) + +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals, task +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import Container, Resources, TaskExecutionMetadata, TaskTemplate + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="function") +def task_template() -> TaskTemplate: + """Standard Databricks task template for testing.""" + task_id = Identifier( + resource_type=ResourceType.TASK, + project="project", + domain="domain", + name="name", + version="version", + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + (), + ) + task_config = { + "sparkConf": { + "spark.driver.memory": "1000M", + "spark.executor.memory": "1000M", + "spark.executor.cores": "1", + "spark.executor.instances": "2", + "spark.driver.cores": "1", + }, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "flytekit databricks plugin example", + "new_cluster": { + "spark_version": "12.2.x-scala2.12", + "node_type_id": "n2-highmem-4", + "num_workers": 1, + }, + "timeout_seconds": 3600, + "max_retries": 1, + }, + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=[ + "pyflyte-fast-execute", + "--additional-distribution", + "s3://my-s3-bucket/flytesnacks/development/24UYJEF2HDZQN3SG4VAZSM4PLI======/script_mode.tar.gz", + "--dest-dir", + "/root", + "--", + "pyflyte-execute", + "--inputs", + "s3://my-s3-bucket", + "--output-prefix", + "s3://my-s3-bucket", + "--raw-output-data-prefix", + "s3://my-s3-bucket", + "--checkpoint-path", + "s3://my-s3-bucket", + "--prev-checkpoint", + "s3://my-s3-bucket", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "spark_local_example", + "task-name", + "hello_spark", + ], + resources=Resources(requests=[], limits=[]), + env={"foo": "bar"}, + config={}, + ) + return TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + + +@pytest.fixture(scope="function") +def task_template_with_custom_secret(task_template) -> TaskTemplate: + """Task template with a custom databricksTokenSecret.""" + task_template.custom["databricksTokenSecret"] = "my-team-databricks-token" + return task_template + + +def _make_task_execution_metadata(namespace: str) -> TaskExecutionMetadata: + """Helper to create a TaskExecutionMetadata with a given namespace.""" + task_exec_id = MagicMock() + return TaskExecutionMetadata( + task_execution_id=task_exec_id, + namespace=namespace, + labels={}, + annotations={}, + k8s_service_account="default", + environment_variables={}, + identity=MagicMock(), + ) + + +# =========================================================================== +# Tests for get_secret_from_k8s +# =========================================================================== + + +class TestGetSecretFromK8s: + """Tests for cross-namespace Kubernetes secret reading.""" + + def test_success(self): + """Secret found via mocked Kubernetes API client.""" + mock_secret = MagicMock() + mock_secret.data = {"token": base64.b64encode(b"dapi_real_token_xyz").decode()} + + with patch("kubernetes.client.CoreV1Api") as mock_api_cls, patch("kubernetes.config.load_incluster_config"): + mock_api_cls.return_value.read_namespaced_secret.return_value = mock_secret + result = get_secret_from_k8s("databricks-token", "token", "production") + assert result == "dapi_real_token_xyz" + mock_api_cls.return_value.read_namespaced_secret.assert_called_once_with( + name="databricks-token", namespace="production" + ) + + def test_secret_not_found_404(self): + """Secret doesn't exist in the namespace (404).""" + from kubernetes.client.exceptions import ApiException + + with patch("kubernetes.client.CoreV1Api") as mock_api_cls, patch("kubernetes.config.load_incluster_config"): + mock_api_cls.return_value.read_namespaced_secret.side_effect = ApiException(status=404) + result = get_secret_from_k8s("databricks-token", "token", "staging") + assert result is None + + def test_secret_key_missing(self): + """Secret exists but the 'token' key is not present.""" + mock_secret = MagicMock() + mock_secret.data = {"other_key": base64.b64encode(b"something").decode()} + + with patch("kubernetes.client.CoreV1Api") as mock_api_cls, patch("kubernetes.config.load_incluster_config"): + mock_api_cls.return_value.read_namespaced_secret.return_value = mock_secret + result = get_secret_from_k8s("databricks-token", "token", "production") + assert result is None + + def test_secret_data_is_none(self): + """Secret exists but data field is None.""" + mock_secret = MagicMock() + mock_secret.data = None + + with patch("kubernetes.client.CoreV1Api") as mock_api_cls, patch("kubernetes.config.load_incluster_config"): + mock_api_cls.return_value.read_namespaced_secret.return_value = mock_secret + result = get_secret_from_k8s("databricks-token", "token", "production") + assert result is None + + def test_kubernetes_import_error(self): + """kubernetes package not installed - graceful fallback.""" + import builtins + + real_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "kubernetes" or name.startswith("kubernetes."): + raise ImportError("No module named 'kubernetes'") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + result = get_secret_from_k8s("databricks-token", "token", "production") + assert result is None + + def test_api_exception_non_404(self): + """Non-404 API exception is logged as warning, returns None.""" + from kubernetes.client.exceptions import ApiException + + with patch("kubernetes.client.CoreV1Api") as mock_api_cls, patch("kubernetes.config.load_incluster_config"): + mock_api_cls.return_value.read_namespaced_secret.side_effect = ApiException(status=403) + result = get_secret_from_k8s("databricks-token", "token", "restricted-ns") + assert result is None + + def test_kubeconfig_fallback(self): + """Falls back to kubeconfig when in-cluster config fails.""" + from kubernetes.config import ConfigException + + mock_secret = MagicMock() + mock_secret.data = {"token": base64.b64encode(b"local_token").decode()} + + with ( + patch("kubernetes.client.CoreV1Api") as mock_api_cls, + patch("kubernetes.config.load_incluster_config", side_effect=ConfigException("not in cluster")), + patch("kubernetes.config.load_kube_config"), + ): + mock_api_cls.return_value.read_namespaced_secret.return_value = mock_secret + result = get_secret_from_k8s("databricks-token", "token", "dev") + assert result == "local_token" + + def test_both_configs_fail(self): + """Both in-cluster and kubeconfig fail - returns None.""" + from kubernetes.config import ConfigException + + with ( + patch("kubernetes.config.load_incluster_config", side_effect=ConfigException("not in cluster")), + patch("kubernetes.config.load_kube_config", side_effect=Exception("no kubeconfig")), + ): + result = get_secret_from_k8s("databricks-token", "token", "orphan-ns") + assert result is None + + +# =========================================================================== +# Tests for get_databricks_token +# =========================================================================== + + +class TestGetDatabricksToken: + """Tests for the multi-tenant token resolution strategy.""" + + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_token_from_namespace_secret(self, mock_k8s): + """Token found in namespace-specific K8s secret.""" + mock_k8s.return_value = "dapi_namespace_token" + token = get_databricks_token(namespace="team-a") + assert token == "dapi_namespace_token" + mock_k8s.assert_called_once_with( + secret_name="databricks-token", + secret_key="token", + namespace="team-a", + ) + + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_token_with_custom_secret_name(self, mock_k8s): + """Token found using a custom K8s secret name.""" + mock_k8s.return_value = "dapi_custom_secret_token" + token = get_databricks_token(namespace="team-b", secret_name="my-db-token") + assert token == "dapi_custom_secret_token" + mock_k8s.assert_called_once_with( + secret_name="my-db-token", + secret_key="token", + namespace="team-b", + ) + + @patch("flytekitplugins.spark.connector.get_connector_secret") + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_fallback_to_env_when_namespace_secret_missing(self, mock_k8s, mock_env): + """Falls back to FLYTE_DATABRICKS_ACCESS_TOKEN when namespace secret not found.""" + mock_k8s.return_value = None + mock_env.return_value = "dapi_env_fallback_token" + token = get_databricks_token(namespace="team-c") + assert token == "dapi_env_fallback_token" + mock_env.assert_called_once_with("FLYTE_DATABRICKS_ACCESS_TOKEN") + + @patch("flytekitplugins.spark.connector.get_connector_secret") + def test_no_namespace_falls_back_to_env(self, mock_env): + """When no namespace is provided, goes directly to env var.""" + mock_env.return_value = "dapi_default_token" + token = get_databricks_token(namespace=None) + assert token == "dapi_default_token" + mock_env.assert_called_once_with("FLYTE_DATABRICKS_ACCESS_TOKEN") + + @patch("flytekitplugins.spark.connector.get_connector_secret") + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_no_token_from_any_source_raises(self, mock_k8s, mock_env): + """ValueError raised when neither K8s secret nor env var has a token.""" + mock_k8s.return_value = None + mock_env.side_effect = Exception("Secret not found") + with pytest.raises(ValueError, match="No Databricks token found from any source"): + get_databricks_token(namespace="orphan-ns") + + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_empty_token_raises(self, mock_k8s): + """ValueError raised when token is an empty string.""" + mock_k8s.return_value = "" + with pytest.raises(ValueError, match="Databricks token is empty"): + get_databricks_token(namespace="team-d") + + @patch("flytekitplugins.spark.connector.get_connector_secret") + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_default_secret_name_is_databricks_token(self, mock_k8s, mock_env): + """Default secret name is 'databricks-token' when not specified.""" + mock_k8s.return_value = "found_it" + get_databricks_token(namespace="ns-1") + mock_k8s.assert_called_once_with( + secret_name="databricks-token", + secret_key="token", + namespace="ns-1", + ) + + @patch("flytekitplugins.spark.connector.get_connector_secret") + def test_backward_compatibility_no_namespace_no_secret(self, mock_env): + """Backward compatible: no namespace + no secret name = env var only.""" + mock_env.return_value = "dapi_legacy_token" + token = get_databricks_token() + assert token == "dapi_legacy_token" + + +# =========================================================================== +# Tests for get_header +# =========================================================================== + + +class TestGetHeader: + """Tests for authorization header generation.""" + + def test_with_preresolved_auth_token(self): + """Header uses pre-fetched auth token directly.""" + headers = get_header(auth_token="dapi_preresolved_123") + assert headers == { + "Authorization": "Bearer dapi_preresolved_123", + "content-type": "application/json", + } + + @patch("flytekitplugins.spark.connector.get_databricks_token") + def test_without_auth_token_resolves(self, mock_get_token): + """Header resolves token via get_databricks_token when not provided.""" + mock_get_token.return_value = "dapi_resolved_456" + headers = get_header() + assert headers["Authorization"] == "Bearer dapi_resolved_456" + mock_get_token.assert_called_once() + + +# =========================================================================== +# Tests for DatabricksJobMetadata +# =========================================================================== + + +class TestDatabricksJobMetadata: + """Tests for token persistence in job metadata.""" + + def test_metadata_stores_auth_token(self): + """Auth token is stored in metadata.""" + meta = DatabricksJobMetadata( + databricks_instance="test.cloud.databricks.com", + run_id="42", + auth_token="dapi_persistent_token", + ) + assert meta.auth_token == "dapi_persistent_token" + + def test_metadata_auth_token_defaults_to_none(self): + """Auth token defaults to None for backward compatibility.""" + meta = DatabricksJobMetadata( + databricks_instance="test.cloud.databricks.com", + run_id="42", + ) + assert meta.auth_token is None + + +# =========================================================================== +# Tests for DatabricksConnector create/get/delete with token +# =========================================================================== + + +class TestDatabricksConnectorWithToken: + """Integration tests for connector operations with per-project tokens.""" + + @pytest.mark.asyncio + async def test_create_uses_namespace_token(self, task_template): + """create() fetches token from namespace and uses it in API call.""" + task_template.custom["databricksInstance"] = "test.cloud.databricks.com" + agent = AgentRegistry.get_agent("spark") + metadata = _make_task_execution_metadata("project-alpha") + + with patch("flytekitplugins.spark.connector.get_databricks_token") as mock_token: + mock_token.return_value = "dapi_project_alpha_token" + create_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" + + with aioresponses() as mocked: + mocked.post(create_url, status=http.HTTPStatus.OK, payload={"run_id": "999"}) + result = await agent.create(task_template, None, task_execution_metadata=metadata) + + mock_token.assert_called_once_with( + namespace="project-alpha", + task_template=task_template, + secret_name=None, + ) + assert result.auth_token == "dapi_project_alpha_token" + assert result.run_id == "999" + + @pytest.mark.asyncio + async def test_create_with_custom_secret_name(self, task_template_with_custom_secret): + """create() passes custom secret name from task template.""" + tt = task_template_with_custom_secret + tt.custom["databricksInstance"] = "test.cloud.databricks.com" + agent = AgentRegistry.get_agent("spark") + metadata = _make_task_execution_metadata("team-x") + + with patch("flytekitplugins.spark.connector.get_databricks_token") as mock_token: + mock_token.return_value = "dapi_team_x_token" + create_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" + + with aioresponses() as mocked: + mocked.post(create_url, status=http.HTTPStatus.OK, payload={"run_id": "888"}) + result = await agent.create(tt, None, task_execution_metadata=metadata) + + mock_token.assert_called_once_with( + namespace="team-x", + task_template=tt, + secret_name="my-team-databricks-token", + ) + assert result.auth_token == "dapi_team_x_token" + + @pytest.mark.asyncio + async def test_create_without_metadata_uses_no_namespace(self, task_template): + """create() works when task_execution_metadata is None (backward compat).""" + task_template.custom["databricksInstance"] = "test.cloud.databricks.com" + agent = AgentRegistry.get_agent("spark") + + with patch("flytekitplugins.spark.connector.get_databricks_token") as mock_token: + mock_token.return_value = "dapi_default_token" + create_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" + + with aioresponses() as mocked: + mocked.post(create_url, status=http.HTTPStatus.OK, payload={"run_id": "777"}) + await agent.create(task_template, None) + + mock_token.assert_called_once_with( + namespace=None, + task_template=task_template, + secret_name=None, + ) + + @pytest.mark.asyncio + async def test_create_stores_token_in_metadata(self, task_template): + """create() persists auth_token in returned DatabricksJobMetadata.""" + task_template.custom["databricksInstance"] = "test.cloud.databricks.com" + agent = AgentRegistry.get_agent("spark") + metadata = _make_task_execution_metadata("data-team") + + with patch("flytekitplugins.spark.connector.get_databricks_token") as mock_token: + mock_token.return_value = "dapi_data_team_abc" + create_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" + + with aioresponses() as mocked: + mocked.post(create_url, status=http.HTTPStatus.OK, payload={"run_id": "555"}) + result = await agent.create(task_template, None, task_execution_metadata=metadata) + + assert isinstance(result, DatabricksJobMetadata) + assert result.auth_token == "dapi_data_team_abc" + assert result.run_id == "555" + assert result.databricks_instance == "test.cloud.databricks.com" + + @pytest.mark.asyncio + async def test_get_uses_stored_token(self): + """get() uses the auth_token stored in DatabricksJobMetadata.""" + meta = DatabricksJobMetadata( + databricks_instance="test.cloud.databricks.com", + run_id="123", + auth_token="dapi_stored_get_token", + ) + agent = AgentRegistry.get_agent("spark") + get_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=123" + + mock_response = { + "job_id": "1", + "run_id": "123", + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS", "state_message": "OK"}, + } + + with aioresponses() as mocked: + mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_response) + await agent.get(meta) + + # Verify the correct token was used in the request header + call_args = list(mocked.requests.values())[0][0] + assert call_args.kwargs["headers"]["Authorization"] == "Bearer dapi_stored_get_token" + + @pytest.mark.asyncio + async def test_delete_uses_stored_token(self): + """delete() uses the auth_token stored in DatabricksJobMetadata.""" + meta = DatabricksJobMetadata( + databricks_instance="test.cloud.databricks.com", + run_id="456", + auth_token="dapi_stored_delete_token", + ) + agent = AgentRegistry.get_agent("spark") + delete_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/cancel" + + with aioresponses() as mocked: + mocked.post(delete_url, status=http.HTTPStatus.OK, payload={}) + await agent.delete(meta) + + call_args = list(mocked.requests.values())[0][0] + assert call_args.kwargs["headers"]["Authorization"] == "Bearer dapi_stored_delete_token" + + @pytest.mark.asyncio + async def test_get_with_none_token_falls_back(self): + """get() falls back to get_header default when auth_token is None.""" + meta = DatabricksJobMetadata( + databricks_instance="test.cloud.databricks.com", + run_id="789", + auth_token=None, + ) + agent = AgentRegistry.get_agent("spark") + get_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=789" + + mock_response = { + "job_id": "2", + "run_id": "789", + "state": {"life_cycle_state": "RUNNING"}, + } + + with patch("flytekitplugins.spark.connector.get_databricks_token") as mock_token: + mock_token.return_value = "dapi_fallback_env_token" + with aioresponses() as mocked: + mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_response) + await agent.get(meta) + + call_args = list(mocked.requests.values())[0][0] + assert call_args.kwargs["headers"]["Authorization"] == "Bearer dapi_fallback_env_token" + + +# =========================================================================== +# Tests for DatabricksV2 task config token secret +# =========================================================================== + + +class TestDatabricksV2TokenSecret: + """Tests for DatabricksV2 databricks_token_secret field and serialization.""" + + def test_token_secret_field_exists(self): + """DatabricksV2 has databricks_token_secret field.""" + from flytekitplugins.spark.task import DatabricksV2 + + config = DatabricksV2( + databricks_conf={"new_cluster": {"spark_version": "12.2.x"}}, + databricks_instance="test.cloud.databricks.com", + databricks_token_secret="my-team-secret", + ) + assert config.databricks_token_secret == "my-team-secret" + + def test_token_secret_defaults_to_none(self): + """databricks_token_secret defaults to None.""" + from flytekitplugins.spark.task import DatabricksV2 + + config = DatabricksV2( + databricks_conf={"new_cluster": {"spark_version": "12.2.x"}}, + databricks_instance="test.cloud.databricks.com", + ) + assert config.databricks_token_secret is None + + def test_get_custom_includes_token_secret(self): + """get_custom() serializes databricksTokenSecret when set.""" + from flytekitplugins.spark.task import DatabricksV2 + + import flytekit + from flytekit.configuration import Image, ImageConfig, SerializationSettings + + databricks_conf = { + "name": "test", + "new_cluster": { + "spark_version": "12.2.x-scala2.12", + "node_type_id": "r3.xlarge", + "num_workers": 1, + "docker_image": {"url": "test:latest"}, + }, + "timeout_seconds": 3600, + "max_retries": 1, + } + + @flytekit.task( + task_config=DatabricksV2( + databricks_conf=databricks_conf, + databricks_instance="test.cloud.databricks.com", + databricks_token_secret="project-x-token", + ) + ) + def my_task(x: int) -> int: + return x + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + custom = my_task.get_custom(settings) + assert "databricksTokenSecret" in custom + assert custom["databricksTokenSecret"] == "project-x-token" + + def test_get_custom_excludes_token_secret_when_none(self): + """get_custom() does NOT include databricksTokenSecret when None.""" + from flytekitplugins.spark.task import DatabricksV2 + + import flytekit + from flytekit.configuration import Image, ImageConfig, SerializationSettings + + databricks_conf = { + "name": "test", + "new_cluster": { + "spark_version": "12.2.x-scala2.12", + "node_type_id": "r3.xlarge", + "num_workers": 1, + "docker_image": {"url": "test:latest"}, + }, + } + + @flytekit.task( + task_config=DatabricksV2( + databricks_conf=databricks_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def my_task(x: int) -> int: + return x + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + custom = my_task.get_custom(settings) + assert "databricksTokenSecret" not in custom diff --git a/pydoclint-errors-baseline.txt b/pydoclint-errors-baseline.txt index 606633e114..521eb5cafb 100644 --- a/pydoclint-errors-baseline.txt +++ b/pydoclint-errors-baseline.txt @@ -578,10 +578,6 @@ plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py plugins/flytekit-spark/flytekitplugins/spark/models.py DOC301: Class `SparkJob`: __init__() should not have a docstring; please combine it with the docstring of the class -------------------- -plugins/flytekit-spark/flytekitplugins/spark/task.py - DOC601: Class `DatabricksV2`: Class docstring contains fewer class attributes than actual class attributes. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) - DOC603: Class `DatabricksV2`: Class docstring attributes are different from actual class attributes. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Attributes in the class definition but not in the docstring: [databricks_conf: Optional[Dict[str, Union[str, dict]]], databricks_instance: Optional[str]]. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) --------------------- plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py DOC601: Class `SQLAlchemyConfig`: Class docstring contains fewer class attributes than actual class attributes. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) DOC603: Class `SQLAlchemyConfig`: Class docstring attributes are different from actual class attributes. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Attributes in the class definition but not in the docstring: [connect_args: typing.Optional[typing.Dict[str, typing.Any]], secret_connect_args: typing.Optional[typing.Dict[str, Secret]], uri: str]. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.)