diff --git a/Cargo.lock b/Cargo.lock index d0dd8ec56ba..67d74e95b5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8529,6 +8529,7 @@ dependencies = [ "vortex-error", "vortex-mask", "vortex-scalar", + "vortex-session", ] [[package]] @@ -8871,6 +8872,7 @@ dependencies = [ "vortex-sequence", "vortex-session", "vortex-utils", + "vortex-vector", "vortex-zstd", ] @@ -8878,6 +8880,7 @@ dependencies = [ name = "vortex-mask" version = "0.1.0" dependencies = [ + "arrow-buffer", "itertools 0.14.0", "rstest", "serde", @@ -8960,6 +8963,7 @@ dependencies = [ "vortex-error", "vortex-mask", "vortex-scalar", + "vortex-session", ] [[package]] diff --git a/encodings/alp/src/alp/array.rs b/encodings/alp/src/alp/array.rs index d2812f32d9f..fb0e432bc79 100644 --- a/encodings/alp/src/alp/array.rs +++ b/encodings/alp/src/alp/array.rs @@ -15,7 +15,9 @@ use vortex_array::DeserializeMetadata; use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; -use vortex_array::execution::ExecutionCtx; +use vortex_array::kernel::BindCtx; +use vortex_array::kernel::KernelRef; +use vortex_array::kernel::kernel; use vortex_array::patches::Patches; use vortex_array::patches::PatchesMetadata; use vortex_array::serde::ArrayChildren; @@ -41,7 +43,6 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; -use vortex_vector::Vector; use crate::ALPFloat; use crate::alp::Exponents; @@ -140,17 +141,16 @@ impl VTable for ALPVTable { ) } - fn batch_execute(array: &ALPArray, ctx: &mut ExecutionCtx) -> VortexResult { - let encoded_vector = array.encoded().batch_execute(ctx)?; - - let patches_vectors = if let Some(patches) = array.patches() { + fn bind_kernel(array: &ALPArray, ctx: &mut BindCtx) -> VortexResult { + let encoded = array.encoded().bind_kernel(ctx)?; + let patches_kernels = if let Some(patches) = array.patches() { Some(( - patches.indices().batch_execute(ctx)?, - patches.values().batch_execute(ctx)?, + patches.indices().bind_kernel(ctx)?, + patches.values().bind_kernel(ctx)?, patches .chunk_offsets() .as_ref() - .map(|co| co.batch_execute(ctx)) + .map(|co| co.bind_kernel(ctx)) .transpose()?, )) } else { @@ -161,7 +161,24 @@ impl VTable for ALPVTable { let exponents = array.exponents(); match_each_alp_float_ptype!(array.dtype().as_ptype(), |T| { - decompress_into_vector::(encoded_vector, exponents, patches_vectors, patches_offset) + Ok(kernel(move || { + let encoded_vector = encoded.execute()?; + let patches_vectors = match patches_kernels { + Some((idx_kernel, val_kernel, co_kernel)) => Some(( + idx_kernel.execute()?, + val_kernel.execute()?, + co_kernel.map(|k| k.execute()).transpose()?, + )), + None => None, + }; + + decompress_into_vector::( + encoded_vector, + exponents, + patches_vectors, + patches_offset, + ) + })) }) } } @@ -456,7 +473,9 @@ mod tests { use std::sync::LazyLock; use rstest::rstest; + use vortex_array::VectorExecutor; use vortex_array::arrays::PrimitiveArray; + use vortex_array::session::ArraySession; use vortex_array::vtable::ValidityHelper; use vortex_dtype::PTypeDowncast; use vortex_session::VortexSession; @@ -464,7 +483,8 @@ mod tests { use super::*; - static SESSION: LazyLock = LazyLock::new(VortexSession::empty); + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); #[rstest] #[case(0)] @@ -480,7 +500,7 @@ mod tests { let values = PrimitiveArray::from_iter((0..size).map(|i| i as f32)); let encoded = alp_encode(&values, None).unwrap(); - let result_vector = encoded.to_array().execute(&SESSION).unwrap(); + let result_vector = encoded.to_array().execute_vector(&SESSION).unwrap(); // Compare against the traditional array-based decompress path let expected = decompress_into_array(encoded); @@ -504,7 +524,7 @@ mod tests { let values = PrimitiveArray::from_iter((0..size).map(|i| i as f64)); let encoded = alp_encode(&values, None).unwrap(); - let result_vector = encoded.to_array().execute(&SESSION).unwrap(); + let result_vector = encoded.to_array().execute_vector(&SESSION).unwrap(); // Compare against the traditional array-based decompress path let expected = decompress_into_array(encoded); @@ -534,7 +554,7 @@ mod tests { let encoded = alp_encode(&array, None).unwrap(); assert!(encoded.patches().unwrap().array_len() > 0); - let result_vector = encoded.to_array().execute(&SESSION).unwrap(); + let result_vector = encoded.to_array().execute_vector(&SESSION).unwrap(); // Compare against the traditional array-based decompress path let expected = decompress_into_array(encoded); @@ -562,7 +582,7 @@ mod tests { let array = PrimitiveArray::from_option_iter(values); let encoded = alp_encode(&array, None).unwrap(); - let result_vector = encoded.to_array().execute(&SESSION).unwrap(); + let result_vector = encoded.to_array().execute_vector(&SESSION).unwrap(); // Compare against the traditional array-based decompress path let expected = decompress_into_array(encoded); @@ -601,7 +621,7 @@ mod tests { let encoded = alp_encode(&array, None).unwrap(); assert!(encoded.patches().unwrap().array_len() > 0); - let result_vector = encoded.to_array().execute(&SESSION).unwrap(); + let result_vector = encoded.to_array().execute_vector(&SESSION).unwrap(); // Compare against the traditional array-based decompress path let expected = decompress_into_array(encoded); @@ -643,7 +663,7 @@ mod tests { let slice_len = slice_end - slice_start; let sliced_encoded = encoded.slice(slice_start..slice_end); - let result_vector = sliced_encoded.execute(&SESSION).unwrap(); + let result_vector = sliced_encoded.execute_vector_optimized(&SESSION).unwrap(); let result_primitive = result_vector.into_primitive().into_f64(); for idx in 0..slice_len { diff --git a/encodings/datetime-parts/Cargo.toml b/encodings/datetime-parts/Cargo.toml index 55d33e567c0..ba791407857 100644 --- a/encodings/datetime-parts/Cargo.toml +++ b/encodings/datetime-parts/Cargo.toml @@ -25,6 +25,7 @@ vortex-dtype = { workspace = true } vortex-error = { workspace = true } vortex-mask = { workspace = true } vortex-scalar = { workspace = true } +vortex-session = { workspace = true } [dev-dependencies] rstest = { workspace = true } diff --git a/encodings/datetime-parts/src/lib.rs b/encodings/datetime-parts/src/lib.rs index 3188a6a76bf..f86e2b8bf73 100644 --- a/encodings/datetime-parts/src/lib.rs +++ b/encodings/datetime-parts/src/lib.rs @@ -3,18 +3,31 @@ pub use array::*; pub use compress::*; +use vortex_array::session::ArraySessionExt; +use vortex_array::vtable::ArrayVTableExt; +use vortex_session::VortexSession; mod array; mod canonical; mod compress; mod compute; mod ops; +mod rules; mod timestamp; +/// Initialize the DateTimeParts encoding in the given session. +pub fn initialize(session: &mut VortexSession) { + session.arrays().register(DateTimePartsVTable.as_vtable()); + // session + // .arrays_mut() + // .optimizer_mut() + // .register_reduce_rule(DateTimePartsExpandRule); +} + #[cfg(test)] mod test { - use vortex_array::ProstMetadata; use vortex_array::test_harness::check_metadata; + use vortex_array::ProstMetadata; use vortex_dtype::PType; use crate::DateTimePartsMetadata; diff --git a/encodings/datetime-parts/src/rules.rs b/encodings/datetime-parts/src/rules.rs new file mode 100644 index 00000000000..0a131f17b41 --- /dev/null +++ b/encodings/datetime-parts/src/rules.rs @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::ConstantArray; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::optimizer::rules::ArrayReduceRule; +use vortex_array::optimizer::rules::Exact; +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_dtype::datetime::TemporalMetadata; +use vortex_dtype::datetime::TimeUnit; +use vortex_dtype::DType; +use vortex_dtype::PType; +use vortex_error::vortex_panic; +use vortex_error::VortexResult; + +use crate::DateTimePartsArray; +use crate::DateTimePartsVTable; + +/// Expand a date-time-parts array into an expression that evaluates to the timestamp. +#[derive(Debug)] +pub(crate) struct DateTimePartsExpandRule; + +impl ArrayReduceRule> for DateTimePartsExpandRule { + fn matcher(&self) -> Exact { + Exact::from(&DateTimePartsVTable) + } + + fn reduce(&self, array: &DateTimePartsArray) -> VortexResult> { + let DType::Extension(ext) = array.dtype().clone() else { + vortex_panic!(ComputeError: "expected dtype to be DType::Extension variant") + }; + + let Ok(temporal_metadata) = TemporalMetadata::try_from(ext.as_ref()) else { + vortex_panic!(ComputeError: "must decode TemporalMetadata from extension metadata"); + }; + + let divisor: i64 = match temporal_metadata.time_unit() { + TimeUnit::Nanoseconds => 1_000_000_000, + TimeUnit::Microseconds => 1_000_000, + TimeUnit::Milliseconds => 1_000, + TimeUnit::Seconds => 1, + TimeUnit::Days => vortex_panic!(InvalidArgument: "cannot decode into TimeUnit::D"), + }; + + // Up-cast days to i64 for computation. + let days = array + .days() + .cast(DType::Primitive(PType::I64, array.dtype().nullability()))?; + + // Multiply days by the number of seconds in a day and the unit divisor. + let days = days.mul(ConstantArray::new(divisor * 86_400, array.len()).into_array())?; + + // Multiply the seconds by the unit divisor. + let seconds = array + .seconds() + .cast(DType::Primitive(PType::I64, array.dtype().nullability()))? + .mul(ConstantArray::new(divisor, array.len()).into_array())?; + + // The subseconds are already in the correct unit, just cast to i64. + let subseconds = array + .subseconds() + .cast(DType::Primitive(PType::I64, array.dtype().nullability()))?; + + // Sum the three components together. + Ok(Some(days.add(seconds)?.add(subseconds)?)) + } +} diff --git a/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs b/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs index 7c6cb0eb0b3..0d9b2628260 100644 --- a/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs +++ b/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs @@ -204,6 +204,7 @@ mod tests { use std::sync::LazyLock; use vortex_array::IntoArray; + use vortex_array::VectorExecutor; use vortex_array::assert_arrays_eq; use vortex_array::validity::Validity; use vortex_buffer::Buffer; @@ -536,7 +537,7 @@ mod tests { let unpacked_array = unpack_array(&bitpacked); // Method 3: Using the execute() method (this is what would be used in production). - let executed = bitpacked.into_array().execute(&SESSION).unwrap(); + let executed = bitpacked.into_array().execute_vector(&SESSION).unwrap(); // All three should produce the same length. assert_eq!(vector_result.len(), array.len(), "vector length mismatch"); @@ -556,7 +557,10 @@ mod tests { // Verify that the execute() method works correctly by comparing with unpack_array. // We convert unpack_array result to a vector using execute() to compare. - let unpacked_executed = unpacked_array.into_array().execute(&SESSION).unwrap(); + let unpacked_executed = unpacked_array + .into_array() + .execute_vector(&SESSION) + .unwrap(); match (&executed, &unpacked_executed) { (Vector::Primitive(exec_pv), Vector::Primitive(unpack_pv)) => { assert_eq!( @@ -593,7 +597,7 @@ mod tests { let sliced_bp = sliced.as_::(); let vector_result = unpack_to_primitive_vector(sliced_bp); let unpacked_array = unpack_array(sliced_bp); - let executed = sliced.execute(&SESSION).unwrap(); + let executed = sliced.execute_vector(&SESSION).unwrap(); assert_eq!( vector_result.len(), diff --git a/encodings/fastlanes/src/bitpacking/vtable/mod.rs b/encodings/fastlanes/src/bitpacking/vtable/mod.rs index e47ed45f044..8d9cf5b5ff8 100644 --- a/encodings/fastlanes/src/bitpacking/vtable/mod.rs +++ b/encodings/fastlanes/src/bitpacking/vtable/mod.rs @@ -4,7 +4,9 @@ use vortex_array::DeserializeMetadata; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; -use vortex_array::execution::ExecutionCtx; +use vortex_array::kernel::BindCtx; +use vortex_array::kernel::KernelRef; +use vortex_array::kernel::kernel; use vortex_array::patches::Patches; use vortex_array::patches::PatchesMetadata; use vortex_array::serde::ArrayChildren; @@ -23,7 +25,6 @@ use vortex_error::VortexError; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; -use vortex_vector::Vector; use vortex_vector::VectorMutOps; use crate::BitPackedArray; @@ -172,8 +173,11 @@ impl VTable for BitPackedVTable { ) } - fn batch_execute(array: &BitPackedArray, _ctx: &mut ExecutionCtx) -> VortexResult { - Ok(unpack_to_primitive_vector(array).freeze().into()) + fn bind_kernel(array: &BitPackedArray, _ctx: &mut BindCtx) -> VortexResult { + let array = array.clone(); + Ok(kernel(move || { + Ok(unpack_to_primitive_vector(&array).freeze().into()) + })) } } diff --git a/encodings/runend/Cargo.toml b/encodings/runend/Cargo.toml index 2eed95f4f16..05d1235727a 100644 --- a/encodings/runend/Cargo.toml +++ b/encodings/runend/Cargo.toml @@ -24,6 +24,7 @@ vortex-dtype = { workspace = true } vortex-error = { workspace = true } vortex-mask = { workspace = true } vortex-scalar = { workspace = true } +vortex-session = { workspace = true } [lints] workspace = true diff --git a/encodings/runend/src/lib.rs b/encodings/runend/src/lib.rs index c5e51d39c6a..3a521b53e53 100644 --- a/encodings/runend/src/lib.rs +++ b/encodings/runend/src/lib.rs @@ -11,6 +11,7 @@ pub mod compress; mod compute; mod iter; mod ops; +mod rules; #[doc(hidden)] pub mod _benchmarking { @@ -23,11 +24,17 @@ pub mod _benchmarking { use vortex_array::ArrayBufferVisitor; use vortex_array::ArrayChildVisitor; use vortex_array::Canonical; +use vortex_array::session::ArraySession; +use vortex_array::session::ArraySessionExt; +use vortex_array::vtable::ArrayVTableExt; use vortex_array::vtable::EncodeVTable; use vortex_array::vtable::VisitorVTable; use vortex_error::VortexResult; +use vortex_session::SessionExt; +use vortex_session::VortexSession; use crate::compress::runend_encode; +use crate::rules::RunEndScalarFnRule; impl EncodeVTable for RunEndVTable { fn encode( @@ -59,6 +66,15 @@ impl VisitorVTable for RunEndVTable { } } +/// Initialize run-end encoding in the given session. +pub fn initialize(session: &mut VortexSession) { + session.arrays().register(RunEndVTable.as_vtable()); + session + .get_mut::() + .optimizer_mut() + .register_parent_rule(RunEndScalarFnRule); +} + #[cfg(test)] mod tests { use vortex_array::ProstMetadata; diff --git a/encodings/runend/src/rules.rs b/encodings/runend/src/rules.rs new file mode 100644 index 00000000000..af1ad9c5bbc --- /dev/null +++ b/encodings/runend/src/rules.rs @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::AnyScalarFn; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::ConstantVTable; +use vortex_array::arrays::ScalarFnArray; +use vortex_array::optimizer::rules::ArrayParentReduceRule; +use vortex_array::optimizer::rules::Exact; +use vortex_dtype::DType; +use vortex_error::VortexResult; + +use crate::RunEndArray; +use crate::RunEndVTable; + +/// A rule to push down scalar functions through run-end encoding into the values array. +/// +/// This only works if all other children of the scalar function array are constants. +#[derive(Debug)] +pub(crate) struct RunEndScalarFnRule; + +impl ArrayParentReduceRule, AnyScalarFn> for RunEndScalarFnRule { + fn child(&self) -> Exact { + Exact::from(&RunEndVTable) + } + + fn parent(&self) -> AnyScalarFn { + AnyScalarFn + } + + fn reduce_parent( + &self, + run_end: &RunEndArray, + parent: &ScalarFnArray, + child_idx: usize, + ) -> VortexResult> { + for (idx, child) in parent.children().iter().enumerate() { + if idx == child_idx { + // Skip ourselves + continue; + } + + if !child.is::() { + // We can only push down if all other children are constants + return Ok(None); + } + } + + // TODO(ngates): relax this constraint and implement run-end decoding for all vector types. + if !matches!(parent.dtype(), DType::Bool(_) | DType::Primitive(..)) { + return Ok(None); + } + + let values_len = run_end.values().len(); + let mut new_children = parent.children(); + for (idx, child) in new_children.iter_mut().enumerate() { + if idx == child_idx { + // Replace ourselves with run end values + *child = run_end.values().clone(); + continue; + } + + // Replace other children with their constant scalar value with length adjusted + // to the length of the run end values. + let constant = child.as_::(); + *child = ConstantArray::new(constant.scalar().clone(), values_len).into_array(); + } + + let new_values = + ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)? + .into_array(); + + Ok(Some( + RunEndArray::try_new_offset_length( + run_end.ends().clone(), + new_values, + run_end.offset(), + run_end.len(), + )? + .into_array(), + )) + } +} diff --git a/encodings/sequence/src/array.rs b/encodings/sequence/src/array.rs index 8be4483b419..8052afdbac9 100644 --- a/encodings/sequence/src/array.rs +++ b/encodings/sequence/src/array.rs @@ -15,7 +15,9 @@ use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; use vortex_array::arrays::PrimitiveArray; -use vortex_array::execution::ExecutionCtx; +use vortex_array::kernel::BindCtx; +use vortex_array::kernel::KernelRef; +use vortex_array::kernel::kernel; use vortex_array::serde::ArrayChildren; use vortex_array::stats::ArrayStats; use vortex_array::stats::StatsSetRef; @@ -48,7 +50,6 @@ use vortex_mask::Mask; use vortex_scalar::PValue; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; -use vortex_vector::Vector; use vortex_vector::primitive::PVector; vtable!(Sequence); @@ -268,23 +269,28 @@ impl VTable for SequenceVTable { )) } - fn batch_execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { + fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult { + let array = array.clone(); + Ok(match_each_native_ptype!(array.ptype(), |P| { let base = array.base().cast::

(); let multiplier = array.multiplier().cast::

(); - let values = if multiplier ==

::one() { - BufferMut::from_iter( - (0..array.len()).map(|i| base +

::from_usize(i).vortex_expect("must fit")), - ) - } else { - BufferMut::from_iter( - (0..array.len()) - .map(|i| base +

::from_usize(i).vortex_expect("must fit") * multiplier), - ) - }; - - PVector::

::new(values.freeze(), Mask::new_true(array.len())).into() + kernel(move || { + let values = + if multiplier ==

::one() { + BufferMut::from_iter( + (0..array.len()) + .map(|i| base +

::from_usize(i).vortex_expect("must fit")), + ) + } else { + BufferMut::from_iter((0..array.len()).map(|i| { + base +

::from_usize(i).vortex_expect("must fit") * multiplier + })) + }; + + Ok(PVector::

::new(values.freeze(), Mask::new_true(array.len())).into()) + }) })) } } diff --git a/vortex-array/src/array/mod.rs b/vortex-array/src/array/mod.rs index 71df1db4c34..360bb45ce6e 100644 --- a/vortex-array/src/array/mod.rs +++ b/vortex-array/src/array/mod.rs @@ -22,9 +22,6 @@ use vortex_error::vortex_ensure; use vortex_error::vortex_panic; use vortex_mask::Mask; use vortex_scalar::Scalar; -use vortex_session::VortexSession; -use vortex_vector::Vector; -use vortex_vector::VectorOps; use crate::ArrayEq; use crate::ArrayHash; @@ -35,6 +32,7 @@ use crate::arrays::BoolVTable; use crate::arrays::ConstantVTable; use crate::arrays::DecimalVTable; use crate::arrays::ExtensionVTable; +use crate::arrays::FilterArray; use crate::arrays::FixedSizeListVTable; use crate::arrays::ListViewVTable; use crate::arrays::NullVTable; @@ -49,11 +47,13 @@ use crate::compute::InvocationArgs; use crate::compute::IsConstantOpts; use crate::compute::Output; use crate::compute::is_constant_opts; -use crate::execution::ExecutionCtx; use crate::expr::stats::Precision; use crate::expr::stats::Stat; use crate::expr::stats::StatsProviderExt; use crate::hash; +use crate::kernel::BindCtx; +use crate::kernel::KernelRef; +use crate::kernel::ValidateKernel; use crate::serde::ArrayChildren; use crate::stats::StatsSetRef; use crate::vtable::ArrayId; @@ -96,6 +96,9 @@ pub trait Array: /// Performs a constant-time slice of the array. fn slice(&self, range: Range) -> ArrayRef; + /// Performs a constant-time filter of the array. + fn filter(&self, mask: &Mask) -> VortexResult; + /// Fetch the scalar at the given index. /// /// This method panics if the index is out of bounds for the array. @@ -190,7 +193,7 @@ pub trait Array: -> VortexResult>; /// Invoke the batch execution function for the array to produce a canonical vector. - fn batch_execute(&self, ctx: &mut ExecutionCtx) -> VortexResult; + fn bind_kernel(&self, ctx: &mut BindCtx) -> VortexResult; } impl Array for Arc { @@ -229,6 +232,10 @@ impl Array for Arc { self.as_ref().slice(range) } + fn filter(&self, mask: &Mask) -> VortexResult { + self.as_ref().filter(mask) + } + #[inline] fn scalar_at(&self, index: usize) -> Scalar { self.as_ref().scalar_at(index) @@ -294,8 +301,8 @@ impl Array for Arc { self.as_ref().invoke(compute_fn, args) } - fn batch_execute(&self, ctx: &mut ExecutionCtx) -> VortexResult { - self.as_ref().batch_execute(ctx) + fn bind_kernel(&self, ctx: &mut BindCtx) -> VortexResult { + self.as_ref().bind_kernel(ctx) } } @@ -362,14 +369,6 @@ impl dyn Array + '_ { } nbytes } - - /// Execute the array and return the resulting vector. - /// - /// This entry-point function will choose an appropriate CPU-based execution strategy. - pub fn execute(&self, session: &VortexSession) -> VortexResult { - let mut ctx = ExecutionCtx::new(session.clone()); - self.batch_execute(&mut ctx) - } } /// Trait for converting a type into a Vortex [`ArrayRef`]. @@ -506,6 +505,12 @@ impl Array for ArrayAdapter { sliced } + fn filter(&self, mask: &Mask) -> VortexResult { + vortex_ensure!(self.len() == mask.len(), "Filter mask length mismatch"); + Ok(V::filter(&self.0, mask)? + .unwrap_or_else(|| FilterArray::new(self.to_array(), mask.clone()).into_array())) + } + fn scalar_at(&self, index: usize) -> Scalar { assert!(index < self.len(), "index {index} out of bounds"); if self.is_invalid(index) { @@ -667,18 +672,17 @@ impl Array for ArrayAdapter { >::invoke(&self.0, compute_fn, args) } - fn batch_execute(&self, ctx: &mut ExecutionCtx) -> VortexResult { - let result = V::batch_execute(&self.0, ctx)?; - - // This check is so cheap we always run it. Whereas DType checks we only do in debug builds. - vortex_ensure!(result.len() == self.len(), "Result length mismatch"); - #[cfg(debug_assertions)] - vortex_ensure!( - vortex_vector::vector_matches_dtype(&result, self.dtype()), - "Executed vector dtype mismatch", - ); - - Ok(result) + fn bind_kernel(&self, ctx: &mut BindCtx) -> VortexResult { + let kernel = V::bind_kernel(&self.0, ctx)?; + if cfg!(debug_assertions) { + Ok(Box::new(ValidateKernel::new( + kernel, + self.dtype().clone(), + self.len(), + ))) + } else { + Ok(kernel) + } } } diff --git a/vortex-array/src/arrays/bool/vtable/mod.rs b/vortex-array/src/arrays/bool/vtable/mod.rs index eb049e9fd58..9ba2f7e9795 100644 --- a/vortex-array/src/arrays/bool/vtable/mod.rs +++ b/vortex-array/src/arrays/bool/vtable/mod.rs @@ -6,14 +6,13 @@ use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_vector::Vector; use vortex_vector::bool::BoolVector; use crate::DeserializeMetadata; use crate::ProstMetadata; use crate::SerializeMetadata; use crate::arrays::BoolArray; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; @@ -31,6 +30,8 @@ mod visitor; pub use operator::BoolMaskedValidityRule; +use crate::kernel::KernelRef; +use crate::kernel::ready; use crate::vtable::ArrayId; use crate::vtable::ArrayVTable; @@ -106,8 +107,10 @@ impl VTable for BoolVTable { BoolArray::try_new(buffer, metadata.offset as usize, len, validity) } - fn batch_execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { - Ok(BoolVector::new(array.bit_buffer().clone(), array.validity_mask()).into()) + fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult { + Ok(ready( + BoolVector::new(array.bit_buffer().clone(), array.validity_mask()).into(), + )) } } diff --git a/vortex-array/src/arrays/chunked/vtable/mod.rs b/vortex-array/src/arrays/chunked/vtable/mod.rs index bb0491e54db..2bfd6f23adc 100644 --- a/vortex-array/src/arrays/chunked/vtable/mod.rs +++ b/vortex-array/src/arrays/chunked/vtable/mod.rs @@ -9,6 +9,7 @@ use vortex_dtype::PType; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; +use vortex_mask::Mask; use vortex_vector::Vector; use vortex_vector::VectorMut; use vortex_vector::VectorMutOps; @@ -17,7 +18,10 @@ use crate::EmptyMetadata; use crate::ToCanonical; use crate::arrays::ChunkedArray; use crate::arrays::PrimitiveArray; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; +use crate::kernel::Kernel; +use crate::kernel::KernelRef; +use crate::kernel::PushDownResult; use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; @@ -36,6 +40,9 @@ mod visitor; vtable!(Chunked); +#[derive(Debug)] +pub struct ChunkedVTable; + impl VTable for ChunkedVTable { type Array = ChunkedArray; @@ -125,15 +132,35 @@ impl VTable for ChunkedVTable { }) } - fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - let mut vector = VectorMut::with_capacity(array.dtype(), 0); - for chunk in array.chunks() { - let chunk_vector = chunk.batch_execute(ctx)?; + fn bind_kernel(array: &Self::Array, ctx: &mut BindCtx) -> VortexResult { + Ok(Box::new(ChunkedKernel { + chunks: array + .chunks + .iter() + .map(|c| c.bind_kernel(ctx)) + .try_collect()?, + dtype: array.dtype.clone(), + })) + } +} + +#[derive(Debug)] +struct ChunkedKernel { + chunks: Vec, + dtype: DType, +} + +impl Kernel for ChunkedKernel { + fn execute(self: Box) -> VortexResult { + let mut vector = VectorMut::with_capacity(&self.dtype, 0); + for chunk in self.chunks { + let chunk_vector = chunk.execute()?; vector.extend_from_vector(&chunk_vector); } Ok(vector.freeze()) } -} -#[derive(Debug)] -pub struct ChunkedVTable; + fn push_down_filter(self: Box, _selection: &Mask) -> VortexResult { + Ok(PushDownResult::NotPushed(self)) + } +} diff --git a/vortex-array/src/arrays/constant/mod.rs b/vortex-array/src/arrays/constant/mod.rs index ca4fb78cd88..ddc4a79dbf3 100644 --- a/vortex-array/src/arrays/constant/mod.rs +++ b/vortex-array/src/arrays/constant/mod.rs @@ -6,7 +6,6 @@ pub use array::ConstantArray; mod compute; -mod vector; mod vtable; pub use vtable::ConstantVTable; diff --git a/vortex-array/src/arrays/constant/vector.rs b/vortex-array/src/arrays/constant/vector.rs deleted file mode 100644 index c4251534032..00000000000 --- a/vortex-array/src/arrays/constant/vector.rs +++ /dev/null @@ -1,98 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_dtype::DType; -use vortex_dtype::DecimalType; -use vortex_dtype::PrecisionScale; -use vortex_dtype::match_each_decimal_value_type; -use vortex_dtype::match_each_native_ptype; -use vortex_error::VortexExpect; -use vortex_scalar::BinaryScalar; -use vortex_scalar::BoolScalar; -use vortex_scalar::DecimalScalar; -use vortex_scalar::PrimitiveScalar; -use vortex_scalar::Scalar; -use vortex_scalar::Utf8Scalar; -use vortex_vector::VectorMut; -use vortex_vector::VectorMutOps; -use vortex_vector::binaryview::BinaryVectorMut; -use vortex_vector::binaryview::StringVectorMut; -use vortex_vector::bool::BoolVectorMut; -use vortex_vector::decimal::DVectorMut; -use vortex_vector::decimal::DecimalVectorMut; -use vortex_vector::null::NullVectorMut; -use vortex_vector::primitive::PVectorMut; -use vortex_vector::primitive::PrimitiveVectorMut; - -pub(super) fn to_vector(scalar: Scalar, len: usize) -> VectorMut { - match scalar.dtype() { - DType::Null => NullVectorMut::new(len).into(), - DType::Bool(_) => to_vector_bool(scalar.as_bool(), len).into(), - DType::Primitive(..) => to_vector_primitive(scalar.as_primitive(), len).into(), - DType::Decimal(..) => to_vector_decimal(scalar.as_decimal(), len).into(), - DType::Utf8(_) => to_vector_utf8(scalar.as_utf8(), len).into(), - DType::Binary(_) => to_vector_binary(scalar.as_binary(), len).into(), - DType::List(..) => unimplemented!("List constant vectorization"), - DType::FixedSizeList(..) => unimplemented!("FixedSizeList constant vectorization"), - DType::Struct(..) => unimplemented!("Struct constant vectorization"), - DType::Extension(_) => to_vector(scalar.as_extension().storage(), len), - } -} - -fn to_vector_bool(scalar: BoolScalar, len: usize) -> BoolVectorMut { - let mut vec = BoolVectorMut::with_capacity(len); - match scalar.value() { - Some(v) => vec.append_values(v, len), - None => vec.append_nulls(len), - } - vec -} - -fn to_vector_primitive(scalar: PrimitiveScalar, len: usize) -> PrimitiveVectorMut { - match_each_native_ptype!(scalar.ptype(), |T| { - let mut vec = PVectorMut::::with_capacity(len); - match scalar.typed_value::() { - Some(v) => vec.append_values(v, len), - None => vec.append_nulls(len), - } - vec.into() - }) -} - -fn to_vector_decimal(scalar: DecimalScalar, len: usize) -> DecimalVectorMut { - let decimal_dtype = scalar - .dtype() - .as_decimal_opt() - .vortex_expect("Decimal scalar must have a decimal type"); - let decimal_type = DecimalType::smallest_decimal_value_type(decimal_dtype); - - match_each_decimal_value_type!(decimal_type, |D| { - let ps = PrecisionScale::::new(decimal_dtype.precision(), decimal_dtype.scale()); - let mut vec = DVectorMut::::with_capacity(ps, len); - match scalar.decimal_value() { - Some(v) => vec - .try_append_n(v.cast::().vortex_expect("known to fit"), len) - .vortex_expect("known to fit"), - None => vec.append_nulls(len), - } - vec.into() - }) -} - -fn to_vector_utf8(scalar: Utf8Scalar, len: usize) -> StringVectorMut { - let mut vec = StringVectorMut::with_capacity(len); - match scalar.value() { - Some(v) => vec.append_values(v.as_ref(), len), - None => vec.append_nulls(len), - } - vec -} - -fn to_vector_binary(scalar: BinaryScalar, len: usize) -> BinaryVectorMut { - let mut vec = BinaryVectorMut::with_capacity(len); - match scalar.value() { - Some(v) => vec.append_values(v.as_ref(), len), - None => vec.append_nulls(len), - } - vec -} diff --git a/vortex-array/src/arrays/constant/vtable/mod.rs b/vortex-array/src/arrays/constant/vtable/mod.rs index 5dfb794f9e2..bf6fd09e261 100644 --- a/vortex-array/src/arrays/constant/vtable/mod.rs +++ b/vortex-array/src/arrays/constant/vtable/mod.rs @@ -7,13 +7,14 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; -use vortex_vector::Vector; +use vortex_vector::ScalarOps; use vortex_vector::VectorMutOps; use crate::EmptyMetadata; use crate::arrays::ConstantArray; -use crate::arrays::constant::vector::to_vector; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; +use crate::kernel::KernelRef; +use crate::kernel::kernel; use crate::serde::ArrayChildren; use crate::vtable; use crate::vtable::ArrayId; @@ -85,7 +86,9 @@ impl VTable for ConstantVTable { Ok(ConstantArray::new(scalar, len)) } - fn batch_execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { - Ok(to_vector(array.scalar().clone(), array.len()).freeze()) + fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult { + let scalar = array.scalar().to_vector_scalar(); + let len = array.len(); + Ok(kernel(move || Ok(scalar.clone().repeat(len).freeze()))) } } diff --git a/vortex-array/src/arrays/decimal/vtable/mod.rs b/vortex-array/src/arrays/decimal/vtable/mod.rs index 1bcf92547df..a52e48c74a1 100644 --- a/vortex-array/src/arrays/decimal/vtable/mod.rs +++ b/vortex-array/src/arrays/decimal/vtable/mod.rs @@ -8,18 +8,18 @@ use vortex_dtype::DType; use vortex_dtype::NativeDecimalType; use vortex_dtype::PrecisionScale; use vortex_dtype::match_each_decimal_value_type; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_scalar::DecimalType; -use vortex_vector::Vector; use vortex_vector::decimal::DVector; use crate::DeserializeMetadata; use crate::ProstMetadata; use crate::SerializeMetadata; use crate::arrays::DecimalArray; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; @@ -37,6 +37,8 @@ mod visitor; pub use operator::DecimalMaskedValidityRule; +use crate::kernel::KernelRef; +use crate::kernel::kernel; use crate::vtable::ArrayId; use crate::vtable::ArrayVTable; @@ -123,16 +125,39 @@ impl VTable for DecimalVTable { }) } - fn batch_execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { + fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult { + use vortex_dtype::BigCast; + match_each_decimal_value_type!(array.values_type(), |D| { - Ok(unsafe { - DVector::::new_unchecked( - PrecisionScale::new_unchecked(array.precision(), array.scale()), - array.buffer::(), - array.validity_mask(), - ) - } - .into()) + // TODO(ngates): we probably shouldn't convert here... Do we allow larger P/S for a + // given physical type, because we know that our values actually fit? + let min_value_type = DecimalType::smallest_decimal_value_type(&array.decimal_dtype()); + match_each_decimal_value_type!(min_value_type, |E| { + let decimal_dtype = array.decimal_dtype(); + let buffer = array.buffer::(); + let validity_mask = array.validity_mask(); + + Ok(kernel(move || { + // Copy from D to E, possibly widening, possibly narrowing + let values = + Buffer::::from_trusted_len_iter(buffer.iter().map(|d| { + ::from(*d).vortex_expect("Decimal cast failed") + })); + + Ok(unsafe { + DVector::::new_unchecked( + // TODO(ngates): this is too small? + PrecisionScale::new_unchecked( + decimal_dtype.precision(), + decimal_dtype.scale(), + ), + values, + validity_mask, + ) + } + .into()) + })) + }) }) } } diff --git a/vortex-array/src/arrays/extension/vtable/mod.rs b/vortex-array/src/arrays/extension/vtable/mod.rs index 95c54b40bb5..c30fc14bc3e 100644 --- a/vortex-array/src/arrays/extension/vtable/mod.rs +++ b/vortex-array/src/arrays/extension/vtable/mod.rs @@ -11,11 +11,11 @@ use vortex_buffer::BufferHandle; use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_vector::Vector; use crate::EmptyMetadata; use crate::arrays::extension::ExtensionArray; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; +use crate::kernel::KernelRef; use crate::serde::ArrayChildren; use crate::vtable; use crate::vtable::ArrayId; @@ -78,8 +78,8 @@ impl VTable for ExtensionVTable { Ok(ExtensionArray::new(ext_dtype.clone(), storage)) } - fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - array.storage().batch_execute(ctx) + fn bind_kernel(array: &Self::Array, ctx: &mut BindCtx) -> VortexResult { + array.storage().bind_kernel(ctx) } } diff --git a/vortex-array/src/arrays/filter/array.rs b/vortex-array/src/arrays/filter/array.rs index 87ef92fa749..afc75d828c1 100644 --- a/vortex-array/src/arrays/filter/array.rs +++ b/vortex-array/src/arrays/filter/array.rs @@ -1,10 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex_error::vortex_panic; use vortex_mask::Mask; -use crate::ArrayRef; use crate::stats::ArrayStats; +use crate::Array; +use crate::ArrayRef; #[derive(Clone, Debug)] pub struct FilterArray { @@ -12,3 +14,24 @@ pub struct FilterArray { pub(super) mask: Mask, pub(super) stats: ArrayStats, } + +impl FilterArray { + pub fn new(child: ArrayRef, mask: Mask) -> Self { + if child.len() != mask.len() { + vortex_panic!( + "FilterArray length mismatch: child array has length {} but mask has length {}", + child.len(), + mask.len() + ); + } + Self { + child, + mask, + stats: ArrayStats::default(), + } + } + + pub fn mask(&self) -> &Mask { + &self.mask + } +} diff --git a/vortex-array/src/arrays/filter/kernel.rs b/vortex-array/src/arrays/filter/kernel.rs new file mode 100644 index 00000000000..d4888f71506 --- /dev/null +++ b/vortex-array/src/arrays/filter/kernel.rs @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_compute::filter::Filter; +use vortex_error::VortexResult; +use vortex_mask::Mask; +use vortex_vector::Vector; + +use crate::kernel::Kernel; +use crate::kernel::KernelRef; +use crate::kernel::PushDownResult; + +#[derive(Debug)] +pub struct FilterKernel { + child: KernelRef, + mask: Mask, +} + +impl FilterKernel { + pub fn new(child: KernelRef, mask: Mask) -> Self { + Self { child, mask } + } +} + +impl Kernel for FilterKernel { + fn execute(self: Box) -> VortexResult { + Ok(Filter::filter(&self.child.execute()?, &self.mask)) + } + + fn push_down_filter(self: Box, selection: &Mask) -> VortexResult { + let new_mask = self.mask.intersect_by_rank(selection); + Ok(match self.child.push_down_filter(&new_mask)? { + PushDownResult::NotPushed(k) => PushDownResult::NotPushed(Box::new(FilterKernel { + child: k, + mask: new_mask, + })), + PushDownResult::Pushed(new_k) => PushDownResult::Pushed(new_k), + }) + } +} diff --git a/vortex-array/src/arrays/filter/mod.rs b/vortex-array/src/arrays/filter/mod.rs index ae7057a351c..dd45a57411a 100644 --- a/vortex-array/src/arrays/filter/mod.rs +++ b/vortex-array/src/arrays/filter/mod.rs @@ -2,6 +2,9 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors mod array; +mod kernel; mod vtable; +pub use array::*; +pub use kernel::*; pub use vtable::*; diff --git a/vortex-array/src/arrays/filter/vtable.rs b/vortex-array/src/arrays/filter/vtable.rs index 27048389320..674271e35db 100644 --- a/vortex-array/src/arrays/filter/vtable.rs +++ b/vortex-array/src/arrays/filter/vtable.rs @@ -2,31 +2,45 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::hash::Hasher; +use std::ops::Range; use vortex_buffer::BufferHandle; +use vortex_compute::filter::Filter; use vortex_dtype::DType; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_mask::Mask; -use vortex_vector::Vector; +use vortex_scalar::Scalar; use crate::Array; use crate::ArrayBufferVisitor; use crate::ArrayChildVisitor; use crate::ArrayEq; use crate::ArrayHash; +use crate::ArrayRef; +use crate::Canonical; +use crate::IntoArray; use crate::Precision; +use crate::arrays::LEGACY_SESSION; use crate::arrays::filter::array::FilterArray; -use crate::execution::ExecutionCtx; +use crate::arrays::filter::kernel::FilterKernel; +use crate::kernel::BindCtx; +use crate::kernel::KernelRef; +use crate::kernel::PushDownResult; use crate::serde::ArrayChildren; use crate::stats::StatsSetRef; +use crate::vectors::VectorIntoArray; use crate::vtable; use crate::vtable::ArrayId; use crate::vtable::ArrayVTable; use crate::vtable::ArrayVTableExt; use crate::vtable::BaseArrayVTable; +use crate::vtable::CanonicalVTable; use crate::vtable::NotSupported; +use crate::vtable::OperationsVTable; use crate::vtable::VTable; +use crate::vtable::ValidityVTable; use crate::vtable::VisitorVTable; vtable!(Filter); @@ -38,9 +52,9 @@ impl VTable for FilterVTable { type Array = FilterArray; type Metadata = Mask; type ArrayVTable = Self; - type CanonicalVTable = NotSupported; - type OperationsVTable = NotSupported; - type ValidityVTable = NotSupported; + type CanonicalVTable = Self; + type OperationsVTable = Self; + type ValidityVTable = Self; type VisitorVTable = Self; type ComputeVTable = NotSupported; type EncodeVTable = NotSupported; @@ -69,21 +83,46 @@ impl VTable for FilterVTable { &self, dtype: &DType, len: usize, - metadata: &Self::Metadata, + selection_mask: &Mask, _buffers: &[BufferHandle], children: &dyn ArrayChildren, ) -> VortexResult { - let child = children.get(0, dtype, len)?; + assert_eq!(len, selection_mask.true_count()); + let child = children.get(0, dtype, selection_mask.len())?; Ok(FilterArray { child, - mask: metadata.clone(), + mask: selection_mask.clone(), stats: Default::default(), }) } - fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - let child = array.child.batch_execute(ctx)?; - Ok(vortex_compute::filter::Filter::filter(&child, &array.mask)) + fn bind_kernel(array: &Self::Array, ctx: &mut BindCtx) -> VortexResult { + let mut child = array.child.bind_kernel(ctx)?; + let mask = array.mask.clone(); + + // NOTE(ngates): for now we keep the same behavior as develop where we push-down any + // query with <20% true values. + let pushdown = array.mask.density() < 0.2; + + if pushdown { + // Try to push down the filter to the child if it's cheaper. + child = match child.push_down_filter(&mask)? { + PushDownResult::Pushed(new_k) => { + tracing::debug!("Filter push down kernel:\n{:?}", new_k); + return Ok(new_k); + } + PushDownResult::NotPushed(child) => { + tracing::warn!( + "Filter pushdown was cheaper but not supported by child array {}", + array.child.display_tree() + ); + child + } + }; + } + + // Otherwise, wrap up the child in a filter kernel. + Ok(Box::new(FilterKernel::new(child, mask))) } } @@ -110,6 +149,48 @@ impl BaseArrayVTable for FilterVTable { } } +impl CanonicalVTable for FilterVTable { + fn canonicalize(array: &FilterArray) -> Canonical { + let vector = FilterVTable::bind_kernel(array, &mut BindCtx::new(LEGACY_SESSION.clone())) + .vortex_expect("Canonicalize should be fallible") + .execute() + .vortex_expect("Canonicalize should be fallible"); + vector.into_array(array.dtype()).to_canonical() + } +} + +impl OperationsVTable for FilterVTable { + fn slice(array: &FilterArray, range: Range) -> ArrayRef { + FilterArray::new(array.child.slice(range.clone()), array.mask.slice(range)).into_array() + } + + fn scalar_at(array: &FilterArray, index: usize) -> Scalar { + let rank_idx = array.mask.rank(index); + array.child.scalar_at(rank_idx) + } +} + +impl ValidityVTable for FilterVTable { + fn is_valid(array: &FilterArray, index: usize) -> bool { + let rank_idx = array.mask.rank(index); + array.child.is_valid(rank_idx) + } + + fn all_valid(array: &FilterArray) -> bool { + // An over-approximation: if the child is all valid, then the filtered array is all valid. + array.child.all_valid() + } + + fn all_invalid(array: &FilterArray) -> bool { + // An over-approximation: if the child is all invalid, then the filtered array is all invalid. + array.child.all_invalid() + } + + fn validity_mask(array: &FilterArray) -> Mask { + Filter::filter(&array.child.validity_mask(), &array.mask) + } +} + impl VisitorVTable for FilterVTable { fn visit_buffers(_array: &FilterArray, _visitor: &mut dyn ArrayBufferVisitor) {} diff --git a/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs b/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs index 0678275db17..0db0858de1a 100644 --- a/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs +++ b/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs @@ -8,12 +8,14 @@ use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; -use vortex_vector::Vector; use vortex_vector::fixed_size_list::FixedSizeListVector; +use crate::Array; use crate::EmptyMetadata; use crate::arrays::FixedSizeListArray; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; +use crate::kernel::KernelRef; +use crate::kernel::kernel; use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; @@ -108,14 +110,20 @@ impl VTable for FixedSizeListVTable { FixedSizeListArray::try_new(elements, *list_size, validity, len) } - fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(unsafe { - FixedSizeListVector::new_unchecked( - Arc::new(array.elements().batch_execute(ctx)?), - array.list_size(), - array.validity_mask(), - ) - } - .into()) + fn bind_kernel(array: &Self::Array, ctx: &mut BindCtx) -> VortexResult { + let elements_kernel = array.elements().bind_kernel(ctx)?; + let list_size = array.list_size(); + let validity_mask = array.validity_mask(); + + Ok(kernel(move || { + Ok(unsafe { + FixedSizeListVector::new_unchecked( + Arc::new(elements_kernel.execute()?), + list_size, + validity_mask, + ) + } + .into()) + })) } } diff --git a/vortex-array/src/arrays/listview/vtable/mod.rs b/vortex-array/src/arrays/listview/vtable/mod.rs index e1c62a109c3..27ce1bf7bb8 100644 --- a/vortex-array/src/arrays/listview/vtable/mod.rs +++ b/vortex-array/src/arrays/listview/vtable/mod.rs @@ -10,14 +10,15 @@ use vortex_dtype::PType; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; -use vortex_vector::Vector; use vortex_vector::listview::ListViewVector; use crate::DeserializeMetadata; use crate::ProstMetadata; use crate::SerializeMetadata; use crate::arrays::ListViewArray; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; +use crate::kernel::KernelRef; +use crate::kernel::kernel; use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; @@ -140,15 +141,25 @@ impl VTable for ListViewVTable { ListViewArray::try_new(elements, offsets, sizes, validity) } - fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(unsafe { - ListViewVector::new_unchecked( - Arc::new(array.elements().batch_execute(ctx)?), - array.offsets().batch_execute(ctx)?.into_primitive(), - array.sizes().batch_execute(ctx)?.into_primitive(), - array.validity_mask(), - ) - } - .into()) + fn bind_kernel(array: &Self::Array, ctx: &mut BindCtx) -> VortexResult { + let elements_kernel = array.elements().bind_kernel(ctx)?; + let offsets_kernel = array.offsets().bind_kernel(ctx)?; + let sizes_kernel = array.sizes().bind_kernel(ctx)?; + let validity_mask = array.validity_mask(); + + Ok(kernel(move || { + let elements = elements_kernel.execute()?; + let offsets = offsets_kernel.execute()?; + let sizes = sizes_kernel.execute()?; + Ok(unsafe { + ListViewVector::new_unchecked( + Arc::new(elements), + offsets.into_primitive(), + sizes.into_primitive(), + validity_mask, + ) + } + .into()) + })) } } diff --git a/vortex-array/src/arrays/masked/vtable/mod.rs b/vortex-array/src/arrays/masked/vtable/mod.rs index f0b5339fcf2..fd93871be0f 100644 --- a/vortex-array/src/arrays/masked/vtable/mod.rs +++ b/vortex-array/src/arrays/masked/vtable/mod.rs @@ -10,14 +10,15 @@ use vortex_buffer::BufferHandle; use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_vector::Vector; use vortex_vector::VectorOps; use crate::ArrayBufferVisitor; use crate::ArrayChildVisitor; use crate::EmptyMetadata; use crate::arrays::masked::MaskedArray; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; +use crate::kernel::KernelRef; +use crate::kernel::kernel; use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; @@ -105,10 +106,15 @@ impl VTable for MaskedVTable { MaskedArray::try_new(child, validity) } - fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - let mut vector = array.child().batch_execute(ctx)?; - vector.mask_validity(&array.validity_mask()); - Ok(vector) + fn bind_kernel(array: &Self::Array, ctx: &mut BindCtx) -> VortexResult { + let child_kernel = array.child().bind_kernel(ctx)?; + let validity_mask = array.validity_mask(); + + Ok(kernel(move || { + let mut vector = child_kernel.execute()?; + vector.mask_validity(&validity_mask); + Ok(vector) + })) } } diff --git a/vortex-array/src/arrays/mod.rs b/vortex-array/src/arrays/mod.rs index 8171a89490b..dca3497b457 100644 --- a/vortex-array/src/arrays/mod.rs +++ b/vortex-array/src/arrays/mod.rs @@ -6,6 +6,8 @@ #[cfg(any(test, feature = "test-harness"))] mod assertions; +use std::sync::LazyLock; + #[cfg(any(test, feature = "test-harness"))] pub use assertions::format_indices; @@ -57,3 +59,12 @@ pub use scalar_fn::*; pub use struct_::*; pub use varbin::*; pub use varbinview::*; +use vortex_session::VortexSession; + +use crate::session::ArraySession; + +// TODO(ngates): canonicalize doesn't currently take a session, therefore we cannot invoke execute +// from the new array encodings to support back-compat for legacy encodings. So we hold a session +// here... +static LEGACY_SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); diff --git a/vortex-array/src/arrays/null/mod.rs b/vortex-array/src/arrays/null/mod.rs index 3aeb4508d87..40260497049 100644 --- a/vortex-array/src/arrays/null/mod.rs +++ b/vortex-array/src/arrays/null/mod.rs @@ -9,7 +9,6 @@ use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_mask::Mask; use vortex_scalar::Scalar; -use vortex_vector::Vector; use vortex_vector::null::NullVector; use crate::ArrayBufferVisitor; @@ -19,7 +18,9 @@ use crate::Canonical; use crate::EmptyMetadata; use crate::IntoArray; use crate::Precision; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; +use crate::kernel::KernelRef; +use crate::kernel::ready; use crate::serde::ArrayChildren; use crate::stats::ArrayStats; use crate::stats::StatsSetRef; @@ -83,8 +84,8 @@ impl VTable for NullVTable { Ok(NullArray::new(len)) } - fn batch_execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { - Ok(NullVector::new(array.len()).into()) + fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult { + Ok(ready(NullVector::new(array.len()).into())) } } diff --git a/vortex-array/src/arrays/primitive/vtable/mod.rs b/vortex-array/src/arrays/primitive/vtable/mod.rs index 986bc159007..01604a10f69 100644 --- a/vortex-array/src/arrays/primitive/vtable/mod.rs +++ b/vortex-array/src/arrays/primitive/vtable/mod.rs @@ -9,12 +9,11 @@ use vortex_dtype::PType; use vortex_dtype::match_each_native_ptype; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_vector::Vector; use vortex_vector::primitive::PVector; use crate::EmptyMetadata; use crate::arrays::PrimitiveArray; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; @@ -32,6 +31,8 @@ mod visitor; pub use operator::PrimitiveMaskedValidityRule; +use crate::kernel::KernelRef; +use crate::kernel::ready; use crate::vtable::ArrayId; use crate::vtable::ArrayVTable; @@ -116,10 +117,10 @@ impl VTable for PrimitiveVTable { }) } - fn batch_execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { - Ok(match_each_native_ptype!(array.ptype(), |T| { + fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult { + Ok(ready(match_each_native_ptype!(array.ptype(), |T| { PVector::new(array.buffer::(), array.validity_mask()).into() - })) + }))) } } diff --git a/vortex-array/src/arrays/scalar_fn/array.rs b/vortex-array/src/arrays/scalar_fn/array.rs index ff4fe9f141d..919742c1610 100644 --- a/vortex-array/src/arrays/scalar_fn/array.rs +++ b/vortex-array/src/arrays/scalar_fn/array.rs @@ -8,7 +8,7 @@ use vortex_error::vortex_ensure; use crate::Array; use crate::ArrayRef; use crate::arrays::ScalarFnVTable; -use crate::expr::functions::scalar::ScalarFn; +use crate::expr::ScalarFn; use crate::stats::ArrayStats; use crate::vtable::ArrayVTable; use crate::vtable::ArrayVTableExt; @@ -26,9 +26,9 @@ pub struct ScalarFnArray { impl ScalarFnArray { /// Create a new ScalarFnArray from a scalar function and its children. - pub fn try_new(scalar_fn: ScalarFn, children: Vec, len: usize) -> VortexResult { + pub fn try_new(bound: ScalarFn, children: Vec, len: usize) -> VortexResult { let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect(); - let dtype = scalar_fn.return_dtype(&arg_dtypes)?; + let dtype = bound.return_dtype(&arg_dtypes)?; vortex_ensure!( children.iter().all(|c| c.len() == len), @@ -36,12 +36,17 @@ impl ScalarFnArray { ); Ok(Self { - vtable: ScalarFnVTable::new(scalar_fn.vtable().clone()).into_vtable(), - scalar_fn, + vtable: ScalarFnVTable::new(bound.vtable().clone()).into_vtable(), + scalar_fn: bound, dtype, len, children, stats: Default::default(), }) } + + /// Get the scalar function bound to this array. + pub fn scalar_fn(&self) -> &ScalarFn { + &self.scalar_fn + } } diff --git a/vortex-array/src/arrays/scalar_fn/kernel.rs b/vortex-array/src/arrays/scalar_fn/kernel.rs new file mode 100644 index 00000000000..3a23a3b6299 --- /dev/null +++ b/vortex-array/src/arrays/scalar_fn/kernel.rs @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_dtype::DType; +use vortex_error::VortexResult; +use vortex_mask::Mask; +use vortex_vector::Datum; +use vortex_vector::Scalar; +use vortex_vector::Vector; + +use crate::expr::ExecutionArgs; +use crate::expr::ScalarFn; +use crate::kernel::Kernel; +use crate::kernel::KernelRef; +use crate::kernel::PushDownResult; + +/// A CPU kernel for executing scalar functions. +#[derive(Debug)] +pub struct ScalarFnKernel { + /// The scalar function to apply. + pub(super) scalar_fn: ScalarFn, + + /// Inputs to the kernel, either constants or other kernels. + pub(super) inputs: Vec, + /// The input data types + pub(super) input_dtypes: Vec, + /// The row count for vector inputs + pub(super) row_count: usize, + /// The return data type + pub(super) return_dtype: DType, +} + +#[derive(Debug)] +pub(super) enum KernelInput { + Scalar(Scalar), + Vector(KernelRef), +} + +impl Kernel for ScalarFnKernel { + fn execute(self: Box) -> VortexResult { + let mut datums: Vec = Vec::with_capacity(self.inputs.len()); + for input in self.inputs { + match input { + KernelInput::Scalar(s) => { + datums.push(Datum::Scalar(s)); + } + KernelInput::Vector(kernel) => { + datums.push(Datum::Vector(kernel.execute()?)); + } + } + } + + let args = ExecutionArgs { + datums, + dtypes: self.input_dtypes, + row_count: self.row_count, + return_dtype: self.return_dtype, + }; + + Ok(self.scalar_fn.execute(args)?.ensure_vector(self.row_count)) + } + + fn push_down_filter(self: Box, selection: &Mask) -> VortexResult { + let mut new_inputs = Vec::with_capacity(self.inputs.len()); + for input in self.inputs { + match input { + KernelInput::Scalar(s) => { + new_inputs.push(KernelInput::Scalar(s.clone())); + } + KernelInput::Vector(k) => { + new_inputs.push(KernelInput::Vector(k.force_push_down_filter(selection)?)); + } + } + } + + Ok(PushDownResult::Pushed(Box::new(ScalarFnKernel { + scalar_fn: self.scalar_fn, + inputs: new_inputs, + input_dtypes: self.input_dtypes, + row_count: selection.true_count(), + return_dtype: self.return_dtype, + }))) + } +} diff --git a/vortex-array/src/arrays/scalar_fn/metadata.rs b/vortex-array/src/arrays/scalar_fn/metadata.rs index 1f458fa83e1..0901288695f 100644 --- a/vortex-array/src/arrays/scalar_fn/metadata.rs +++ b/vortex-array/src/arrays/scalar_fn/metadata.rs @@ -1,12 +1,22 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::fmt::Debug; +use std::fmt::Formatter; + use vortex_dtype::DType; -use crate::expr::functions::scalar::ScalarFn; +use crate::expr::ScalarFn; -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct ScalarFnMetadata { pub(super) scalar_fn: ScalarFn, pub(super) child_dtypes: Vec, } + +// Array tree display wrongly uses debug... +impl Debug for ScalarFnMetadata { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.scalar_fn.options()) + } +} diff --git a/vortex-array/src/arrays/scalar_fn/mod.rs b/vortex-array/src/arrays/scalar_fn/mod.rs index db178a139d1..bbadf45e7a0 100644 --- a/vortex-array/src/arrays/scalar_fn/mod.rs +++ b/vortex-array/src/arrays/scalar_fn/mod.rs @@ -2,7 +2,9 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors mod array; +mod kernel; mod metadata; +pub(crate) mod rules; mod vtable; pub use array::*; diff --git a/vortex-array/src/arrays/scalar_fn/rules.rs b/vortex-array/src/arrays/scalar_fn/rules.rs new file mode 100644 index 00000000000..05231b3947e --- /dev/null +++ b/vortex-array/src/arrays/scalar_fn/rules.rs @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; +use vortex_scalar::Scalar; +use vortex_vector::Datum; +use vortex_vector::VectorOps; +use vortex_vector::scalar_matches_dtype; + +use crate::Array; +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::AnyScalarFn; +use crate::arrays::ConstantArray; +use crate::arrays::ConstantVTable; +use crate::arrays::ScalarFnArray; +use crate::expr::ExecutionArgs; +use crate::optimizer::rules::ArrayReduceRule; + +#[derive(Debug)] +pub(crate) struct ScalarFnConstantRule; +impl ArrayReduceRule for ScalarFnConstantRule { + fn matcher(&self) -> AnyScalarFn { + AnyScalarFn + } + + fn reduce(&self, array: &ScalarFnArray) -> VortexResult> { + if !array.children.iter().all(|c| c.is::()) { + return Ok(None); + } + + let input_datums: Vec<_> = array + .children + .iter() + .map(|c| c.as_::().scalar().to_vector_scalar()) + .map(Datum::Scalar) + .collect(); + let input_dtypes = array.children.iter().map(|c| c.dtype().clone()).collect(); + + let result = array.scalar_fn.execute(ExecutionArgs { + datums: input_datums, + dtypes: input_dtypes, + row_count: array.len, + return_dtype: array.dtype.clone(), + })?; + + let result = match result { + Datum::Scalar(s) => s, + Datum::Vector(v) => { + tracing::warn!( + "Scalar function {} returned vector from execution over all scalar inputs", + array.scalar_fn + ); + v.scalar_at(0) + } + }; + assert!(scalar_matches_dtype(&result, &array.dtype)); + + let _fn = format!("{}", array.scalar_fn); + Ok(Some( + ConstantArray::new(Scalar::from_vector_scalar(result, &array.dtype)?, array.len) + .into_array(), + )) + } +} diff --git a/vortex-array/src/arrays/scalar_fn/vtable/canonical.rs b/vortex-array/src/arrays/scalar_fn/vtable/canonical.rs index f1222998ab9..58ae22e4cc7 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/canonical.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/canonical.rs @@ -1,41 +1,45 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use itertools::Itertools; use vortex_error::VortexExpect; -use vortex_vector::Datum; use crate::Array; use crate::Canonical; +use crate::arrays::LEGACY_SESSION; use crate::arrays::scalar_fn::array::ScalarFnArray; -use crate::arrays::scalar_fn::vtable::SCALAR_FN_SESSION; use crate::arrays::scalar_fn::vtable::ScalarFnVTable; -use crate::expr::functions::ExecutionArgs; +use crate::executor::VectorExecutor; +use crate::expr::ExecutionArgs; use crate::vectors::VectorIntoArray; use crate::vtable::CanonicalVTable; impl CanonicalVTable for ScalarFnVTable { fn canonicalize(array: &ScalarFnArray) -> Canonical { let child_dtypes: Vec<_> = array.children.iter().map(|c| c.dtype().clone()).collect(); - let child_datums: Vec<_> = array - .children() - .iter() - // TODO(ngates): we could make all execution operate over datums - .map(|child| child.execute(&SCALAR_FN_SESSION).map(Datum::Vector)) - .try_collect() - // FIXME(ngates): canonicalizing really ought to be fallible - .vortex_expect( - "Failed to execute child array during canonicalization of ScalarFnArray", - ); - let ctx = ExecutionArgs::new(array.len, array.dtype.clone(), child_dtypes, child_datums); + let mut child_datums = Vec::with_capacity(array.children.len()); + for child in array.children.iter() { + let datum = child + .execute_datum_optimized(&LEGACY_SESSION) + .vortex_expect( + "Failed to execute child array during canonicalization of ScalarFnArray", + ); + child_datums.push(datum); + } + let ctx = ExecutionArgs { + datums: child_datums, + dtypes: child_dtypes, + row_count: array.len, + return_dtype: array.dtype.clone(), + }; + + let len = array.len; let result_vector = array .scalar_fn - .execute(&ctx) + .execute(ctx) .vortex_expect("Canonicalize should be fallible") - .into_vector() - .vortex_expect("Canonicalize should return a vector"); + .ensure_vector(len); result_vector.into_array(&array.dtype).to_canonical() } diff --git a/vortex-array/src/arrays/scalar_fn/vtable/mod.rs b/vortex-array/src/arrays/scalar_fn/vtable/mod.rs index 37a5cf643e8..e94e7e96aa8 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/mod.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/mod.rs @@ -9,30 +9,30 @@ mod visitor; use std::marker::PhantomData; use std::ops::Deref; -use std::sync::LazyLock; use itertools::Itertools; use vortex_buffer::BufferHandle; use vortex_dtype::DType; -use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; -use vortex_session::VortexSession; -use vortex_vector::Vector; use crate::Array; use crate::ArrayRef; use crate::IntoArray; +use crate::arrays::ConstantVTable; use crate::arrays::scalar_fn::array::ScalarFnArray; +use crate::arrays::scalar_fn::kernel::KernelInput; +use crate::arrays::scalar_fn::kernel::ScalarFnKernel; use crate::arrays::scalar_fn::metadata::ScalarFnMetadata; -use crate::execution::ExecutionCtx; -use crate::expr::functions; -use crate::expr::functions::scalar::ScalarFn; +use crate::expr; +use crate::expr::ExprVTable; +use crate::expr::ScalarFn; +use crate::kernel::BindCtx; +use crate::kernel::KernelRef; use crate::optimizer::rules::MatchKey; use crate::optimizer::rules::Matcher; use crate::serde::ArrayChildren; -use crate::session::ArraySession; use crate::vtable; use crate::vtable::ArrayId; use crate::vtable::ArrayVTable; @@ -40,21 +40,15 @@ use crate::vtable::ArrayVTableExt; use crate::vtable::NotSupported; use crate::vtable::VTable; -// TODO(ngates): canonicalize doesn't currently take a session, therefore we cannot dispatch -// to registered scalar function kernels. We therefore hold our own non-pluggable session here -// that contains all the built-in kernels while we migrate over to "execute" instead of canonicalize. -static SCALAR_FN_SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); - vtable!(ScalarFn); #[derive(Clone, Debug)] pub struct ScalarFnVTable { - vtable: functions::ScalarFnVTable, + vtable: ExprVTable, } impl ScalarFnVTable { - pub fn new(vtable: functions::ScalarFnVTable) -> Self { + pub fn new(vtable: ExprVTable) -> Self { Self { vtable } } } @@ -130,29 +124,37 @@ impl VTable for ScalarFnVTable { }) } - fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - let input_dtypes: Vec<_> = array.children().iter().map(|c| c.dtype().clone()).collect(); - let input_datums = array + fn bind_kernel(array: &Self::Array, ctx: &mut BindCtx) -> VortexResult { + let inputs: Vec<_> = array .children() .iter() - .map(|child| child.batch_execute(ctx)) + .map(|child| match child.as_opt::() { + None => child.bind_kernel(ctx).map(KernelInput::Vector), + Some(constant) => { + let scalar = constant.scalar().to_vector_scalar(); + Ok(KernelInput::Scalar(scalar)) + } + }) .try_collect()?; - let ctx = functions::ExecutionArgs::new( - array.len(), - array.dtype.clone(), + + let input_dtypes: Vec<_> = array + .children() + .iter() + .map(|child| child.dtype().clone()) + .collect(); + + Ok(Box::new(ScalarFnKernel { + scalar_fn: array.scalar_fn.clone(), + inputs, input_dtypes, - input_datums, - ); - Ok(array - .scalar_fn - .execute(&ctx)? - .into_vector() - .vortex_expect("Vector inputs should return vector outputs")) + row_count: array.len(), + return_dtype: array.dtype().clone(), + })) } } /// Array factory functions for scalar functions. -pub trait ScalarFnArrayExt: functions::VTable { +pub trait ScalarFnArrayExt: expr::VTable { fn try_new_array( &'static self, len: usize, @@ -186,7 +188,7 @@ pub trait ScalarFnArrayExt: functions::VTable { .into_array()) } } -impl ScalarFnArrayExt for V {} +impl ScalarFnArrayExt for V {} /// A matcher that matches any scalar function expression. #[derive(Debug)] @@ -205,12 +207,12 @@ impl Matcher for AnyScalarFn { /// A matcher that matches a specific scalar function expression. #[derive(Debug)] -pub struct ExactScalarFn { +pub struct ExactScalarFn { id: ArrayId, _phantom: PhantomData, } -impl From<&'static F> for ExactScalarFn { +impl From<&'static F> for ExactScalarFn { fn from(value: &'static F) -> Self { Self { id: value.id(), @@ -219,7 +221,7 @@ impl From<&'static F> for ExactScalarFn { } } -impl Matcher for ExactScalarFn { +impl Matcher for ExactScalarFn { type View<'a> = ScalarFnArrayView<'a, F>; fn key(&self) -> MatchKey { @@ -246,13 +248,13 @@ impl Matcher for ExactScalarFn { } } -pub struct ScalarFnArrayView<'a, F: functions::VTable> { +pub struct ScalarFnArrayView<'a, F: expr::VTable> { array: &'a ArrayRef, pub vtable: &'a F, pub options: &'a F::Options, } -impl Deref for ScalarFnArrayView<'_, F> { +impl Deref for ScalarFnArrayView<'_, F> { type Target = ArrayRef; fn deref(&self) -> &Self::Target { diff --git a/vortex-array/src/arrays/scalar_fn/vtable/operations.rs b/vortex-array/src/arrays/scalar_fn/vtable/operations.rs index fd7d5255667..aff294863c8 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/operations.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/operations.rs @@ -11,7 +11,7 @@ use crate::ArrayRef; use crate::IntoArray; use crate::arrays::scalar_fn::array::ScalarFnArray; use crate::arrays::scalar_fn::vtable::ScalarFnVTable; -use crate::expr::functions::ExecutionArgs; +use crate::expr::ExecutionArgs; use crate::vtable::OperationsVTable; impl OperationsVTable for ScalarFnVTable { @@ -42,16 +42,16 @@ impl OperationsVTable for ScalarFnVTable { .map(|scalar| Datum::from(scalar.to_vector_scalar())) .collect(); - let ctx = ExecutionArgs::new( - 1, - array.dtype.clone(), - array.children().iter().map(|s| s.dtype().clone()).collect(), - input_datums, - ); + let ctx = ExecutionArgs { + datums: input_datums, + dtypes: array.children().iter().map(|c| c.dtype().clone()).collect(), + row_count: 1, + return_dtype: array.dtype.clone(), + }; let _result = array .scalar_fn - .execute(&ctx) + .execute(ctx) .vortex_expect("Scalar function execution should be fallible") .into_scalar() .vortex_expect("Scalar function execution should return scalar"); diff --git a/vortex-array/src/arrays/scalar_fn/vtable/validity.rs b/vortex-array/src/arrays/scalar_fn/vtable/validity.rs index 4b43acc6cab..3e899ef09e3 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/validity.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/validity.rs @@ -3,12 +3,15 @@ use vortex_error::VortexExpect; use vortex_mask::Mask; +use vortex_vector::Datum; +use vortex_vector::ScalarOps; +use vortex_vector::VectorOps; use crate::Array; +use crate::arrays::LEGACY_SESSION; use crate::arrays::scalar_fn::array::ScalarFnArray; -use crate::arrays::scalar_fn::vtable::SCALAR_FN_SESSION; use crate::arrays::scalar_fn::vtable::ScalarFnVTable; -use crate::expr::functions::NullHandling; +use crate::executor::VectorExecutor; use crate::vtable::ValidityVTable; impl ValidityVTable for ScalarFnVTable { @@ -17,35 +20,43 @@ impl ValidityVTable for ScalarFnVTable { } fn all_valid(array: &ScalarFnArray) -> bool { - match array.scalar_fn.signature().null_handling() { - NullHandling::Propagate | NullHandling::AbsorbsNull => { - // Requires all children to guarantee all_valid - array.children().iter().all(|child| child.all_valid()) - } - NullHandling::Custom => { - // We cannot guarantee that the array is all valid without evaluating the function + match array.scalar_fn.signature().is_null_sensitive() { + true => { + // If the function is null sensitive, we cannot guarantee all valid without evaluating + // the function false } + false => { + // If the function is not null sensitive, we can guarantee all valid if all children + // are all valid + array.children().iter().all(|child| child.all_valid()) + } } } fn all_invalid(array: &ScalarFnArray) -> bool { - match array.scalar_fn.signature().null_handling() { - NullHandling::Propagate => { - // All null if any child is all null - array.children().iter().any(|child| child.all_invalid()) - } - NullHandling::AbsorbsNull | NullHandling::Custom => { - // We cannot guarantee that the array is all valid without evaluating the function + match array.scalar_fn.signature().is_null_sensitive() { + true => { + // If the function is null sensitive, we cannot guarantee all invalid without evaluating + // the function false } + false => { + // If the function is not null sensitive, we can guarantee all invalid if any child + // is all invalid + array.children().iter().any(|child| child.all_invalid()) + } } } fn validity_mask(array: &ScalarFnArray) -> Mask { - let vector = array - .execute(&SCALAR_FN_SESSION) + let datum = array + .to_array() + .execute_datum(&LEGACY_SESSION) .vortex_expect("Validity mask computation should be fallible"); - Mask::from_buffer(vector.into_bool().into_parts().0) + match datum { + Datum::Scalar(s) => Mask::new(array.len, s.is_valid()), + Datum::Vector(v) => v.validity().clone(), + } } } diff --git a/vortex-array/src/arrays/scalar_fn/vtable/visitor.rs b/vortex-array/src/arrays/scalar_fn/vtable/visitor.rs index e43b1049068..e650ba0d6a3 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/visitor.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/visitor.rs @@ -12,7 +12,7 @@ impl VisitorVTable for ScalarFnVTable { fn visit_children(array: &ScalarFnArray, visitor: &mut dyn ArrayChildVisitor) { for (idx, child) in array.children.iter().enumerate() { - let name = array.scalar_fn.signature().arg_name(idx); + let name = array.scalar_fn.signature().child_name(idx); visitor.visit_child(name.as_ref(), child.as_ref()) } } diff --git a/vortex-array/src/arrays/struct_/mod.rs b/vortex-array/src/arrays/struct_/mod.rs index 91cc3e54505..dcdb2736ea2 100644 --- a/vortex-array/src/arrays/struct_/mod.rs +++ b/vortex-array/src/arrays/struct_/mod.rs @@ -3,10 +3,11 @@ mod array; pub use array::StructArray; - mod compute; +mod rules; mod vtable; +pub(crate) use rules::*; pub use vtable::StructVTable; #[cfg(test)] diff --git a/vortex-array/src/arrays/struct_/rules.rs b/vortex-array/src/arrays/struct_/rules.rs new file mode 100644 index 00000000000..d73cfddcd78 --- /dev/null +++ b/vortex-array/src/arrays/struct_/rules.rs @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::Array; +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::ConstantArray; +use crate::arrays::ExactScalarFn; +use crate::arrays::ScalarFnArrayExt; +use crate::arrays::ScalarFnArrayView; +use crate::arrays::StructArray; +use crate::arrays::StructVTable; +use crate::expr::EmptyOptions; +use crate::expr::GetItem; +use crate::expr::Mask; +use crate::optimizer::rules::ArrayParentReduceRule; +use crate::optimizer::rules::Exact; +use crate::validity::Validity; +use crate::vtable::ValidityHelper; + +#[derive(Debug)] +pub(crate) struct StructGetItemRule; +impl ArrayParentReduceRule, ExactScalarFn> for StructGetItemRule { + fn child(&self) -> Exact { + Exact::from(&StructVTable) + } + + fn parent(&self) -> ExactScalarFn { + ExactScalarFn::from(&GetItem) + } + + fn reduce_parent( + &self, + child: &StructArray, + parent: ScalarFnArrayView<'_, GetItem>, + _child_idx: usize, + ) -> VortexResult> { + let field_name = parent.options; + let Some(field) = child.field_by_name_opt(field_name) else { + return Ok(None); + }; + + match child.validity() { + Validity::NonNullable | Validity::AllValid => { + // If the struct is non-nullable or all valid, the field's validity is unchanged + Ok(Some(field.clone())) + } + Validity::AllInvalid => { + // If everything is invalid, the field is also all invalid + Ok(Some( + ConstantArray::new( + vortex_scalar::Scalar::null(field.dtype().clone()), + field.len(), + ) + .into_array(), + )) + } + Validity::Array(mask) => { + // If the validity is an array, we need to combine it with the field's validity + Mask.try_new_array(field.len(), EmptyOptions, [field.clone(), mask.clone()]) + .map(Some) + } + } + } +} diff --git a/vortex-array/src/arrays/struct_/vtable/mod.rs b/vortex-array/src/arrays/struct_/vtable/mod.rs index a621793985e..39e4d29368e 100644 --- a/vortex-array/src/arrays/struct_/vtable/mod.rs +++ b/vortex-array/src/arrays/struct_/vtable/mod.rs @@ -9,12 +9,13 @@ use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_vector::Vector; use vortex_vector::struct_::StructVector; use crate::EmptyMetadata; use crate::arrays::struct_::StructArray; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; +use crate::kernel::KernelRef; +use crate::kernel::kernel; use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; @@ -106,14 +107,19 @@ impl VTable for StructVTable { StructArray::try_new_with_dtype(children, struct_dtype.clone(), len, validity) } - fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + fn bind_kernel(array: &Self::Array, ctx: &mut BindCtx) -> VortexResult { let fields: Box<[_]> = array .fields() .iter() - .map(|field| field.batch_execute(ctx)) + .map(|field| field.bind_kernel(ctx)) .try_collect()?; - // SAFETY: we know that all field lengths match the struct array length, and the validity - Ok(unsafe { StructVector::new_unchecked(Arc::new(fields), array.validity_mask()) }.into()) + let validity_mask = array.validity_mask(); + + Ok(kernel(move || { + // SAFETY: we know that all field lengths match the struct array length, and the validity + let fields = fields.into_iter().map(|k| k.execute()).try_collect()?; + Ok(unsafe { StructVector::new_unchecked(Arc::new(fields), validity_mask) }.into()) + })) } } diff --git a/vortex-array/src/arrays/varbinview/vtable/mod.rs b/vortex-array/src/arrays/varbinview/vtable/mod.rs index 677cd767ccc..5d03e4f4713 100644 --- a/vortex-array/src/arrays/varbinview/vtable/mod.rs +++ b/vortex-array/src/arrays/varbinview/vtable/mod.rs @@ -10,14 +10,15 @@ use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_vector::Vector; use vortex_vector::binaryview::BinaryVector; use vortex_vector::binaryview::BinaryView; use vortex_vector::binaryview::StringVector; use crate::EmptyMetadata; use crate::arrays::varbinview::VarBinViewArray; -use crate::execution::ExecutionCtx; +use crate::kernel::BindCtx; +use crate::kernel::KernelRef; +use crate::kernel::ready; use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; @@ -36,6 +37,9 @@ mod visitor; vtable!(VarBinView); +#[derive(Debug)] +pub struct VarBinViewVTable; + impl VTable for VarBinViewVTable { type Array = VarBinViewArray; @@ -104,28 +108,25 @@ impl VTable for VarBinViewVTable { VarBinViewArray::try_new(views, Arc::from(buffers), dtype.clone(), validity) } - fn batch_execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { + fn bind_kernel(array: &Self::Array, _ctx: &mut BindCtx) -> VortexResult { Ok(match array.dtype() { - DType::Utf8(_) => unsafe { + DType::Utf8(_) => ready(unsafe { StringVector::new_unchecked( array.views().clone(), Arc::new(array.buffers().to_vec().into_boxed_slice()), array.validity_mask(), ) - } - .into(), - DType::Binary(_) => unsafe { + .into() + }), + DType::Binary(_) => ready(unsafe { BinaryVector::new_unchecked( array.views().clone(), Arc::new(array.buffers().to_vec().into_boxed_slice()), array.validity_mask(), ) - } - .into(), + .into() + }), _ => unreachable!("VarBinViewArray must have Binary or Utf8 dtype"), }) } } - -#[derive(Debug)] -pub struct VarBinViewVTable; diff --git a/vortex-array/src/scalar_fns/mod.rs b/vortex-array/src/builtins.rs similarity index 55% rename from vortex-array/src/scalar_fns/mod.rs rename to vortex-array/src/builtins.rs index ddf14a9bed8..2fb41b88be7 100644 --- a/vortex-array/src/scalar_fns/mod.rs +++ b/vortex-array/src/builtins.rs @@ -13,19 +13,18 @@ use vortex_dtype::DType; use vortex_dtype::FieldName; use vortex_error::VortexResult; -use crate::Array; -use crate::ArrayRef; use crate::arrays::ScalarFnArrayExt; +use crate::expr::Binary; +use crate::expr::Cast; +use crate::expr::EmptyOptions; use crate::expr::Expression; -use crate::expr::ScalarFnExprExt; -use crate::expr::functions::EmptyOptions; - -pub mod binary; -pub mod cast; -pub mod get_item; -pub mod is_null; -pub mod mask; -pub mod not; +use crate::expr::GetItem; +use crate::expr::IsNull; +use crate::expr::Mask; +use crate::expr::Not; +use crate::expr::VTableExt; +use crate::Array; +use crate::ArrayRef; /// A collection of built-in scalar functions that can be applied to expressions or arrays. pub trait ExprBuiltins: Sized { @@ -49,23 +48,23 @@ pub trait ExprBuiltins: Sized { impl ExprBuiltins for Expression { fn cast(&self, dtype: DType) -> VortexResult { - cast::CastFn.try_new_expr(dtype, [self.clone()]) + Cast.try_new_expr(dtype, [self.clone()]) } fn get_item(&self, field_name: impl Into) -> VortexResult { - get_item::GetItemFn.try_new_expr(field_name.into(), [self.clone()]) + GetItem.try_new_expr(field_name.into(), [self.clone()]) } fn is_null(&self) -> VortexResult { - is_null::IsNullFn.try_new_expr(EmptyOptions, [self.clone()]) + IsNull.try_new_expr(EmptyOptions, [self.clone()]) } fn mask(&self, mask: Expression) -> VortexResult { - mask::MaskFn.try_new_expr(EmptyOptions, [self.clone(), mask]) + Mask.try_new_expr(EmptyOptions, [self.clone(), mask]) } fn not(&self) -> VortexResult { - not::NotFn.try_new_expr(EmptyOptions, [self.clone()]) + Not.try_new_expr(EmptyOptions, [self.clone()]) } } @@ -82,30 +81,74 @@ pub trait ArrayBuiltins: Sized { /// Mask the array using the given boolean mask. /// The resulting array's validity is the intersection of the original array's validity /// and the mask's validity. - fn mask(&self, mask: &ArrayRef) -> VortexResult; + fn mask(&self, mask: ArrayRef) -> VortexResult; /// Boolean negation. fn not(&self) -> VortexResult; + + /// Add two arrays together. + fn add(&self, other: ArrayRef) -> VortexResult; + + /// Subtract two arrays. + fn sub(&self, other: ArrayRef) -> VortexResult; + + /// Multiply two arrays together. + fn mul(&self, other: ArrayRef) -> VortexResult; + + /// Divide two arrays. + fn div(&self, other: ArrayRef) -> VortexResult; } impl ArrayBuiltins for ArrayRef { fn cast(&self, dtype: DType) -> VortexResult { - cast::CastFn.try_new_array(self.len(), dtype, [self.clone()]) + Cast.try_new_array(self.len(), dtype, [self.clone()]) } fn get_item(&self, field_name: impl Into) -> VortexResult { - get_item::GetItemFn.try_new_array(self.len(), field_name.into(), [self.clone()]) + GetItem.try_new_array(self.len(), field_name.into(), [self.clone()]) } fn is_null(&self) -> VortexResult { - is_null::IsNullFn.try_new_array(self.len(), EmptyOptions, [self.clone()]) + IsNull.try_new_array(self.len(), EmptyOptions, [self.clone()]) } - fn mask(&self, mask: &ArrayRef) -> VortexResult { - mask::MaskFn.try_new_array(self.len(), EmptyOptions, [self.clone(), mask.clone()]) + fn mask(&self, mask: ArrayRef) -> VortexResult { + Mask.try_new_array(self.len(), EmptyOptions, [self.clone(), mask.clone()]) } fn not(&self) -> VortexResult { - not::NotFn.try_new_array(self.len(), EmptyOptions, [self.clone()]) + Not.try_new_array(self.len(), EmptyOptions, [self.clone()]) + } + + fn add(&self, other: ArrayRef) -> VortexResult { + Binary.try_new_array( + self.len(), + crate::expr::Operator::Add, + [self.clone(), other], + ) + } + + fn sub(&self, other: ArrayRef) -> VortexResult { + Binary.try_new_array( + self.len(), + crate::expr::Operator::Sub, + [self.clone(), other], + ) + } + + fn mul(&self, other: ArrayRef) -> VortexResult { + Binary.try_new_array( + self.len(), + crate::expr::Operator::Mul, + [self.clone(), other], + ) + } + + fn div(&self, other: ArrayRef) -> VortexResult { + Binary.try_new_array( + self.len(), + crate::expr::Operator::Div, + [self.clone(), other], + ) } } diff --git a/vortex-array/src/compute/between.rs b/vortex-array/src/compute/between.rs index ee0e8409ee6..28679bfdf24 100644 --- a/vortex-array/src/compute/between.rs +++ b/vortex-array/src/compute/between.rs @@ -2,6 +2,8 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::any::Any; +use std::fmt::Display; +use std::fmt::Formatter; use std::sync::LazyLock; use arcref::ArcRef; @@ -253,6 +255,22 @@ pub struct BetweenOptions { pub upper_strict: StrictComparison, } +impl Display for BetweenOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let lower_op = if self.lower_strict.is_strict() { + "<" + } else { + "<=" + }; + let upper_op = if self.upper_strict.is_strict() { + "<" + } else { + "<=" + }; + write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op) + } +} + impl Options for BetweenOptions { fn as_any(&self) -> &dyn Any { self diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index 2aab501dc34..01437f3c6d1 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -222,7 +222,7 @@ impl Kernel for FilterKernelAdapter { let Some(array) = inputs.array.as_opt::() else { return Ok(None); }; - let filtered = V::filter(&self.0, array, inputs.mask)?; + let filtered = ::filter(&self.0, array, inputs.mask)?; Ok(Some(filtered.into())) } } diff --git a/vortex-array/src/compute/like.rs b/vortex-array/src/compute/like.rs index fc352d2bd8e..fca5e6ad12f 100644 --- a/vortex-array/src/compute/like.rs +++ b/vortex-array/src/compute/like.rs @@ -2,6 +2,8 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::any::Any; +use std::fmt::Display; +use std::fmt::Formatter; use std::sync::LazyLock; use arcref::ArcRef; @@ -150,6 +152,19 @@ pub struct LikeOptions { pub case_insensitive: bool, } +impl Display for LikeOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if self.negated { + write!(f, "NOT ")?; + } + if self.case_insensitive { + write!(f, "ILIKE") + } else { + write!(f, "LIKE") + } + } +} + impl Options for LikeOptions { fn as_any(&self) -> &dyn Any { self diff --git a/vortex-array/src/execution/batch.rs b/vortex-array/src/execution/batch.rs deleted file mode 100644 index 928a9b7487b..00000000000 --- a/vortex-array/src/execution/batch.rs +++ /dev/null @@ -1,43 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; -use vortex_vector::Vector; - -use crate::ArrayRef; - -/// Type-alias for heap-allocated batch execution kernels. -pub type BatchKernelRef = Box; - -/// Trait for batch execution kernels that produce a vector result. -pub trait BatchKernel: 'static + Send { - fn execute(self: Box) -> VortexResult; -} - -/// Adapter to create a batch kernel from a closure. -pub struct BatchKernelAdapter(F); -impl VortexResult + Send + 'static> BatchKernel for BatchKernelAdapter { - fn execute(self: Box) -> VortexResult { - self.0() - } -} - -/// Create a batch execution kernel from the given closure. -#[inline(always)] -pub fn kernel VortexResult + Send + 'static>(f: F) -> BatchKernelRef { - Box::new(BatchKernelAdapter(f)) -} - -/// Context for binding batch execution kernels. -/// -/// By binding child arrays through this context, we can perform common subtree elimination and -/// share canonicalized results across multiple kernels. -pub trait BindCtx { - /// Bind the given array and optional selection to produce a batch kernel, possibly reusing - /// previously bound results from this context. - fn bind( - &mut self, - array: &ArrayRef, - selection: Option<&ArrayRef>, - ) -> VortexResult; -} diff --git a/vortex-array/src/execution/mask.rs b/vortex-array/src/execution/mask.rs deleted file mode 100644 index 81d786f0137..00000000000 --- a/vortex-array/src/execution/mask.rs +++ /dev/null @@ -1,88 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_dtype::DType; -use vortex_dtype::Nullability::NonNullable; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_mask::Mask; - -use crate::ArrayRef; -use crate::execution::BindCtx; - -pub enum MaskExecution { - AllTrue(usize), - AllFalse(usize), - Lazy(Box VortexResult + Send + 'static>), -} - -impl MaskExecution { - pub fn lazy VortexResult + Send + 'static>(f: F) -> MaskExecution { - MaskExecution::Lazy(Box::new(f)) - } - - pub fn execute(self) -> VortexResult { - match self { - MaskExecution::AllTrue(len) => Ok(Mask::new_true(len)), - MaskExecution::AllFalse(len) => Ok(Mask::new_false(len)), - MaskExecution::Lazy(f) => f(), - } - } -} - -impl dyn BindCtx + '_ { - /// Bind an optional selection mask into a `MaskExecution`. - /// - /// The caller must provide a mask length to handle the case where no mask is provided. - pub fn bind_selection( - &mut self, - mask_len: usize, - mask: Option<&ArrayRef>, - ) -> VortexResult { - match mask { - Some(mask) => { - assert_eq!(mask.len(), mask_len); - self.bind_mask(mask) - } - None => Ok(MaskExecution::AllTrue(mask_len)), - } - } - - /// Bind a non-nullable boolean array into a `MaskExecution`. - /// - /// This binding will optimize for constant arrays or other array types that can be more - /// efficiently converted into a `Mask`. - pub fn bind_mask(&mut self, mask: &ArrayRef) -> VortexResult { - if !matches!(mask.dtype(), DType::Bool(NonNullable)) { - vortex_bail!( - "Expected non-nullable boolean array for mask binding, got {}", - mask.dtype() - ); - } - - // Check for a constant mask - if let Some(scalar) = mask.as_constant() { - let constant = scalar - .as_bool() - .value() - .vortex_expect("checked non-nullable"); - let len = mask.len(); - if constant { - return Ok(MaskExecution::AllTrue(len)); - } else { - return Ok(MaskExecution::AllFalse(len)); - } - } - - // TODO(ngates): we may want to support creating masks from iterator of slices, in which - // case we could check for run-end encoding here? - - // If none of the above patterns match, we fall back to canonicalizing. - let execution = self.bind(mask, None)?; - Ok(MaskExecution::lazy(move || { - let mask = execution.execute()?.into_bool(); - Ok(Mask::from(mask.bits().clone())) - })) - } -} diff --git a/vortex-array/src/execution/mod.rs b/vortex-array/src/execution/mod.rs deleted file mode 100644 index 0c20c86f901..00000000000 --- a/vortex-array/src/execution/mod.rs +++ /dev/null @@ -1,30 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -mod batch; -mod mask; -mod validity; - -pub use batch::*; -pub use mask::*; -use vortex_session::VortexSession; - -/// Execution context for batch array compute. -// NOTE(ngates): This context will eventually hold cached resources for execution, such as CSE -// nodes, and may well eventually support a type-map interface for arrays to stash arbitrary -// execution-related data. -pub struct ExecutionCtx { - session: VortexSession, -} - -impl ExecutionCtx { - /// Create a new execution context with the given session. - pub(crate) fn new(session: VortexSession) -> Self { - Self { session } - } - - /// Get the session associated with this execution context. - pub fn session(&self) -> &VortexSession { - &self.session - } -} diff --git a/vortex-array/src/execution/validity.rs b/vortex-array/src/execution/validity.rs deleted file mode 100644 index 507f16d6b26..00000000000 --- a/vortex-array/src/execution/validity.rs +++ /dev/null @@ -1,51 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_compute::filter::Filter; -use vortex_error::VortexResult; -use vortex_mask::Mask; - -use crate::ArrayRef; -use crate::execution::BindCtx; -use crate::execution::MaskExecution; -use crate::validity::Validity; - -impl dyn BindCtx + '_ { - /// Bind a validity helper into a [`MaskExecution`]. - pub fn bind_validity( - &mut self, - validity: &Validity, - array_len: usize, - selection: Option<&ArrayRef>, - ) -> VortexResult { - match selection { - None => match validity { - Validity::NonNullable | Validity::AllValid => Ok(MaskExecution::AllTrue(array_len)), - Validity::AllInvalid => Ok(MaskExecution::AllFalse(array_len)), - Validity::Array(validity) => self.bind_mask(validity), - }, - Some(selection) => { - let selection = self.bind_mask(selection)?; - match validity { - Validity::NonNullable | Validity::AllValid => { - Ok(MaskExecution::lazy(move || { - Ok(Mask::AllTrue(selection.execute()?.true_count())) - })) - } - Validity::AllInvalid => Ok(MaskExecution::lazy(move || { - Ok(Mask::AllFalse(selection.execute()?.true_count())) - })), - Validity::Array(validity) => { - let validity = self.bind_mask(validity)?; - Ok(MaskExecution::lazy(move || { - let validity = validity.execute()?; - let selection = selection.execute()?; - // We perform a take on the validity mask using the selection mask. - Ok(validity.filter(&selection)) - })) - } - } - } - } - } -} diff --git a/vortex-array/src/executor.rs b/vortex-array/src/executor.rs new file mode 100644 index 00000000000..af321f7e34a --- /dev/null +++ b/vortex-array/src/executor.rs @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_session::VortexSession; +use vortex_vector::Datum; +use vortex_vector::Vector; +use vortex_vector::VectorOps; + +use crate::Array; +use crate::ArrayRef; +use crate::arrays::ConstantVTable; +use crate::kernel::BindCtx; +use crate::session::ArraySessionExt; + +/// Executor for exporting a Vortex [`Vector`] or [`Datum`] from an [`ArrayRef`]. +pub trait VectorExecutor { + /// Execute the array and return the resulting datum after running the optimizer. + fn execute_datum_optimized(&self, session: &VortexSession) -> VortexResult; + /// Execute the array and return the resulting datum. + fn execute_datum(&self, session: &VortexSession) -> VortexResult; + /// Execute the array and return the resulting vector after running the optimizer. + fn execute_vector_optimized(&self, session: &VortexSession) -> VortexResult; + /// Execute the array and return the resulting vector. + fn execute_vector(&self, session: &VortexSession) -> VortexResult; +} + +impl VectorExecutor for ArrayRef { + fn execute_datum_optimized(&self, session: &VortexSession) -> VortexResult { + session + .arrays() + .optimizer() + .optimize_array(self)? + .execute_datum(session) + } + + fn execute_datum(&self, session: &VortexSession) -> VortexResult { + // Attempt to short-circuit constant arrays. + if let Some(constant) = self.as_opt::() { + return Ok(Datum::Scalar(constant.scalar().to_vector_scalar())); + } + + let mut ctx = BindCtx::new(session.clone()); + + // NOTE(ngates): in the future we can choose a different mode of execution, or run + // optimization here, etc. + let kernel = self.bind_kernel(&mut ctx)?; + let result = kernel.execute()?; + + vortex_ensure!( + result.len() == self.len(), + "Result length mismatch for {}", + self.encoding_id() + ); + + #[cfg(debug_assertions)] + { + vortex_ensure!( + vortex_vector::vector_matches_dtype(&result, self.dtype()), + "Executed vector dtype mismatch for {}", + self.encoding_id(), + ); + } + + Ok(Datum::Vector(result)) + } + + fn execute_vector_optimized(&self, session: &VortexSession) -> VortexResult { + session + .arrays() + .optimizer() + .optimize_array(self)? + .execute_vector(session) + } + + fn execute_vector(&self, session: &VortexSession) -> VortexResult { + let len = self.len(); + Ok(self.execute_datum(session)?.ensure_vector(len)) + } +} diff --git a/vortex-array/src/expr/analysis/fallible.rs b/vortex-array/src/expr/analysis/fallible.rs index a7cfb2835df..c1ffb2bffda 100644 --- a/vortex-array/src/expr/analysis/fallible.rs +++ b/vortex-array/src/expr/analysis/fallible.rs @@ -6,7 +6,11 @@ use crate::expr::analysis::BooleanLabels; use crate::expr::label_tree; pub fn label_is_fallible(expr: &Expression) -> BooleanLabels<'_> { - label_tree(expr, |expr| expr.is_fallible(), |acc, &child| acc | child) + label_tree( + expr, + |expr| expr.signature().is_fallible(), + |acc, &child| acc | child, + ) } #[cfg(test)] diff --git a/vortex-array/src/expr/analysis/immediate_access.rs b/vortex-array/src/expr/analysis/immediate_access.rs index c05679b2c98..6dc26a7c950 100644 --- a/vortex-array/src/expr/analysis/immediate_access.rs +++ b/vortex-array/src/expr/analysis/immediate_access.rs @@ -24,9 +24,9 @@ pub fn annotate_scope_access(scope: &StructFields) -> impl AnnotationFn() { - if get_item.child(0).is::() { - return vec![get_item.data().clone()]; + if let Some(field_name) = expr.as_opt::() { + if expr.child(0).is::() { + return vec![field_name.clone()]; } } else if expr.is::() { return scope.names().iter().cloned().collect(); diff --git a/vortex-array/src/expr/analysis/null_sensitive.rs b/vortex-array/src/expr/analysis/null_sensitive.rs index 1711792bf9e..5e1c4d0c996 100644 --- a/vortex-array/src/expr/analysis/null_sensitive.rs +++ b/vortex-array/src/expr/analysis/null_sensitive.rs @@ -15,7 +15,7 @@ pub type BooleanLabels<'a> = HashMap<&'a Expression, bool>; pub fn label_null_sensitive(expr: &Expression) -> BooleanLabels<'_> { label_tree( expr, - |expr| expr.is_null_sensitive(), + |expr| expr.signature().is_null_sensitive(), |acc, &child| acc | child, ) } diff --git a/vortex-array/src/expr/display.rs b/vortex-array/src/expr/display.rs index abc5460d429..8f2f4161382 100644 --- a/vortex-array/src/expr/display.rs +++ b/vortex-array/src/expr/display.rs @@ -3,8 +3,10 @@ use std::fmt::Display; use std::fmt::Formatter; +use std::ops::Deref; use crate::expr::Expression; +use crate::expr::ScalarFn; pub enum DisplayFormat { Compact, @@ -17,10 +19,11 @@ impl Display for DisplayTreeExpr<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { pub use termtree::Tree; fn make_tree(expr: &Expression) -> Result, std::fmt::Error> { - let node_name = format!("{}", ExpressionDebug(expr)); + let scalar_fn: &ScalarFn = expr.deref(); + let node_name = format!("{}", scalar_fn); // Get child names for display purposes - let child_names = (0..expr.children().len()).map(|i| expr.child_name(i)); + let child_names = (0..expr.children().len()).map(|i| expr.signature().child_name(i)); let children = expr.children(); let child_trees: Result>, _> = children @@ -40,18 +43,6 @@ impl Display for DisplayTreeExpr<'_> { } } -struct ExpressionDebug<'a>(&'a Expression); -impl Display for ExpressionDebug<'_> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - // Special-case when expression has no data to avoid trailing space. - if self.0.data().is::<()>() { - return write!(f, "{}", self.0.id().as_ref()); - } - write!(f, "{} ", self.0.id().as_ref())?; - self.0.vtable().as_dyn().fmt_data(self.0.data().as_ref(), f) - } -} - #[cfg(test)] mod tests { use vortex_dtype::DType; @@ -104,76 +95,111 @@ mod tests { } #[test] - fn test_display_tree() { + fn test_display_tree_root() { use insta::assert_snapshot; - let root_expr = root(); - assert_snapshot!(root_expr.display_tree().to_string(), @"vortex.root"); + assert_snapshot!(root_expr.display_tree().to_string(), @"vortex.root()"); + } + #[test] + fn test_display_tree_literal() { + use insta::assert_snapshot; let lit_expr = lit(42); - assert_snapshot!(lit_expr.display_tree().to_string(), @"vortex.literal 42i32"); + assert_snapshot!(lit_expr.display_tree().to_string(), @"vortex.literal(42i32)"); + } + #[test] + fn test_display_tree_get_item() { + use insta::assert_snapshot; let get_item_expr = get_item("my_field", root()); - assert_snapshot!(get_item_expr.display_tree().to_string(), @r#" - vortex.get_item "my_field" - └── input: vortex.root - "#); + assert_snapshot!(get_item_expr.display_tree().to_string(), @r" + vortex.get_item(my_field) + └── input: vortex.root() + "); + } + #[test] + fn test_display_tree_binary() { + use insta::assert_snapshot; let binary_expr = gt(get_item("x", root()), lit(10)); - assert_snapshot!(binary_expr.display_tree().to_string(), @r#" - vortex.binary > - ├── lhs: vortex.get_item "x" - │ └── input: vortex.root - └── rhs: vortex.literal 10i32 - "#); + assert_snapshot!(binary_expr.display_tree().to_string(), @r" + vortex.binary(>) + ├── lhs: vortex.get_item(x) + │ └── input: vortex.root() + └── rhs: vortex.literal(10i32) + "); + } + #[test] + fn test_display_tree_complex_binary() { + use insta::assert_snapshot; let complex_binary = and( eq(get_item("name", root()), lit("alice")), gt(get_item("age", root()), lit(18)), ); assert_snapshot!(complex_binary.display_tree().to_string(), @r#" - vortex.binary and - ├── lhs: vortex.binary = - │ ├── lhs: vortex.get_item "name" - │ │ └── input: vortex.root - │ └── rhs: vortex.literal "alice" - └── rhs: vortex.binary > - ├── lhs: vortex.get_item "age" - │ └── input: vortex.root - └── rhs: vortex.literal 18i32 + vortex.binary(and) + ├── lhs: vortex.binary(=) + │ ├── lhs: vortex.get_item(name) + │ │ └── input: vortex.root() + │ └── rhs: vortex.literal("alice") + └── rhs: vortex.binary(>) + ├── lhs: vortex.get_item(age) + │ └── input: vortex.root() + └── rhs: vortex.literal(18i32) "#); + } + #[test] + fn test_display_tree_select() { + use insta::assert_snapshot; let select_expr = select(["name", "age"], root()); assert_snapshot!(select_expr.display_tree().to_string(), @r" - vortex.select include={name, age} - └── child: vortex.root + vortex.select({name, age}) + └── child: vortex.root() "); + } + #[test] + fn test_display_tree_select_exclude() { + use insta::assert_snapshot; let select_exclude_expr = select_exclude(["internal_id", "metadata"], root()); assert_snapshot!(select_exclude_expr.display_tree().to_string(), @r" - vortex.select exclude={internal_id, metadata} - └── child: vortex.root + vortex.select(~{internal_id, metadata}) + └── child: vortex.root() "); + } + #[test] + fn test_display_tree_cast() { + use insta::assert_snapshot; let cast_expr = cast( get_item("value", root()), DType::Primitive(PType::I64, Nullability::NonNullable), ); - assert_snapshot!(cast_expr.display_tree().to_string(), @r#" - vortex.cast i64 - └── input: vortex.get_item "value" - └── input: vortex.root - "#); + assert_snapshot!(cast_expr.display_tree().to_string(), @r" + vortex.cast(i64) + └── input: vortex.get_item(value) + └── input: vortex.root() + "); + } + #[test] + fn test_display_tree_not() { + use insta::assert_snapshot; let not_expr = not(eq(get_item("active", root()), lit(true))); - assert_snapshot!(not_expr.display_tree().to_string(), @r#" - vortex.not - └── input: vortex.binary = - ├── lhs: vortex.get_item "active" - │ └── input: vortex.root - └── rhs: vortex.literal true - "#); + assert_snapshot!(not_expr.display_tree().to_string(), @r" + vortex.not() + └── input: vortex.binary(=) + ├── lhs: vortex.get_item(active) + │ └── input: vortex.root() + └── rhs: vortex.literal(true) + "); + } + #[test] + fn test_display_tree_between() { + use insta::assert_snapshot; let between_expr = between( get_item("score", root()), lit(0), @@ -183,15 +209,18 @@ mod tests { upper_strict: StrictComparison::NonStrict, }, ); - assert_snapshot!(between_expr.display_tree().to_string(), @r#" - vortex.between BetweenOptions { lower_strict: NonStrict, upper_strict: NonStrict } - ├── array: vortex.get_item "score" - │ └── input: vortex.root - ├── lower: vortex.literal 0i32 - └── upper: vortex.literal 100i32 - "#); + assert_snapshot!(between_expr.display_tree().to_string(), @r" + vortex.between(lower_strict: <=, upper_strict: <=) + ├── array: vortex.get_item(score) + │ └── input: vortex.root() + ├── lower: vortex.literal(0i32) + └── upper: vortex.literal(100i32) + "); + } - // Test nested expression + #[test] + fn test_display_tree_nested() { + use insta::assert_snapshot; let nested_expr = select( ["result"], cast( @@ -207,16 +236,20 @@ mod tests { DType::Bool(Nullability::NonNullable), ), ); - assert_snapshot!(nested_expr.display_tree().to_string(), @r#" - vortex.select include={result} - └── child: vortex.cast bool - └── input: vortex.between BetweenOptions { lower_strict: Strict, upper_strict: NonStrict } - ├── array: vortex.get_item "score" - │ └── input: vortex.root - ├── lower: vortex.literal 50i32 - └── upper: vortex.literal 100i32 - "#); + assert_snapshot!(nested_expr.display_tree().to_string(), @r" + vortex.select({result}) + └── child: vortex.cast(bool) + └── input: vortex.between(lower_strict: <, upper_strict: <=) + ├── array: vortex.get_item(score) + │ └── input: vortex.root() + ├── lower: vortex.literal(50i32) + └── upper: vortex.literal(100i32) + "); + } + #[test] + fn test_display_tree_pack() { + use insta::assert_snapshot; let select_from_pack_expr = select( ["fizz", "buzz"], pack( @@ -228,15 +261,15 @@ mod tests { Nullability::Nullable, ), ); - assert_snapshot!(select_from_pack_expr.display_tree().to_string(), @r#" - vortex.select include={fizz, buzz} - └── child: vortex.pack PackOptions { names: FieldNames([FieldName("fizz"), FieldName("bar"), FieldName("buzz")]), nullability: Nullable } - ├── fizz: vortex.root - ├── bar: vortex.literal 5i32 - └── buzz: vortex.binary = - ├── lhs: vortex.literal 42i32 - └── rhs: vortex.get_item "answer" - └── input: vortex.root - "#); + assert_snapshot!(select_from_pack_expr.display_tree().to_string(), @r" + vortex.select({fizz, buzz}) + └── child: vortex.pack(names: [fizz, bar, buzz], nullability: ?) + ├── fizz: vortex.root() + ├── bar: vortex.literal(5i32) + └── buzz: vortex.binary(=) + ├── lhs: vortex.literal(42i32) + └── rhs: vortex.get_item(answer) + └── input: vortex.root() + "); } } diff --git a/vortex-array/src/expr/expression.rs b/vortex-array/src/expr/expression.rs index fa798b4b17c..05ea00601a3 100644 --- a/vortex-array/src/expr/expression.rs +++ b/vortex-array/src/expr/expression.rs @@ -1,29 +1,23 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::any::Any; use std::fmt; use std::fmt::Debug; use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hash; -use std::hash::Hasher; +use std::ops::Deref; use std::sync::Arc; use itertools::Itertools; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_vector::Vector; -use vortex_vector::VectorOps; +use vortex_error::vortex_ensure; use crate::ArrayRef; -use crate::expr::ChildName; -use crate::expr::ExecutionArgs; -use crate::expr::ExprId; -use crate::expr::ExprVTable; -use crate::expr::ExpressionView; use crate::expr::Root; +use crate::expr::ScalarFn; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::display::DisplayTreeExpr; @@ -33,96 +27,62 @@ use crate::expr::stats::Stat; /// /// Expressions represent scalar computations that can be performed on data. Each /// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions. -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Expression { - /// The vtable for this expression. - vtable: ExprVTable, - /// The instance data for this expression. - data: Arc, + /// The scalar fn for this node. + scalar_fn: ScalarFn, /// Any children of this expression. children: Arc<[Expression]>, } +impl Deref for Expression { + type Target = ScalarFn; + + fn deref(&self) -> &Self::Target { + &self.scalar_fn + } +} + impl Expression { - /// Create a new expression from a vtable. - pub fn try_new( - vtable: V, - data: V::Instance, + /// Create a new expression node from a scalar_fn expression and its children. + pub fn try_new( + scalar_fn: ScalarFn, children: impl Into>, ) -> VortexResult { - let vtable = ExprVTable::new::(vtable); - let data = Arc::new(data); - Self::try_new_erased(vtable, data, children.into()) - } + let children: Arc<[Expression]> = children.into(); - /// Create a new expression from a static vtable. - pub fn new_static( - vtable: &'static V, - data: V::Instance, - children: impl Into>, - ) -> Self { - let vtable = ExprVTable::new_static::(vtable); - let data = Arc::new(data); - Self { - vtable, - data, - children: children.into(), - } - } + vortex_ensure!( + scalar_fn.signature().arity().matches(children.len()), + "Expression arity mismatch: expected {} children but got {}", + scalar_fn.signature().arity(), + children.len() + ); - /// Creates a new expression with the given encoding, metadata, and children. - /// - /// # Errors - /// - /// Returns an error if the provided `encoding` is not compatible with the - /// `metadata` and `children` or the encoding's own validation logic fails. - pub(super) fn try_new_erased( - vtable: ExprVTable, - data: Arc, - children: Arc<[Expression]>, - ) -> VortexResult { - let this = Self { - vtable, - data, + Ok(Self { + scalar_fn, children, - }; - // Validate that the encoding is compatible with the metadata and children. - this.vtable.as_dyn().validate(&this)?; - Ok(this) + }) } - /// Returns if the expression is an instance of the given vtable. + /// Returns true if this expression is of the given vtable type. pub fn is(&self) -> bool { - self.vtable.is::() - } - - /// Returns a typed view of this expression for the given vtable. - /// - /// # Panics - /// - /// Panics if the expression's encoding or metadata cannot be cast to the specified vtable. - pub fn as_(&self) -> ExpressionView<'_, V> { - ExpressionView::maybe_new(self).vortex_expect("Failed to downcast expression {} to {}") + self.vtable().is::() } - /// Returns a typed view of this expression for the given vtable, if the types match. - pub fn as_opt(&self) -> Option> { - ExpressionView::maybe_new(self) + /// Returns the typed options for this expression if it matches the given vtable type. + pub fn as_opt(&self) -> Option<&V::Options> { + self.options().as_any().downcast_ref::() } - /// Returns the expression ID. - pub fn id(&self) -> ExprId { - self.vtable.as_dyn().id() + /// Returns the typed options for this expression if it matches the given vtable type. + pub fn as_(&self) -> &V::Options { + self.as_opt::() + .vortex_expect("Expression options type mismatch") } - /// Returns the expression's vtable. - pub fn vtable(&self) -> &ExprVTable { - &self.vtable - } - - /// Returns the opaque data of the expression. - pub fn data(&self) -> &Arc { - &self.data + /// Returns the scalar fn vtable for this expression. + pub fn scalar_fn(&self) -> &ScalarFn { + &self.scalar_fn } /// Returns the children of this expression. @@ -135,60 +95,39 @@ impl Expression { &self.children[n] } - /// Returns the name of the n'th child of this expression. - pub fn child_name(&self, n: usize) -> ChildName { - self.vtable.as_dyn().child_name(self.data().as_ref(), n) - } - /// Replace the children of this expression with the provided new children. pub fn with_children(mut self, children: impl Into>) -> VortexResult { - self.children = children.into(); - self.vtable.as_dyn().validate(&self)?; + let children = children.into(); + vortex_ensure!( + self.signature().arity().matches(children.len()), + "Expression arity mismatch: expected {} children but got {}", + self.signature().arity(), + children.len() + ); + self.children = children; Ok(self) } - /// Returns the serialized metadata for this expression. - pub fn serialize_metadata(&self) -> VortexResult>> { - self.vtable.as_dyn().serialize(self.data.as_ref()) - } - /// Computes the return dtype of this expression given the input dtype. pub fn return_dtype(&self, scope: &DType) -> VortexResult { - self.vtable.as_dyn().return_dtype(self, scope) - } - - /// Evaluates the expression in the given scope. - pub fn evaluate(&self, scope: &ArrayRef) -> VortexResult { - self.vtable.as_dyn().evaluate(self, scope) - } - - /// Executes the expression over the given vector input scope. - pub fn execute(&self, vector: &Vector, dtype: &DType) -> VortexResult { - // We special-case the "root" expression that must extract that scope vector directly. if self.is::() { - return Ok(vector.clone()); + return Ok(scope.clone()); } - let return_dtype = self.return_dtype(dtype)?; - let child_dtypes: Vec<_> = self - .children - .iter() - .map(|child| child.return_dtype(dtype)) - .try_collect()?; - let child_vectors: Vec<_> = self + let dtypes: Vec<_> = self .children .iter() - .map(|child| child.execute(vector, dtype)) + .map(|c| c.return_dtype(scope)) .try_collect()?; + self.scalar_fn.return_dtype(&dtypes) + } - let args = ExecutionArgs { - vectors: child_vectors, - dtypes: child_dtypes, - row_count: vector.len(), - return_dtype, - }; - - self.vtable.as_dyn().execute(&self.data, args) + /// Evaluates the expression in the given scope, returning an array. + pub fn evaluate(&self, scope: &ArrayRef) -> VortexResult { + if self.is::() { + return Ok(scope.clone()); + } + self.scalar_fn.evaluate(self, scope) } /// An expression over zone-statistics which implies all records in the zone evaluate to false. @@ -210,7 +149,7 @@ impl Expression { /// Some expressions, in theory, have falsifications but this function does not support them /// such as `x < (y < z)` or `x LIKE "needle%"`. pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option { - self.vtable.as_dyn().stat_falsification(self, catalog) + self.vtable().as_dyn().stat_falsification(self, catalog) } /// Returns an expression representing the zoned statistic for the given stat, if available. @@ -222,46 +161,25 @@ impl Expression { /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an /// issue to discuss a solution to this. pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option { - self.vtable.as_dyn().stat_expression(self, stat, catalog) + self.vtable().as_dyn().stat_expression(self, stat, catalog) } /// Returns an expression representing the zoned maximum statistic, if available. - /// - /// See [`Self::stat_expression`] for details. pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option { self.stat_expression(Stat::Min, catalog) } /// Returns an expression representing the zoned maximum statistic, if available. - /// - /// See [`Self::stat_expression`] for details. pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option { self.stat_expression(Stat::Max, catalog) } - /// Returns whether this expression itself is null-sensitive. - /// See [`VTable::is_null_sensitive`]. - pub fn is_null_sensitive(&self) -> bool { - self.vtable.as_dyn().is_null_sensitive(self.data.as_ref()) - } - - /// Returns whether this expression itself is fallible. - /// See [`VTable::is_fallible`]. - pub fn is_fallible(&self) -> bool { - self.vtable.as_dyn().is_fallible(self.data.as_ref()) - } - /// Format the expression as a compact string. /// /// Since this is a recursive formatter, it is exposed on the public Expression type. /// See fmt_data that is only implemented on the vtable trait. pub fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.vtable.as_dyn().fmt_sql(self, f) - } - - /// Format the instance data of the expression as a string for rendering.. - pub fn fmt_data(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.vtable.as_dyn().fmt_data(self.data().as_ref(), f) + self.vtable().as_dyn().fmt_sql(self, f) } /// Display the expression as a formatted tree structure. @@ -327,50 +245,3 @@ impl Display for Expression { self.fmt_sql(f) } } - -struct FormatExpressionData<'a> { - vtable: &'a ExprVTable, - data: &'a Arc, -} - -impl<'a> Debug for FormatExpressionData<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.vtable.as_dyn().fmt_data(self.data.as_ref(), f) - } -} - -impl Debug for Expression { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("Expression") - .field("vtable", &self.vtable) - .field( - "data", - &FormatExpressionData { - vtable: &self.vtable, - data: &self.data, - }, - ) - .field("children", &self.children) - .finish() - } -} - -impl PartialEq for Expression { - fn eq(&self, other: &Self) -> bool { - self.vtable.as_dyn().id() == other.vtable.as_dyn().id() - && self - .vtable - .as_dyn() - .dyn_eq(self.data.as_ref(), other.data.as_ref()) - && self.children.eq(&other.children) - } -} -impl Eq for Expression {} - -impl Hash for Expression { - fn hash(&self, state: &mut H) { - self.vtable.as_dyn().id().hash(state); - self.vtable.as_dyn().dyn_hash(self.data.as_ref(), state); - self.children.hash(state); - } -} diff --git a/vortex-array/src/expr/exprs/between.rs b/vortex-array/src/expr/exprs/between.rs index 2135a5ecd8a..a0689b1dd34 100644 --- a/vortex-array/src/expr/exprs/between.rs +++ b/vortex-array/src/expr/exprs/between.rs @@ -9,14 +9,17 @@ use vortex_dtype::DType::Bool; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_error::vortex_err; use vortex_proto::expr as pb; +use vortex_vector::Datum; use crate::ArrayRef; use crate::compute::BetweenOptions; use crate::compute::between as between_compute; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -38,13 +41,13 @@ use crate::expr::exprs::operators::Operator; pub struct Between; impl VTable for Between { - type Instance = BetweenOptions; + type Options = BetweenOptions; fn id(&self) -> ExprId { ExprId::from("vortex.between") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::BetweenOpts { lower_strict: instance.lower_strict.is_strict(), @@ -54,9 +57,9 @@ impl VTable for Between { )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let opts = pb::BetweenOpts::decode(metadata)?; - Ok(Some(BetweenOptions { + Ok(BetweenOptions { lower_strict: if opts.lower_strict { crate::compute::StrictComparison::Strict } else { @@ -67,20 +70,14 @@ impl VTable for Between { } else { crate::compute::StrictComparison::NonStrict }, - })) + }) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 3 { - vortex_bail!( - "Between expression requires exactly 3 children, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(3) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("array"), 1 => ChildName::from("lower"), @@ -89,8 +86,12 @@ impl VTable for Between { } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { - let options = expr.data(); + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { let lower_op = if options.lower_strict.is_strict() { "<" } else { @@ -104,27 +105,27 @@ impl VTable for Between { write!( f, "({} {} {} {} {})", - expr.lower(), + expr.child(1), lower_op, - expr.child(), + expr.child(0), upper_op, - expr.upper() + expr.child(2) ) } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let arr_dt = expr.child().return_dtype(scope)?; - let lower_dt = expr.lower().return_dtype(scope)?; - let upper_dt = expr.upper().return_dtype(scope)?; + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let arr_dt = &arg_dtypes[0]; + let lower_dt = &arg_dtypes[1]; + let upper_dt = &arg_dtypes[2]; - if !arr_dt.eq_ignore_nullability(&lower_dt) { + if !arr_dt.eq_ignore_nullability(lower_dt) { vortex_bail!( "Array dtype {} does not match lower dtype {}", arr_dt, lower_dt ); } - if !arr_dt.eq_ignore_nullability(&upper_dt) { + if !arr_dt.eq_ignore_nullability(upper_dt) { vortex_bail!( "Array dtype {} does not match upper dtype {}", arr_dt, @@ -137,51 +138,76 @@ impl VTable for Between { )) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - let arr = expr.child().evaluate(scope)?; - let lower = expr.lower().evaluate(scope)?; - let upper = expr.upper().evaluate(scope)?; - between_compute(&arr, &lower, &upper, expr.data()) - } - - fn stat_falsification( + fn evaluate( &self, - expr: &ExpressionView, - catalog: &dyn StatsCatalog, - ) -> Option { - expr.to_binary_expr().stat_falsification(catalog) + options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + let arr = expr.child(0).evaluate(scope)?; + let lower = expr.child(1).evaluate(scope)?; + let upper = expr.child(2).evaluate(scope)?; + between_compute(&arr, &lower, &upper, options) } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { - false - } -} + fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult { + let [arr, lower, upper]: [Datum; _] = args + .datums + .try_into() + .map_err(|_| vortex_err!("Expected 3 arguments for Between expression",))?; + let [arr_dt, lower_dt, upper_dt]: [DType; _] = args + .dtypes + .try_into() + .map_err(|_| vortex_err!("Expected 3 dtypes for Between expression",))?; -impl ExpressionView<'_, Between> { - pub fn child(&self) -> &Expression { - &self.children()[0] - } - - pub fn lower(&self) -> &Expression { - &self.children()[1] - } + let lower_bound = Binary + .bind(options.lower_strict.to_operator().into()) + .execute(ExecutionArgs { + datums: vec![lower, arr.clone()], + dtypes: vec![lower_dt, arr_dt.clone()], + row_count: args.row_count, + return_dtype: args.return_dtype.clone(), + })?; + let upper_bound = Binary + .bind(options.upper_strict.to_operator().into()) + .execute(ExecutionArgs { + datums: vec![arr, upper], + dtypes: vec![arr_dt, upper_dt], + row_count: args.row_count, + return_dtype: args.return_dtype.clone(), + })?; - pub fn upper(&self) -> &Expression { - &self.children()[2] + Binary.bind(Operator::And).execute(ExecutionArgs { + datums: vec![lower_bound, upper_bound], + dtypes: vec![args.return_dtype.clone(), args.return_dtype.clone()], + row_count: args.row_count, + return_dtype: args.return_dtype, + }) } - pub fn to_binary_expr(&self) -> Expression { - let options = self.data(); - let arr = self.children()[0].clone(); - let lower = self.children()[1].clone(); - let upper = self.children()[2].clone(); + fn stat_falsification( + &self, + options: &Self::Options, + expr: &Expression, + catalog: &dyn StatsCatalog, + ) -> Option { + let arr = expr.child(0).clone(); + let lower = expr.child(1).clone(); + let upper = expr.child(2).clone(); let lhs = Binary.new_expr( options.lower_strict.to_operator().into(), [lower, arr.clone()], ); let rhs = Binary.new_expr(options.upper_strict.to_operator().into(), [arr, upper]); - Binary.new_expr(Operator::And, [lhs, rhs]) + + Binary + .new_expr(Operator::And, [lhs, rhs]) + .stat_falsification(catalog) + } + + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { + false } } diff --git a/vortex-array/src/expr/exprs/binary.rs b/vortex-array/src/expr/exprs/binary.rs index dfe1bb03f31..af7a19ec36a 100644 --- a/vortex-array/src/expr/exprs/binary.rs +++ b/vortex-array/src/expr/exprs/binary.rs @@ -3,12 +3,18 @@ use std::fmt::Formatter; +use arrow_ord::cmp; use prost::Message; +use vortex_compute::arrow::IntoArrow; +use vortex_compute::arrow::IntoVector; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_error::vortex_err; use vortex_proto::expr as pb; +use vortex_vector::Datum; +use vortex_vector::VectorOps; use crate::ArrayRef; use crate::compute; @@ -19,10 +25,10 @@ use crate::compute::div; use crate::compute::mul; use crate::compute::or_kleene; use crate::compute::sub; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; -use crate::expr::ExpressionView; -use crate::expr::ScalarFnExprExt; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -30,18 +36,17 @@ use crate::expr::expression::Expression; use crate::expr::exprs::literal::lit; use crate::expr::exprs::operators::Operator; use crate::expr::stats::Stat; -use crate::scalar_fns::binary; pub struct Binary; impl VTable for Binary { - type Instance = Operator; + type Options = Operator; fn id(&self) -> ExprId { ExprId::from("vortex.binary") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::BinaryOpts { op: (*instance).into(), @@ -50,17 +55,16 @@ impl VTable for Binary { )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let opts = pb::BinaryOpts::decode(metadata)?; - Ok(Some(Operator::try_from(opts.op)?)) + Operator::try_from(opts.op) } - fn validate(&self, _expr: &ExpressionView) -> VortexResult<()> { - // TODO(ngates): check the dtypes. - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("lhs"), 1 => ChildName::from("rhs"), @@ -68,24 +72,25 @@ impl VTable for Binary { } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + operator: &Operator, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "(")?; - expr.lhs().fmt_sql(f)?; - write!(f, " {} ", expr.operator())?; - expr.rhs().fmt_sql(f)?; + expr.child(0).fmt_sql(f)?; + write!(f, " {} ", operator)?; + expr.child(1).fmt_sql(f)?; write!(f, ")") } - fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", *instance) - } - - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let lhs = expr.lhs().return_dtype(scope)?; - let rhs = expr.rhs().return_dtype(scope)?; + fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult { + let lhs = &arg_dtypes[0]; + let rhs = &arg_dtypes[1]; - if expr.operator().is_arithmetic() { - if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) { + if operator.is_arithmetic() { + if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) { return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability())); } vortex_bail!( @@ -98,11 +103,16 @@ impl VTable for Binary { Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into())) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - let lhs = expr.lhs().evaluate(scope)?; - let rhs = expr.rhs().evaluate(scope)?; - - match expr.operator() { + fn evaluate( + &self, + operator: &Operator, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + let lhs = expr.child(0).evaluate(scope)?; + let rhs = expr.child(1).evaluate(scope)?; + + match operator { Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq), Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq), Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt), @@ -118,9 +128,80 @@ impl VTable for Binary { } } + fn execute(&self, op: &Operator, args: ExecutionArgs) -> VortexResult { + let [lhs, rhs]: [Datum; _] = args + .datums + .try_into() + .map_err(|_| vortex_err!("Wrong arg count"))?; + + match op { + Operator::And => { + // FIXME(ngates): implement logical compute over datums + let lhs = lhs.ensure_vector(args.row_count).into_bool().into_arrow()?; + let rhs = rhs.ensure_vector(args.row_count).into_bool().into_arrow()?; + return Ok(Datum::Vector( + arrow_arith::boolean::and_kleene(&lhs, &rhs)? + .into_vector()? + .into(), + )); + } + Operator::Or => { + // FIXME(ngates): implement logical compute over datums + let lhs = lhs.ensure_vector(args.row_count).into_bool().into_arrow()?; + let rhs = rhs.ensure_vector(args.row_count).into_bool().into_arrow()?; + return Ok(Datum::Vector( + arrow_arith::boolean::or_kleene(&lhs, &rhs)? + .into_vector()? + .into(), + )); + } + _ => {} + } + + let lhs = lhs.into_arrow()?; + let rhs = rhs.into_arrow()?; + + let vector = match op { + Operator::Eq => cmp::eq(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(), + Operator::NotEq => cmp::neq(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(), + Operator::Gt => cmp::gt(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(), + Operator::Gte => cmp::gt_eq(lhs.as_ref(), rhs.as_ref())? + .into_vector()? + .into(), + Operator::Lt => cmp::lt(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(), + Operator::Lte => cmp::lt_eq(lhs.as_ref(), rhs.as_ref())? + .into_vector()? + .into(), + + Operator::Add => { + arrow_arith::numeric::add(lhs.as_ref(), rhs.as_ref())?.into_vector()? + } + Operator::Sub => { + arrow_arith::numeric::sub(lhs.as_ref(), rhs.as_ref())?.into_vector()? + } + Operator::Mul => { + arrow_arith::numeric::mul(lhs.as_ref(), rhs.as_ref())?.into_vector()? + } + Operator::Div => { + arrow_arith::numeric::div(lhs.as_ref(), rhs.as_ref())?.into_vector()? + } + Operator::And | Operator::Or => { + unreachable!("Already dealt with above") + } + }; + + // Arrow computed over scalar datums + if vector.len() == 1 && args.row_count != 1 { + return Ok(Datum::Scalar(vector.scalar_at(0))); + } + + Ok(Datum::Vector(vector)) + } + fn stat_falsification( &self, - expr: &ExpressionView, + operator: &Operator, + expr: &Expression, catalog: &dyn StatsCatalog, ) -> Option { // Wrap another predicate with an optional NaNCount check, if the stat is available. @@ -157,13 +238,15 @@ impl VTable for Binary { } } - match expr.operator() { + let lhs = expr.child(0); + let rhs = expr.child(1); + match operator { Operator::Eq => { - let min_lhs = expr.lhs().stat_min(catalog); - let max_lhs = expr.lhs().stat_max(catalog); + let min_lhs = lhs.stat_min(catalog); + let max_lhs = lhs.stat_max(catalog); - let min_rhs = expr.rhs().stat_min(catalog); - let max_rhs = expr.rhs().stat_max(catalog); + let min_rhs = rhs.stat_min(catalog); + let max_rhs = rhs.stat_max(catalog); let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b)); let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b)); @@ -171,99 +254,64 @@ impl VTable for Binary { let min_max_check = left.into_iter().chain(right).reduce(or)?; // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } Operator::NotEq => { - let min_lhs = expr.lhs().stat_min(catalog)?; - let max_lhs = expr.lhs().stat_max(catalog)?; + let min_lhs = lhs.stat_min(catalog)?; + let max_lhs = lhs.stat_max(catalog)?; - let min_rhs = expr.rhs().stat_min(catalog)?; - let max_rhs = expr.rhs().stat_max(catalog)?; + let min_rhs = rhs.stat_min(catalog)?; + let max_rhs = rhs.stat_max(catalog)?; let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)); - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } Operator::Gt => { - let min_max_check = - lt_eq(expr.lhs().stat_max(catalog)?, expr.rhs().stat_min(catalog)?); - - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); + + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } Operator::Gte => { // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = - lt(expr.lhs().stat_max(catalog)?, expr.rhs().stat_min(catalog)?); - - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); + + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } Operator::Lt => { // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = - gt_eq(expr.lhs().stat_min(catalog)?, expr.rhs().stat_max(catalog)?); - - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); + + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } Operator::Lte => { // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = - gt(expr.lhs().stat_min(catalog)?, expr.rhs().stat_max(catalog)?); - - Some(with_nan_predicate( - expr.lhs(), - expr.rhs(), - min_max_check, - catalog, - )) + let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); + + Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) } - Operator::And => expr - .lhs() + Operator::And => lhs .stat_falsification(catalog) .into_iter() - .chain(expr.rhs().stat_falsification(catalog)) + .chain(rhs.stat_falsification(catalog)) .reduce(or), Operator::Or => Some(and( - expr.lhs().stat_falsification(catalog)?, - expr.rhs().stat_falsification(catalog)?, + lhs.stat_falsification(catalog)?, + rhs.stat_falsification(catalog)?, )), Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, } } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _operator: &Operator) -> bool { false } - fn is_fallible(&self, instance: &Self::Instance) -> bool { + fn is_fallible(&self, operator: &Operator) -> bool { // Opt-in not out for fallibility. // Arithmetic operations could be better modelled here. let infallible = matches!( - instance, + operator, Operator::Eq | Operator::NotEq | Operator::Gt @@ -276,24 +324,6 @@ impl VTable for Binary { !infallible } - - fn expr_v2(&self, view: &ExpressionView) -> VortexResult { - ScalarFnExprExt::try_new_expr(&binary::BinaryFn, view.operator(), view.children().clone()) - } -} - -impl ExpressionView<'_, Binary> { - pub fn lhs(&self) -> &Expression { - &self.children()[0] - } - - pub fn rhs(&self) -> &Expression { - &self.children()[1] - } - - pub fn operator(&self) -> Operator { - *self.data() - } } /// Create a new [`Binary`] using the [`Eq`](crate::expr::exprs::operators::Operator::Eq) operator. @@ -642,15 +672,6 @@ mod tests { ); } - #[test] - fn test_debug_print() { - let expr = gt(lit(1), lit(2)); - assert_eq!( - format!("{expr:?}"), - "Expression { vtable: vortex.binary, data: >, children: [Expression { vtable: vortex.literal, data: 1i32, children: [] }, Expression { vtable: vortex.literal, data: 2i32, children: [] }] }" - ); - } - #[test] fn test_display_print() { let expr = gt(lit(1), lit(2)); diff --git a/vortex-array/src/expr/exprs/cast/mod.rs b/vortex-array/src/expr/exprs/cast.rs similarity index 75% rename from vortex-array/src/expr/exprs/cast/mod.rs rename to vortex-array/src/expr/exprs/cast.rs index 9808d42a9cc..53cdcf40856 100644 --- a/vortex-array/src/expr/exprs/cast/mod.rs +++ b/vortex-array/src/expr/exprs/cast.rs @@ -8,17 +8,16 @@ use prost::Message; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_proto::expr as pb; -use vortex_vector::Vector; +use vortex_vector::Datum; use crate::ArrayRef; use crate::compute::cast as compute_cast; +use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::ExecutionArgs; use crate::expr::ExprId; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -29,66 +28,59 @@ use crate::expr::stats::Stat; pub struct Cast; impl VTable for Cast { - type Instance = DType; + type Options = DType; fn id(&self) -> ExprId { ExprId::from("vortex.cast") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, dtype: &DType) -> VortexResult>> { Ok(Some( pb::CastOpts { - target: Some(instance.into()), + target: Some(dtype.into()), } .encode_to_vec(), )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { - Ok(Some( - pb::CastOpts::decode(metadata)? - .target - .as_ref() - .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))? - .try_into()?, - )) + fn deserialize(&self, metadata: &[u8]) -> VortexResult { + pb::CastOpts::decode(metadata)? + .target + .as_ref() + .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))? + .try_into() } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "Cast expression requires exactly 1 child, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &DType) -> Arity { + Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &DType, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("input"), _ => unreachable!("Invalid child index {} for Cast expression", child_idx), } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql(&self, dtype: &DType, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "cast(")?; expr.children()[0].fmt_sql(f)?; - write!(f, " as {}", expr.data())?; + write!(f, " as {}", dtype)?; write!(f, ")") } - fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", instance) - } - - fn return_dtype(&self, expr: &ExpressionView, _scope: &DType) -> VortexResult { - Ok(expr.data().clone()) + fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult { + Ok(dtype.clone()) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + dtype: &DType, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let array = expr.children()[0].evaluate(scope)?; - compute_cast(&array, expr.data()).map_err(|e| { + compute_cast(&array, dtype).map_err(|e| { e.with_context(format!( "Failed to cast array of dtype {} to {}", array.dtype(), @@ -97,9 +89,18 @@ impl VTable for Cast { }) } + fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult { + let input = args + .datums + .pop() + .vortex_expect("missing input for Cast expression"); + vortex_compute::cast::Cast::cast(&input, target_dtype) + } + fn stat_expression( &self, - expr: &ExpressionView, + dtype: &DType, + expr: &Expression, stat: Stat, catalog: &dyn StatsCatalog, ) -> Option { @@ -114,7 +115,7 @@ impl VTable for Cast { // We cast min/max to the new type expr.child(0) .stat_expression(stat, catalog) - .map(|x| cast(x, expr.data().clone())) + .map(|x| cast(x, dtype.clone())) } Stat::NullCount => { // if !expr.data().is_nullable() { @@ -129,16 +130,8 @@ impl VTable for Cast { } } - fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult { - let input = args - .vectors - .pop() - .vortex_expect("missing input for Cast expression"); - vortex_compute::cast::Cast::cast(&input, target_dtype) - } - // This might apply a nullability - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &DType) -> bool { true } } diff --git a/vortex-array/src/expr/exprs/dynamic.rs b/vortex-array/src/expr/exprs/dynamic.rs index ff325a4fb3e..9dbe9819868 100644 --- a/vortex-array/src/expr/exprs/dynamic.rs +++ b/vortex-array/src/expr/exprs/dynamic.rs @@ -15,6 +15,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; +use vortex_vector::Datum; use crate::Array; use crate::ArrayRef; @@ -22,10 +23,11 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::compute::Operator; use crate::compute::compare; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -39,116 +41,130 @@ use crate::expr::traversal::TraversalOrder; pub struct DynamicComparison; impl VTable for DynamicComparison { - type Instance = DynamicComparisonExpr; + type Options = DynamicComparisonExpr; fn id(&self) -> ExprId { ExprId::new_ref("vortex.dynamic") } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "DynamicComparison expression requires exactly one child, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("lhs"), _ => unreachable!(), } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { - expr.lhs().fmt_sql(f)?; - write!(f, " {} dynamic(", expr.data())?; - match expr.scalar() { + fn fmt_sql( + &self, + dynamic: &DynamicComparisonExpr, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + expr.child(0).fmt_sql(f)?; + write!(f, " {} dynamic(", dynamic)?; + match dynamic.scalar() { None => write!(f, "")?, Some(scalar) => write!(f, "{}", scalar)?, } write!(f, ")") } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let lhs = expr.lhs().return_dtype(scope)?; - if !expr.data().rhs.dtype.eq_ignore_nullability(&lhs) { + fn return_dtype( + &self, + dynamic: &DynamicComparisonExpr, + arg_dtypes: &[DType], + ) -> VortexResult { + let lhs = &arg_dtypes[0]; + if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) { vortex_bail!( "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}", - &expr.data().rhs.dtype, + &dynamic.rhs.dtype, lhs ); } Ok(DType::Bool( - lhs.nullability() | expr.data().rhs.dtype.nullability(), + lhs.nullability() | dynamic.rhs.dtype.nullability(), )) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - if let Some(value) = expr.scalar() { - let lhs = expr.lhs().evaluate(scope)?; + fn evaluate( + &self, + dynamic: &DynamicComparisonExpr, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + if let Some(value) = dynamic.rhs.scalar() { + let lhs = expr.child(0).evaluate(scope)?; let rhs = ConstantArray::new(value, scope.len()); - return compare(lhs.as_ref(), rhs.as_ref(), expr.data().operator); + return compare(lhs.as_ref(), rhs.as_ref(), dynamic.operator); } // Otherwise, we return the default value. let lhs = expr.return_dtype(scope.dtype())?; Ok(ConstantArray::new( Scalar::new( - DType::Bool(lhs.nullability() | expr.data().rhs.dtype.nullability()), - expr.data().default.into(), + DType::Bool(lhs.nullability() | dynamic.rhs.dtype.nullability()), + dynamic.default.into(), ), scope.len(), ) .into_array()) } + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { + todo!() + } + fn stat_falsification( &self, - expr: &ExpressionView, + dynamic: &DynamicComparisonExpr, + expr: &Expression, catalog: &dyn StatsCatalog, ) -> Option { - match expr.data().operator { + let lhs = expr.child(0); + match dynamic.operator { Operator::Gt => Some(DynamicComparison.new_expr( DynamicComparisonExpr { operator: Operator::Lte, - rhs: expr.data().rhs.clone(), - default: !expr.data().default, + rhs: dynamic.rhs.clone(), + default: !dynamic.default, }, - vec![expr.lhs().stat_max(catalog)?], + vec![lhs.stat_max(catalog)?], )), Operator::Gte => Some(DynamicComparison.new_expr( DynamicComparisonExpr { operator: Operator::Lt, - rhs: expr.data().rhs.clone(), - default: !expr.data().default, + rhs: dynamic.rhs.clone(), + default: !dynamic.default, }, - vec![expr.lhs().stat_max(catalog)?], + vec![lhs.stat_max(catalog)?], )), Operator::Lt => Some(DynamicComparison.new_expr( DynamicComparisonExpr { operator: Operator::Gte, - rhs: expr.data().rhs.clone(), - default: !expr.data().default, + rhs: dynamic.rhs.clone(), + default: !dynamic.default, }, - vec![expr.lhs().stat_min(catalog)?], + vec![lhs.stat_min(catalog)?], )), Operator::Lte => Some(DynamicComparison.new_expr( DynamicComparisonExpr { operator: Operator::Gt, - rhs: expr.data().rhs.clone(), - default: !expr.data().default, + rhs: dynamic.rhs.clone(), + default: !dynamic.default, }, - vec![expr.lhs().stat_min(catalog)?], + vec![lhs.stat_min(catalog)?], )), _ => None, } } // Defer to the child - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false } } @@ -225,6 +241,12 @@ struct Rhs { dtype: DType, } +impl Rhs { + pub fn scalar(&self) -> Option { + (self.value)().map(|v| Scalar::new(self.dtype.clone(), v)) + } +} + impl Debug for Rhs { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("Rhs") @@ -234,16 +256,6 @@ impl Debug for Rhs { } } -impl ExpressionView<'_, DynamicComparison> { - pub fn lhs(&self) -> &Expression { - &self.children()[0] - } - - pub fn scalar(&self) -> Option { - (self.data().rhs.value)().map(|v| Scalar::new(self.data().rhs.dtype.clone(), v)) - } -} - /// A utility for checking whether any dynamic expressions have been updated. pub struct DynamicExprUpdates { exprs: Box<[DynamicComparisonExpr]>, @@ -261,7 +273,7 @@ impl DynamicExprUpdates { fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult { if let Some(dynamic) = node.as_opt::() { - self.0.push(dynamic.data().clone()); + self.0.push(dynamic.clone()); } Ok(TraversalOrder::Continue) } diff --git a/vortex-array/src/expr/exprs/get_item.rs b/vortex-array/src/expr/exprs/get_item.rs new file mode 100644 index 00000000000..ce8cb3e2805 --- /dev/null +++ b/vortex-array/src/expr/exprs/get_item.rs @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt::Formatter; +use std::ops::Not; + +use prost::Message; +use vortex_dtype::DType; +use vortex_dtype::FieldName; +use vortex_dtype::FieldPath; +use vortex_dtype::Nullability; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_err; +use vortex_proto::expr as pb; +use vortex_vector::Datum; +use vortex_vector::ScalarOps; +use vortex_vector::VectorOps; + +use crate::ArrayRef; +use crate::ToCanonical; +use crate::builtins::ExprBuiltins; +use crate::compute::mask; +use crate::expr::Arity; +use crate::expr::ChildName; +use crate::expr::ExecutionArgs; +use crate::expr::ExprId; +use crate::expr::Expression; +use crate::expr::Pack; +use crate::expr::StatsCatalog; +use crate::expr::VTable; +use crate::expr::VTableExt; +use crate::expr::exprs::root::root; +use crate::expr::lit; +use crate::expr::stats::Stat; + +pub struct GetItem; + +impl VTable for GetItem { + type Options = FieldName; + + fn id(&self) -> ExprId { + ExprId::from("vortex.get_item") + } + + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { + Ok(Some( + pb::GetItemOpts { + path: instance.to_string(), + } + .encode_to_vec(), + )) + } + + fn deserialize(&self, metadata: &[u8]) -> VortexResult { + let opts = pb::GetItemOpts::decode(metadata)?; + Ok(FieldName::from(opts.path)) + } + + fn arity(&self, _field_name: &FieldName) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("input"), + _ => unreachable!("Invalid child index {} for GetItem expression", child_idx), + } + } + + fn fmt_sql( + &self, + field_name: &FieldName, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + expr.children()[0].fmt_sql(f)?; + write!(f, ".{}", field_name) + } + + fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult { + let struct_dtype = &arg_dtypes[0]; + let field_dtype = struct_dtype + .as_struct_fields_opt() + .and_then(|st| st.field(field_name)) + .ok_or_else(|| { + vortex_err!("Couldn't find the {} field in the input scope", field_name) + })?; + + // Match here to avoid cloning the dtype if nullability doesn't need to change + if matches!( + (struct_dtype.nullability(), field_dtype.nullability()), + (Nullability::Nullable, Nullability::NonNullable) + ) { + return Ok(field_dtype.with_nullability(Nullability::Nullable)); + } + + Ok(field_dtype) + } + + fn evaluate( + &self, + field_name: &FieldName, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + let input = expr.children()[0].evaluate(scope)?.to_struct(); + let field = input.field_by_name(field_name).cloned()?; + + match input.dtype().nullability() { + Nullability::NonNullable => Ok(field), + Nullability::Nullable => mask(&field, &input.validity_mask().not()), + } + } + + fn execute(&self, field_name: &FieldName, mut args: ExecutionArgs) -> VortexResult { + let struct_dtype = args.dtypes[0] + .as_struct_fields_opt() + .ok_or_else(|| vortex_err!("Expected struct dtype for child of GetItem expression"))?; + let field_idx = struct_dtype + .find(field_name) + .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", field_name))?; + + match args.datums.pop().vortex_expect("missing input") { + Datum::Scalar(s) => { + let mut field = s.as_struct().field(field_idx); + field.mask_validity(s.is_valid()); + Ok(Datum::Scalar(field)) + } + Datum::Vector(v) => { + let mut field = v.as_struct().fields()[field_idx].clone(); + field.mask_validity(v.validity()); + Ok(Datum::Vector(field)) + } + } + } + + fn simplify_untyped( + &self, + field_name: &FieldName, + expr: &Expression, + ) -> VortexResult> { + let child = expr.child(0); + + // If the child is a Pack expression, we can directly return the corresponding child. + if let Some(pack) = child.as_opt::() { + let idx = pack + .names + .iter() + .position(|name| name == field_name) + .ok_or_else(|| { + vortex_err!( + "Cannot find field {} in pack fields {:?}", + field_name, + pack.names + ) + })?; + + let mut field = child.child(idx).clone(); + + // It's useful to simplify this node without type info, but we need to make sure + // the nullability is correct. We cannot cast since we don't have the dtype info here, + // so instead we insert a Mask expression that we know converts a child's dtype to + // nullable. + if pack.nullability.is_nullable() { + // Mask with an all-true array to ensure the field DType is nullable. + field = field.mask(lit(true))?; + } + + return Ok(Some(field)); + } + + Ok(None) + } + + fn stat_expression( + &self, + field_name: &FieldName, + _expr: &Expression, + stat: Stat, + catalog: &dyn StatsCatalog, + ) -> Option { + // TODO(ngates): I think we can do better here and support stats over nested fields. + // It would be nice if delegating to our child would return a struct of statistics + // matching the nested DType such that we can write: + // `get_item(expr.child(0).stat_expression(...), expr.data().field_name())` + + // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same + // name as a field in the root struct. This should be resolved with upcoming change to + // falsify expressions, but for now I'm preserving the existing buggy behavior. + catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat) + } + + // This will apply struct nullability field. We could add a dtype?? + fn is_null_sensitive(&self, _field_name: &FieldName) -> bool { + true + } + + fn is_fallible(&self, _field_name: &FieldName) -> bool { + // If this type-checks its infallible. + false + } +} + +/// Creates an expression that accesses a field from the root array. +/// +/// Equivalent to `get_item(field, root())` - extracts a named field from the input array. +/// +/// ```rust +/// # use vortex_array::expr::col; +/// let expr = col("name"); +/// ``` +pub fn col(field: impl Into) -> Expression { + GetItem.new_expr(field.into(), vec![root()]) +} + +/// Creates an expression that extracts a named field from a struct expression. +/// +/// Accesses the specified field from the result of the child expression. +/// +/// ```rust +/// # use vortex_array::expr::{get_item, root}; +/// let expr = get_item("user_id", root()); +/// ``` +pub fn get_item(field: impl Into, child: Expression) -> Expression { + GetItem.new_expr(field.into(), vec![child]) +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_dtype::DType; + use vortex_dtype::FieldNames; + use vortex_dtype::Nullability; + use vortex_dtype::Nullability::NonNullable; + use vortex_dtype::PType; + use vortex_dtype::StructFields; + use vortex_scalar::Scalar; + + use crate::Array; + use crate::IntoArray; + use crate::arrays::StructArray; + use crate::expr::exprs::binary::checked_add; + use crate::expr::exprs::get_item::get_item; + use crate::expr::exprs::literal::lit; + use crate::expr::exprs::pack::pack; + use crate::expr::exprs::root::root; + use crate::validity::Validity; + + fn test_array() -> StructArray { + StructArray::from_fields(&[ + ("a", buffer![0i32, 1, 2].into_array()), + ("b", buffer![4i64, 5, 6].into_array()), + ]) + .unwrap() + } + + #[test] + fn get_item_by_name() { + let st = test_array(); + let get_item = get_item("a", root()); + let item = get_item.evaluate(&st.to_array()).unwrap(); + assert_eq!(item.dtype(), &DType::from(PType::I32)) + } + + #[test] + fn get_item_by_name_none() { + let st = test_array(); + let get_item = get_item("c", root()); + assert!(get_item.evaluate(&st.to_array()).is_err()); + } + + #[test] + fn get_nullable_field() { + let st = StructArray::try_new( + FieldNames::from(["a"]), + vec![buffer![1i32].into_array()], + 1, + Validity::AllInvalid, + ) + .unwrap() + .to_array(); + + let get_item = get_item("a", root()); + let item = get_item.evaluate(&st).unwrap(); + assert_eq!( + item.scalar_at(0), + Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)) + ); + } + + #[test] + fn test_pack_get_item_rule() { + // Create: pack(a: lit(1), b: lit(2)).get_item("b") + let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable); + let get_item_expr = get_item("b", pack_expr); + + let result = get_item_expr + .simplify(&DType::Struct(StructFields::empty(), NonNullable)) + .unwrap(); + + assert_eq!(result, lit(2)); + } + + #[test] + fn test_multi_level_pack_get_item_simplify() { + let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable); + let get_a = get_item("a", inner_pack); + + let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable); + let get_z = get_item("z", outer_pack); + + let dtype = DType::Primitive(PType::I32, NonNullable); + + let result = get_z.simplify(&dtype).unwrap(); + assert_eq!(result, lit(4)); + } + + #[test] + fn test_deeply_nested_pack_get_item() { + let innermost = pack([("a", lit(42))], NonNullable); + let get_a = get_item("a", innermost); + + let level2 = pack([("b", get_a)], NonNullable); + let get_b = get_item("b", level2); + + let level3 = pack([("c", get_b)], NonNullable); + let get_c = get_item("c", level3); + + let outermost = pack([("final", get_c)], NonNullable); + let get_final = get_item("final", outermost); + + let dtype = DType::Primitive(PType::I32, NonNullable); + + let result = get_final.simplify(&dtype).unwrap(); + assert_eq!(result, lit(42)); + } + + #[test] + fn test_partial_pack_get_item_simplify() { + let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable); + let get_x = get_item("x", inner_pack); + let add_expr = checked_add(get_x, lit(10)); + + let outer_pack = pack([("result", add_expr)], NonNullable); + let get_result = get_item("result", outer_pack); + + let dtype = DType::Primitive(PType::I32, NonNullable); + + let result = get_result.simplify(&dtype).unwrap(); + let expected = checked_add(lit(1), lit(10)); + assert_eq!(&result, &expected); + } +} diff --git a/vortex-array/src/expr/exprs/get_item/mod.rs b/vortex-array/src/expr/exprs/get_item/mod.rs deleted file mode 100644 index 2eafa582e8f..00000000000 --- a/vortex-array/src/expr/exprs/get_item/mod.rs +++ /dev/null @@ -1,245 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -pub mod transform; - -use std::fmt::Formatter; -use std::ops::Not; - -use prost::Message; -use vortex_dtype::DType; -use vortex_dtype::FieldName; -use vortex_dtype::FieldPath; -use vortex_dtype::Nullability; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use vortex_proto::expr as pb; -use vortex_vector::Vector; -use vortex_vector::VectorOps; - -use crate::ArrayRef; -use crate::ToCanonical; -use crate::compute::mask; -use crate::expr::ChildName; -use crate::expr::ExecutionArgs; -use crate::expr::ExprId; -use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::StatsCatalog; -use crate::expr::VTable; -use crate::expr::VTableExt; -use crate::expr::exprs::root::root; -use crate::expr::stats::Stat; - -pub struct GetItem; - -impl VTable for GetItem { - type Instance = FieldName; - - fn id(&self) -> ExprId { - ExprId::from("vortex.get_item") - } - - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { - Ok(Some( - pb::GetItemOpts { - path: instance.to_string(), - } - .encode_to_vec(), - )) - } - - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { - let opts = pb::GetItemOpts::decode(metadata)?; - Ok(Some(FieldName::from(opts.path))) - } - - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "GetItem expression requires exactly 1 child, got {}", - expr.children().len() - ); - } - Ok(()) - } - - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { - match child_idx { - 0 => ChildName::from("input"), - _ => unreachable!("Invalid child index {} for GetItem expression", child_idx), - } - } - - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { - expr.children()[0].fmt_sql(f)?; - write!(f, ".{}", expr.data()) - } - - fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "\"{}\"", instance) - } - - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let struct_dtype = expr.children()[0].return_dtype(scope)?; - let field_dtype = struct_dtype - .as_struct_fields_opt() - .and_then(|st| st.field(expr.data())) - .ok_or_else(|| { - vortex_err!("Couldn't find the {} field in the input scope", expr.data()) - })?; - - // Match here to avoid cloning the dtype if nullability doesn't need to change - if matches!( - (struct_dtype.nullability(), field_dtype.nullability()), - (Nullability::Nullable, Nullability::NonNullable) - ) { - return Ok(field_dtype.with_nullability(Nullability::Nullable)); - } - - Ok(field_dtype) - } - - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - let input = expr.children()[0].evaluate(scope)?.to_struct(); - let field = input.field_by_name(expr.data()).cloned()?; - - match input.dtype().nullability() { - Nullability::NonNullable => Ok(field), - Nullability::Nullable => mask(&field, &input.validity_mask().not()), - } - } - - fn stat_expression( - &self, - expr: &ExpressionView, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - // TODO(ngates): I think we can do better here and support stats over nested fields. - // It would be nice if delegating to our child would return a struct of statistics - // matching the nested DType such that we can write: - // `get_item(expr.child(0).stat_expression(...), expr.data().field_name())` - - // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same - // name as a field in the root struct. This should be resolved with upcoming change to - // falsify expressions, but for now I'm preserving the existing buggy behavior. - catalog.stats_ref(&FieldPath::from_name(expr.data().clone()), stat) - } - - fn execute(&self, field_name: &FieldName, mut args: ExecutionArgs) -> VortexResult { - let struct_dtype = args.dtypes[0] - .as_struct_fields_opt() - .ok_or_else(|| vortex_err!("Expected struct dtype for child of GetItem expression"))?; - let field_idx = struct_dtype - .find(field_name) - .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", field_name))?; - - let struct_vector = args - .vectors - .pop() - .vortex_expect("missing input") - .into_struct(); - - // We must intersect the validity with that of the parent struct - let mut field = struct_vector.fields()[field_idx].clone(); - field.mask_validity(struct_vector.validity()); - - Ok(field) - } - - // This will apply struct nullability field. We could add a dtype?? - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { - true - } - - fn is_fallible(&self, _instance: &Self::Instance) -> bool { - // If this type-checks its infallible. - false - } -} - -/// Creates an expression that accesses a field from the root array. -/// -/// Equivalent to `get_item(field, root())` - extracts a named field from the input array. -/// -/// ```rust -/// # use vortex_array::expr::col; -/// let expr = col("name"); -/// ``` -pub fn col(field: impl Into) -> Expression { - GetItem.new_expr(field.into(), vec![root()]) -} - -/// Creates an expression that extracts a named field from a struct expression. -/// -/// Accesses the specified field from the result of the child expression. -/// -/// ```rust -/// # use vortex_array::expr::{get_item, root}; -/// let expr = get_item("user_id", root()); -/// ``` -pub fn get_item(field: impl Into, child: Expression) -> Expression { - GetItem.new_expr(field.into(), vec![child]) -} - -#[cfg(test)] -mod tests { - use vortex_buffer::buffer; - use vortex_dtype::DType; - use vortex_dtype::FieldNames; - use vortex_dtype::Nullability; - use vortex_dtype::PType::I32; - use vortex_scalar::Scalar; - - use super::get_item; - use crate::Array; - use crate::IntoArray; - use crate::arrays::StructArray; - use crate::expr::exprs::root::root; - use crate::validity::Validity; - - fn test_array() -> StructArray { - StructArray::from_fields(&[ - ("a", buffer![0i32, 1, 2].into_array()), - ("b", buffer![4i64, 5, 6].into_array()), - ]) - .unwrap() - } - - #[test] - fn get_item_by_name() { - let st = test_array(); - let get_item = get_item("a", root()); - let item = get_item.evaluate(&st.to_array()).unwrap(); - assert_eq!(item.dtype(), &DType::from(I32)) - } - - #[test] - fn get_item_by_name_none() { - let st = test_array(); - let get_item = get_item("c", root()); - assert!(get_item.evaluate(&st.to_array()).is_err()); - } - - #[test] - fn get_nullable_field() { - let st = StructArray::try_new( - FieldNames::from(["a"]), - vec![buffer![1i32].into_array()], - 1, - Validity::AllInvalid, - ) - .unwrap() - .to_array(); - - let get_item = get_item("a", root()); - let item = get_item.evaluate(&st).unwrap(); - assert_eq!( - item.scalar_at(0), - Scalar::null(DType::Primitive(I32, Nullability::Nullable)) - ); - } -} diff --git a/vortex-array/src/expr/exprs/get_item/transform.rs b/vortex-array/src/expr/exprs/get_item/transform.rs deleted file mode 100644 index 82c74f54095..00000000000 --- a/vortex-array/src/expr/exprs/get_item/transform.rs +++ /dev/null @@ -1,139 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; - -use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::exprs::get_item::GetItem; -use crate::expr::exprs::pack::Pack; -use crate::expr::transform::rules::ReduceRule; -use crate::expr::transform::rules::RuleContext; - -/// Rewrite rule: `pack(l_1: e_1, ..., l_i: e_i, ..., l_n: e_n).get_item(l_i) = e_i` -/// -/// Simplifies accessing a field from a pack expression by directly returning the field's -/// expression instead of materializing the pack. -#[derive(Debug, Default)] -pub struct PackGetItemRule; - -impl ReduceRule for PackGetItemRule { - fn reduce( - &self, - get_item: &ExpressionView, - _ctx: &RuleContext, - ) -> VortexResult> { - if let Some(pack) = get_item.child(0).as_opt::() { - let field_expr = pack.field(get_item.data())?; - return Ok(Some(field_expr)); - } - - Ok(None) - } -} - -#[cfg(test)] -mod tests { - use vortex_dtype::DType; - use vortex_dtype::Nullability::NonNullable; - use vortex_dtype::PType; - - use super::PackGetItemRule; - use crate::expr::exprs::binary::checked_add; - use crate::expr::exprs::get_item::GetItem; - use crate::expr::exprs::get_item::get_item; - use crate::expr::exprs::literal::lit; - use crate::expr::exprs::pack::pack; - use crate::expr::session::ExprSession; - use crate::expr::transform::ExprOptimizer; - use crate::expr::transform::rules::ReduceRule; - use crate::expr::transform::rules::RuleContext; - - #[test] - fn test_pack_get_item_rule() { - // Create: pack(a: lit(1), b: lit(2)).get_item("b") - let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable); - let get_item_expr = get_item("b", pack_expr); - - let get_item_view = get_item_expr.as_::(); - let result = PackGetItemRule - .reduce(&get_item_view, &RuleContext) - .unwrap(); - - assert!(result.is_some()); - assert_eq!(&result.unwrap(), &lit(2)); - } - - #[test] - fn test_pack_get_item_rule_no_match() { - // Create: get_item("x", lit(42)) - not a pack child - let lit_expr = lit(42); - let get_item_expr = get_item("x", lit_expr); - - let get_item_view = get_item_expr.as_::(); - let result = PackGetItemRule - .reduce(&get_item_view, &RuleContext) - .unwrap(); - - assert!(result.is_none()); - } - - #[test] - fn test_multi_level_pack_get_item_simplify() { - let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable); - let get_a = get_item("a", inner_pack); - - let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable); - let get_z = get_item("z", outer_pack); - - let dtype = DType::Primitive(PType::I32, NonNullable); - - let session = ExprSession::default(); - let optimizer = ExprOptimizer::new(&session); - let result = optimizer.optimize_typed(get_z, &dtype).unwrap(); - - assert_eq!(&result, &lit(4)); - } - - #[test] - fn test_deeply_nested_pack_get_item() { - let innermost = pack([("a", lit(42))], NonNullable); - let get_a = get_item("a", innermost); - - let level2 = pack([("b", get_a)], NonNullable); - let get_b = get_item("b", level2); - - let level3 = pack([("c", get_b)], NonNullable); - let get_c = get_item("c", level3); - - let outermost = pack([("final", get_c)], NonNullable); - let get_final = get_item("final", outermost); - - let dtype = DType::Primitive(PType::I32, NonNullable); - - let session = ExprSession::default(); - let optimizer = ExprOptimizer::new(&session); - let result = optimizer.optimize_typed(get_final, &dtype).unwrap(); - - assert_eq!(&result, &lit(42)); - } - - #[test] - fn test_partial_pack_get_item_simplify() { - let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable); - let get_x = get_item("x", inner_pack); - let add_expr = checked_add(get_x, lit(10)); - - let outer_pack = pack([("result", add_expr)], NonNullable); - let get_result = get_item("result", outer_pack); - - let dtype = DType::Primitive(PType::I32, NonNullable); - - let session = ExprSession::default(); - let optimizer = ExprOptimizer::new(&session); - let result = optimizer.optimize_typed(get_result, &dtype).unwrap(); - - let expected = checked_add(lit(1), lit(10)); - assert_eq!(&result, &expected); - } -} diff --git a/vortex-array/src/expr/exprs/is_null.rs b/vortex-array/src/expr/exprs/is_null.rs index ba3dacd6f7b..8205179a498 100644 --- a/vortex-array/src/expr/exprs/is_null.rs +++ b/vortex-array/src/expr/exprs/is_null.rs @@ -4,15 +4,15 @@ use std::fmt::Formatter; use std::ops::Not; -use is_null::IsNullFn; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_mask::Mask; -use vortex_vector::Vector; +use vortex_vector::Datum; +use vortex_vector::ScalarOps; use vortex_vector::VectorOps; +use vortex_vector::bool::BoolScalar; use vortex_vector::bool::BoolVector; use crate::Array; @@ -20,67 +20,69 @@ use crate::ArrayRef; use crate::IntoArray; use crate::arrays::BoolArray; use crate::arrays::ConstantArray; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::EmptyOptions; use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::ScalarFnExprExt; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; use crate::expr::exprs::binary::eq; use crate::expr::exprs::literal::lit; -use crate::expr::functions::EmptyOptions; use crate::expr::stats::Stat; -use crate::scalar_fns::is_null; /// Expression that checks for null values. pub struct IsNull; impl VTable for IsNull { - type Instance = (); + type Options = EmptyOptions; fn id(&self) -> ExprId { ExprId::new_ref("is_null") } - fn serialize(&self, _instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, _instance: &Self::Options) -> VortexResult>> { Ok(Some(vec![])) } - fn deserialize(&self, _metadata: &[u8]) -> VortexResult> { - Ok(Some(())) + fn deserialize(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyOptions) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "IsNull expression expects exactly one child, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("input"), _ => unreachable!("Invalid child index {} for IsNull expression", child_idx), } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "is_null(")?; expr.child(0).fmt_sql(f)?; write!(f, ")") } - fn return_dtype(&self, _expr: &ExpressionView, _scope: &DType) -> VortexResult { + fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult { Ok(DType::Bool(Nullability::NonNullable)) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + _options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let array = expr.child(0).evaluate(scope)?; match array.validity_mask() { Mask::AllTrue(len) => Ok(ConstantArray::new(false, len).into_array()), @@ -89,35 +91,33 @@ impl VTable for IsNull { } } + fn execute(&self, _data: &Self::Options, mut args: ExecutionArgs) -> VortexResult { + let child = args.datums.pop().vortex_expect("Missing input child"); + Ok(match child { + Datum::Scalar(s) => Datum::Scalar(BoolScalar::new(Some(s.is_invalid())).into()), + Datum::Vector(v) => Datum::Vector( + BoolVector::new(v.validity().to_bit_buffer().not(), Mask::new_true(v.len())).into(), + ), + }) + } + fn stat_falsification( &self, - expr: &ExpressionView, + _options: &Self::Options, + expr: &Expression, catalog: &dyn StatsCatalog, ) -> Option { let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?; Some(eq(null_count_expr, lit(0u64))) } - fn execute(&self, _data: &Self::Instance, mut args: ExecutionArgs) -> VortexResult { - let child = args.vectors.pop().vortex_expect("Missing input child"); - Ok(BoolVector::new( - child.validity().to_bit_buffer().not(), - Mask::new_true(child.len()), - ) - .into()) - } - - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _instance: &Self::Options) -> bool { false } - - fn expr_v2(&self, view: &ExpressionView) -> VortexResult { - ScalarFnExprExt::try_new_expr(&IsNullFn, EmptyOptions, view.children().clone()) - } } /// Creates an expression that checks for null values. @@ -129,7 +129,7 @@ impl VTable for IsNull { /// let expr = is_null(root()); /// ``` pub fn is_null(child: Expression) -> Expression { - IsNull.new_expr((), vec![child]) + IsNull.new_expr(EmptyOptions, vec![child]) } #[cfg(test)] @@ -279,6 +279,6 @@ mod tests { #[test] fn test_is_null_sensitive() { // is_null itself is null-sensitive - assert!(is_null(col("a")).is_null_sensitive()); + assert!(is_null(col("a")).signature().is_null_sensitive()); } } diff --git a/vortex-array/src/expr/exprs/like.rs b/vortex-array/src/expr/exprs/like.rs index 65f6d823dce..09a0eddca25 100644 --- a/vortex-array/src/expr/exprs/like.rs +++ b/vortex-array/src/expr/exprs/like.rs @@ -4,18 +4,24 @@ use std::fmt::Formatter; use prost::Message; +use vortex_compute::arrow::IntoArrow; +use vortex_compute::arrow::IntoVector; use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_error::vortex_err; use vortex_proto::expr as pb; +use vortex_vector::Datum; +use vortex_vector::VectorOps; use crate::ArrayRef; use crate::compute::LikeOptions; use crate::compute::like as like_compute; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; use crate::expr::VTable; use crate::expr::VTableExt; @@ -23,13 +29,13 @@ use crate::expr::VTableExt; pub struct Like; impl VTable for Like { - type Instance = LikeOptions; + type Options = LikeOptions; fn id(&self) -> ExprId { ExprId::from("vortex.like") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::LikeOpts { negated: instance.negated, @@ -39,25 +45,19 @@ impl VTable for Like { )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let opts = pb::LikeOpts::decode(metadata)?; - Ok(Some(LikeOptions { + Ok(LikeOptions { negated: opts.negated, case_insensitive: opts.case_insensitive, - })) + }) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 2 { - vortex_bail!( - "Like expression requires exactly 2 children, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("child"), 1 => ChildName::from("pattern"), @@ -65,12 +65,17 @@ impl VTable for Like { } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { expr.child(0).fmt_sql(f)?; - if expr.data().negated { + if options.negated { write!(f, " not")?; } - if expr.data().case_insensitive { + if options.case_insensitive { write!(f, " ilike ")?; } else { write!(f, " like ")?; @@ -78,9 +83,9 @@ impl VTable for Like { expr.child(1).fmt_sql(f) } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let input = expr.children()[0].return_dtype(scope)?; - let pattern = expr.children()[1].return_dtype(scope)?; + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let input = &arg_dtypes[0]; + let pattern = &arg_dtypes[1]; if !input.is_utf8() { vortex_bail!("LIKE expression requires UTF8 input dtype, got {}", input); @@ -97,13 +102,43 @@ impl VTable for Like { )) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let child = expr.child(0).evaluate(scope)?; let pattern = expr.child(1).evaluate(scope)?; - like_compute(&child, &pattern, *expr.data()) + like_compute(&child, &pattern, *options) + } + + fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult { + let [child, pattern]: [Datum; _] = args + .datums + .try_into() + .map_err(|_| vortex_err!("Wrong argument count"))?; + + let child = child.into_arrow()?; + let pattern = pattern.into_arrow()?; + + let array = match (options.negated, options.case_insensitive) { + (false, false) => arrow_string::like::like(child.as_ref(), pattern.as_ref()), + (false, true) => arrow_string::like::ilike(child.as_ref(), pattern.as_ref()), + (true, false) => arrow_string::like::nlike(child.as_ref(), pattern.as_ref()), + (true, true) => arrow_string::like::nilike(child.as_ref(), pattern.as_ref()), + }?; + + let vector = array.into_vector()?; + if vector.len() == 1 && args.row_count != 1 { + // Arrow returns a scalar datum result + return Ok(Datum::Scalar(vector.scalar_at(0).into())); + } + + Ok(Datum::Vector(array.into_vector()?.into())) } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false } } diff --git a/vortex-array/src/expr/exprs/list_contains.rs b/vortex-array/src/expr/exprs/list_contains.rs index 8a4664b8e2a..f48607a0471 100644 --- a/vortex-array/src/expr/exprs/list_contains.rs +++ b/vortex-array/src/expr/exprs/list_contains.rs @@ -2,17 +2,40 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::fmt::Formatter; +use std::ops::BitOr; +use std::ops::Deref; +use arrow_buffer::bit_iterator::BitIndexIterator; +use vortex_buffer::BitBuffer; +use vortex_compute::logical::LogicalOr; use vortex_dtype::DType; +use vortex_dtype::IntegerPType; +use vortex_dtype::Nullability; +use vortex_dtype::PTypeDowncastExt; +use vortex_dtype::match_each_integer_ptype; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_error::vortex_err; +use vortex_mask::Mask; +use vortex_vector::BoolDatum; +use vortex_vector::Datum; +use vortex_vector::Vector; +use vortex_vector::VectorOps; +use vortex_vector::bool::BoolVector; +use vortex_vector::listview::ListViewScalar; +use vortex_vector::listview::ListViewVector; +use vortex_vector::primitive::PVector; use crate::ArrayRef; use crate::compute::list_contains as compute_list_contains; +use crate::expr::Arity; +use crate::expr::Binary; use crate::expr::ChildName; +use crate::expr::EmptyOptions; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -22,35 +45,30 @@ use crate::expr::exprs::binary::lt; use crate::expr::exprs::binary::or; use crate::expr::exprs::literal::Literal; use crate::expr::exprs::literal::lit; +use crate::expr::operators; pub struct ListContains; impl VTable for ListContains { - type Instance = (); + type Options = EmptyOptions; fn id(&self) -> ExprId { ExprId::from("vortex.list.contains") } - fn serialize(&self, _instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, _instance: &Self::Options) -> VortexResult>> { Ok(Some(vec![])) } - fn deserialize(&self, _metadata: &[u8]) -> VortexResult> { - Ok(Some(())) + fn deserialize(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyOptions) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 2 { - vortex_bail!( - "ListContains expression requires exactly 2 children, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("list"), 1 => ChildName::from("needle"), @@ -60,8 +78,12 @@ impl VTable for ListContains { ), } } - - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "contains(")?; expr.child(0).fmt_sql(f)?; write!(f, ", ")?; @@ -69,9 +91,9 @@ impl VTable for ListContains { write!(f, ")") } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let list_dtype = expr.child(0).return_dtype(scope)?; - let value_dtype = expr.child(1).return_dtype(scope)?; + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let list_dtype = &arg_dtypes[0]; + let needle_dtype = &arg_dtypes[0]; let nullability = match list_dtype { DType::List(_, list_nullability) => list_nullability, @@ -81,38 +103,76 @@ impl VTable for ListContains { list_dtype ); } - } | value_dtype.nullability(); + } + .bitor(needle_dtype.nullability()); Ok(DType::Bool(nullability)) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + _options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let list_array = expr.child(0).evaluate(scope)?; let value_array = expr.child(1).evaluate(scope)?; compute_list_contains(list_array.as_ref(), value_array.as_ref()) } + fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult { + let [lhs, rhs]: [Datum; _] = args + .datums + .try_into() + .map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?; + + let matches = match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) { + (true, true) => { + todo!("Implement ListContains for two scalars") + } + (true, false) => constant_list_scalar_contains( + lhs.into_scalar().vortex_expect("scalar").into_list(), + rhs.into_vector().vortex_expect("vector"), + ), + (false, true) => list_contains_scalar( + lhs.ensure_vector(args.row_count).into_list(), + rhs.into_scalar().vortex_expect("scalar").into_list(), + ), + (false, false) => { + vortex_bail!( + "ListContains currently only supports constant needle (RHS) or constant list (LHS)" + ) + } + }?; + + Ok(Datum::Vector(matches.into())) + } + fn stat_falsification( &self, - expr: &ExpressionView, + _options: &Self::Options, + expr: &Expression, catalog: &dyn StatsCatalog, ) -> Option { + let list = expr.child(0); + let needle = expr.child(1); + // falsification(contains([1,2,5], x)) => // falsification(x != 1) and falsification(x != 2) and falsification(x != 5) - let min = expr.list().stat_min(catalog)?; - let max = expr.list().stat_max(catalog)?; + let min = list.stat_min(catalog)?; + let max = list.stat_max(catalog)?; // If the list is constant when we can compare each element to the value if min == max { let list_ = min .as_opt::() - .and_then(|l| l.data().as_list_opt()) + .and_then(|l| l.as_list_opt()) .and_then(|l| l.elements())?; if list_.is_empty() { // contains([], x) is always false. return Some(lit(true)); } - let value_max = expr.needle().stat_max(catalog)?; - let value_min = expr.needle().stat_min(catalog)?; + let value_max = needle.stat_max(catalog)?; + let value_min = needle.stat_min(catalog)?; return list_ .iter() @@ -129,7 +189,7 @@ impl VTable for ListContains { } // Nullability matters for contains([], x) where x is false. - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } } @@ -143,17 +203,153 @@ impl VTable for ListContains { /// let expr = list_contains(root(), lit(42)); /// ``` pub fn list_contains(list: Expression, value: Expression) -> Expression { - ListContains.new_expr((), [list, value]) + ListContains.new_expr(EmptyOptions, [list, value]) } -impl ExpressionView<'_, ListContains> { - pub fn list(&self) -> &Expression { - &self.children()[0] +/// Returns a [`BoolVector`] where each bit represents if a list contains the scalar. +// FIXME(ngates): test implementation and move to vortex-compute +fn list_contains_scalar(list: ListViewVector, value: ListViewScalar) -> VortexResult { + // If the list array is constant, we perform a single comparison. + // if list.len() > 1 && list.is_constant() { + // let contains = list_contains_scalar(&array.slice(0..1), value, nullability)?; + // return Ok(ConstantArray::new(contains.scalar_at(0), array.len()).into_array()); + // } + + let elems = list.elements(); + if elems.is_empty() { + // Must return false when a list is empty (but valid), or null when the list itself is null. + // return crate::compute::list_contains::list_false_or_null(&list_array, nullability); + todo!() } - pub fn needle(&self) -> &Expression { - &self.children()[1] + let matches = Binary + .bind(operators::Operator::Eq) + .execute(ExecutionArgs { + datums: vec![ + Datum::Vector(elems.deref().clone()), + Datum::Scalar(value.into()), + ], + // FIXME(ngates): dtypes + dtypes: vec![], + row_count: elems.len(), + return_dtype: DType::Bool(Nullability::Nullable), + })? + .ensure_vector(elems.len()) + .into_bool() + .into_bits(); + + // // Fast path: no elements match. + // if let Some(pred) = matches.as_constant() { + // return match pred.as_bool().value() { + // // All comparisons are invalid (result in `null`), and search is not null because + // // we already checked for null above. + // None => { + // assert!( + // !rhs.scalar().is_null(), + // "Search value must not be null here" + // ); + // // False, unless the list itself is null in which case we return null. + // crate::compute::list_contains::list_false_or_null(&list_array, nullability) + // } + // // No elements match, and all comparisons are valid (result in `false`). + // Some(false) => { + // // False, but match the nullability to the input list array. + // Ok( + // ConstantArray::new(Scalar::bool(false, nullability), list_array.len()) + // .into_array(), + // ) + // } + // // All elements match, and all comparisons are valid (result in `true`). + // Some(true) => { + // // True, unless the list itself is empty or NULL. + // crate::compute::list_contains::list_is_not_empty(&list_array, nullability) + // } + // }; + // } + + // Get the offsets and sizes as primitive arrays. + let offsets = list.offsets(); + let sizes = list.sizes(); + + // Process based on the offset and size types. + let list_matches = match_each_integer_ptype!(offsets.ptype(), |O| { + match_each_integer_ptype!(sizes.ptype(), |S| { + process_matches::( + matches, + list.len(), + offsets.downcast::(), + sizes.downcast::(), + ) + }) + }); + + Ok(BoolVector::new(list_matches, list.validity().clone())) +} + +// Then there is a constant list scalar (haystack) being compared to an array of needles. +// FIXME(ngates): test implementation and move to vortex-compute +fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> VortexResult { + let elements = list.value().elements(); + + // For each element in the list, we perform a full comparison over the values and OR + // the results together. + let mut result: BoolVector = BoolVector::new( + BitBuffer::new_unset(values.len()), + Mask::new(values.len(), true), + ); + for i in 0..elements.len() { + let element = Datum::Scalar(elements.scalar_at(i)); + let compared: BoolDatum = Binary + .bind(operators::Operator::Eq) + .execute(ExecutionArgs { + datums: vec![Datum::Vector(values.clone()), element], + dtypes: vec![ + // FIXME(ngates): call compute function directly! + ], + row_count: values.len(), + return_dtype: DType::Bool(Nullability::Nullable), + })? + .into_bool(); + let compared = Datum::from(compared) + .ensure_vector(values.len()) + .into_bool(); + + result = LogicalOr::or(result, &compared); } + + Ok(result) +} + +/// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a +/// [`BoolArray`] of matches on the child elements array. +/// +/// TODO(ngates): replace this for aggregation function. +fn process_matches( + matches: BitBuffer, + list_array_len: usize, + offsets: &PVector, + sizes: &PVector, +) -> BitBuffer +where + O: IntegerPType, + S: IntegerPType, +{ + let offsets_slice = offsets.elements().as_slice(); + let sizes_slice = sizes.elements().as_slice(); + + (0..list_array_len) + .map(|i| { + // TODO(ngates): does validity render this invalid? + let offset = offsets_slice[i].as_(); + let size = sizes_slice[i].as_(); + + // BitIndexIterator yields indices of true bits only. If `.next()` returns + // `Some(_)`, at least one element in this list's range matches. + let mut set_bits = + BitIndexIterator::new(matches.inner().as_slice(), matches.offset() + offset, size); + set_bits.next().is_some() + }) + .collect::() } #[cfg(test)] diff --git a/vortex-array/src/expr/exprs/literal.rs b/vortex-array/src/expr/exprs/literal.rs index 949871ca601..76480e67658 100644 --- a/vortex-array/src/expr/exprs/literal.rs +++ b/vortex-array/src/expr/exprs/literal.rs @@ -7,19 +7,20 @@ use prost::Message; use vortex_dtype::DType; use vortex_dtype::match_each_float_ptype; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_proto::expr as pb; use vortex_scalar::Scalar; +use vortex_vector::Datum; use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ConstantArray; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -29,13 +30,13 @@ use crate::expr::stats::Stat; pub struct Literal; impl VTable for Literal { - type Instance = Scalar; + type Options = Scalar; fn id(&self) -> ExprId { ExprId::new_ref("vortex.literal") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::LiteralOpts { value: Some(instance.as_ref().into()), @@ -44,49 +45,53 @@ impl VTable for Literal { )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let ops = pb::LiteralOpts::decode(metadata)?; - Ok(Some( - ops.value - .as_ref() - .ok_or_else(|| vortex_err!("Literal metadata missing value"))? - .try_into()?, - )) + ops.value + .as_ref() + .ok_or_else(|| vortex_err!("Literal metadata missing value"))? + .try_into() } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if !expr.children().is_empty() { - vortex_bail!( - "Literal expression does not have children, got: {:?}", - expr.children() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(0) } - fn child_name(&self, _instance: &Self::Instance, _child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, _child_idx: usize) -> ChildName { unreachable!() } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", expr.data()) + fn fmt_sql( + &self, + scalar: &Scalar, + _expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "{}", scalar) } - fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", instance) + fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult { + Ok(options.dtype().clone()) } - fn return_dtype(&self, expr: &ExpressionView, _scope: &DType) -> VortexResult { - Ok(expr.data().dtype().clone()) + fn evaluate( + &self, + scalar: &Scalar, + _expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + Ok(ConstantArray::new(scalar.clone(), scope.len()).into_array()) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - Ok(ConstantArray::new(expr.data().clone(), scope.len()).into_array()) + fn execute(&self, scalar: &Scalar, _args: ExecutionArgs) -> VortexResult { + let vector_scalar = scalar.to_vector_scalar(); + Ok(Datum::Scalar(vector_scalar)) } fn stat_expression( &self, - expr: &ExpressionView, + scalar: &Scalar, + _expr: &Expression, stat: Stat, _catalog: &dyn StatsCatalog, ) -> Option { @@ -96,12 +101,12 @@ impl VTable for Literal { // only currently used for pruning, it doesn't change the outcome. match stat { - Stat::Min | Stat::Max => Some(lit(expr.data().clone())), + Stat::Min | Stat::Max => Some(lit(scalar.clone())), Stat::IsConstant => Some(lit(true)), Stat::NaNCount => { // The NaNCount for a non-float literal is not defined. // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise. - let value = expr.data().as_primitive_opt()?; + let value = scalar.as_primitive_opt()?; if !value.ptype().is_float() { return None; } @@ -115,7 +120,7 @@ impl VTable for Literal { }) } Stat::NullCount => { - if expr.data().is_null() { + if scalar.is_null() { Some(lit(1u64)) } else { Some(lit(0u64)) @@ -127,11 +132,11 @@ impl VTable for Literal { } } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _instance: &Self::Options) -> bool { false } } @@ -149,8 +154,8 @@ impl VTable for Literal { /// /// let number = lit(34i32); /// -/// let literal = number.as_::(); -/// assert_eq!(literal.data(), &Scalar::primitive(34i32, Nullability::NonNullable)); +/// let scalar = number.as_::(); +/// assert_eq!(scalar, &Scalar::primitive(34i32, Nullability::NonNullable)); /// ``` pub fn lit(value: impl Into) -> Expression { Literal.new_expr(value.into(), []) diff --git a/vortex-array/src/expr/exprs/mask.rs b/vortex-array/src/expr/exprs/mask.rs new file mode 100644 index 00000000000..61b0eb0dfe1 --- /dev/null +++ b/vortex-array/src/expr/exprs/mask.rs @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt::Formatter; +use std::ops::Not; + +use vortex_dtype::DType; +use vortex_dtype::Nullability; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; +use vortex_vector::BoolDatum; +use vortex_vector::Datum; +use vortex_vector::ScalarOps; +use vortex_vector::VectorMutOps; +use vortex_vector::VectorOps; + +use crate::Array; +use crate::ArrayRef; +use crate::expr::Arity; +use crate::expr::ChildName; +use crate::expr::EmptyOptions; +use crate::expr::ExecutionArgs; +use crate::expr::ExprId; +use crate::expr::Expression; +use crate::expr::VTable; +use crate::expr::VTableExt; + +/// An expression that masks an input based on a boolean mask. +/// +/// Where the mask is true, the input value is retained; where the mask is false, the output is +/// null. In other words, this performs an intersection of the input's validity with the mask. +pub struct Mask; + +impl VTable for Mask { + type Options = EmptyOptions; + + fn id(&self) -> ExprId { + ExprId::from("vortex.mask") + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("input"), + 1 => ChildName::from("mask"), + _ => unreachable!("Invalid child index {} for Mask expression", child_idx), + } + } + + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "mask(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ", ")?; + expr.child(1).fmt_sql(f)?; + write!(f, ")") + } + + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + vortex_ensure!( + arg_dtypes[1] == DType::Bool(Nullability::NonNullable), + "The mask argument to 'mask' must be a non-nullable boolean array, got {}", + arg_dtypes[1] + ); + Ok(arg_dtypes[0].as_nullable()) + } + + fn evaluate( + &self, + _options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + let child = expr.child(0).evaluate(scope)?; + + // Invert the validity mask - we want to set values to null where validity is false. + let inverted_mask = child.validity_mask().not(); + + crate::compute::mask(&child, &inverted_mask) + } + + fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult { + let [input, mask]: [Datum; _] = args + .datums + .try_into() + .map_err(|_| vortex_err!("Wrong arg count"))?; + let mask = mask.into_bool(); + + match (input, mask) { + (Datum::Scalar(input), BoolDatum::Scalar(mask)) => { + let mut result = input; + result.mask_validity(mask.value().vortex_expect("mask is non-nullable")); + Ok(Datum::Scalar(result)) + } + (Datum::Scalar(input), BoolDatum::Vector(mask)) => { + let mut result = input.repeat(args.row_count).freeze(); + result.mask_validity(&vortex_mask::Mask::from(mask.into_bits())); + Ok(Datum::Vector(result)) + } + (Datum::Vector(input_array), BoolDatum::Scalar(mask)) => { + let mut result = input_array; + result.mask_validity(&vortex_mask::Mask::new( + args.row_count, + mask.value().vortex_expect("mask is non-nullable"), + )); + Ok(Datum::Vector(result)) + } + (Datum::Vector(input_array), BoolDatum::Vector(mask)) => { + let mut result = input_array; + result.mask_validity(&vortex_mask::Mask::from(mask.into_bits())); + Ok(Datum::Vector(result)) + } + } + } +} + +/// Creates a mask expression that applies the given boolean mask to the input array. +pub fn mask(array: Expression, mask: Expression) -> Expression { + Mask.new_expr(EmptyOptions, [array, mask]) +} diff --git a/vortex-array/src/expr/exprs/merge/mod.rs b/vortex-array/src/expr/exprs/merge.rs similarity index 76% rename from vortex-array/src/expr/exprs/merge/mod.rs rename to vortex-array/src/expr/exprs/merge.rs index 3d2a2ec1ee5..f0a5372e3f2 100644 --- a/vortex-array/src/expr/exprs/merge/mod.rs +++ b/vortex-array/src/expr/exprs/merge.rs @@ -1,8 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -pub mod transform; - +use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hash; use std::sync::Arc; @@ -12,21 +11,27 @@ use vortex_dtype::DType; use vortex_dtype::FieldNames; use vortex_dtype::Nullability; use vortex_dtype::StructFields; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_utils::aliases::hash_set::HashSet; +use vortex_vector::Datum; use crate::Array; use crate::ArrayRef; use crate::IntoArray as _; use crate::ToCanonical; use crate::arrays::StructArray; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; +use crate::expr::SimplifyCtx; use crate::expr::VTable; use crate::expr::VTableExt; +use crate::expr::get_item; +use crate::expr::pack; use crate::validity::Validity; /// Merge zero or more expressions that ALL return structs. @@ -38,20 +43,20 @@ use crate::validity::Validity; pub struct Merge; impl VTable for Merge { - type Instance = DuplicateHandling; + type Options = DuplicateHandling; fn id(&self) -> ExprId { ExprId::new_ref("vortex.merge") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some(match instance { DuplicateHandling::RightMost => vec![0x00], DuplicateHandling::Error => vec![0x01], })) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let instance = match metadata { [0x00] => DuplicateHandling::RightMost, [0x01] => DuplicateHandling::Error, @@ -59,18 +64,23 @@ impl VTable for Merge { vortex_bail!("invalid metadata for Merge expression"); } }; - Ok(Some(instance)) + Ok(instance) } - fn validate(&self, _expr: &ExpressionView) -> VortexResult<()> { - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Variadic { min: 0, max: None } } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { ChildName::from(Arc::from(format!("{}", child_idx))) } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "merge(")?; for (i, child) in expr.children().iter().enumerate() { child.fmt_sql(f)?; @@ -81,14 +91,13 @@ impl VTable for Merge { write!(f, ")") } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { + fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { let mut field_names = Vec::new(); let mut arrays = Vec::new(); let mut merge_nullability = Nullability::NonNullable; let mut duplicate_names = HashSet::<_>::new(); - for child in expr.children().iter() { - let dtype = child.return_dtype(scope)?; + for dtype in arg_dtypes { let Some(fields) = dtype.as_struct_fields_opt() else { vortex_bail!("merge expects struct input"); }; @@ -109,7 +118,7 @@ impl VTable for Merge { } } - if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() { + if options == &DuplicateHandling::Error && !duplicate_names.is_empty() { vortex_bail!( "merge: duplicate fields in children: {}", duplicate_names.into_iter().format(", ") @@ -122,7 +131,12 @@ impl VTable for Merge { )) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { // Collect fields in order of appearance. Later fields overwrite earlier fields. let mut field_names = Vec::new(); let mut arrays = Vec::new(); @@ -151,7 +165,7 @@ impl VTable for Merge { } } - if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() { + if options == &DuplicateHandling::Error && !duplicate_names.is_empty() { vortex_bail!( "merge: duplicate fields in children: {}", duplicate_names.into_iter().format(", ") @@ -167,11 +181,68 @@ impl VTable for Merge { ) } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { + todo!() + } + + fn simplify( + &self, + options: &Self::Options, + expr: &Expression, + ctx: &dyn SimplifyCtx, + ) -> VortexResult> { + let merge_dtype = ctx.return_dtype(expr)?; + let mut names = Vec::with_capacity(expr.children().len() * 2); + let mut children = Vec::with_capacity(expr.children().len() * 2); + let mut duplicate_names = HashSet::<_>::new(); + + for child in expr.children().iter() { + let child_dtype = ctx.return_dtype(child)?; + if !child_dtype.is_struct() { + vortex_bail!( + "Merge child must return a non-nullable struct dtype, got {}", + child_dtype + ) + } + + let child_dtype = child_dtype + .as_struct_fields_opt() + .vortex_expect("expected struct"); + + for name in child_dtype.names().iter() { + if let Some(idx) = names.iter().position(|n| n == name) { + duplicate_names.insert(name.clone()); + children[idx] = child.clone(); + } else { + names.push(name.clone()); + children.push(child.clone()); + } + } + + if options == &DuplicateHandling::Error && !duplicate_names.is_empty() { + vortex_bail!( + "merge: duplicate fields in children: {}", + duplicate_names.into_iter().format(", ") + ) + } + } + + let expr = pack( + names + .into_iter() + .zip(children) + .map(|(name, child)| (name.clone(), get_item(name, child))), + merge_dtype.nullability(), + ); + + Ok(Some(expr)) + } + + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } - fn is_fallible(&self, instance: &Self::Instance) -> bool { + fn is_fallible(&self, instance: &Self::Options) -> bool { matches!(instance, DuplicateHandling::Error) } } @@ -186,6 +257,15 @@ pub enum DuplicateHandling { Error, } +impl Display for DuplicateHandling { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DuplicateHandling::RightMost => write!(f, "RightMost"), + DuplicateHandling::Error => write!(f, "Error"), + } + } +} + /// Creates an expression that merges struct expressions into a single struct. /// /// Combines fields from all input expressions. If field names are duplicated, @@ -212,6 +292,12 @@ pub fn merge_opts( #[cfg(test)] mod tests { use vortex_buffer::buffer; + use vortex_dtype::DType; + use vortex_dtype::Nullability::NonNullable; + use vortex_dtype::PType::I32; + use vortex_dtype::PType::I64; + use vortex_dtype::PType::U32; + use vortex_dtype::PType::U64; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -222,6 +308,7 @@ mod tests { use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; use crate::expr::Expression; + use crate::expr::Pack; use crate::expr::exprs::get_item::get_item; use crate::expr::exprs::merge::DuplicateHandling; use crate::expr::exprs::merge::merge_opts; @@ -472,4 +559,28 @@ mod tests { let expr2 = merge(vec![get_item("a", root())]); assert_eq!(expr2.to_string(), "merge($.a)"); } + + #[test] + fn test_remove_merge() { + let dtype = DType::struct_( + [ + ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)), + ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)), + ], + NonNullable, + ); + + let e = merge_opts( + [get_item("0", root()), get_item("1", root())], + DuplicateHandling::RightMost, + ); + + let result = e.simplify(&dtype).unwrap(); + + assert!(result.is::()); + assert_eq!( + result.return_dtype(&dtype).unwrap(), + DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable) + ); + } } diff --git a/vortex-array/src/expr/exprs/merge/transform.rs b/vortex-array/src/expr/exprs/merge/transform.rs deleted file mode 100644 index 55d7bee73d6..00000000000 --- a/vortex-array/src/expr/exprs/merge/transform.rs +++ /dev/null @@ -1,125 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use itertools::Itertools as _; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_utils::aliases::hash_set::HashSet; - -use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::exprs::get_item::get_item; -use crate::expr::exprs::merge::DuplicateHandling; -use crate::expr::exprs::merge::Merge; -use crate::expr::exprs::pack::pack; -use crate::expr::transform::rules::ReduceRule; -use crate::expr::transform::rules::TypedRuleContext; - -/// Rule that removes Merge expressions by converting them to Pack + GetItem. -/// -/// Transforms: `merge([struct1, struct2])` → `pack(field1: get_item("field1", struct1), field2: get_item("field2", struct2), ...)` -#[derive(Debug, Default)] -pub struct RemoveMergeRule; - -impl ReduceRule for RemoveMergeRule { - fn reduce( - &self, - merge: &ExpressionView, - ctx: &TypedRuleContext, - ) -> VortexResult> { - let merge_dtype = merge.return_dtype(ctx.dtype())?; - let mut names = Vec::with_capacity(merge.children().len() * 2); - let mut children = Vec::with_capacity(merge.children().len() * 2); - let mut duplicate_names = HashSet::<_>::new(); - - for child in merge.children().iter() { - let child_dtype = child.return_dtype(ctx.dtype())?; - if !child_dtype.is_struct() { - vortex_bail!( - "Merge child must return a non-nullable struct dtype, got {}", - child_dtype - ) - } - - let child_dtype = child_dtype - .as_struct_fields_opt() - .vortex_expect("expected struct"); - - for name in child_dtype.names().iter() { - if let Some(idx) = names.iter().position(|n| n == name) { - duplicate_names.insert(name.clone()); - children[idx] = child.clone(); - } else { - names.push(name.clone()); - children.push(child.clone()); - } - } - - if merge.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() { - vortex_bail!( - "merge: duplicate fields in children: {}", - duplicate_names.into_iter().format(", ") - ) - } - } - - let expr = pack( - names - .into_iter() - .zip(children) - .map(|(name, child)| (name.clone(), get_item(name, child))), - merge_dtype.nullability(), - ); - - Ok(Some(expr)) - } -} - -#[cfg(test)] -mod tests { - use vortex_dtype::DType; - use vortex_dtype::Nullability::NonNullable; - use vortex_dtype::PType::I32; - use vortex_dtype::PType::I64; - use vortex_dtype::PType::U32; - use vortex_dtype::PType::U64; - - use super::RemoveMergeRule; - use crate::expr::exprs::get_item::get_item; - use crate::expr::exprs::merge::DuplicateHandling; - use crate::expr::exprs::merge::Merge; - use crate::expr::exprs::merge::merge_opts; - use crate::expr::exprs::pack::Pack; - use crate::expr::exprs::root::root; - use crate::expr::transform::rules::ReduceRule; - use crate::expr::transform::rules::TypedRuleContext; - - #[test] - fn test_remove_merge() { - let dtype = DType::struct_( - [ - ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)), - ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)), - ], - NonNullable, - ); - - let e = merge_opts( - [get_item("0", root()), get_item("1", root())], - DuplicateHandling::RightMost, - ); - - let ctx = TypedRuleContext::new(dtype.clone()); - let merge_view = e.as_::(); - let result = RemoveMergeRule.reduce(&merge_view, &ctx).unwrap(); - - assert!(result.is_some()); - let result = result.unwrap(); - assert!(result.is::()); - assert_eq!( - result.return_dtype(&dtype).unwrap(), - DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable) - ); - } -} diff --git a/vortex-array/src/expr/exprs/mod.rs b/vortex-array/src/expr/exprs/mod.rs index 7965311aae2..c606b53f5a0 100644 --- a/vortex-array/src/expr/exprs/mod.rs +++ b/vortex-array/src/expr/exprs/mod.rs @@ -10,12 +10,12 @@ pub(crate) mod is_null; pub(crate) mod like; pub(crate) mod list_contains; pub(crate) mod literal; +pub(crate) mod mask; pub(crate) mod merge; pub(crate) mod not; pub(crate) mod operators; pub(crate) mod pack; pub(crate) mod root; -pub(crate) mod scalar_fn; pub(crate) mod select; pub use between::*; @@ -27,10 +27,10 @@ pub use is_null::*; pub use like::*; pub use list_contains::*; pub use literal::*; +pub use mask::*; pub use merge::*; pub use not::*; pub use operators::*; pub use pack::*; pub use root::*; -pub use scalar_fn::*; pub use select::*; diff --git a/vortex-array/src/expr/exprs/not.rs b/vortex-array/src/expr/exprs/not.rs index 952384e14d8..5185af58c97 100644 --- a/vortex-array/src/expr/exprs/not.rs +++ b/vortex-array/src/expr/exprs/not.rs @@ -8,94 +8,92 @@ use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_vector::Vector; +use vortex_vector::Datum; use crate::ArrayRef; use crate::compute::invert; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::EmptyOptions; use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::ScalarFnExprExt; use crate::expr::VTable; use crate::expr::VTableExt; -use crate::expr::functions::EmptyOptions; -use crate::scalar_fns::not; /// Expression that logically inverts boolean values. pub struct Not; impl VTable for Not { - type Instance = (); + type Options = EmptyOptions; fn id(&self) -> ExprId { - ExprId::new_ref("vortex.not") + ExprId::from("vortex.not") } - fn serialize(&self, _instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { Ok(Some(vec![])) } - fn deserialize(&self, _metadata: &[u8]) -> VortexResult> { - Ok(Some(())) + fn deserialize(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyOptions) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "Not expression expects exactly one child, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("input"), _ => unreachable!("Invalid child index {} for Not expression", child_idx), } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "not(")?; expr.child(0).fmt_sql(f)?; write!(f, ")") } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let child_dtype = expr.child(0).return_dtype(scope)?; + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let child_dtype = &arg_dtypes[0]; if !matches!(child_dtype, DType::Bool(_)) { vortex_bail!( "Not expression expects a boolean child, got: {}", child_dtype ); } - Ok(child_dtype) + Ok(child_dtype.clone()) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + _options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let child_result = expr.child(0).evaluate(scope)?; invert(&child_result) } - fn execute(&self, _data: &Self::Instance, mut args: ExecutionArgs) -> VortexResult { - let child = args.vectors.pop().vortex_expect("Missing input child"); + fn execute(&self, _data: &Self::Options, mut args: ExecutionArgs) -> VortexResult { + let child = args.datums.pop().vortex_expect("Missing input child"); Ok(child.into_bool().not().into()) } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _options: &Self::Options) -> bool { false } - - fn expr_v2(&self, view: &ExpressionView) -> VortexResult { - ScalarFnExprExt::try_new_expr(¬::NotFn, EmptyOptions, view.children().clone()) - } } /// Creates an expression that logically inverts boolean values. @@ -107,7 +105,7 @@ impl VTable for Not { /// let expr = not(root()); /// ``` pub fn not(operand: Expression) -> Expression { - Not.new_expr((), vec![operand]) + Not.new_expr(EmptyOptions, vec![operand]) } #[cfg(test)] diff --git a/vortex-array/src/expr/exprs/pack.rs b/vortex-array/src/expr/exprs/pack/mod.rs similarity index 73% rename from vortex-array/src/expr/exprs/pack.rs rename to vortex-array/src/expr/exprs/pack/mod.rs index a03136ca6f4..d7f5725fca3 100644 --- a/vortex-array/src/expr/exprs/pack.rs +++ b/vortex-array/src/expr/exprs/pack/mod.rs @@ -1,8 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +pub(crate) mod rules; + +use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hash; +use std::sync::Arc; use itertools::Itertools as _; use prost::Message; @@ -11,21 +15,27 @@ use vortex_dtype::FieldName; use vortex_dtype::FieldNames; use vortex_dtype::Nullability; use vortex_dtype::StructFields; +use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; +use vortex_mask::Mask; use vortex_proto::expr as pb; +use vortex_vector::struct_::StructVector; +use vortex_vector::Datum; +use vortex_vector::ScalarOps; +use vortex_vector::VectorMutOps; +use vortex_vector::VectorOps; -use crate::ArrayRef; -use crate::IntoArray; use crate::arrays::StructArray; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; -use crate::expr::ExpressionView; use crate::expr::VTable; use crate::expr::VTableExt; use crate::validity::Validity; +use crate::ArrayRef; +use crate::IntoArray; /// Pack zero or more expressions into a structure with named fields. pub struct Pack; @@ -36,14 +46,25 @@ pub struct PackOptions { pub nullability: Nullability, } +impl Display for PackOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "names: [{}], nullability: {}", + self.names.iter().join(", "), + self.nullability + ) + } +} + impl VTable for Pack { - type Instance = PackOptions; + type Options = PackOptions; fn id(&self) -> ExprId { ExprId::new_ref("vortex.pack") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::PackOpts { paths: instance.names.iter().map(|n| n.to_string()).collect(), @@ -53,32 +74,24 @@ impl VTable for Pack { )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let opts = pb::PackOpts::decode(metadata)?; let names: FieldNames = opts .paths .iter() .map(|name| FieldName::from(name.as_str())) .collect(); - Ok(Some(PackOptions { + Ok(PackOptions { names, nullability: opts.nullable.into(), - })) + }) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - let instance = expr.data(); - if expr.children().len() != instance.names.len() { - vortex_bail!( - "Pack expression expects {} children, got {}", - instance.names.len(), - expr.children().len() - ); - } - Ok(()) + fn arity(&self, options: &Self::Options) -> Arity { + Arity::Exact(options.names.len()) } - fn child_name(&self, instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, instance: &Self::Options, child_idx: usize) -> ChildName { match instance.names.get(child_idx) { Some(name) => ChildName::from(name.inner().clone()), None => unreachable!( @@ -89,87 +102,92 @@ impl VTable for Pack { } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "pack(")?; - for (i, (name, child)) in expr - .data() - .names - .iter() - .zip(expr.children().iter()) - .enumerate() - { + for (i, (name, child)) in options.names.iter().zip(expr.children().iter()).enumerate() { write!(f, "{}: ", name)?; child.fmt_sql(f)?; - if i + 1 < expr.data().names.len() { + if i + 1 < options.names.len() { write!(f, ", ")?; } } - write!(f, "){}", expr.data().nullability) + write!(f, "){}", options.nullability) } - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let value_dtypes = expr - .children() - .iter() - .map(|child| child.return_dtype(scope)) - .collect::>>()?; + fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { Ok(DType::Struct( - StructFields::new(expr.data().names.clone(), value_dtypes), - expr.data().nullability, + StructFields::new(options.names.clone(), arg_dtypes.to_vec()), + options.nullability, )) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + options: &Self::Options, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { let len = scope.len(); let value_arrays = expr .children() .iter() - .zip_eq(expr.data().names.iter()) + .zip_eq(options.names.iter()) .map(|(child_expr, name)| { child_expr .evaluate(scope) .map_err(|e| e.with_context(format!("Can't evaluate '{name}'"))) }) .process_results(|it| it.collect::>())?; - let validity = match expr.data().nullability { + let validity = match options.nullability { Nullability::NonNullable => Validity::NonNullable, Nullability::Nullable => Validity::AllValid, }; - Ok( - StructArray::try_new(expr.data().names.clone(), value_arrays, len, validity)? - .into_array(), - ) + Ok(StructArray::try_new(options.names.clone(), value_arrays, len, validity)?.into_array()) + } + + fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult { + // If any datum is a vector, we must convert them all to vectors. + if args.datums.iter().any(|d| matches!(d, Datum::Vector(_))) { + let fields: Box<[_]> = args + .datums + .into_iter() + .map(|v| v.ensure_vector(args.row_count)) + .collect(); + return Ok(Datum::Vector( + StructVector::try_new(Arc::new(fields), Mask::new_true(args.row_count))?.into(), + )); + } + + // Otherwise, we can produce a scalar datum by constructing a length-1 struct vector. + let fields: Box<[_]> = args + .datums + .into_iter() + .map(|d| { + d.into_scalar() + .vortex_expect("all scalars") + .repeat(1) + .freeze() + }) + .collect(); + let vector = StructVector::new(Arc::new(fields), Mask::new_true(1)); + Ok(Datum::Scalar(vector.scalar_at(0).into())) } // This applies a nullability - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _instance: &Self::Options) -> bool { false } } -impl ExpressionView<'_, Pack> { - pub fn field(&self, field_name: &FieldName) -> VortexResult { - let idx = self - .data() - .names - .iter() - .position(|name| name == field_name) - .ok_or_else(|| { - vortex_err!( - "Cannot find field {} in pack fields {:?}", - field_name, - self.data().names - ) - })?; - - Ok(self.child(idx).clone()) - } -} - /// Creates an expression that packs values into a struct with named fields. /// /// ```rust @@ -198,22 +216,22 @@ pub fn pack( mod tests { use vortex_buffer::buffer; use vortex_dtype::Nullability; - use vortex_error::VortexResult; use vortex_error::vortex_bail; + use vortex_error::VortexResult; + use super::pack; use super::Pack; use super::PackOptions; - use super::pack; - use crate::Array; - use crate::ArrayRef; - use crate::IntoArray; - use crate::ToCanonical; use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; - use crate::expr::VTableExt; use crate::expr::exprs::get_item::col; + use crate::expr::VTableExt; use crate::validity::Validity; use crate::vtable::ValidityHelper; + use crate::Array; + use crate::ArrayRef; + use crate::IntoArray; + use crate::ToCanonical; fn test_array() -> ArrayRef { StructArray::from_fields(&[ diff --git a/vortex-array/src/expr/exprs/pack/rules.rs b/vortex-array/src/expr/exprs/pack/rules.rs new file mode 100644 index 00000000000..d07904594ed --- /dev/null +++ b/vortex-array/src/expr/exprs/pack/rules.rs @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::arrays::ExactScalarFn; +use crate::arrays::FilterArray; +use crate::arrays::FilterVTable; +use crate::arrays::ScalarFnArrayExt; +use crate::arrays::ScalarFnArrayView; +use crate::expr::Pack; +use crate::optimizer::rules::ArrayParentReduceRule; +use crate::optimizer::rules::Exact; +use crate::ArrayRef; +use crate::ArrayVisitor; +use crate::IntoArray; + +/// Pack expression should always push-down filter, regardless of cost. +#[derive(Debug)] +pub(crate) struct PackFilterPushdown; + +impl ArrayParentReduceRule, Exact> for PackFilterPushdown { + fn child(&self) -> ExactScalarFn { + ExactScalarFn::from(&Pack) + } + + fn parent(&self) -> Exact { + Exact::from(&FilterVTable) + } + + fn reduce_parent( + &self, + child: ScalarFnArrayView, + parent: &FilterArray, + _child_idx: usize, + ) -> VortexResult> { + let new_children: Vec<_> = child + .children() + .into_iter() + .map(|child| FilterArray::new(child, parent.mask().clone()).into_array()) + .collect(); + Ok(Some( + Pack.try_new_array(parent.len(), child.options.clone(), new_children)? + .into_array(), + )) + } +} diff --git a/vortex-array/src/expr/exprs/root.rs b/vortex-array/src/expr/exprs/root.rs index d3af4e1e089..2433ada7ea0 100644 --- a/vortex-array/src/expr/exprs/root.rs +++ b/vortex-array/src/expr/exprs/root.rs @@ -8,13 +8,14 @@ use vortex_dtype::FieldPath; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_vector::Vector; +use vortex_vector::Datum; use crate::ArrayRef; +use crate::expr::Arity; use crate::expr::ChildName; +use crate::expr::EmptyOptions; use crate::expr::ExecutionArgs; use crate::expr::ExprId; -use crate::expr::ExpressionView; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -26,67 +27,72 @@ use crate::expr::stats::Stat; pub struct Root; impl VTable for Root { - type Instance = (); + type Options = EmptyOptions; fn id(&self) -> ExprId { ExprId::from("vortex.root") } - fn serialize(&self, _instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, _instance: &Self::Options) -> VortexResult>> { Ok(Some(vec![])) } - fn deserialize(&self, _metadata: &[u8]) -> VortexResult> { - Ok(Some(())) + fn deserialize(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyOptions) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if !expr.children().is_empty() { - vortex_bail!( - "Root expression does not have children, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(0) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { unreachable!( "Root expression does not have children, got index {}", child_idx ) } - fn fmt_sql(&self, _expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_sql( + &self, + _options: &Self::Options, + _expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { write!(f, "$") } - fn return_dtype(&self, _expr: &ExpressionView, scope: &DType) -> VortexResult { - Ok(scope.clone()) + fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult { + vortex_bail!("Root expression does not support return_dtype") } - fn evaluate(&self, _expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { + fn evaluate( + &self, + _options: &Self::Options, + _expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { Ok(scope.clone()) } - fn execute(&self, _data: &Self::Instance, _args: ExecutionArgs) -> VortexResult { + fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult { vortex_bail!("Root expression is not executable") } fn stat_expression( &self, - _expr: &ExpressionView, + _options: &Self::Options, + _expr: &Expression, stat: Stat, catalog: &dyn StatsCatalog, ) -> Option { catalog.stats_ref(&FieldPath::root(), stat) } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _options: &Self::Options) -> bool { false } } @@ -96,7 +102,7 @@ impl VTable for Root { /// Returns the entire input array as passed to the expression evaluator. /// This is commonly used as the starting point for field access and other operations. pub fn root() -> Expression { - Root.try_new_expr((), vec![]) + Root.try_new_expr(EmptyOptions, vec![]) .vortex_expect("Failed to create Root expression") } diff --git a/vortex-array/src/expr/exprs/scalar_fn.rs b/vortex-array/src/expr/exprs/scalar_fn.rs deleted file mode 100644 index 20d898c8530..00000000000 --- a/vortex-array/src/expr/exprs/scalar_fn.rs +++ /dev/null @@ -1,185 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::fmt::Debug; -use std::fmt::Formatter; -use std::marker::PhantomData; -use std::sync::Arc; - -use itertools::Itertools; -use vortex_dtype::DType; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_session::SessionVar; -use vortex_vector::Datum; -use vortex_vector::ScalarOps; -use vortex_vector::Vector; -use vortex_vector::VectorMutOps; - -use crate::ArrayRef; -use crate::IntoArray; -use crate::arrays::ScalarFnArray; -use crate::expr::ChildName; -use crate::expr::ExecutionArgs; -use crate::expr::ExprId; -use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::StatsCatalog; -use crate::expr::VTable; -use crate::expr::functions; -use crate::expr::functions::ScalarFnVTable; -use crate::expr::functions::scalar::ScalarFn; -use crate::expr::stats::Stat; -use crate::expr::transform::rules::Matcher; - -/// An expression that wraps arbitrary scalar functions. -/// -/// Note that for backwards-compatibility, the `id` of this expression is the same as the -/// `id` of the underlying scalar function vtable, rather than being something constant like -/// `vortex.scalar_fn`. -pub struct ScalarFnExpr { - /// The vtable of the particular scalar function represented by this expression. - vtable: ScalarFnVTable, -} - -impl VTable for ScalarFnExpr { - type Instance = ScalarFn; - - fn id(&self) -> ExprId { - self.vtable.id() - } - - fn serialize(&self, func: &ScalarFn) -> VortexResult>> { - func.options().serialize() - } - - fn deserialize(&self, bytes: &[u8]) -> VortexResult> { - self.vtable.deserialize(bytes).map(Some) - } - - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - vortex_ensure!( - expr.data() - .signature() - .arity() - .matches(expr.children().len()), - "invalid number of arguments for scalar function" - ); - Ok(()) - } - - fn child_name(&self, _func: &ScalarFn, _child_idx: usize) -> ChildName { - "unknown".into() - } - - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}(", expr.data())?; - for (i, child) in expr.children().iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - child.fmt_sql(f)?; - } - write!(f, ")") - } - - fn fmt_data(&self, func: &ScalarFn, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", func) - } - - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let arg_dtypes: Vec<_> = expr - .children() - .iter() - .map(|e| e.return_dtype(scope)) - .try_collect()?; - expr.data().return_dtype(&arg_dtypes) - } - - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - let children: Vec<_> = expr - .children() - .iter() - .map(|child| child.evaluate(scope)) - .try_collect()?; - Ok(ScalarFnArray::try_new(expr.data().clone(), children, scope.len())?.into_array()) - } - - fn execute(&self, func: &ScalarFn, args: ExecutionArgs) -> VortexResult { - let expr_args = functions::ExecutionArgs::new( - args.row_count, - args.return_dtype, - args.dtypes, - args.vectors.into_iter().map(Datum::Vector).collect(), - ); - let result = func.execute(&expr_args)?; - Ok(match result { - Datum::Scalar(s) => s.repeat(args.row_count).freeze(), - Datum::Vector(v) => v, - }) - } - - fn stat_falsification( - &self, - _expr: &ExpressionView, - _catalog: &dyn StatsCatalog, - ) -> Option { - // TODO(ngates): ideally this is implemented as optimizer rules over a `falsify` and - // `verify` expressions. - todo!() - } - - fn stat_expression( - &self, - _expr: &ExpressionView, - _stat: Stat, - _catalog: &dyn StatsCatalog, - ) -> Option { - // TODO(ngates): ideally this is implemented specifically for the Zoned layout, no one - // else needs to know what a specific stat over a column resolves to. - todo!() - } - - fn is_null_sensitive(&self, _func: &ScalarFn) -> bool { - todo!() - } -} - -/// A matcher that matches any scalar function expression. -#[derive(Debug)] -pub struct AnyScalarFn; -impl Matcher for AnyScalarFn { - type View<'a> = &'a ScalarFn; - - fn try_match(parent: &Expression) -> Option> { - Some(parent.as_opt::()?.data()) - } -} - -/// A matcher that matches a specific scalar function expression. -#[derive(Debug)] -pub struct ExactScalarFn(PhantomData); -impl Matcher for ExactScalarFn { - type View<'a> = &'a F::Options; - - fn try_match(parent: &Expression) -> Option> { - let expr_view = parent.as_opt::()?; - expr_view.data().as_any().downcast_ref::() - } -} - -/// Expression factory functions for ScalarFn vtables. -pub trait ScalarFnExprExt: functions::VTable { - fn try_new_expr( - &'static self, - options: Self::Options, - children: impl Into>, - ) -> VortexResult { - let expr_vtable = ScalarFnExpr { - vtable: ScalarFnVTable::new_static(self), - }; - let scalar_fn = ScalarFn::new_static(self, options); - Expression::try_new(expr_vtable, scalar_fn, children) - } -} -impl ScalarFnExprExt for V {} diff --git a/vortex-array/src/expr/exprs/select/mod.rs b/vortex-array/src/expr/exprs/select.rs similarity index 70% rename from vortex-array/src/expr/exprs/select/mod.rs rename to vortex-array/src/expr/exprs/select.rs index 8b27fb73d3a..a074d8bf643 100644 --- a/vortex-array/src/expr/exprs/select/mod.rs +++ b/vortex-array/src/expr/exprs/select.rs @@ -1,8 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -pub mod transform; - use std::fmt::Display; use std::fmt::Formatter; use std::sync::Arc; @@ -18,20 +16,25 @@ use vortex_error::vortex_err; use vortex_proto::expr::FieldNames as ProtoFieldNames; use vortex_proto::expr::SelectOpts; use vortex_proto::expr::select_opts::Opts; -use vortex_vector::Vector; +use vortex_vector::Datum; +use vortex_vector::StructDatum; +use vortex_vector::VectorOps; use vortex_vector::struct_::StructVector; use crate::ArrayRef; use crate::IntoArray; use crate::ToCanonical; +use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::ExecutionArgs; use crate::expr::ExprId; -use crate::expr::ExpressionView; +use crate::expr::SimplifyCtx; use crate::expr::VTable; use crate::expr::VTableExt; use crate::expr::expression::Expression; use crate::expr::field::DisplayFieldNames; +use crate::expr::get_item; +use crate::expr::pack; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum FieldSelection { @@ -42,13 +45,13 @@ pub enum FieldSelection { pub struct Select; impl VTable for Select { - type Instance = FieldSelection; + type Options = FieldSelection; fn id(&self) -> ExprId { ExprId::new_ref("vortex.select") } - fn serialize(&self, instance: &Self::Instance) -> VortexResult>> { + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { let opts = match instance { FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames { names: fields.iter().map(|f| f.to_string()).collect(), @@ -62,7 +65,7 @@ impl VTable for Select { Ok(Some(select_opts.encode_to_vec())) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult> { + fn deserialize(&self, metadata: &[u8]) -> VortexResult { let prost_metadata = SelectOpts::decode(metadata)?; let select_opts = prost_metadata @@ -78,29 +81,28 @@ impl VTable for Select { )), }; - Ok(Some(field_selection)) + Ok(field_selection) } - fn validate(&self, expr: &ExpressionView) -> VortexResult<()> { - if expr.children().len() != 1 { - vortex_bail!( - "Select expression requires exactly 1 child, got {}", - expr.children().len() - ); - } - Ok(()) + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::new_ref("child"), _ => unreachable!(), } } - fn fmt_sql(&self, expr: &ExpressionView, f: &mut Formatter<'_>) -> std::fmt::Result { - expr.child().fmt_sql(f)?; - match expr.data() { + fn fmt_sql( + &self, + selection: &FieldSelection, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + expr.child(0).fmt_sql(f)?; + match selection { FieldSelection::Include(fields) => { write!(f, "{{{}}}", DisplayFieldNames(fields)) } @@ -110,27 +112,17 @@ impl VTable for Select { } } - fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result { - let names = match instance { - FieldSelection::Include(names) => { - write!(f, "include=")?; - names - } - FieldSelection::Exclude(names) => { - write!(f, "exclude=")?; - names - } - }; - write!(f, "{{{}}}", DisplayFieldNames(names)) - } - - fn return_dtype(&self, expr: &ExpressionView, scope: &DType) -> VortexResult { - let child_dtype = expr.child().return_dtype(scope)?; + fn return_dtype( + &self, + selection: &FieldSelection, + arg_dtypes: &[DType], + ) -> VortexResult { + let child_dtype = &arg_dtypes[0]; let child_struct_dtype = child_dtype .as_struct_fields_opt() .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?; - let projected = match expr.data() { + let projected = match selection { FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?, FieldSelection::Exclude(fields) => child_struct_dtype .names() @@ -144,9 +136,49 @@ impl VTable for Select { Ok(DType::Struct(projected, child_dtype.nullability())) } - fn evaluate(&self, expr: &ExpressionView, scope: &ArrayRef) -> VortexResult { - let batch = expr.child().evaluate(scope)?.to_struct(); - Ok(match expr.data() { + fn simplify( + &self, + options: &Self::Options, + expr: &Expression, + ctx: &dyn SimplifyCtx, + ) -> VortexResult> { + let child = expr.child(0); + let child_dtype = ctx.return_dtype(child)?; + let child_nullability = child_dtype.nullability(); + + let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| { + vortex_err!( + "Select child must return a struct dtype, however it was a {}", + child_dtype + ) + })?; + + let expr = pack( + options + .as_include_names(child_dtype.names()) + .map_err(|e| { + e.with_context(format!( + "Select fields {:?} must be a subset of child fields {:?}", + options, + child_dtype.names() + )) + })? + .iter() + .map(|name| (name.clone(), get_item(name.clone(), child.clone()))), + child_nullability, + ); + + Ok(Some(expr)) + } + + fn evaluate( + &self, + selection: &FieldSelection, + expr: &Expression, + scope: &ArrayRef, + ) -> VortexResult { + let batch = expr.child(0).evaluate(scope)?.to_struct(); + Ok(match selection { FieldSelection::Include(f) => batch.project(f.as_ref()), FieldSelection::Exclude(names) => { let included_names = batch @@ -161,12 +193,7 @@ impl VTable for Select { .into_array()) } - fn execute(&self, selection: &FieldSelection, mut args: ExecutionArgs) -> VortexResult { - let child = args - .vectors - .pop() - .vortex_expect("Missing input child") - .into_struct(); + fn execute(&self, selection: &FieldSelection, mut args: ExecutionArgs) -> VortexResult { let child_fields = args .dtypes .pop() @@ -194,24 +221,44 @@ impl VTable for Select { .try_collect(), }?; - let (fields, mask) = child.into_parts(); - let new_fields = field_indices - .iter() - .map(|&idx| fields[idx].clone()) - .collect(); - Ok(unsafe { StructVector::new_unchecked(Arc::new(new_fields), mask) }.into()) + let child = args + .datums + .pop() + .vortex_expect("Missing input child") + .into_struct(); + + Ok(match child { + StructDatum::Scalar(s) => StructDatum::Scalar( + select_from_struct_vector(s.value(), &field_indices)?.scalar_at(0), + ), + StructDatum::Vector(v) => { + StructDatum::Vector(select_from_struct_vector(&v, &field_indices)?) + } + } + .into()) } - fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool { + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } - fn is_fallible(&self, _instance: &Self::Instance) -> bool { + fn is_fallible(&self, _instance: &Self::Options) -> bool { // If this type-checks its infallible. false } } +fn select_from_struct_vector( + vec: &StructVector, + field_indices: &[usize], +) -> VortexResult { + let new_fields = field_indices + .iter() + .map(|&idx| vec.fields()[idx].clone()) + .collect(); + Ok(unsafe { StructVector::new_unchecked(Arc::new(new_fields), vec.validity().clone()) }) +} + /// Creates an expression that selects (includes) specific fields from an array. /// /// Projects only the specified fields from the child expression, which must be of DType struct. @@ -239,34 +286,6 @@ pub fn select_exclude(fields: impl Into, child: Expression) -> Expre .vortex_expect("Failed to create Select expression") } -impl ExpressionView<'_, Select> { - pub fn child(&self) -> &Expression { - &self.children()[0] - } - - /// Turn the select expression into an `include`, relative to a provided array of field names. - /// - /// For example: - /// ```rust - /// # use vortex_array::expr::{root, Select}; - /// # use vortex_array::expr::{FieldSelection, select, select_exclude}; - /// # use vortex_dtype::FieldNames; - /// let field_names = FieldNames::from(["a", "b", "c"]); - /// let include = select(["a"], root()); - /// let exclude = select_exclude(["b", "c"], root()); - /// assert_eq!( - /// &include.as_::().as_include(&field_names).unwrap(), - /// ); - /// ``` - pub fn as_include(&self, field_names: &FieldNames) -> VortexResult { - Select.try_new_expr( - FieldSelection::Include(self.data().as_include_names(field_names)?), - [self.child().clone()], - ) - } -} - impl FieldSelection { pub fn include(columns: FieldNames) -> Self { assert_eq!(columns.iter().unique().collect_vec().len(), columns.len()); @@ -331,12 +350,16 @@ mod tests { use vortex_dtype::FieldName; use vortex_dtype::FieldNames; use vortex_dtype::Nullability; + use vortex_dtype::Nullability::Nullable; + use vortex_dtype::PType::I32; + use vortex_dtype::StructFields; use super::select; use super::select_exclude; use crate::IntoArray; use crate::ToCanonical; use crate::arrays::StructArray; + use crate::expr::exprs::pack::Pack; use crate::expr::exprs::root::root; use crate::expr::exprs::select::Select; use crate::expr::test_harness; @@ -421,14 +444,50 @@ mod tests { assert_eq!( &include .as_::() - .data() .as_include_names(&field_names) .unwrap() ); } + + #[test] + fn test_remove_select_rule() { + let dtype = DType::Struct( + StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]), + Nullable, + ); + let e = select(["a", "b"], root()); + + let result = e.simplify(&dtype).unwrap(); + + assert!(result.is::()); + assert!(result.return_dtype(&dtype).unwrap().is_nullable()); + } + + #[test] + fn test_remove_select_rule_exclude_fields() { + use crate::expr::exprs::select::select_exclude; + + let dtype = DType::Struct( + StructFields::new( + ["a", "b", "c"].into(), + vec![I32.into(), I32.into(), I32.into()], + ), + Nullable, + ); + let e = select_exclude(["c"], root()); + + let result = e.simplify(&dtype).unwrap(); + + assert!(result.is::()); + + // Should exclude "c" and include "a" and "b" + let result_dtype = result.return_dtype(&dtype).unwrap(); + assert!(result_dtype.is_nullable()); + let fields = result_dtype.as_struct_fields_opt().unwrap(); + assert_eq!(fields.names().as_ref(), &["a", "b"]); + } } diff --git a/vortex-array/src/expr/exprs/select/transform.rs b/vortex-array/src/expr/exprs/select/transform.rs deleted file mode 100644 index 7181d0b2f18..00000000000 --- a/vortex-array/src/expr/exprs/select/transform.rs +++ /dev/null @@ -1,118 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; -use vortex_error::vortex_err; - -use crate::expr::Expression; -use crate::expr::ExpressionView; -use crate::expr::exprs::get_item::get_item; -use crate::expr::exprs::pack::pack; -use crate::expr::exprs::select::Select; -use crate::expr::transform::rules::ReduceRule; -use crate::expr::transform::rules::TypedRuleContext; - -/// Rule that removes Select expressions by converting them to Pack + GetItem. -/// -/// Transforms: `select(["a", "b"], expr)` → `pack(a: get_item("a", expr), b: get_item("b", expr))` -#[derive(Debug, Default)] -pub struct RemoveSelectRule; - -impl ReduceRule for RemoveSelectRule { - fn reduce( - &self, - select: &ExpressionView(); - let result = RemoveSelectRule.reduce(&select_view, &ctx).unwrap(); - - assert!(result.is_some()); - let transformed = result.unwrap(); - assert!(transformed.is::()); - assert!(transformed.return_dtype(&dtype).unwrap().is_nullable()); - } - - #[test] - fn test_remove_select_rule_exclude_fields() { - use crate::expr::exprs::select::select_exclude; - - let dtype = DType::Struct( - StructFields::new( - ["a", "b", "c"].into(), - vec![I32.into(), I32.into(), I32.into()], - ), - Nullable, - ); - let e = select_exclude(["c"], root()); - - let ctx = TypedRuleContext::new(dtype.clone()); - let select_view = e.as_::() { - self.fields.extend(sel.data().field_names().iter().cloned()); + if let Some(field_selection) = node.as_opt::