Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 2 additions & 117 deletions vortex-array/src/arrays/extension/vtable/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,16 @@ use vortex_error::VortexResult;
use crate::Array;
use crate::ArrayRef;
use crate::IntoArray;
use crate::arrays::AnyScalarFn;
use crate::arrays::ConstantArray;
use crate::arrays::ConstantVTable;
use crate::arrays::ExtensionArray;
use crate::arrays::ExtensionVTable;
use crate::arrays::FilterArray;
use crate::arrays::FilterVTable;
use crate::arrays::ScalarFnArray;
use crate::matchers::Exact;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::optimizer::rules::ParentRuleSet;

pub(super) const PARENT_RULES: ParentRuleSet<ExtensionVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&ExtensionFilterPushDownRule),
ParentRuleSet::lift(&ExtensionScalarFnConstantPushDownRule),
]);
pub(super) const PARENT_RULES: ParentRuleSet<ExtensionVTable> =
ParentRuleSet::new(&[ParentRuleSet::lift(&ExtensionFilterPushDownRule)]);

/// Push filter operations into the storage array of an extension array.
#[derive(Debug)]
Expand Down Expand Up @@ -51,68 +45,6 @@ impl ArrayParentReduceRule<ExtensionVTable> for ExtensionFilterPushDownRule {
}
}

/// Push scalar function operations into the storage array when the other operand is a constant
/// with the same extension type.
#[derive(Debug)]
struct ExtensionScalarFnConstantPushDownRule;

impl ArrayParentReduceRule<ExtensionVTable> for ExtensionScalarFnConstantPushDownRule {
type Parent = AnyScalarFn;

fn parent(&self) -> Self::Parent {
AnyScalarFn
}

fn reduce_parent(
&self,
child: &ExtensionArray,
parent: &ScalarFnArray,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
// Check that all other children are constants with matching extension types.
for (idx, sibling) in parent.children().iter().enumerate() {
if idx == child_idx {
continue;
}

// Sibling must be a constant.
let Some(const_array) = sibling.as_opt::<ConstantVTable>() else {
return Ok(None);
};

// Sibling must be an extension scalar with the same extension type.
let Some(ext_scalar) = const_array.scalar().as_extension_opt() else {
return Ok(None);
};

// ExtDType::eq_ignore_nullability checks id, metadata, and storage dtype
if !ext_scalar
.ext_dtype()
.eq_ignore_nullability(child.ext_dtype())
{
return Ok(None);
}
}

// Build new children with storage arrays/scalars.
let mut new_children = Vec::with_capacity(parent.children().len());
for (idx, sibling) in parent.children().iter().enumerate() {
if idx == child_idx {
new_children.push(child.storage().clone());
} else {
let const_array = sibling.as_::<ConstantVTable>();
let storage_scalar = const_array.scalar().as_extension().storage();
new_children.push(ConstantArray::new(storage_scalar, child.len()).into_array());
}
}

Ok(Some(
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, child.len())?
.into_array(),
))
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand All @@ -134,7 +66,6 @@ mod tests {
use crate::arrays::ExtensionVTable;
use crate::arrays::FilterArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::PrimitiveVTable;
use crate::arrays::ScalarFnArrayExt;
use crate::expr::Binary;
use crate::expr::Operator;
Expand Down Expand Up @@ -202,52 +133,6 @@ mod tests {
assert_eq!(canonical.len(), 3);
}

#[test]
fn test_scalar_fn_constant_pushdown_comparison() {
let ext_dtype = test_ext_dtype();
let storage = buffer![10i64, 20, 30, 40, 50].into_array();
let ext_array = ExtensionArray::new(ext_dtype.clone(), storage).into_array();

// Create a constant extension scalar with value 25
let const_scalar = Scalar::extension(ext_dtype, Scalar::from(25i64));
let const_array = ConstantArray::new(const_scalar, 5).into_array();

// Create a binary comparison: ext_array < const_array
let scalar_fn_array = Binary
.try_new_array(5, Operator::Lt, [ext_array, const_array])
.unwrap();

// Optimize should push down the comparison to storage
let optimized = scalar_fn_array.optimize().unwrap();

// The result should still be a ScalarFnArray but operating on primitive storage
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>();
assert!(
scalar_fn.is_some(),
"Expected ScalarFnArray after optimization"
);

// The children should now be primitives, not extensions
let children = scalar_fn.unwrap().children();
assert_eq!(children.len(), 2);

// First child should be the primitive storage
assert!(
children[0].as_opt::<PrimitiveVTable>().is_some(),
"Expected first child to be PrimitiveArray, got {}",
children[0].encoding_id()
);

// Second child should be a constant with primitive value
assert!(
children[1]
.as_opt::<crate::arrays::ConstantVTable>()
.is_some(),
"Expected second child to be ConstantArray, got {}",
children[1].encoding_id()
);
}

#[test]
fn test_scalar_fn_no_pushdown_different_ext_types() {
let ext_dtype1 = Arc::new(ExtDType::new(
Expand Down
18 changes: 12 additions & 6 deletions vortex-array/src/arrays/struct_/vtable/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_error::VortexResult;
use vortex_error::vortex_err;

use crate::ArrayRef;
use crate::IntoArray;
Expand Down Expand Up @@ -49,13 +50,18 @@ impl ArrayParentReduceRule<StructVTable> for StructCastPushDownRule {
new_fields.push(field_array.cast(field_dtype)?)
}

let validity = if parent.options.is_nullable() {
array.validity().clone().into_nullable()
} else {
array
.validity()
.clone()
.into_non_nullable(array.len)
.ok_or_else(|| vortex_err!("Failed to cast nullable struct to non-nullable"))?
};

let new_struct = unsafe {
StructArray::new_unchecked(
new_fields,
target_fields.clone(),
array.len(),
array.validity().clone(),
)
StructArray::new_unchecked(new_fields, target_fields.clone(), array.len(), validity)
};

Ok(Some(new_struct.into_array()))
Expand Down
51 changes: 31 additions & 20 deletions vortex-array/src/arrow/executor/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,23 @@
use std::any::type_name;
use std::sync::Arc;

use arrow_array::Array;
use arrow_array::ArrayRef as ArrowArrayRef;
use arrow_array::GenericListArray;
use arrow_array::OffsetSizeTrait;
use arrow_schema::DataType;
use arrow_schema::FieldRef;
use vortex_buffer::BufferMut;
use vortex_compute::arrow::IntoArrow;
use vortex_compute::cast::Cast;
use vortex_dtype::DType;
use vortex_dtype::NativePType;
use vortex_dtype::Nullability;
use vortex_dtype::PTypeDowncastExt;
use vortex_error::VortexError;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use vortex_session::VortexSession;

use crate::Array;
use crate::ArrayRef;
use crate::IntoArray;
use crate::VectorExecutor;
Expand All @@ -35,6 +33,7 @@ use crate::arrow::ArrowArrayExecutor;
use crate::arrow::executor::validity::to_arrow_null_buffer;
use crate::builtins::ArrayBuiltins;
use crate::validity::Validity;
use crate::vectors::VectorIntoArray;
use crate::vtable::ValidityHelper;

/// Convert a Vortex array into an Arrow GenericBinaryArray.
Expand Down Expand Up @@ -64,12 +63,32 @@ pub(super) fn to_arrow_list<O: OffsetSizeTrait + NativePType>(
// In other words, check that offsets + sizes are monotonically increasing.

// Otherwise, we execute the array to become a ListViewVector.
let list_view = array.execute_vector(session)?.into_arrow()?;
match O::IS_LARGE {
true => arrow_cast::cast(&list_view, &DataType::LargeList(elements_field.clone())),
false => arrow_cast::cast(&list_view, &DataType::List(elements_field.clone())),
}
.map_err(VortexError::from)
let elements_dtype = array
.dtype()
.as_list_element_opt()
.ok_or_else(|| vortex_err!("Cannot convert non-list array to Arrow ListArray"))?;
let list_view = array.execute_vector(session)?.into_list();
let (elements, offsets, sizes, validity) = list_view.into_parts();
let offset_dtype = DType::Primitive(O::PTYPE, Nullability::NonNullable);
let list_view = unsafe {
ListViewArray::new_unchecked(
(*elements).clone().into_array(elements_dtype),
offsets.cast(&offset_dtype)?.into_array(&offset_dtype),
sizes.cast(&offset_dtype)?.into_array(&offset_dtype),
Validity::from_mask(validity, array.dtype().nullability()),
)
};

list_view_to_list::<O>(list_view, elements_field, session)

// FIXME(ngates): we need this PR from arrow-rs:
// https://github.com/apache/arrow-rs/pull/8735
// let list_view = array.execute_vector(session)?.into_arrow()?;
// match O::IS_LARGE {
// true => arrow_cast::cast(&list_view, &DataType::LargeList(elements_field.clone())),
// false => arrow_cast::cast(&list_view, &DataType::List(elements_field.clone())),
// }
// .map_err(VortexError::from)
}

/// Convert a Vortex VarBinArray into an Arrow GenericBinaryArray.
Expand Down Expand Up @@ -203,6 +222,7 @@ fn list_view_to_list<O: OffsetSizeTrait + NativePType>(
}
new_offsets.push(O::usize_as(take_indices.len()));
}
assert_eq!(new_offsets.len(), offsets.len() + 1);

// Now we can "take" the elements using the computed indices.
let elements =
Expand All @@ -214,20 +234,11 @@ fn list_view_to_list<O: OffsetSizeTrait + NativePType>(
"Cannot convert to non-nullable Arrow array with null elements"
);

// We need to compute the final offsets from the sizes.
let mut final_offsets = Vec::with_capacity(sizes.len() + 1);
final_offsets.push(O::usize_as(0));
for i in 0..sizes.len() {
let last_offset = final_offsets[i].as_usize();
let size = sizes[i].as_usize();
final_offsets.push(O::usize_as(last_offset + size));
}

let null_buffer = to_arrow_null_buffer(&validity, sizes.len(), session)?;

Ok(Arc::new(GenericListArray::<O>::new(
elements_field.clone(),
offsets.into_arrow_offset_buffer(),
new_offsets.freeze().into_arrow_offset_buffer(),
elements,
null_buffer,
)))
Expand Down
34 changes: 32 additions & 2 deletions vortex-array/src/arrow/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_schema::DataType;
use arrow_schema::Schema;
use itertools::Itertools;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_ensure;
use vortex_session::VortexSession;

use crate::Array;
use crate::ArrayRef;
use crate::arrow::executor::bool::to_arrow_bool;
use crate::arrow::executor::byte::to_arrow_byte_array;
Expand Down Expand Up @@ -61,6 +64,13 @@ pub trait ArrowArrayExecutor: Sized {
let array = self.execute_arrow(&DataType::Struct(schema.fields.clone()), session)?;
Ok(RecordBatch::from(array.as_struct()))
}

/// Execute the array to produce Arrow `RecordBatch`'s with the given schema.
fn execute_record_batches(
self,
schema: &Schema,
session: &VortexSession,
) -> VortexResult<Vec<RecordBatch>>;
}

impl ArrowArrayExecutor for ArrayRef {
Expand All @@ -69,7 +79,9 @@ impl ArrowArrayExecutor for ArrayRef {
data_type: &DataType,
session: &VortexSession,
) -> VortexResult<ArrowArrayRef> {
match data_type {
let len = self.len();

let arrow = match data_type {
DataType::Null => to_arrow_null(self, session),
DataType::Boolean => to_arrow_bool(self, session),
DataType::Int8 => to_arrow_primitive::<Int8Type>(self, session),
Expand Down Expand Up @@ -133,6 +145,24 @@ impl ArrowArrayExecutor for ArrayRef {
| DataType::Union(..) => {
vortex_bail!("Conversion to Arrow type {data_type} is not supported");
}
}
}?;

vortex_ensure!(
arrow.len() == len,
"Arrow array length does not match Vortex array length after conversion to {:?}",
arrow
);

Ok(arrow)
}

fn execute_record_batches(
self,
schema: &Schema,
session: &VortexSession,
) -> VortexResult<Vec<RecordBatch>> {
self.to_array_iterator()
.map(|a| a?.execute_record_batch(schema, session))
.try_collect()
}
}
15 changes: 14 additions & 1 deletion vortex-array/src/arrow/executor/struct_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ use vortex_session::VortexSession;

use crate::Array;
use crate::ArrayRef;
use crate::IntoArray;
use crate::ToCanonical;
use crate::VectorExecutor;
use crate::arrays::ChunkedVTable;
use crate::arrays::ScalarFnVTable;
use crate::arrays::StructVTable;
use crate::arrow::ArrowArrayExecutor;
Expand All @@ -35,7 +38,17 @@ pub(super) fn to_arrow_struct(
) -> VortexResult<ArrowArrayRef> {
let len = array.len();

// First, we attempt to short-circuit if the array is already a StructVTable:
// If the array is chunked, then we invert the chunk-of-struct to struct-of-chunk.
let array = match array.try_into::<ChunkedVTable>() {
Ok(array) => {
// NOTE(ngates): this currently uses the old into_canonical code path, but we should
// just call directly into the swizzle-chunks function.
array.to_struct().into_array()
}
Err(array) => array,
};

// Attempt to short-circuit if the array is already a StructVTable:
let array = match array.try_into::<StructVTable>() {
Ok(array) => {
let validity = to_arrow_null_buffer(array.validity(), array.len(), session)?;
Expand Down
Loading