diff --git a/vortex-array/src/arrays/chunked/vtable/mod.rs b/vortex-array/src/arrays/chunked/vtable/mod.rs index c5ab0cb95e4..a24ad0c1360 100644 --- a/vortex-array/src/arrays/chunked/vtable/mod.rs +++ b/vortex-array/src/arrays/chunked/vtable/mod.rs @@ -18,6 +18,7 @@ use crate::IntoArray; use crate::ToCanonical; use crate::arrays::ChunkedArray; use crate::arrays::PrimitiveArray; +use crate::arrays::chunked::vtable::rules::PARENT_RULES; use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; @@ -31,6 +32,7 @@ mod array; mod canonical; mod compute; mod operations; +mod rules; mod validity; mod visitor; @@ -166,4 +168,12 @@ impl VTable for ChunkedVTable { _ => None, }) } + + fn reduce_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + PARENT_RULES.evaluate(array, parent, child_idx) + } } diff --git a/vortex-array/src/arrays/chunked/vtable/rules.rs b/vortex-array/src/arrays/chunked/vtable/rules.rs new file mode 100644 index 00000000000..d5ccd23bfe5 --- /dev/null +++ b/vortex-array/src/arrays/chunked/vtable/rules.rs @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use itertools::Itertools; +use vortex_error::VortexResult; + +use crate::Array; +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::AnyScalarFn; +use crate::arrays::ChunkedArray; +use crate::arrays::ChunkedVTable; +use crate::arrays::ConstantArray; +use crate::arrays::ConstantVTable; +use crate::arrays::ScalarFnArray; +use crate::optimizer::ArrayOptimizer; +use crate::optimizer::rules::ArrayParentReduceRule; +use crate::optimizer::rules::ParentRuleSet; + +pub(super) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ + ParentRuleSet::lift(&ChunkedUnaryScalarFnPushDownRule), + ParentRuleSet::lift(&ChunkedConstantScalarFnPushDownRule), +]); + +/// Push down any unary scalar function through chunked arrays. +#[derive(Debug)] +struct ChunkedUnaryScalarFnPushDownRule; +impl ArrayParentReduceRule for ChunkedUnaryScalarFnPushDownRule { + type Parent = AnyScalarFn; + + fn parent(&self) -> Self::Parent { + AnyScalarFn + } + + fn reduce_parent( + &self, + array: &ChunkedArray, + parent: &ScalarFnArray, + _child_idx: usize, + ) -> VortexResult> { + if parent.children().len() != 1 { + return Ok(None); + } + + let new_chunks: Vec<_> = array + .chunks + .iter() + .map(|chunk| { + ScalarFnArray::try_new( + parent.scalar_fn().clone(), + vec![chunk.clone()], + chunk.len(), + )? + .into_array() + .optimize() + }) + .try_collect()?; + + Ok(Some( + unsafe { ChunkedArray::new_unchecked(new_chunks, parent.dtype().clone()) }.into_array(), + )) + } +} + +/// Push down non-unary scalar functions through chunked arrays where other siblings are constant. +#[derive(Debug)] +struct ChunkedConstantScalarFnPushDownRule; +impl ArrayParentReduceRule for ChunkedConstantScalarFnPushDownRule { + type Parent = AnyScalarFn; + + fn parent(&self) -> Self::Parent { + AnyScalarFn + } + + fn reduce_parent( + &self, + array: &ChunkedArray, + parent: &ScalarFnArray, + child_idx: usize, + ) -> VortexResult> { + for (idx, child) in parent.children().iter().enumerate() { + if idx == child_idx { + continue; + } + if !child.is::() { + return Ok(None); + } + } + + let new_chunks: Vec<_> = array + .chunks + .iter() + .map(|chunk| { + let new_children: Vec<_> = parent + .children() + .iter() + .enumerate() + .map(|(idx, child)| { + if idx == child_idx { + chunk.clone() + } else { + ConstantArray::new( + child.as_::().scalar().clone(), + chunk.len(), + ) + .into_array() + } + }) + .collect(); + + ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, chunk.len())? + .into_array() + .optimize() + }) + .try_collect()?; + + Ok(Some( + unsafe { ChunkedArray::new_unchecked(new_chunks, parent.dtype().clone()) }.into_array(), + )) + } +} diff --git a/vortex-array/src/arrays/struct_/mod.rs b/vortex-array/src/arrays/struct_/mod.rs index 3541d7713e9..47d90317bc8 100644 --- a/vortex-array/src/arrays/struct_/mod.rs +++ b/vortex-array/src/arrays/struct_/mod.rs @@ -4,7 +4,6 @@ mod array; pub use array::StructArray; mod compute; -mod rules; mod vtable; pub use vtable::StructVTable; diff --git a/vortex-array/src/arrays/struct_/vtable/mod.rs b/vortex-array/src/arrays/struct_/vtable/mod.rs index 55951dcb224..1b946a38c26 100644 --- a/vortex-array/src/arrays/struct_/vtable/mod.rs +++ b/vortex-array/src/arrays/struct_/vtable/mod.rs @@ -17,7 +17,7 @@ use crate::ArrayRef; use crate::EmptyMetadata; use crate::VectorExecutor; use crate::arrays::struct_::StructArray; -use crate::arrays::struct_::rules::RULES; +use crate::arrays::struct_::vtable::rules::PARENT_RULES; use crate::executor::ExecutionCtx; use crate::serde::ArrayChildren; use crate::validity::Validity; @@ -30,6 +30,7 @@ use crate::vtable::ValidityVTableFromValidityHelper; mod array; mod canonical; mod operations; +mod rules; mod validity; mod visitor; @@ -159,7 +160,7 @@ impl VTable for StructVTable { parent: &ArrayRef, child_idx: usize, ) -> VortexResult> { - RULES.evaluate(array, parent, child_idx) + PARENT_RULES.evaluate(array, parent, child_idx) } } diff --git a/vortex-array/src/arrays/struct_/rules.rs b/vortex-array/src/arrays/struct_/vtable/rules.rs similarity index 61% rename from vortex-array/src/arrays/struct_/rules.rs rename to vortex-array/src/arrays/struct_/vtable/rules.rs index 5ad946cfacc..7c1e89e38d8 100644 --- a/vortex-array/src/arrays/struct_/rules.rs +++ b/vortex-array/src/arrays/struct_/vtable/rules.rs @@ -3,7 +3,6 @@ use vortex_error::VortexResult; -use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ConstantArray; @@ -12,6 +11,8 @@ use crate::arrays::ScalarFnArrayExt; use crate::arrays::ScalarFnArrayView; use crate::arrays::StructArray; use crate::arrays::StructVTable; +use crate::builtins::ArrayBuiltins; +use crate::expr::Cast; use crate::expr::EmptyOptions; use crate::expr::GetItem; use crate::expr::Mask; @@ -20,9 +21,48 @@ use crate::optimizer::rules::ParentRuleSet; use crate::validity::Validity; use crate::vtable::ValidityHelper; -pub(super) const RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&StructGetItemRule)]); +pub(super) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ + ParentRuleSet::lift(&StructCastPushDownRule), + ParentRuleSet::lift(&StructGetItemRule), +]); +/// Rule to push down cast into struct fields +#[derive(Debug)] +struct StructCastPushDownRule; +impl ArrayParentReduceRule for StructCastPushDownRule { + type Parent = ExactScalarFn; + + fn parent(&self) -> Self::Parent { + ExactScalarFn::from(&Cast) + } + + fn reduce_parent( + &self, + array: &StructArray, + parent: ScalarFnArrayView, + _child_idx: usize, + ) -> VortexResult> { + let target_fields = parent.options.as_struct_fields(); + + let mut new_fields = Vec::with_capacity(target_fields.nfields()); + for (field_array, field_dtype) in array.fields.iter().zip(target_fields.fields()) { + new_fields.push(field_array.cast(field_dtype)?) + } + + let new_struct = unsafe { + StructArray::new_unchecked( + new_fields, + target_fields.clone(), + array.len(), + array.validity().clone(), + ) + }; + + Ok(Some(new_struct.into_array())) + } +} + +/// Rule to flatten get_item from struct by field name #[derive(Debug)] pub(crate) struct StructGetItemRule; impl ArrayParentReduceRule for StructGetItemRule {