Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 321 additions & 0 deletions doc/developer/design/20260519_lag_lead_const_args.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/compute-types/src/plan/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ pub fn reduction_type(func: &AggregateFunc) -> ReductionType {
| AggregateFunc::Rank { .. }
| AggregateFunc::DenseRank { .. }
| AggregateFunc::LagLead { .. }
| AggregateFunc::LagLeadConst { .. }
| AggregateFunc::FirstValue { .. }
| AggregateFunc::LastValue { .. }
| AggregateFunc::WindowAggregate { .. }
Expand Down
1 change: 1 addition & 0 deletions src/compute/src/render/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2353,6 +2353,7 @@ mod monoids {
| AggregateFunc::Rank { .. }
| AggregateFunc::DenseRank { .. }
| AggregateFunc::LagLead { .. }
| AggregateFunc::LagLeadConst { .. }
| AggregateFunc::FirstValue { .. }
| AggregateFunc::LastValue { .. }
| AggregateFunc::WindowAggregate { .. }
Expand Down
101 changes: 101 additions & 0 deletions src/expr/src/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2548,6 +2548,7 @@ impl AggregateExpr {
| AggregateFunc::Rank { .. }
| AggregateFunc::DenseRank { .. }
| AggregateFunc::LagLead { .. }
| AggregateFunc::LagLeadConst { .. }
| AggregateFunc::FirstValue { .. }
| AggregateFunc::LastValue { .. }
| AggregateFunc::FusedValueWindowFunc { .. }
Expand Down Expand Up @@ -2674,6 +2675,60 @@ impl AggregateExpr {
self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
}

// The input type for LagLeadConst is ((OriginalRow, InputValue), OrderByExprs...)
// — i.e. the bare input value, no 3-field encoded-args record. The
// single-row computation is a plain constant fold: if the constant
// offset is 0 the result is the input value, otherwise it's the
// constant default value. No `RecordGet`, no `IsNull`, no
// equality.
AggregateFunc::LagLeadConst {
offset,
default,
ignore_nulls: _,
order_by: _,
} => {
let tuple = self
.expr
.clone()
.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));

// Get the overall return type
let return_type_with_orig_row = self
.typ(input_type)
.scalar_type
.unwrap_list_element_type()
.clone();
let lag_lead_return_type =
return_type_with_orig_row.unwrap_record_element_type()[0].clone();

// Extract the original row
let original_row = tuple
.clone()
.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));

// Extract the bare input value (no encoded-args record).
let input_value = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));

let (result_expr, column_name) = Self::on_unique_lag_lead_const(
*offset,
default,
input_value,
lag_lead_return_type,
);

MirScalarExpr::call_variadic(
ListCreate {
elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
},
vec![MirScalarExpr::call_variadic(
RecordCreate {
field_names: vec![column_name, ColumnName::from("?record?")],
},
vec![result_expr, original_row],
)],
)
}

// The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
AggregateFunc::LagLead { lag_lead, .. } => {
let tuple = self
Expand Down Expand Up @@ -2968,6 +3023,22 @@ impl AggregateExpr {
assert_eq!(order_by, outer_order_by);
Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
}
AggregateFunc::LagLeadConst {
order_by,
ignore_nulls: _,
offset,
default,
} => {
assert_eq!(order_by, outer_order_by);
// For the const constituent, `args_for_func` is
// the bare input value (no encoded-args record).
Self::on_unique_lag_lead_const(
*offset,
default,
args_for_func,
return_type_for_func,
)
}
AggregateFunc::FirstValue {
window_frame,
order_by,
Expand Down Expand Up @@ -3107,6 +3178,36 @@ impl AggregateExpr {
)
}

/// `on_unique` for `LagLeadConst`.
///
/// The per-row payload for `LagLeadConst` is the *bare* input value (no
/// 3-field encoded-args record), so the single-row computation is a
/// plain constant fold: when the constant offset is `0` the result is
/// the input value, otherwise it's the constant default value. No
/// `RecordGet`, no `IsNull`, no equality.
///
/// Note: `specialize_lag_lead` declines to rewrite `offset == 0`, so in
/// practice `offset == 0` is unreachable here. The branch is kept as
/// defensive coding (and to keep `on_unique` cheap).
fn on_unique_lag_lead_const(
offset: i32,
default: &Row,
input_value: MirScalarExpr,
return_type: ReprScalarType,
) -> (MirScalarExpr, ColumnName) {
let result_expr = if offset == 0 {
input_value
} else {
MirScalarExpr::literal_ok(default.unpack_first(), return_type)
};
let column_name = if offset < 0 {
ColumnName::from("?lag?")
} else {
ColumnName::from("?lead?")
};
(result_expr, column_name)
}

/// `on_unique` for `lag` and `lead`
fn on_unique_lag_lead(
lag_lead: &LagLeadType,
Expand Down
Loading
Loading