From 6ae981cd7f64681d924d9cfaf82e153a440ebb4c Mon Sep 17 00:00:00 2001 From: Asish Kumar Date: Tue, 14 Apr 2026 01:46:02 +0530 Subject: [PATCH] [AURON #2178] [AURON #2179] Implement native support for first_value and last_value window functions Spark `first_value(...)` and `last_value(...)` are not supported in Auron's native window execution path, causing queries using them to fall back to the Spark path instead of being executed natively. This PR extends native window coverage to include both functions: first_value: - maps `First(child, ignoreNulls)` in the window expression to the existing `AggFunction::First` / `AggFunction::FirstIgnoresNull` Rust aggregates via `NativeWindowBase` - the running-aggregate semantics of `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW` naturally produce the correct first-value behavior through the existing `AggProcessor` - no new Rust aggregate code is required for first_value last_value: - introduces `AggFunction::Last` and `AggFunction::LastIgnoresNull` to the Rust aggregate infrastructure - adds `AggLast` (always updates accumulator, including with null values) and `AggLastIgnoresNull` (updates accumulator only for non-null values) - extends the protobuf `AggFunction` enum with `LAST = 10` and `LAST_IGNORES_NULL = 11` - adds planner and lib.rs mappings for the new proto values - maps `Last(child, ignoreNulls)` in the window expression to the new `AggFunction::Last` / `AggFunction::LastIgnoresNull` aggregates Both functions use the existing `AggProcessor` with frame `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`, consistent with all other aggregate window functions in Auron. Additional changes: - adds `Last` import and conversion case to `NativeConverters.scala` so `last()` as a group aggregate also works natively - adds `First` and `Last` imports and match cases to `NativeWindowBase` - adds Scala regression tests for first_value and last_value, covering both RESPECT NULLS and IGNORE NULLS variants Signed-off-by: Asish Kumar --- native-engine/auron-planner/proto/auron.proto | 2 + native-engine/auron-planner/src/lib.rs | 2 + native-engine/auron-planner/src/planner.rs | 6 + .../datafusion-ext-plans/src/agg/agg.rs | 10 + .../datafusion-ext-plans/src/agg/last.rs | 235 ++++++++++++++++++ .../src/agg/last_ignores_null.rs | 232 +++++++++++++++++ .../datafusion-ext-plans/src/agg/mod.rs | 4 + .../org/apache/auron/AuronWindowSuite.scala | 97 ++++++++ .../spark/sql/auron/NativeConverters.scala | 14 +- .../auron/plan/NativeWindowBase.scala | 28 +++ 10 files changed, 629 insertions(+), 1 deletion(-) create mode 100644 native-engine/datafusion-ext-plans/src/agg/last.rs create mode 100644 native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs create mode 100644 spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index b0618b971..6c0ef7f72 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -142,6 +142,8 @@ enum AggFunction { FIRST = 7; FIRST_IGNORES_NULL = 8; BLOOM_FILTER = 9; + LAST = 10; + LAST_IGNORES_NULL = 11; BRICKHOUSE_COLLECT = 1000; BRICKHOUSE_COMBINE_UNIQUE = 1001; UDAF = 1002; diff --git a/native-engine/auron-planner/src/lib.rs b/native-engine/auron-planner/src/lib.rs index a0f7b83d2..fb862bffd 100644 --- a/native-engine/auron-planner/src/lib.rs +++ b/native-engine/auron-planner/src/lib.rs @@ -135,6 +135,8 @@ impl From for AggFunction { protobuf::AggFunction::CollectSet => AggFunction::CollectSet, protobuf::AggFunction::First => AggFunction::First, protobuf::AggFunction::FirstIgnoresNull => AggFunction::FirstIgnoresNull, + protobuf::AggFunction::Last => AggFunction::Last, + protobuf::AggFunction::LastIgnoresNull => AggFunction::LastIgnoresNull, protobuf::AggFunction::BloomFilter => AggFunction::BloomFilter, protobuf::AggFunction::BrickhouseCollect => AggFunction::BrickhouseCollect, protobuf::AggFunction::BrickhouseCombineUnique => AggFunction::BrickhouseCombineUnique, diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index 84a625734..6af38eb4b 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -657,6 +657,12 @@ impl PhysicalPlanner { protobuf::AggFunction::FirstIgnoresNull => { WindowFunction::Agg(AggFunction::FirstIgnoresNull) } + protobuf::AggFunction::Last => { + WindowFunction::Agg(AggFunction::Last) + } + protobuf::AggFunction::LastIgnoresNull => { + WindowFunction::Agg(AggFunction::LastIgnoresNull) + } protobuf::AggFunction::BloomFilter => { WindowFunction::Agg(AggFunction::BloomFilter) } diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index 5eb4c3dad..99adc470d 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -33,6 +33,8 @@ use crate::agg::{ count::AggCount, first::AggFirst, first_ignores_null::AggFirstIgnoresNull, + last::AggLast, + last_ignores_null::AggLastIgnoresNull, maxmin::{AggMax, AggMin}, spark_udaf_wrapper::SparkUDAFWrapper, sum::AggSum, @@ -212,6 +214,14 @@ pub fn create_agg( let dt = children[0].data_type(input_schema)?; Arc::new(AggFirstIgnoresNull::try_new(children[0].clone(), dt)?) } + AggFunction::Last => { + let dt = children[0].data_type(input_schema)?; + Arc::new(AggLast::try_new(children[0].clone(), dt)?) + } + AggFunction::LastIgnoresNull => { + let dt = children[0].data_type(input_schema)?; + Arc::new(AggLastIgnoresNull::try_new(children[0].clone(), dt)?) + } AggFunction::BloomFilter => { let dt = children[0].data_type(input_schema)?; let empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty())); diff --git a/native-engine/datafusion-ext-plans/src/agg/last.rs b/native-engine/datafusion-ext-plans/src/agg/last.rs new file mode 100644 index 000000000..6cf753b03 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/agg/last.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{Result, ScalarValue}, + physical_expr::PhysicalExprRef, +}; +use datafusion_ext_commons::{downcast_any, scalar_value::compacted_scalar_value_from_array}; + +use crate::{ + agg::{ + Agg, + acc::{ + AccBooleanColumn, AccBytes, AccBytesColumn, AccColumnRef, AccPrimColumn, + AccScalarValueColumn, create_acc_generic_column, + }, + agg::IdxSelection, + }, + idx_for_zipped, +}; + +pub struct AggLast { + child: PhysicalExprRef, + data_type: DataType, + acc_array_data_types: Vec, +} + +impl AggLast { + pub fn try_new(child: PhysicalExprRef, data_type: DataType) -> Result { + let acc_array_data_types = vec![data_type.clone()]; + Ok(Self { + child, + data_type, + acc_array_data_types, + }) + } +} + +impl Debug for AggLast { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Last({:?})", self.child) + } +} + +impl Agg for AggLast { + fn as_any(&self) -> &dyn Any { + self + } + + fn exprs(&self) -> Vec { + vec![self.child.clone()] + } + + fn with_new_exprs(&self, exprs: Vec) -> Result> { + Ok(Arc::new(Self::try_new( + exprs[0].clone(), + self.data_type.clone(), + )?)) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nullable(&self) -> bool { + true + } + + fn create_acc_column(&self, num_rows: usize) -> AccColumnRef { + create_acc_generic_column(self.data_type.clone(), num_rows) + } + + fn acc_array_data_types(&self) -> &[DataType] { + &self.acc_array_data_types + } + + fn partial_update( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + partial_args: &[ArrayRef], + partial_arg_idx: IdxSelection<'_>, + ) -> Result<()> { + let partial_arg = &partial_args[0]; + accs.ensure_size(acc_idx); + + macro_rules! handle_bytes { + ($array:expr) => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let partial_arg = $array; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(AccBytes::from(partial_arg.value(partial_arg_idx).as_ref()))); + } else { + accs.set_value(acc_idx, None); + } + } + } + }} + } + + downcast_primitive_array! { + partial_arg => { + if let Ok(accs) = downcast_any!(accs, mut AccPrimColumn<_>) { + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } else { + accs.set_value(acc_idx, None); + } + } + } + } + } + DataType::Boolean => { + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let partial_arg = downcast_any!(partial_arg, BooleanArray)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } else { + accs.set_value(acc_idx, None); + } + } + } + } + DataType::Utf8 => handle_bytes!(downcast_any!(partial_arg, StringArray)?), + DataType::Binary => handle_bytes!(downcast_any!(partial_arg, BinaryArray)?), + _other => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, compacted_scalar_value_from_array(partial_arg, partial_arg_idx)?); + } else { + accs.set_value(acc_idx, ScalarValue::try_from(&self.data_type)?); + } + } + } + } + } + Ok(()) + } + + fn partial_merge( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + merging_accs: &mut AccColumnRef, + merging_acc_idx: IdxSelection<'_>, + ) -> Result<()> { + accs.ensure_size(acc_idx); + + // For last, always overwrite with the merging accumulator's value + macro_rules! handle_primitive { + ($ty:ty) => {{ + type TNative = <$ty as ArrowPrimitiveType>::Native; + let accs = downcast_any!(accs, mut AccPrimColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccPrimColumn<_>)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + }} + } + + macro_rules! handle_boolean { + () => {{ + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBooleanColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + }}; + } + + macro_rules! handle_bytes { + () => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBytesColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + }}; + } + + downcast_primitive! { + (&self.data_type) => (handle_primitive), + DataType::Boolean => handle_boolean!(), + DataType::Utf8 | DataType::Binary => handle_bytes!(), + DataType::Null => {} + _ => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + } + } + Ok(()) + } + + fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { + Ok(accs.freeze_to_arrays(acc_idx)?[0].clone()) + } +} diff --git a/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs b/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs new file mode 100644 index 000000000..fde2afd94 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{common::Result, physical_expr::PhysicalExprRef}; +use datafusion_ext_commons::{downcast_any, scalar_value::compacted_scalar_value_from_array}; + +use crate::{ + agg::{ + Agg, + acc::{ + AccBooleanColumn, AccBytes, AccBytesColumn, AccColumnRef, AccPrimColumn, + AccScalarValueColumn, create_acc_generic_column, + }, + agg::IdxSelection, + }, + idx_for_zipped, +}; + +pub struct AggLastIgnoresNull { + child: PhysicalExprRef, + data_type: DataType, + acc_array_data_types: Vec, +} + +impl AggLastIgnoresNull { + pub fn try_new(child: PhysicalExprRef, data_type: DataType) -> Result { + let acc_array_data_types = vec![data_type.clone()]; + Ok(Self { + child, + data_type, + acc_array_data_types, + }) + } +} + +impl Debug for AggLastIgnoresNull { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "LastIgnoresNull({:?})", self.child) + } +} + +impl Agg for AggLastIgnoresNull { + fn as_any(&self) -> &dyn Any { + self + } + + fn exprs(&self) -> Vec { + vec![self.child.clone()] + } + + fn with_new_exprs(&self, exprs: Vec) -> Result> { + Ok(Arc::new(Self::try_new( + exprs[0].clone(), + self.data_type.clone(), + )?)) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nullable(&self) -> bool { + true + } + + fn create_acc_column(&self, num_rows: usize) -> AccColumnRef { + create_acc_generic_column(self.data_type.clone(), num_rows) + } + + fn acc_array_data_types(&self) -> &[DataType] { + &self.acc_array_data_types + } + + fn partial_update( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + partial_args: &[ArrayRef], + partial_arg_idx: IdxSelection<'_>, + ) -> Result<()> { + let partial_arg = &partial_args[0]; + accs.ensure_size(acc_idx); + + macro_rules! handle_bytes { + ($array:expr) => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let partial_arg = $array; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(AccBytes::from(partial_arg.value(partial_arg_idx).as_ref()))); + } + } + } + }} + } + + downcast_primitive_array! { + partial_arg => { + if let Ok(accs) = downcast_any!(accs, mut AccPrimColumn<_>) { + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } + } + } + } + } + DataType::Boolean => { + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let partial_arg = downcast_any!(partial_arg, BooleanArray)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } + } + } + } + DataType::Utf8 => handle_bytes!(downcast_any!(partial_arg, StringArray)?), + DataType::Binary => handle_bytes!(downcast_any!(partial_arg, BinaryArray)?), + _other => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, compacted_scalar_value_from_array(partial_arg, partial_arg_idx)?); + } + } + } + } + } + Ok(()) + } + + fn partial_merge( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + merging_accs: &mut AccColumnRef, + merging_acc_idx: IdxSelection<'_>, + ) -> Result<()> { + accs.ensure_size(acc_idx); + + // primitive types + macro_rules! handle_primitive { + ($ty:ty) => {{ + type TNative = <$ty as ArrowPrimitiveType>::Native; + let accs = downcast_any!(accs, mut AccPrimColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccPrimColumn<_>)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + } + }} + } + + macro_rules! handle_boolean { + () => {{ + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBooleanColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + } + }}; + } + + macro_rules! handle_bytes { + () => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBytesColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + } + }}; + } + + downcast_primitive! { + (&self.data_type) => (handle_primitive), + DataType::Boolean => handle_boolean!(), + DataType::Utf8 | DataType::Binary => handle_bytes!(), + DataType::Null => {} + _ => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if !merging_accs.value(merging_acc_idx).is_null() { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + } + } + } + Ok(()) + } + + fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { + Ok(accs.freeze_to_arrays(acc_idx)?[0].clone()) + } +} diff --git a/native-engine/datafusion-ext-plans/src/agg/mod.rs b/native-engine/datafusion-ext-plans/src/agg/mod.rs index 9f19b02c8..0aa579ebd 100644 --- a/native-engine/datafusion-ext-plans/src/agg/mod.rs +++ b/native-engine/datafusion-ext-plans/src/agg/mod.rs @@ -25,6 +25,8 @@ pub mod collect; pub mod count; pub mod first; pub mod first_ignores_null; +pub mod last; +pub mod last_ignores_null; pub mod maxmin; pub mod spark_udaf_wrapper; pub mod sum; @@ -69,6 +71,8 @@ pub enum AggFunction { Min, First, FirstIgnoresNull, + Last, + LastIgnoresNull, CollectList, CollectSet, BloomFilter, diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala new file mode 100644 index 000000000..1c0669e96 --- /dev/null +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.auron + +import org.apache.spark.sql.AuronQueryTest +import org.apache.spark.sql.execution.auron.plan.NativeWindowBase + +import org.apache.auron.util.AuronTestUtils + +class AuronWindowSuite extends AuronQueryTest with BaseAuronSQLSuite with AuronSQLTestHelper { + + test("first_value window function") { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | first_value(v) over (partition by grp order by id) as first_v + |from t1 + |""".stripMargin) + } + } + } + + test("first_value window function with ignore nulls") { + if (AuronTestUtils.isSparkV32OrGreater) { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, null), (2, 1, 'b'), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | first_value(v) ignore nulls over (partition by grp order by id) as first_non_null_v + |from t1 + |""".stripMargin) + } + } + } + } + + test("last_value window function") { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | last_value(v) over (partition by grp order by id) as last_v + |from t1 + |""".stripMargin) + } + } + } + + test("last_value window function with ignore nulls") { + if (AuronTestUtils.isSparkV32OrGreater) { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | last_value(v) ignore nulls over (partition by grp order by id) as last_non_null_v + |from t1 + |""".stripMargin) + } + } + } + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 750aaa524..c5ad62cf6 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -37,7 +37,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.auron.util.Using import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, Max, Min, Sum, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, Last, Max, Min, Sum, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero @@ -1259,6 +1259,18 @@ object NativeConverters extends Logging { }) aggBuilder.addChildren(convertExpr(child)) + case Last(child, ignoresNullExpr) => + val ignoresNull = ignoresNullExpr.asInstanceOf[Any] match { + case Literal(v: Boolean, BooleanType) => v + case v: Boolean => v + } + aggBuilder.setAggFunction(if (ignoresNull) { + pb.AggFunction.LAST_IGNORES_NULL + } else { + pb.AggFunction.LAST + }) + aggBuilder.addChildren(convertExpr(child)) + case CollectList(child, _, _) => aggBuilder.setAggFunction(pb.AggFunction.COLLECT_LIST) aggBuilder.addChildren(convertExpr(child)) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala index fad61ff09..eb16a989e 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala @@ -36,6 +36,8 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.expressions.WindowExpression import org.apache.spark.sql.catalyst.expressions.aggregate.Average import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.expressions.aggregate.First +import org.apache.spark.sql.catalyst.expressions.aggregate.Last import org.apache.spark.sql.catalyst.expressions.aggregate.Max import org.apache.spark.sql.catalyst.expressions.aggregate.Min import org.apache.spark.sql.catalyst.expressions.aggregate.Sum @@ -158,6 +160,32 @@ abstract class NativeWindowBase( windowExprBuilder.setAggFunc(pb.AggFunction.COUNT) windowExprBuilder.addChildren(NativeConverters.convertExpr(child)) + case e @ First(child, ignoresNullExpr) => + assert( + spec.frameSpecification == RowNumber().frame, // only supports RowFrame(Unbounded, CurrentRow) + s"window frame not supported: ${spec.frameSpecification}") + val ignoresNull = ignoresNullExpr.asInstanceOf[Any] match { + case Literal(v: Boolean, BooleanType) => v + case v: Boolean => v + } + windowExprBuilder.setFuncType(pb.WindowFunctionType.Agg) + windowExprBuilder.setAggFunc( + if (ignoresNull) pb.AggFunction.FIRST_IGNORES_NULL else pb.AggFunction.FIRST) + windowExprBuilder.addChildren(NativeConverters.convertExpr(child)) + + case e @ Last(child, ignoresNullExpr) => + assert( + spec.frameSpecification == RowNumber().frame, // only supports RowFrame(Unbounded, CurrentRow) + s"window frame not supported: ${spec.frameSpecification}") + val ignoresNull = ignoresNullExpr.asInstanceOf[Any] match { + case Literal(v: Boolean, BooleanType) => v + case v: Boolean => v + } + windowExprBuilder.setFuncType(pb.WindowFunctionType.Agg) + windowExprBuilder.setAggFunc( + if (ignoresNull) pb.AggFunction.LAST_IGNORES_NULL else pb.AggFunction.LAST) + windowExprBuilder.addChildren(NativeConverters.convertExpr(child)) + case other => throw new NotImplementedError(s"window function not supported: $other") }