Skip to content

Commit 2b94c56

Browse files
committed
Convert list_length_max to new infrastructure
Uses the eager binary function infrastructure to implement the list_length_max function. This should be a good playground, also because the function isn't enabled. Signed-off-by: Moritz Hoffmann <antiguru@gmail.com>
1 parent 5e3fd9b commit 2b94c56

File tree

4 files changed

+84
-56
lines changed

4 files changed

+84
-56
lines changed

src/expr/src/scalar/func.rs

Lines changed: 10 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2664,7 +2664,7 @@ pub enum BinaryFunc {
26642664
TrimLeading(TrimLeading),
26652665
TrimTrailing(TrimTrailing),
26662666
EncodedBytesCharLength(EncodedBytesCharLength),
2667-
ListLengthMax { max_layer: usize },
2667+
ListLengthMax(ListLengthMax),
26682668
ArrayContains(ArrayContains),
26692669
ArrayContainsArray { rev: bool },
26702670
ArrayLength(ArrayLength),
@@ -2937,7 +2937,7 @@ impl BinaryFunc {
29372937
BinaryFunc::EncodedBytesCharLength(s) => {
29382938
return s.eval(datums, temp_storage, a_expr, b_expr);
29392939
}
2940-
// BinaryFunc::ListLengthMax { max_layer }(s) => return s.eval(datums, temp_storage, a_expr, b_expr),
2940+
BinaryFunc::ListLengthMax(s) => return s.eval(datums, temp_storage, a_expr, b_expr),
29412941
BinaryFunc::ArrayLength(s) => return s.eval(datums, temp_storage, a_expr, b_expr),
29422942
BinaryFunc::ArrayContains(s) => return s.eval(datums, temp_storage, a_expr, b_expr),
29432943
// BinaryFunc::ArrayContainsArray { rev: false } => Ok(array_contains_array(a, b)),
@@ -3037,7 +3037,6 @@ impl BinaryFunc {
30373037
BinaryFunc::JsonbGetStringStringify => Ok(jsonb_get_string(a, b, temp_storage, true)),
30383038
BinaryFunc::JsonbGetPath => Ok(jsonb_get_path(a, b, temp_storage, false)),
30393039
BinaryFunc::JsonbGetPathStringify => Ok(jsonb_get_path(a, b, temp_storage, true)),
3040-
BinaryFunc::ListLengthMax { max_layer } => list_length_max(a, b, *max_layer),
30413040
BinaryFunc::ArrayContainsArray { rev: false } => Ok(array_contains_array(a, b)),
30423041
BinaryFunc::ArrayContainsArray { rev: true } => Ok(array_contains_array(b, a)),
30433042
BinaryFunc::ListContainsList { rev: false } => Ok(list_contains_list(a, b)),
@@ -3258,7 +3257,7 @@ impl BinaryFunc {
32583257
ArrayLower(s) => s.output_type(input1_type, input2_type),
32593258
ArrayUpper(s) => s.output_type(input1_type, input2_type),
32603259

3261-
ListLengthMax { .. } => SqlScalarType::Int32.nullable(true),
3260+
ListLengthMax(s) => s.output_type(input1_type, input2_type),
32623261

32633262
ArrayArrayConcat(s) => s.output_type(input1_type, input2_type),
32643263
ArrayRemove(s) => s.output_type(input1_type, input2_type),
@@ -3432,7 +3431,7 @@ impl BinaryFunc {
34323431
BinaryFunc::LikeEscape(s) => s.propagates_nulls(),
34333432
BinaryFunc::ListContainsList { .. } => true,
34343433
BinaryFunc::ListElementConcat(s) => s.propagates_nulls(),
3435-
BinaryFunc::ListLengthMax { .. } => true,
3434+
BinaryFunc::ListLengthMax(s) => s.propagates_nulls(),
34363435
BinaryFunc::ListListConcat(s) => s.propagates_nulls(),
34373436
BinaryFunc::ListRemove(s) => s.propagates_nulls(),
34383437
BinaryFunc::LogNumeric(s) => s.propagates_nulls(),
@@ -3724,7 +3723,7 @@ impl BinaryFunc {
37243723
JsonbGetPathStringify => true,
37253724
JsonbGetString => true,
37263725
JsonbGetStringStringify => true,
3727-
ListLengthMax { .. } => true,
3726+
ListLengthMax(s) => s.introduces_nulls(),
37283727
MapGetValue(s) => s.introduces_nulls(),
37293728
}
37303729
}
@@ -3901,7 +3900,7 @@ impl BinaryFunc {
39013900
GetByte(s) => s.is_infix_op(),
39023901
Left(s) => s.is_infix_op(),
39033902
LikeEscape(s) => s.is_infix_op(),
3904-
ListLengthMax { .. } => false,
3903+
ListLengthMax(s) => s.is_infix_op(),
39053904
ListRemove(s) => s.is_infix_op(),
39063905
LogNumeric(s) => s.is_infix_op(),
39073906
MzAclItemContainsPrivilege(s) => s.is_infix_op(),
@@ -4045,7 +4044,7 @@ impl BinaryFunc {
40454044
BinaryFunc::LikeEscape(s) => s.negate(),
40464045
BinaryFunc::ListContainsList { .. } => None,
40474046
BinaryFunc::ListElementConcat(s) => s.negate(),
4048-
BinaryFunc::ListLengthMax { .. } => None,
4047+
BinaryFunc::ListLengthMax(s) => s.negate(),
40494048
BinaryFunc::ListListConcat(s) => s.negate(),
40504049
BinaryFunc::ListRemove(s) => s.negate(),
40514050
BinaryFunc::LogNumeric(s) => s.negate(),
@@ -4310,7 +4309,7 @@ impl BinaryFunc {
43104309
BinaryFunc::RepeatString => true,
43114310
BinaryFunc::Normalize => true,
43124311
BinaryFunc::EncodedBytesCharLength(s) => s.could_error(),
4313-
BinaryFunc::ListLengthMax { .. } => true,
4312+
BinaryFunc::ListLengthMax(s) => s.could_error(),
43144313
BinaryFunc::ArrayLength(s) => s.could_error(),
43154314
BinaryFunc::ArrayRemove(s) => s.could_error(),
43164315
BinaryFunc::ArrayUpper(s) => s.could_error(),
@@ -4510,7 +4509,7 @@ impl BinaryFunc {
45104509
BinaryFunc::TrimLeading(s) => s.is_monotone(),
45114510
BinaryFunc::TrimTrailing(s) => s.is_monotone(),
45124511
BinaryFunc::EncodedBytesCharLength(s) => s.is_monotone(),
4513-
BinaryFunc::ListLengthMax { .. } => (false, false),
4512+
BinaryFunc::ListLengthMax(s) => s.is_monotone(),
45144513
BinaryFunc::ArrayContains(s) => s.is_monotone(),
45154514
BinaryFunc::ArrayContainsArray { .. } => (false, false),
45164515
BinaryFunc::ArrayLength(s) => s.is_monotone(),
@@ -4718,7 +4717,7 @@ impl fmt::Display for BinaryFunc {
47184717
BinaryFunc::TrimLeading(s) => s.fmt(f),
47194718
BinaryFunc::TrimTrailing(s) => s.fmt(f),
47204719
BinaryFunc::EncodedBytesCharLength(s) => s.fmt(f),
4721-
BinaryFunc::ListLengthMax { .. } => f.write_str("list_length_max"),
4720+
BinaryFunc::ListLengthMax(s) => s.fmt(f),
47224721
BinaryFunc::ArrayContains(s) => s.fmt(f),
47234722
BinaryFunc::ArrayContainsArray { rev } => f.write_str(if *rev { "<@" } else { "@>" }),
47244723
BinaryFunc::ArrayLength(s) => s.fmt(f),
@@ -5226,48 +5225,6 @@ fn array_upper<'a>(a: Array<'a>, i: i64) -> Result<Option<i32>, EvalError> {
52265225
.transpose()
52275226
}
52285227

5229-
// TODO(benesch): remove potentially dangerous usage of `as`.
5230-
#[allow(clippy::as_conversions)]
5231-
fn list_length_max<'a>(
5232-
a: Datum<'a>,
5233-
b: Datum<'a>,
5234-
max_layer: usize,
5235-
) -> Result<Datum<'a>, EvalError> {
5236-
fn max_len_on_layer<'a>(d: Datum<'a>, on_layer: i64) -> Option<usize> {
5237-
match d {
5238-
Datum::List(i) => {
5239-
let mut i = i.iter();
5240-
if on_layer > 1 {
5241-
let mut max_len = None;
5242-
while let Some(Datum::List(i)) = i.next() {
5243-
max_len =
5244-
std::cmp::max(max_len_on_layer(Datum::List(i), on_layer - 1), max_len);
5245-
}
5246-
max_len
5247-
} else {
5248-
Some(i.count())
5249-
}
5250-
}
5251-
Datum::Null => None,
5252-
_ => unreachable!(),
5253-
}
5254-
}
5255-
5256-
let b = b.unwrap_int64();
5257-
5258-
if b as usize > max_layer || b < 1 {
5259-
Err(EvalError::InvalidLayer { max_layer, val: b })
5260-
} else {
5261-
match max_len_on_layer(a, b) {
5262-
Some(l) => match l.try_into() {
5263-
Ok(c) => Ok(Datum::Int32(c)),
5264-
Err(_) => Err(EvalError::Int32OutOfRange(l.to_string().into())),
5265-
},
5266-
None => Ok(Datum::Null),
5267-
}
5268-
}
5269-
}
5270-
52715228
#[sqlfunc(
52725229
is_infix_op = true,
52735230
sqlname = "array_contains",

src/expr/src/scalar/func/binary.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ mod derive {
291291
LikeEscape,
292292
// ListContainsList
293293
ListElementConcat,
294-
// ListLengthMax
294+
ListLengthMax,
295295
ListListConcat,
296296
ListRemove,
297297
LogNumeric(LogBaseNumeric),

src/expr/src/scalar/func/impls/list.rs

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
use std::fmt;
1111

1212
use mz_lowertest::MzReflect;
13-
use mz_repr::{Datum, Row, RowArena, SqlColumnType, SqlScalarType};
13+
use mz_repr::{Datum, DatumList, Row, RowArena, SqlColumnType, SqlScalarType};
1414
use serde::{Deserialize, Serialize};
1515

1616
use crate::scalar::func::{LazyUnaryFunc, stringify_datum};
@@ -249,3 +249,74 @@ impl fmt::Display for ListLength {
249249
f.write_str("list_length")
250250
}
251251
}
252+
253+
/// The `list_length_max` implementation.
254+
#[derive(
255+
Ord,
256+
PartialOrd,
257+
Clone,
258+
Debug,
259+
Eq,
260+
PartialEq,
261+
serde::Serialize,
262+
serde::Deserialize,
263+
Hash,
264+
mz_lowertest::MzReflect,
265+
)]
266+
pub struct ListLengthMax {
267+
/// Maximal allowed layer to query.
268+
pub max_layer: usize,
269+
}
270+
impl<'a> crate::func::binary::EagerBinaryFunc<'a> for ListLengthMax {
271+
type Input1 = DatumList<'a>;
272+
type Input2 = i64;
273+
type Output = Result<Option<i32>, EvalError>;
274+
// TODO(benesch): remove potentially dangerous usage of `as`.
275+
#[allow(clippy::as_conversions)]
276+
fn call(&self, a: Self::Input1, b: Self::Input2, _: &'a RowArena) -> Self::Output {
277+
fn max_len_on_layer<'a>(i: DatumList<'a>, on_layer: i64) -> Option<usize> {
278+
let mut i = i.iter();
279+
if on_layer > 1 {
280+
let mut max_len = None;
281+
while let Some(Datum::List(i)) = i.next() {
282+
max_len = std::cmp::max(max_len_on_layer(i, on_layer - 1), max_len);
283+
}
284+
max_len
285+
} else {
286+
Some(i.count())
287+
}
288+
}
289+
if b as usize > self.max_layer || b < 1 {
290+
Err(EvalError::InvalidLayer {
291+
max_layer: self.max_layer,
292+
val: b,
293+
})
294+
} else {
295+
match max_len_on_layer(a, b) {
296+
Some(l) => match l.try_into() {
297+
Ok(c) => Ok(Some(c)),
298+
Err(_) => Err(EvalError::Int32OutOfRange(l.to_string().into())),
299+
},
300+
None => Ok(None),
301+
}
302+
}
303+
}
304+
fn output_type(
305+
&self,
306+
input_type_a: SqlColumnType,
307+
input_type_b: SqlColumnType,
308+
) -> SqlColumnType {
309+
use mz_repr::AsColumnType;
310+
let output = Self::Output::as_column_type();
311+
let propagates_nulls = crate::func::binary::EagerBinaryFunc::propagates_nulls(self);
312+
let nullable = output.nullable;
313+
output.nullable(
314+
nullable || (propagates_nulls && (input_type_a.nullable || input_type_b.nullable)),
315+
)
316+
}
317+
}
318+
impl fmt::Display for ListLengthMax {
319+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
320+
f.write_str("list_length_max")
321+
}
322+
}

src/sql/src/func.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3753,7 +3753,7 @@ pub static MZ_CATALOG_BUILTINS: LazyLock<BTreeMap<&'static str, Func>> = LazyLoc
37533753
vec![ListAny, Plain(SqlScalarType::Int64)] => Operation::binary(|ecx, lhs, rhs| {
37543754
ecx.require_feature_flag(&crate::session::vars::ENABLE_LIST_LENGTH_MAX)?;
37553755
let max_layer = ecx.scalar_type(&lhs).unwrap_list_n_layers();
3756-
Ok(lhs.call_binary(rhs, BinaryFunc::ListLengthMax { max_layer }))
3756+
Ok(lhs.call_binary(rhs, BinaryFunc::from(func::ListLengthMax { max_layer })))
37573757
}) => Int32, oid::FUNC_LIST_LENGTH_MAX_OID;
37583758
},
37593759
"list_prepend" => Scalar {

0 commit comments

Comments
 (0)