From ec75a28f4209eee90272fa15d93550b472e8f792 Mon Sep 17 00:00:00 2001 From: Asish Kumar Date: Tue, 14 Apr 2026 01:19:56 +0530 Subject: [PATCH] [AURON #2177] Implement native support for lag window function Spark `lag(...)` is not supported in Auron's native window execution path, causing queries using it to fall back to the Spark path instead of being executed natively. This PR extends native window coverage to include `lag(...)`: - adds `Lag` handling in `NativeWindowBase` - extends the protobuf/planner window function enum with `LAG` - adds native planner support to decode `LAG` into the native window plan - introduces a native `LagProcessor` in `datafusion-ext-plans` - evaluates `lag` using Spark-compatible offset/default/null behavior - adds a full-partition processing path for `lag` so that lookback works correctly across input batches - adds Rust regression coverage for cross-batch `lag` - adds Scala regression tests for: - native `lag(...)` execution - Spark fallback for `lag(...) IGNORE NULLS` The native implementation supports Spark semantics for: - `lag(input)` - default offset is 1 - default value is null - `lag(input, offset, default)` - returns the value of `input` at the `offset`th row before the current row in the same window partition - if the target row exists and `input` there is null, returns null - if the target row does not exist, returns `default` Supported scope in this PR: - standard `RESPECT NULLS` behavior Not supported natively in this PR: - `IGNORE NULLS` Unsupported `IGNORE NULLS` queries continue to fall back to Spark to preserve correctness. The full-partition processing infrastructure added here mirrors the approach used for `lead` offset functions, ensuring all rows in a partition are available before computing lag values across batch boundaries. Signed-off-by: Asish Kumar --- native-engine/auron-planner/proto/auron.proto | 1 + native-engine/auron-planner/src/planner.rs | 1 + .../datafusion-ext-plans/src/window/mod.rs | 10 +- .../src/window/processors/lag_processor.rs | 110 +++++++++++ .../src/window/processors/mod.rs | 1 + .../src/window/window_context.rs | 6 + .../datafusion-ext-plans/src/window_exec.rs | 171 ++++++++++++++---- .../org/apache/auron/AuronWindowSuite.scala | 64 +++++++ .../auron/plan/NativeWindowBase.scala | 17 ++ 9 files changed, 340 insertions(+), 41 deletions(-) create mode 100644 native-engine/datafusion-ext-plans/src/window/processors/lag_processor.rs create mode 100644 spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index b0618b971..748bc447f 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -129,6 +129,7 @@ enum WindowFunction { ROW_NUMBER = 0; RANK = 1; DENSE_RANK = 2; + LAG = 3; } enum AggFunction { diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index 84a625734..60b649b02 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -636,6 +636,7 @@ impl PhysicalPlanner { protobuf::WindowFunction::DenseRank => { WindowFunction::RankLike(WindowRankType::DenseRank) } + protobuf::WindowFunction::Lag => WindowFunction::Lag, }, protobuf::WindowFunctionType::Agg => match w.agg_func() { protobuf::AggFunction::Min => WindowFunction::Agg(AggFunction::Min), diff --git a/native-engine/datafusion-ext-plans/src/window/mod.rs b/native-engine/datafusion-ext-plans/src/window/mod.rs index a9e9da29d..a3d6ad3a2 100644 --- a/native-engine/datafusion-ext-plans/src/window/mod.rs +++ b/native-engine/datafusion-ext-plans/src/window/mod.rs @@ -23,8 +23,8 @@ use crate::{ agg::{AggFunction, agg::create_agg}, window::{ processors::{ - agg_processor::AggProcessor, rank_processor::RankProcessor, - row_number_processor::RowNumberProcessor, + agg_processor::AggProcessor, lag_processor::LagProcessor, + rank_processor::RankProcessor, row_number_processor::RowNumberProcessor, }, window_context::WindowContext, }, @@ -36,6 +36,7 @@ pub mod window_context; #[derive(Debug, Clone, Copy)] pub enum WindowFunction { RankLike(WindowRankType), + Lag, Agg(AggFunction), } @@ -87,6 +88,7 @@ impl WindowExpr { WindowFunction::RankLike(WindowRankType::DenseRank) => { Ok(Box::new(RankProcessor::new(true))) } + WindowFunction::Lag => Ok(Box::new(LagProcessor::new(self.children.clone()))), WindowFunction::Agg(agg_func) => { let agg = create_agg( agg_func.clone(), @@ -98,4 +100,8 @@ impl WindowExpr { } } } + + pub fn requires_full_partition(&self) -> bool { + matches!(self.func, WindowFunction::Lag) + } } diff --git a/native-engine/datafusion-ext-plans/src/window/processors/lag_processor.rs b/native-engine/datafusion-ext-plans/src/window/processors/lag_processor.rs new file mode 100644 index 000000000..737c81313 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/window/processors/lag_processor.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::{array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; +use datafusion::{ + common::{DataFusionError, Result, ScalarValue}, + physical_expr::PhysicalExprRef, +}; +use datafusion_ext_commons::arrow::cast::cast; + +use crate::window::{WindowFunctionProcessor, window_context::WindowContext}; + +pub struct LagProcessor { + children: Vec, +} + +impl LagProcessor { + pub fn new(children: Vec) -> Self { + Self { children } + } +} + +impl WindowFunctionProcessor for LagProcessor { + fn process_batch(&mut self, context: &WindowContext, batch: &RecordBatch) -> Result { + assert_eq!( + self.children.len(), + 3, + "lag expects input/offset/default children", + ); + + let input_values = self.children[0] + .evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows()))?; + + let offset_values = self.children[1] + .evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows()))?; + let offset_values = if offset_values.data_type() == &DataType::Int32 { + offset_values + } else { + cast(&offset_values, &DataType::Int32)? + }; + let offset = match ScalarValue::try_from_array(&offset_values, 0)? { + ScalarValue::Int32(Some(offset)) => offset as i64, + other => { + return Err(DataFusionError::Execution(format!( + "lag offset must be a non-null foldable integer, got {other:?}", + ))); + } + }; + + let default_values = self.children[2] + .evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows()))?; + let default_values = if default_values.data_type() == input_values.data_type() { + default_values + } else { + cast(&default_values, input_values.data_type())? + }; + + let mut partition_starts = vec![0usize; batch.num_rows()]; + let mut partition_ends = vec![batch.num_rows(); batch.num_rows()]; + if context.has_partition() && batch.num_rows() > 0 { + let partition_rows = context.get_partition_rows(batch)?; + let mut partition_start = 0usize; + for row_idx in 1..=batch.num_rows() { + let is_boundary = row_idx == batch.num_rows() + || partition_rows.row(row_idx).as_ref() + != partition_rows.row(partition_start).as_ref(); + if is_boundary { + for idx in partition_start..row_idx { + partition_starts[idx] = partition_start; + partition_ends[idx] = row_idx; + } + partition_start = row_idx; + } + } + } + + let mut output = Vec::with_capacity(batch.num_rows()); + for row_idx in 0..batch.num_rows() { + // lag looks backward: target is offset rows before current row + let target_idx = row_idx as i64 - offset; + let partition_start = partition_starts[row_idx] as i64; + let partition_end = partition_ends[row_idx] as i64; + let value = if target_idx >= partition_start && target_idx < partition_end { + ScalarValue::try_from_array(&input_values, target_idx as usize)? + } else { + ScalarValue::try_from_array(&default_values, row_idx)? + }; + output.push(value); + } + + ScalarValue::iter_to_array(output) + } +} diff --git a/native-engine/datafusion-ext-plans/src/window/processors/mod.rs b/native-engine/datafusion-ext-plans/src/window/processors/mod.rs index 7d4a72b55..1fc34b220 100644 --- a/native-engine/datafusion-ext-plans/src/window/processors/mod.rs +++ b/native-engine/datafusion-ext-plans/src/window/processors/mod.rs @@ -14,5 +14,6 @@ // limitations under the License. pub mod agg_processor; +pub mod lag_processor; pub mod rank_processor; pub mod row_number_processor; diff --git a/native-engine/datafusion-ext-plans/src/window/window_context.rs b/native-engine/datafusion-ext-plans/src/window/window_context.rs index a76eb1253..1c5f68f12 100644 --- a/native-engine/datafusion-ext-plans/src/window/window_context.rs +++ b/native-engine/datafusion-ext-plans/src/window/window_context.rs @@ -167,4 +167,10 @@ impl WindowContext { .collect::>>()?, )?) } + + pub fn requires_full_partition(&self) -> bool { + self.window_exprs + .iter() + .any(|expr| expr.requires_full_partition()) + } } diff --git a/native-engine/datafusion-ext-plans/src/window_exec.rs b/native-engine/datafusion-ext-plans/src/window_exec.rs index 5bb698eec..58008605f 100644 --- a/native-engine/datafusion-ext-plans/src/window_exec.rs +++ b/native-engine/datafusion-ext-plans/src/window_exec.rs @@ -17,6 +17,7 @@ use std::{any::Any, fmt::Formatter, sync::Arc}; use arrow::{ array::{Array, ArrayRef, Int32Array}, + compute::concat_batches, datatypes::SchemaRef, record_batch::{RecordBatch, RecordBatchOptions}, }; @@ -37,7 +38,7 @@ use once_cell::sync::OnceCell; use crate::{ common::execution_context::ExecutionContext, - window::{WindowExpr, window_context::WindowContext}, + window::{WindowExpr, WindowFunctionProcessor, window_context::WindowContext}, }; #[derive(Debug)] @@ -209,7 +210,8 @@ fn execute_window( Ok(exec_ctx .clone() .output_with_sender("Window", |sender| async move { - sender.exclude_time(exec_ctx.baseline_metrics().elapsed_compute()); + let elapsed_compute = exec_ctx.baseline_metrics().elapsed_compute().clone(); + sender.exclude_time(&elapsed_compute); let mut processors = window_ctx .window_exprs @@ -217,45 +219,29 @@ fn execute_window( .map(|expr: &WindowExpr| expr.create_processor(&window_ctx)) .collect::>>()?; - while let Some(mut batch) = input.next().await.transpose()? { - let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer(); - let mut window_cols: Vec = processors - .iter_mut() - .map(|processor| processor.process_batch(&window_ctx, &batch)) - .collect::>()?; - - if let Some(group_limit) = window_ctx.group_limit { - assert_eq!(window_cols.len(), 1); - let limited = arrow::compute::kernels::cmp::lt_eq( - &window_cols[0], - &Int32Array::new_scalar(group_limit as i32), - )?; - window_cols[0] = arrow::compute::filter(&window_cols[0], &limited)?; - batch = arrow::compute::filter_record_batch(&batch, &limited)?; + if window_ctx.requires_full_partition() { + let mut staging_batches = vec![]; + while let Some(batch) = input.next().await.transpose()? { + staging_batches.push(batch); } - let outputs: Vec = batch - .columns() - .iter() - .cloned() - .chain(if window_ctx.output_window_cols { - window_cols - } else { - vec![] - }) - .zip(window_ctx.output_schema.fields()) - .map(|(array, field)| { - if array.data_type() != field.data_type() { - return cast(&array, field.data_type()); - } - Ok(array.clone()) - }) - .collect::>()?; - let output_batch = RecordBatch::try_new_with_options( - window_ctx.output_schema.clone(), - outputs, - &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), - )?; + if !staging_batches.is_empty() { + let _timer = elapsed_compute.timer(); + let batch = concat_batches(&window_ctx.input_schema, &staging_batches)?; + let output_batch = + process_window_batch(batch, &window_ctx, processors.as_mut_slice())?; + exec_ctx + .baseline_metrics() + .record_output(output_batch.num_rows()); + sender.send(output_batch).await; + } + return Ok(()); + } + + while let Some(batch) = input.next().await.transpose()? { + let _timer = elapsed_compute.timer(); + let output_batch = + process_window_batch(batch, &window_ctx, processors.as_mut_slice())?; exec_ctx .baseline_metrics() .record_output(output_batch.num_rows()); @@ -265,6 +251,50 @@ fn execute_window( })) } +fn process_window_batch( + mut batch: RecordBatch, + window_ctx: &WindowContext, + processors: &mut [Box], +) -> Result { + let mut window_cols: Vec = processors + .iter_mut() + .map(|processor| processor.process_batch(window_ctx, &batch)) + .collect::>()?; + + if let Some(group_limit) = window_ctx.group_limit { + assert_eq!(window_cols.len(), 1); + let limited = arrow::compute::kernels::cmp::lt_eq( + &window_cols[0], + &Int32Array::new_scalar(group_limit as i32), + )?; + window_cols[0] = arrow::compute::filter(&window_cols[0], &limited)?; + batch = arrow::compute::filter_record_batch(&batch, &limited)?; + } + + let outputs: Vec = batch + .columns() + .iter() + .cloned() + .chain(if window_ctx.output_window_cols { + window_cols + } else { + vec![] + }) + .zip(window_ctx.output_schema.fields()) + .map(|(array, field)| { + if array.data_type() != field.data_type() { + return cast(&array, field.data_type()); + } + Ok(array.clone()) + }) + .collect::>()?; + Ok(RecordBatch::try_new_with_options( + window_ctx.output_schema.clone(), + outputs, + &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), + )?) +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -278,6 +308,8 @@ mod test { prelude::SessionContext, }; + use datafusion::{physical_expr::expressions::Literal, scalar::ScalarValue}; + use crate::{ agg::AggFunction, window::{WindowExpr, WindowFunction, WindowRankType}, @@ -491,4 +523,65 @@ mod test { assert_batches_eq!(expected, &batches); Ok(()) } + + #[tokio::test] + async fn test_window_lag_across_batches() -> Result<(), Box> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + let batch1 = build_table_i32( + ("a1", &vec![1, 1]), + ("b1", &vec![10, 20]), + ("c1", &vec![0, 0]), + )?; + let batch2 = build_table_i32( + ("a1", &vec![1, 2]), + ("b1", &vec![30, 40]), + ("c1", &vec![0, 0]), + )?; + let schema = batch1.schema(); + let input = Arc::new(TestMemoryExec::try_new( + &[vec![batch1, batch2]], + schema, + None, + )?); + + let window_exprs = vec![WindowExpr::new( + WindowFunction::Lag, + vec![ + Arc::new(Column::new("b1", 1)), + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::new(ScalarValue::Int32(Some(-1)))), + ], + Arc::new(Field::new("b1_lag", DataType::Int32, false)), + DataType::Int32, + )]; + + let window = Arc::new(WindowExec::try_new( + input, + window_exprs, + vec![Arc::new(Column::new("a1", 0))], + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("b1", 1)), + options: Default::default(), + }], + None, + true, + )?); + + let stream = window.execute(0, task_ctx)?; + let batches = datafusion::physical_plan::common::collect(stream).await?; + let expected = vec![ + "+----+----+----+--------+", + "| a1 | b1 | c1 | b1_lag |", + "+----+----+----+--------+", + "| 1 | 10 | 0 | -1 |", + "| 1 | 20 | 0 | 10 |", + "| 1 | 30 | 0 | 20 |", + "| 2 | 40 | 0 | -1 |", + "+----+----+----+--------+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } } diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala new file mode 100644 index 000000000..18b87679d --- /dev/null +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.auron + +import org.apache.spark.sql.AuronQueryTest +import org.apache.spark.sql.execution.auron.plan.NativeWindowBase + +import org.apache.auron.util.AuronTestUtils + +class AuronWindowSuite extends AuronQueryTest with BaseAuronSQLSuite with AuronSQLTestHelper { + + test("lag window function") { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | lag(v) over (partition by grp order by id) as prev_v, + | lag(v, 2, 'fallback') over (partition by grp order by id) as prev2_v + |from t1 + |""".stripMargin) + } + } + } + + test("lag window function with ignore nulls falls back") { + if (AuronTestUtils.isSparkV32OrGreater) { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + val df = checkSparkAnswer("""select + | id, + | grp, + | lag(v, 1, 'fallback') ignore nulls + | over (partition by grp order by id) as prev_non_null_v + |from t1 + |""".stripMargin) + val plan = stripAQEPlan(df.queryExecution.executedPlan) + assert(plan.collectFirst { case _: NativeWindowBase => true }.isEmpty) + } + } + } + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala index fad61ff09..2c33e6e74 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Ascending import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.DenseRank import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Lag import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.expressions.NullsFirst import org.apache.spark.sql.catalyst.expressions.Rank @@ -89,6 +90,11 @@ abstract class NativeWindowBase( override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + private def lagIgnoreNulls(expr: Lag): Boolean = + expr.getClass.getMethods + .find(method => method.getName == "ignoreNulls" && method.getParameterCount == 0) + .exists(method => method.invoke(expr).asInstanceOf[Boolean]) + private def nativeWindowExprs = windowExpression.map { named => val field = NativeConverters.convertField(Util.getSchema(named :: Nil).fields(0)) val windowExprBuilder = pb.WindowExprNode.newBuilder().setField(field) @@ -118,6 +124,17 @@ abstract class NativeWindowBase( windowExprBuilder.setFuncType(pb.WindowFunctionType.Window) windowExprBuilder.setWindowFunc(pb.WindowFunction.DENSE_RANK) + case e: Lag => + assert( + spec.frameSpecification == e.frame, + s"window frame not supported: ${spec.frameSpecification}") + assert(!lagIgnoreNulls(e), "window function not supported: lag with IGNORE NULLS") + windowExprBuilder.setFuncType(pb.WindowFunctionType.Window) + windowExprBuilder.setWindowFunc(pb.WindowFunction.LAG) + windowExprBuilder.addChildren(NativeConverters.convertExpr(e.input)) + windowExprBuilder.addChildren(NativeConverters.convertExpr(e.inputOffset)) + windowExprBuilder.addChildren(NativeConverters.convertExpr(e.default)) + case e: Sum => assert( spec.frameSpecification == RowNumber().frame, // only supports RowFrame(Unbounde, CurrentRow)