Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
318 changes: 291 additions & 27 deletions native/core/src/parquet/schema_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

use crate::parquet::cast_column::CometCastColumnExpr;
use crate::parquet::parquet_support::{spark_parquet_convert, SparkParquetOptions};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion::common::Result as DataFusionResult;
use datafusion::error::DataFusionError;
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::ColumnarValue;
Expand All @@ -32,6 +33,18 @@ use datafusion_physical_expr_adapter::{
use std::collections::HashMap;
use std::sync::Arc;

/// Corresponds to Spark's `SchemaColumnConvertNotSupportedException`.
#[derive(Debug)]
struct SchemaColumnConvertError(String);

impl std::fmt::Display for SchemaColumnConvertError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

impl std::error::Error for SchemaColumnConvertError {}

/// Factory for creating Spark-compatible physical expression adapters.
///
/// This factory creates adapters that rewrite expressions at planning time
Expand Down Expand Up @@ -147,13 +160,20 @@ impl PhysicalExprAdapterFactory for SparkPhysicalExprAdapterFactory {
Arc::clone(&adapted_physical_schema),
);

let schema_validation_error = validate_spark_schema_compatibility(
&logical_file_schema,
&adapted_physical_schema,
self.parquet_options.case_sensitive,
);

Arc::new(SparkPhysicalExprAdapter {
logical_file_schema,
physical_file_schema: adapted_physical_schema,
parquet_options: self.parquet_options.clone(),
default_values: self.default_values.clone(),
default_adapter,
logical_to_physical_names,
schema_validation_error,
})
}
}
Expand Down Expand Up @@ -183,10 +203,17 @@ struct SparkPhysicalExprAdapter {
/// physical names so that downstream reassign_expr_columns can find
/// columns in the actual stream schema.
logical_to_physical_names: Option<HashMap<String, String>>,
schema_validation_error: Option<String>,
}

impl PhysicalExprAdapter for SparkPhysicalExprAdapter {
fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
if let Some(error_msg) = &self.schema_validation_error {
return Err(DataFusionError::External(Box::new(
SchemaColumnConvertError(error_msg.clone()),
)));
}

// First let the default adapter handle column remapping, missing columns,
// and simple scalar type casts. Then replace DataFusion's CastColumnExpr
// with Spark-compatible equivalents.
Expand Down Expand Up @@ -436,12 +463,174 @@ impl SparkPhysicalExprAdapter {
}
}

/// Validates physical-vs-logical schema compatibility per Spark's `TypeUtil.checkParquetType()`.
/// Returns an error message for the first incompatible column, or None if all compatible.
fn validate_spark_schema_compatibility(
logical_schema: &SchemaRef,
physical_schema: &SchemaRef,
case_sensitive: bool,
) -> Option<String> {
for logical_field in logical_schema.fields() {
let physical_field = if case_sensitive {
physical_schema
.fields()
.iter()
.find(|f| f.name() == logical_field.name())
} else {
physical_schema
.fields()
.iter()
.find(|f| f.name().to_lowercase() == logical_field.name().to_lowercase())
};

if let Some(physical_field) = physical_field {
let physical_type = physical_field.data_type();
let logical_type = logical_field.data_type();
if physical_type != logical_type
&& !is_spark_compatible_read(physical_type, logical_type)
{
return Some(format!(
"Column: [{}], Expected: {}, Found: {}",
logical_field.name(),
spark_type_name(logical_type),
arrow_to_parquet_type_name(physical_type),
));
}
}
}
None
}

/// Whether reading a Parquet column with `physical_type` as Spark `logical_type` is allowed.
/// See Spark's `TypeUtil.checkParquetType()`.
fn is_spark_compatible_read(physical_type: &DataType, logical_type: &DataType) -> bool {
use DataType::*;

match (physical_type, logical_type) {
_ if physical_type == logical_type => true,

// RunEndEncoded is an Arrow encoding wrapper (e.g., from Iceberg).
// Unwrap to the inner values type and check compatibility.
(RunEndEncoded(_, values_field), _) => {
is_spark_compatible_read(values_field.data_type(), logical_type)
}

(_, Null) => true,

// Integer family (same-width and widenings)
(Int8 | Int16 | Int32, Int8 | Int16 | Int32 | Int64) => true,
(Int64, Int64) => true,
(Int32 | Int8 | Int16, Date32) => true,
(Int8 | Int16 | Int32 | Int64, Decimal128(_, _)) => true,

// Unsigned int conversions
(UInt8, Int8 | Int16 | Int32 | Int64) => true,
(UInt16, Int16 | Int32 | Int64) => true,
(UInt32, Int32 | Int64) => true,
(UInt64, Decimal128(20, 0)) => true,

// Float widening
(Float32, Float64) => true,

// Timestamps: only LTZ → NTZ is rejected (SPARK-36182).
// NTZ → LTZ is allowed because DataFusion coerces INT96 to Timestamp(us, None)
// and the Spark schema may expect Timestamp(us, Some("UTC")).
(Timestamp(_, tz_physical), Timestamp(_, tz_logical)) => {
!(tz_physical.is_some() && tz_logical.is_none())
}

// Timestamp ↔ Int64: nanosAsLong and Iceberg timestamp partition columns
(Timestamp(_, _), Int64) | (Int64, Timestamp(_, _)) => true,

// BINARY / String interop
(Binary | LargeBinary | Utf8 | LargeUtf8, Binary | LargeBinary | Utf8 | LargeUtf8) => true,
(Binary | LargeBinary | FixedSizeBinary(_), Decimal128(_, _)) => true,
(FixedSizeBinary(_), Binary | LargeBinary | Utf8 | LargeUtf8) => true,

// Decimal precision/scale: required precision >= physical, scales must match
(Decimal128(p1, s1), Decimal128(p2, s2)) => p1 <= p2 && s1 == s2,

// Nested types (DataFusion handles inner-type adaptation)
(Struct(_), Struct(_))
| (List(_), List(_))
| (LargeList(_), List(_) | LargeList(_))
| (Map(_, _), Map(_, _)) => true,

_ => false,
}
}

fn spark_type_name(dt: &DataType) -> String {
match dt {
DataType::Boolean => "boolean".to_string(),
DataType::Int8 => "tinyint".to_string(),
DataType::Int16 => "smallint".to_string(),
DataType::Int32 => "int".to_string(),
DataType::Int64 => "bigint".to_string(),
DataType::Float32 => "float".to_string(),
DataType::Float64 => "double".to_string(),
DataType::Utf8 | DataType::LargeUtf8 => "string".to_string(),
DataType::Binary | DataType::LargeBinary => "binary".to_string(),
DataType::Date32 => "date".to_string(),
DataType::Timestamp(TimeUnit::Microsecond, None) => "timestamp_ntz".to_string(),
DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => "timestamp".to_string(),
DataType::Timestamp(unit, tz) => format!("timestamp({unit:?}, {tz:?})"),
DataType::Decimal128(p, s) => format!("decimal({p},{s})"),
DataType::List(f) => format!("array<{}>", spark_type_name(f.data_type())),
DataType::LargeList(f) => format!("array<{}>", spark_type_name(f.data_type())),
DataType::Map(f, _) => {
if let DataType::Struct(fields) = f.data_type() {
if fields.len() == 2 {
return format!(
"map<{},{}>",
spark_type_name(fields[0].data_type()),
spark_type_name(fields[1].data_type())
);
}
}
format!("map<{}>", spark_type_name(f.data_type()))
}
DataType::Struct(fields) => {
let field_strs: Vec<String> = fields
.iter()
.map(|f| format!("{}:{}", f.name(), spark_type_name(f.data_type())))
.collect();
format!("struct<{}>", field_strs.join(","))
}
other => format!("{other}"),
}
}

fn arrow_to_parquet_type_name(dt: &DataType) -> String {
match dt {
DataType::Boolean => "BOOLEAN".to_string(),
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::UInt32
| DataType::Date32 => "INT32".to_string(),
DataType::Int64 | DataType::UInt64 => "INT64".to_string(),
DataType::Float32 => "FLOAT".to_string(),
DataType::Float64 => "DOUBLE".to_string(),
DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
"BINARY".to_string()
}
DataType::FixedSizeBinary(n) => format!("FIXED_LEN_BYTE_ARRAY({n})"),
DataType::Timestamp(_, _) => "INT64".to_string(),
DataType::Decimal128(p, s) => format!("DECIMAL({p},{s})"),
DataType::RunEndEncoded(_, values_field) => {
arrow_to_parquet_type_name(values_field.data_type())
}
other => format!("{other}"),
}
}

#[cfg(test)]
mod test {
use crate::parquet::parquet_support::SparkParquetOptions;
use crate::parquet::schema_adapter::SparkPhysicalExprAdapterFactory;
use arrow::array::Int32Array;
use arrow::array::UInt32Array;
use arrow::array::{Int32Array, StringArray};
use arrow::datatypes::SchemaRef;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
Expand All @@ -460,28 +649,6 @@ mod test {
use std::fs::File;
use std::sync::Arc;

#[tokio::test]
async fn parquet_roundtrip_int_as_string() -> Result<(), DataFusionError> {
let file_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));

let ids = Arc::new(Int32Array::from(vec![1, 2, 3])) as Arc<dyn arrow::array::Array>;
let names = Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"]))
as Arc<dyn arrow::array::Array>;
let batch = RecordBatch::try_new(Arc::clone(&file_schema), vec![ids, names])?;

let required_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("name", DataType::Utf8, false),
]));

let _ = roundtrip(&batch, required_schema).await?;

Ok(())
}

#[tokio::test]
async fn parquet_roundtrip_unsigned_int() -> Result<(), DataFusionError> {
let file_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt32, false)]));
Expand All @@ -496,8 +663,106 @@ mod test {
Ok(())
}

/// Create a Parquet file containing a single batch and then read the batch back using
/// the specified required_schema. This will cause the PhysicalExprAdapter code to be used.
// Int32→Int64 is a valid type widening that DataFusion handles correctly
#[tokio::test]
async fn parquet_int_as_long_should_succeed() -> Result<(), DataFusionError> {
let file_schema = Arc::new(Schema::new(vec![Field::new("_1", DataType::Int32, true)]));
let values = Arc::new(Int32Array::from(vec![1, 2, 3])) as Arc<dyn arrow::array::Array>;
let batch = RecordBatch::try_new(Arc::clone(&file_schema), vec![values])?;
let required_schema = Arc::new(Schema::new(vec![Field::new("_1", DataType::Int64, true)]));

let result = roundtrip(&batch, required_schema).await?;
assert_eq!(result.num_rows(), 3);
Ok(())
}

// SPARK-36182: reading TimestampLTZ as TimestampNTZ should fail
#[tokio::test]
async fn parquet_timestamp_ltz_as_ntz_should_fail() -> Result<(), DataFusionError> {
use arrow::datatypes::TimeUnit;
let file_schema = Arc::new(Schema::new(vec![Field::new(
"ts",
DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
true,
)]));
let values = Arc::new(
arrow::array::TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000])
.with_timezone("UTC"),
) as Arc<dyn arrow::array::Array>;
let batch = RecordBatch::try_new(Arc::clone(&file_schema), vec![values])?;
let required_schema = Arc::new(Schema::new(vec![Field::new(
"ts",
DataType::Timestamp(TimeUnit::Microsecond, None),
true,
)]));

let result = roundtrip(&batch, required_schema).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Column: [ts]"));
Ok(())
}

#[test]
fn test_is_spark_compatible_read() {
use super::is_spark_compatible_read;
use arrow::datatypes::TimeUnit;

// Compatible
assert!(is_spark_compatible_read(&DataType::Binary, &DataType::Utf8));
assert!(is_spark_compatible_read(
&DataType::UInt32,
&DataType::Int64
));
assert!(is_spark_compatible_read(
&DataType::Int32,
&DataType::Date32
));
assert!(is_spark_compatible_read(
&DataType::Decimal128(10, 2),
&DataType::Decimal128(18, 2)
));
// NTZ → LTZ allowed (INT96 coercion produces NTZ, Spark schema expects LTZ)
assert!(is_spark_compatible_read(
&DataType::Timestamp(TimeUnit::Microsecond, None),
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
));
// Timestamp → Int64 allowed (nanosAsLong)
assert!(is_spark_compatible_read(
&DataType::Timestamp(TimeUnit::Nanosecond, None),
&DataType::Int64
));

// Compatible widenings
assert!(is_spark_compatible_read(&DataType::Int32, &DataType::Int64));
assert!(is_spark_compatible_read(
&DataType::Float32,
&DataType::Float64
));
assert!(is_spark_compatible_read(
&DataType::Int64,
&DataType::Timestamp(TimeUnit::Microsecond, None)
));

// Incompatible (#3720 cases)
assert!(!is_spark_compatible_read(
&DataType::Utf8,
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
));
assert!(!is_spark_compatible_read(
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
&DataType::Timestamp(TimeUnit::Microsecond, None)
));
assert!(!is_spark_compatible_read(&DataType::Utf8, &DataType::Int32));
assert!(!is_spark_compatible_read(
&DataType::Decimal128(18, 2),
&DataType::Decimal128(10, 2)
));
assert!(!is_spark_compatible_read(
&DataType::Decimal128(10, 2),
&DataType::Decimal128(10, 3)
));
}

async fn roundtrip(
batch: &RecordBatch,
required_schema: SchemaRef,
Expand All @@ -514,7 +779,6 @@ mod test {
let mut spark_parquet_options = SparkParquetOptions::new(EvalMode::Legacy, "UTC", false);
spark_parquet_options.allow_cast_unsigned_ints = true;

// Create expression adapter factory for Spark-compatible schema adaptation
let expr_adapter_factory: Arc<dyn PhysicalExprAdapterFactory> = Arc::new(
SparkPhysicalExprAdapterFactory::new(spark_parquet_options, None),
);
Expand Down
Loading