diff --git a/plugins/flytekit-spark/flytekitplugins/spark/connector.py b/plugins/flytekit-spark/flytekitplugins/spark/connector.py index 5adb9157c4..06b02048d1 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/connector.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/connector.py @@ -1,5 +1,6 @@ import http import json +import logging import os import typing from dataclasses import dataclass @@ -13,12 +14,14 @@ from flytekit.extend.backend.utils import convert_to_flyte_phase, get_connector_secret from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap -from flytekit.models.task import TaskTemplate +from flytekit.models.task import TaskExecutionMetadata, TaskTemplate from .utils import is_serverless_config as _is_serverless_config aiohttp = lazy_module("aiohttp") +logger = logging.getLogger(__name__) + DATABRICKS_API_ENDPOINT = "/api/2.1/jobs" DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE" DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY = "FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER" @@ -28,6 +31,7 @@ class DatabricksJobMetadata(ResourceMeta): databricks_instance: str run_id: str + auth_token: Optional[str] = None # Store auth token for get/delete operations def _configure_serverless(databricks_job: dict, envs: dict) -> str: @@ -252,7 +256,11 @@ def __init__(self): super().__init__(task_type_name="spark", metadata_type=DatabricksJobMetadata) async def create( - self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + task_execution_metadata: Optional[TaskExecutionMetadata] = None, + **kwargs, ) -> DatabricksJobMetadata: data = json.dumps(_get_databricks_job_spec(task_template)) databricks_instance = task_template.custom.get( @@ -264,15 +272,31 @@ async def create( f"Missing databricks instance. Please set the value through the task config or set the {DEFAULT_DATABRICKS_INSTANCE_ENV_KEY} environment variable in the connector." ) + # Get workflow-specific token or fall back to default + namespace = task_execution_metadata.namespace if task_execution_metadata else None + + # Extract custom secret name from task template (if provided) + custom_secret_name = task_template.custom.get("databricksTokenSecret") + + logger.info(f"Creating Databricks job for namespace: {namespace or 'unknown'}") + if custom_secret_name: + logger.info(f"Using custom secret name: {custom_secret_name}") + + auth_token = get_databricks_token( + namespace=namespace, task_template=task_template, secret_name=custom_secret_name + ) databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/submit" async with aiohttp.ClientSession() as session: - async with session.post(databricks_url, headers=get_header(), data=data) as resp: + async with session.post(databricks_url, headers=get_header(auth_token=auth_token), data=data) as resp: response = await resp.json() if resp.status != http.HTTPStatus.OK: raise RuntimeError(f"Failed to create databricks job with error: {response}") - return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"])) + logger.info(f"Successfully created Databricks job with run_id: {response['run_id']}") + return DatabricksJobMetadata( + databricks_instance=databricks_instance, run_id=str(response["run_id"]), auth_token=auth_token + ) async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: databricks_instance = resource_meta.databricks_instance @@ -280,8 +304,11 @@ async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}" ) + # Use the stored auth token if available, otherwise fall back to default + headers = get_header(auth_token=resource_meta.auth_token) + async with aiohttp.ClientSession() as session: - async with session.get(databricks_url, headers=get_header()) as resp: + async with session.get(databricks_url, headers=headers) as resp: if resp.status != http.HTTPStatus.OK: raise RuntimeError(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}") response = await resp.json() @@ -312,8 +339,11 @@ async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs): databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel" data = json.dumps({"run_id": resource_meta.run_id}) + # Use the stored auth token if available, otherwise fall back to default + headers = get_header(auth_token=resource_meta.auth_token) + async with aiohttp.ClientSession() as session: - async with session.post(databricks_url, headers=get_header(), data=data) as resp: + async with session.post(databricks_url, headers=headers, data=data) as resp: if resp.status != http.HTTPStatus.OK: raise RuntimeError( f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}" @@ -334,9 +364,139 @@ def __init__(self): super(DatabricksConnector, self).__init__(task_type_name="databricks", metadata_type=DatabricksJobMetadata) -def get_header() -> typing.Dict[str, str]: - token = get_connector_secret("FLYTE_DATABRICKS_ACCESS_TOKEN") - return {"Authorization": f"Bearer {token}", "content-type": "application/json"} +def get_secret_from_k8s(secret_name: str, secret_key: str, namespace: str) -> Optional[str]: + """Read a secret from Kubernetes using the Kubernetes Python client. + + Args: + secret_name (str): Name of the Kubernetes secret (e.g., "databricks-token"). + secret_key (str): Key within the secret (e.g., "token"). + namespace (str): Kubernetes namespace where the secret is stored. + + Returns: + Optional[str]: The secret value as a string, or None if not found. + """ + try: + import base64 + + from kubernetes import client, config + + # Try to load in-cluster config first (when running in K8s) + try: + config.load_incluster_config() + except config.ConfigException: + # Fall back to kubeconfig (for local testing) + try: + config.load_kube_config() + except Exception as e: + logger.warning(f"Failed to load Kubernetes config: {e}") + return None + + v1 = client.CoreV1Api() + + try: + secret = v1.read_namespaced_secret(name=secret_name, namespace=namespace) + if secret.data and secret_key in secret.data: + # Kubernetes secrets are base64 encoded + secret_value = base64.b64decode(secret.data[secret_key]).decode("utf-8") + return secret_value + else: + logger.debug( + f"Secret '{secret_name}' exists but key '{secret_key}' not found in namespace '{namespace}'" + ) + return None + except client.exceptions.ApiException as e: + if e.status == 404: + logger.debug(f"Secret '{secret_name}' not found in namespace '{namespace}'") + else: + logger.warning(f"Error reading secret '{secret_name}' from namespace '{namespace}': {e}") + return None + + except ImportError: + logger.warning("kubernetes Python package not installed - cannot read namespace secrets") + return None + except Exception as e: + logger.warning(f"Unexpected error reading K8s secret: {e}") + return None + + +def get_databricks_token( + namespace: Optional[str] = None, task_template: Optional[TaskTemplate] = None, secret_name: Optional[str] = None +) -> str: + """Get the Databricks access token with multi-tenant support. + + Token resolution: namespace K8s secret -> FLYTE_DATABRICKS_ACCESS_TOKEN env var. + + Args: + namespace (Optional[str]): Kubernetes namespace for workflow-specific token lookup. + task_template (Optional[TaskTemplate]): Optional TaskTemplate (kept for API compatibility). + secret_name (Optional[str]): Custom secret name. Defaults to 'databricks-token'. + + Returns: + str: The Databricks access token. + + Raises: + ValueError: If no token is found from any source. + """ + token = None + token_source = "unknown" + + # Use custom secret name or default to 'databricks-token' + k8s_secret_name = secret_name or "databricks-token" + + # Step 1: Try namespace-specific K8s secret (cross-namespace lookup) + if namespace: + logger.info(f"Looking for Databricks token in workflow namespace: {namespace} (secret: {k8s_secret_name})") + token = get_secret_from_k8s(secret_name=k8s_secret_name, secret_key="token", namespace=namespace) + + if token: + logger.info(f"Found Databricks token in namespace '{namespace}' from secret '{k8s_secret_name}'") + token_source = f"k8s_namespace:{namespace}/secret:{k8s_secret_name}" + else: + logger.info( + f"Databricks token not found in secret '{k8s_secret_name}' in namespace '{namespace}' - trying fallback" + ) + else: + logger.info("No namespace provided for cross-namespace lookup") + + # Step 2: Fall back to environment variable (backward compatibility) + if token is None: + logger.info("Falling back to default Databricks token (FLYTE_DATABRICKS_ACCESS_TOKEN)") + try: + token = get_connector_secret("FLYTE_DATABRICKS_ACCESS_TOKEN") + token_source = "env_variable" + except Exception as e: + logger.error(f"Failed to get default Databricks token: {e}") + raise ValueError( + "No Databricks token found from any source:\n" + f"1. Namespace-specific K8s secret '{k8s_secret_name}'\n" + "2. FLYTE_DATABRICKS_ACCESS_TOKEN environment variable\n" + f"Workflow namespace: {namespace or 'N/A'}" + ) + + if not token: + raise ValueError("Databricks token is empty") + + # Log token info without exposing the actual token value + token_preview = f"{token[:8]}..." if len(token) > 8 else "***" + logger.info(f"Using Databricks token from: {token_source} (preview: {token_preview})") + + return token + + +def get_header(task_template: Optional[TaskTemplate] = None, auth_token: Optional[str] = None) -> typing.Dict[str, str]: + """Get the authorization header for Databricks API calls. + + Args: + task_template (Optional[TaskTemplate]): TaskTemplate with workflow-specific secret requests. + auth_token (Optional[str]): Pre-fetched auth token to use directly. + + Returns: + typing.Dict[str, str]: Authorization and content-type headers. + """ + if auth_token is None: + auth_token = get_databricks_token(task_template) + + return {"Authorization": f"Bearer {auth_token}", "content-type": "application/json"} def result_state_is_available(life_cycle_state: str) -> bool: diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 6492183969..6c447e8bd4 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -87,6 +87,8 @@ class DatabricksV2(Spark): Use the form .cloud.databricks.com. databricks_service_credential_provider (Optional[str]): Provider name for Databricks Service Credentials for S3 access. Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var. + databricks_token_secret (Optional[str]): Custom name for the K8s secret containing + the Databricks token. Defaults to 'databricks-token' if not specified. notebook_path (Optional[str]): Path to Databricks notebook (e.g., "/Users/user@example.com/notebook"). notebook_base_parameters (Optional[Dict[str, str]]): Parameters to pass to the notebook. @@ -194,12 +196,11 @@ class DatabricksV2(Spark): """ databricks_conf: Optional[Dict[str, Union[str, dict]]] = None - databricks_instance: Optional[str] = None # Falls back to FLYTE_DATABRICKS_INSTANCE env var - databricks_service_credential_provider: Optional[str] = ( - None # Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var - ) - notebook_path: Optional[str] = None # Path to Databricks notebook (e.g., "/Users/user@example.com/notebook") - notebook_base_parameters: Optional[Dict[str, str]] = None # Parameters to pass to the notebook + databricks_instance: Optional[str] = None + databricks_service_credential_provider: Optional[str] = None + databricks_token_secret: Optional[str] = None + notebook_path: Optional[str] = None + notebook_base_parameters: Optional[Dict[str, str]] = None # This method does not reset the SparkSession since it's a bit hard to handle multiple @@ -311,6 +312,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: cfg = cast(DatabricksV2, self.task_config) if cfg.databricks_service_credential_provider: custom_dict["databricksServiceCredentialProvider"] = cfg.databricks_service_credential_provider + if cfg.databricks_token_secret: + custom_dict["databricksTokenSecret"] = cfg.databricks_token_secret if cfg.notebook_path: custom_dict["notebookPath"] = cfg.notebook_path if cfg.notebook_base_parameters: @@ -479,7 +482,7 @@ def execute(self, **kwargs) -> Any: if ctx.execution_state and ctx.execution_state.is_local_execution(): return AsyncConnectorExecutorMixin.execute(self, **kwargs) except Exception as e: - click.secho(f"❌ Connector failed to run the task with error: {e}", fg="red") + click.secho(f"Connector failed to run the task with error: {e}", fg="red") click.secho("Falling back to local execution", fg="red") return PythonFunctionTask.execute(self, **kwargs) diff --git a/plugins/flytekit-spark/tests/test_connector.py b/plugins/flytekit-spark/tests/test_connector.py index b2828de9fa..0fc4effbb9 100644 --- a/plugins/flytekit-spark/tests/test_connector.py +++ b/plugins/flytekit-spark/tests/test_connector.py @@ -1,29 +1,28 @@ import http import json +import os from datetime import timedelta from unittest import mock import pytest from aioresponses import aioresponses from flyteidl.core.execution_pb2 import TaskExecution - -from flytekit.core.constants import FLYTE_FAIL_ON_ERROR from flytekitplugins.spark.connector import ( DATABRICKS_API_ENDPOINT, + DEFAULT_DATABRICKS_INSTANCE_ENV_KEY, DatabricksJobMetadata, - get_header, + _configure_serverless, _get_databricks_job_spec, _is_serverless_config, - _configure_serverless, - DEFAULT_DATABRICKS_INSTANCE_ENV_KEY, + get_header, ) +from flytekit.core.constants import FLYTE_FAIL_ON_ERROR from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interfaces.cli_identifiers import Identifier from flytekit.models import literals, task from flytekit.models.core.identifier import ResourceType from flytekit.models.task import Container, Resources, TaskTemplate -import os @pytest.fixture(scope="function") @@ -61,7 +60,7 @@ def task_template() -> TaskTemplate: }, "timeout_seconds": 3600, "max_retries": 1, - } + }, } container = Container( image="flyteorg/flytekit:databricks-0.18.0-py3.7", @@ -125,6 +124,7 @@ async def test_databricks_agent(task_template: TaskTemplate): databricks_metadata = DatabricksJobMetadata( databricks_instance="test-account.cloud.databricks.com", run_id="123", + auth_token=mocked_token, ) mock_create_response = {"run_id": "123"} @@ -142,7 +142,7 @@ async def test_databricks_agent(task_template: TaskTemplate): res = await agent.create(task_template, None) spec = _get_databricks_job_spec(task_template) data = json.dumps(spec) - mocked.assert_called_with(create_url, method="POST", data=data, headers=get_header()) + mocked.assert_called_with(create_url, method="POST", data=data, headers=get_header(auth_token=mocked_token)) spark_envs = spec["new_cluster"]["spark_env_vars"] assert spark_envs["foo"] == "bar" assert spark_envs[FLYTE_FAIL_ON_ERROR] == "true" @@ -159,7 +159,10 @@ async def test_databricks_agent(task_template: TaskTemplate): mocked.post(delete_url, status=http.HTTPStatus.OK, payload=mock_delete_response) await agent.delete(databricks_metadata) - assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"} + assert get_header(auth_token=mocked_token) == { + "Authorization": f"Bearer {mocked_token}", + "content-type": "application/json", + } mock.patch.stopall() @@ -168,7 +171,7 @@ async def test_databricks_agent(task_template: TaskTemplate): async def test_agent_create_with_no_instance(task_template: TaskTemplate): agent = AgentRegistry.get_agent("spark") - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError): await agent.create(task_template, None) @@ -183,6 +186,7 @@ async def test_agent_create_with_default_instance(task_template: TaskTemplate): databricks_metadata = DatabricksJobMetadata( databricks_instance="test-account.cloud.databricks.com", run_id="123", + auth_token=mocked_token, ) mock_create_response = {"run_id": "123"} @@ -195,7 +199,7 @@ async def test_agent_create_with_default_instance(task_template: TaskTemplate): res = await agent.create(task_template, None) spec = _get_databricks_job_spec(task_template) data = json.dumps(spec) - mocked.assert_called_with(create_url, method="POST", data=data, headers=get_header()) + mocked.assert_called_with(create_url, method="POST", data=data, headers=get_header(auth_token=mocked_token)) assert res == databricks_metadata mock.patch.stopall() @@ -235,7 +239,7 @@ def serverless_task_template_with_env_key() -> TaskTemplate: "git_branch": "main", }, "python_file": "entrypoint_serverless.py", - } + }, } container = Container( image="flyteorg/flytekit:databricks-0.18.0-py3.7", @@ -280,13 +284,15 @@ def serverless_task_template_with_inline_env() -> TaskTemplate: "databricksConf": { "run_name": "flytekit serverless job with inline env", "environment_key": "default", - "environments": [{ - "environment_key": "default", - "spec": { - "client": "1", - "dependencies": ["pandas==2.0.0"], + "environments": [ + { + "environment_key": "default", + "spec": { + "client": "1", + "dependencies": ["pandas==2.0.0"], + }, } - }], + ], "timeout_seconds": 3600, "git_source": { "git_url": "https://github.com/test-org/test-repo", @@ -294,7 +300,7 @@ def serverless_task_template_with_inline_env() -> TaskTemplate: "git_branch": "main", }, "python_file": "entrypoint_serverless.py", - } + }, } container = Container( image="flyteorg/flytekit:databricks-0.18.0-py3.7", @@ -339,14 +345,16 @@ def serverless_task_template_no_git_source() -> TaskTemplate: "databricksConf": { "run_name": "flytekit serverless job - no git source", "environment_key": "default", - "environments": [{ - "environment_key": "default", - "spec": { - "client": "4", + "environments": [ + { + "environment_key": "default", + "spec": { + "client": "4", + }, } - }], + ], "timeout_seconds": 3600, - } + }, } container = Container( image="flyteorg/flytekit:databricks-0.18.0-py3.7", @@ -391,7 +399,7 @@ def invalid_task_template_no_compute() -> TaskTemplate: "databricksConf": { "run_name": "invalid job - no compute config", "timeout_seconds": 3600, - } + }, } container = Container( image="flyteorg/flytekit:databricks-0.18.0-py3.7", @@ -427,19 +435,15 @@ def test_is_serverless_config_detection(): assert _is_serverless_config({"environments": [{"environment_key": "default"}]}) is True # Serverless with both environment_key and environments - assert _is_serverless_config({ - "environment_key": "default", - "environments": [{"environment_key": "default"}] - }) is True + assert ( + _is_serverless_config({"environment_key": "default", "environments": [{"environment_key": "default"}]}) is True + ) # No compute config at all assert _is_serverless_config({"run_name": "test"}) is False # Has cluster AND environment (cluster takes precedence, not serverless) - assert _is_serverless_config({ - "new_cluster": {"spark_version": "13.3"}, - "environment_key": "my-env" - }) is False + assert _is_serverless_config({"new_cluster": {"spark_version": "13.3"}, "environment_key": "my-env"}) is False def test_configure_serverless_with_env_key_only(): @@ -466,13 +470,15 @@ def test_configure_serverless_with_inline_env(): """Test serverless configuration with inline environment spec.""" databricks_job = { "environment_key": "default", - "environments": [{ - "environment_key": "default", - "spec": { - "client": "1", - "dependencies": ["pandas==2.0.0"], + "environments": [ + { + "environment_key": "default", + "spec": { + "client": "1", + "dependencies": ["pandas==2.0.0"], + }, } - }] + ], } envs = {"FOO": "bar", FLYTE_FAIL_ON_ERROR: "true"} @@ -577,6 +583,7 @@ def test_get_databricks_job_spec_error_no_compute(invalid_task_template_no_compu async def test_databricks_agent_serverless(serverless_task_template_with_env_key: TaskTemplate): """Test the full agent flow with serverless compute.""" import copy + agent = AgentRegistry.get_agent("spark") serverless_task_template_with_env_key.custom["databricksInstance"] = "test-account.cloud.databricks.com" @@ -601,6 +608,7 @@ async def test_databricks_agent_serverless(serverless_task_template_with_env_key databricks_metadata = DatabricksJobMetadata( databricks_instance="test-account.cloud.databricks.com", run_id="456", + auth_token=mocked_token, ) mock_create_response = {"run_id": "456"} @@ -659,7 +667,9 @@ def test_serverless_task_git_source_overrides_default(serverless_task_template_w assert task_def["spark_python_task"]["python_file"] == "entrypoint_serverless.py" -def test_classic_and_serverless_use_same_repo(task_template: TaskTemplate, serverless_task_template_no_git_source: TaskTemplate): +def test_classic_and_serverless_use_same_repo( + task_template: TaskTemplate, serverless_task_template_no_git_source: TaskTemplate +): """Test that both classic and serverless default to the same flytetools repo.""" classic_spec = _get_databricks_job_spec(task_template) serverless_spec = _get_databricks_job_spec(serverless_task_template_no_git_source) @@ -670,4 +680,7 @@ def test_classic_and_serverless_use_same_repo(task_template: TaskTemplate, serve assert classic_spec["git_source"]["git_commit"] == serverless_spec["git_source"]["git_commit"] # Different python_file assert classic_spec["spark_python_task"]["python_file"] == "flytekitplugins/databricks/entrypoint.py" - assert serverless_spec["tasks"][0]["spark_python_task"]["python_file"] == "flytekitplugins/databricks/entrypoint_serverless.py" + assert ( + serverless_spec["tasks"][0]["spark_python_task"]["python_file"] + == "flytekitplugins/databricks/entrypoint_serverless.py" + ) diff --git a/plugins/flytekit-spark/tests/test_databricks_token.py b/plugins/flytekit-spark/tests/test_databricks_token.py new file mode 100644 index 0000000000..9bed13bb56 --- /dev/null +++ b/plugins/flytekit-spark/tests/test_databricks_token.py @@ -0,0 +1,654 @@ +""" +Comprehensive tests for the Databricks per-project token support feature. + +Tests cover: +- get_secret_from_k8s: Cross-namespace K8s secret reading +- get_databricks_token: Token resolution strategy (K8s -> env var fallback) +- get_header: Authorization header generation with token support +- DatabricksConnector.create/get/delete: Token persistence in job metadata +- DatabricksV2 task config: databricks_token_secret serialization +""" + +import base64 +import http +from datetime import timedelta +from unittest.mock import MagicMock, patch + +import pytest +from aioresponses import aioresponses +from flytekitplugins.spark.connector import ( + DATABRICKS_API_ENDPOINT, + DatabricksJobMetadata, + get_databricks_token, + get_header, + get_secret_from_k8s, +) + +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals, task +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import Container, Resources, TaskExecutionMetadata, TaskTemplate + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="function") +def task_template() -> TaskTemplate: + """Standard Databricks task template for testing.""" + task_id = Identifier( + resource_type=ResourceType.TASK, + project="project", + domain="domain", + name="name", + version="version", + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + (), + ) + task_config = { + "sparkConf": { + "spark.driver.memory": "1000M", + "spark.executor.memory": "1000M", + "spark.executor.cores": "1", + "spark.executor.instances": "2", + "spark.driver.cores": "1", + }, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "flytekit databricks plugin example", + "new_cluster": { + "spark_version": "12.2.x-scala2.12", + "node_type_id": "n2-highmem-4", + "num_workers": 1, + }, + "timeout_seconds": 3600, + "max_retries": 1, + }, + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=[ + "pyflyte-fast-execute", + "--additional-distribution", + "s3://my-s3-bucket/flytesnacks/development/24UYJEF2HDZQN3SG4VAZSM4PLI======/script_mode.tar.gz", + "--dest-dir", + "/root", + "--", + "pyflyte-execute", + "--inputs", + "s3://my-s3-bucket", + "--output-prefix", + "s3://my-s3-bucket", + "--raw-output-data-prefix", + "s3://my-s3-bucket", + "--checkpoint-path", + "s3://my-s3-bucket", + "--prev-checkpoint", + "s3://my-s3-bucket", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "spark_local_example", + "task-name", + "hello_spark", + ], + resources=Resources(requests=[], limits=[]), + env={"foo": "bar"}, + config={}, + ) + return TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + + +@pytest.fixture(scope="function") +def task_template_with_custom_secret(task_template) -> TaskTemplate: + """Task template with a custom databricksTokenSecret.""" + task_template.custom["databricksTokenSecret"] = "my-team-databricks-token" + return task_template + + +def _make_task_execution_metadata(namespace: str) -> TaskExecutionMetadata: + """Helper to create a TaskExecutionMetadata with a given namespace.""" + task_exec_id = MagicMock() + return TaskExecutionMetadata( + task_execution_id=task_exec_id, + namespace=namespace, + labels={}, + annotations={}, + k8s_service_account="default", + environment_variables={}, + identity=MagicMock(), + ) + + +# =========================================================================== +# Tests for get_secret_from_k8s +# =========================================================================== + + +class TestGetSecretFromK8s: + """Tests for cross-namespace Kubernetes secret reading.""" + + def test_success(self): + """Secret found via mocked Kubernetes API client.""" + mock_secret = MagicMock() + mock_secret.data = {"token": base64.b64encode(b"dapi_real_token_xyz").decode()} + + with patch("kubernetes.client.CoreV1Api") as mock_api_cls, patch("kubernetes.config.load_incluster_config"): + mock_api_cls.return_value.read_namespaced_secret.return_value = mock_secret + result = get_secret_from_k8s("databricks-token", "token", "production") + assert result == "dapi_real_token_xyz" + mock_api_cls.return_value.read_namespaced_secret.assert_called_once_with( + name="databricks-token", namespace="production" + ) + + def test_secret_not_found_404(self): + """Secret doesn't exist in the namespace (404).""" + from kubernetes.client.exceptions import ApiException + + with patch("kubernetes.client.CoreV1Api") as mock_api_cls, patch("kubernetes.config.load_incluster_config"): + mock_api_cls.return_value.read_namespaced_secret.side_effect = ApiException(status=404) + result = get_secret_from_k8s("databricks-token", "token", "staging") + assert result is None + + def test_secret_key_missing(self): + """Secret exists but the 'token' key is not present.""" + mock_secret = MagicMock() + mock_secret.data = {"other_key": base64.b64encode(b"something").decode()} + + with patch("kubernetes.client.CoreV1Api") as mock_api_cls, patch("kubernetes.config.load_incluster_config"): + mock_api_cls.return_value.read_namespaced_secret.return_value = mock_secret + result = get_secret_from_k8s("databricks-token", "token", "production") + assert result is None + + def test_secret_data_is_none(self): + """Secret exists but data field is None.""" + mock_secret = MagicMock() + mock_secret.data = None + + with patch("kubernetes.client.CoreV1Api") as mock_api_cls, patch("kubernetes.config.load_incluster_config"): + mock_api_cls.return_value.read_namespaced_secret.return_value = mock_secret + result = get_secret_from_k8s("databricks-token", "token", "production") + assert result is None + + def test_kubernetes_import_error(self): + """kubernetes package not installed - graceful fallback.""" + import builtins + + real_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "kubernetes" or name.startswith("kubernetes."): + raise ImportError("No module named 'kubernetes'") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + result = get_secret_from_k8s("databricks-token", "token", "production") + assert result is None + + def test_api_exception_non_404(self): + """Non-404 API exception is logged as warning, returns None.""" + from kubernetes.client.exceptions import ApiException + + with patch("kubernetes.client.CoreV1Api") as mock_api_cls, patch("kubernetes.config.load_incluster_config"): + mock_api_cls.return_value.read_namespaced_secret.side_effect = ApiException(status=403) + result = get_secret_from_k8s("databricks-token", "token", "restricted-ns") + assert result is None + + def test_kubeconfig_fallback(self): + """Falls back to kubeconfig when in-cluster config fails.""" + from kubernetes.config import ConfigException + + mock_secret = MagicMock() + mock_secret.data = {"token": base64.b64encode(b"local_token").decode()} + + with ( + patch("kubernetes.client.CoreV1Api") as mock_api_cls, + patch("kubernetes.config.load_incluster_config", side_effect=ConfigException("not in cluster")), + patch("kubernetes.config.load_kube_config"), + ): + mock_api_cls.return_value.read_namespaced_secret.return_value = mock_secret + result = get_secret_from_k8s("databricks-token", "token", "dev") + assert result == "local_token" + + def test_both_configs_fail(self): + """Both in-cluster and kubeconfig fail - returns None.""" + from kubernetes.config import ConfigException + + with ( + patch("kubernetes.config.load_incluster_config", side_effect=ConfigException("not in cluster")), + patch("kubernetes.config.load_kube_config", side_effect=Exception("no kubeconfig")), + ): + result = get_secret_from_k8s("databricks-token", "token", "orphan-ns") + assert result is None + + +# =========================================================================== +# Tests for get_databricks_token +# =========================================================================== + + +class TestGetDatabricksToken: + """Tests for the multi-tenant token resolution strategy.""" + + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_token_from_namespace_secret(self, mock_k8s): + """Token found in namespace-specific K8s secret.""" + mock_k8s.return_value = "dapi_namespace_token" + token = get_databricks_token(namespace="team-a") + assert token == "dapi_namespace_token" + mock_k8s.assert_called_once_with( + secret_name="databricks-token", + secret_key="token", + namespace="team-a", + ) + + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_token_with_custom_secret_name(self, mock_k8s): + """Token found using a custom K8s secret name.""" + mock_k8s.return_value = "dapi_custom_secret_token" + token = get_databricks_token(namespace="team-b", secret_name="my-db-token") + assert token == "dapi_custom_secret_token" + mock_k8s.assert_called_once_with( + secret_name="my-db-token", + secret_key="token", + namespace="team-b", + ) + + @patch("flytekitplugins.spark.connector.get_connector_secret") + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_fallback_to_env_when_namespace_secret_missing(self, mock_k8s, mock_env): + """Falls back to FLYTE_DATABRICKS_ACCESS_TOKEN when namespace secret not found.""" + mock_k8s.return_value = None + mock_env.return_value = "dapi_env_fallback_token" + token = get_databricks_token(namespace="team-c") + assert token == "dapi_env_fallback_token" + mock_env.assert_called_once_with("FLYTE_DATABRICKS_ACCESS_TOKEN") + + @patch("flytekitplugins.spark.connector.get_connector_secret") + def test_no_namespace_falls_back_to_env(self, mock_env): + """When no namespace is provided, goes directly to env var.""" + mock_env.return_value = "dapi_default_token" + token = get_databricks_token(namespace=None) + assert token == "dapi_default_token" + mock_env.assert_called_once_with("FLYTE_DATABRICKS_ACCESS_TOKEN") + + @patch("flytekitplugins.spark.connector.get_connector_secret") + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_no_token_from_any_source_raises(self, mock_k8s, mock_env): + """ValueError raised when neither K8s secret nor env var has a token.""" + mock_k8s.return_value = None + mock_env.side_effect = Exception("Secret not found") + with pytest.raises(ValueError, match="No Databricks token found from any source"): + get_databricks_token(namespace="orphan-ns") + + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_empty_token_raises(self, mock_k8s): + """ValueError raised when token is an empty string.""" + mock_k8s.return_value = "" + with pytest.raises(ValueError, match="Databricks token is empty"): + get_databricks_token(namespace="team-d") + + @patch("flytekitplugins.spark.connector.get_connector_secret") + @patch("flytekitplugins.spark.connector.get_secret_from_k8s") + def test_default_secret_name_is_databricks_token(self, mock_k8s, mock_env): + """Default secret name is 'databricks-token' when not specified.""" + mock_k8s.return_value = "found_it" + get_databricks_token(namespace="ns-1") + mock_k8s.assert_called_once_with( + secret_name="databricks-token", + secret_key="token", + namespace="ns-1", + ) + + @patch("flytekitplugins.spark.connector.get_connector_secret") + def test_backward_compatibility_no_namespace_no_secret(self, mock_env): + """Backward compatible: no namespace + no secret name = env var only.""" + mock_env.return_value = "dapi_legacy_token" + token = get_databricks_token() + assert token == "dapi_legacy_token" + + +# =========================================================================== +# Tests for get_header +# =========================================================================== + + +class TestGetHeader: + """Tests for authorization header generation.""" + + def test_with_preresolved_auth_token(self): + """Header uses pre-fetched auth token directly.""" + headers = get_header(auth_token="dapi_preresolved_123") + assert headers == { + "Authorization": "Bearer dapi_preresolved_123", + "content-type": "application/json", + } + + @patch("flytekitplugins.spark.connector.get_databricks_token") + def test_without_auth_token_resolves(self, mock_get_token): + """Header resolves token via get_databricks_token when not provided.""" + mock_get_token.return_value = "dapi_resolved_456" + headers = get_header() + assert headers["Authorization"] == "Bearer dapi_resolved_456" + mock_get_token.assert_called_once() + + +# =========================================================================== +# Tests for DatabricksJobMetadata +# =========================================================================== + + +class TestDatabricksJobMetadata: + """Tests for token persistence in job metadata.""" + + def test_metadata_stores_auth_token(self): + """Auth token is stored in metadata.""" + meta = DatabricksJobMetadata( + databricks_instance="test.cloud.databricks.com", + run_id="42", + auth_token="dapi_persistent_token", + ) + assert meta.auth_token == "dapi_persistent_token" + + def test_metadata_auth_token_defaults_to_none(self): + """Auth token defaults to None for backward compatibility.""" + meta = DatabricksJobMetadata( + databricks_instance="test.cloud.databricks.com", + run_id="42", + ) + assert meta.auth_token is None + + +# =========================================================================== +# Tests for DatabricksConnector create/get/delete with token +# =========================================================================== + + +class TestDatabricksConnectorWithToken: + """Integration tests for connector operations with per-project tokens.""" + + @pytest.mark.asyncio + async def test_create_uses_namespace_token(self, task_template): + """create() fetches token from namespace and uses it in API call.""" + task_template.custom["databricksInstance"] = "test.cloud.databricks.com" + agent = AgentRegistry.get_agent("spark") + metadata = _make_task_execution_metadata("project-alpha") + + with patch("flytekitplugins.spark.connector.get_databricks_token") as mock_token: + mock_token.return_value = "dapi_project_alpha_token" + create_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" + + with aioresponses() as mocked: + mocked.post(create_url, status=http.HTTPStatus.OK, payload={"run_id": "999"}) + result = await agent.create(task_template, None, task_execution_metadata=metadata) + + mock_token.assert_called_once_with( + namespace="project-alpha", + task_template=task_template, + secret_name=None, + ) + assert result.auth_token == "dapi_project_alpha_token" + assert result.run_id == "999" + + @pytest.mark.asyncio + async def test_create_with_custom_secret_name(self, task_template_with_custom_secret): + """create() passes custom secret name from task template.""" + tt = task_template_with_custom_secret + tt.custom["databricksInstance"] = "test.cloud.databricks.com" + agent = AgentRegistry.get_agent("spark") + metadata = _make_task_execution_metadata("team-x") + + with patch("flytekitplugins.spark.connector.get_databricks_token") as mock_token: + mock_token.return_value = "dapi_team_x_token" + create_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" + + with aioresponses() as mocked: + mocked.post(create_url, status=http.HTTPStatus.OK, payload={"run_id": "888"}) + result = await agent.create(tt, None, task_execution_metadata=metadata) + + mock_token.assert_called_once_with( + namespace="team-x", + task_template=tt, + secret_name="my-team-databricks-token", + ) + assert result.auth_token == "dapi_team_x_token" + + @pytest.mark.asyncio + async def test_create_without_metadata_uses_no_namespace(self, task_template): + """create() works when task_execution_metadata is None (backward compat).""" + task_template.custom["databricksInstance"] = "test.cloud.databricks.com" + agent = AgentRegistry.get_agent("spark") + + with patch("flytekitplugins.spark.connector.get_databricks_token") as mock_token: + mock_token.return_value = "dapi_default_token" + create_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" + + with aioresponses() as mocked: + mocked.post(create_url, status=http.HTTPStatus.OK, payload={"run_id": "777"}) + await agent.create(task_template, None) + + mock_token.assert_called_once_with( + namespace=None, + task_template=task_template, + secret_name=None, + ) + + @pytest.mark.asyncio + async def test_create_stores_token_in_metadata(self, task_template): + """create() persists auth_token in returned DatabricksJobMetadata.""" + task_template.custom["databricksInstance"] = "test.cloud.databricks.com" + agent = AgentRegistry.get_agent("spark") + metadata = _make_task_execution_metadata("data-team") + + with patch("flytekitplugins.spark.connector.get_databricks_token") as mock_token: + mock_token.return_value = "dapi_data_team_abc" + create_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" + + with aioresponses() as mocked: + mocked.post(create_url, status=http.HTTPStatus.OK, payload={"run_id": "555"}) + result = await agent.create(task_template, None, task_execution_metadata=metadata) + + assert isinstance(result, DatabricksJobMetadata) + assert result.auth_token == "dapi_data_team_abc" + assert result.run_id == "555" + assert result.databricks_instance == "test.cloud.databricks.com" + + @pytest.mark.asyncio + async def test_get_uses_stored_token(self): + """get() uses the auth_token stored in DatabricksJobMetadata.""" + meta = DatabricksJobMetadata( + databricks_instance="test.cloud.databricks.com", + run_id="123", + auth_token="dapi_stored_get_token", + ) + agent = AgentRegistry.get_agent("spark") + get_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=123" + + mock_response = { + "job_id": "1", + "run_id": "123", + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS", "state_message": "OK"}, + } + + with aioresponses() as mocked: + mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_response) + await agent.get(meta) + + # Verify the correct token was used in the request header + call_args = list(mocked.requests.values())[0][0] + assert call_args.kwargs["headers"]["Authorization"] == "Bearer dapi_stored_get_token" + + @pytest.mark.asyncio + async def test_delete_uses_stored_token(self): + """delete() uses the auth_token stored in DatabricksJobMetadata.""" + meta = DatabricksJobMetadata( + databricks_instance="test.cloud.databricks.com", + run_id="456", + auth_token="dapi_stored_delete_token", + ) + agent = AgentRegistry.get_agent("spark") + delete_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/cancel" + + with aioresponses() as mocked: + mocked.post(delete_url, status=http.HTTPStatus.OK, payload={}) + await agent.delete(meta) + + call_args = list(mocked.requests.values())[0][0] + assert call_args.kwargs["headers"]["Authorization"] == "Bearer dapi_stored_delete_token" + + @pytest.mark.asyncio + async def test_get_with_none_token_falls_back(self): + """get() falls back to get_header default when auth_token is None.""" + meta = DatabricksJobMetadata( + databricks_instance="test.cloud.databricks.com", + run_id="789", + auth_token=None, + ) + agent = AgentRegistry.get_agent("spark") + get_url = f"https://test.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=789" + + mock_response = { + "job_id": "2", + "run_id": "789", + "state": {"life_cycle_state": "RUNNING"}, + } + + with patch("flytekitplugins.spark.connector.get_databricks_token") as mock_token: + mock_token.return_value = "dapi_fallback_env_token" + with aioresponses() as mocked: + mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_response) + await agent.get(meta) + + call_args = list(mocked.requests.values())[0][0] + assert call_args.kwargs["headers"]["Authorization"] == "Bearer dapi_fallback_env_token" + + +# =========================================================================== +# Tests for DatabricksV2 task config token secret +# =========================================================================== + + +class TestDatabricksV2TokenSecret: + """Tests for DatabricksV2 databricks_token_secret field and serialization.""" + + def test_token_secret_field_exists(self): + """DatabricksV2 has databricks_token_secret field.""" + from flytekitplugins.spark.task import DatabricksV2 + + config = DatabricksV2( + databricks_conf={"new_cluster": {"spark_version": "12.2.x"}}, + databricks_instance="test.cloud.databricks.com", + databricks_token_secret="my-team-secret", + ) + assert config.databricks_token_secret == "my-team-secret" + + def test_token_secret_defaults_to_none(self): + """databricks_token_secret defaults to None.""" + from flytekitplugins.spark.task import DatabricksV2 + + config = DatabricksV2( + databricks_conf={"new_cluster": {"spark_version": "12.2.x"}}, + databricks_instance="test.cloud.databricks.com", + ) + assert config.databricks_token_secret is None + + def test_get_custom_includes_token_secret(self): + """get_custom() serializes databricksTokenSecret when set.""" + from flytekitplugins.spark.task import DatabricksV2 + + import flytekit + from flytekit.configuration import Image, ImageConfig, SerializationSettings + + databricks_conf = { + "name": "test", + "new_cluster": { + "spark_version": "12.2.x-scala2.12", + "node_type_id": "r3.xlarge", + "num_workers": 1, + "docker_image": {"url": "test:latest"}, + }, + "timeout_seconds": 3600, + "max_retries": 1, + } + + @flytekit.task( + task_config=DatabricksV2( + databricks_conf=databricks_conf, + databricks_instance="test.cloud.databricks.com", + databricks_token_secret="project-x-token", + ) + ) + def my_task(x: int) -> int: + return x + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + custom = my_task.get_custom(settings) + assert "databricksTokenSecret" in custom + assert custom["databricksTokenSecret"] == "project-x-token" + + def test_get_custom_excludes_token_secret_when_none(self): + """get_custom() does NOT include databricksTokenSecret when None.""" + from flytekitplugins.spark.task import DatabricksV2 + + import flytekit + from flytekit.configuration import Image, ImageConfig, SerializationSettings + + databricks_conf = { + "name": "test", + "new_cluster": { + "spark_version": "12.2.x-scala2.12", + "node_type_id": "r3.xlarge", + "num_workers": 1, + "docker_image": {"url": "test:latest"}, + }, + } + + @flytekit.task( + task_config=DatabricksV2( + databricks_conf=databricks_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def my_task(x: int) -> int: + return x + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + custom = my_task.get_custom(settings) + assert "databricksTokenSecret" not in custom