Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions vortex-array/src/arrays/chunked/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::IntoArray;
use crate::ToCanonical;
use crate::arrays::ChunkedArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::chunked::vtable::rules::PARENT_RULES;
use crate::serde::ArrayChildren;
use crate::validity::Validity;
use crate::vtable;
Expand All @@ -31,6 +32,7 @@ mod array;
mod canonical;
mod compute;
mod operations;
mod rules;
mod validity;
mod visitor;

Expand Down Expand Up @@ -166,4 +168,12 @@ impl VTable for ChunkedVTable {
_ => None,
})
}

fn reduce_parent(
array: &Self::Array,
parent: &ArrayRef,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
PARENT_RULES.evaluate(array, parent, child_idx)
}
}
121 changes: 121 additions & 0 deletions vortex-array/src/arrays/chunked/vtable/rules.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use itertools::Itertools;
use vortex_error::VortexResult;

use crate::Array;
use crate::ArrayRef;
use crate::IntoArray;
use crate::arrays::AnyScalarFn;
use crate::arrays::ChunkedArray;
use crate::arrays::ChunkedVTable;
use crate::arrays::ConstantArray;
use crate::arrays::ConstantVTable;
use crate::arrays::ScalarFnArray;
use crate::optimizer::ArrayOptimizer;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::optimizer::rules::ParentRuleSet;

pub(super) const PARENT_RULES: ParentRuleSet<ChunkedVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&ChunkedUnaryScalarFnPushDownRule),
ParentRuleSet::lift(&ChunkedConstantScalarFnPushDownRule),
]);

/// Push down any unary scalar function through chunked arrays.
#[derive(Debug)]
struct ChunkedUnaryScalarFnPushDownRule;
impl ArrayParentReduceRule<ChunkedVTable> for ChunkedUnaryScalarFnPushDownRule {
type Parent = AnyScalarFn;

fn parent(&self) -> Self::Parent {
AnyScalarFn
}

fn reduce_parent(
&self,
array: &ChunkedArray,
parent: &ScalarFnArray,
_child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
if parent.children().len() != 1 {
return Ok(None);
}

let new_chunks: Vec<_> = array
.chunks
.iter()
.map(|chunk| {
ScalarFnArray::try_new(
parent.scalar_fn().clone(),
vec![chunk.clone()],
chunk.len(),
)?
.into_array()
.optimize()
})
.try_collect()?;

Ok(Some(
unsafe { ChunkedArray::new_unchecked(new_chunks, parent.dtype().clone()) }.into_array(),
))
}
}

/// Push down non-unary scalar functions through chunked arrays where other siblings are constant.
#[derive(Debug)]
struct ChunkedConstantScalarFnPushDownRule;
impl ArrayParentReduceRule<ChunkedVTable> for ChunkedConstantScalarFnPushDownRule {
type Parent = AnyScalarFn;

fn parent(&self) -> Self::Parent {
AnyScalarFn
}

fn reduce_parent(
&self,
array: &ChunkedArray,
parent: &ScalarFnArray,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
for (idx, child) in parent.children().iter().enumerate() {
if idx == child_idx {
continue;
}
if !child.is::<ConstantVTable>() {
return Ok(None);
}
}

let new_chunks: Vec<_> = array
.chunks
.iter()
.map(|chunk| {
let new_children: Vec<_> = parent
.children()
.iter()
.enumerate()
.map(|(idx, child)| {
if idx == child_idx {
chunk.clone()
} else {
ConstantArray::new(
child.as_::<ConstantVTable>().scalar().clone(),
chunk.len(),
)
.into_array()
}
})
.collect();

ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, chunk.len())?
.into_array()
.optimize()
})
.try_collect()?;

Ok(Some(
unsafe { ChunkedArray::new_unchecked(new_chunks, parent.dtype().clone()) }.into_array(),
))
}
}
1 change: 0 additions & 1 deletion vortex-array/src/arrays/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
mod array;
pub use array::StructArray;
mod compute;
mod rules;

mod vtable;
pub use vtable::StructVTable;
Expand Down
5 changes: 3 additions & 2 deletions vortex-array/src/arrays/struct_/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::ArrayRef;
use crate::EmptyMetadata;
use crate::VectorExecutor;
use crate::arrays::struct_::StructArray;
use crate::arrays::struct_::rules::RULES;
use crate::arrays::struct_::vtable::rules::PARENT_RULES;
use crate::executor::ExecutionCtx;
use crate::serde::ArrayChildren;
use crate::validity::Validity;
Expand All @@ -30,6 +30,7 @@ use crate::vtable::ValidityVTableFromValidityHelper;
mod array;
mod canonical;
mod operations;
mod rules;
mod validity;
mod visitor;

Expand Down Expand Up @@ -159,7 +160,7 @@ impl VTable for StructVTable {
parent: &ArrayRef,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
RULES.evaluate(array, parent, child_idx)
PARENT_RULES.evaluate(array, parent, child_idx)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

use vortex_error::VortexResult;

use crate::Array;
use crate::ArrayRef;
use crate::IntoArray;
use crate::arrays::ConstantArray;
Expand All @@ -12,6 +11,8 @@ use crate::arrays::ScalarFnArrayExt;
use crate::arrays::ScalarFnArrayView;
use crate::arrays::StructArray;
use crate::arrays::StructVTable;
use crate::builtins::ArrayBuiltins;
use crate::expr::Cast;
use crate::expr::EmptyOptions;
use crate::expr::GetItem;
use crate::expr::Mask;
Expand All @@ -20,9 +21,48 @@ use crate::optimizer::rules::ParentRuleSet;
use crate::validity::Validity;
use crate::vtable::ValidityHelper;

pub(super) const RULES: ParentRuleSet<StructVTable> =
ParentRuleSet::new(&[ParentRuleSet::lift(&StructGetItemRule)]);
pub(super) const PARENT_RULES: ParentRuleSet<StructVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&StructCastPushDownRule),
ParentRuleSet::lift(&StructGetItemRule),
]);

/// Rule to push down cast into struct fields
#[derive(Debug)]
struct StructCastPushDownRule;
impl ArrayParentReduceRule<StructVTable> for StructCastPushDownRule {
type Parent = ExactScalarFn<Cast>;

fn parent(&self) -> Self::Parent {
ExactScalarFn::from(&Cast)
}

fn reduce_parent(
&self,
array: &StructArray,
parent: ScalarFnArrayView<Cast>,
_child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
let target_fields = parent.options.as_struct_fields();

let mut new_fields = Vec::with_capacity(target_fields.nfields());
for (field_array, field_dtype) in array.fields.iter().zip(target_fields.fields()) {
new_fields.push(field_array.cast(field_dtype)?)
}

let new_struct = unsafe {
StructArray::new_unchecked(
new_fields,
target_fields.clone(),
array.len(),
array.validity().clone(),
)
};

Ok(Some(new_struct.into_array()))
}
}

/// Rule to flatten get_item from struct by field name
#[derive(Debug)]
pub(crate) struct StructGetItemRule;
impl ArrayParentReduceRule<StructVTable> for StructGetItemRule {
Expand Down
Loading