Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions Dockerfile.connector
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ ARG VERSION

RUN apt-get update && apt-get install build-essential -y \
&& pip install uv

RUN uv pip install --system --no-cache-dir -U flytekit[connector]==$VERSION \
# Pin pendulum<3.0: Apache Airflow (via flytekitplugins-airflow) imports
# pendulum.tz.timezone() at module load time (airflow/settings.py).
# Pendulum 3.x changed the tz API, causing the connector to crash on startup:
# airflow/settings.py → TIMEZONE = pendulum.tz.timezone("UTC") → AttributeError
# Without this pin, uv resolves to pendulum 3.x which breaks the import chain:
# pyflyte serve connector → load_implicit_plugins → airflow → pendulum → crash
RUN uv pip install --system --no-cache-dir -U \
"pendulum>=2.0.0,<3.0" \
flytekit[connector]==$VERSION \
flytekitplugins-airflow==$VERSION \
flytekitplugins-bigquery==$VERSION \
flytekitplugins-k8sdataservice==$VERSION \
Expand Down
234 changes: 209 additions & 25 deletions plugins/flytekit-spark/flytekitplugins/spark/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

from .utils import is_serverless_config as _is_serverless_config

aiohttp = lazy_module("aiohttp")

DATABRICKS_API_ENDPOINT = "/api/2.1/jobs"
DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE"
DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY = "FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER"


@dataclass
Expand All @@ -27,38 +30,219 @@ class DatabricksJobMetadata(ResourceMeta):
run_id: str


def _configure_serverless(databricks_job: dict, envs: dict) -> str:
"""
Configure serverless compute settings and return the environment_key to use.

Databricks serverless requires the ``environments`` array in the job submission.
This function ensures the array exists and injects Flyte environment variables
into the matching environment's ``spec.environment_vars``.

Reference: https://docs.databricks.com/api/workspace/jobs/submit

Expected ``environments`` format::

"environments": [
{
"environment_key": "<key>",
"spec": {
"client": "1",
"dependencies": ["pandas==2.0.0"],
"environment_vars": {"KEY": "VALUE"}
}
}
]

Tasks reference an environment via their own ``environment_key`` field,
analogous to how ``job_cluster_key`` links a task to a shared cluster.

Args:
databricks_job (dict): The databricks job configuration dict.
envs (dict): Environment variables to inject into the environment spec.

Returns:
str: The environment_key to use in the task definition.
"""
environment_key = databricks_job.get("environment_key", "default")
environments = databricks_job.get("environments", [])

# Check if environment already exists in the array
env_exists = any(env.get("environment_key") == environment_key for env in environments)

if not env_exists:
# Create the environment entry - Databricks serverless requires environments
# to be defined in the job submission (not externally pre-configured)
new_env = {
"environment_key": environment_key,
"spec": {
"client": "1", # Required: Databricks serverless client version
},
}
environments.append(new_env)
databricks_job["environments"] = environments

# Inject Flyte environment variables into the environment spec
for env in environments:
if env.get("environment_key") == environment_key:
spec = env.setdefault("spec", {})
existing_env_vars = spec.get("environment_vars", {})
# Merge Flyte env vars with any existing ones (Flyte vars take precedence)
merged_env_vars = {**existing_env_vars, **{k: v for k, v in envs.items()}}
spec["environment_vars"] = merged_env_vars
break

# Remove environment_key from top level (it's now in the task definition)
databricks_job.pop("environment_key", None)

return environment_key


def _configure_classic_cluster(databricks_job: dict, custom: dict, container: typing.Any, envs: dict) -> None:
"""
Configure classic compute (existing cluster or new cluster).

Args:
databricks_job (dict): The databricks job configuration dict.
custom (dict): The custom config from task template.
container (typing.Any): The container config from task template.
envs (dict): Environment variables to inject.
"""
if databricks_job.get("existing_cluster_id") is not None:
# Using an existing cluster, no additional configuration needed
return

new_cluster = databricks_job.get("new_cluster")
if new_cluster is None:
return

if not new_cluster.get("docker_image"):
new_cluster["docker_image"] = {"url": container.image}
if not new_cluster.get("spark_conf"):
new_cluster["spark_conf"] = custom.get("sparkConf", {})
if not new_cluster.get("spark_env_vars"):
new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()}
else:
new_cluster["spark_env_vars"].update({k: v for k, v in envs.items()})


def _build_notebook_job_spec(
databricks_job: dict, custom: dict, container: typing.Any, envs: dict, is_serverless: bool
) -> dict:
"""Build the Databricks job spec for a notebook task."""
notebook_path = custom["notebookPath"]
notebook_base_parameters = custom.get("notebookBaseParameters", {})

notebook_task = {"notebook_path": notebook_path}
if notebook_base_parameters:
notebook_task["base_parameters"] = notebook_base_parameters

user_git_source = databricks_job.get("git_source")
if user_git_source:
notebook_task["source"] = "GIT"

if is_serverless:
environment_key = _configure_serverless(databricks_job, envs)
task_def = {
"task_key": "flyte_notebook_task",
"notebook_task": notebook_task,
"environment_key": environment_key,
}
databricks_job["tasks"] = [task_def]
else:
_configure_classic_cluster(databricks_job, custom, container, envs)
databricks_job["notebook_task"] = notebook_task

databricks_job.pop("git_source", None)
if user_git_source:
databricks_job["git_source"] = user_git_source

return databricks_job


def _build_python_file_job_spec(
databricks_job: dict, custom: dict, container: typing.Any, envs: dict, is_serverless: bool
) -> dict:
"""Build the Databricks job spec for a python file (spark_python_task)."""
user_git_source = databricks_job.get("git_source")
user_python_file = databricks_job.get("python_file")

default_git_source = {
"git_url": "https://github.com/flyteorg/flytetools",
"git_provider": "gitHub",
"git_commit": "572298df1f971fb58c258398bd70a6372f811c96",
}
default_classic_python_file = "flytekitplugins/databricks/entrypoint.py"
default_serverless_python_file = "flytekitplugins/databricks/entrypoint_serverless.py"

if is_serverless:
git_source = user_git_source or default_git_source
python_file = user_python_file or default_serverless_python_file

environment_key = _configure_serverless(databricks_job, envs)

parameters = list(container.args) if container.args else []

service_credential_provider = custom.get(
"databricksServiceCredentialProvider", os.getenv(DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY)
)
if service_credential_provider:
parameters.append(f"--flyte-credential-provider={service_credential_provider}")

spark_python_task = {
"python_file": python_file,
"source": "GIT",
"parameters": parameters,
}

task_def = {
"task_key": "flyte_task",
"spark_python_task": spark_python_task,
"environment_key": environment_key,
}

databricks_job["tasks"] = [task_def]
else:
git_source = user_git_source or default_git_source
python_file = user_python_file or default_classic_python_file

spark_python_task = {
"python_file": python_file,
"source": "GIT",
"parameters": container.args,
}

_configure_classic_cluster(databricks_job, custom, container, envs)
databricks_job["spark_python_task"] = spark_python_task

databricks_job.pop("git_source", None)
databricks_job.pop("python_file", None)
databricks_job["git_source"] = git_source

return databricks_job


def _get_databricks_job_spec(task_template: TaskTemplate) -> dict:
custom = task_template.custom
container = task_template.container
envs = task_template.container.env
envs[FLYTE_FAIL_ON_ERROR] = "true"
databricks_job = custom["databricksConf"]
if databricks_job.get("existing_cluster_id") is None:
new_cluster = databricks_job.get("new_cluster")
if new_cluster is None:
raise ValueError("Either existing_cluster_id or new_cluster must be specified")
if not new_cluster.get("docker_image"):
new_cluster["docker_image"] = {"url": container.image}
if not new_cluster.get("spark_conf"):
new_cluster["spark_conf"] = custom.get("sparkConf", {})
if not new_cluster.get("spark_env_vars"):
new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()}
else:
new_cluster["spark_env_vars"].update({k: v for k, v in envs.items()})
# https://docs.databricks.com/api/workspace/jobs/submit
databricks_job["spark_python_task"] = {
"python_file": "flytekitplugins/databricks/entrypoint.py",
"source": "GIT",
"parameters": container.args,
}
databricks_job["git_source"] = {
"git_url": "https://github.com/flyteorg/flytetools",
"git_provider": "gitHub",
# https://github.com/flyteorg/flytetools/commit/572298df1f971fb58c258398bd70a6372f811c96
"git_commit": "572298df1f971fb58c258398bd70a6372f811c96",
}

return databricks_job
has_cluster = databricks_job.get("existing_cluster_id") is not None or databricks_job.get("new_cluster") is not None
has_serverless = bool(databricks_job.get("environment_key") or databricks_job.get("environments"))
if not has_cluster and not has_serverless:
raise ValueError(
"No compute configuration found in databricks_conf. "
"Provide one of: 'existing_cluster_id' (classic), 'new_cluster' (classic), "
"'environment_key' (serverless), or 'environments' (serverless)."
)

is_serverless = _is_serverless_config(databricks_job)

if custom.get("notebookPath"):
return _build_notebook_job_spec(databricks_job, custom, container, envs, is_serverless)

return _build_python_file_job_spec(databricks_job, custom, container, envs, is_serverless)


class DatabricksConnector(AsyncConnectorBase):
Expand Down
Loading
Loading