diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 5af31fcc22..2175670361 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -123,8 +123,8 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::{ ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract, - NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, - WideDecimalBinaryExpr, WideDecimalOp, + NormalizeNaNAndZero, Percentile, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, + Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::GlobalRef; @@ -2267,6 +2267,58 @@ impl PhysicalPlanner { )); Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func) } + AggExprStruct::PercentileCont(expr) => { + let return_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + + // Cast input to appropriate type based on return type + // For interval types, we preserve the type; for numeric types, cast to Float64 + let child = match &return_type { + DataType::Interval(_) => child, + _ => Arc::new(CastExpr::new(child, DataType::Float64, None)) as Arc, + }; + + // Extract the literal percentile value + let percentile_expr = + self.create_expr(expr.percentile.as_ref().unwrap(), Arc::clone(&schema))?; + let percentile_value = percentile_expr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ExecutionError::GeneralError("percentile must be a literal".into()) + })? + .value() + .clone(); + + let percentile = match percentile_value { + ScalarValue::Float64(Some(p)) => p, + ScalarValue::Float32(Some(p)) => p as f64, + ScalarValue::Int64(Some(p)) => p as f64, + ScalarValue::Int32(Some(p)) => p as f64, + _ => { + return Err(ExecutionError::GeneralError(format!( + "percentile must be a numeric literal, got {:?}", + percentile_value + ))) + } + }; + + // Custom Spark-compatible Percentile implementation + let func = AggregateUDF::new_from_impl(Percentile::new( + "spark_percentile", + percentile, + expr.reverse, + return_type, + )); + + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("spark_percentile") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + } } } diff --git a/native/core/src/execution/serde.rs b/native/core/src/execution/serde.rs index ae0554ee76..e11afb33e4 100644 --- a/native/core/src/execution/serde.rs +++ b/native/core/src/execution/serde.rs @@ -168,5 +168,11 @@ pub fn to_arrow_datatype(dt_value: &DataType) -> ArrowDataType { } _ => unreachable!(), }, + DataTypeId::YearMonthInterval => { + ArrowDataType::Interval(arrow::datatypes::IntervalUnit::YearMonth) + } + DataTypeId::DayTimeInterval => { + ArrowDataType::Interval(arrow::datatypes::IntervalUnit::DayTime) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 32cbc0ce13..1301b17957 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -139,6 +139,7 @@ message AggExpr { Stddev stddev = 14; Correlation correlation = 15; BloomFilterAgg bloomFilterAgg = 16; + PercentileCont percentileCont = 17; } // Optional QueryContext for error reporting (contains SQL text and position) @@ -243,6 +244,13 @@ message BloomFilterAgg { DataType datatype = 4; } +message PercentileCont { + Expr child = 1; // The column to compute percentile on + Expr percentile = 2; // The percentile value (0.0-1.0) + DataType datatype = 3; // Return type + bool reverse = 4; // True if ORDER BY DESC +} + enum EvalMode { LEGACY = 0; TRY = 1; diff --git a/native/proto/src/proto/types.proto b/native/proto/src/proto/types.proto index 2fd3d59a73..361607f14d 100644 --- a/native/proto/src/proto/types.proto +++ b/native/proto/src/proto/types.proto @@ -59,6 +59,8 @@ message DataType { LIST = 14; MAP = 15; STRUCT = 16; + YEAR_MONTH_INTERVAL = 17; + DAY_TIME_INTERVAL = 18; } DataTypeId type_id = 1; diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index b1027153e8..e71c49c17b 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod percentile; mod stddev; mod sum_decimal; mod sum_int; @@ -28,6 +29,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use percentile::Percentile; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/native/spark-expr/src/agg_funcs/percentile.rs b/native/spark-expr/src/agg_funcs/percentile.rs new file mode 100644 index 0000000000..3cadbe3392 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/percentile.rs @@ -0,0 +1,481 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Spark-compatible Percentile aggregate function. +//! +//! This implementation matches Spark's `Percentile` class intermediate state format, +//! which uses a serialized map of (value -> frequency) stored as BinaryType. + +use arrow::array::{Array, ArrayRef, BinaryArray, Float64Array, IntervalDayTimeArray, IntervalYearMonthArray}; +use arrow::datatypes::{DataType, Field, FieldRef, IntervalDayTimeType, IntervalUnit}; +use datafusion::common::{Result, ScalarValue}; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature}; +use datafusion::physical_expr::expressions::format_state_name; +use std::any::Any; +use std::collections::BTreeMap; +use std::sync::Arc; + +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::Volatility::Immutable; + +/// Spark-compatible Percentile aggregate function. +/// +/// Stores intermediate state as BinaryType containing serialized (value, count) pairs, +/// matching Spark's `TypedAggregateWithHashMapAsBuffer` format. +#[derive(Debug, Clone, PartialEq)] +pub struct Percentile { + name: String, + signature: Signature, + /// Percentile value stored as bits for Hash/Eq + percentile_bits: u64, + reverse: bool, + /// The return data type + return_type: DataType, +} + +impl std::hash::Hash for Percentile { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.percentile_bits.hash(state); + self.reverse.hash(state); + } +} + +impl Eq for Percentile {} + +impl Percentile { + pub fn new(name: impl Into, percentile: f64, reverse: bool, return_type: DataType) -> Self { + Self { + name: name.into(), + signature: Signature::any(1, Immutable), + percentile_bits: percentile.to_bits(), + reverse, + return_type, + } + } + + fn percentile(&self) -> f64 { + f64::from_bits(self.percentile_bits) + } +} + +impl AggregateUDFImpl for Percentile { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + // Match Spark's BinaryType state format + Ok(vec![Arc::new(Field::new( + format_state_name(args.name, "counts"), + DataType::Binary, + true, + ))]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(PercentileAccumulator::new( + self.percentile(), + self.reverse, + self.return_type.clone(), + ))) + } + + fn default_value(&self, _data_type: &DataType) -> Result { + match &self.return_type { + DataType::Float64 => Ok(ScalarValue::Float64(None)), + DataType::Interval(IntervalUnit::YearMonth) => Ok(ScalarValue::IntervalYearMonth(None)), + DataType::Interval(IntervalUnit::DayTime) => Ok(ScalarValue::IntervalDayTime(None)), + _ => Ok(ScalarValue::Float64(None)), + } + } +} + +/// Accumulator for Percentile that stores (value -> count) map. +/// Values are stored as i64 regardless of input type to simplify the implementation. +#[derive(Debug)] +pub struct PercentileAccumulator { + /// Map of value (as i64 bits) -> frequency count (using BTreeMap for sorted iteration) + counts: BTreeMap, + /// The percentile to compute (0.0 to 1.0) + percentile: f64, + /// Whether to reverse the order (for DESC) + reverse: bool, + /// The return data type + return_type: DataType, +} + +impl PercentileAccumulator { + pub fn new(percentile: f64, reverse: bool, return_type: DataType) -> Self { + Self { + counts: BTreeMap::new(), + percentile, + reverse, + return_type, + } + } + + /// Serialize the counts map to Spark's binary format. + fn serialize(&self) -> Vec { + let mut buf = Vec::new(); + + for (&key, &count) in &self.counts { + // Each entry: [size: i32][key: i64][count: i64] + // Size = 8 (i64) + 8 (i64) = 16 bytes + let size: i32 = 16; + buf.extend_from_slice(&size.to_be_bytes()); + buf.extend_from_slice(&key.to_be_bytes()); + buf.extend_from_slice(&count.to_be_bytes()); + } + + // End marker + buf.extend_from_slice(&(-1i32).to_be_bytes()); + buf + } + + /// Deserialize counts map from Spark's binary format. + fn deserialize(bytes: &[u8]) -> Result> { + let mut counts = BTreeMap::new(); + let mut offset = 0; + + while offset + 4 <= bytes.len() { + let size = i32::from_be_bytes(bytes[offset..offset + 4].try_into().unwrap()); + offset += 4; + + if size < 0 { + // End marker + break; + } + + if offset + 16 > bytes.len() { + break; + } + + let key = i64::from_be_bytes(bytes[offset..offset + 8].try_into().unwrap()); + offset += 8; + let count = i64::from_be_bytes(bytes[offset..offset + 8].try_into().unwrap()); + offset += 8; + + counts.insert(key, count); + } + + Ok(counts) + } + + /// Compute the percentile from the accumulated counts. + /// Returns the result as i64 bits (can be interpreted as f64 bits or interval value). + fn compute_percentile_i64(&self) -> Option { + if self.counts.is_empty() { + return None; + } + + // Get sorted (value, accumulated_count) pairs + let sorted_counts: Vec<(i64, i64)> = if self.reverse { + self.counts.iter().rev().map(|(&k, &v)| (k, v)).collect() + } else { + self.counts.iter().map(|(&k, &v)| (k, v)).collect() + }; + + // Compute accumulated counts + let mut accumulated: Vec<(i64, i64)> = Vec::with_capacity(sorted_counts.len()); + let mut total: i64 = 0; + for (value, count) in sorted_counts { + total += count; + accumulated.push((value, total)); + } + + let total_count = total; + if total_count == 0 { + return None; + } + + // Position in the distribution (0-indexed) + let position = (total_count - 1) as f64 * self.percentile; + let lower = position.floor() as i64; + let higher = position.ceil() as i64; + + // Binary search for lower and higher indices + let lower_idx = Self::binary_search_count(&accumulated, lower + 1); + let higher_idx = Self::binary_search_count(&accumulated, higher + 1); + + let lower_key = accumulated[lower_idx].0; + + if higher == lower { + // No interpolation needed + return Some(lower_key); + } + + let higher_key = accumulated[higher_idx].0; + + if lower_key == higher_key { + // Same key, no interpolation needed + return Some(lower_key); + } + + // Linear interpolation + let fraction = position - lower as f64; + + // Handle interpolation based on return type + match &self.return_type { + DataType::Float64 => { + // Interpret i64 bits as f64 + let lower_f = f64::from_bits(lower_key as u64); + let higher_f = f64::from_bits(higher_key as u64); + let result = (1.0 - fraction) * lower_f + fraction * higher_f; + Some(result.to_bits() as i64) + } + DataType::Interval(IntervalUnit::YearMonth) => { + // Values are i32 months stored as i64 + let lower_months = lower_key as i32; + let higher_months = higher_key as i32; + let result = (1.0 - fraction) * (lower_months as f64) + fraction * (higher_months as f64); + Some(result.round() as i64) + } + DataType::Interval(IntervalUnit::DayTime) => { + // Values are packed as (days << 32) | milliseconds + let lower_days = (lower_key >> 32) as i32; + let lower_ms = lower_key as i32; + let higher_days = (higher_key >> 32) as i32; + let higher_ms = higher_key as i32; + + // Convert to total milliseconds for interpolation + let lower_total_ms = (lower_days as i64) * 86_400_000 + (lower_ms as i64); + let higher_total_ms = (higher_days as i64) * 86_400_000 + (higher_ms as i64); + let result_ms = ((1.0 - fraction) * (lower_total_ms as f64) + fraction * (higher_total_ms as f64)).round() as i64; + + // Convert back to days and milliseconds + let result_days = (result_ms / 86_400_000) as i32; + let result_remaining_ms = (result_ms % 86_400_000) as i32; + + Some(((result_days as i64) << 32) | (result_remaining_ms as i64 & 0xFFFFFFFF)) + } + _ => Some(lower_key), + } + } + + /// Binary search to find the index where accumulated count >= target + fn binary_search_count(accumulated: &[(i64, i64)], target: i64) -> usize { + match accumulated.binary_search_by(|(_, count)| count.cmp(&target)) { + Ok(idx) => idx, + Err(idx) => idx.min(accumulated.len() - 1), + } + } +} + +impl Accumulator for PercentileAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + + match array.data_type() { + DataType::Float64 => { + let values = array.as_any().downcast_ref::().unwrap(); + for i in 0..values.len() { + if values.is_null(i) { + continue; + } + let key = values.value(i).to_bits() as i64; + *self.counts.entry(key).or_insert(0) += 1; + } + } + DataType::Interval(IntervalUnit::YearMonth) => { + let values = array.as_any().downcast_ref::().unwrap(); + for i in 0..values.len() { + if values.is_null(i) { + continue; + } + let key = values.value(i) as i64; + *self.counts.entry(key).or_insert(0) += 1; + } + } + DataType::Interval(IntervalUnit::DayTime) => { + let values = array.as_any().downcast_ref::().unwrap(); + for i in 0..values.len() { + if values.is_null(i) { + continue; + } + // Convert IntervalDayTime struct to packed i64: (days << 32) | milliseconds + let (days, ms) = IntervalDayTimeType::to_parts(values.value(i)); + let key = ((days as i64) << 32) | (ms as i64 & 0xFFFFFFFF); + *self.counts.entry(key).or_insert(0) += 1; + } + } + _ => { + // Fallback: try to treat as Float64 + if let Some(values) = array.as_any().downcast_ref::() { + for i in 0..values.len() { + if values.is_null(i) { + continue; + } + let key = values.value(i).to_bits() as i64; + *self.counts.entry(key).or_insert(0) += 1; + } + } + } + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let binary_array = states[0].as_any().downcast_ref::().unwrap(); + + for i in 0..binary_array.len() { + if binary_array.is_null(i) { + continue; + } + let bytes = binary_array.value(i); + let other_counts = Self::deserialize(bytes)?; + + for (key, count) in other_counts { + *self.counts.entry(key).or_insert(0) += count; + } + } + + Ok(()) + } + + fn state(&mut self) -> Result> { + let bytes = self.serialize(); + Ok(vec![ScalarValue::Binary(Some(bytes))]) + } + + fn evaluate(&mut self) -> Result { + match self.compute_percentile_i64() { + Some(value) => match &self.return_type { + DataType::Float64 => Ok(ScalarValue::Float64(Some(f64::from_bits(value as u64)))), + DataType::Interval(IntervalUnit::YearMonth) => { + Ok(ScalarValue::IntervalYearMonth(Some(value as i32))) + } + DataType::Interval(IntervalUnit::DayTime) => { + // Unpack i64 to (days, milliseconds) and create IntervalDayTime struct + let days = (value >> 32) as i32; + let ms = value as i32; + Ok(ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(days, ms)))) + } + _ => Ok(ScalarValue::Float64(Some(f64::from_bits(value as u64)))), + }, + None => match &self.return_type { + DataType::Float64 => Ok(ScalarValue::Float64(None)), + DataType::Interval(IntervalUnit::YearMonth) => Ok(ScalarValue::IntervalYearMonth(None)), + DataType::Interval(IntervalUnit::DayTime) => Ok(ScalarValue::IntervalDayTime(None)), + _ => Ok(ScalarValue::Float64(None)), + }, + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + self.counts.len() * (std::mem::size_of::() * 2) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Float64Array; + use std::sync::Arc; + + #[test] + fn test_percentile_median() { + let mut acc = PercentileAccumulator::new(0.5, false, DataType::Float64); + let values: ArrayRef = Arc::new(Float64Array::from(vec![0.0, 10.0, 20.0, 30.0, 40.0])); + acc.update_batch(&[values]).unwrap(); + + let result = acc.evaluate().unwrap(); + assert_eq!(result, ScalarValue::Float64(Some(20.0))); + } + + #[test] + fn test_percentile_25th() { + let mut acc = PercentileAccumulator::new(0.25, false, DataType::Float64); + let values: ArrayRef = Arc::new(Float64Array::from(vec![0.0, 10.0, 20.0, 30.0, 40.0])); + acc.update_batch(&[values]).unwrap(); + + let result = acc.evaluate().unwrap(); + assert_eq!(result, ScalarValue::Float64(Some(10.0))); + } + + #[test] + fn test_percentile_serialize_deserialize() { + let mut acc = PercentileAccumulator::new(0.5, false, DataType::Float64); + let values: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])); + acc.update_batch(&[values]).unwrap(); + + let state = acc.state().unwrap(); + let bytes = match &state[0] { + ScalarValue::Binary(Some(b)) => b.clone(), + _ => panic!("Expected Binary state"), + }; + + let deserialized = PercentileAccumulator::deserialize(&bytes).unwrap(); + assert_eq!(deserialized.len(), 3); + } + + #[test] + fn test_percentile_reverse() { + // With DESC ordering, 25th percentile should equal 75th percentile of ASC + let mut acc_asc = PercentileAccumulator::new(0.75, false, DataType::Float64); + let mut acc_desc = PercentileAccumulator::new(0.25, true, DataType::Float64); + + let values: ArrayRef = Arc::new(Float64Array::from(vec![0.0, 10.0, 20.0, 30.0, 40.0])); + acc_asc.update_batch(&[values.clone()]).unwrap(); + acc_desc.update_batch(&[values]).unwrap(); + + let result_asc = acc_asc.evaluate().unwrap(); + let result_desc = acc_desc.evaluate().unwrap(); + assert_eq!(result_asc, result_desc); + } + + #[test] + fn test_percentile_year_month_interval() { + let mut acc = PercentileAccumulator::new(0.5, false, DataType::Interval(IntervalUnit::YearMonth)); + let values: ArrayRef = Arc::new(IntervalYearMonthArray::from(vec![0, 10, 20, 30, 40])); + acc.update_batch(&[values]).unwrap(); + + let result = acc.evaluate().unwrap(); + assert_eq!(result, ScalarValue::IntervalYearMonth(Some(20))); + } + + #[test] + fn test_percentile_day_time_interval() { + let mut acc = PercentileAccumulator::new(0.5, false, DataType::Interval(IntervalUnit::DayTime)); + // Create intervals: 1 day, 2 days, 3 days, 4 days, 5 days (no milliseconds) + let values: ArrayRef = Arc::new(IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(1, 0), + IntervalDayTimeType::make_value(2, 0), + IntervalDayTimeType::make_value(3, 0), + IntervalDayTimeType::make_value(4, 0), + IntervalDayTimeType::make_value(5, 0), + ])); + acc.update_batch(&[values]).unwrap(); + + let result = acc.evaluate().unwrap(); + // Median should be 3 days + assert_eq!(result, ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(3, 0)))); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8c39ba779d..74e4fdb567 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -263,6 +263,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Last] -> CometLast, classOf[Max] -> CometMax, classOf[Min] -> CometMin, + classOf[Percentile] -> CometPercentile, classOf[StddevPop] -> CometStddevPop, classOf[StddevSamp] -> CometStddevSamp, classOf[Sum] -> CometSum, @@ -370,6 +371,8 @@ object QueryPlanSerde extends Logging with CometExprShim { case _: ArrayType => 14 case _: MapType => 15 case _: StructType => 16 + case _: YearMonthIntervalType => 17 + case _: DayTimeIntervalType => 18 case dt => logWarning(s"Cannot serialize Spark data type: $dt") return None diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 1485589b46..6746a5f825 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,9 +22,10 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType} +import org.apache.spark.sql.types.{ArrayType, ByteType, DataTypes, DayTimeIntervalType, DecimalType, IntegerType, LongType, NumericType, , StringType, YearMonthIntervalType} import org.apache.comet.CometConf import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT @@ -671,6 +672,62 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt } } +object CometPercentile extends CometAggregateExpressionSerde[Percentile] { + override def convert( + aggExpr: AggregateExpression, + expr: Percentile, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + + // Only support when frequency is Literal(1L) - i.e., percentile_cont behavior + expr.frequencyExpression match { + case Literal(1L, LongType) => + case _ => + withInfo(aggExpr, "weighted percentile not supported") + return None + } + + // Only support scalar percentile, not array of percentiles + if (expr.percentageExpression.dataType.isInstanceOf[ArrayType]) { + withInfo(aggExpr, "array of percentiles not supported") + return None + } + + // Support numeric types and interval types + expr.child.dataType match { + case _: NumericType => + case _: DecimalType => + case _: YearMonthIntervalType => + case _: DayTimeIntervalType => + case _ => + withInfo(aggExpr, s"unsupported input type: ${expr.child.dataType}") + return None + } + + val childExpr = exprToProto(expr.child, inputs, binding) + val percentileExpr = exprToProto(expr.percentageExpression, inputs, binding) + val dataType = serializeDataType(expr.dataType) + + if (childExpr.isDefined && percentileExpr.isDefined && dataType.isDefined) { + val builder = ExprOuterClass.PercentileCont.newBuilder() + builder.setChild(childExpr.get) + builder.setPercentile(percentileExpr.get) + builder.setDatatype(dataType.get) + builder.setReverse(expr.reverse) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setPercentileCont(builder) + .build()) + } else { + withInfo(aggExpr, expr.child, expr.percentageExpression) + None + } + } +} + object AggSerde { import org.apache.spark.sql.types._ diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql new file mode 100644 index 0000000000..5cb61c6610 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_cont.sql @@ -0,0 +1,77 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Tests for percentile_cont aggregate function +-- Uses similar test data as Spark's percentiles.sql + +statement +CREATE TABLE test_percentile(k int, v int) USING parquet + +statement +INSERT INTO test_percentile VALUES (0, 0), (0, 10), (0, 20), (0, 30), (0, 40), (1, 10), (1, 20), (2, 10), (2, 20), (2, 25), (2, 30), (3, 60), (4, NULL) + +-- Basic percentile_cont (25th percentile) - should match Spark result: 10.0 +query +SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FROM test_percentile + +-- percentile_cont with DESC ordering - should match Spark result: 30.0 +query +SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FROM test_percentile + +-- percentile_cont with GROUP BY - should match Spark results +query +SELECT k, percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FROM test_percentile GROUP BY k ORDER BY k + +-- percentile_cont with GROUP BY and DESC - should match Spark results +query +SELECT k, percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FROM test_percentile GROUP BY k ORDER BY k + +-- median (50th percentile) +query +SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_percentile + +-- Multiple percentile_cont in same query +query +SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY v), percentile_cont(0.75) WITHIN GROUP (ORDER BY v) FROM test_percentile + +-- Tests for interval types +statement +CREATE TABLE test_interval ( + id INT, + ym INTERVAL YEAR TO MONTH, + dt INTERVAL DAY TO SECOND +) USING parquet + +statement +INSERT INTO test_interval VALUES + (1, INTERVAL '1' YEAR, INTERVAL '1' DAY), + (2, INTERVAL '2' YEAR, INTERVAL '2' DAY), + (3, INTERVAL '3' YEAR, INTERVAL '3' DAY), + (4, INTERVAL '4' YEAR, INTERVAL '4' DAY), + (5, INTERVAL '5' YEAR, INTERVAL '5' DAY) + +-- percentile_cont with YearMonthIntervalType +query +SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY ym) FROM test_interval + +-- percentile_cont with DayTimeIntervalType +query +SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY dt) FROM test_interval + +-- percentile_cont with interval types and GROUP BY +query +SELECT id % 2 AS grp, percentile_cont(0.5) WITHIN GROUP (ORDER BY ym) FROM test_interval GROUP BY id % 2 ORDER BY grp