From 8f9dbc59aa3af4ba4147a88c65f9af72071b6bc9 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 4 Dec 2025 07:29:08 -0500 Subject: [PATCH 01/37] Combine ScalarFnVTable and ExprVTable (#5616) We proved the idea out with ScalarFns, but realised we can minimize the diff for end-users if we actually just re-purpose an expression to == a scalar function. It pretty much was already anyway. This PR is a step on the way to having Expression -> Array be a constant-time operation, with actual compute deferred until we execute the array tree. # Breaking Changes * Expression VTable changes * `Instance` renamed to `Options` to make the purpose more obvious * Added an `Arity` function to replace `validate`, since the only possible validation in the absence of types is arity based anyway. * Added an optional `execute` function that will replace `evaluate` in the future. * Removed `fmt_data` in favor of requiring a Display bound on the associated Options type. * `return_dtype` takes a `ReturnDTypeCtx` to avoid eagerly computing child dtypes. * Removed ExpressionView in favor of returning the options from `Expression::as_` and `Expression::as_opt`. Other changes that I don't expect will impact users: * Moved `is_null_sensitive` and `child_name` from `Expression`, to `ExpressionSignature` * Moved `Expression::serialize_metadata()` to `Expression::options().serialize()` * Removed ExprOptimizer, in favor of `VTable::simplify`. Any more complex optimizations can be implemented in future over the Array tree prior to execution. --------- Signed-off-by: Nicholas Gates --- Cargo.lock | 1 + vortex-array/src/arrays/scalar_fn/array.rs | 12 +- vortex-array/src/arrays/scalar_fn/metadata.rs | 4 +- .../src/arrays/scalar_fn/vtable/array.rs | 4 +- .../src/arrays/scalar_fn/vtable/canonical.rs | 13 +- .../src/arrays/scalar_fn/vtable/mod.rs | 61 +-- .../src/arrays/scalar_fn/vtable/operations.rs | 20 +- .../src/arrays/scalar_fn/vtable/validity.rs | 33 +- .../src/arrays/scalar_fn/vtable/visitor.rs | 2 +- vortex-array/src/compute/between.rs | 18 + vortex-array/src/compute/like.rs | 15 + vortex-array/src/expr/analysis/fallible.rs | 6 +- .../src/expr/analysis/immediate_access.rs | 6 +- .../src/expr/analysis/null_sensitive.rs | 2 +- vortex-array/src/expr/display.rs | 193 ++++---- vortex-array/src/expr/expression.rs | 251 +++------- vortex-array/src/expr/exprs/between.rs | 112 +++-- vortex-array/src/expr/exprs/binary.rs | 188 +++----- .../src/expr/exprs/{cast/mod.rs => cast.rs} | 81 ++-- vortex-array/src/expr/exprs/dynamic.rs | 116 ++--- .../exprs/{get_item/mod.rs => get_item.rs} | 216 ++++++--- .../src/expr/exprs/get_item/transform.rs | 139 ------ vortex-array/src/expr/exprs/is_null.rs | 82 ++-- vortex-array/src/expr/exprs/like.rs | 58 ++- vortex-array/src/expr/exprs/list_contains.rs | 84 ++-- vortex-array/src/expr/exprs/literal.rs | 76 ++-- vortex-array/src/expr/exprs/mask.rs | 122 +++++ .../src/expr/exprs/{merge/mod.rs => merge.rs} | 149 +++++- .../src/expr/exprs/merge/transform.rs | 125 ----- vortex-array/src/expr/exprs/mod.rs | 4 +- vortex-array/src/expr/exprs/not.rs | 64 ++- vortex-array/src/expr/exprs/pack.rs | 115 ++--- vortex-array/src/expr/exprs/root.rs | 54 ++- vortex-array/src/expr/exprs/scalar_fn.rs | 185 -------- .../expr/exprs/{select/mod.rs => select.rs} | 227 +++++---- .../src/expr/exprs/select/transform.rs | 118 ----- .../src/expr/forms/extract_conjuncts.rs | 8 +- vortex-array/src/expr/functions/execution.rs | 60 --- vortex-array/src/expr/functions/mod.rs | 15 - vortex-array/src/expr/functions/scalar.rs | 223 --------- vortex-array/src/expr/functions/session.rs | 26 -- vortex-array/src/expr/functions/vtable.rs | 356 --------------- vortex-array/src/expr/mod.rs | 20 +- vortex-array/src/expr/options.rs | 58 +++ vortex-array/src/expr/proto.rs | 8 +- vortex-array/src/expr/scalar_fn.rs | 159 +++++++ vortex-array/src/expr/session/mod.rs | 112 +---- vortex-array/src/expr/signature.rs | 38 ++ vortex-array/src/expr/simplify.rs | 100 ++++ .../src/expr/transform/match_between.rs | 52 ++- vortex-array/src/expr/transform/mod.rs | 5 - vortex-array/src/expr/transform/optimizer.rs | 36 -- vortex-array/src/expr/transform/partition.rs | 60 +-- vortex-array/src/expr/transform/rules.rs | 181 -------- vortex-array/src/expr/transform/simplify.rs | 247 ---------- .../src/expr/transform/simplify_typed.rs | 129 ------ vortex-array/src/expr/traversal/fold.rs | 18 +- vortex-array/src/expr/traversal/mod.rs | 8 +- vortex-array/src/expr/traversal/references.rs | 9 +- vortex-array/src/expr/view.rs | 63 --- vortex-array/src/expr/vtable.rs | 430 +++++++++++------- vortex-array/src/scalar_fns/cast/array.rs | 59 --- vortex-array/src/scalar_fns/cast/mod.rs | 108 ----- vortex-array/src/scalar_fns/get_item/mod.rs | 115 ----- vortex-array/src/scalar_fns/is_null/mod.rs | 57 --- vortex-array/src/scalar_fns/mask/mod.rs | 84 ---- vortex-array/src/scalar_fns/mod.rs | 35 +- vortex-array/src/scalar_fns/not/mod.rs | 65 --- vortex-array/src/session/mod.rs | 4 - vortex-compute/src/logical/not.rs | 13 + vortex-datafusion/src/convert/exprs.rs | 36 +- vortex-duckdb/src/convert/expr.rs | 1 - vortex-layout/Cargo.toml | 1 + vortex-layout/src/layouts/row_idx/expr.rs | 43 +- vortex-layout/src/layouts/row_idx/mod.rs | 34 +- vortex-layout/src/layouts/struct_/reader.rs | 16 +- vortex-scan/src/scan_builder.rs | 18 +- 77 files changed, 2135 insertions(+), 3931 deletions(-) rename vortex-array/src/expr/exprs/{cast/mod.rs => cast.rs} (75%) rename vortex-array/src/expr/exprs/{get_item/mod.rs => get_item.rs} (51%) delete mode 100644 vortex-array/src/expr/exprs/get_item/transform.rs create mode 100644 vortex-array/src/expr/exprs/mask.rs rename vortex-array/src/expr/exprs/{merge/mod.rs => merge.rs} (76%) delete mode 100644 vortex-array/src/expr/exprs/merge/transform.rs delete mode 100644 vortex-array/src/expr/exprs/scalar_fn.rs rename vortex-array/src/expr/exprs/{select/mod.rs => select.rs} (70%) delete mode 100644 vortex-array/src/expr/exprs/select/transform.rs delete mode 100644 vortex-array/src/expr/functions/execution.rs delete mode 100644 vortex-array/src/expr/functions/mod.rs delete mode 100644 vortex-array/src/expr/functions/scalar.rs delete mode 100644 vortex-array/src/expr/functions/session.rs delete mode 100644 vortex-array/src/expr/functions/vtable.rs create mode 100644 vortex-array/src/expr/options.rs create mode 100644 vortex-array/src/expr/scalar_fn.rs create mode 100644 vortex-array/src/expr/signature.rs create mode 100644 vortex-array/src/expr/simplify.rs delete mode 100644 vortex-array/src/expr/transform/optimizer.rs delete mode 100644 vortex-array/src/expr/transform/rules.rs delete mode 100644 vortex-array/src/expr/transform/simplify.rs delete mode 100644 vortex-array/src/expr/transform/simplify_typed.rs delete mode 100644 vortex-array/src/expr/view.rs delete mode 100644 vortex-array/src/scalar_fns/cast/array.rs delete mode 100644 vortex-array/src/scalar_fns/cast/mod.rs delete mode 100644 vortex-array/src/scalar_fns/get_item/mod.rs delete mode 100644 vortex-array/src/scalar_fns/is_null/mod.rs delete mode 100644 vortex-array/src/scalar_fns/mask/mod.rs delete mode 100644 vortex-array/src/scalar_fns/not/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 5c6e2b009f1..bed8e6d3b44 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8930,6 +8930,7 @@ dependencies = [ "vortex-sequence", "vortex-session", "vortex-utils", + "vortex-vector", "vortex-zstd", ] diff --git a/vortex-array/src/arrays/scalar_fn/array.rs b/vortex-array/src/arrays/scalar_fn/array.rs index ff4fe9f141d..580cad610f0 100644 --- a/vortex-array/src/arrays/scalar_fn/array.rs +++ b/vortex-array/src/arrays/scalar_fn/array.rs @@ -8,7 +8,7 @@ use vortex_error::vortex_ensure; use crate::Array; use crate::ArrayRef; use crate::arrays::ScalarFnVTable; -use crate::expr::functions::scalar::ScalarFn; +use crate::expr::ScalarFn; use crate::stats::ArrayStats; use crate::vtable::ArrayVTable; use crate::vtable::ArrayVTableExt; @@ -17,7 +17,7 @@ use crate::vtable::ArrayVTableExt; pub struct ScalarFnArray { // NOTE(ngates): we should fix vtables so we don't have to hold this pub(super) vtable: ArrayVTable, - pub(super) scalar_fn: ScalarFn, + pub(super) bound: ScalarFn, pub(super) dtype: DType, pub(super) len: usize, pub(super) children: Vec, @@ -26,9 +26,9 @@ pub struct ScalarFnArray { impl ScalarFnArray { /// Create a new ScalarFnArray from a scalar function and its children. - pub fn try_new(scalar_fn: ScalarFn, children: Vec, len: usize) -> VortexResult { + pub fn try_new(bound: ScalarFn, children: Vec, len: usize) -> VortexResult { let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect(); - let dtype = scalar_fn.return_dtype(&arg_dtypes)?; + let dtype = bound.return_dtype(&arg_dtypes)?; vortex_ensure!( children.iter().all(|c| c.len() == len), @@ -36,8 +36,8 @@ impl ScalarFnArray { ); Ok(Self { - vtable: ScalarFnVTable::new(scalar_fn.vtable().clone()).into_vtable(), - scalar_fn, + vtable: ScalarFnVTable::new(bound.vtable().clone()).into_vtable(), + bound, dtype, len, children, diff --git a/vortex-array/src/arrays/scalar_fn/metadata.rs b/vortex-array/src/arrays/scalar_fn/metadata.rs index 1f458fa83e1..0f65ac6ba9c 100644 --- a/vortex-array/src/arrays/scalar_fn/metadata.rs +++ b/vortex-array/src/arrays/scalar_fn/metadata.rs @@ -3,10 +3,10 @@ use vortex_dtype::DType; -use crate::expr::functions::scalar::ScalarFn; +use crate::expr::ScalarFn; #[derive(Clone, Debug)] pub struct ScalarFnMetadata { - pub(super) scalar_fn: ScalarFn, + pub(super) bound: ScalarFn, pub(super) child_dtypes: Vec, } diff --git a/vortex-array/src/arrays/scalar_fn/vtable/array.rs b/vortex-array/src/arrays/scalar_fn/vtable/array.rs index 2da705e808a..516aebb943e 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/array.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/array.rs @@ -30,7 +30,7 @@ impl BaseArrayVTable for ScalarFnVTable { fn array_hash(array: &ScalarFnArray, state: &mut H, precision: Precision) { array.len.hash(state); array.dtype.hash(state); - array.scalar_fn.hash(state); + array.bound.hash(state); for child in &array.children { child.array_hash(state, precision); } @@ -43,7 +43,7 @@ impl BaseArrayVTable for ScalarFnVTable { if array.dtype != other.dtype { return false; } - if array.scalar_fn != other.scalar_fn { + if array.bound != other.bound { return false; } for (child, other_child) in array.children.iter().zip(other.children.iter()) { diff --git a/vortex-array/src/arrays/scalar_fn/vtable/canonical.rs b/vortex-array/src/arrays/scalar_fn/vtable/canonical.rs index f1222998ab9..4ae3f46658b 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/canonical.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/canonical.rs @@ -10,7 +10,7 @@ use crate::Canonical; use crate::arrays::scalar_fn::array::ScalarFnArray; use crate::arrays::scalar_fn::vtable::SCALAR_FN_SESSION; use crate::arrays::scalar_fn::vtable::ScalarFnVTable; -use crate::expr::functions::ExecutionArgs; +use crate::expr::ExecutionArgs; use crate::vectors::VectorIntoArray; use crate::vtable::CanonicalVTable; @@ -28,11 +28,16 @@ impl CanonicalVTable for ScalarFnVTable { "Failed to execute child array during canonicalization of ScalarFnArray", ); - let ctx = ExecutionArgs::new(array.len, array.dtype.clone(), child_dtypes, child_datums); + let ctx = ExecutionArgs { + datums: child_datums, + dtypes: child_dtypes, + row_count: array.len, + return_dtype: array.dtype.clone(), + }; let result_vector = array - .scalar_fn - .execute(&ctx) + .bound + .execute(ctx) .vortex_expect("Canonicalize should be fallible") .into_vector() .vortex_expect("Canonicalize should return a vector"); diff --git a/vortex-array/src/arrays/scalar_fn/vtable/mod.rs b/vortex-array/src/arrays/scalar_fn/vtable/mod.rs index 37a5cf643e8..b8359873eee 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/mod.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/mod.rs @@ -19,6 +19,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_session::VortexSession; +use vortex_vector::Datum; use vortex_vector::Vector; use crate::Array; @@ -27,8 +28,10 @@ use crate::IntoArray; use crate::arrays::scalar_fn::array::ScalarFnArray; use crate::arrays::scalar_fn::metadata::ScalarFnMetadata; use crate::execution::ExecutionCtx; -use crate::expr::functions; -use crate::expr::functions::scalar::ScalarFn; +use crate::expr; +use crate::expr::ExecutionArgs; +use crate::expr::ExprVTable; +use crate::expr::ScalarFn; use crate::optimizer::rules::MatchKey; use crate::optimizer::rules::Matcher; use crate::serde::ArrayChildren; @@ -50,11 +53,11 @@ vtable!(ScalarFn); #[derive(Clone, Debug)] pub struct ScalarFnVTable { - vtable: functions::ScalarFnVTable, + vtable: ExprVTable, } impl ScalarFnVTable { - pub fn new(vtable: functions::ScalarFnVTable) -> Self { + pub fn new(vtable: ExprVTable) -> Self { Self { vtable } } } @@ -81,7 +84,7 @@ impl VTable for ScalarFnVTable { fn metadata(array: &Self::Array) -> VortexResult { let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect(); Ok(ScalarFnMetadata { - scalar_fn: array.scalar_fn.clone(), + bound: array.bound.clone(), child_dtypes, }) } @@ -114,7 +117,7 @@ impl VTable for ScalarFnVTable { { let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect(); vortex_error::vortex_ensure!( - &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype, + &metadata.bound.return_dtype(&child_dtypes)? == dtype, "Return dtype mismatch when building ScalarFnArray" ); } @@ -122,7 +125,7 @@ impl VTable for ScalarFnVTable { Ok(ScalarFnArray { // This requires a new Arc, but we plan to remove this later anyway. vtable: self.to_vtable(), - scalar_fn: metadata.scalar_fn.clone(), + bound: metadata.bound.clone(), dtype: dtype.clone(), len, children, @@ -135,31 +138,31 @@ impl VTable for ScalarFnVTable { let input_datums = array .children() .iter() - .map(|child| child.batch_execute(ctx)) + .map(|child| child.batch_execute(ctx).map(Datum::Vector)) .try_collect()?; - let ctx = functions::ExecutionArgs::new( - array.len(), - array.dtype.clone(), - input_dtypes, - input_datums, - ); + let ctx = ExecutionArgs { + datums: input_datums, + dtypes: input_dtypes, + row_count: array.len, + return_dtype: array.dtype.clone(), + }; Ok(array - .scalar_fn - .execute(&ctx)? + .bound + .execute(ctx)? .into_vector() .vortex_expect("Vector inputs should return vector outputs")) } } /// Array factory functions for scalar functions. -pub trait ScalarFnArrayExt: functions::VTable { +pub trait ScalarFnArrayExt: expr::VTable { fn try_new_array( &'static self, len: usize, options: Self::Options, children: impl Into>, ) -> VortexResult { - let scalar_fn = ScalarFn::new_static(self, options); + let bound = ScalarFn::new_static(self, options); let children = children.into(); vortex_ensure!( @@ -168,16 +171,16 @@ pub trait ScalarFnArrayExt: functions::VTable { ); let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec(); - let dtype = scalar_fn.return_dtype(&child_dtypes)?; + let dtype = bound.return_dtype(&child_dtypes)?; let array_vtable: ArrayVTable = ScalarFnVTable { - vtable: scalar_fn.vtable().clone(), + vtable: bound.vtable().clone(), } .into_vtable(); Ok(ScalarFnArray { vtable: array_vtable, - scalar_fn, + bound, dtype, len, children, @@ -186,7 +189,7 @@ pub trait ScalarFnArrayExt: functions::VTable { .into_array()) } } -impl ScalarFnArrayExt for V {} +impl ScalarFnArrayExt for V {} /// A matcher that matches any scalar function expression. #[derive(Debug)] @@ -205,12 +208,12 @@ impl Matcher for AnyScalarFn { /// A matcher that matches a specific scalar function expression. #[derive(Debug)] -pub struct ExactScalarFn { +pub struct ExactScalarFn { id: ArrayId, _phantom: PhantomData, } -impl From<&'static F> for ExactScalarFn { +impl From<&'static F> for ExactScalarFn { fn from(value: &'static F) -> Self { Self { id: value.id(), @@ -219,7 +222,7 @@ impl From<&'static F> for ExactScalarFn { } } -impl Matcher for ExactScalarFn { +impl Matcher for ExactScalarFn { type View<'a> = ScalarFnArrayView<'a, F>; fn key(&self) -> MatchKey { @@ -229,12 +232,12 @@ impl Matcher for ExactScalarFn { fn try_match<'a>(&self, array: &'a ArrayRef) -> Option> { let scalar_fn_array = array.as_opt::()?; let scalar_fn_vtable = scalar_fn_array - .scalar_fn + .bound .vtable() .as_any() .downcast_ref::()?; let scalar_fn_options = scalar_fn_array - .scalar_fn + .bound .options() .as_any() .downcast_ref::()?; @@ -246,13 +249,13 @@ impl Matcher for ExactScalarFn { } } -pub struct ScalarFnArrayView<'a, F: functions::VTable> { +pub struct ScalarFnArrayView<'a, F: expr::VTable> { array: &'a ArrayRef, pub vtable: &'a F, pub options: &'a F::Options, } -impl Deref for ScalarFnArrayView<'_, F> { +impl Deref for ScalarFnArrayView<'_, F> { type Target = ArrayRef; fn deref(&self) -> &Self::Target { diff --git a/vortex-array/src/arrays/scalar_fn/vtable/operations.rs b/vortex-array/src/arrays/scalar_fn/vtable/operations.rs index fd7d5255667..0201d84d036 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/operations.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/operations.rs @@ -11,7 +11,7 @@ use crate::ArrayRef; use crate::IntoArray; use crate::arrays::scalar_fn::array::ScalarFnArray; use crate::arrays::scalar_fn::vtable::ScalarFnVTable; -use crate::expr::functions::ExecutionArgs; +use crate::expr::ExecutionArgs; use crate::vtable::OperationsVTable; impl OperationsVTable for ScalarFnVTable { @@ -24,7 +24,7 @@ impl OperationsVTable for ScalarFnVTable { ScalarFnArray { vtable: array.vtable.clone(), - scalar_fn: array.scalar_fn.clone(), + bound: array.bound.clone(), dtype: array.dtype.clone(), len: range.len(), children, @@ -42,16 +42,16 @@ impl OperationsVTable for ScalarFnVTable { .map(|scalar| Datum::from(scalar.to_vector_scalar())) .collect(); - let ctx = ExecutionArgs::new( - 1, - array.dtype.clone(), - array.children().iter().map(|s| s.dtype().clone()).collect(), - input_datums, - ); + let ctx = ExecutionArgs { + datums: input_datums, + dtypes: array.children().iter().map(|c| c.dtype().clone()).collect(), + row_count: 1, + return_dtype: array.dtype.clone(), + }; let _result = array - .scalar_fn - .execute(&ctx) + .bound + .execute(ctx) .vortex_expect("Scalar function execution should be fallible") .into_scalar() .vortex_expect("Scalar function execution should return scalar"); diff --git a/vortex-array/src/arrays/scalar_fn/vtable/validity.rs b/vortex-array/src/arrays/scalar_fn/vtable/validity.rs index eb3fada3d50..4de38bdbe44 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/validity.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/validity.rs @@ -8,7 +8,6 @@ use crate::Array; use crate::arrays::scalar_fn::array::ScalarFnArray; use crate::arrays::scalar_fn::vtable::SCALAR_FN_SESSION; use crate::arrays::scalar_fn::vtable::ScalarFnVTable; -use crate::expr::functions::NullHandling; use crate::vtable::ValidityVTable; impl ValidityVTable for ScalarFnVTable { @@ -17,28 +16,32 @@ impl ValidityVTable for ScalarFnVTable { } fn all_valid(array: &ScalarFnArray) -> bool { - match array.scalar_fn.signature().null_handling() { - NullHandling::Propagate | NullHandling::AbsorbsNull => { - // Requires all children to guarantee all_valid - array.children().iter().all(|child| child.all_valid()) - } - NullHandling::Custom => { - // We cannot guarantee that the array is all valid without evaluating the function + match array.bound.signature().is_null_sensitive() { + true => { + // If the function is null sensitive, we cannot guarantee all valid without evaluating + // the function false } + false => { + // If the function is not null sensitive, we can guarantee all valid if all children + // are all valid + array.children().iter().all(|child| child.all_valid()) + } } } fn all_invalid(array: &ScalarFnArray) -> bool { - match array.scalar_fn.signature().null_handling() { - NullHandling::Propagate => { - // All null if any child is all null - array.children().iter().any(|child| child.all_invalid()) - } - NullHandling::AbsorbsNull | NullHandling::Custom => { - // We cannot guarantee that the array is all valid without evaluating the function + match array.bound.signature().is_null_sensitive() { + true => { + // If the function is null sensitive, we cannot guarantee all invalid without evaluating + // the function false } + false => { + // If the function is not null sensitive, we can guarantee all invalid if any child + // is all invalid + array.children().iter().any(|child| child.all_invalid()) + } } } diff --git a/vortex-array/src/arrays/scalar_fn/vtable/visitor.rs b/vortex-array/src/arrays/scalar_fn/vtable/visitor.rs index e43b1049068..4c40e5f42bd 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/visitor.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/visitor.rs @@ -12,7 +12,7 @@ impl VisitorVTable for ScalarFnVTable { fn visit_children(array: &ScalarFnArray, visitor: &mut dyn ArrayChildVisitor) { for (idx, child) in array.children.iter().enumerate() { - let name = array.scalar_fn.signature().arg_name(idx); + let name = array.bound.signature().child_name(idx); visitor.visit_child(name.as_ref(), child.as_ref()) } } diff --git a/vortex-array/src/compute/between.rs b/vortex-array/src/compute/between.rs index ee0e8409ee6..28679bfdf24 100644 --- a/vortex-array/src/compute/between.rs +++ b/vortex-array/src/compute/between.rs @@ -2,6 +2,8 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::any::Any; +use std::fmt::Display; +use std::fmt::Formatter; use std::sync::LazyLock; use arcref::ArcRef; @@ -253,6 +255,22 @@ pub struct BetweenOptions { pub upper_strict: StrictComparison, } +impl Display for BetweenOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let lower_op = if self.lower_strict.is_strict() { + "<" + } else { + "<=" + }; + let upper_op = if self.upper_strict.is_strict() { + "<" + } else { + "<=" + }; + write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op) + } +} + impl Options for BetweenOptions { fn as_any(&self) -> &dyn Any { self diff --git a/vortex-array/src/compute/like.rs b/vortex-array/src/compute/like.rs index fc352d2bd8e..fca5e6ad12f 100644 --- a/vortex-array/src/compute/like.rs +++ b/vortex-array/src/compute/like.rs @@ -2,6 +2,8 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::any::Any; +use std::fmt::Display; +use std::fmt::Formatter; use std::sync::LazyLock; use arcref::ArcRef; @@ -150,6 +152,19 @@ pub struct LikeOptions { pub case_insensitive: bool, } +impl Display for LikeOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if self.negated { + write!(f, "NOT ")?; + } + if self.case_insensitive { + write!(f, "ILIKE") + } else { + write!(f, "LIKE") + } + } +} + impl Options for LikeOptions { fn as_any(&self) -> &dyn Any { self diff --git a/vortex-array/src/expr/analysis/fallible.rs b/vortex-array/src/expr/analysis/fallible.rs index a7cfb2835df..c1ffb2bffda 100644 --- a/vortex-array/src/expr/analysis/fallible.rs +++ b/vortex-array/src/expr/analysis/fallible.rs @@ -6,7 +6,11 @@ use crate::expr::analysis::BooleanLabels; use crate::expr::label_tree; pub fn label_is_fallible(expr: &Expression) -> BooleanLabels<'_> { - label_tree(expr, |expr| expr.is_fallible(), |acc, &child| acc | child) + label_tree( + expr, + |expr| expr.signature().is_fallible(), + |acc, &child| acc | child, + ) } #[cfg(test)] diff --git a/vortex-array/src/expr/analysis/immediate_access.rs b/vortex-array/src/expr/analysis/immediate_access.rs index c05679b2c98..6dc26a7c950 100644 --- a/vortex-array/src/expr/analysis/immediate_access.rs +++ b/vortex-array/src/expr/analysis/immediate_access.rs @@ -24,9 +24,9 @@ pub fn annotate_scope_access(scope: &StructFields) -> impl AnnotationFn() { - if get_item.child(0).is::() { - return vec![get_item.data().clone()]; + if let Some(field_name) = expr.as_opt::() { + if expr.child(0).is::() { + return vec![field_name.clone()]; } } else if expr.is::() { return scope.names().iter().cloned().collect(); diff --git a/vortex-array/src/expr/analysis/null_sensitive.rs b/vortex-array/src/expr/analysis/null_sensitive.rs index 1711792bf9e..5e1c4d0c996 100644 --- a/vortex-array/src/expr/analysis/null_sensitive.rs +++ b/vortex-array/src/expr/analysis/null_sensitive.rs @@ -15,7 +15,7 @@ pub type BooleanLabels<'a> = HashMap<&'a Expression, bool>; pub fn label_null_sensitive(expr: &Expression) -> BooleanLabels<'_> { label_tree( expr, - |expr| expr.is_null_sensitive(), + |expr| expr.signature().is_null_sensitive(), |acc, &child| acc | child, ) } diff --git a/vortex-array/src/expr/display.rs b/vortex-array/src/expr/display.rs index abc5460d429..c1afcd8d4b8 100644 --- a/vortex-array/src/expr/display.rs +++ b/vortex-array/src/expr/display.rs @@ -3,8 +3,10 @@ use std::fmt::Display; use std::fmt::Formatter; +use std::ops::Deref; use crate::expr::Expression; +use crate::expr::ScalarFn; pub enum DisplayFormat { Compact, @@ -17,10 +19,11 @@ impl Display for DisplayTreeExpr<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { pub use termtree::Tree; fn make_tree(expr: &Expression) -> Result, std::fmt::Error> { - let node_name = format!("{}", ExpressionDebug(expr)); + let bound: &ScalarFn = expr.deref(); + let node_name = format!("{}", bound); // Get child names for display purposes - let child_names = (0..expr.children().len()).map(|i| expr.child_name(i)); + let child_names = (0..expr.children().len()).map(|i| expr.signature().child_name(i)); let children = expr.children(); let child_trees: Result>, _> = children @@ -40,18 +43,6 @@ impl Display for DisplayTreeExpr<'_> { } } -struct ExpressionDebug<'a>(&'a Expression); -impl Display for ExpressionDebug<'_> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - // Special-case when expression has no data to avoid trailing space. - if self.0.data().is::<()>() { - return write!(f, "{}", self.0.id().as_ref()); - } - write!(f, "{} ", self.0.id().as_ref())?; - self.0.vtable().as_dyn().fmt_data(self.0.data().as_ref(), f) - } -} - #[cfg(test)] mod tests { use vortex_dtype::DType; @@ -104,76 +95,111 @@ mod tests { } #[test] - fn test_display_tree() { + fn test_display_tree_root() { use insta::assert_snapshot; - let root_expr = root(); - assert_snapshot!(root_expr.display_tree().to_string(), @"vortex.root"); + assert_snapshot!(root_expr.display_tree().to_string(), @"vortex.root()"); + } + #[test] + fn test_display_tree_literal() { + use insta::assert_snapshot; let lit_expr = lit(42); - assert_snapshot!(lit_expr.display_tree().to_string(), @"vortex.literal 42i32"); + assert_snapshot!(lit_expr.display_tree().to_string(), @"vortex.literal(42i32)"); + } + #[test] + fn test_display_tree_get_item() { + use insta::assert_snapshot; let get_item_expr = get_item("my_field", root()); - assert_snapshot!(get_item_expr.display_tree().to_string(), @r#" - vortex.get_item "my_field" - └── input: vortex.root - "#); + assert_snapshot!(get_item_expr.display_tree().to_string(), @r" + vortex.get_item(my_field) + └── input: vortex.root() + "); + } + #[test] + fn test_display_tree_binary() { + use insta::assert_snapshot; let binary_expr = gt(get_item("x", root()), lit(10)); - assert_snapshot!(binary_expr.display_tree().to_string(), @r#" - vortex.binary > - ├── lhs: vortex.get_item "x" - │ └── input: vortex.root - └── rhs: vortex.literal 10i32 - "#); + assert_snapshot!(binary_expr.display_tree().to_string(), @r" + vortex.binary(>) + ├── lhs: vortex.get_item(x) + │ └── input: vortex.root() + └── rhs: vortex.literal(10i32) + "); + } + #[test] + fn test_display_tree_complex_binary() { + use insta::assert_snapshot; let complex_binary = and( eq(get_item("name", root()), lit("alice")), gt(get_item("age", root()), lit(18)), ); assert_snapshot!(complex_binary.display_tree().to_string(), @r#" - vortex.binary and - ├── lhs: vortex.binary = - │ ├── lhs: vortex.get_item "name" - │ │ └── input: vortex.root - │ └── rhs: vortex.literal "alice" - └── rhs: vortex.binary > - ├── lhs: vortex.get_item "age" - │ └── input: vortex.root - └── rhs: vortex.literal 18i32 + vortex.binary(and) + ├── lhs: vortex.binary(=) + │ ├── lhs: vortex.get_item(name) + │ │ └── input: vortex.root() + │ └── rhs: vortex.literal("alice") + └── rhs: vortex.binary(>) + ├── lhs: vortex.get_item(age) + │ └── input: vortex.root() + └── rhs: vortex.literal(18i32) "#); + } + #[test] + fn test_display_tree_select() { + use insta::assert_snapshot; let select_expr = select(["name", "age"], root()); assert_snapshot!(select_expr.display_tree().to_string(), @r" - vortex.select include={name, age} - └── child: vortex.root + vortex.select({name, age}) + └── child: vortex.root() "); + } + #[test] + fn test_display_tree_select_exclude() { + use insta::assert_snapshot; let select_exclude_expr = select_exclude(["internal_id", "metadata"], root()); assert_snapshot!(select_exclude_expr.display_tree().to_string(), @r" - vortex.select exclude={internal_id, metadata} - └── child: vortex.root + vortex.select(~{internal_id, metadata}) + └── child: vortex.root() "); + } + #[test] + fn test_display_tree_cast() { + use insta::assert_snapshot; let cast_expr = cast( get_item("value", root()), DType::Primitive(PType::I64, Nullability::NonNullable), ); - assert_snapshot!(cast_expr.display_tree().to_string(), @r#" - vortex.cast i64 - └── input: vortex.get_item "value" - └── input: vortex.root - "#); + assert_snapshot!(cast_expr.display_tree().to_string(), @r" + vortex.cast(i64) + └── input: vortex.get_item(value) + └── input: vortex.root() + "); + } + #[test] + fn test_display_tree_not() { + use insta::assert_snapshot; let not_expr = not(eq(get_item("active", root()), lit(true))); - assert_snapshot!(not_expr.display_tree().to_string(), @r#" - vortex.not - └── input: vortex.binary = - ├── lhs: vortex.get_item "active" - │ └── input: vortex.root - └── rhs: vortex.literal true - "#); + assert_snapshot!(not_expr.display_tree().to_string(), @r" + vortex.not() + └── input: vortex.binary(=) + ├── lhs: vortex.get_item(active) + │ └── input: vortex.root() + └── rhs: vortex.literal(true) + "); + } + #[test] + fn test_display_tree_between() { + use insta::assert_snapshot; let between_expr = between( get_item("score", root()), lit(0), @@ -183,15 +209,18 @@ mod tests { upper_strict: StrictComparison::NonStrict, }, ); - assert_snapshot!(between_expr.display_tree().to_string(), @r#" - vortex.between BetweenOptions { lower_strict: NonStrict, upper_strict: NonStrict } - ├── array: vortex.get_item "score" - │ └── input: vortex.root - ├── lower: vortex.literal 0i32 - └── upper: vortex.literal 100i32 - "#); + assert_snapshot!(between_expr.display_tree().to_string(), @r" + vortex.between(lower_strict: <=, upper_strict: <=) + ├── array: vortex.get_item(score) + │ └── input: vortex.root() + ├── lower: vortex.literal(0i32) + └── upper: vortex.literal(100i32) + "); + } - // Test nested expression + #[test] + fn test_display_tree_nested() { + use insta::assert_snapshot; let nested_expr = select( ["result"], cast( @@ -207,16 +236,20 @@ mod tests { DType::Bool(Nullability::NonNullable), ), ); - assert_snapshot!(nested_expr.display_tree().to_string(), @r#" - vortex.select include={result} - └── child: vortex.cast bool - └── input: vortex.between BetweenOptions { lower_strict: Strict, upper_strict: NonStrict } - ├── array: vortex.get_item "score" - │ └── input: vortex.root - ├── lower: vortex.literal 50i32 - └── upper: vortex.literal 100i32 - "#); + assert_snapshot!(nested_expr.display_tree().to_string(), @r" + vortex.select({result}) + └── child: vortex.cast(bool) + └── input: vortex.between(lower_strict: <, upper_strict: <=) + ├── array: vortex.get_item(score) + │ └── input: vortex.root() + ├── lower: vortex.literal(50i32) + └── upper: vortex.literal(100i32) + "); + } + #[test] + fn test_display_tree_pack() { + use insta::assert_snapshot; let select_from_pack_expr = select( ["fizz", "buzz"], pack( @@ -228,15 +261,15 @@ mod tests { Nullability::Nullable, ), ); - assert_snapshot!(select_from_pack_expr.display_tree().to_string(), @r#" - vortex.select include={fizz, buzz} - └── child: vortex.pack PackOptions { names: FieldNames([FieldName("fizz"), FieldName("bar"), FieldName("buzz")]), nullability: Nullable } - ├── fizz: vortex.root - ├── bar: vortex.literal 5i32 - └── buzz: vortex.binary = - ├── lhs: vortex.literal 42i32 - └── rhs: vortex.get_item "answer" - └── input: vortex.root - "#); + assert_snapshot!(select_from_pack_expr.display_tree().to_string(), @r" + vortex.select({fizz, buzz}) + └── child: vortex.pack(names: [fizz, bar, buzz], nullability: ?) + ├── fizz: vortex.root() + ├── bar: vortex.literal(5i32) + └── buzz: vortex.binary(=) + ├── lhs: vortex.literal(42i32) + └── rhs: vortex.get_item(answer) + └── input: vortex.root() + "); } } diff --git a/vortex-array/src/expr/expression.rs b/vortex-array/src/expr/expression.rs index fa798b4b17c..05ea00601a3 100644 --- a/vortex-array/src/expr/expression.rs +++ b/vortex-array/src/expr/expression.rs @@ -1,29 +1,23 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::any::Any; use std::fmt; use std::fmt::Debug; use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hash; -use std::hash::Hasher; +use std::ops::Deref; use std::sync::Arc; use itertools::Itertools; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_vector::Vector; -use vortex_vector::VectorOps; +use vortex_error::vortex_ensure; use crate::ArrayRef; -use crate::expr::ChildName; -use crate::expr::ExecutionArgs; -use crate::expr::ExprId; -use crate::expr::ExprVTable; -use crate::expr::ExpressionView; use crate::expr::Root; +use crate::expr::ScalarFn; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::display::DisplayTreeExpr; @@ -33,96 +27,62 @@ use crate::expr::stats::Stat; /// /// Expressions represent scalar computations that can be performed on data. Each /// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions. -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Expression { - /// The vtable for this expression. - vtable: ExprVTable, - /// The instance data for this expression. - data: Arc, + /// The scalar fn for this node. + scalar_fn: ScalarFn, /// Any children of this expression. children: Arc<[Expression]>, } +impl Deref for Expression { + type Target = ScalarFn; + + fn deref(&self) -> &Self::Target { + &self.scalar_fn + } +} + impl Expression { - /// Create a new expression from a vtable. - pub fn try_new( - vtable: V, - data: V::Instance, + /// Create a new expression node from a scalar_fn expression and its children. + pub fn try_new( + scalar_fn: ScalarFn, children: impl Into>, ) -> VortexResult { - let vtable = ExprVTable::new::(vtable); - let data = Arc::new(data); - Self::try_new_erased(vtable, data, children.into()) - } + let children: Arc<[Expression]> = children.into(); - /// Create a new expression from a static vtable. - pub fn new_static( - vtable: &'static V, - data: V::Instance, - children: impl Into>, - ) -> Self { - let vtable = ExprVTable::new_static::(vtable); - let data = Arc::new(data); - Self { - vtable, - data, - children: children.into(), - } - } + vortex_ensure!( + scalar_fn.signature().arity().matches(children.len()), + "Expression arity mismatch: expected {} children but got {}", + scalar_fn.signature().arity(), + children.len() + ); - /// Creates a new expression with the given encoding, metadata, and children. - /// - /// # Errors - /// - /// Returns an error if the provided `encoding` is not compatible with the - /// `metadata` and `children` or the encoding's own validation logic fails. - pub(super) fn try_new_erased( - vtable: ExprVTable, - data: Arc, - children: Arc<[Expression]>, - ) -> VortexResult { - let this = Self { - vtable, - data, + Ok(Self { + scalar_fn, children, - }; - // Validate that the encoding is compatible with the metadata and children. - this.vtable.as_dyn().validate(&this)?; - Ok(this) + }) } - /// Returns if the expression is an instance of the given vtable. + /// Returns true if this expression is of the given vtable type. pub fn is(&self) -> bool { - self.vtable.is::() - } - - /// Returns a typed view of this expression for the given vtable. - /// - /// # Panics - /// - /// Panics if the expression's encoding or metadata cannot be cast to the specified vtable. - pub fn as_(&self) -> ExpressionView<'_, V> { - ExpressionView::maybe_new(self).vortex_expect("Failed to downcast expression {} to {}") + self.vtable().is::() } - /// Returns a typed view of this expression for the given vtable, if the types match. - pub fn as_opt(&self) -> Option> { - ExpressionView::maybe_new(self) + /// Returns the typed options for this expression if it matches the given vtable type. + pub fn as_opt(&self) -> Option<&V::Options> { + self.options().as_any().downcast_ref::() } - /// Returns the expression ID. - pub fn id(&self) -> ExprId { - self.vtable.as_dyn().id() + /// Returns the typed options for this expression if it matches the given vtable type. + pub fn as_(&self) -> &V::Options { + self.as_opt::() + .vortex_expect("Expression options type mismatch") } - /// Returns the expression's vtable. - pub fn vtable(&self) -> &ExprVTable { - &self.vtable - } - - /// Returns the opaque data of the expression. - pub fn data(&self) -> &Arc { - &self.data + /// Returns the scalar fn vtable for this expression. + pub fn scalar_fn(&self) -> &ScalarFn { + &self.scalar_fn } /// Returns the children of this expression. @@ -135,60 +95,39 @@ impl Expression { &self.children[n] } - /// Returns the name of the n'th child of this expression. - pub fn child_name(&self, n: usize) -> ChildName { - self.vtable.as_dyn().child_name(self.data().as_ref(), n) - } - /// Replace the children of this expression with the provided new children. pub fn with_children(mut self, children: impl Into>) -> VortexResult { - self.children = children.into(); - self.vtable.as_dyn().validate(&self)?; + let children = children.into(); + vortex_ensure!( + self.signature().arity().matches(children.len()), + "Expression arity mismatch: expected {} children but got {}", + self.signature().arity(), + children.len() + ); + self.children = children; Ok(self) } - /// Returns the serialized metadata for this expression. - pub fn serialize_metadata(&self) -> VortexResult>> { - self.vtable.as_dyn().serialize(self.data.as_ref()) - } - /// Computes the return dtype of this expression given the input dtype. pub fn return_dtype(&self, scope: &DType) -> VortexResult { - self.vtable.as_dyn().return_dtype(self, scope) - } - - /// Evaluates the expression in the given scope. - pub fn evaluate(&self, scope: &ArrayRef) -> VortexResult { - self.vtable.as_dyn().evaluate(self, scope) - } - - /// Executes the expression over the given vector input scope. - pub fn execute(&self, vector: &Vector, dtype: &DType) -> VortexResult { - // We special-case the "root" expression that must extract that scope vector directly. if self.is::() { - return Ok(vector.clone()); + return Ok(scope.clone()); } - let return_dtype = self.return_dtype(dtype)?; - let child_dtypes: Vec<_> = self - .children - .iter() - .map(|child| child.return_dtype(dtype)) - .try_collect()?; - let child_vectors: Vec<_> = self + let dtypes: Vec<_> = self .children .iter() - .map(|child| child.execute(vector, dtype)) + .map(|c| c.return_dtype(scope)) .try_collect()?; + self.scalar_fn.return_dtype(&dtypes) + } - let args = ExecutionArgs { - vectors: child_vectors, - dtypes: child_dtypes, - row_count: vector.len(), - return_dtype, - }; - - self.vtable.as_dyn().execute(&self.data, args) + /// Evaluates the expression in the given scope, returning an array. + pub fn evaluate(&self, scope: &ArrayRef) -> VortexResult { + if self.is::() { + return Ok(scope.clone()); + } + self.scalar_fn.evaluate(self, scope) } /// An expression over zone-statistics which implies all records in the zone evaluate to false. @@ -210,7 +149,7 @@ impl Expression { /// Some expressions, in theory, have falsifications but this function does not support them /// such as `x < (y < z)` or `x LIKE "needle%"`. pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option { - self.vtable.as_dyn().stat_falsification(self, catalog) + self.vtable().as_dyn().stat_falsification(self, catalog) } /// Returns an expression representing the zoned statistic for the given stat, if available. @@ -222,46 +161,25 @@ impl Expression { /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an /// issue to discuss a solution to this. pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option { - self.vtable.as_dyn().stat_expression(self, stat, catalog) + self.vtable().as_dyn().stat_expression(self, stat, catalog) } /// Returns an expression representing the zoned maximum statistic, if available. - /// - /// See [`Self::stat_expression`] for details. pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option { self.stat_expression(Stat::Min, catalog) } /// Returns an expression representing the zoned maximum statistic, if available. - /// - /// See [`Self::stat_expression`] for details. pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option { self.stat_expression(Stat::Max, catalog) } - /// Returns whether this expression itself is null-sensitive. - /// See [`VTable::is_null_sensitive`]. - pub fn is_null_sensitive(&self) -> bool { - self.vtable.as_dyn().is_null_sensitive(self.data.as_ref()) - } - - /// Returns whether this expression itself is fallible. - /// See [`VTable::is_fallible`]. - pub fn is_fallible(&self) -> bool { - self.vtable.as_dyn().is_fallible(self.data.as_ref()) - } - /// Format the expression as a compact string. /// /// Since this is a recursive formatter, it is exposed on the public Expression type. /// See fmt_data that is only implemented on the vtable trait. pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.vtable.as_dyn().fmt_sql(self, f) - } - - /// Format the instance data of the expression as a string for rendering.. - pub fn fmt_data(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.vtable.as_dyn().fmt_data(self.data().as_ref(), f) + self.vtable().as_dyn().fmt_sql(self, f) } /// Display the expression as a formatted tree structure. @@ -327,50 +245,3 @@ impl Display for Expression { self.fmt_sql(f) } } - -struct FormatExpressionData<'a> { - vtable: &'a ExprVTable, - data: &'a Arc, -} - -impl<'a> Debug for FormatExpressionData<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.vtable.as_dyn().fmt_data(self.data.as_ref(), f) - } -} - -impl Debug for Expression { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("Expression") - .field("vtable", &self.vtable) - .field( - "data", - &FormatExpressionData { - vtable: &self.vtable, - data: &self.data, - }, - ) - .field("children", &self.children) - .finish() - } -} - -impl PartialEq for Expression { - fn eq(&self, other: &Self) -> bool { - self.vtable.as_dyn().id() == other.vtable.as_dyn().id() - && self - .vtable - .as_dyn() - .dyn_eq(self.data.as_ref(), other.data.as_ref()) - && self.children.eq(&other.children) - } -} -impl Eq for Expression {} - -impl Hash for Expression { - fn hash(&self, state: &mut H) { - self.vtable.as_dyn().id().hash(state); - self.vtable.as_dyn().dyn_hash(self.data.as_ref(), state); - self.children.hash(state); - } -} diff --git a/vortex-array/src/expr/exprs/between.rs b/vortex-array/src/expr/exprs/between.rs index 2135a5ecd8a..50bc2f94504 100644 --- a/vortex-array/src/expr/exprs/between.rs +++ b/vortex-array/src/expr/exprs/between.rs @@ -10,13 +10,15 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_proto::expr as pb; +use vortex_vector::Datum; use crate::ArrayRef; use crate::compute::BetweenOptions; use crate::compute::between as between_compute; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -38,13 +40,13 @@ use crate::expr::exprs::operators::Operator; pub struct Between; impl VTable for Between { - type Instance = BetweenOptions; + type Options = BetweenOptions; fn id(&self) -> ExprId { ExprId::from("vortex.between") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::BetweenOpts { lower_strict: instance.lower_strict.is_strict(), @@ -54,9 +56,9 @@ impl VTable for Between { )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let opts = pb::BetweenOpts::decode(metadata)?; - Ok(Some(BetweenOptions { + Ok(BetweenOptions { lower_strict: if opts.lower_strict { crate::compute::StrictComparison::Strict } else { @@ -67,20 +69,14 @@ impl VTable for Between { } else { crate::compute::StrictComparison::NonStrict }, - })) + }) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 3 { - vortex_bail!( - "Between expression requires exactly 3 children, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(3) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("array"), 1 => ChildName::from("lower"), @@ -89,8 +85,12 @@ impl VTable for Between { } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { - let options = expr.data(); + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { let lower_op = if options.lower_strict.is_strict() { "<" } else { @@ -104,27 +104,27 @@ impl VTable for Between { write!( f, "({} {} {} {} {})", - expr.lower(), + expr.child(1), lower_op, - expr.child(), + expr.child(0), upper_op, - expr.upper() + expr.child(2) ) } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let arr_dt = expr.child().return_dtype(scope)?; - let lower_dt = expr.lower().return_dtype(scope)?; - let upper_dt = expr.upper().return_dtype(scope)?; + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let arr_dt = &arg_dtypes[0]; + let lower_dt = &arg_dtypes[1]; + let upper_dt = &arg_dtypes[2]; - if !arr_dt.eq_ignore_nullability(&lower_dt) { + if !arr_dt.eq_ignore_nullability(lower_dt) { vortex_bail!( "Array dtype {} does not match lower dtype {}", arr_dt, lower_dt ); } - if !arr_dt.eq_ignore_nullability(&upper_dt) { + if !arr_dt.eq_ignore_nullability(upper_dt) { vortex_bail!( "Array dtype {} does not match upper dtype {}", arr_dt, @@ -137,51 +137,45 @@ impl VTable for Between { )) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - let arr = expr.child().evaluate(scope)?; - let lower = expr.lower().evaluate(scope)?; - let upper = expr.upper().evaluate(scope)?; - between_compute(&arr, &lower, &upper, expr.data()) + fn evaluate( + &self, + options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + let arr = expr.child(0).evaluate(scope)?; + let lower = expr.child(1).evaluate(scope)?; + let upper = expr.child(2).evaluate(scope)?; + between_compute(&arr, &lower, &upper, options) + } + + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { + todo!() } fn stat_falsification( &self, - expr: &ExpressionView, + options: &Self::Options, + expr: &Expression, catalog: &dyn StatsCatalog, ) -> Option { - expr.to_binary_expr().stat_falsification(catalog) - } - - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { - false - } -} - -impl ExpressionView<'_, Between> { - pub fn child(&self) -> &Expression { - &self.children()[0] - } - - pub fn lower(&self) -> &Expression { - &self.children()[1] - } - - pub fn upper(&self) -> &Expression { - &self.children()[2] - } - - pub fn to_binary_expr(&self) -> Expression { - let options = self.data(); - let arr = self.children()[0].clone(); - let lower = self.children()[1].clone(); - let upper = self.children()[2].clone(); + let arr = expr.child(0).clone(); + let lower = expr.child(1).clone(); + let upper = expr.child(2).clone(); let lhs = Binary.new_expr( options.lower_strict.to_operator().into(), [lower, arr.clone()], ); let rhs = Binary.new_expr(options.upper_strict.to_operator().into(), [arr, upper]); - Binary.new_expr(Operator::And, [lhs, rhs]) + + Binary + .new_expr(Operator::And, [lhs, rhs]) + .stat_falsification(catalog) + } + + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { + false } } diff --git a/vortex-array/src/expr/exprs/binary.rs b/vortex-array/src/expr/exprs/binary.rs index 067469b81b5..15b62203360 100644 --- a/vortex-array/src/expr/exprs/binary.rs +++ b/vortex-array/src/expr/exprs/binary.rs @@ -9,6 +9,7 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_proto::expr as pb; +use vortex_vector::Datum; use crate::ArrayRef; use crate::compute; @@ -19,9 +20,10 @@ use crate::compute::div; use crate::compute::mul; use crate::compute::or_kleene; use crate::compute::sub; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -33,13 +35,13 @@ use crate::expr::stats::Stat; pub struct Binary; impl VTable for Binary { - type Instance = Operator; + type Options = Operator; fn id(&self) -> ExprId { ExprId::from("vortex.binary") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::BinaryOpts { op: (*instance).into(), @@ -48,17 +50,16 @@ impl VTable for Binary { )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let opts = pb::BinaryOpts::decode(metadata)?; - Ok(Some(Operator::try_from(opts.op)?)) + Operator::try_from(opts.op) } - fn validate(&self, _expr: &ExpressionView) -> VortexResult<()> { - // TODO(ngates): check the dtypes. - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("lhs"), 1 => ChildName::from("rhs"), @@ -66,24 +67,25 @@ impl VTable for Binary { } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + operator: &Operator, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "(")?; - expr.lhs().fmt_sql(f)?; - write!(f, " {} ", expr.operator())?; - expr.rhs().fmt_sql(f)?; + expr.child(0).fmt_sql(f)?; + write!(f, " {} ", operator)?; + expr.child(1).fmt_sql(f)?; write!(f, ")") } - fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", *instance) - } - - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let lhs = expr.lhs().return_dtype(scope)?; - let rhs = expr.rhs().return_dtype(scope)?; + fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult { + let lhs = &arg_dtypes[0]; + let rhs = &arg_dtypes[1]; - if expr.operator().is_arithmetic() { - if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) { + if operator.is_arithmetic() { + if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) { return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability())); } vortex_bail!( @@ -96,11 +98,16 @@ impl VTable for Binary { Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into())) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - let lhs = expr.lhs().evaluate(scope)?; - let rhs = expr.rhs().evaluate(scope)?; - - match expr.operator() { + fn evaluate( + &self, + operator: &Operator, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + let lhs = expr.child(0).evaluate(scope)?; + let rhs = expr.child(1).evaluate(scope)?; + + match operator { Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq), Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq), Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt), @@ -116,9 +123,14 @@ impl VTable for Binary { } } + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { + todo!() + } + fn stat_falsification( &self, - expr: &ExpressionView, + operator: &Operator, + expr: &Expression, catalog: &dyn StatsCatalog, ) -> Option { // Wrap another predicate with an optional NaNCount check, if the stat is available. @@ -155,13 +167,15 @@ impl VTable for Binary { } } - match expr.operator() { + let lhs = expr.child(0); + let rhs = expr.child(1); + match operator { Operator::Eq => { - let min_lhs = expr.lhs().stat_min(catalog); - let max_lhs = expr.lhs().stat_max(catalog); + let min_lhs = lhs.stat_min(catalog); + let max_lhs = lhs.stat_max(catalog); - let min_rhs = expr.rhs().stat_min(catalog); - let max_rhs = expr.rhs().stat_max(catalog); + let min_rhs = rhs.stat_min(catalog); + let max_rhs = rhs.stat_max(catalog); let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b)); let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b)); @@ -169,99 +183,64 @@ impl VTable for Binary { let min_max_check = left.into_iter().chain(right).reduce(or)?; // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } Operator::NotEq => { - let min_lhs = expr.lhs().stat_min(catalog)?; - let max_lhs = expr.lhs().stat_max(catalog)?; + let min_lhs = lhs.stat_min(catalog)?; + let max_lhs = lhs.stat_max(catalog)?; - let min_rhs = expr.rhs().stat_min(catalog)?; - let max_rhs = expr.rhs().stat_max(catalog)?; + let min_rhs = rhs.stat_min(catalog)?; + let max_rhs = rhs.stat_max(catalog)?; let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)); - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } Operator::Gt => { - let min_max_check = - lt_eq(expr.lhs().stat_max(catalog)?, expr.rhs().stat_min(catalog)?); - - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); + + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } Operator::Gte => { // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = - lt(expr.lhs().stat_max(catalog)?, expr.rhs().stat_min(catalog)?); - - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); + + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } Operator::Lt => { // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = - gt_eq(expr.lhs().stat_min(catalog)?, expr.rhs().stat_max(catalog)?); - - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); + + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } Operator::Lte => { // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = - gt(expr.lhs().stat_min(catalog)?, expr.rhs().stat_max(catalog)?); - - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); + + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } - Operator::And => expr - .lhs() + Operator::And => lhs .stat_falsification(catalog) .into_iter() - .chain(expr.rhs().stat_falsification(catalog)) + .chain(rhs.stat_falsification(catalog)) .reduce(or), Operator::Or => Some(and( - expr.lhs().stat_falsification(catalog)?, - expr.rhs().stat_falsification(catalog)?, + lhs.stat_falsification(catalog)?, + rhs.stat_falsification(catalog)?, )), Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, } } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _operator: &Operator) -> bool { false } - fn is_fallible(&self, instance: &Self::Instance) -> bool { + fn is_fallible(&self, operator: &Operator) -> bool { // Opt-in not out for fallibility. // Arithmetic operations could be better modelled here. let infallible = matches!( - instance, + operator, Operator::Eq | Operator::NotEq | Operator::Gt @@ -276,20 +255,6 @@ impl VTable for Binary { } } -impl ExpressionView<'_, Binary> { - pub fn lhs(&self) -> &Expression { - &self.children()[0] - } - - pub fn rhs(&self) -> &Expression { - &self.children()[1] - } - - pub fn operator(&self) -> Operator { - *self.data() - } -} - /// Create a new [`Binary`] using the [`Eq`](crate::expr::exprs::operators::Operator::Eq) operator. /// /// ## Example usage @@ -636,15 +601,6 @@ mod tests { ); } - #[test] - fn test_debug_print() { - let expr = gt(lit(1), lit(2)); - assert_eq!( - format!("{expr:?}"), - "Expression { vtable: vortex.binary, data: >, children: [Expression { vtable: vortex.literal, data: 1i32, children: [] }, Expression { vtable: vortex.literal, data: 2i32, children: [] }] }" - ); - } - #[test] fn test_display_print() { let expr = gt(lit(1), lit(2)); diff --git a/vortex-array/src/expr/exprs/cast/mod.rs b/vortex-array/src/expr/exprs/cast.rs similarity index 75% rename from vortex-array/src/expr/exprs/cast/mod.rs rename to vortex-array/src/expr/exprs/cast.rs index 9808d42a9cc..53cdcf40856 100644 --- a/vortex-array/src/expr/exprs/cast/mod.rs +++ b/vortex-array/src/expr/exprs/cast.rs @@ -8,17 +8,16 @@ use prost::Message; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_proto::expr as pb; -use vortex_vector::Vector; +use vortex_vector::Datum; use crate::ArrayRef; use crate::compute::cast as compute_cast; +use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::ExecutionArgs; use crate::expr::ExprId; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -29,66 +28,59 @@ use crate::expr::stats::Stat; pub struct Cast; impl VTable for Cast { - type Instance = DType; + type Options = DType; fn id(&self) -> ExprId { ExprId::from("vortex.cast") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, dtype: &DType) -> VortexResult>> { Ok(Some( pb::CastOpts { - target: Some(instance.into()), + target: Some(dtype.into()), } .encode_to_vec(), )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { - Ok(Some( - pb::CastOpts::decode(metadata)? - .target - .as_ref() - .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))? - .try_into()?, - )) + fn deserialize(&self, metadata: &[u8]) -> VortexResult { + pb::CastOpts::decode(metadata)? + .target + .as_ref() + .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))? + .try_into() } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "Cast expression requires exactly 1 child, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &DType) -> Arity { + Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &DType, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("input"), _ => unreachable!("Invalid child index {} for Cast expression", child_idx), } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql(&self, dtype: &DType, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "cast(")?; expr.children()[0].fmt_sql(f)?; - write!(f, " as {}", expr.data())?; + write!(f, " as {}", dtype)?; write!(f, ")") } - fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", instance) - } - - fn return_dtype(&self, expr: &ExpressionView, _scope: &DType) -> VortexResult { - Ok(expr.data().clone()) + fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult { + Ok(dtype.clone()) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + dtype: &DType, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let array = expr.children()[0].evaluate(scope)?; - compute_cast(&array, expr.data()).map_err(|e| { + compute_cast(&array, dtype).map_err(|e| { e.with_context(format!( "Failed to cast array of dtype {} to {}", array.dtype(), @@ -97,9 +89,18 @@ impl VTable for Cast { }) } + fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult { + let input = args + .datums + .pop() + .vortex_expect("missing input for Cast expression"); + vortex_compute::cast::Cast::cast(&input, target_dtype) + } + fn stat_expression( &self, - expr: &ExpressionView, + dtype: &DType, + expr: &Expression, stat: Stat, catalog: &dyn StatsCatalog, ) -> Option { @@ -114,7 +115,7 @@ impl VTable for Cast { // We cast min/max to the new type expr.child(0) .stat_expression(stat, catalog) - .map(|x| cast(x, expr.data().clone())) + .map(|x| cast(x, dtype.clone())) } Stat::NullCount => { // if !expr.data().is_nullable() { @@ -129,16 +130,8 @@ impl VTable for Cast { } } - fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult { - let input = args - .vectors - .pop() - .vortex_expect("missing input for Cast expression"); - vortex_compute::cast::Cast::cast(&input, target_dtype) - } - // This might apply a nullability - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &DType) -> bool { true } } diff --git a/vortex-array/src/expr/exprs/dynamic.rs b/vortex-array/src/expr/exprs/dynamic.rs index ff325a4fb3e..9dbe9819868 100644 --- a/vortex-array/src/expr/exprs/dynamic.rs +++ b/vortex-array/src/expr/exprs/dynamic.rs @@ -15,6 +15,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; +use vortex_vector::Datum; use crate::Array; use crate::ArrayRef; @@ -22,10 +23,11 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::compute::Operator; use crate::compute::compare; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -39,116 +41,130 @@ use crate::expr::traversal::TraversalOrder; pub struct DynamicComparison; impl VTable for DynamicComparison { - type Instance = DynamicComparisonExpr; + type Options = DynamicComparisonExpr; fn id(&self) -> ExprId { ExprId::new_ref("vortex.dynamic") } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "DynamicComparison expression requires exactly one child, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("lhs"), _ => unreachable!(), } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { - expr.lhs().fmt_sql(f)?; - write!(f, " {} dynamic(", expr.data())?; - match expr.scalar() { + fn fmt_sql( + &self, + dynamic: &DynamicComparisonExpr, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + expr.child(0).fmt_sql(f)?; + write!(f, " {} dynamic(", dynamic)?; + match dynamic.scalar() { None => write!(f, "")?, Some(scalar) => write!(f, "{}", scalar)?, } write!(f, ")") } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let lhs = expr.lhs().return_dtype(scope)?; - if !expr.data().rhs.dtype.eq_ignore_nullability(&lhs) { + fn return_dtype( + &self, + dynamic: &DynamicComparisonExpr, + arg_dtypes: &[DType], + ) -> VortexResult { + let lhs = &arg_dtypes[0]; + if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) { vortex_bail!( "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}", - &expr.data().rhs.dtype, + &dynamic.rhs.dtype, lhs ); } Ok(DType::Bool( - lhs.nullability() | expr.data().rhs.dtype.nullability(), + lhs.nullability() | dynamic.rhs.dtype.nullability(), )) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - if let Some(value) = expr.scalar() { - let lhs = expr.lhs().evaluate(scope)?; + fn evaluate( + &self, + dynamic: &DynamicComparisonExpr, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + if let Some(value) = dynamic.rhs.scalar() { + let lhs = expr.child(0).evaluate(scope)?; let rhs = ConstantArray::new(value, scope.len()); - return compare(lhs.as_ref(), rhs.as_ref(), expr.data().operator); + return compare(lhs.as_ref(), rhs.as_ref(), dynamic.operator); } // Otherwise, we return the default value. let lhs = expr.return_dtype(scope.dtype())?; Ok(ConstantArray::new( Scalar::new( - DType::Bool(lhs.nullability() | expr.data().rhs.dtype.nullability()), - expr.data().default.into(), + DType::Bool(lhs.nullability() | dynamic.rhs.dtype.nullability()), + dynamic.default.into(), ), scope.len(), ) .into_array()) } + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { + todo!() + } + fn stat_falsification( &self, - expr: &ExpressionView, + dynamic: &DynamicComparisonExpr, + expr: &Expression, catalog: &dyn StatsCatalog, ) -> Option { - match expr.data().operator { + let lhs = expr.child(0); + match dynamic.operator { Operator::Gt => Some(DynamicComparison.new_expr( DynamicComparisonExpr { operator: Operator::Lte, - rhs: expr.data().rhs.clone(), - default: !expr.data().default, + rhs: dynamic.rhs.clone(), + default: !dynamic.default, }, - vec![expr.lhs().stat_max(catalog)?], + vec![lhs.stat_max(catalog)?], )), Operator::Gte => Some(DynamicComparison.new_expr( DynamicComparisonExpr { operator: Operator::Lt, - rhs: expr.data().rhs.clone(), - default: !expr.data().default, + rhs: dynamic.rhs.clone(), + default: !dynamic.default, }, - vec![expr.lhs().stat_max(catalog)?], + vec![lhs.stat_max(catalog)?], )), Operator::Lt => Some(DynamicComparison.new_expr( DynamicComparisonExpr { operator: Operator::Gte, - rhs: expr.data().rhs.clone(), - default: !expr.data().default, + rhs: dynamic.rhs.clone(), + default: !dynamic.default, }, - vec![expr.lhs().stat_min(catalog)?], + vec![lhs.stat_min(catalog)?], )), Operator::Lte => Some(DynamicComparison.new_expr( DynamicComparisonExpr { operator: Operator::Gt, - rhs: expr.data().rhs.clone(), - default: !expr.data().default, + rhs: dynamic.rhs.clone(), + default: !dynamic.default, }, - vec![expr.lhs().stat_min(catalog)?], + vec![lhs.stat_min(catalog)?], )), _ => None, } } // Defer to the child - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false } } @@ -225,6 +241,12 @@ struct Rhs { dtype: DType, } +impl Rhs { + pub fn scalar(&self) -> Option { + (self.value)().map(|v| Scalar::new(self.dtype.clone(), v)) + } +} + impl Debug for Rhs { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("Rhs") @@ -234,16 +256,6 @@ impl Debug for Rhs { } } -impl ExpressionView<'_, DynamicComparison> { - pub fn lhs(&self) -> &Expression { - &self.children()[0] - } - - pub fn scalar(&self) -> Option { - (self.data().rhs.value)().map(|v| Scalar::new(self.data().rhs.dtype.clone(), v)) - } -} - /// A utility for checking whether any dynamic expressions have been updated. pub struct DynamicExprUpdates { exprs: Box<[DynamicComparisonExpr]>, @@ -261,7 +273,7 @@ impl DynamicExprUpdates { fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult { if let Some(dynamic) = node.as_opt::() { - self.0.push(dynamic.data().clone()); + self.0.push(dynamic.clone()); } Ok(TraversalOrder::Continue) } diff --git a/vortex-array/src/expr/exprs/get_item/mod.rs b/vortex-array/src/expr/exprs/get_item.rs similarity index 51% rename from vortex-array/src/expr/exprs/get_item/mod.rs rename to vortex-array/src/expr/exprs/get_item.rs index 2eafa582e8f..af8874f55d0 100644 --- a/vortex-array/src/expr/exprs/get_item/mod.rs +++ b/vortex-array/src/expr/exprs/get_item.rs @@ -1,8 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -pub mod transform; - use std::fmt::Formatter; use std::ops::Not; @@ -13,20 +11,22 @@ use vortex_dtype::FieldPath; use vortex_dtype::Nullability; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_proto::expr as pb; -use vortex_vector::Vector; +use vortex_vector::Datum; +use vortex_vector::ScalarOps; use vortex_vector::VectorOps; use crate::ArrayRef; use crate::ToCanonical; use crate::compute::mask; +use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; +use crate::expr::Pack; +use crate::expr::SimplifyCtx; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -36,13 +36,13 @@ use crate::expr::stats::Stat; pub struct GetItem; impl VTable for GetItem { - type Instance = FieldName; + type Options = FieldName; fn id(&self) -> ExprId { ExprId::from("vortex.get_item") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::GetItemOpts { path: instance.to_string(), @@ -51,44 +51,39 @@ impl VTable for GetItem { )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let opts = pb::GetItemOpts::decode(metadata)?; - Ok(Some(FieldName::from(opts.path))) + Ok(FieldName::from(opts.path)) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "GetItem expression requires exactly 1 child, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _field_name: &FieldName) -> Arity { + Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("input"), _ => unreachable!("Invalid child index {} for GetItem expression", child_idx), } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + field_name: &FieldName, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { expr.children()[0].fmt_sql(f)?; - write!(f, ".{}", expr.data()) - } - - fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "\"{}\"", instance) + write!(f, ".{}", field_name) } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let struct_dtype = expr.children()[0].return_dtype(scope)?; + fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult { + let struct_dtype = &arg_dtypes[0]; let field_dtype = struct_dtype .as_struct_fields_opt() - .and_then(|st| st.field(expr.data())) + .and_then(|st| st.field(field_name)) .ok_or_else(|| { - vortex_err!("Couldn't find the {} field in the input scope", expr.data()) + vortex_err!("Couldn't find the {} field in the input scope", field_name) })?; // Match here to avoid cloning the dtype if nullability doesn't need to change @@ -102,9 +97,14 @@ impl VTable for GetItem { Ok(field_dtype) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + field_name: &FieldName, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let input = expr.children()[0].evaluate(scope)?.to_struct(); - let field = input.field_by_name(expr.data()).cloned()?; + let field = input.field_by_name(field_name).cloned()?; match input.dtype().nullability() { Nullability::NonNullable => Ok(field), @@ -112,9 +112,60 @@ impl VTable for GetItem { } } + fn execute(&self, field_name: &FieldName, mut args: ExecutionArgs) -> VortexResult { + let struct_dtype = args.dtypes[0] + .as_struct_fields_opt() + .ok_or_else(|| vortex_err!("Expected struct dtype for child of GetItem expression"))?; + let field_idx = struct_dtype + .find(field_name) + .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", field_name))?; + + match args.datums.pop().vortex_expect("missing input") { + Datum::Scalar(s) => { + let mut field = s.as_struct().field(field_idx); + field.mask_validity(s.is_valid()); + Ok(Datum::Scalar(field)) + } + Datum::Vector(v) => { + let mut field = v.as_struct().fields()[field_idx].clone(); + field.mask_validity(v.validity()); + Ok(Datum::Vector(field)) + } + } + } + + fn simplify( + &self, + field_name: &FieldName, + expr: &Expression, + _ctx: &dyn SimplifyCtx, + ) -> VortexResult> { + let child = expr.child(0); + + // If the child is a Pack expression, we can directly return the corresponding child. + if let Some(pack) = child.as_opt::() { + let idx = pack + .names + .iter() + .position(|name| name == field_name) + .ok_or_else(|| { + vortex_err!( + "Cannot find field {} in pack fields {:?}", + field_name, + pack.names + ) + })?; + + return Ok(Some(child.child(idx).clone())); + } + + Ok(None) + } + fn stat_expression( &self, - expr: &ExpressionView, + field_name: &FieldName, + _expr: &Expression, stat: Stat, catalog: &dyn StatsCatalog, ) -> Option { @@ -126,36 +177,15 @@ impl VTable for GetItem { // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same // name as a field in the root struct. This should be resolved with upcoming change to // falsify expressions, but for now I'm preserving the existing buggy behavior. - catalog.stats_ref(&FieldPath::from_name(expr.data().clone()), stat) - } - - fn execute(&self, field_name: &FieldName, mut args: ExecutionArgs) -> VortexResult { - let struct_dtype = args.dtypes[0] - .as_struct_fields_opt() - .ok_or_else(|| vortex_err!("Expected struct dtype for child of GetItem expression"))?; - let field_idx = struct_dtype - .find(field_name) - .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", field_name))?; - - let struct_vector = args - .vectors - .pop() - .vortex_expect("missing input") - .into_struct(); - - // We must intersect the validity with that of the parent struct - let mut field = struct_vector.fields()[field_idx].clone(); - field.mask_validity(struct_vector.validity()); - - Ok(field) + catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat) } // This will apply struct nullability field. We could add a dtype?? - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _field_name: &FieldName) -> bool { true } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _field_name: &FieldName) -> bool { // If this type-checks its infallible. false } @@ -191,13 +221,18 @@ mod tests { use vortex_dtype::DType; use vortex_dtype::FieldNames; use vortex_dtype::Nullability; - use vortex_dtype::PType::I32; + use vortex_dtype::Nullability::NonNullable; + use vortex_dtype::PType; + use vortex_dtype::StructFields; use vortex_scalar::Scalar; - use super::get_item; use crate::Array; use crate::IntoArray; use crate::arrays::StructArray; + use crate::expr::exprs::binary::checked_add; + use crate::expr::exprs::get_item::get_item; + use crate::expr::exprs::literal::lit; + use crate::expr::exprs::pack::pack; use crate::expr::exprs::root::root; use crate::validity::Validity; @@ -214,7 +249,7 @@ mod tests { let st = test_array(); let get_item = get_item("a", root()); let item = get_item.evaluate(&st.to_array()).unwrap(); - assert_eq!(item.dtype(), &DType::from(I32)) + assert_eq!(item.dtype(), &DType::from(PType::I32)) } #[test] @@ -239,7 +274,70 @@ mod tests { let item = get_item.evaluate(&st).unwrap(); assert_eq!( item.scalar_at(0), - Scalar::null(DType::Primitive(I32, Nullability::Nullable)) + Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)) ); } + + #[test] + fn test_pack_get_item_rule() { + // Create: pack(a: lit(1), b: lit(2)).get_item("b") + let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable); + let get_item_expr = get_item("b", pack_expr); + + let result = get_item_expr + .simplify(&DType::Struct(StructFields::empty(), NonNullable)) + .unwrap(); + + assert_eq!(result, lit(2)); + } + + #[test] + fn test_multi_level_pack_get_item_simplify() { + let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable); + let get_a = get_item("a", inner_pack); + + let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable); + let get_z = get_item("z", outer_pack); + + let dtype = DType::Primitive(PType::I32, NonNullable); + + let result = get_z.simplify(&dtype).unwrap(); + assert_eq!(result, lit(4)); + } + + #[test] + fn test_deeply_nested_pack_get_item() { + let innermost = pack([("a", lit(42))], NonNullable); + let get_a = get_item("a", innermost); + + let level2 = pack([("b", get_a)], NonNullable); + let get_b = get_item("b", level2); + + let level3 = pack([("c", get_b)], NonNullable); + let get_c = get_item("c", level3); + + let outermost = pack([("final", get_c)], NonNullable); + let get_final = get_item("final", outermost); + + let dtype = DType::Primitive(PType::I32, NonNullable); + + let result = get_final.simplify(&dtype).unwrap(); + assert_eq!(result, lit(42)); + } + + #[test] + fn test_partial_pack_get_item_simplify() { + let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable); + let get_x = get_item("x", inner_pack); + let add_expr = checked_add(get_x, lit(10)); + + let outer_pack = pack([("result", add_expr)], NonNullable); + let get_result = get_item("result", outer_pack); + + let dtype = DType::Primitive(PType::I32, NonNullable); + + let result = get_result.simplify(&dtype).unwrap(); + let expected = checked_add(lit(1), lit(10)); + assert_eq!(&result, &expected); + } } diff --git a/vortex-array/src/expr/exprs/get_item/transform.rs b/vortex-array/src/expr/exprs/get_item/transform.rs deleted file mode 100644 index 82c74f54095..00000000000 --- a/vortex-array/src/expr/exprs/get_item/transform.rs +++ /dev/null @@ -1,139 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; - -use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::exprs::get_item::GetItem; -use crate::expr::exprs::pack::Pack; -use crate::expr::transform::rules::ReduceRule; -use crate::expr::transform::rules::RuleContext; - -/// Rewrite rule: `pack(l_1: e_1, ..., l_i: e_i, ..., l_n: e_n).get_item(l_i) = e_i` -/// -/// Simplifies accessing a field from a pack expression by directly returning the field's -/// expression instead of materializing the pack. -#[derive(Debug, Default)] -pub struct PackGetItemRule; - -impl ReduceRule for PackGetItemRule { - fn reduce( - &self, - get_item: &ExpressionView, - _ctx: &RuleContext, - ) -> VortexResult> { - if let Some(pack) = get_item.child(0).as_opt::() { - let field_expr = pack.field(get_item.data())?; - return Ok(Some(field_expr)); - } - - Ok(None) - } -} - -#[cfg(test)] -mod tests { - use vortex_dtype::DType; - use vortex_dtype::Nullability::NonNullable; - use vortex_dtype::PType; - - use super::PackGetItemRule; - use crate::expr::exprs::binary::checked_add; - use crate::expr::exprs::get_item::GetItem; - use crate::expr::exprs::get_item::get_item; - use crate::expr::exprs::literal::lit; - use crate::expr::exprs::pack::pack; - use crate::expr::session::ExprSession; - use crate::expr::transform::ExprOptimizer; - use crate::expr::transform::rules::ReduceRule; - use crate::expr::transform::rules::RuleContext; - - #[test] - fn test_pack_get_item_rule() { - // Create: pack(a: lit(1), b: lit(2)).get_item("b") - let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable); - let get_item_expr = get_item("b", pack_expr); - - let get_item_view = get_item_expr.as_::(); - let result = PackGetItemRule - .reduce(&get_item_view, &RuleContext) - .unwrap(); - - assert!(result.is_some()); - assert_eq!(&result.unwrap(), &lit(2)); - } - - #[test] - fn test_pack_get_item_rule_no_match() { - // Create: get_item("x", lit(42)) - not a pack child - let lit_expr = lit(42); - let get_item_expr = get_item("x", lit_expr); - - let get_item_view = get_item_expr.as_::(); - let result = PackGetItemRule - .reduce(&get_item_view, &RuleContext) - .unwrap(); - - assert!(result.is_none()); - } - - #[test] - fn test_multi_level_pack_get_item_simplify() { - let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable); - let get_a = get_item("a", inner_pack); - - let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable); - let get_z = get_item("z", outer_pack); - - let dtype = DType::Primitive(PType::I32, NonNullable); - - let session = ExprSession::default(); - let optimizer = ExprOptimizer::new(&session); - let result = optimizer.optimize_typed(get_z, &dtype).unwrap(); - - assert_eq!(&result, &lit(4)); - } - - #[test] - fn test_deeply_nested_pack_get_item() { - let innermost = pack([("a", lit(42))], NonNullable); - let get_a = get_item("a", innermost); - - let level2 = pack([("b", get_a)], NonNullable); - let get_b = get_item("b", level2); - - let level3 = pack([("c", get_b)], NonNullable); - let get_c = get_item("c", level3); - - let outermost = pack([("final", get_c)], NonNullable); - let get_final = get_item("final", outermost); - - let dtype = DType::Primitive(PType::I32, NonNullable); - - let session = ExprSession::default(); - let optimizer = ExprOptimizer::new(&session); - let result = optimizer.optimize_typed(get_final, &dtype).unwrap(); - - assert_eq!(&result, &lit(42)); - } - - #[test] - fn test_partial_pack_get_item_simplify() { - let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable); - let get_x = get_item("x", inner_pack); - let add_expr = checked_add(get_x, lit(10)); - - let outer_pack = pack([("result", add_expr)], NonNullable); - let get_result = get_item("result", outer_pack); - - let dtype = DType::Primitive(PType::I32, NonNullable); - - let session = ExprSession::default(); - let optimizer = ExprOptimizer::new(&session); - let result = optimizer.optimize_typed(get_result, &dtype).unwrap(); - - let expected = checked_add(lit(1), lit(10)); - assert_eq!(&result, &expected); - } -} diff --git a/vortex-array/src/expr/exprs/is_null.rs b/vortex-array/src/expr/exprs/is_null.rs index ba3dacd6f7b..8205179a498 100644 --- a/vortex-array/src/expr/exprs/is_null.rs +++ b/vortex-array/src/expr/exprs/is_null.rs @@ -4,15 +4,15 @@ use std::fmt::Formatter; use std::ops::Not; -use is_null::IsNullFn; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_mask::Mask; -use vortex_vector::Vector; +use vortex_vector::Datum; +use vortex_vector::ScalarOps; use vortex_vector::VectorOps; +use vortex_vector::bool::BoolScalar; use vortex_vector::bool::BoolVector; use crate::Array; @@ -20,67 +20,69 @@ use crate::ArrayRef; use crate::IntoArray; use crate::arrays::BoolArray; use crate::arrays::ConstantArray; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::EmptyOptions; use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::ScalarFnExprExt; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; use crate::expr::exprs::binary::eq; use crate::expr::exprs::literal::lit; -use crate::expr::functions::EmptyOptions; use crate::expr::stats::Stat; -use crate::scalar_fns::is_null; /// Expression that checks for null values. pub struct IsNull; impl VTable for IsNull { - type Instance = (); + type Options = EmptyOptions; fn id(&self) -> ExprId { ExprId::new_ref("is_null") } - fn serialize(&self, _instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, _instance: &Self::Options) -> VortexResult>> { Ok(Some(vec![])) } - fn deserialize(&self, _metadata: &[u8]) -> VortexResult> { - Ok(Some(())) + fn deserialize(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyOptions) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "IsNull expression expects exactly one child, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("input"), _ => unreachable!("Invalid child index {} for IsNull expression", child_idx), } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "is_null(")?; expr.child(0).fmt_sql(f)?; write!(f, ")") } - fn return_dtype(&self, _expr: &ExpressionView, _scope: &DType) -> VortexResult { + fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult { Ok(DType::Bool(Nullability::NonNullable)) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + _options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let array = expr.child(0).evaluate(scope)?; match array.validity_mask() { Mask::AllTrue(len) => Ok(ConstantArray::new(false, len).into_array()), @@ -89,35 +91,33 @@ impl VTable for IsNull { } } + fn execute(&self, _data: &Self::Options, mut args: ExecutionArgs) -> VortexResult { + let child = args.datums.pop().vortex_expect("Missing input child"); + Ok(match child { + Datum::Scalar(s) => Datum::Scalar(BoolScalar::new(Some(s.is_invalid())).into()), + Datum::Vector(v) => Datum::Vector( + BoolVector::new(v.validity().to_bit_buffer().not(), Mask::new_true(v.len())).into(), + ), + }) + } + fn stat_falsification( &self, - expr: &ExpressionView, + _options: &Self::Options, + expr: &Expression, catalog: &dyn StatsCatalog, ) -> Option { let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?; Some(eq(null_count_expr, lit(0u64))) } - fn execute(&self, _data: &Self::Instance, mut args: ExecutionArgs) -> VortexResult { - let child = args.vectors.pop().vortex_expect("Missing input child"); - Ok(BoolVector::new( - child.validity().to_bit_buffer().not(), - Mask::new_true(child.len()), - ) - .into()) - } - - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _instance: &Self::Options) -> bool { false } - - fn expr_v2(&self, view: &ExpressionView) -> VortexResult { - ScalarFnExprExt::try_new_expr(&IsNullFn, EmptyOptions, view.children().clone()) - } } /// Creates an expression that checks for null values. @@ -129,7 +129,7 @@ impl VTable for IsNull { /// let expr = is_null(root()); /// ``` pub fn is_null(child: Expression) -> Expression { - IsNull.new_expr((), vec![child]) + IsNull.new_expr(EmptyOptions, vec![child]) } #[cfg(test)] @@ -279,6 +279,6 @@ mod tests { #[test] fn test_is_null_sensitive() { // is_null itself is null-sensitive - assert!(is_null(col("a")).is_null_sensitive()); + assert!(is_null(col("a")).signature().is_null_sensitive()); } } diff --git a/vortex-array/src/expr/exprs/like.rs b/vortex-array/src/expr/exprs/like.rs index 65f6d823dce..d614e2703f7 100644 --- a/vortex-array/src/expr/exprs/like.rs +++ b/vortex-array/src/expr/exprs/like.rs @@ -8,14 +8,16 @@ use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_proto::expr as pb; +use vortex_vector::Datum; use crate::ArrayRef; use crate::compute::LikeOptions; use crate::compute::like as like_compute; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; use crate::expr::VTable; use crate::expr::VTableExt; @@ -23,13 +25,13 @@ use crate::expr::VTableExt; pub struct Like; impl VTable for Like { - type Instance = LikeOptions; + type Options = LikeOptions; fn id(&self) -> ExprId { ExprId::from("vortex.like") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::LikeOpts { negated: instance.negated, @@ -39,25 +41,19 @@ impl VTable for Like { )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let opts = pb::LikeOpts::decode(metadata)?; - Ok(Some(LikeOptions { + Ok(LikeOptions { negated: opts.negated, case_insensitive: opts.case_insensitive, - })) + }) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 2 { - vortex_bail!( - "Like expression requires exactly 2 children, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("child"), 1 => ChildName::from("pattern"), @@ -65,12 +61,17 @@ impl VTable for Like { } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { expr.child(0).fmt_sql(f)?; - if expr.data().negated { + if options.negated { write!(f, " not")?; } - if expr.data().case_insensitive { + if options.case_insensitive { write!(f, " ilike ")?; } else { write!(f, " like ")?; @@ -78,9 +79,9 @@ impl VTable for Like { expr.child(1).fmt_sql(f) } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let input = expr.children()[0].return_dtype(scope)?; - let pattern = expr.children()[1].return_dtype(scope)?; + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let input = &arg_dtypes[0]; + let pattern = &arg_dtypes[1]; if !input.is_utf8() { vortex_bail!("LIKE expression requires UTF8 input dtype, got {}", input); @@ -97,13 +98,22 @@ impl VTable for Like { )) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let child = expr.child(0).evaluate(scope)?; let pattern = expr.child(1).evaluate(scope)?; - like_compute(&child, &pattern, *expr.data()) + like_compute(&child, &pattern, *options) + } + + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { + todo!() } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false } } diff --git a/vortex-array/src/expr/exprs/list_contains.rs b/vortex-array/src/expr/exprs/list_contains.rs index 8a4664b8e2a..e0b5d2ebc31 100644 --- a/vortex-array/src/expr/exprs/list_contains.rs +++ b/vortex-array/src/expr/exprs/list_contains.rs @@ -2,17 +2,21 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::fmt::Formatter; +use std::ops::BitOr; use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_vector::Datum; use crate::ArrayRef; use crate::compute::list_contains as compute_list_contains; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::EmptyOptions; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -26,31 +30,25 @@ use crate::expr::exprs::literal::lit; pub struct ListContains; impl VTable for ListContains { - type Instance = (); + type Options = EmptyOptions; fn id(&self) -> ExprId { ExprId::from("vortex.list.contains") } - fn serialize(&self, _instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, _instance: &Self::Options) -> VortexResult>> { Ok(Some(vec![])) } - fn deserialize(&self, _metadata: &[u8]) -> VortexResult> { - Ok(Some(())) + fn deserialize(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyOptions) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 2 { - vortex_bail!( - "ListContains expression requires exactly 2 children, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("list"), 1 => ChildName::from("needle"), @@ -60,8 +58,12 @@ impl VTable for ListContains { ), } } - - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "contains(")?; expr.child(0).fmt_sql(f)?; write!(f, ", ")?; @@ -69,9 +71,9 @@ impl VTable for ListContains { write!(f, ")") } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let list_dtype = expr.child(0).return_dtype(scope)?; - let value_dtype = expr.child(1).return_dtype(scope)?; + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let list_dtype = &arg_dtypes[0]; + let needle_dtype = &arg_dtypes[0]; let nullability = match list_dtype { DType::List(_, list_nullability) => list_nullability, @@ -81,38 +83,52 @@ impl VTable for ListContains { list_dtype ); } - } | value_dtype.nullability(); + } + .bitor(needle_dtype.nullability()); Ok(DType::Bool(nullability)) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + _options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let list_array = expr.child(0).evaluate(scope)?; let value_array = expr.child(1).evaluate(scope)?; compute_list_contains(list_array.as_ref(), value_array.as_ref()) } + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { + todo!() + } + fn stat_falsification( &self, - expr: &ExpressionView, + _options: &Self::Options, + expr: &Expression, catalog: &dyn StatsCatalog, ) -> Option { + let list = expr.child(0); + let needle = expr.child(1); + // falsification(contains([1,2,5], x)) => // falsification(x != 1) and falsification(x != 2) and falsification(x != 5) - let min = expr.list().stat_min(catalog)?; - let max = expr.list().stat_max(catalog)?; + let min = list.stat_min(catalog)?; + let max = list.stat_max(catalog)?; // If the list is constant when we can compare each element to the value if min == max { let list_ = min .as_opt::() - .and_then(|l| l.data().as_list_opt()) + .and_then(|l| l.as_list_opt()) .and_then(|l| l.elements())?; if list_.is_empty() { // contains([], x) is always false. return Some(lit(true)); } - let value_max = expr.needle().stat_max(catalog)?; - let value_min = expr.needle().stat_min(catalog)?; + let value_max = needle.stat_max(catalog)?; + let value_min = needle.stat_min(catalog)?; return list_ .iter() @@ -129,7 +145,7 @@ impl VTable for ListContains { } // Nullability matters for contains([], x) where x is false. - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } } @@ -143,17 +159,7 @@ impl VTable for ListContains { /// let expr = list_contains(root(), lit(42)); /// ``` pub fn list_contains(list: Expression, value: Expression) -> Expression { - ListContains.new_expr((), [list, value]) -} - -impl ExpressionView<'_, ListContains> { - pub fn list(&self) -> &Expression { - &self.children()[0] - } - - pub fn needle(&self) -> &Expression { - &self.children()[1] - } + ListContains.new_expr(EmptyOptions, [list, value]) } #[cfg(test)] diff --git a/vortex-array/src/expr/exprs/literal.rs b/vortex-array/src/expr/exprs/literal.rs index 949871ca601..eb5f5e089d3 100644 --- a/vortex-array/src/expr/exprs/literal.rs +++ b/vortex-array/src/expr/exprs/literal.rs @@ -7,19 +7,20 @@ use prost::Message; use vortex_dtype::DType; use vortex_dtype::match_each_float_ptype; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_proto::expr as pb; use vortex_scalar::Scalar; +use vortex_vector::Datum; use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ConstantArray; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -29,13 +30,13 @@ use crate::expr::stats::Stat; pub struct Literal; impl VTable for Literal { - type Instance = Scalar; + type Options = Scalar; fn id(&self) -> ExprId { ExprId::new_ref("vortex.literal") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::LiteralOpts { value: Some(instance.as_ref().into()), @@ -44,49 +45,52 @@ impl VTable for Literal { )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let ops = pb::LiteralOpts::decode(metadata)?; - Ok(Some( - ops.value - .as_ref() - .ok_or_else(|| vortex_err!("Literal metadata missing value"))? - .try_into()?, - )) + ops.value + .as_ref() + .ok_or_else(|| vortex_err!("Literal metadata missing value"))? + .try_into() } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if !expr.children().is_empty() { - vortex_bail!( - "Literal expression does not have children, got: {:?}", - expr.children() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(0) } - fn child_name(&self, _instance: &Self::Instance, _child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, _child_idx: usize) -> ChildName { unreachable!() } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", expr.data()) + fn fmt_sql( + &self, + scalar: &Scalar, + _expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "{}", scalar) } - fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", instance) + fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult { + Ok(options.dtype().clone()) } - fn return_dtype(&self, expr: &ExpressionView, _scope: &DType) -> VortexResult { - Ok(expr.data().dtype().clone()) + fn evaluate( + &self, + scalar: &Scalar, + _expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + Ok(ConstantArray::new(scalar.clone(), scope.len()).into_array()) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - Ok(ConstantArray::new(expr.data().clone(), scope.len()).into_array()) + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { + todo!() } fn stat_expression( &self, - expr: &ExpressionView, + scalar: &Scalar, + _expr: &Expression, stat: Stat, _catalog: &dyn StatsCatalog, ) -> Option { @@ -96,12 +100,12 @@ impl VTable for Literal { // only currently used for pruning, it doesn't change the outcome. match stat { - Stat::Min | Stat::Max => Some(lit(expr.data().clone())), + Stat::Min | Stat::Max => Some(lit(scalar.clone())), Stat::IsConstant => Some(lit(true)), Stat::NaNCount => { // The NaNCount for a non-float literal is not defined. // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise. - let value = expr.data().as_primitive_opt()?; + let value = scalar.as_primitive_opt()?; if !value.ptype().is_float() { return None; } @@ -115,7 +119,7 @@ impl VTable for Literal { }) } Stat::NullCount => { - if expr.data().is_null() { + if scalar.is_null() { Some(lit(1u64)) } else { Some(lit(0u64)) @@ -127,11 +131,11 @@ impl VTable for Literal { } } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _instance: &Self::Options) -> bool { false } } @@ -149,8 +153,8 @@ impl VTable for Literal { /// /// let number = lit(34i32); /// -/// let literal = number.as_::(); -/// assert_eq!(literal.data(), &Scalar::primitive(34i32, Nullability::NonNullable)); +/// let scalar = number.as_::(); +/// assert_eq!(scalar, &Scalar::primitive(34i32, Nullability::NonNullable)); /// ``` pub fn lit(value: impl Into) -> Expression { Literal.new_expr(value.into(), []) diff --git a/vortex-array/src/expr/exprs/mask.rs b/vortex-array/src/expr/exprs/mask.rs new file mode 100644 index 00000000000..fa1042d61de --- /dev/null +++ b/vortex-array/src/expr/exprs/mask.rs @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt::Formatter; + +use vortex_dtype::DType; +use vortex_dtype::Nullability; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_vector::BoolDatum; +use vortex_vector::Datum; +use vortex_vector::ScalarOps; +use vortex_vector::VectorMutOps; +use vortex_vector::VectorOps; + +use crate::ArrayRef; +use crate::expr::Arity; +use crate::expr::ChildName; +use crate::expr::EmptyOptions; +use crate::expr::ExecutionArgs; +use crate::expr::ExprId; +use crate::expr::Expression; +use crate::expr::VTable; +use crate::expr::VTableExt; + +/// An expression that masks an input based on a boolean mask. +/// +/// Where the mask is true, the input value is retained; where the mask is false, the output is +/// null. In other words, this performs an intersection of the input's validity with the mask. +pub struct Mask; + +impl VTable for Mask { + type Options = EmptyOptions; + + fn id(&self) -> ExprId { + ExprId::from("vortex.mask") + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("input"), + 1 => ChildName::from("mask"), + _ => unreachable!("Invalid child index {} for Mask expression", child_idx), + } + } + + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "mask(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ", ")?; + expr.child(1).fmt_sql(f)?; + write!(f, ")") + } + + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + vortex_ensure!( + arg_dtypes[1] == DType::Bool(Nullability::NonNullable), + "The mask argument to 'mask' must be a non-nullable boolean array, got {}", + arg_dtypes[1] + ); + Ok(arg_dtypes[0].as_nullable()) + } + + fn evaluate( + &self, + _options: &Self::Options, + _expr: &Expression, + _scope: &ArrayRef, + ) -> VortexResult { + todo!() + } + + fn execute(&self, _options: &Self::Options, mut args: ExecutionArgs) -> VortexResult { + let input = args.datums.pop().vortex_expect("Missing input datum"); + let mask = args + .datums + .pop() + .vortex_expect("Missing mask datum") + .into_bool(); + + match (input, mask) { + (Datum::Scalar(input), BoolDatum::Scalar(mask)) => { + let mut result = input; + result.mask_validity(mask.value().vortex_expect("mask is non-nullable")); + Ok(Datum::Scalar(result)) + } + (Datum::Scalar(input), BoolDatum::Vector(mask)) => { + let mut result = input.repeat(args.row_count).freeze(); + result.mask_validity(&vortex_mask::Mask::from(mask.into_bits())); + Ok(Datum::Vector(result)) + } + (Datum::Vector(input_array), BoolDatum::Scalar(mask)) => { + let mut result = input_array; + result.mask_validity(&vortex_mask::Mask::new( + args.row_count, + mask.value().vortex_expect("mask is non-nullable"), + )); + Ok(Datum::Vector(result)) + } + (Datum::Vector(input_array), BoolDatum::Vector(mask)) => { + let mut result = input_array; + result.mask_validity(&vortex_mask::Mask::from(mask.into_bits())); + Ok(Datum::Vector(result)) + } + } + } +} + +/// Creates a mask expression that applies the given boolean mask to the input array. +pub fn mask(array: Expression, mask: Expression) -> Expression { + Mask.new_expr(EmptyOptions, [array, mask]) +} diff --git a/vortex-array/src/expr/exprs/merge/mod.rs b/vortex-array/src/expr/exprs/merge.rs similarity index 76% rename from vortex-array/src/expr/exprs/merge/mod.rs rename to vortex-array/src/expr/exprs/merge.rs index 3d2a2ec1ee5..f0a5372e3f2 100644 --- a/vortex-array/src/expr/exprs/merge/mod.rs +++ b/vortex-array/src/expr/exprs/merge.rs @@ -1,8 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -pub mod transform; - +use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hash; use std::sync::Arc; @@ -12,21 +11,27 @@ use vortex_dtype::DType; use vortex_dtype::FieldNames; use vortex_dtype::Nullability; use vortex_dtype::StructFields; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_utils::aliases::hash_set::HashSet; +use vortex_vector::Datum; use crate::Array; use crate::ArrayRef; use crate::IntoArray as _; use crate::ToCanonical; use crate::arrays::StructArray; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; +use crate::expr::SimplifyCtx; use crate::expr::VTable; use crate::expr::VTableExt; +use crate::expr::get_item; +use crate::expr::pack; use crate::validity::Validity; /// Merge zero or more expressions that ALL return structs. @@ -38,20 +43,20 @@ use crate::validity::Validity; pub struct Merge; impl VTable for Merge { - type Instance = DuplicateHandling; + type Options = DuplicateHandling; fn id(&self) -> ExprId { ExprId::new_ref("vortex.merge") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some(match instance { DuplicateHandling::RightMost => vec![0x00], DuplicateHandling::Error => vec![0x01], })) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let instance = match metadata { [0x00] => DuplicateHandling::RightMost, [0x01] => DuplicateHandling::Error, @@ -59,18 +64,23 @@ impl VTable for Merge { vortex_bail!("invalid metadata for Merge expression"); } }; - Ok(Some(instance)) + Ok(instance) } - fn validate(&self, _expr: &ExpressionView) -> VortexResult<()> { - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Variadic { min: 0, max: None } } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { ChildName::from(Arc::from(format!("{}", child_idx))) } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "merge(")?; for (i, child) in expr.children().iter().enumerate() { child.fmt_sql(f)?; @@ -81,14 +91,13 @@ impl VTable for Merge { write!(f, ")") } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { + fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { let mut field_names = Vec::new(); let mut arrays = Vec::new(); let mut merge_nullability = Nullability::NonNullable; let mut duplicate_names = HashSet::<_>::new(); - for child in expr.children().iter() { - let dtype = child.return_dtype(scope)?; + for dtype in arg_dtypes { let Some(fields) = dtype.as_struct_fields_opt() else { vortex_bail!("merge expects struct input"); }; @@ -109,7 +118,7 @@ impl VTable for Merge { } } - if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() { + if options == &DuplicateHandling::Error && !duplicate_names.is_empty() { vortex_bail!( "merge: duplicate fields in children: {}", duplicate_names.into_iter().format(", ") @@ -122,7 +131,12 @@ impl VTable for Merge { )) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { // Collect fields in order of appearance. Later fields overwrite earlier fields. let mut field_names = Vec::new(); let mut arrays = Vec::new(); @@ -151,7 +165,7 @@ impl VTable for Merge { } } - if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() { + if options == &DuplicateHandling::Error && !duplicate_names.is_empty() { vortex_bail!( "merge: duplicate fields in children: {}", duplicate_names.into_iter().format(", ") @@ -167,11 +181,68 @@ impl VTable for Merge { ) } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { + todo!() + } + + fn simplify( + &self, + options: &Self::Options, + expr: &Expression, + ctx: &dyn SimplifyCtx, + ) -> VortexResult> { + let merge_dtype = ctx.return_dtype(expr)?; + let mut names = Vec::with_capacity(expr.children().len() * 2); + let mut children = Vec::with_capacity(expr.children().len() * 2); + let mut duplicate_names = HashSet::<_>::new(); + + for child in expr.children().iter() { + let child_dtype = ctx.return_dtype(child)?; + if !child_dtype.is_struct() { + vortex_bail!( + "Merge child must return a non-nullable struct dtype, got {}", + child_dtype + ) + } + + let child_dtype = child_dtype + .as_struct_fields_opt() + .vortex_expect("expected struct"); + + for name in child_dtype.names().iter() { + if let Some(idx) = names.iter().position(|n| n == name) { + duplicate_names.insert(name.clone()); + children[idx] = child.clone(); + } else { + names.push(name.clone()); + children.push(child.clone()); + } + } + + if options == &DuplicateHandling::Error && !duplicate_names.is_empty() { + vortex_bail!( + "merge: duplicate fields in children: {}", + duplicate_names.into_iter().format(", ") + ) + } + } + + let expr = pack( + names + .into_iter() + .zip(children) + .map(|(name, child)| (name.clone(), get_item(name, child))), + merge_dtype.nullability(), + ); + + Ok(Some(expr)) + } + + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } - fn is_fallible(&self, instance: &Self::Instance) -> bool { + fn is_fallible(&self, instance: &Self::Options) -> bool { matches!(instance, DuplicateHandling::Error) } } @@ -186,6 +257,15 @@ pub enum DuplicateHandling { Error, } +impl Display for DuplicateHandling { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DuplicateHandling::RightMost => write!(f, "RightMost"), + DuplicateHandling::Error => write!(f, "Error"), + } + } +} + /// Creates an expression that merges struct expressions into a single struct. /// /// Combines fields from all input expressions. If field names are duplicated, @@ -212,6 +292,12 @@ pub fn merge_opts( #[cfg(test)] mod tests { use vortex_buffer::buffer; + use vortex_dtype::DType; + use vortex_dtype::Nullability::NonNullable; + use vortex_dtype::PType::I32; + use vortex_dtype::PType::I64; + use vortex_dtype::PType::U32; + use vortex_dtype::PType::U64; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -222,6 +308,7 @@ mod tests { use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; use crate::expr::Expression; + use crate::expr::Pack; use crate::expr::exprs::get_item::get_item; use crate::expr::exprs::merge::DuplicateHandling; use crate::expr::exprs::merge::merge_opts; @@ -472,4 +559,28 @@ mod tests { let expr2 = merge(vec![get_item("a", root())]); assert_eq!(expr2.to_string(), "merge($.a)"); } + + #[test] + fn test_remove_merge() { + let dtype = DType::struct_( + [ + ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)), + ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)), + ], + NonNullable, + ); + + let e = merge_opts( + [get_item("0", root()), get_item("1", root())], + DuplicateHandling::RightMost, + ); + + let result = e.simplify(&dtype).unwrap(); + + assert!(result.is::()); + assert_eq!( + result.return_dtype(&dtype).unwrap(), + DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable) + ); + } } diff --git a/vortex-array/src/expr/exprs/merge/transform.rs b/vortex-array/src/expr/exprs/merge/transform.rs deleted file mode 100644 index 55d7bee73d6..00000000000 --- a/vortex-array/src/expr/exprs/merge/transform.rs +++ /dev/null @@ -1,125 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use itertools::Itertools as _; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_utils::aliases::hash_set::HashSet; - -use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::exprs::get_item::get_item; -use crate::expr::exprs::merge::DuplicateHandling; -use crate::expr::exprs::merge::Merge; -use crate::expr::exprs::pack::pack; -use crate::expr::transform::rules::ReduceRule; -use crate::expr::transform::rules::TypedRuleContext; - -/// Rule that removes Merge expressions by converting them to Pack + GetItem. -/// -/// Transforms: `merge([struct1, struct2])` → `pack(field1: get_item("field1", struct1), field2: get_item("field2", struct2), ...)` -#[derive(Debug, Default)] -pub struct RemoveMergeRule; - -impl ReduceRule for RemoveMergeRule { - fn reduce( - &self, - merge: &ExpressionView, - ctx: &TypedRuleContext, - ) -> VortexResult> { - let merge_dtype = merge.return_dtype(ctx.dtype())?; - let mut names = Vec::with_capacity(merge.children().len() * 2); - let mut children = Vec::with_capacity(merge.children().len() * 2); - let mut duplicate_names = HashSet::<_>::new(); - - for child in merge.children().iter() { - let child_dtype = child.return_dtype(ctx.dtype())?; - if !child_dtype.is_struct() { - vortex_bail!( - "Merge child must return a non-nullable struct dtype, got {}", - child_dtype - ) - } - - let child_dtype = child_dtype - .as_struct_fields_opt() - .vortex_expect("expected struct"); - - for name in child_dtype.names().iter() { - if let Some(idx) = names.iter().position(|n| n == name) { - duplicate_names.insert(name.clone()); - children[idx] = child.clone(); - } else { - names.push(name.clone()); - children.push(child.clone()); - } - } - - if merge.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() { - vortex_bail!( - "merge: duplicate fields in children: {}", - duplicate_names.into_iter().format(", ") - ) - } - } - - let expr = pack( - names - .into_iter() - .zip(children) - .map(|(name, child)| (name.clone(), get_item(name, child))), - merge_dtype.nullability(), - ); - - Ok(Some(expr)) - } -} - -#[cfg(test)] -mod tests { - use vortex_dtype::DType; - use vortex_dtype::Nullability::NonNullable; - use vortex_dtype::PType::I32; - use vortex_dtype::PType::I64; - use vortex_dtype::PType::U32; - use vortex_dtype::PType::U64; - - use super::RemoveMergeRule; - use crate::expr::exprs::get_item::get_item; - use crate::expr::exprs::merge::DuplicateHandling; - use crate::expr::exprs::merge::Merge; - use crate::expr::exprs::merge::merge_opts; - use crate::expr::exprs::pack::Pack; - use crate::expr::exprs::root::root; - use crate::expr::transform::rules::ReduceRule; - use crate::expr::transform::rules::TypedRuleContext; - - #[test] - fn test_remove_merge() { - let dtype = DType::struct_( - [ - ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)), - ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)), - ], - NonNullable, - ); - - let e = merge_opts( - [get_item("0", root()), get_item("1", root())], - DuplicateHandling::RightMost, - ); - - let ctx = TypedRuleContext::new(dtype.clone()); - let merge_view = e.as_::(); - let result = RemoveMergeRule.reduce(&merge_view, &ctx).unwrap(); - - assert!(result.is_some()); - let result = result.unwrap(); - assert!(result.is::()); - assert_eq!( - result.return_dtype(&dtype).unwrap(), - DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable) - ); - } -} diff --git a/vortex-array/src/expr/exprs/mod.rs b/vortex-array/src/expr/exprs/mod.rs index 7965311aae2..c606b53f5a0 100644 --- a/vortex-array/src/expr/exprs/mod.rs +++ b/vortex-array/src/expr/exprs/mod.rs @@ -10,12 +10,12 @@ pub(crate) mod is_null; pub(crate) mod like; pub(crate) mod list_contains; pub(crate) mod literal; +pub(crate) mod mask; pub(crate) mod merge; pub(crate) mod not; pub(crate) mod operators; pub(crate) mod pack; pub(crate) mod root; -pub(crate) mod scalar_fn; pub(crate) mod select; pub use between::*; @@ -27,10 +27,10 @@ pub use is_null::*; pub use like::*; pub use list_contains::*; pub use literal::*; +pub use mask::*; pub use merge::*; pub use not::*; pub use operators::*; pub use pack::*; pub use root::*; -pub use scalar_fn::*; pub use select::*; diff --git a/vortex-array/src/expr/exprs/not.rs b/vortex-array/src/expr/exprs/not.rs index 952384e14d8..5185af58c97 100644 --- a/vortex-array/src/expr/exprs/not.rs +++ b/vortex-array/src/expr/exprs/not.rs @@ -8,94 +8,92 @@ use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_vector::Vector; +use vortex_vector::Datum; use crate::ArrayRef; use crate::compute::invert; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::EmptyOptions; use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::ScalarFnExprExt; use crate::expr::VTable; use crate::expr::VTableExt; -use crate::expr::functions::EmptyOptions; -use crate::scalar_fns::not; /// Expression that logically inverts boolean values. pub struct Not; impl VTable for Not { - type Instance = (); + type Options = EmptyOptions; fn id(&self) -> ExprId { - ExprId::new_ref("vortex.not") + ExprId::from("vortex.not") } - fn serialize(&self, _instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { Ok(Some(vec![])) } - fn deserialize(&self, _metadata: &[u8]) -> VortexResult> { - Ok(Some(())) + fn deserialize(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyOptions) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "Not expression expects exactly one child, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("input"), _ => unreachable!("Invalid child index {} for Not expression", child_idx), } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "not(")?; expr.child(0).fmt_sql(f)?; write!(f, ")") } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let child_dtype = expr.child(0).return_dtype(scope)?; + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let child_dtype = &arg_dtypes[0]; if !matches!(child_dtype, DType::Bool(_)) { vortex_bail!( "Not expression expects a boolean child, got: {}", child_dtype ); } - Ok(child_dtype) + Ok(child_dtype.clone()) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + _options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let child_result = expr.child(0).evaluate(scope)?; invert(&child_result) } - fn execute(&self, _data: &Self::Instance, mut args: ExecutionArgs) -> VortexResult { - let child = args.vectors.pop().vortex_expect("Missing input child"); + fn execute(&self, _data: &Self::Options, mut args: ExecutionArgs) -> VortexResult { + let child = args.datums.pop().vortex_expect("Missing input child"); Ok(child.into_bool().not().into()) } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _options: &Self::Options) -> bool { false } - - fn expr_v2(&self, view: &ExpressionView) -> VortexResult { - ScalarFnExprExt::try_new_expr(¬::NotFn, EmptyOptions, view.children().clone()) - } } /// Creates an expression that logically inverts boolean values. @@ -107,7 +105,7 @@ impl VTable for Not { /// let expr = not(root()); /// ``` pub fn not(operand: Expression) -> Expression { - Not.new_expr((), vec![operand]) + Not.new_expr(EmptyOptions, vec![operand]) } #[cfg(test)] diff --git a/vortex-array/src/expr/exprs/pack.rs b/vortex-array/src/expr/exprs/pack.rs index a03136ca6f4..92dd76221bb 100644 --- a/vortex-array/src/expr/exprs/pack.rs +++ b/vortex-array/src/expr/exprs/pack.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hash; @@ -12,17 +13,17 @@ use vortex_dtype::FieldNames; use vortex_dtype::Nullability; use vortex_dtype::StructFields; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; use vortex_proto::expr as pb; +use vortex_vector::Datum; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::StructArray; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; use crate::expr::VTable; use crate::expr::VTableExt; use crate::validity::Validity; @@ -36,14 +37,25 @@ pub struct PackOptions { pub nullability: Nullability, } +impl Display for PackOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "names: [{}], nullability: {}", + self.names.iter().join(", "), + self.nullability + ) + } +} + impl VTable for Pack { - type Instance = PackOptions; + type Options = PackOptions; fn id(&self) -> ExprId { ExprId::new_ref("vortex.pack") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::PackOpts { paths: instance.names.iter().map(|n| n.to_string()).collect(), @@ -53,32 +65,24 @@ impl VTable for Pack { )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let opts = pb::PackOpts::decode(metadata)?; let names: FieldNames = opts .paths .iter() .map(|name| FieldName::from(name.as_str())) .collect(); - Ok(Some(PackOptions { + Ok(PackOptions { names, nullability: opts.nullable.into(), - })) + }) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - let instance = expr.data(); - if expr.children().len() != instance.names.len() { - vortex_bail!( - "Pack expression expects {} children, got {}", - instance.names.len(), - expr.children().len() - ); - } - Ok(()) + fn arity(&self, options: &Self::Options) -> Arity { + Arity::Exact(options.names.len()) } - fn child_name(&self, instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, instance: &Self::Options, child_idx: usize) -> ChildName { match instance.names.get(child_idx) { Some(name) => ChildName::from(name.inner().clone()), None => unreachable!( @@ -89,87 +93,68 @@ impl VTable for Pack { } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "pack(")?; - for (i, (name, child)) in expr - .data() - .names - .iter() - .zip(expr.children().iter()) - .enumerate() - { + for (i, (name, child)) in options.names.iter().zip(expr.children().iter()).enumerate() { write!(f, "{}: ", name)?; child.fmt_sql(f)?; - if i + 1 < expr.data().names.len() { + if i + 1 < options.names.len() { write!(f, ", ")?; } } - write!(f, "){}", expr.data().nullability) + write!(f, "){}", options.nullability) } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let value_dtypes = expr - .children() - .iter() - .map(|child| child.return_dtype(scope)) - .collect::>>()?; + fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { Ok(DType::Struct( - StructFields::new(expr.data().names.clone(), value_dtypes), - expr.data().nullability, + StructFields::new(options.names.clone(), arg_dtypes.to_vec()), + options.nullability, )) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let len = scope.len(); let value_arrays = expr .children() .iter() - .zip_eq(expr.data().names.iter()) + .zip_eq(options.names.iter()) .map(|(child_expr, name)| { child_expr .evaluate(scope) .map_err(|e| e.with_context(format!("Can't evaluate '{name}'"))) }) .process_results(|it| it.collect::>())?; - let validity = match expr.data().nullability { + let validity = match options.nullability { Nullability::NonNullable => Validity::NonNullable, Nullability::Nullable => Validity::AllValid, }; - Ok( - StructArray::try_new(expr.data().names.clone(), value_arrays, len, validity)? - .into_array(), - ) + Ok(StructArray::try_new(options.names.clone(), value_arrays, len, validity)?.into_array()) + } + + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { + todo!() } // This applies a nullability - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _instance: &Self::Options) -> bool { false } } -impl ExpressionView<'_, Pack> { - pub fn field(&self, field_name: &FieldName) -> VortexResult { - let idx = self - .data() - .names - .iter() - .position(|name| name == field_name) - .ok_or_else(|| { - vortex_err!( - "Cannot find field {} in pack fields {:?}", - field_name, - self.data().names - ) - })?; - - Ok(self.child(idx).clone()) - } -} - /// Creates an expression that packs values into a struct with named fields. /// /// ```rust diff --git a/vortex-array/src/expr/exprs/root.rs b/vortex-array/src/expr/exprs/root.rs index d3af4e1e089..2433ada7ea0 100644 --- a/vortex-array/src/expr/exprs/root.rs +++ b/vortex-array/src/expr/exprs/root.rs @@ -8,13 +8,14 @@ use vortex_dtype::FieldPath; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_vector::Vector; +use vortex_vector::Datum; use crate::ArrayRef; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::EmptyOptions; use crate::expr::ExecutionArgs; use crate::expr::ExprId; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -26,67 +27,72 @@ use crate::expr::stats::Stat; pub struct Root; impl VTable for Root { - type Instance = (); + type Options = EmptyOptions; fn id(&self) -> ExprId { ExprId::from("vortex.root") } - fn serialize(&self, _instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, _instance: &Self::Options) -> VortexResult>> { Ok(Some(vec![])) } - fn deserialize(&self, _metadata: &[u8]) -> VortexResult> { - Ok(Some(())) + fn deserialize(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyOptions) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if !expr.children().is_empty() { - vortex_bail!( - "Root expression does not have children, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(0) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { unreachable!( "Root expression does not have children, got index {}", child_idx ) } - fn fmt_sql(&self, _expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + _options: &Self::Options, + _expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "$") } - fn return_dtype(&self, _expr: &ExpressionView, scope: &DType) -> VortexResult { - Ok(scope.clone()) + fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult { + vortex_bail!("Root expression does not support return_dtype") } - fn evaluate(&self, _expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + _options: &Self::Options, + _expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { Ok(scope.clone()) } - fn execute(&self, _data: &Self::Instance, _args: ExecutionArgs) -> VortexResult { + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { vortex_bail!("Root expression is not executable") } fn stat_expression( &self, - _expr: &ExpressionView, + _options: &Self::Options, + _expr: &Expression, stat: Stat, catalog: &dyn StatsCatalog, ) -> Option { catalog.stats_ref(&FieldPath::root(), stat) } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _options: &Self::Options) -> bool { false } } @@ -96,7 +102,7 @@ impl VTable for Root { /// Returns the entire input array as passed to the expression evaluator. /// This is commonly used as the starting point for field access and other operations. pub fn root() -> Expression { - Root.try_new_expr((), vec![]) + Root.try_new_expr(EmptyOptions, vec![]) .vortex_expect("Failed to create Root expression") } diff --git a/vortex-array/src/expr/exprs/scalar_fn.rs b/vortex-array/src/expr/exprs/scalar_fn.rs deleted file mode 100644 index 20d898c8530..00000000000 --- a/vortex-array/src/expr/exprs/scalar_fn.rs +++ /dev/null @@ -1,185 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::fmt::Debug; -use std::fmt::Formatter; -use std::marker::PhantomData; -use std::sync::Arc; - -use itertools::Itertools; -use vortex_dtype::DType; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_session::SessionVar; -use vortex_vector::Datum; -use vortex_vector::ScalarOps; -use vortex_vector::Vector; -use vortex_vector::VectorMutOps; - -use crate::ArrayRef; -use crate::IntoArray; -use crate::arrays::ScalarFnArray; -use crate::expr::ChildName; -use crate::expr::ExecutionArgs; -use crate::expr::ExprId; -use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::StatsCatalog; -use crate::expr::VTable; -use crate::expr::functions; -use crate::expr::functions::ScalarFnVTable; -use crate::expr::functions::scalar::ScalarFn; -use crate::expr::stats::Stat; -use crate::expr::transform::rules::Matcher; - -/// An expression that wraps arbitrary scalar functions. -/// -/// Note that for backwards-compatibility, the `id` of this expression is the same as the -/// `id` of the underlying scalar function vtable, rather than being something constant like -/// `vortex.scalar_fn`. -pub struct ScalarFnExpr { - /// The vtable of the particular scalar function represented by this expression. - vtable: ScalarFnVTable, -} - -impl VTable for ScalarFnExpr { - type Instance = ScalarFn; - - fn id(&self) -> ExprId { - self.vtable.id() - } - - fn serialize(&self, func: &ScalarFn) -> VortexResult>> { - func.options().serialize() - } - - fn deserialize(&self, bytes: &[u8]) -> VortexResult> { - self.vtable.deserialize(bytes).map(Some) - } - - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - vortex_ensure!( - expr.data() - .signature() - .arity() - .matches(expr.children().len()), - "invalid number of arguments for scalar function" - ); - Ok(()) - } - - fn child_name(&self, _func: &ScalarFn, _child_idx: usize) -> ChildName { - "unknown".into() - } - - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}(", expr.data())?; - for (i, child) in expr.children().iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - child.fmt_sql(f)?; - } - write!(f, ")") - } - - fn fmt_data(&self, func: &ScalarFn, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", func) - } - - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let arg_dtypes: Vec<_> = expr - .children() - .iter() - .map(|e| e.return_dtype(scope)) - .try_collect()?; - expr.data().return_dtype(&arg_dtypes) - } - - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - let children: Vec<_> = expr - .children() - .iter() - .map(|child| child.evaluate(scope)) - .try_collect()?; - Ok(ScalarFnArray::try_new(expr.data().clone(), children, scope.len())?.into_array()) - } - - fn execute(&self, func: &ScalarFn, args: ExecutionArgs) -> VortexResult { - let expr_args = functions::ExecutionArgs::new( - args.row_count, - args.return_dtype, - args.dtypes, - args.vectors.into_iter().map(Datum::Vector).collect(), - ); - let result = func.execute(&expr_args)?; - Ok(match result { - Datum::Scalar(s) => s.repeat(args.row_count).freeze(), - Datum::Vector(v) => v, - }) - } - - fn stat_falsification( - &self, - _expr: &ExpressionView, - _catalog: &dyn StatsCatalog, - ) -> Option { - // TODO(ngates): ideally this is implemented as optimizer rules over a `falsify` and - // `verify` expressions. - todo!() - } - - fn stat_expression( - &self, - _expr: &ExpressionView, - _stat: Stat, - _catalog: &dyn StatsCatalog, - ) -> Option { - // TODO(ngates): ideally this is implemented specifically for the Zoned layout, no one - // else needs to know what a specific stat over a column resolves to. - todo!() - } - - fn is_null_sensitive(&self, _func: &ScalarFn) -> bool { - todo!() - } -} - -/// A matcher that matches any scalar function expression. -#[derive(Debug)] -pub struct AnyScalarFn; -impl Matcher for AnyScalarFn { - type View<'a> = &'a ScalarFn; - - fn try_match(parent: &Expression) -> Option> { - Some(parent.as_opt::()?.data()) - } -} - -/// A matcher that matches a specific scalar function expression. -#[derive(Debug)] -pub struct ExactScalarFn(PhantomData); -impl Matcher for ExactScalarFn { - type View<'a> = &'a F::Options; - - fn try_match(parent: &Expression) -> Option> { - let expr_view = parent.as_opt::()?; - expr_view.data().as_any().downcast_ref::() - } -} - -/// Expression factory functions for ScalarFn vtables. -pub trait ScalarFnExprExt: functions::VTable { - fn try_new_expr( - &'static self, - options: Self::Options, - children: impl Into>, - ) -> VortexResult { - let expr_vtable = ScalarFnExpr { - vtable: ScalarFnVTable::new_static(self), - }; - let scalar_fn = ScalarFn::new_static(self, options); - Expression::try_new(expr_vtable, scalar_fn, children) - } -} -impl ScalarFnExprExt for V {} diff --git a/vortex-array/src/expr/exprs/select/mod.rs b/vortex-array/src/expr/exprs/select.rs similarity index 70% rename from vortex-array/src/expr/exprs/select/mod.rs rename to vortex-array/src/expr/exprs/select.rs index 8b27fb73d3a..a074d8bf643 100644 --- a/vortex-array/src/expr/exprs/select/mod.rs +++ b/vortex-array/src/expr/exprs/select.rs @@ -1,8 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -pub mod transform; - use std::fmt::Display; use std::fmt::Formatter; use std::sync::Arc; @@ -18,20 +16,25 @@ use vortex_error::vortex_err; use vortex_proto::expr::FieldNames as ProtoFieldNames; use vortex_proto::expr::SelectOpts; use vortex_proto::expr::select_opts::Opts; -use vortex_vector::Vector; +use vortex_vector::Datum; +use vortex_vector::StructDatum; +use vortex_vector::VectorOps; use vortex_vector::struct_::StructVector; use crate::ArrayRef; use crate::IntoArray; use crate::ToCanonical; +use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::ExecutionArgs; use crate::expr::ExprId; -use crate::expr::ExpressionView; +use crate::expr::SimplifyCtx; use crate::expr::VTable; use crate::expr::VTableExt; use crate::expr::expression::Expression; use crate::expr::field::DisplayFieldNames; +use crate::expr::get_item; +use crate::expr::pack; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum FieldSelection { @@ -42,13 +45,13 @@ pub enum FieldSelection { pub struct Select; impl VTable for Select { - type Instance = FieldSelection; + type Options = FieldSelection; fn id(&self) -> ExprId { ExprId::new_ref("vortex.select") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { let opts = match instance { FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames { names: fields.iter().map(|f| f.to_string()).collect(), @@ -62,7 +65,7 @@ impl VTable for Select { Ok(Some(select_opts.encode_to_vec())) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let prost_metadata = SelectOpts::decode(metadata)?; let select_opts = prost_metadata @@ -78,29 +81,28 @@ impl VTable for Select { )), }; - Ok(Some(field_selection)) + Ok(field_selection) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "Select expression requires exactly 1 child, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::new_ref("child"), _ => unreachable!(), } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { - expr.child().fmt_sql(f)?; - match expr.data() { + fn fmt_sql( + &self, + selection: &FieldSelection, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + expr.child(0).fmt_sql(f)?; + match selection { FieldSelection::Include(fields) => { write!(f, "{{{}}}", DisplayFieldNames(fields)) } @@ -110,27 +112,17 @@ impl VTable for Select { } } - fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result { - let names = match instance { - FieldSelection::Include(names) => { - write!(f, "include=")?; - names - } - FieldSelection::Exclude(names) => { - write!(f, "exclude=")?; - names - } - }; - write!(f, "{{{}}}", DisplayFieldNames(names)) - } - - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let child_dtype = expr.child().return_dtype(scope)?; + fn return_dtype( + &self, + selection: &FieldSelection, + arg_dtypes: &[DType], + ) -> VortexResult { + let child_dtype = &arg_dtypes[0]; let child_struct_dtype = child_dtype .as_struct_fields_opt() .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?; - let projected = match expr.data() { + let projected = match selection { FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?, FieldSelection::Exclude(fields) => child_struct_dtype .names() @@ -144,9 +136,49 @@ impl VTable for Select { Ok(DType::Struct(projected, child_dtype.nullability())) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - let batch = expr.child().evaluate(scope)?.to_struct(); - Ok(match expr.data() { + fn simplify( + &self, + options: &Self::Options, + expr: &Expression, + ctx: &dyn SimplifyCtx, + ) -> VortexResult> { + let child = expr.child(0); + let child_dtype = ctx.return_dtype(child)?; + let child_nullability = child_dtype.nullability(); + + let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| { + vortex_err!( + "Select child must return a struct dtype, however it was a {}", + child_dtype + ) + })?; + + let expr = pack( + options + .as_include_names(child_dtype.names()) + .map_err(|e| { + e.with_context(format!( + "Select fields {:?} must be a subset of child fields {:?}", + options, + child_dtype.names() + )) + })? + .iter() + .map(|name| (name.clone(), get_item(name.clone(), child.clone()))), + child_nullability, + ); + + Ok(Some(expr)) + } + + fn evaluate( + &self, + selection: &FieldSelection, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + let batch = expr.child(0).evaluate(scope)?.to_struct(); + Ok(match selection { FieldSelection::Include(f) => batch.project(f.as_ref()), FieldSelection::Exclude(names) => { let included_names = batch @@ -161,12 +193,7 @@ impl VTable for Select { .into_array()) } - fn execute(&self, selection: &FieldSelection, mut args: ExecutionArgs) -> VortexResult { - let child = args - .vectors - .pop() - .vortex_expect("Missing input child") - .into_struct(); + fn execute(&self, selection: &FieldSelection, mut args: ExecutionArgs) -> VortexResult { let child_fields = args .dtypes .pop() @@ -194,24 +221,44 @@ impl VTable for Select { .try_collect(), }?; - let (fields, mask) = child.into_parts(); - let new_fields = field_indices - .iter() - .map(|&idx| fields[idx].clone()) - .collect(); - Ok(unsafe { StructVector::new_unchecked(Arc::new(new_fields), mask) }.into()) + let child = args + .datums + .pop() + .vortex_expect("Missing input child") + .into_struct(); + + Ok(match child { + StructDatum::Scalar(s) => StructDatum::Scalar( + select_from_struct_vector(s.value(), &field_indices)?.scalar_at(0), + ), + StructDatum::Vector(v) => { + StructDatum::Vector(select_from_struct_vector(&v, &field_indices)?) + } + } + .into()) } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _instance: &Self::Options) -> bool { // If this type-checks its infallible. false } } +fn select_from_struct_vector( + vec: &StructVector, + field_indices: &[usize], +) -> VortexResult { + let new_fields = field_indices + .iter() + .map(|&idx| vec.fields()[idx].clone()) + .collect(); + Ok(unsafe { StructVector::new_unchecked(Arc::new(new_fields), vec.validity().clone()) }) +} + /// Creates an expression that selects (includes) specific fields from an array. /// /// Projects only the specified fields from the child expression, which must be of DType struct. @@ -239,34 +286,6 @@ pub fn select_exclude(fields: impl Into, child: Expression) -> Expre .vortex_expect("Failed to create Select expression") } -impl ExpressionView<'_, Select> { - pub fn child(&self) -> &Expression { - &self.children()[0] - } - - /// Turn the select expression into an `include`, relative to a provided array of field names. - /// - /// For example: - /// ```rust - /// # use vortex_array::expr::{root, Select}; - /// # use vortex_array::expr::{FieldSelection, select, select_exclude}; - /// # use vortex_dtype::FieldNames; - /// let field_names = FieldNames::from(["a", "b", "c"]); - /// let include = select(["a"], root()); - /// let exclude = select_exclude(["b", "c"], root()); - /// assert_eq!( - /// &include.as_::().as_include(&field_names).unwrap(), - /// ); - /// ``` - pub fn as_include(&self, field_names: &FieldNames) -> VortexResult { - Select.try_new_expr( - FieldSelection::Include(self.data().as_include_names(field_names)?), - [self.child().clone()], - ) - } -} - impl FieldSelection { pub fn include(columns: FieldNames) -> Self { assert_eq!(columns.iter().unique().collect_vec().len(), columns.len()); @@ -331,12 +350,16 @@ mod tests { use vortex_dtype::FieldName; use vortex_dtype::FieldNames; use vortex_dtype::Nullability; + use vortex_dtype::Nullability::Nullable; + use vortex_dtype::PType::I32; + use vortex_dtype::StructFields; use super::select; use super::select_exclude; use crate::IntoArray; use crate::ToCanonical; use crate::arrays::StructArray; + use crate::expr::exprs::pack::Pack; use crate::expr::exprs::root::root; use crate::expr::exprs::select::Select; use crate::expr::test_harness; @@ -421,14 +444,50 @@ mod tests { assert_eq!( &include .as_::() - .data() .as_include_names(&field_names) .unwrap() ); } + + #[test] + fn test_remove_select_rule() { + let dtype = DType::Struct( + StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]), + Nullable, + ); + let e = select(["a", "b"], root()); + + let result = e.simplify(&dtype).unwrap(); + + assert!(result.is::()); + assert!(result.return_dtype(&dtype).unwrap().is_nullable()); + } + + #[test] + fn test_remove_select_rule_exclude_fields() { + use crate::expr::exprs::select::select_exclude; + + let dtype = DType::Struct( + StructFields::new( + ["a", "b", "c"].into(), + vec![I32.into(), I32.into(), I32.into()], + ), + Nullable, + ); + let e = select_exclude(["c"], root()); + + let result = e.simplify(&dtype).unwrap(); + + assert!(result.is::()); + + // Should exclude "c" and include "a" and "b" + let result_dtype = result.return_dtype(&dtype).unwrap(); + assert!(result_dtype.is_nullable()); + let fields = result_dtype.as_struct_fields_opt().unwrap(); + assert_eq!(fields.names().as_ref(), &["a", "b"]); + } } diff --git a/vortex-array/src/expr/exprs/select/transform.rs b/vortex-array/src/expr/exprs/select/transform.rs deleted file mode 100644 index 7181d0b2f18..00000000000 --- a/vortex-array/src/expr/exprs/select/transform.rs +++ /dev/null @@ -1,118 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; -use vortex_error::vortex_err; - -use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::exprs::get_item::get_item; -use crate::expr::exprs::pack::pack; -use crate::expr::exprs::select::Select; -use crate::expr::transform::rules::ReduceRule; -use crate::expr::transform::rules::TypedRuleContext; - -/// Rule that removes Select expressions by converting them to Pack + GetItem. -/// -/// Transforms: `select(["a", "b"], expr)` → `pack(a: get_item("a", expr), b: get_item("b", expr))` -#[derive(Debug, Default)] -pub struct RemoveSelectRule; - -impl ReduceRule for RemoveSelectRule { - fn reduce( - &self, - select: &ExpressionView(); - let result = RemoveSelectRule.reduce(&select_view, &ctx).unwrap(); - - assert!(result.is_some()); - let transformed = result.unwrap(); - assert!(transformed.is::()); - assert!(transformed.return_dtype(&dtype).unwrap().is_nullable()); - } - - #[test] - fn test_remove_select_rule_exclude_fields() { - use crate::expr::exprs::select::select_exclude; - - let dtype = DType::Struct( - StructFields::new( - ["a", "b", "c"].into(), - vec![I32.into(), I32.into(), I32.into()], - ), - Nullable, - ); - let e = select_exclude(["c"], root()); - - let ctx = TypedRuleContext::new(dtype.clone()); - let select_view = e.as_::() { - self.fields.extend(sel.data().field_names().iter().cloned()); + if let Some(field_selection) = node.as_opt::