From 010c0bbdaccf9f42da5c67d1b57ab24eef84c0dd Mon Sep 17 00:00:00 2001 From: Rohit Sharma Date: Tue, 17 Feb 2026 13:45:19 +0000 Subject: [PATCH 1/3] feat(spark): Add Databricks Serverless Compute support Signed-off-by: Rohit Sharma --- Dockerfile.connector | 13 +- .../flytekitplugins/spark/connector.py | 224 +++++++- .../flytekitplugins/spark/task.py | 224 +++++++- .../flytekit-spark/tests/test_connector.py | 483 +++++++++++++++++- .../flytekit-spark/tests/test_spark_task.py | 214 ++++++++ 5 files changed, 1129 insertions(+), 29 deletions(-) diff --git a/Dockerfile.connector b/Dockerfile.connector index bed0fa0160..7eea1d42b9 100644 --- a/Dockerfile.connector +++ b/Dockerfile.connector @@ -7,8 +7,15 @@ ARG VERSION RUN apt-get update && apt-get install build-essential -y \ && pip install uv - -RUN uv pip install --system --no-cache-dir -U flytekit[connector]==$VERSION \ +# Pin pendulum<3.0: Apache Airflow (via flytekitplugins-airflow) imports +# pendulum.tz.timezone() at module load time (airflow/settings.py). +# Pendulum 3.x changed the tz API, causing the connector to crash on startup: +# airflow/settings.py → TIMEZONE = pendulum.tz.timezone("UTC") → AttributeError +# Without this pin, uv resolves to pendulum 3.x which breaks the import chain: +# pyflyte serve connector → load_implicit_plugins → airflow → pendulum → crash +RUN uv pip install --system --no-cache-dir -U \ + "pendulum>=2.0.0,<3.0" \ + flytekit[connector]==$VERSION \ flytekitplugins-airflow==$VERSION \ flytekitplugins-bigquery==$VERSION \ flytekitplugins-k8sdataservice==$VERSION \ @@ -28,4 +35,4 @@ ARG VERSION RUN uv pip install --system --no-cache-dir -U \ flytekitplugins-mmcloud==$VERSION \ - flytekitplugins-spark==$VERSION + flytekitplugins-spark==$VERSION \ No newline at end of file diff --git a/plugins/flytekit-spark/flytekitplugins/spark/connector.py b/plugins/flytekit-spark/flytekitplugins/spark/connector.py index 895c7d153d..628ef057f8 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/connector.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/connector.py @@ -19,6 +19,7 @@ DATABRICKS_API_ENDPOINT = "/api/2.1/jobs" DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE" +DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY = "FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER" @dataclass @@ -27,36 +28,221 @@ class DatabricksJobMetadata(ResourceMeta): run_id: str +def _is_serverless_config(databricks_job: dict) -> bool: + """ + Detect if the configuration is for serverless compute. + Serverless is indicated by having environment_key or environments without cluster config. + """ + # Check if cluster config keys exist (even empty dict counts as cluster config) + has_cluster_config = "existing_cluster_id" in databricks_job or "new_cluster" in databricks_job + has_serverless_config = bool(databricks_job.get("environment_key") or databricks_job.get("environments")) + return not has_cluster_config and has_serverless_config + + +def _configure_serverless(databricks_job: dict, envs: dict) -> str: + """ + Configure serverless compute settings and return the environment_key to use. + + Databricks serverless requires the 'environments' array to be defined in the job + submission. This function ensures the environments array exists and injects + Flyte environment variables. + + Args: + databricks_job: The databricks job configuration dict + envs: Environment variables to inject + + Returns: + The environment_key to use for the task + """ + environment_key = databricks_job.get("environment_key", "default") + environments = databricks_job.get("environments", []) + + # Check if environment already exists in the array + env_exists = any(env.get("environment_key") == environment_key for env in environments) + + if not env_exists: + # Create the environment entry - Databricks serverless requires environments + # to be defined in the job submission (not externally pre-configured) + new_env = { + "environment_key": environment_key, + "spec": { + "client": "1", # Required: Databricks serverless client version + } + } + environments.append(new_env) + databricks_job["environments"] = environments + + # Inject Flyte environment variables into the environment spec + for env in environments: + if env.get("environment_key") == environment_key: + spec = env.setdefault("spec", {}) + existing_env_vars = spec.get("environment_vars", {}) + # Merge Flyte env vars with any existing ones (Flyte vars take precedence) + merged_env_vars = {**existing_env_vars, **{k: v for k, v in envs.items()}} + spec["environment_vars"] = merged_env_vars + break + + # Remove environment_key from top level (it's now in the task definition) + databricks_job.pop("environment_key", None) + + return environment_key + + +def _configure_classic_cluster(databricks_job: dict, custom: dict, container, envs: dict) -> None: + """ + Configure classic compute (existing cluster or new cluster). + + Args: + databricks_job: The databricks job configuration dict + custom: The custom config from task template + container: The container config from task template + envs: Environment variables to inject + """ + if databricks_job.get("existing_cluster_id") is not None: + # Using an existing cluster, no additional configuration needed + return + + new_cluster = databricks_job.get("new_cluster") + if new_cluster is None: + raise ValueError( + "Either existing_cluster_id, new_cluster, environment_key, or environments must be specified" + ) + + if not new_cluster.get("docker_image"): + new_cluster["docker_image"] = {"url": container.image} + if not new_cluster.get("spark_conf"): + new_cluster["spark_conf"] = custom.get("sparkConf", {}) + if not new_cluster.get("spark_env_vars"): + new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()} + else: + new_cluster["spark_env_vars"].update({k: v for k, v in envs.items()}) + + def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: custom = task_template.custom container = task_template.container envs = task_template.container.env envs[FLYTE_FAIL_ON_ERROR] = "true" databricks_job = custom["databricksConf"] - if databricks_job.get("existing_cluster_id") is None: - new_cluster = databricks_job.get("new_cluster") - if new_cluster is None: - raise ValueError("Either existing_cluster_id or new_cluster must be specified") - if not new_cluster.get("docker_image"): - new_cluster["docker_image"] = {"url": container.image} - if not new_cluster.get("spark_conf"): - new_cluster["spark_conf"] = custom.get("sparkConf", {}) - if not new_cluster.get("spark_env_vars"): - new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()} + + # Check if this is a notebook task + notebook_path = custom.get("notebookPath") + notebook_base_parameters = custom.get("notebookBaseParameters", {}) + + # Determine compute mode and configure accordingly + is_serverless = _is_serverless_config(databricks_job) + + if notebook_path: + # Notebook task - runs a Databricks notebook + notebook_task = { + "notebook_path": notebook_path, + } + if notebook_base_parameters: + notebook_task["base_parameters"] = notebook_base_parameters + + # Check if notebook should be sourced from git + user_git_source = databricks_job.get("git_source") + if user_git_source: + notebook_task["source"] = "GIT" + # Set git_source at job level + databricks_job["git_source"] = user_git_source + + if is_serverless: + # Serverless notebook task + environment_key = _configure_serverless(databricks_job, envs) + + task_def = { + "task_key": "flyte_notebook_task", + "notebook_task": notebook_task, + "environment_key": environment_key, + } + databricks_job["tasks"] = [task_def] + databricks_job.pop("environment_key", None) else: - new_cluster["spark_env_vars"].update({k: v for k, v in envs.items()}) - # https://docs.databricks.com/api/workspace/jobs/submit - databricks_job["spark_python_task"] = { - "python_file": "flytekitplugins/databricks/entrypoint.py", - "source": "GIT", - "parameters": container.args, - } - databricks_job["git_source"] = { + # Classic compute notebook task + _configure_classic_cluster(databricks_job, custom, container, envs) + databricks_job["notebook_task"] = notebook_task + + # Clean up git_source from databricks_job if it was there (already set at job level) + databricks_job.pop("git_source", None) + if user_git_source: + databricks_job["git_source"] = user_git_source + + return databricks_job + + # Python file task (original behavior) + # Allow custom git_source and python_file override from user config + user_git_source = databricks_job.get("git_source") + user_python_file = databricks_job.get("python_file") + + # Default entrypoints from the flytetools repo. + # Both classic and serverless use the same repo; only the python_file differs. + default_git_source = { "git_url": "https://github.com/flyteorg/flytetools", "git_provider": "gitHub", - # https://github.com/flyteorg/flytetools/commit/572298df1f971fb58c258398bd70a6372f811c96 "git_commit": "572298df1f971fb58c258398bd70a6372f811c96", } + default_classic_python_file = "flytekitplugins/databricks/entrypoint.py" + default_serverless_python_file = "flytekitplugins/databricks/entrypoint_serverless.py" + + if is_serverless: + # Serverless compute - use flytetools serverless entrypoint by default + git_source = user_git_source if user_git_source else default_git_source + python_file = user_python_file if user_python_file else default_serverless_python_file + + # Serverless requires multi-task format with tasks array + environment_key = _configure_serverless(databricks_job, envs) + + # Build parameters list - append credential provider if specified + # This allows the entrypoint to receive the credential provider via command line + # We append at the END to avoid breaking pyflyte-fast-execute which must be first + parameters = list(container.args) if container.args else [] + + # Resolve service credential provider: task config > env var + service_credential_provider = custom.get( + "databricksServiceCredentialProvider", + os.getenv(DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY) + ) + if service_credential_provider: + # Append as a special argument that the entrypoint will parse and remove + parameters.append(f"--flyte-credential-provider={service_credential_provider}") + + spark_python_task = { + "python_file": python_file, + "source": "GIT", + "parameters": parameters, + } + + # Build the task definition for serverless + task_def = { + "task_key": "flyte_task", + "spark_python_task": spark_python_task, + "environment_key": environment_key, + } + + # Add tasks array for serverless (required by Databricks API) + databricks_job["tasks"] = [task_def] + + # Remove environment_key from top level (it's now in the task) + databricks_job.pop("environment_key", None) + else: + # Classic compute - use flytetools entrypoint by default + git_source = user_git_source if user_git_source else default_git_source + python_file = user_python_file if user_python_file else default_classic_python_file + + spark_python_task = { + "python_file": python_file, + "source": "GIT", + "parameters": container.args, + } + + _configure_classic_cluster(databricks_job, custom, container, envs) + databricks_job["spark_python_task"] = spark_python_task + + # Set git_source (remove from user config if it was there to avoid duplication) + databricks_job.pop("git_source", None) + databricks_job.pop("python_file", None) + databricks_job["git_source"] = git_source return databricks_job diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 5801d24fde..c8b09447dd 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -73,17 +73,123 @@ def __post_init__(self): class DatabricksV2(Spark): """ Use this to configure a Databricks task. Task's marked with this will automatically execute - natively onto databricks platform as a distributed execution of spark + natively onto databricks platform as a distributed execution of spark. + + Supports both classic compute (clusters) and serverless compute. Args: databricks_conf: Databricks job configuration compliant with API version 2.1, supporting 2.0 use cases. - For the configuration structure, visit here.https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure - For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html + For the configuration structure, visit: https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. + + Compute Modes: + The connector auto-detects the compute mode based on the databricks_conf contents: + + 1. Classic Compute (existing cluster): + Provide `existing_cluster_id` in databricks_conf. + + 2. Classic Compute (new cluster): + Provide `new_cluster` configuration in databricks_conf. + + 3. Serverless Compute (pre-configured environment): + Provide `environment_key` referencing a pre-configured environment in Databricks. + Do not include `existing_cluster_id` or `new_cluster`. + + 4. Serverless Compute (inline environment spec): + Provide `environments` array with environment specifications. + Optionally include `environment_key` to specify which environment to use. + Do not include `existing_cluster_id` or `new_cluster`. + + Example - Classic Compute with new cluster:: + + DatabricksV2( + databricks_conf={ + "run_name": "my-spark-job", + "new_cluster": { + "spark_version": "13.3.x-scala2.12", + "node_type_id": "m5.xlarge", + "num_workers": 2, + }, + }, + databricks_instance="my-workspace.cloud.databricks.com", + ) + + Example - Serverless Compute with pre-configured environment:: + + DatabricksV2( + databricks_conf={ + "run_name": "my-serverless-job", + "environment_key": "my-preconfigured-env", + }, + databricks_instance="my-workspace.cloud.databricks.com", + ) + + Example - Serverless Compute with inline environment spec:: + + DatabricksV2( + databricks_conf={ + "run_name": "my-serverless-job", + "environment_key": "default", + "environments": [{ + "environment_key": "default", + "spec": { + "client": "1", + "dependencies": ["pandas==2.0.0", "numpy==1.24.0"], + } + }], + }, + databricks_instance="my-workspace.cloud.databricks.com", + ) + + Note: + Serverless compute has certain limitations compared to classic compute: + - Only Python and SQL are supported (no Scala or R) + - Only Spark Connect APIs are supported (no RDD APIs) + - Must use Unity Catalog for external data sources + - No support for compute-scoped init scripts or libraries + For full details, see: https://docs.databricks.com/en/compute/serverless/limitations.html + + Serverless Entrypoint: + Both classic and serverless use the same ``flytetools`` repo for their entrypoints. + Classic uses ``flytekitplugins/databricks/entrypoint.py`` and serverless uses + ``flytekitplugins/databricks/entrypoint_serverless.py``. No additional configuration needed. + + To override the default, provide ``git_source`` and ``python_file`` in ``databricks_conf``. + + AWS Credentials for Serverless: + Databricks serverless does not provide AWS credentials via instance metadata. + To access S3 (for Flyte data), configure a Databricks Service Credential. + + The provider name is resolved in this order: + 1. ``databricks_service_credential_provider`` in the task config (per-task override) + 2. ``FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER`` environment variable on the connector (default for all tasks) + + The entrypoint will use this to obtain AWS credentials via: + dbutils.credentials.getServiceCredentialsProvider(provider_name) + + Notebook Support: + To run a Databricks notebook instead of a Python file, set `notebook_path`. + Parameters can be passed via `notebook_base_parameters`. + + Example - Running a notebook:: + + DatabricksV2( + databricks_conf={ + "run_name": "my-notebook-job", + "new_cluster": {...}, + }, + databricks_instance="my-workspace.cloud.databricks.com", + notebook_path="/Users/user@example.com/my-notebook", + notebook_base_parameters={"param1": "value1"}, + ) """ databricks_conf: Optional[Dict[str, Union[str, dict]]] = None - databricks_instance: Optional[str] = None + databricks_instance: Optional[str] = None # Falls back to FLYTE_DATABRICKS_INSTANCE env var + databricks_service_credential_provider: Optional[str] = None # Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var + notebook_path: Optional[str] = None # Path to Databricks notebook (e.g., "/Users/user@example.com/notebook") + notebook_base_parameters: Optional[Dict[str, str]] = None # Parameters to pass to the notebook # This method does not reset the SparkSession since it's a bit hard to handle multiple @@ -187,7 +293,20 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: job._databricks_conf = cfg.databricks_conf job._databricks_instance = cfg.databricks_instance - return MessageToDict(job.to_flyte_idl()) + # Serialize to dict + custom_dict = MessageToDict(job.to_flyte_idl()) + + # Add DatabricksV2-specific fields (not part of protobuf) + if isinstance(self.task_config, DatabricksV2): + cfg = cast(DatabricksV2, self.task_config) + if cfg.databricks_service_credential_provider: + custom_dict['databricksServiceCredentialProvider'] = cfg.databricks_service_credential_provider + if cfg.notebook_path: + custom_dict['notebookPath'] = cfg.notebook_path + if cfg.notebook_base_parameters: + custom_dict['notebookBaseParameters'] = cfg.notebook_base_parameters + + return custom_dict def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8sPod]: """ @@ -210,10 +329,105 @@ def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8s return K8sPod.from_pod_template(pod_template) + def _is_databricks_serverless(self) -> bool: + """ + Detect if we're running in Databricks serverless environment. + + Serverless uses Spark Connect and requires different SparkSession handling. + """ + # Check for explicit serverless markers set by our entrypoint + if os.environ.get("DATABRICKS_SERVERLESS") == "true": + return True + if os.environ.get("SPARK_CONNECT_MODE") == "true": + return True + + # Check for Databricks serverless indicators + # 1. DATABRICKS_RUNTIME_VERSION exists (Databricks environment) + # 2. No SPARK_HOME (serverless doesn't have traditional Spark) + is_databricks = "DATABRICKS_RUNTIME_VERSION" in os.environ + + # Additional check: if using DatabricksV2 with serverless config + if isinstance(self.task_config, DatabricksV2): + conf = self.task_config.databricks_conf or {} + has_serverless_config = ( + "environment_key" in conf or + "environments" in conf + ) and "new_cluster" not in conf and "existing_cluster_id" not in conf + if has_serverless_config: + return True + + return is_databricks and "SPARK_HOME" not in os.environ + + def _get_databricks_serverless_spark_session(self): + """ + Get SparkSession in Databricks serverless environment. + + The entrypoint injects the SparkSession into: + 1. Custom module '_flyte_spark_session' in sys.modules (most reliable) + 2. builtins.spark (backup) + + Returns: + SparkSession or None if not available + """ + import sys + + # Method 1: Try custom module (most reliable - survives module reloads) + try: + if '_flyte_spark_session' in sys.modules: + spark_module = sys.modules['_flyte_spark_session'] + if hasattr(spark_module, 'spark') and spark_module.spark is not None: + logger.info(f"Got SparkSession from _flyte_spark_session module") + return spark_module.spark + except Exception as e: + logger.debug(f"Could not get spark from _flyte_spark_session: {e}") + + # Method 2: Try builtins (backup location) + try: + import builtins + if hasattr(builtins, 'spark') and builtins.spark is not None: + logger.info(f"Got SparkSession from builtins") + return builtins.spark + except Exception as e: + logger.debug(f"Could not get spark from builtins: {e}") + + # Method 3: Try __main__ module + try: + import __main__ + if hasattr(__main__, 'spark') and __main__.spark is not None: + logger.info(f"Got SparkSession from __main__") + return __main__.spark + except Exception as e: + logger.debug(f"Could not get spark from __main__: {e}") + + # Method 4: Try active session + try: + from pyspark.sql import SparkSession + active = SparkSession.getActiveSession() + if active: + logger.info(f"Got active SparkSession") + return active + except Exception as e: + logger.debug(f"Could not get active SparkSession: {e}") + + logger.warning("Could not obtain SparkSession in serverless environment") + return None + def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: import pyspark as _pyspark ctx = FlyteContextManager.current_context() + + # Databricks serverless uses Spark Connect - SparkSession is pre-configured + if self._is_databricks_serverless(): + logger.info("Detected Databricks serverless environment - using pre-configured SparkSession") + self.sess = self._get_databricks_serverless_spark_session() + + if self.sess is None: + logger.warning("No SparkSession available - task will run without Spark") + + return user_params.builder().add_attr("SPARK_SESSION", self.sess).build() + + # Standard Spark session creation for non-serverless environments sess_builder = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {user_params.execution_id}") if not (ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION): # If either of above cases is not true, then we are in local execution of this task diff --git a/plugins/flytekit-spark/tests/test_connector.py b/plugins/flytekit-spark/tests/test_connector.py index 5136d39ce8..b2828de9fa 100644 --- a/plugins/flytekit-spark/tests/test_connector.py +++ b/plugins/flytekit-spark/tests/test_connector.py @@ -8,8 +8,15 @@ from flyteidl.core.execution_pb2 import TaskExecution from flytekit.core.constants import FLYTE_FAIL_ON_ERROR -from flytekitplugins.spark.connector import DATABRICKS_API_ENDPOINT, DatabricksJobMetadata, get_header, \ - _get_databricks_job_spec, DEFAULT_DATABRICKS_INSTANCE_ENV_KEY +from flytekitplugins.spark.connector import ( + DATABRICKS_API_ENDPOINT, + DatabricksJobMetadata, + get_header, + _get_databricks_job_spec, + _is_serverless_config, + _configure_serverless, + DEFAULT_DATABRICKS_INSTANCE_ENV_KEY, +) from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interfaces.cli_identifiers import Identifier @@ -192,3 +199,475 @@ async def test_agent_create_with_default_instance(task_template: TaskTemplate): assert res == databricks_metadata mock.patch.stopall() + + +# ==================== Serverless Compute Tests ==================== + + +@pytest.fixture(scope="function") +def serverless_task_template_with_env_key() -> TaskTemplate: + """Task template configured for serverless with pre-configured environment_key.""" + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + (), + ) + task_config = { + "sparkConf": {}, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "flytekit serverless job", + "environment_key": "my-preconfigured-env", + "timeout_seconds": 3600, + "git_source": { + "git_url": "https://github.com/test-org/test-repo", + "git_provider": "gitHub", + "git_branch": "main", + }, + "python_file": "entrypoint_serverless.py", + } + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=["pyflyte-execute", "--inputs", "s3://my-s3-bucket"], + resources=Resources(requests=[], limits=[]), + env={"foo": "bar"}, + config={}, + ) + + return TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + + +@pytest.fixture(scope="function") +def serverless_task_template_with_inline_env() -> TaskTemplate: + """Task template configured for serverless with inline environments spec.""" + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + (), + ) + task_config = { + "sparkConf": {}, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "flytekit serverless job with inline env", + "environment_key": "default", + "environments": [{ + "environment_key": "default", + "spec": { + "client": "1", + "dependencies": ["pandas==2.0.0"], + } + }], + "timeout_seconds": 3600, + "git_source": { + "git_url": "https://github.com/test-org/test-repo", + "git_provider": "gitHub", + "git_branch": "main", + }, + "python_file": "entrypoint_serverless.py", + } + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=["pyflyte-execute", "--inputs", "s3://my-s3-bucket"], + resources=Resources(requests=[], limits=[]), + env={"foo": "bar", "MY_VAR": "my_value"}, + config={}, + ) + + return TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + + +@pytest.fixture(scope="function") +def serverless_task_template_no_git_source() -> TaskTemplate: + """Task template for serverless without git_source - relies on connector env vars.""" + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + (), + ) + task_config = { + "sparkConf": {}, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "flytekit serverless job - no git source", + "environment_key": "default", + "environments": [{ + "environment_key": "default", + "spec": { + "client": "4", + } + }], + "timeout_seconds": 3600, + } + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=["pyflyte-execute", "--inputs", "s3://my-s3-bucket"], + resources=Resources(requests=[], limits=[]), + env={"foo": "bar"}, + config={}, + ) + + return TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + + +@pytest.fixture(scope="function") +def invalid_task_template_no_compute() -> TaskTemplate: + """Task template with no cluster or environment config - should fail.""" + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + (), + ) + task_config = { + "sparkConf": {}, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "invalid job - no compute config", + "timeout_seconds": 3600, + } + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=["pyflyte-execute", "--inputs", "s3://my-s3-bucket"], + resources=Resources(requests=[], limits=[]), + env={"foo": "bar"}, + config={}, + ) + + return TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + + +def test_is_serverless_config_detection(): + """Test the serverless configuration detection logic.""" + # Classic compute with existing_cluster_id + assert _is_serverless_config({"existing_cluster_id": "abc123"}) is False + + # Classic compute with new_cluster + assert _is_serverless_config({"new_cluster": {"spark_version": "13.3"}}) is False + + # Serverless with environment_key only + assert _is_serverless_config({"environment_key": "my-env"}) is True + + # Serverless with environments array + assert _is_serverless_config({"environments": [{"environment_key": "default"}]}) is True + + # Serverless with both environment_key and environments + assert _is_serverless_config({ + "environment_key": "default", + "environments": [{"environment_key": "default"}] + }) is True + + # No compute config at all + assert _is_serverless_config({"run_name": "test"}) is False + + # Has cluster AND environment (cluster takes precedence, not serverless) + assert _is_serverless_config({ + "new_cluster": {"spark_version": "13.3"}, + "environment_key": "my-env" + }) is False + + +def test_configure_serverless_with_env_key_only(): + """Test serverless configuration with environment_key only (no environments array).""" + databricks_job = {"environment_key": "my-env"} + envs = {"FOO": "bar", FLYTE_FAIL_ON_ERROR: "true"} + + result_key = _configure_serverless(databricks_job, envs) + + assert result_key == "my-env" + # Databricks serverless requires environments array - it should be auto-created + assert "environments" in databricks_job + assert len(databricks_job["environments"]) == 1 + assert databricks_job["environments"][0]["environment_key"] == "my-env" + # Environment variables should be injected + env_vars = databricks_job["environments"][0]["spec"]["environment_vars"] + assert env_vars["FOO"] == "bar" + assert env_vars[FLYTE_FAIL_ON_ERROR] == "true" + # environment_key should be removed from top level + assert "environment_key" not in databricks_job + + +def test_configure_serverless_with_inline_env(): + """Test serverless configuration with inline environment spec.""" + databricks_job = { + "environment_key": "default", + "environments": [{ + "environment_key": "default", + "spec": { + "client": "1", + "dependencies": ["pandas==2.0.0"], + } + }] + } + envs = {"FOO": "bar", FLYTE_FAIL_ON_ERROR: "true"} + + result_key = _configure_serverless(databricks_job, envs) + + assert result_key == "default" + # Environment variables should be injected + env_vars = databricks_job["environments"][0]["spec"]["environment_vars"] + assert env_vars["FOO"] == "bar" + assert env_vars[FLYTE_FAIL_ON_ERROR] == "true" + # environment_key should be removed from top level + assert "environment_key" not in databricks_job + + +def test_configure_serverless_creates_default_env(): + """Test that serverless creates a default environment when no environment specified.""" + databricks_job = {} # No environment_key or environments + envs = {"FOO": "bar"} + + result_key = _configure_serverless(databricks_job, envs) + + assert result_key == "default" + assert len(databricks_job["environments"]) == 1 + assert databricks_job["environments"][0]["environment_key"] == "default" + # Should have env vars injected + assert databricks_job["environments"][0]["spec"]["environment_vars"]["FOO"] == "bar" + + +def test_get_databricks_job_spec_serverless_with_env_key(serverless_task_template_with_env_key: TaskTemplate): + """Test job spec generation for serverless with environment_key only.""" + serverless_task_template_with_env_key.custom["databricksInstance"] = "test-account.cloud.databricks.com" + + spec = _get_databricks_job_spec(serverless_task_template_with_env_key) + + # Serverless uses multi-task format with tasks array + assert "tasks" in spec + assert len(spec["tasks"]) == 1 + + task_def = spec["tasks"][0] + assert task_def["task_key"] == "flyte_task" + assert task_def["environment_key"] == "my-preconfigured-env" + assert "spark_python_task" in task_def + + # Databricks serverless requires environments array - should be auto-created + assert "environments" in spec + assert len(spec["environments"]) == 1 + assert spec["environments"][0]["environment_key"] == "my-preconfigured-env" + + # Should NOT have spark_python_task at top level for serverless + assert "spark_python_task" not in spec + + # Should NOT have environment_key at top level (moved to task) + assert "environment_key" not in spec + + # Should NOT have cluster config + assert "new_cluster" not in spec + assert "existing_cluster_id" not in spec + + # Should have git_source + assert "git_source" in spec + + +def test_get_databricks_job_spec_serverless_with_inline_env(serverless_task_template_with_inline_env: TaskTemplate): + """Test job spec generation for serverless with inline environment spec.""" + serverless_task_template_with_inline_env.custom["databricksInstance"] = "test-account.cloud.databricks.com" + + spec = _get_databricks_job_spec(serverless_task_template_with_inline_env) + + # Serverless uses multi-task format with tasks array + assert "tasks" in spec + assert len(spec["tasks"]) == 1 + + task_def = spec["tasks"][0] + assert task_def["task_key"] == "flyte_task" + assert task_def["environment_key"] == "default" + assert "spark_python_task" in task_def + + # Should have environments array with injected env vars + assert "environments" in spec + env_vars = spec["environments"][0]["spec"]["environment_vars"] + assert env_vars["foo"] == "bar" + assert env_vars["MY_VAR"] == "my_value" + assert env_vars[FLYTE_FAIL_ON_ERROR] == "true" + + # Should NOT have cluster config + assert "new_cluster" not in spec + assert "existing_cluster_id" not in spec + + +def test_get_databricks_job_spec_error_no_compute(invalid_task_template_no_compute: TaskTemplate): + """Test that job spec generation fails when no compute config is provided.""" + with pytest.raises(ValueError) as exc_info: + _get_databricks_job_spec(invalid_task_template_no_compute) + + assert "existing_cluster_id" in str(exc_info.value) + assert "new_cluster" in str(exc_info.value) + assert "environment_key" in str(exc_info.value) + assert "environments" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_databricks_agent_serverless(serverless_task_template_with_env_key: TaskTemplate): + """Test the full agent flow with serverless compute.""" + import copy + agent = AgentRegistry.get_agent("spark") + + serverless_task_template_with_env_key.custom["databricksInstance"] = "test-account.cloud.databricks.com" + + # Generate spec BEFORE agent.create() mutates the template in-place + spec_copy = copy.deepcopy(serverless_task_template_with_env_key) + spec = _get_databricks_job_spec(spec_copy) + + # Verify serverless config uses multi-task format with environments + assert "tasks" in spec + task_def = spec["tasks"][0] + assert task_def["task_key"] == "flyte_task" + assert task_def["environment_key"] == "my-preconfigured-env" + assert "spark_python_task" in task_def + assert "environments" in spec # Required for serverless + assert "new_cluster" not in spec + + mocked_token = "mocked_databricks_token" + mocked_context = mock.patch("flytekit.current_context", autospec=True).start() + mocked_context.return_value.secrets.get.return_value = mocked_token + + databricks_metadata = DatabricksJobMetadata( + databricks_instance="test-account.cloud.databricks.com", + run_id="456", + ) + + mock_create_response = {"run_id": "456"} + mock_get_response = { + "job_id": "2", + "run_id": "456", + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS", "state_message": "OK"}, + } + + create_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" + get_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=456" + + with aioresponses() as mocked: + mocked.post(create_url, status=http.HTTPStatus.OK, payload=mock_create_response) + res = await agent.create(serverless_task_template_with_env_key, None) + assert res == databricks_metadata + + mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_get_response) + resource = await agent.get(databricks_metadata) + assert resource.phase == TaskExecution.SUCCEEDED + + mock.patch.stopall() + + +# ==================== Default Serverless Entrypoint Tests ==================== + + +def test_serverless_default_entrypoint_from_flytetools(serverless_task_template_no_git_source: TaskTemplate): + """Test that serverless uses the default flytetools entrypoint when no git_source in task config.""" + spec = _get_databricks_job_spec(serverless_task_template_no_git_source) + + # Should use the same flytetools repo as classic + assert spec["git_source"]["git_url"] == "https://github.com/flyteorg/flytetools" + assert spec["git_source"]["git_provider"] == "gitHub" + assert "git_commit" in spec["git_source"] + + # Should use the serverless-specific python_file + task_def = spec["tasks"][0] + assert task_def["spark_python_task"]["python_file"] == "flytekitplugins/databricks/entrypoint_serverless.py" + + # Should still be valid serverless format + assert "environments" in spec + assert "new_cluster" not in spec + + +def test_serverless_task_git_source_overrides_default(serverless_task_template_with_env_key: TaskTemplate): + """Test that task-level git_source takes precedence over the flytetools default.""" + spec = _get_databricks_job_spec(serverless_task_template_with_env_key) + + # Should use the task-level git_source, NOT the flytetools default + assert spec["git_source"]["git_url"] == "https://github.com/test-org/test-repo" + assert spec["git_source"]["git_branch"] == "main" + + # Should use the task-level python_file + task_def = spec["tasks"][0] + assert task_def["spark_python_task"]["python_file"] == "entrypoint_serverless.py" + + +def test_classic_and_serverless_use_same_repo(task_template: TaskTemplate, serverless_task_template_no_git_source: TaskTemplate): + """Test that both classic and serverless default to the same flytetools repo.""" + classic_spec = _get_databricks_job_spec(task_template) + serverless_spec = _get_databricks_job_spec(serverless_task_template_no_git_source) + + # Same repo + assert classic_spec["git_source"]["git_url"] == serverless_spec["git_source"]["git_url"] + # Same commit + assert classic_spec["git_source"]["git_commit"] == serverless_spec["git_source"]["git_commit"] + # Different python_file + assert classic_spec["spark_python_task"]["python_file"] == "flytekitplugins/databricks/entrypoint.py" + assert serverless_spec["tasks"][0]["spark_python_task"]["python_file"] == "flytekitplugins/databricks/entrypoint_serverless.py" diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 7198a4dec0..76c229766c 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -1,3 +1,4 @@ +import os import os.path from unittest import mock @@ -493,3 +494,216 @@ def my_spark(a: str) -> int: configs = my_spark.sess.sparkContext.getConf().getAll() assert ("spark.driver.memory", "1000M") in configs assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs + + +# ==================== Serverless Detection Tests ==================== + + +def test_databricks_v2_serverless_detection_with_env_var(reset_spark_session): + """Test that serverless is detected when DATABRICKS_SERVERLESS env var is set.""" + databricks_conf = { + "run_name": "test", + "new_cluster": {"spark_version": "13.3.x-scala2.12"}, # Has cluster config + } + + @task( + task_config=DatabricksV2( + databricks_conf=databricks_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def my_task(a: int) -> int: + return a + + # Without env var, should NOT be serverless (has new_cluster) + assert my_task._is_databricks_serverless() is False + + # With env var set, should BE serverless + os.environ["DATABRICKS_SERVERLESS"] = "true" + try: + assert my_task._is_databricks_serverless() is True + finally: + del os.environ["DATABRICKS_SERVERLESS"] + + +def test_databricks_v2_serverless_detection_with_config(reset_spark_session): + """Test that serverless is detected based on DatabricksV2 config.""" + # Serverless config: has environment_key, no cluster config + serverless_conf = { + "run_name": "serverless-test", + "environment_key": "my-env", + } + + @task( + task_config=DatabricksV2( + databricks_conf=serverless_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def serverless_task(a: int) -> int: + return a + + # Should detect serverless from config + assert serverless_task._is_databricks_serverless() is True + + # Classic config: has new_cluster + classic_conf = { + "run_name": "classic-test", + "new_cluster": {"spark_version": "13.3.x-scala2.12"}, + } + + @task( + task_config=DatabricksV2( + databricks_conf=classic_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def classic_task(a: int) -> int: + return a + + # Should NOT detect serverless + assert classic_task._is_databricks_serverless() is False + + +def test_databricks_v2_serverless_detection_with_environments_array(reset_spark_session): + """Test serverless detection with inline environments array.""" + serverless_conf = { + "run_name": "serverless-inline", + "environments": [{ + "environment_key": "default", + "spec": {"client": "1", "dependencies": ["pandas"]} + }], + } + + @task( + task_config=DatabricksV2( + databricks_conf=serverless_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def serverless_task(a: int) -> int: + return a + + assert serverless_task._is_databricks_serverless() is True + + +def test_databricks_v2_classic_not_detected_as_serverless(reset_spark_session): + """Test that classic compute is not incorrectly detected as serverless.""" + # Classic with existing_cluster_id + existing_cluster_conf = { + "run_name": "existing-cluster", + "existing_cluster_id": "abc-123", + } + + @task( + task_config=DatabricksV2( + databricks_conf=existing_cluster_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def existing_cluster_task(a: int) -> int: + return a + + assert existing_cluster_task._is_databricks_serverless() is False + + # Classic with new_cluster AND environment_key (cluster takes precedence) + mixed_conf = { + "run_name": "mixed", + "new_cluster": {"spark_version": "13.3.x-scala2.12"}, + "environment_key": "my-env", # Should be ignored + } + + @task( + task_config=DatabricksV2( + databricks_conf=mixed_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def mixed_task(a: int) -> int: + return a + + assert mixed_task._is_databricks_serverless() is False + + +def test_databricks_v2_service_credential_provider(): + """Test that service credential provider is properly serialized.""" + serverless_conf = { + "run_name": "serverless-with-creds", + "environment_key": "my-env", + } + + @task( + task_config=DatabricksV2( + databricks_conf=serverless_conf, + databricks_instance="test.cloud.databricks.com", + databricks_service_credential_provider="my-credential-provider", + ) + ) + def task_with_creds(a: int) -> int: + return a + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + custom = task_with_creds.get_custom(settings) + assert custom.get("databricksServiceCredentialProvider") == "my-credential-provider" + + +def test_databricks_v2_no_service_credential_provider(): + """Test that custom dict doesn't have credential provider when not set.""" + serverless_conf = { + "run_name": "serverless-no-creds", + "environment_key": "my-env", + } + + @task( + task_config=DatabricksV2( + databricks_conf=serverless_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def task_no_creds(a: int) -> int: + return a + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + custom = task_no_creds.get_custom(settings) + assert "databricksServiceCredentialProvider" not in custom + + +def test_spark_classic_not_affected_by_serverless_code(reset_spark_session): + """Test that regular Spark tasks (non-Databricks) are not affected by serverless code.""" + @task( + task_config=Spark( + spark_conf={"spark.driver.memory": "512M"}, + ) + ) + def spark_task(a: int) -> int: + return a + + # Regular Spark task should NOT be detected as serverless + assert spark_task._is_databricks_serverless() is False + + # pre_execute should work normally + pb = ExecutionParameters.new_builder() + pb.working_dir = "/tmp" + pb.execution_id = "ex:local:local:local" + p = pb.build() + new_p = spark_task.pre_execute(p) + + assert new_p is not None + assert new_p.has_attr("SPARK_SESSION") + assert spark_task.sess is not None From 94c5a1d5f831ba376a54de5176b8e1d99c22b14f Mon Sep 17 00:00:00 2001 From: Rohit Sharma Date: Mon, 2 Mar 2026 21:21:46 +0000 Subject: [PATCH 2/3] refactor(spark): Address review comments on serverless support - Extract _build_notebook_job_spec() and _build_python_file_job_spec() - Move compute validation to beginning of _get_databricks_job_spec() - Add API docs link and config format to _configure_serverless docstring - Create shared utils.py with is_serverless_config() for reuse in task.py Signed-off-by: Rohit Sharma --- .../flytekitplugins/spark/connector.py | 181 +++++++++--------- .../flytekitplugins/spark/task.py | 7 +- .../flytekitplugins/spark/utils.py | 16 ++ 3 files changed, 108 insertions(+), 96 deletions(-) create mode 100644 plugins/flytekit-spark/flytekitplugins/spark/utils.py diff --git a/plugins/flytekit-spark/flytekitplugins/spark/connector.py b/plugins/flytekit-spark/flytekitplugins/spark/connector.py index 628ef057f8..fe8692da9f 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/connector.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/connector.py @@ -17,6 +17,8 @@ aiohttp = lazy_module("aiohttp") +from .utils import is_serverless_config as _is_serverless_config + DATABRICKS_API_ENDPOINT = "/api/2.1/jobs" DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE" DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY = "FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER" @@ -28,31 +30,38 @@ class DatabricksJobMetadata(ResourceMeta): run_id: str -def _is_serverless_config(databricks_job: dict) -> bool: - """ - Detect if the configuration is for serverless compute. - Serverless is indicated by having environment_key or environments without cluster config. - """ - # Check if cluster config keys exist (even empty dict counts as cluster config) - has_cluster_config = "existing_cluster_id" in databricks_job or "new_cluster" in databricks_job - has_serverless_config = bool(databricks_job.get("environment_key") or databricks_job.get("environments")) - return not has_cluster_config and has_serverless_config - - def _configure_serverless(databricks_job: dict, envs: dict) -> str: """ Configure serverless compute settings and return the environment_key to use. - Databricks serverless requires the 'environments' array to be defined in the job - submission. This function ensures the environments array exists and injects - Flyte environment variables. + Databricks serverless requires the ``environments`` array in the job submission. + This function ensures the array exists and injects Flyte environment variables + into the matching environment's ``spec.environment_vars``. + + Reference: https://docs.databricks.com/api/workspace/jobs/submit + + Expected ``environments`` format:: + + "environments": [ + { + "environment_key": "", + "spec": { + "client": "1", + "dependencies": ["pandas==2.0.0"], + "environment_vars": {"KEY": "VALUE"} + } + } + ] + + Tasks reference an environment via their own ``environment_key`` field, + analogous to how ``job_cluster_key`` links a task to a shared cluster. Args: - databricks_job: The databricks job configuration dict - envs: Environment variables to inject + databricks_job: The databricks job configuration dict. + envs: Environment variables to inject into the environment spec. Returns: - The environment_key to use for the task + The environment_key to use in the task definition. """ environment_key = databricks_job.get("environment_key", "default") environments = databricks_job.get("environments", []) @@ -104,9 +113,7 @@ def _configure_classic_cluster(databricks_job: dict, custom: dict, container, en new_cluster = databricks_job.get("new_cluster") if new_cluster is None: - raise ValueError( - "Either existing_cluster_id, new_cluster, environment_key, or environments must be specified" - ) + return if not new_cluster.get("docker_image"): new_cluster["docker_image"] = {"url": container.image} @@ -118,65 +125,47 @@ def _configure_classic_cluster(databricks_job: dict, custom: dict, container, en new_cluster["spark_env_vars"].update({k: v for k, v in envs.items()}) -def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: - custom = task_template.custom - container = task_template.container - envs = task_template.container.env - envs[FLYTE_FAIL_ON_ERROR] = "true" - databricks_job = custom["databricksConf"] - - # Check if this is a notebook task - notebook_path = custom.get("notebookPath") +def _build_notebook_job_spec( + databricks_job: dict, custom: dict, container, envs: dict, is_serverless: bool +) -> dict: + """Build the Databricks job spec for a notebook task.""" + notebook_path = custom["notebookPath"] notebook_base_parameters = custom.get("notebookBaseParameters", {}) - # Determine compute mode and configure accordingly - is_serverless = _is_serverless_config(databricks_job) + notebook_task = {"notebook_path": notebook_path} + if notebook_base_parameters: + notebook_task["base_parameters"] = notebook_base_parameters + + user_git_source = databricks_job.get("git_source") + if user_git_source: + notebook_task["source"] = "GIT" - if notebook_path: - # Notebook task - runs a Databricks notebook - notebook_task = { - "notebook_path": notebook_path, + if is_serverless: + environment_key = _configure_serverless(databricks_job, envs) + task_def = { + "task_key": "flyte_notebook_task", + "notebook_task": notebook_task, + "environment_key": environment_key, } - if notebook_base_parameters: - notebook_task["base_parameters"] = notebook_base_parameters - - # Check if notebook should be sourced from git - user_git_source = databricks_job.get("git_source") - if user_git_source: - notebook_task["source"] = "GIT" - # Set git_source at job level - databricks_job["git_source"] = user_git_source - - if is_serverless: - # Serverless notebook task - environment_key = _configure_serverless(databricks_job, envs) - - task_def = { - "task_key": "flyte_notebook_task", - "notebook_task": notebook_task, - "environment_key": environment_key, - } - databricks_job["tasks"] = [task_def] - databricks_job.pop("environment_key", None) - else: - # Classic compute notebook task - _configure_classic_cluster(databricks_job, custom, container, envs) - databricks_job["notebook_task"] = notebook_task - - # Clean up git_source from databricks_job if it was there (already set at job level) - databricks_job.pop("git_source", None) - if user_git_source: - databricks_job["git_source"] = user_git_source - - return databricks_job - - # Python file task (original behavior) - # Allow custom git_source and python_file override from user config + databricks_job["tasks"] = [task_def] + else: + _configure_classic_cluster(databricks_job, custom, container, envs) + databricks_job["notebook_task"] = notebook_task + + databricks_job.pop("git_source", None) + if user_git_source: + databricks_job["git_source"] = user_git_source + + return databricks_job + + +def _build_python_file_job_spec( + databricks_job: dict, custom: dict, container, envs: dict, is_serverless: bool +) -> dict: + """Build the Databricks job spec for a python file (spark_python_task).""" user_git_source = databricks_job.get("git_source") user_python_file = databricks_job.get("python_file") - # Default entrypoints from the flytetools repo. - # Both classic and serverless use the same repo; only the python_file differs. default_git_source = { "git_url": "https://github.com/flyteorg/flytetools", "git_provider": "gitHub", @@ -186,25 +175,18 @@ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: default_serverless_python_file = "flytekitplugins/databricks/entrypoint_serverless.py" if is_serverless: - # Serverless compute - use flytetools serverless entrypoint by default - git_source = user_git_source if user_git_source else default_git_source - python_file = user_python_file if user_python_file else default_serverless_python_file + git_source = user_git_source or default_git_source + python_file = user_python_file or default_serverless_python_file - # Serverless requires multi-task format with tasks array environment_key = _configure_serverless(databricks_job, envs) - # Build parameters list - append credential provider if specified - # This allows the entrypoint to receive the credential provider via command line - # We append at the END to avoid breaking pyflyte-fast-execute which must be first parameters = list(container.args) if container.args else [] - - # Resolve service credential provider: task config > env var + service_credential_provider = custom.get( "databricksServiceCredentialProvider", os.getenv(DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY) ) if service_credential_provider: - # Append as a special argument that the entrypoint will parse and remove parameters.append(f"--flyte-credential-provider={service_credential_provider}") spark_python_task = { @@ -213,22 +195,16 @@ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: "parameters": parameters, } - # Build the task definition for serverless task_def = { "task_key": "flyte_task", "spark_python_task": spark_python_task, "environment_key": environment_key, } - # Add tasks array for serverless (required by Databricks API) databricks_job["tasks"] = [task_def] - - # Remove environment_key from top level (it's now in the task) - databricks_job.pop("environment_key", None) else: - # Classic compute - use flytetools entrypoint by default - git_source = user_git_source if user_git_source else default_git_source - python_file = user_python_file if user_python_file else default_classic_python_file + git_source = user_git_source or default_git_source + python_file = user_python_file or default_classic_python_file spark_python_task = { "python_file": python_file, @@ -239,7 +215,6 @@ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: _configure_classic_cluster(databricks_job, custom, container, envs) databricks_job["spark_python_task"] = spark_python_task - # Set git_source (remove from user config if it was there to avoid duplication) databricks_job.pop("git_source", None) databricks_job.pop("python_file", None) databricks_job["git_source"] = git_source @@ -247,6 +222,30 @@ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: return databricks_job +def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: + custom = task_template.custom + container = task_template.container + envs = task_template.container.env + envs[FLYTE_FAIL_ON_ERROR] = "true" + databricks_job = custom["databricksConf"] + + has_cluster = "existing_cluster_id" in databricks_job or "new_cluster" in databricks_job + has_serverless = bool(databricks_job.get("environment_key") or databricks_job.get("environments")) + if not has_cluster and not has_serverless: + raise ValueError( + "No compute configuration found in databricks_conf. " + "Provide one of: 'existing_cluster_id' (classic), 'new_cluster' (classic), " + "'environment_key' (serverless), or 'environments' (serverless)." + ) + + is_serverless = _is_serverless_config(databricks_job) + + if custom.get("notebookPath"): + return _build_notebook_job_spec(databricks_job, custom, container, envs, is_serverless) + + return _build_python_file_job_spec(databricks_job, custom, container, envs, is_serverless) + + class DatabricksConnector(AsyncConnectorBase): name = "Databricks Connector" diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index c8b09447dd..1ffb68ca9d 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -18,6 +18,7 @@ from flytekit.models.task import K8sPod from .models import SparkJob, SparkType +from .utils import is_serverless_config pyspark_sql = lazy_module("pyspark.sql") SparkSession = pyspark_sql.SparkSession @@ -349,11 +350,7 @@ def _is_databricks_serverless(self) -> bool: # Additional check: if using DatabricksV2 with serverless config if isinstance(self.task_config, DatabricksV2): conf = self.task_config.databricks_conf or {} - has_serverless_config = ( - "environment_key" in conf or - "environments" in conf - ) and "new_cluster" not in conf and "existing_cluster_id" not in conf - if has_serverless_config: + if is_serverless_config(conf): return True return is_databricks and "SPARK_HOME" not in os.environ diff --git a/plugins/flytekit-spark/flytekitplugins/spark/utils.py b/plugins/flytekit-spark/flytekitplugins/spark/utils.py new file mode 100644 index 0000000000..3abe59b377 --- /dev/null +++ b/plugins/flytekit-spark/flytekitplugins/spark/utils.py @@ -0,0 +1,16 @@ +def is_serverless_config(databricks_conf: dict) -> bool: + """ + Detect if the Databricks configuration is for serverless compute. + + Serverless is indicated by having ``environment_key`` or ``environments`` + without any cluster config (``existing_cluster_id`` or ``new_cluster``). + + Args: + databricks_conf: The databricks job configuration dict. + + Returns: + True if the configuration targets serverless compute. + """ + has_cluster_config = "existing_cluster_id" in databricks_conf or "new_cluster" in databricks_conf + has_serverless_config = bool(databricks_conf.get("environment_key") or databricks_conf.get("environments")) + return not has_cluster_config and has_serverless_config From 27ea1887f97f3254218a1d137b7c214fa2069e90 Mon Sep 17 00:00:00 2001 From: Rohit Sharma Date: Thu, 5 Mar 2026 13:02:11 +0000 Subject: [PATCH 3/3] refactor(spark): Address review comments - value checks, serverless flag, lint fixes - Switch validation to value-based .get() checks so None values are rejected - Combine serverless config flag with is_databricks check in task.py - Fix ruff, ruff-format, and pydoclint violations - Update tests for new serverless detection behavior - Update pydoclint-errors-baseline.txt (fixed DatabricksV2 docstring) Signed-off-by: Rohit Sharma --- Dockerfile.connector | 2 +- .../flytekitplugins/spark/connector.py | 31 +++--- .../flytekitplugins/spark/task.py | 103 ++++++++++-------- .../flytekitplugins/spark/utils.py | 8 +- .../flytekit-spark/tests/test_spark_task.py | 99 +++++++++-------- pydoclint-errors-baseline.txt | 4 - 6 files changed, 126 insertions(+), 121 deletions(-) diff --git a/Dockerfile.connector b/Dockerfile.connector index 7eea1d42b9..bcb51e5db9 100644 --- a/Dockerfile.connector +++ b/Dockerfile.connector @@ -35,4 +35,4 @@ ARG VERSION RUN uv pip install --system --no-cache-dir -U \ flytekitplugins-mmcloud==$VERSION \ - flytekitplugins-spark==$VERSION \ No newline at end of file + flytekitplugins-spark==$VERSION diff --git a/plugins/flytekit-spark/flytekitplugins/spark/connector.py b/plugins/flytekit-spark/flytekitplugins/spark/connector.py index fe8692da9f..5adb9157c4 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/connector.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/connector.py @@ -15,10 +15,10 @@ from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -aiohttp = lazy_module("aiohttp") - from .utils import is_serverless_config as _is_serverless_config +aiohttp = lazy_module("aiohttp") + DATABRICKS_API_ENDPOINT = "/api/2.1/jobs" DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE" DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY = "FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER" @@ -57,11 +57,11 @@ def _configure_serverless(databricks_job: dict, envs: dict) -> str: analogous to how ``job_cluster_key`` links a task to a shared cluster. Args: - databricks_job: The databricks job configuration dict. - envs: Environment variables to inject into the environment spec. + databricks_job (dict): The databricks job configuration dict. + envs (dict): Environment variables to inject into the environment spec. Returns: - The environment_key to use in the task definition. + str: The environment_key to use in the task definition. """ environment_key = databricks_job.get("environment_key", "default") environments = databricks_job.get("environments", []) @@ -76,7 +76,7 @@ def _configure_serverless(databricks_job: dict, envs: dict) -> str: "environment_key": environment_key, "spec": { "client": "1", # Required: Databricks serverless client version - } + }, } environments.append(new_env) databricks_job["environments"] = environments @@ -97,15 +97,15 @@ def _configure_serverless(databricks_job: dict, envs: dict) -> str: return environment_key -def _configure_classic_cluster(databricks_job: dict, custom: dict, container, envs: dict) -> None: +def _configure_classic_cluster(databricks_job: dict, custom: dict, container: typing.Any, envs: dict) -> None: """ Configure classic compute (existing cluster or new cluster). Args: - databricks_job: The databricks job configuration dict - custom: The custom config from task template - container: The container config from task template - envs: Environment variables to inject + databricks_job (dict): The databricks job configuration dict. + custom (dict): The custom config from task template. + container (typing.Any): The container config from task template. + envs (dict): Environment variables to inject. """ if databricks_job.get("existing_cluster_id") is not None: # Using an existing cluster, no additional configuration needed @@ -126,7 +126,7 @@ def _configure_classic_cluster(databricks_job: dict, custom: dict, container, en def _build_notebook_job_spec( - databricks_job: dict, custom: dict, container, envs: dict, is_serverless: bool + databricks_job: dict, custom: dict, container: typing.Any, envs: dict, is_serverless: bool ) -> dict: """Build the Databricks job spec for a notebook task.""" notebook_path = custom["notebookPath"] @@ -160,7 +160,7 @@ def _build_notebook_job_spec( def _build_python_file_job_spec( - databricks_job: dict, custom: dict, container, envs: dict, is_serverless: bool + databricks_job: dict, custom: dict, container: typing.Any, envs: dict, is_serverless: bool ) -> dict: """Build the Databricks job spec for a python file (spark_python_task).""" user_git_source = databricks_job.get("git_source") @@ -183,8 +183,7 @@ def _build_python_file_job_spec( parameters = list(container.args) if container.args else [] service_credential_provider = custom.get( - "databricksServiceCredentialProvider", - os.getenv(DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY) + "databricksServiceCredentialProvider", os.getenv(DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY) ) if service_credential_provider: parameters.append(f"--flyte-credential-provider={service_credential_provider}") @@ -229,7 +228,7 @@ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: envs[FLYTE_FAIL_ON_ERROR] = "true" databricks_job = custom["databricksConf"] - has_cluster = "existing_cluster_id" in databricks_job or "new_cluster" in databricks_job + has_cluster = databricks_job.get("existing_cluster_id") is not None or databricks_job.get("new_cluster") is not None has_serverless = bool(databricks_job.get("environment_key") or databricks_job.get("environments")) if not has_cluster and not has_serverless: raise ValueError( diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 1ffb68ca9d..6492183969 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -78,11 +78,18 @@ class DatabricksV2(Spark): Supports both classic compute (clusters) and serverless compute. - Args: - databricks_conf: Databricks job configuration compliant with API version 2.1, supporting 2.0 use cases. + Attributes: + databricks_conf (Optional[Dict[str, Union[str, dict]]]): Databricks job configuration + compliant with API version 2.1, supporting 2.0 use cases. For the configuration structure, visit: https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html - databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. + databricks_instance (Optional[str]): Domain name of your deployment. + Use the form .cloud.databricks.com. + databricks_service_credential_provider (Optional[str]): Provider name for Databricks + Service Credentials for S3 access. Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var. + notebook_path (Optional[str]): Path to Databricks notebook + (e.g., "/Users/user@example.com/notebook"). + notebook_base_parameters (Optional[Dict[str, str]]): Parameters to pass to the notebook. Compute Modes: The connector auto-detects the compute mode based on the databricks_conf contents: @@ -150,31 +157,31 @@ class DatabricksV2(Spark): - Must use Unity Catalog for external data sources - No support for compute-scoped init scripts or libraries For full details, see: https://docs.databricks.com/en/compute/serverless/limitations.html - + Serverless Entrypoint: Both classic and serverless use the same ``flytetools`` repo for their entrypoints. Classic uses ``flytekitplugins/databricks/entrypoint.py`` and serverless uses ``flytekitplugins/databricks/entrypoint_serverless.py``. No additional configuration needed. - + To override the default, provide ``git_source`` and ``python_file`` in ``databricks_conf``. AWS Credentials for Serverless: Databricks serverless does not provide AWS credentials via instance metadata. To access S3 (for Flyte data), configure a Databricks Service Credential. - + The provider name is resolved in this order: 1. ``databricks_service_credential_provider`` in the task config (per-task override) 2. ``FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER`` environment variable on the connector (default for all tasks) - + The entrypoint will use this to obtain AWS credentials via: dbutils.credentials.getServiceCredentialsProvider(provider_name) Notebook Support: To run a Databricks notebook instead of a Python file, set `notebook_path`. Parameters can be passed via `notebook_base_parameters`. - + Example - Running a notebook:: - + DatabricksV2( databricks_conf={ "run_name": "my-notebook-job", @@ -188,7 +195,9 @@ class DatabricksV2(Spark): databricks_conf: Optional[Dict[str, Union[str, dict]]] = None databricks_instance: Optional[str] = None # Falls back to FLYTE_DATABRICKS_INSTANCE env var - databricks_service_credential_provider: Optional[str] = None # Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var + databricks_service_credential_provider: Optional[str] = ( + None # Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var + ) notebook_path: Optional[str] = None # Path to Databricks notebook (e.g., "/Users/user@example.com/notebook") notebook_base_parameters: Optional[Dict[str, str]] = None # Parameters to pass to the notebook @@ -296,17 +305,17 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: # Serialize to dict custom_dict = MessageToDict(job.to_flyte_idl()) - + # Add DatabricksV2-specific fields (not part of protobuf) if isinstance(self.task_config, DatabricksV2): cfg = cast(DatabricksV2, self.task_config) if cfg.databricks_service_credential_provider: - custom_dict['databricksServiceCredentialProvider'] = cfg.databricks_service_credential_provider + custom_dict["databricksServiceCredentialProvider"] = cfg.databricks_service_credential_provider if cfg.notebook_path: - custom_dict['notebookPath'] = cfg.notebook_path + custom_dict["notebookPath"] = cfg.notebook_path if cfg.notebook_base_parameters: - custom_dict['notebookBaseParameters'] = cfg.notebook_base_parameters - + custom_dict["notebookBaseParameters"] = cfg.notebook_base_parameters + return custom_dict def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8sPod]: @@ -333,7 +342,7 @@ def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8s def _is_databricks_serverless(self) -> bool: """ Detect if we're running in Databricks serverless environment. - + Serverless uses Spark Connect and requires different SparkSession handling. """ # Check for explicit serverless markers set by our entrypoint @@ -341,71 +350,71 @@ def _is_databricks_serverless(self) -> bool: return True if os.environ.get("SPARK_CONNECT_MODE") == "true": return True - - # Check for Databricks serverless indicators - # 1. DATABRICKS_RUNTIME_VERSION exists (Databricks environment) - # 2. No SPARK_HOME (serverless doesn't have traditional Spark) + is_databricks = "DATABRICKS_RUNTIME_VERSION" in os.environ - - # Additional check: if using DatabricksV2 with serverless config + + is_serverless_cfg = False if isinstance(self.task_config, DatabricksV2): conf = self.task_config.databricks_conf or {} if is_serverless_config(conf): - return True - - return is_databricks and "SPARK_HOME" not in os.environ + is_serverless_cfg = True + + return is_databricks and (is_serverless_cfg or "SPARK_HOME" not in os.environ) - def _get_databricks_serverless_spark_session(self): + def _get_databricks_serverless_spark_session(self) -> Optional[SparkSession]: """ Get SparkSession in Databricks serverless environment. - + The entrypoint injects the SparkSession into: 1. Custom module '_flyte_spark_session' in sys.modules (most reliable) 2. builtins.spark (backup) - + Returns: - SparkSession or None if not available + Optional[SparkSession]: SparkSession or None if not available. """ import sys - + # Method 1: Try custom module (most reliable - survives module reloads) try: - if '_flyte_spark_session' in sys.modules: - spark_module = sys.modules['_flyte_spark_session'] - if hasattr(spark_module, 'spark') and spark_module.spark is not None: - logger.info(f"Got SparkSession from _flyte_spark_session module") + if "_flyte_spark_session" in sys.modules: + spark_module = sys.modules["_flyte_spark_session"] + if hasattr(spark_module, "spark") and spark_module.spark is not None: + logger.info("Got SparkSession from _flyte_spark_session module") return spark_module.spark except Exception as e: logger.debug(f"Could not get spark from _flyte_spark_session: {e}") - + # Method 2: Try builtins (backup location) try: import builtins - if hasattr(builtins, 'spark') and builtins.spark is not None: - logger.info(f"Got SparkSession from builtins") + + if hasattr(builtins, "spark") and builtins.spark is not None: + logger.info("Got SparkSession from builtins") return builtins.spark except Exception as e: logger.debug(f"Could not get spark from builtins: {e}") - + # Method 3: Try __main__ module try: import __main__ - if hasattr(__main__, 'spark') and __main__.spark is not None: - logger.info(f"Got SparkSession from __main__") + + if hasattr(__main__, "spark") and __main__.spark is not None: + logger.info("Got SparkSession from __main__") return __main__.spark except Exception as e: logger.debug(f"Could not get spark from __main__: {e}") - + # Method 4: Try active session try: from pyspark.sql import SparkSession + active = SparkSession.getActiveSession() if active: - logger.info(f"Got active SparkSession") + logger.info("Got active SparkSession") return active except Exception as e: logger.debug(f"Could not get active SparkSession: {e}") - + logger.warning("Could not obtain SparkSession in serverless environment") return None @@ -413,17 +422,17 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: import pyspark as _pyspark ctx = FlyteContextManager.current_context() - + # Databricks serverless uses Spark Connect - SparkSession is pre-configured if self._is_databricks_serverless(): logger.info("Detected Databricks serverless environment - using pre-configured SparkSession") self.sess = self._get_databricks_serverless_spark_session() - + if self.sess is None: logger.warning("No SparkSession available - task will run without Spark") - + return user_params.builder().add_attr("SPARK_SESSION", self.sess).build() - + # Standard Spark session creation for non-serverless environments sess_builder = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {user_params.execution_id}") if not (ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION): diff --git a/plugins/flytekit-spark/flytekitplugins/spark/utils.py b/plugins/flytekit-spark/flytekitplugins/spark/utils.py index 3abe59b377..c60b2871a3 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/utils.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/utils.py @@ -6,11 +6,13 @@ def is_serverless_config(databricks_conf: dict) -> bool: without any cluster config (``existing_cluster_id`` or ``new_cluster``). Args: - databricks_conf: The databricks job configuration dict. + databricks_conf (dict): The databricks job configuration dict. Returns: - True if the configuration targets serverless compute. + bool: True if the configuration targets serverless compute. """ - has_cluster_config = "existing_cluster_id" in databricks_conf or "new_cluster" in databricks_conf + has_cluster_config = ( + databricks_conf.get("existing_cluster_id") is not None or databricks_conf.get("new_cluster") is not None + ) has_serverless_config = bool(databricks_conf.get("environment_key") or databricks_conf.get("environments")) return not has_cluster_config and has_serverless_config diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 76c229766c..d8d47a3c06 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -5,41 +5,39 @@ import pandas as pd import pyspark import pytest -import tempfile - -from google.protobuf.json_format import MessageToDict -from flytekit import PodTemplate -from flytekit.core import context_manager from flytekitplugins.spark import Spark from flytekitplugins.spark.task import Databricks, DatabricksV2, new_spark_session +from google.protobuf.json_format import MessageToDict +from kubernetes.client.models import ( + V1Container, + V1EnvVar, + V1PodSpec, + V1Toleration, +) from pyspark.sql import SparkSession import flytekit from flytekit import ( + ImageSpec, + PodTemplate, StructuredDataset, StructuredDatasetTransformerEngine, task, - ImageSpec, ) from flytekit.configuration import ( + DefaultImages, + FastSerializationSettings, Image, ImageConfig, SerializationSettings, - FastSerializationSettings, - DefaultImages, ) +from flytekit.core import context_manager from flytekit.core.context_manager import ( ExecutionParameters, - FlyteContextManager, ExecutionState, + FlyteContextManager, ) from flytekit.models.task import K8sObjectMetadata, K8sPod -from kubernetes.client.models import ( - V1Container, - V1PodSpec, - V1Toleration, - V1EnvVar, -) @pytest.fixture(scope="function") @@ -52,6 +50,7 @@ def reset_spark_session() -> None: SparkSession.builder.getOrCreate().stop() SparkSession._instantiatedSession = None + def test_spark_task(reset_spark_session): databricks_conf = { "name": "flytekit databricks plugin example", @@ -97,10 +96,7 @@ def my_spark(a: str) -> int: retrieved_settings = my_spark.get_custom(settings) assert retrieved_settings["sparkConf"] == {"spark": "1"} assert retrieved_settings["executorPath"] == "/usr/bin/python3" - assert ( - retrieved_settings["mainApplicationFile"] - == "local:///usr/local/bin/entrypoint.py" - ) + assert retrieved_settings["mainApplicationFile"] == "local:///usr/local/bin/entrypoint.py" pb = ExecutionParameters.new_builder() pb.working_dir = "/tmp" @@ -193,9 +189,7 @@ def test_to_html(): df = spark.createDataFrame([("Bob", 10)], ["name", "age"]) sd = StructuredDataset(dataframe=df) tf = StructuredDatasetTransformerEngine() - output = tf.to_html( - FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame - ) + output = tf.to_html(FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame) assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output @@ -228,9 +222,7 @@ def my_spark(a: int) -> int: ctx = context_manager.FlyteContextManager.current_context() with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( - ctx.new_execution_state().with_params( - mode=ExecutionState.Mode.TASK_EXECUTION - ) + ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION) ).with_serialization_settings(serialization_settings) ) as new_ctx: my_spark.pre_execute(new_ctx.user_space_params) @@ -239,6 +231,7 @@ def my_spark(a: int) -> int: mock_shutil_make_archive.assert_called_once_with("/tmp/123/flyte_wf", "zip", os.getcwd()) mock_add_pyfile.assert_called_once_with("/tmp/123/flyte_wf.zip") + def test_spark_with_image_spec(): custom_image = ImageSpec( registry="ghcr.io/flyteorg", @@ -428,7 +421,7 @@ def test_spark_driver_executor_podSpec(reset_spark_session): annotations={"aKeyA_d": "aValA", "aKeyB_d": "aValB"}, ), pod_spec=driver_pod_spec_dict_remove_None, # type: ignore - primary_container_name="driver-primary" + primary_container_name="driver-primary", ) target_executor_k8sPod = K8sPod( @@ -437,7 +430,7 @@ def test_spark_driver_executor_podSpec(reset_spark_session): annotations={"aKeyA_e": "aValA", "aKeyB_e": "aValB"}, ), pod_spec=executor_pod_spec_dict_remove_None, # type: ignore - primary_container_name="executor-primary" + primary_container_name="executor-primary", ) @task( @@ -471,16 +464,9 @@ def my_spark(a: str) -> int: retrieved_settings = my_spark.get_custom(settings) assert retrieved_settings["sparkConf"] == {"spark.driver.memory": "1000M"} assert retrieved_settings["executorPath"] == "/usr/bin/python3" - assert ( - retrieved_settings["mainApplicationFile"] - == "local:///usr/local/bin/entrypoint.py" - ) - assert retrieved_settings["driverPod"] == MessageToDict( - target_driver_k8sPod.to_flyte_idl() - ) - assert retrieved_settings["executorPod"] == MessageToDict( - target_executor_k8sPod.to_flyte_idl() - ) + assert retrieved_settings["mainApplicationFile"] == "local:///usr/local/bin/entrypoint.py" + assert retrieved_settings["driverPod"] == MessageToDict(target_driver_k8sPod.to_flyte_idl()) + assert retrieved_settings["executorPod"] == MessageToDict(target_executor_k8sPod.to_flyte_idl()) pb = ExecutionParameters.new_builder() pb.working_dir = "/tmp" @@ -527,8 +513,7 @@ def my_task(a: int) -> int: def test_databricks_v2_serverless_detection_with_config(reset_spark_session): - """Test that serverless is detected based on DatabricksV2 config.""" - # Serverless config: has environment_key, no cluster config + """Test that serverless is detected based on DatabricksV2 config when on Databricks.""" serverless_conf = { "run_name": "serverless-test", "environment_key": "my-env", @@ -543,10 +528,16 @@ def test_databricks_v2_serverless_detection_with_config(reset_spark_session): def serverless_task(a: int) -> int: return a - # Should detect serverless from config - assert serverless_task._is_databricks_serverless() is True + # Serverless config alone is not enough; must also be on Databricks + assert serverless_task._is_databricks_serverless() is False + + # With DATABRICKS_RUNTIME_VERSION set (simulating Databricks env), should detect serverless + os.environ["DATABRICKS_RUNTIME_VERSION"] = "15.4" + try: + assert serverless_task._is_databricks_serverless() is True + finally: + del os.environ["DATABRICKS_RUNTIME_VERSION"] - # Classic config: has new_cluster classic_conf = { "run_name": "classic-test", "new_cluster": {"spark_version": "13.3.x-scala2.12"}, @@ -561,18 +552,20 @@ def serverless_task(a: int) -> int: def classic_task(a: int) -> int: return a - # Should NOT detect serverless + # Should NOT detect serverless (classic config) assert classic_task._is_databricks_serverless() is False def test_databricks_v2_serverless_detection_with_environments_array(reset_spark_session): - """Test serverless detection with inline environments array.""" + """Test serverless detection with inline environments array on Databricks.""" serverless_conf = { "run_name": "serverless-inline", - "environments": [{ - "environment_key": "default", - "spec": {"client": "1", "dependencies": ["pandas"]} - }], + "environments": [ + { + "environment_key": "default", + "spec": {"client": "1", "dependencies": ["pandas"]}, + } + ], } @task( @@ -584,7 +577,12 @@ def test_databricks_v2_serverless_detection_with_environments_array(reset_spark_ def serverless_task(a: int) -> int: return a - assert serverless_task._is_databricks_serverless() is True + # Must be on Databricks for config-based detection to work + os.environ["DATABRICKS_RUNTIME_VERSION"] = "15.4" + try: + assert serverless_task._is_databricks_serverless() is True + finally: + del os.environ["DATABRICKS_RUNTIME_VERSION"] def test_databricks_v2_classic_not_detected_as_serverless(reset_spark_session): @@ -686,6 +684,7 @@ def task_no_creds(a: int) -> int: def test_spark_classic_not_affected_by_serverless_code(reset_spark_session): """Test that regular Spark tasks (non-Databricks) are not affected by serverless code.""" + @task( task_config=Spark( spark_conf={"spark.driver.memory": "512M"}, @@ -703,7 +702,7 @@ def spark_task(a: int) -> int: pb.execution_id = "ex:local:local:local" p = pb.build() new_p = spark_task.pre_execute(p) - + assert new_p is not None assert new_p.has_attr("SPARK_SESSION") assert spark_task.sess is not None diff --git a/pydoclint-errors-baseline.txt b/pydoclint-errors-baseline.txt index 606633e114..521eb5cafb 100644 --- a/pydoclint-errors-baseline.txt +++ b/pydoclint-errors-baseline.txt @@ -578,10 +578,6 @@ plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py plugins/flytekit-spark/flytekitplugins/spark/models.py DOC301: Class `SparkJob`: __init__() should not have a docstring; please combine it with the docstring of the class -------------------- -plugins/flytekit-spark/flytekitplugins/spark/task.py - DOC601: Class `DatabricksV2`: Class docstring contains fewer class attributes than actual class attributes. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) - DOC603: Class `DatabricksV2`: Class docstring attributes are different from actual class attributes. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Attributes in the class definition but not in the docstring: [databricks_conf: Optional[Dict[str, Union[str, dict]]], databricks_instance: Optional[str]]. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) --------------------- plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py DOC601: Class `SQLAlchemyConfig`: Class docstring contains fewer class attributes than actual class attributes. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) DOC603: Class `SQLAlchemyConfig`: Class docstring attributes are different from actual class attributes. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Attributes in the class definition but not in the docstring: [connect_args: typing.Optional[typing.Dict[str, typing.Any]], secret_connect_args: typing.Optional[typing.Dict[str, Secret]], uri: str]. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.)