diff --git a/encodings/alp/src/alp/compute/mod.rs b/encodings/alp/src/alp/compute/mod.rs index 599b7624dc1..b87ecb37244 100644 --- a/encodings/alp/src/alp/compute/mod.rs +++ b/encodings/alp/src/alp/compute/mod.rs @@ -1,8 +1,8 @@ mod between; mod compare; -mod filter; +mod nan_count; -use vortex_array::compute::{ScalarAtFn, SliceFn, TakeFn, scalar_at, slice, take}; +use vortex_array::compute::{NaNCountFn, ScalarAtFn, SliceFn, TakeFn, scalar_at, slice, take}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::vtable::ComputeVTable; use vortex_array::{Array, ArrayRef}; @@ -12,6 +12,9 @@ use vortex_scalar::Scalar; use crate::{ALPArray, ALPEncoding, ALPFloat, match_each_alp_float_ptype}; impl ComputeVTable for ALPEncoding { + fn nan_count_fn(&self) -> Option<&dyn NaNCountFn<&dyn Array>> { + Some(self) + } fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> { Some(self) } diff --git a/encodings/alp/src/alp/compute/nan_count.rs b/encodings/alp/src/alp/compute/nan_count.rs new file mode 100644 index 00000000000..102b9cd5340 --- /dev/null +++ b/encodings/alp/src/alp/compute/nan_count.rs @@ -0,0 +1,15 @@ +use vortex_array::compute::{NaNCountFn, nan_count}; +use vortex_error::VortexResult; + +use crate::{ALPArray, ALPEncoding}; + +impl NaNCountFn<&ALPArray> for ALPEncoding { + fn nan_count(&self, array: &ALPArray) -> VortexResult> { + // NANs can only be in patches + if let Some(patches) = array.patches() { + nan_count(patches.values()) + } else { + Ok(Some(0)) + } + } +} diff --git a/vortex-array/src/arrays/primitive/compute/mod.rs b/vortex-array/src/arrays/primitive/compute/mod.rs index d2fd7b93225..49d7bd3d8c8 100644 --- a/vortex-array/src/arrays/primitive/compute/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/mod.rs @@ -1,7 +1,7 @@ use crate::Array; use crate::arrays::PrimitiveEncoding; use crate::compute::{ - FillNullFn, IsConstantFn, IsSortedFn, MinMaxFn, ScalarAtFn, SearchSortedFn, + FillNullFn, IsConstantFn, IsSortedFn, MinMaxFn, NaNCountFn, ScalarAtFn, SearchSortedFn, SearchSortedUsizeFn, SliceFn, TakeFn, ToArrowFn, UncompressedSizeFn, }; use crate::vtable::ComputeVTable; @@ -14,6 +14,7 @@ mod is_constant; mod is_sorted; mod mask; mod min_max; +mod nan_count; mod scalar_at; mod search_sorted; mod slice; @@ -41,6 +42,10 @@ impl ComputeVTable for PrimitiveEncoding { Some(self) } + fn nan_count_fn(&self) -> Option<&dyn NaNCountFn<&dyn Array>> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> { Some(self) } diff --git a/vortex-array/src/arrays/primitive/compute/nan_count.rs b/vortex-array/src/arrays/primitive/compute/nan_count.rs new file mode 100644 index 00000000000..8c4980af175 --- /dev/null +++ b/vortex-array/src/arrays/primitive/compute/nan_count.rs @@ -0,0 +1,56 @@ +use vortex_dtype::{NativePType, match_each_float_ptype}; +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::Array; +use crate::arrays::{PrimitiveArray, PrimitiveEncoding}; +use crate::compute::NaNCountFn; +use crate::variants::PrimitiveArrayTrait; + +impl NaNCountFn<&PrimitiveArray> for PrimitiveEncoding { + fn nan_count(&self, array: &PrimitiveArray) -> VortexResult> { + Ok(Some(match_each_float_ptype!(array.ptype(), |$F| { + compute_nan_count_with_validity(array.as_slice::<$F>(), array.validity_mask()?) + }))) + } +} + +#[inline] +fn compute_nan_count_with_validity(values: &[T], validity: Mask) -> usize { + match validity { + Mask::AllTrue(_) => values.iter().filter(|v| v.is_nan()).count(), + Mask::AllFalse(_) => 0, + Mask::Values(v) => values + .iter() + .zip(v.boolean_buffer().iter()) + .filter_map(|(v, m)| m.then_some(v)) + .filter(|v| v.is_nan()) + .count(), + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + + use crate::arrays::PrimitiveArray; + use crate::compute::nan_count; + use crate::validity::Validity; + + #[test] + fn primitive_nan_count() { + let p = PrimitiveArray::new( + buffer![ + -f32::NAN, + f32::NAN, + 0.1, + 1.1, + -0.0, + f32::INFINITY, + f32::NEG_INFINITY + ], + Validity::NonNullable, + ); + assert_eq!(nan_count(&p).unwrap(), Some(2)); + } +} diff --git a/vortex-array/src/compute/conformance/binary_numeric.rs b/vortex-array/src/compute/conformance/binary_numeric.rs index 12bd9b7c20b..c610606d940 100644 --- a/vortex-array/src/compute/conformance/binary_numeric.rs +++ b/vortex-array/src/compute/conformance/binary_numeric.rs @@ -4,7 +4,8 @@ use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_err}; use vortex_scalar::{NumericOperator, PrimitiveScalar, Scalar}; use crate::arrays::ConstantArray; -use crate::compute::{numeric, scalar_at}; +use crate::compute::numeric::numeric; +use crate::compute::scalar_at; use crate::{Array, ArrayRef, ToCanonical}; fn to_vec_of_scalar(array: &dyn Array) -> Vec { diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 90693ecf909..e5e81e7a697 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -23,6 +23,7 @@ use itertools::Itertools; pub use like::{LikeFn, LikeOptions, like}; pub use mask::*; pub use min_max::{MinMaxFn, MinMaxResult, min_max}; +pub use nan_count::*; pub use numeric::*; pub use optimize::*; pub use scalar_at::{ScalarAtFn, scalar_at}; @@ -58,6 +59,7 @@ mod is_sorted; mod like; mod mask; mod min_max; +mod nan_count; mod numeric; mod optimize; mod scalar_at; diff --git a/vortex-array/src/compute/nan_count.rs b/vortex-array/src/compute/nan_count.rs new file mode 100644 index 00000000000..48cd5de931d --- /dev/null +++ b/vortex-array/src/compute/nan_count.rs @@ -0,0 +1,65 @@ +use vortex_error::{VortexExpect, VortexResult, vortex_bail}; +use vortex_scalar::ScalarValue; + +use crate::stats::{Precision, Stat}; +use crate::{Array, Encoding}; + +/// Computes the min and max of an array, returning the (min, max) values +/// If the array is empty or has only nulls, the result is `None`. +pub trait NaNCountFn { + fn nan_count(&self, array: A) -> VortexResult>; +} + +impl NaNCountFn<&dyn Array> for E +where + E: for<'a> NaNCountFn<&'a E::Array>, +{ + fn nan_count(&self, array: &dyn Array) -> VortexResult> { + let array_ref = array + .as_any() + .downcast_ref::() + .vortex_expect("Failed to downcast array"); + NaNCountFn::nan_count(self, array_ref) + } +} + +/// Computes the nunmber of NaN values in the array +/// This will update the stats set of this array (as a side effect). +pub fn nan_count(array: &dyn Array) -> VortexResult> { + if array.is_empty() || array.valid_count()? == 0 { + return Ok(Some(0)); + } + + let nan_count = array + .statistics() + .get_as::(Stat::NaNCount) + .and_then(Precision::as_exact); + + if let Some(nan_count) = nan_count { + return Ok(Some(nan_count)); + } + + // Only float arrays can have NaNs + let nan_count = if !array.dtype().is_float() { + Some(0) + } else if let Some(fn_) = array.vtable().nan_count_fn() { + fn_.nan_count(array)? + } else { + let canonical = array.to_canonical()?; + if let Some(fn_) = canonical.as_ref().vtable().nan_count_fn() { + fn_.nan_count(canonical.as_ref())? + } else { + vortex_bail!(NotImplemented: "nan_count", array.encoding()); + } + }; + + if let Some(nan_count) = nan_count { + // Update the stats set with the computed min/max + array.statistics().set( + Stat::NaNCount, + Precision::Exact(ScalarValue::from(nan_count)), + ); + } + + Ok(nan_count) +} diff --git a/vortex-array/src/stats/array.rs b/vortex-array/src/stats/array.rs index 3b93abd3cb5..4d0829b3019 100644 --- a/vortex-array/src/stats/array.rs +++ b/vortex-array/src/stats/array.rs @@ -11,7 +11,8 @@ use super::{ }; use crate::Array; use crate::compute::{ - MinMaxResult, is_constant, is_sorted, is_strict_sorted, min_max, sum, uncompressed_size, + MinMaxResult, is_constant, is_sorted, is_strict_sorted, min_max, nan_count, sum, + uncompressed_size, }; /// A shared [`StatsSet`] stored in an array. Can be shared by copies of the array and can also be mutated in place. @@ -135,6 +136,7 @@ impl StatsSetRef<'_> { Stat::IsSorted => Some(is_sorted(self.dyn_array_ref)?.into()), Stat::IsStrictSorted => Some(is_strict_sorted(self.dyn_array_ref)?.into()), Stat::UncompressedSizeInBytes => Some(uncompressed_size(self.dyn_array_ref)?.into()), + Stat::NaNCount => Some(nan_count(self.dyn_array_ref)?.into()), }) } diff --git a/vortex-array/src/stats/flatbuffers.rs b/vortex-array/src/stats/flatbuffers.rs index 9efa1b749e0..e66afe3ed30 100644 --- a/vortex-array/src/stats/flatbuffers.rs +++ b/vortex-array/src/stats/flatbuffers.rs @@ -49,6 +49,9 @@ impl WriteFlatBuffer for StatsSet { uncompressed_size_in_bytes: self .get_as::(Stat::UncompressedSizeInBytes) .and_then(Precision::as_exact), + nan_count: self + .get_as::(Stat::NaNCount) + .and_then(Precision::as_exact), }; crate::flatbuffers::ArrayStats::create(fbb, stat_args) @@ -112,6 +115,14 @@ impl ReadFlatBuffer for StatsSet { stats_set.set(Stat::Sum, Precision::Exact(ScalarValue::try_from(sum)?)); } } + Stat::NaNCount => { + if let Some(nan_count) = fb.nan_count() { + stats_set.set( + Stat::NaNCount, + Precision::Exact(ScalarValue::from(nan_count)), + ); + } + } } } diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index fa05c018f1e..4996973c23b 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -29,13 +29,20 @@ use vortex_error::VortexExpect; /// Statistics that are used for pruning files (i.e., we want to ensure they are computed when compressing/writing). /// Sum is included for boolean arrays. -pub const PRUNING_STATS: &[Stat] = &[Stat::Min, Stat::Max, Stat::Sum, Stat::NullCount]; +pub const PRUNING_STATS: &[Stat] = &[ + Stat::Min, + Stat::Max, + Stat::Sum, + Stat::NullCount, + Stat::NaNCount, +]; /// Stats to keep when serializing arrays to layouts pub const STATS_TO_WRITE: &[Stat] = &[ Stat::Min, Stat::Max, Stat::NullCount, + Stat::NaNCount, Stat::Sum, Stat::IsConstant, Stat::IsSorted, @@ -75,6 +82,8 @@ pub enum Stat { NullCount = 6, /// The uncompressed size of the array in bytes UncompressedSizeInBytes = 7, + /// The number of NaN values in the array + NaNCount = 8, } /// These structs allow the extraction of the bound from the `Precision` value. @@ -87,6 +96,7 @@ pub struct IsSorted; pub struct IsStrictSorted; pub struct NullCount; pub struct UncompressedSizeInBytes; +pub struct NaNCount; impl StatType for IsConstant { type Bound = Precision; @@ -136,19 +146,26 @@ impl StatType for Sum { const STAT: Stat = Stat::Sum; } +impl StatType for NaNCount { + type Bound = UpperBound; + + const STAT: Stat = Stat::NaNCount; +} + impl Stat { /// Whether the statistic is commutative (i.e., whether merging can be done independently of ordering) /// e.g., min/max are commutative, but is_sorted is not pub fn is_commutative(&self) -> bool { // NOTE: we prefer this syntax to force a compile error if we add a new stat match self { - Stat::IsConstant - | Stat::Max - | Stat::Min - | Stat::NullCount - | Stat::Sum - | Stat::UncompressedSizeInBytes => true, - Stat::IsSorted | Stat::IsStrictSorted => false, + Self::IsConstant + | Self::Max + | Self::Min + | Self::NullCount + | Self::Sum + | Self::NaNCount + | Self::UncompressedSizeInBytes => true, + Self::IsSorted | Self::IsStrictSorted => false, } } @@ -159,14 +176,15 @@ impl Stat { pub fn dtype(&self, data_type: &DType) -> Option { Some(match self { - Stat::IsConstant => DType::Bool(NonNullable), - Stat::IsSorted => DType::Bool(NonNullable), - Stat::IsStrictSorted => DType::Bool(NonNullable), - Stat::Max => data_type.clone(), - Stat::Min => data_type.clone(), - Stat::NullCount => DType::Primitive(PType::U64, NonNullable), - Stat::UncompressedSizeInBytes => DType::Primitive(PType::U64, NonNullable), - Stat::Sum => { + Self::IsConstant => DType::Bool(NonNullable), + Self::IsSorted => DType::Bool(NonNullable), + Self::IsStrictSorted => DType::Bool(NonNullable), + Self::Max => data_type.clone(), + Self::Min => data_type.clone(), + Self::NullCount => DType::Primitive(PType::U64, NonNullable), + Self::UncompressedSizeInBytes => DType::Primitive(PType::U64, NonNullable), + Self::NaNCount => DType::Primitive(PType::U64, NonNullable), + Self::Sum => { // Any array that cannot be summed has a sum DType of null. // Any array that can be summed, but overflows, has a sum _value_ of null. // Therefore, we make integer sum stats nullable. @@ -207,13 +225,14 @@ impl Stat { Self::Min => "min", Self::NullCount => "null_count", Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes", - Stat::Sum => "sum", + Self::Sum => "sum", + Self::NaNCount => "nan_count", } } } pub fn as_stat_bitset_bytes(stats: &[Stat]) -> Vec { - let max_stat = u8::from(last::().vortex_expect("last stat")) as usize; + let max_stat = u8::from(last::().vortex_expect("last stat")) as usize + 1; // TODO(ngates): use vortex-buffer::BitBuffer let mut stat_bitset = BooleanBufferBuilder::new_from_buffer( MutableBuffer::from_len_zeroed(max_stat.div_ceil(8)), diff --git a/vortex-array/src/stats/stat_bound.rs b/vortex-array/src/stats/stat_bound.rs index ebbbc758d68..1a43e1a1998 100644 --- a/vortex-array/src/stats/stat_bound.rs +++ b/vortex-array/src/stats/stat_bound.rs @@ -48,6 +48,10 @@ impl StatBound for Precision { value } + fn into_value(self) -> Precision { + self + } + fn union(&self, other: &Self) -> Option { self.clone() .zip(other.clone()) @@ -84,8 +88,4 @@ impl StatBound for Precision { _ => None, } } - - fn into_value(self) -> Precision { - self - } } diff --git a/vortex-array/src/stats/stats_set.rs b/vortex-array/src/stats/stats_set.rs index 0fd85188b03..bb40fd56163 100644 --- a/vortex-array/src/stats/stats_set.rs +++ b/vortex-array/src/stats/stats_set.rs @@ -7,7 +7,7 @@ use vortex_error::{VortexExpect, VortexResult, vortex_err}; use vortex_scalar::{Scalar, ScalarValue}; use super::traits::StatsProvider; -use super::{IsSorted, IsStrictSorted, NullCount, StatType, UncompressedSizeInBytes}; +use super::{IsSorted, IsStrictSorted, NaNCount, NullCount, StatType, UncompressedSizeInBytes}; use crate::stats::{IsConstant, Max, Min, Precision, Stat, StatBound, StatsProviderExt, Sum}; #[derive(Default, Debug, Clone)] @@ -199,6 +199,7 @@ impl StatsSet { Stat::Sum => self.merge_sum(other, dtype), Stat::NullCount => self.merge_null_count(other), Stat::UncompressedSizeInBytes => self.merge_uncompressed_size_in_bytes(other), + Stat::NaNCount => self.merge_nan_count(other), } } @@ -224,6 +225,7 @@ impl StatsSet { Stat::IsSorted | Stat::IsStrictSorted => { unreachable!("not commutative") } + Stat::NaNCount => self.merge_nan_count(other), } } @@ -245,6 +247,7 @@ impl StatsSet { Stat::IsStrictSorted => self.combine_bool_stat::(other)?, Stat::NullCount => self.combine_bound::(other, dtype)?, Stat::Sum => self.combine_bound::(other, dtype)?, + Stat::NaNCount => self.combine_bound::(other, dtype)?, } } Ok(()) @@ -429,6 +432,10 @@ impl StatsSet { self.merge_sum_stat(Stat::NullCount, other) } + fn merge_nan_count(&mut self, other: &Self) { + self.merge_sum_stat(Stat::NaNCount, other) + } + fn merge_uncompressed_size_in_bytes(&mut self, other: &Self) { self.merge_sum_stat(Stat::UncompressedSizeInBytes, other) } diff --git a/vortex-array/src/vtable/compute.rs b/vortex-array/src/vtable/compute.rs index 7ce3de61d7b..cb1ee4429ac 100644 --- a/vortex-array/src/vtable/compute.rs +++ b/vortex-array/src/vtable/compute.rs @@ -1,7 +1,8 @@ use crate::Array; use crate::compute::{ - FillNullFn, IsConstantFn, IsSortedFn, LikeFn, MinMaxFn, OptimizeFn, ScalarAtFn, SearchSortedFn, - SearchSortedUsizeFn, SliceFn, TakeFn, TakeFromFn, ToArrowFn, UncompressedSizeFn, + FillNullFn, IsConstantFn, IsSortedFn, LikeFn, MinMaxFn, NaNCountFn, OptimizeFn, ScalarAtFn, + SearchSortedFn, SearchSortedUsizeFn, SliceFn, TakeFn, TakeFromFn, ToArrowFn, + UncompressedSizeFn, }; /// VTable for dispatching compute functions to Vortex encodings. @@ -38,6 +39,13 @@ pub trait ComputeVTable { None } + /// Compute nan count of the array + /// + /// See: [`NaNCountFn`] + fn nan_count_fn(&self) -> Option<&dyn NaNCountFn<&dyn Array>> { + None + } + /// Try and optimize the layout of an array. /// /// See: [`OptimizeFn`] diff --git a/vortex-flatbuffers/flatbuffers/vortex-array/array.fbs b/vortex-flatbuffers/flatbuffers/vortex-array/array.fbs index e09af236ec8..a708b158c14 100644 --- a/vortex-flatbuffers/flatbuffers/vortex-array/array.fbs +++ b/vortex-flatbuffers/flatbuffers/vortex-array/array.fbs @@ -45,6 +45,7 @@ table ArrayStats { is_constant: bool = null; null_count: uint64 = null; uncompressed_size_in_bytes: uint64 = null; + nan_count: uint64 = null; } root_type Array; diff --git a/vortex-flatbuffers/src/generated/array.rs b/vortex-flatbuffers/src/generated/array.rs index a51b7383448..04a0e9f92ed 100644 --- a/vortex-flatbuffers/src/generated/array.rs +++ b/vortex-flatbuffers/src/generated/array.rs @@ -601,6 +601,7 @@ impl<'a> ArrayStats<'a> { pub const VT_IS_CONSTANT: flatbuffers::VOffsetT = 14; pub const VT_NULL_COUNT: flatbuffers::VOffsetT = 16; pub const VT_UNCOMPRESSED_SIZE_IN_BYTES: flatbuffers::VOffsetT = 18; + pub const VT_NAN_COUNT: flatbuffers::VOffsetT = 20; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { @@ -612,6 +613,7 @@ impl<'a> ArrayStats<'a> { args: &'args ArrayStatsArgs<'args> ) -> flatbuffers::WIPOffset> { let mut builder = ArrayStatsBuilder::new(_fbb); + if let Some(x) = args.nan_count { builder.add_nan_count(x); } if let Some(x) = args.uncompressed_size_in_bytes { builder.add_uncompressed_size_in_bytes(x); } if let Some(x) = args.null_count { builder.add_null_count(x); } if let Some(x) = args.sum { builder.add_sum(x); } @@ -680,6 +682,13 @@ impl<'a> ArrayStats<'a> { // which contains a valid value in this slot unsafe { self._tab.get::(ArrayStats::VT_UNCOMPRESSED_SIZE_IN_BYTES, None)} } + #[inline] + pub fn nan_count(&self) -> Option { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { self._tab.get::(ArrayStats::VT_NAN_COUNT, None)} + } } impl flatbuffers::Verifiable for ArrayStats<'_> { @@ -697,6 +706,7 @@ impl flatbuffers::Verifiable for ArrayStats<'_> { .visit_field::("is_constant", Self::VT_IS_CONSTANT, false)? .visit_field::("null_count", Self::VT_NULL_COUNT, false)? .visit_field::("uncompressed_size_in_bytes", Self::VT_UNCOMPRESSED_SIZE_IN_BYTES, false)? + .visit_field::("nan_count", Self::VT_NAN_COUNT, false)? .finish(); Ok(()) } @@ -710,6 +720,7 @@ pub struct ArrayStatsArgs<'a> { pub is_constant: Option, pub null_count: Option, pub uncompressed_size_in_bytes: Option, + pub nan_count: Option, } impl<'a> Default for ArrayStatsArgs<'a> { #[inline] @@ -723,6 +734,7 @@ impl<'a> Default for ArrayStatsArgs<'a> { is_constant: None, null_count: None, uncompressed_size_in_bytes: None, + nan_count: None, } } } @@ -765,6 +777,10 @@ impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ArrayStatsBuilder<'a, 'b, A> { self.fbb_.push_slot_always::(ArrayStats::VT_UNCOMPRESSED_SIZE_IN_BYTES, uncompressed_size_in_bytes); } #[inline] + pub fn add_nan_count(&mut self, nan_count: u64) { + self.fbb_.push_slot_always::(ArrayStats::VT_NAN_COUNT, nan_count); + } + #[inline] pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>) -> ArrayStatsBuilder<'a, 'b, A> { let start = _fbb.start_table(); ArrayStatsBuilder { @@ -790,6 +806,7 @@ impl core::fmt::Debug for ArrayStats<'_> { ds.field("is_constant", &self.is_constant()); ds.field("null_count", &self.null_count()); ds.field("uncompressed_size_in_bytes", &self.uncompressed_size_in_bytes()); + ds.field("nan_count", &self.nan_count()); ds.finish() } } diff --git a/vortex-layout/src/layouts/stats/stats_table.rs b/vortex-layout/src/layouts/stats/stats_table.rs index 82f7a27a090..f52c2d51445 100644 --- a/vortex-layout/src/layouts/stats/stats_table.rs +++ b/vortex-layout/src/layouts/stats/stats_table.rs @@ -78,7 +78,7 @@ impl StatsTable { } } // These stats sum up - Stat::NullCount | Stat::UncompressedSizeInBytes => { + Stat::NullCount | Stat::NaNCount | Stat::UncompressedSizeInBytes => { let sum = sum(&array)? .cast(&DType::Primitive(PType::U64, Nullability::Nullable))? .into_value();