diff --git a/docs/changelog/next_release/371.dependency.rst b/docs/changelog/next_release/371.dependency.rst new file mode 100644 index 00000000..5efb37df --- /dev/null +++ b/docs/changelog/next_release/371.dependency.rst @@ -0,0 +1 @@ +Updated ClickHouse ``spark-dialect-extension`` to 0.0.4. diff --git a/syncmaster/worker/handlers/db/clickhouse.py b/syncmaster/worker/handlers/db/clickhouse.py index dbe14462..59bea251 100644 --- a/syncmaster/worker/handlers/db/clickhouse.py +++ b/syncmaster/worker/handlers/db/clickhouse.py @@ -3,16 +3,36 @@ from __future__ import annotations +from datetime import UTC, datetime +from decimal import Decimal from typing import TYPE_CHECKING, ClassVar from onetl.connection import Clickhouse from onetl.db import DBWriter from onetl.hooks import slot, support_hooks +from pyspark.sql import functions as F # noqa: N812 +from pyspark.sql.types import ( + ArrayType, + BooleanType, + ByteType, + DataType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + MapType, + ShortType, + StringType, + TimestampNTZType, + TimestampType, +) from syncmaster.worker.handlers.db.base import DBHandler if TYPE_CHECKING: - from pyspark.sql import SparkSession + from pyspark.sql import Column, SparkSession from pyspark.sql.dataframe import DataFrame from syncmaster.dto.connections import ClickhouseConnectionDTO @@ -55,6 +75,7 @@ def write(self, df: DataFrame) -> None: (col for col in normalized_df.columns if col.lower().endswith("id")), normalized_df.columns[0], # if there is no column with "id", take the first column ) + normalized_df = self._normalize_column_to_non_nullable(normalized_df, sort_column) self.transfer_dto.options["createTableOptions"] = ( f"ENGINE = MergeTree() ORDER BY {self._quote_field(sort_column)}" ) @@ -74,6 +95,36 @@ def _normalize_column_names(self, df: DataFrame) -> DataFrame: df = df.withColumnRenamed(column_name, column_name.lower()) return df + def _normalize_column_to_non_nullable(self, df: DataFrame, column: str) -> DataFrame: + """Rewrite ``column`` with ``coalesce`` so Spark marks it non-nullable. + + TODO: remove this workaround if Spark infers ``nullable=false`` without a dummy literal: + https://issues.apache.org/jira/browse/SPARK-54302 + """ + field = df.schema[column] + if not field.nullable: + return df + + field_type: DataType = field.dataType + sentinel: Column + if isinstance(field_type, (ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, BooleanType)): + sentinel = F.lit(0).cast(field_type) + elif isinstance(field_type, DecimalType): + sentinel = F.lit(Decimal(0)).cast(field_type) + elif isinstance(field_type, StringType): + sentinel = F.lit("") + elif isinstance(field_type, (DateType, TimestampType, TimestampNTZType)): + sentinel = F.lit(datetime(1970, 1, 1, 0, 0, 0, tzinfo=UTC)).cast(field_type) + elif isinstance(field_type, ArrayType): + sentinel = F.array().cast(field_type) + elif isinstance(field_type, MapType): + sentinel = F.map_from_arrays(F.array(), F.array()).cast(field_type) + else: + msg = f"Unsupported Spark type for non-null: {field_type!r} (column {column!r})" + raise TypeError(msg) + + return df.withColumn(column, F.coalesce(F.col(column), sentinel)) + def _make_rows_filter_expression(self, filters: list[dict]) -> str | None: expressions = [] for filter_ in filters: diff --git a/syncmaster/worker/ivy2.py b/syncmaster/worker/ivy2.py index bd883928..61b12736 100644 --- a/syncmaster/worker/ivy2.py +++ b/syncmaster/worker/ivy2.py @@ -31,7 +31,7 @@ def get_packages(connection_types: set[str]) -> list[str]: if connection_types & {"oracle", "all"}: result.extend(Oracle.get_packages()) if connection_types & {"clickhouse", "all"}: - result.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2") + result.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.4") result.extend(Clickhouse.get_packages()) if connection_types & {"mssql", "all"}: result.extend(MSSQL.get_packages()) diff --git a/tests/test_integration/test_run_transfer/connection_fixtures/spark_fixtures.py b/tests/test_integration/test_run_transfer/connection_fixtures/spark_fixtures.py index 6dc4c01e..84219df6 100644 --- a/tests/test_integration/test_run_transfer/connection_fixtures/spark_fixtures.py +++ b/tests/test_integration/test_run_transfer/connection_fixtures/spark_fixtures.py @@ -51,7 +51,7 @@ def spark( # noqa: C901 maven_packages.extend(Oracle.get_packages()) if "clickhouse" in markers: - maven_packages.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2") + maven_packages.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.4") maven_packages.extend(Clickhouse.get_packages()) if "mssql" in markers: