Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog/next_release/371.dependency.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Updated ClickHouse ``spark-dialect-extension`` to 0.0.4.
53 changes: 52 additions & 1 deletion syncmaster/worker/handlers/db/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,36 @@

from __future__ import annotations

from datetime import UTC, datetime
from decimal import Decimal
from typing import TYPE_CHECKING, ClassVar

from onetl.connection import Clickhouse
from onetl.db import DBWriter
from onetl.hooks import slot, support_hooks
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import (
ArrayType,
BooleanType,
ByteType,
DataType,
DateType,
DecimalType,
DoubleType,
FloatType,
IntegerType,
LongType,
MapType,
ShortType,
StringType,
TimestampNTZType,
TimestampType,
)

from syncmaster.worker.handlers.db.base import DBHandler

if TYPE_CHECKING:
from pyspark.sql import SparkSession
from pyspark.sql import Column, SparkSession
from pyspark.sql.dataframe import DataFrame

from syncmaster.dto.connections import ClickhouseConnectionDTO
Expand Down Expand Up @@ -55,6 +75,7 @@ def write(self, df: DataFrame) -> None:
(col for col in normalized_df.columns if col.lower().endswith("id")),
normalized_df.columns[0], # if there is no column with "id", take the first column
)
normalized_df = self._normalize_column_to_non_nullable(normalized_df, sort_column)
self.transfer_dto.options["createTableOptions"] = (
f"ENGINE = MergeTree() ORDER BY {self._quote_field(sort_column)}"
)
Expand All @@ -74,6 +95,36 @@ def _normalize_column_names(self, df: DataFrame) -> DataFrame:
df = df.withColumnRenamed(column_name, column_name.lower())
return df

def _normalize_column_to_non_nullable(self, df: DataFrame, column: str) -> DataFrame:
"""Rewrite ``column`` with ``coalesce`` so Spark marks it non-nullable.

TODO: remove this workaround if Spark infers ``nullable=false`` without a dummy literal:
https://issues.apache.org/jira/browse/SPARK-54302
"""
field = df.schema[column]
if not field.nullable:
return df

field_type: DataType = field.dataType
sentinel: Column
if isinstance(field_type, (ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, BooleanType)):
sentinel = F.lit(0).cast(field_type)
elif isinstance(field_type, DecimalType):
sentinel = F.lit(Decimal(0)).cast(field_type)
elif isinstance(field_type, StringType):
sentinel = F.lit("")
elif isinstance(field_type, (DateType, TimestampType, TimestampNTZType)):
sentinel = F.lit(datetime(1970, 1, 1, 0, 0, 0, tzinfo=UTC)).cast(field_type)
elif isinstance(field_type, ArrayType):
sentinel = F.array().cast(field_type)
elif isinstance(field_type, MapType):
sentinel = F.map_from_arrays(F.array(), F.array()).cast(field_type)
else:
msg = f"Unsupported Spark type for non-null: {field_type!r} (column {column!r})"
raise TypeError(msg)

return df.withColumn(column, F.coalesce(F.col(column), sentinel))

def _make_rows_filter_expression(self, filters: list[dict]) -> str | None:
expressions = []
for filter_ in filters:
Expand Down
2 changes: 1 addition & 1 deletion syncmaster/worker/ivy2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_packages(connection_types: set[str]) -> list[str]:
if connection_types & {"oracle", "all"}:
result.extend(Oracle.get_packages())
if connection_types & {"clickhouse", "all"}:
result.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2")
result.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.4")
result.extend(Clickhouse.get_packages())
if connection_types & {"mssql", "all"}:
result.extend(MSSQL.get_packages())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def spark( # noqa: C901
maven_packages.extend(Oracle.get_packages())

if "clickhouse" in markers:
maven_packages.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2")
maven_packages.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.4")
maven_packages.extend(Clickhouse.get_packages())

if "mssql" in markers:
Expand Down