From 859e7d812b3b6dd0d6d240dfbf1ec8fc4e604929 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 12 Dec 2025 16:13:19 -0500 Subject: [PATCH 1/3] bring portable simd take back add OOB check + safety comments Signed-off-by: Connor Tsui --- vortex-compute/src/take/slice/mod.rs | 11 +- vortex-compute/src/take/slice/portable.rs | 139 ++++++++++++++-------- vortex-compute/src/take/vector/tests.rs | 2 - 3 files changed, 95 insertions(+), 57 deletions(-) diff --git a/vortex-compute/src/take/slice/mod.rs b/vortex-compute/src/take/slice/mod.rs index eed5623ca54..f9d806bd155 100644 --- a/vortex-compute/src/take/slice/mod.rs +++ b/vortex-compute/src/take/slice/mod.rs @@ -17,14 +17,15 @@ impl Take<[I]> for &[T] { type Output = Buffer; fn take(self, indices: &[I]) -> Buffer { - // TODO(connor): Make the SIMD implementations bound by `Copy` instead of `NativePType`. - /* - #[cfg(vortex_nightly)] { return portable::take_portable(self, indices); } + // TODO(connor): Make the SIMD implementations bound by `Copy` instead of `NativePType`. + + /* + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { if is_x86_feature_detected!("avx2") { @@ -40,10 +41,6 @@ impl Take<[I]> for &[T] { } } -#[allow( - unused, - reason = "Compiler may see this as unused based on enabled features" -)] fn take_scalar(buffer: &[T], indices: &[I]) -> Buffer { // NB: The simpler `indices.iter().map(|idx| buff1er[idx.as_()]).collect()` generates suboptimal // assembly where the buffer length is repeatedly loaded from the stack on each iteration. diff --git a/vortex-compute/src/take/slice/portable.rs b/vortex-compute/src/take/slice/portable.rs index b3ef83862fb..feec9ca312d 100644 --- a/vortex-compute/src/take/slice/portable.rs +++ b/vortex-compute/src/take/slice/portable.rs @@ -7,18 +7,15 @@ use std::mem::MaybeUninit; use std::mem::size_of; -use std::mem::transmute; use std::simd; +use std::simd::cmp::SimdPartialOrd; use std::simd::num::SimdUint; use multiversion::multiversion; use vortex_buffer::Alignment; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; -use vortex_dtype::NativePType; -use vortex_dtype::PType; use vortex_dtype::UnsignedPType; -use vortex_dtype::match_each_native_simd_ptype; use vortex_dtype::match_each_unsigned_integer_ptype; /// SIMD types larger than the SIMD register size are beneficial for @@ -27,38 +24,49 @@ pub const SIMD_WIDTH: usize = 64; /// Takes the specified indices into a new [`Buffer`] using portable SIMD. /// -/// This function handles the type matching required to satisfy `SimdElement` bounds. -/// For `f16` values, it reinterprets them as `u16` since `f16` doesn't implement `SimdElement`. +/// This function handles the type matching required to satisfy `SimdElement` bounds by casting +/// to unsigned integers of the same size. Falls back to scalar implementation for unsupported +/// type sizes. #[inline] -pub fn take_portable(buffer: &[T], indices: &[I]) -> Buffer { - if T::PTYPE == PType::F16 { - assert_eq!(size_of::(), size_of::()); - - // Since Rust does not actually support 16-bit floats, we first reinterpret the data as - // `u16` integers. - // SAFETY: We know that f16 has the same bit pattern as u16, so this transmute is fine to - // make. - let u16_slice: &[u16] = - unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u16, buffer.len()) }; - return take_with_indices(u16_slice, indices).cast_into::(); +pub fn take_portable(buffer: &[T], indices: &[I]) -> Buffer { + // SIMD gather operations only care about bit patterns, not semantic type. We cast to unsigned + // integers which implement `SimdElement` and then cast back. + // + // SAFETY: The pointer casts below are safe because: + // - `T` and the target type have the same size (matched by `size_of::()`). + // - The alignment of unsigned integers is always <= their size, and `buffer` came from a valid + // `&[T]` which guarantees proper alignment for types of the same size. + match size_of::() { + 1 => { + let buffer: &[u8] = + unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u8, buffer.len()) }; + take_with_indices(buffer, indices).cast_into::() + } + 2 => { + let buffer: &[u16] = + unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u16, buffer.len()) }; + take_with_indices(buffer, indices).cast_into::() + } + 4 => { + let buffer: &[u32] = + unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u32, buffer.len()) }; + take_with_indices(buffer, indices).cast_into::() + } + 8 => { + let buffer: &[u64] = + unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u64, buffer.len()) }; + take_with_indices(buffer, indices).cast_into::() + } + // Fall back to scalar implementation for unsupported type sizes. + _ => super::take_scalar(buffer, indices), } - - match_each_native_simd_ptype!(T::PTYPE, |TC| { - assert_eq!(size_of::(), size_of::()); - - // SAFETY: This is essentially a no-op that tricks the compiler into adding the - // `simd::SimdElement` bound we need to call `take_with_indices`. - let buffer: &[TC] = - unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const TC, buffer.len()) }; - take_with_indices(buffer, indices).cast_into::() - }) } /// Helper that matches on index type and calls `take_portable_simd`. /// /// We separate this code out from above to add the [`simd::SimdElement`] constraint. #[inline] -fn take_with_indices( +fn take_with_indices( buffer: &[T], indices: &[I], ) -> Buffer { @@ -75,10 +83,14 @@ fn take_with_indices( /// buffer. Uses SIMD instructions to process `LANE_COUNT` indices in parallel. /// /// Returns a `Buffer` where each element corresponds to `values[indices[i]]`. +/// +/// # Panics +/// +/// Panics if any index is out of bounds for `values`. #[multiversion(targets("x86_64+avx2", "x86_64+avx", "aarch64+neon"))] pub fn take_portable_simd(values: &[T], indices: &[I]) -> Buffer where - T: NativePType + simd::SimdElement, + T: Copy + Default + simd::SimdElement, I: UnsignedPType + simd::SimdElement, simd::LaneCount: simd::SupportedLaneCount, simd::Simd: SimdUint = simd::Simd>, @@ -92,27 +104,57 @@ where let buf_slice = buffer.spare_capacity_mut(); + // Set up a vector that we can SIMD compare against for out-of-bounds indices. + let len_vec = simd::Simd::::splat(values.len()); + let mut all_valid = simd::Mask::::splat(true); + for chunk_idx in 0..(indices_len / LANE_COUNT) { let offset = chunk_idx * LANE_COUNT; - let mask = simd::Mask::from_bitmask(u64::MAX); let codes_chunk = simd::Simd::::from_slice(&indices[offset..]); - - let selection = simd::Simd::gather_select( - values, - mask, - codes_chunk.cast::(), - simd::Simd::::default(), - ); - + let codes_usize = codes_chunk.cast::(); + + // Accumulate validity and use as gather mask. An out-of-bounds index will turn a bit off. + all_valid &= codes_usize.simd_lt(len_vec); + + // SAFETY: We use `all_valid` to mask the gather, preventing OOB memory access. If any + // index is OOB, `all_valid` will have those bits turned off, masking out the invalid + // indices. + // Note that this may also mask out valid indices in subsequent iterations. This is fine + // because we will panic after the loop if **any** index was OOB, so we do not care if the + // resulting gathered data is correct or not. + let selection = unsafe { + simd::Simd::gather_select_unchecked( + values, + all_valid, + codes_usize, + simd::Simd::::default(), + ) + }; + + // SAFETY: `MaybeUninit` has the same layout as `T`, and we are about to initialize these + // elements with the store. + let uninit = unsafe { + std::mem::transmute::<&mut [MaybeUninit], &mut [T]>( + &mut buf_slice[offset..][..LANE_COUNT], + ) + }; + + // SAFETY: The slice `buf_slice[offset..][..LANE_COUNT]` is guaranteed to have exactly + // `LANE_COUNT` elements since `offset` is a multiple of `LANE_COUNT` and we only iterate + // while `offset + LANE_COUNT <= indices_len`. unsafe { - selection.store_select_unchecked( - transmute::<&mut [MaybeUninit], &mut [T]>(&mut buf_slice[offset..][..64]), - mask.cast(), - ); + selection.store_select_unchecked(uninit, simd::Mask::splat(true)); } } + // Check accumulated validity after hot loop. If there are any 0's, then there was an + // out-of-bounds index. + assert!(all_valid.all(), "index out of bounds in SIMD take"); + + // Fall back to scalar iteration for the remainder. for idx in ((indices_len / LANE_COUNT) * LANE_COUNT)..indices_len { + // SAFETY: `idx` is in bounds for `buf_slice` since `idx < indices_len == buf_slice.len()`. + // Note that the `values[...]` access is already bounds-checked and will panic if OOB. unsafe { buf_slice .get_unchecked_mut(idx) @@ -120,24 +162,25 @@ where } } - unsafe { - buffer.set_len(indices_len); - } + // SAFETY: All elements have been initialized: the SIMD loop handles `0..chunks * LANE_COUNT` + // and the scalar loop handles the remainder up to `indices_len`. + unsafe { buffer.set_len(indices_len) }; buffer.freeze() } #[cfg(test)] +#[allow(clippy::cast_possible_truncation)] mod tests { use super::take_portable_simd; #[test] + #[should_panic(expected = "index out of bounds")] fn test_take_out_of_bounds() { let indices = vec![2_000_000u32; 64]; let values = vec![1i32]; - let result = take_portable_simd::(&values, &indices); - assert_eq!(result.as_slice(), [0i32; 64]); + drop(take_portable_simd::(&values, &indices)); } /// Tests SIMD gather with a mix of sequential, strided, and repeated indices. This exercises @@ -159,7 +202,7 @@ mod tests { // Strided by 4: 0, 4, 8, ..., 252. indices.extend((0u32..64).map(|i| i * 4)); // Repeated: index 42 repeated 32 times. - indices.extend(std::iter::repeat(42u32).take(32)); + indices.extend(std::iter::repeat_n(42u32, 32)); // Reverse: 255, 254, ..., 216. indices.extend((216u32..256).rev()); diff --git a/vortex-compute/src/take/vector/tests.rs b/vortex-compute/src/take/vector/tests.rs index 72928193d06..057f3789d90 100644 --- a/vortex-compute/src/take/vector/tests.rs +++ b/vortex-compute/src/take/vector/tests.rs @@ -120,7 +120,6 @@ fn test_null_vector_take() { assert!(result.validity().all_false()); } -#[ignore = "TODO(connor): Implement `DecimalVector::take`."] #[test] fn test_dvector_take() { use vortex_buffer::buffer; @@ -154,7 +153,6 @@ fn test_dvector_take() { assert_eq!(validity, vec![true, false, true, false]); } -#[ignore = "TODO(connor): Implement `DecimalVector::take`."] #[test] fn test_decimal_vector_take() { use vortex_buffer::buffer; From e184b25a93852fc4f948c2fad3e5889cb77239fd Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 12 Dec 2025 17:21:40 -0500 Subject: [PATCH 2/3] add avx2 take impl back and bound by `Copy` Signed-off-by: Connor Tsui --- vortex-compute/src/take/slice/avx2.rs | 294 ++++++++++---------------- vortex-compute/src/take/slice/mod.rs | 13 +- 2 files changed, 115 insertions(+), 192 deletions(-) diff --git a/vortex-compute/src/take/slice/avx2.rs b/vortex-compute/src/take/slice/avx2.rs index 0731b20e400..e7c420162b6 100644 --- a/vortex-compute/src/take/slice/avx2.rs +++ b/vortex-compute/src/take/slice/avx2.rs @@ -6,10 +6,15 @@ //! Only enabled for x86_64 hosts and it is gated at runtime behind feature detection to ensure AVX2 //! instructions are available. +#![allow( + unused, + reason = "Compiler may see things in this module as unused based on enabled features" +)] #![cfg(any(target_arch = "x86_64", target_arch = "x86"))] use std::arch::x86_64::__m256i; use std::arch::x86_64::_mm_loadu_si128; +use std::arch::x86_64::_mm_movemask_epi8; use std::arch::x86_64::_mm_setzero_si128; use std::arch::x86_64::_mm_shuffle_epi32; use std::arch::x86_64::_mm_storeu_si128; @@ -26,92 +31,68 @@ use std::arch::x86_64::_mm256_loadu_si256; use std::arch::x86_64::_mm256_mask_i32gather_epi32; use std::arch::x86_64::_mm256_mask_i64gather_epi32; use std::arch::x86_64::_mm256_mask_i64gather_epi64; +use std::arch::x86_64::_mm256_movemask_epi8; use std::arch::x86_64::_mm256_set1_epi32; use std::arch::x86_64::_mm256_set1_epi64x; use std::arch::x86_64::_mm256_setzero_si256; use std::arch::x86_64::_mm256_storeu_si256; use std::convert::identity; +use std::mem::size_of; use vortex_buffer::Alignment; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; -use vortex_dtype::NativePType; -use vortex_dtype::PType; use vortex_dtype::UnsignedPType; +use vortex_dtype::match_each_unsigned_integer_ptype; use crate::take::slice::take_scalar; /// Takes the specified indices into a new [`Buffer`] using AVX2 SIMD. /// -/// This returns None if the AVX2 feature is not detected at runtime, signalling to the caller -/// that it should fall back to the scalar implementation. -/// -/// If AVX2 is available, this returns a PrimitiveArray containing the result of the take operation -/// accelerated using AVX2 instructions. +/// This function handles the type matching required to satisfy AVX2 gather instruction requirements +/// by casting to unsigned integers of the same size. Falls back to scalar implementation for +/// unsupported type sizes. /// /// # Panics /// -/// This function panics if any of the provided `indices` are out of bounds for `values` +/// This function panics if any of the provided `indices` are out of bounds for `values`. /// /// # Safety /// /// The caller must ensure the `avx2` feature is enabled. -#[allow(dead_code, unused_variables, reason = "TODO(connor): Implement this")] #[target_feature(enable = "avx2")] #[inline] -pub unsafe fn take_avx2( - buffer: &[V], - indices: &[I], -) -> Buffer { - macro_rules! dispatch_avx2 { - ($indices:ty, $values:ty) => { - { let result = dispatch_avx2!($indices, $values, cast: $values); result } - }; - ($indices:ty, $values:ty, cast: $cast:ty) => {{ - let indices = unsafe { std::mem::transmute::<&[I], &[$indices]>(indices) }; - let values = unsafe { std::mem::transmute::<&[V], &[$cast]>(buffer) }; - - let result = exec_take::<$cast, $indices, AVX2Gather>(values, indices); - result.cast_into::() - }}; - } - - match (I::PTYPE, V::PTYPE) { - // Int value types. Only 32 and 64 bit types are supported. - (PType::U8, PType::I32) => dispatch_avx2!(u8, i32), - (PType::U8, PType::U32) => dispatch_avx2!(u8, u32), - (PType::U8, PType::I64) => dispatch_avx2!(u8, i64), - (PType::U8, PType::U64) => dispatch_avx2!(u8, u64), - (PType::U16, PType::I32) => dispatch_avx2!(u16, i32), - (PType::U16, PType::U32) => dispatch_avx2!(u16, u32), - (PType::U16, PType::I64) => dispatch_avx2!(u16, i64), - (PType::U16, PType::U64) => dispatch_avx2!(u16, u64), - (PType::U32, PType::I32) => dispatch_avx2!(u32, i32), - (PType::U32, PType::U32) => dispatch_avx2!(u32, u32), - (PType::U32, PType::I64) => dispatch_avx2!(u32, i64), - (PType::U32, PType::U64) => dispatch_avx2!(u32, u64), - - // Float value types, treat them as if they were corresponding int types. - (PType::U8, PType::F32) => dispatch_avx2!(u8, f32, cast: u32), - (PType::U16, PType::F32) => dispatch_avx2!(u16, f32, cast: u32), - (PType::U32, PType::F32) => dispatch_avx2!(u32, f32, cast: u32), - (PType::U64, PType::F32) => dispatch_avx2!(u64, f32, cast: u32), - - (PType::U8, PType::F64) => dispatch_avx2!(u8, f64, cast: u64), - (PType::U16, PType::F64) => dispatch_avx2!(u16, f64, cast: u64), - (PType::U32, PType::F64) => dispatch_avx2!(u32, f64, cast: u64), - (PType::U64, PType::F64) => dispatch_avx2!(u64, f64, cast: u64), - - // Scalar fallback for unsupported value types. - _ => { - tracing::trace!( - "take AVX2 kernel missing for indices {} values {}, falling back to scalar", - I::PTYPE, - V::PTYPE - ); - - take_scalar(buffer, indices) +pub unsafe fn take_avx2(buffer: &[V], indices: &[I]) -> Buffer { + // AVX2 gather operations only care about bit patterns, not semantic type. We cast to unsigned + // integers which have the required gather implementations and then cast back. + // + // SAFETY: The pointer casts below are safe because: + // - `V` and the target type have the same size (matched by `size_of::()`) + // - The alignment of unsigned integers is always <= their size, and `buffer` came from a valid + // `&[V]` which guarantees proper alignment for types of the same size. + match size_of::() { + 4 => { + let values: &[u32] = + unsafe { std::slice::from_raw_parts(buffer.as_ptr().cast::(), buffer.len()) }; + match_each_unsigned_integer_ptype!(I::PTYPE, |IC| { + let indices: &[IC] = unsafe { + std::slice::from_raw_parts(indices.as_ptr().cast::(), indices.len()) + }; + exec_take::(values, indices).cast_into::() + }) + } + 8 => { + let values: &[u64] = + unsafe { std::slice::from_raw_parts(buffer.as_ptr().cast::(), buffer.len()) }; + match_each_unsigned_integer_ptype!(I::PTYPE, |IC| { + let indices: &[IC] = unsafe { + std::slice::from_raw_parts(indices.as_ptr().cast::(), indices.len()) + }; + exec_take::(values, indices).cast_into::() + }) } + // Fall back to scalar implementation for unsupported type sizes (1, 2 byte types). + _ => take_scalar(buffer, indices), } } @@ -127,30 +108,43 @@ pub(crate) trait GatherFn { /// Gather values from `src` into the `dst` using the `indices`, optionally using /// SIMD instructions. /// + /// Returns `true` if all indices in this batch were valid (less than `max_idx`), `false` + /// otherwise. Invalid indices are masked out during the gather (substituting zeros). + /// /// # Safety /// /// This function can read up to `STRIDE` elements through `indices`, and read/write up to /// `WIDTH` elements through `src` and `dst` respectively. - unsafe fn gather(indices: *const Idx, max_idx: Idx, src: *const Values, dst: *mut Values); + unsafe fn gather( + indices: *const Idx, + max_idx: Idx, + src: *const Values, + dst: *mut Values, + ) -> bool; } /// AVX2 version of GatherFn defined for 32- and 64-bit value types. enum AVX2Gather {} macro_rules! impl_gather { - ($idx:ty, $({$value:ty => load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal }),+) => { + ($idx:ty, $({$value:ty => load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, movemask: $movemask:ident, all_valid_mask: $all_valid_mask:expr, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal }),+) => { $( - impl_gather!(single; $idx, $value, load: $load, extend: $extend, splat: $splat, zero_vec: $zero_vec, mask_indices: $mask_indices, mask_cvt: |$mask_var| $mask_cvt, gather: $masked_gather, store: $store, WIDTH = $WIDTH, STRIDE = $STRIDE); + impl_gather!(single; $idx, $value, load: $load, extend: $extend, splat: $splat, zero_vec: $zero_vec, mask_indices: $mask_indices, mask_cvt: |$mask_var| $mask_cvt, movemask: $movemask, all_valid_mask: $all_valid_mask, gather: $masked_gather, store: $store, WIDTH = $WIDTH, STRIDE = $STRIDE); )* }; - (single; $idx:ty, $value:ty, load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal) => { + (single; $idx:ty, $value:ty, load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, movemask: $movemask:ident, all_valid_mask: $all_valid_mask:expr, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal) => { impl GatherFn<$idx, $value> for AVX2Gather { const WIDTH: usize = $WIDTH; const STRIDE: usize = $STRIDE; #[allow(unused_unsafe, clippy::cast_possible_truncation)] #[inline(always)] - unsafe fn gather(indices: *const $idx, max_idx: $idx, src: *const $value, dst: *mut $value) { + unsafe fn gather( + indices: *const $idx, + max_idx: $idx, + src: *const $value, + dst: *mut $value + ) -> bool { const { assert!($WIDTH <= $STRIDE, "dst cannot advance by more than the stride"); } @@ -158,33 +152,39 @@ macro_rules! impl_gather { const SCALE: i32 = std::mem::size_of::<$value>() as i32; let indices_vec = unsafe { $load(indices.cast()) }; - // Extend indices to fill vector register + // Extend indices to fill vector register. let indices_vec = unsafe { $extend(indices_vec) }; - // create a vec of the max idx + // Create a vec of the max idx. let max_idx_vec = unsafe { $splat(max_idx as _) }; - // create a mask for valid indices (where the max_idx > provided index). - let invalid_mask = unsafe { $mask_indices(max_idx_vec, indices_vec) }; - let invalid_mask = { - let $mask_var = invalid_mask; + // Create a mask for valid indices (where the max_idx > provided index). + let valid_mask = unsafe { $mask_indices(max_idx_vec, indices_vec) }; + let valid_mask = { + let $mask_var = valid_mask; $mask_cvt }; let zero_vec = unsafe { $zero_vec() }; // Gather the values into new vector register, for masked positions // it substitutes zero instead of accessing the src. - let values_vec = unsafe { $masked_gather::(zero_vec, src.cast(), indices_vec, invalid_mask) }; + let values_vec = unsafe { + $masked_gather::(zero_vec, src.cast(), indices_vec, valid_mask) + }; // Write the vec out to dst. unsafe { $store(dst.cast(), values_vec) }; + + // Return true if all indices were valid (all mask bits set). + let mask_bits = unsafe { $movemask(valid_mask) }; + mask_bits == $all_valid_mask } } }; } -// kernels for u8 indices +// Kernels for u8 indices. impl_gather!(u8, - // 32-bit values, loaded 8 at a time + // 32-bit values, loaded 8 at a time. { u32 => load: _mm_loadu_si128, extend: _mm256_cvtepu8_epi32, @@ -192,23 +192,13 @@ impl_gather!(u8, zero_vec: _mm256_setzero_si256, mask_indices: _mm256_cmpgt_epi32, mask_cvt: |x| { x }, + movemask: _mm256_movemask_epi8, + all_valid_mask: -1_i32, gather: _mm256_mask_i32gather_epi32, store: _mm256_storeu_si256, WIDTH = 8, STRIDE = 16 }, - { i32 => - load: _mm_loadu_si128, - extend: _mm256_cvtepu8_epi32, - splat: _mm256_set1_epi32, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi32, - mask_cvt: |x| { x }, - gather: _mm256_mask_i32gather_epi32, - store: _mm256_storeu_si256, - WIDTH = 8, STRIDE = 16 - }, - - // 64-bit values, loaded 4 at a time + // 64-bit values, loaded 4 at a time. { u64 => load: _mm_loadu_si128, extend: _mm256_cvtepu8_epi64, @@ -216,26 +206,17 @@ impl_gather!(u8, zero_vec: _mm256_setzero_si256, mask_indices: _mm256_cmpgt_epi64, mask_cvt: |x| { x }, - gather: _mm256_mask_i64gather_epi64, - store: _mm256_storeu_si256, - WIDTH = 4, STRIDE = 16 - }, - { i64 => - load: _mm_loadu_si128, - extend: _mm256_cvtepu8_epi64, - splat: _mm256_set1_epi64x, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi64, - mask_cvt: |x| { x }, + movemask: _mm256_movemask_epi8, + all_valid_mask: -1_i32, gather: _mm256_mask_i64gather_epi64, store: _mm256_storeu_si256, WIDTH = 4, STRIDE = 16 } ); -// kernels for u16 indices +// Kernels for u16 indices. impl_gather!(u16, - // 32-bit values. 8x indices loaded at a time and 8x values written at a time + // 32-bit values. 8x indices loaded at a time and 8x values written at a time. { u32 => load: _mm_loadu_si128, extend: _mm256_cvtepu16_epi32, @@ -243,22 +224,12 @@ impl_gather!(u16, zero_vec: _mm256_setzero_si256, mask_indices: _mm256_cmpgt_epi32, mask_cvt: |x| { x }, + movemask: _mm256_movemask_epi8, + all_valid_mask: -1_i32, gather: _mm256_mask_i32gather_epi32, store: _mm256_storeu_si256, WIDTH = 8, STRIDE = 8 }, - { i32 => - load: _mm_loadu_si128, - extend: _mm256_cvtepu16_epi32, - splat: _mm256_set1_epi32, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi32, - mask_cvt: |x| { x }, - gather: _mm256_mask_i32gather_epi32, - store: _mm256_storeu_si256, - WIDTH = 8, STRIDE = 8 - }, - // 64-bit values. 8x indices loaded at a time and 4x values loaded at a time. { u64 => load: _mm_loadu_si128, @@ -267,26 +238,17 @@ impl_gather!(u16, zero_vec: _mm256_setzero_si256, mask_indices: _mm256_cmpgt_epi64, mask_cvt: |x| { x }, - gather: _mm256_mask_i64gather_epi64, - store: _mm256_storeu_si256, - WIDTH = 4, STRIDE = 8 - }, - { i64 => - load: _mm_loadu_si128, - extend: _mm256_cvtepu16_epi64, - splat: _mm256_set1_epi64x, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi64, - mask_cvt: |x| { x }, + movemask: _mm256_movemask_epi8, + all_valid_mask: -1_i32, gather: _mm256_mask_i64gather_epi64, store: _mm256_storeu_si256, WIDTH = 4, STRIDE = 8 } ); -// kernels for u32 indices +// Kernels for u32 indices. impl_gather!(u32, - // 32-bit values. 8x indices loaded at a time and 8x values written + // 32-bit values. 8x indices loaded at a time and 8x values written. { u32 => load: _mm256_loadu_si256, extend: identity, @@ -294,23 +256,13 @@ impl_gather!(u32, zero_vec: _mm256_setzero_si256, mask_indices: _mm256_cmpgt_epi32, mask_cvt: |x| { x }, + movemask: _mm256_movemask_epi8, + all_valid_mask: -1_i32, gather: _mm256_mask_i32gather_epi32, store: _mm256_storeu_si256, WIDTH = 8, STRIDE = 8 }, - { i32 => - load: _mm256_loadu_si256, - extend: identity, - splat: _mm256_set1_epi32, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi32, - mask_cvt: |x| { x }, - gather: _mm256_mask_i32gather_epi32, - store: _mm256_storeu_si256, - WIDTH = 8, STRIDE = 8 - }, - - // 64-bit values + // 64-bit values. { u64 => load: _mm_loadu_si128, extend: _mm256_cvtepu32_epi64, @@ -318,25 +270,17 @@ impl_gather!(u32, zero_vec: _mm256_setzero_si256, mask_indices: _mm256_cmpgt_epi64, mask_cvt: |x| { x }, - gather: _mm256_mask_i64gather_epi64, - store: _mm256_storeu_si256, - WIDTH = 4, STRIDE = 4 - }, - { i64 => - load: _mm_loadu_si128, - extend: _mm256_cvtepu32_epi64, - splat: _mm256_set1_epi64x, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi64, - mask_cvt: |x| { x }, + movemask: _mm256_movemask_epi8, + all_valid_mask: -1_i32, gather: _mm256_mask_i64gather_epi64, store: _mm256_storeu_si256, WIDTH = 4, STRIDE = 4 } ); -// kernels for u64 indices +// Kernels for u64 indices. impl_gather!(u64, + // 32-bit values. { u32 => load: _mm256_loadu_si256, extend: identity, @@ -352,31 +296,13 @@ impl_gather!(u64, _mm_unpacklo_epi64(lo_packed, hi_packed) } }, + movemask: _mm_movemask_epi8, + all_valid_mask: 0xFFFF_i32, gather: _mm256_mask_i64gather_epi32, store: _mm_storeu_si128, WIDTH = 4, STRIDE = 4 }, - { i32 => - load: _mm256_loadu_si256, - extend: identity, - splat: _mm256_set1_epi64x, - zero_vec: _mm_setzero_si128, - mask_indices: _mm256_cmpgt_epi64, - mask_cvt: |m| { - unsafe { - let lo_bits = _mm256_extracti128_si256::<0>(m); // lower half - let hi_bits = _mm256_extracti128_si256::<1>(m); // upper half - let lo_packed = _mm_shuffle_epi32::<0b01_01_01_01>(lo_bits); - let hi_packed = _mm_shuffle_epi32::<0b01_01_01_01>(hi_bits); - _mm_unpacklo_epi64(lo_packed, hi_packed) - } - }, - gather: _mm256_mask_i64gather_epi32, - store: _mm_storeu_si128, - WIDTH = 4, STRIDE = 4 - }, - - // 64-bit values + // 64-bit values. { u64 => load: _mm256_loadu_si256, extend: identity, @@ -384,17 +310,8 @@ impl_gather!(u64, zero_vec: _mm256_setzero_si256, mask_indices: _mm256_cmpgt_epi64, mask_cvt: |x| { x }, - gather: _mm256_mask_i64gather_epi64, - store: _mm256_storeu_si256, - WIDTH = 4, STRIDE = 4 - }, - { i64 => - load: _mm256_loadu_si256, - extend: identity, - splat: _mm256_set1_epi64x, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi64, - mask_cvt: |x| { x }, + movemask: _mm256_movemask_epi8, + all_valid_mask: -1_i32, gather: _mm256_mask_i64gather_epi64, store: _mm256_storeu_si256, WIDTH = 4, STRIDE = 4 @@ -416,6 +333,8 @@ where let buf_uninit = buffer.spare_capacity_mut(); let mut offset = 0; + let mut all_valid = true; + // Loop terminates STRIDE elements before end of the indices array because the GatherFn // might read up to STRIDE src elements at a time, even though it only advances WIDTH elements // in the dst. @@ -423,7 +342,7 @@ where // SAFETY: gather_simd preconditions satisfied: // 1. `(indices + offset)..(indices + offset + STRIDE)` is in-bounds for indices allocation // 2. `buffer` has same len as indices so `buffer + offset + STRIDE` is always valid. - unsafe { + let batch_valid = unsafe { Gather::gather( indices.as_ptr().add(offset), max_index, @@ -431,10 +350,15 @@ where buf_uninit.as_mut_ptr().add(offset).cast(), ) }; + all_valid &= batch_valid; offset += Gather::WIDTH; } - // Remainder + // Check accumulated validity after hot loop. If there are any 0's, then there was an + // out-of-bounds index. + assert!(all_valid, "index out of bounds in AVX2 take"); + + // Fall back to scalar iteration for the remainder. while offset < indices_len { buf_uninit[offset].write(values[indices[offset].as_()]); offset += 1; diff --git a/vortex-compute/src/take/slice/mod.rs b/vortex-compute/src/take/slice/mod.rs index f9d806bd155..e8f7f81ebaf 100644 --- a/vortex-compute/src/take/slice/mod.rs +++ b/vortex-compute/src/take/slice/mod.rs @@ -3,6 +3,11 @@ //! Take function implementations on slices. +#![allow( + unused, + reason = "Compiler may see things in this module as unused based on enabled features" +)] + use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_dtype::UnsignedPType; @@ -22,10 +27,6 @@ impl Take<[I]> for &[T] { return portable::take_portable(self, indices); } - // TODO(connor): Make the SIMD implementations bound by `Copy` instead of `NativePType`. - - /* - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { if is_x86_feature_detected!("avx2") { @@ -34,14 +35,12 @@ impl Take<[I]> for &[T] { } } - */ - #[allow(unreachable_code, reason = "`vortex_nightly` path returns early")] take_scalar(self, indices) } } -fn take_scalar(buffer: &[T], indices: &[I]) -> Buffer { +pub(crate) fn take_scalar(buffer: &[T], indices: &[I]) -> Buffer { // NB: The simpler `indices.iter().map(|idx| buff1er[idx.as_()]).collect()` generates suboptimal // assembly where the buffer length is repeatedly loaded from the stack on each iteration. From 078ca7c9ec56a9b4b6027e6dcb26e777837b0b15 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Sun, 14 Dec 2025 10:05:03 -0500 Subject: [PATCH 3/3] TODO Signed-off-by: Connor Tsui --- .../src/arrays/primitive/compute/take/avx2.rs | 5 + vortex-compute/src/take/slice/asm_stubs.rs | 91 +++++++++++++++++++ vortex-compute/src/take/slice/mod.rs | 16 ++-- vortex-compute/src/take/slice/portable.rs | 50 +++++----- 4 files changed, 131 insertions(+), 31 deletions(-) create mode 100644 vortex-compute/src/take/slice/asm_stubs.rs diff --git a/vortex-array/src/arrays/primitive/compute/take/avx2.rs b/vortex-array/src/arrays/primitive/compute/take/avx2.rs index c330a9226a3..c86c57f6ac9 100644 --- a/vortex-array/src/arrays/primitive/compute/take/avx2.rs +++ b/vortex-array/src/arrays/primitive/compute/take/avx2.rs @@ -6,6 +6,11 @@ //! Only enabled for x86_64 hosts and it is gated at runtime behind feature detection to //! ensure AVX2 instructions are available. +#![allow( + unused, + reason = "Compiler may see things in this module as unused based on enabled features" +)] + use vortex_compute::take::slice::avx2; use vortex_dtype::NativePType; use vortex_dtype::UnsignedPType; diff --git a/vortex-compute/src/take/slice/asm_stubs.rs b/vortex-compute/src/take/slice/asm_stubs.rs new file mode 100644 index 00000000000..f855a3f25f8 --- /dev/null +++ b/vortex-compute/src/take/slice/asm_stubs.rs @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Assembly inspection stubs for cargo-show-asm. +//! +//! These functions are `#[inline(never)]` wrappers around the take implementations +//! to allow inspecting the generated assembly with cargo-show-asm. +//! +//! # Usage +//! +//! ```bash +//! # Scalar implementations +//! cargo asm -p vortex-compute take_scalar_u32_u32 --rust +//! cargo asm -p vortex-compute take_scalar_u64_u32 --rust +//! +//! # AVX2 implementations +//! cargo asm -p vortex-compute take_avx2_u32_u32 --rust +//! cargo asm -p vortex-compute take_avx2_u64_u32 --rust +//! +//! # Portable SIMD implementations (requires nightly) +//! RUSTFLAGS='--cfg vortex_nightly' cargo +nightly asm -p vortex-compute take_portable_simd_u32_u32 --rust +//! RUSTFLAGS='--cfg vortex_nightly' cargo +nightly asm -p vortex-compute take_portable_simd_u64_u32 --rust +//! ``` + +#![allow(unused, reason = "These stubs are for assembly inspection only")] + +use vortex_buffer::Buffer; + +// ============ SCALAR STUBS ============ + +/// Scalar take: u32 values, u32 indices. +#[inline(never)] +pub fn take_scalar_u32_u32(buffer: &[u32], indices: &[u32]) -> Buffer { + super::take_scalar(buffer, indices) +} + +/// Scalar take: u64 values, u32 indices. +#[inline(never)] +pub fn take_scalar_u64_u32(buffer: &[u64], indices: &[u32]) -> Buffer { + super::take_scalar(buffer, indices) +} + +// ============ PORTABLE SIMD STUBS ============ + +/// Portable SIMD assembly stubs. +#[cfg(vortex_nightly)] +pub mod portable { + use vortex_buffer::Buffer; + + /// Portable SIMD take: u32 values, u32 indices. + #[inline(never)] + pub fn take_portable_simd_u32_u32(buffer: &[u32], indices: &[u32]) -> Buffer { + super::super::portable::take_portable_simd::(buffer, indices) + } + + /// Portable SIMD take: u64 values, u32 indices. + #[inline(never)] + pub fn take_portable_simd_u64_u32(buffer: &[u64], indices: &[u32]) -> Buffer { + super::super::portable::take_portable_simd::(buffer, indices) + } +} + +// ============ AVX2 STUBS ============ + +/// AVX2 assembly stubs. +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +pub mod avx2 { + use vortex_buffer::Buffer; + + /// AVX2 take: u32 values, u32 indices. + /// + /// # Safety + /// + /// Caller must ensure AVX2 is available. + #[inline(never)] + #[target_feature(enable = "avx2")] + pub unsafe fn take_avx2_u32_u32(buffer: &[u32], indices: &[u32]) -> Buffer { + unsafe { super::super::avx2::take_avx2(buffer, indices) } + } + + /// AVX2 take: u64 values, u32 indices. + /// + /// # Safety + /// + /// Caller must ensure AVX2 is available. + #[inline(never)] + #[target_feature(enable = "avx2")] + pub unsafe fn take_avx2_u64_u32(buffer: &[u64], indices: &[u32]) -> Buffer { + unsafe { super::super::avx2::take_avx2(buffer, indices) } + } +} diff --git a/vortex-compute/src/take/slice/mod.rs b/vortex-compute/src/take/slice/mod.rs index e8f7f81ebaf..cf6b0f0075c 100644 --- a/vortex-compute/src/take/slice/mod.rs +++ b/vortex-compute/src/take/slice/mod.rs @@ -14,6 +14,7 @@ use vortex_dtype::UnsignedPType; use crate::take::Take; +pub mod asm_stubs; pub mod avx2; pub mod portable; @@ -22,19 +23,22 @@ impl Take<[I]> for &[T] { type Output = Buffer; fn take(self, indices: &[I]) -> Buffer { - #[cfg(vortex_nightly)] - { - return portable::take_portable(self, indices); - } - + // Prefer hand-tuned AVX2 on x86 - it's faster than portable SIMD because it uses + // native 32-bit index gathers (`vpgatherdd`) instead of extending to 64-bit. #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { if is_x86_feature_detected!("avx2") { - // SAFETY: We just checked that the AVX2 feature in enabled. + // SAFETY: We just checked that the AVX2 feature is enabled. return unsafe { avx2::take_avx2(self, indices) }; } } + // Fall back to portable SIMD on non-x86 platforms (ARM, RISC-V) or x86 without AVX2. + #[cfg(vortex_nightly)] + { + return portable::take_portable(self, indices); + } + #[allow(unreachable_code, reason = "`vortex_nightly` path returns early")] take_scalar(self, indices) } diff --git a/vortex-compute/src/take/slice/portable.rs b/vortex-compute/src/take/slice/portable.rs index feec9ca312d..ae4a9d9a932 100644 --- a/vortex-compute/src/take/slice/portable.rs +++ b/vortex-compute/src/take/slice/portable.rs @@ -18,9 +18,12 @@ use vortex_buffer::BufferMut; use vortex_dtype::UnsignedPType; use vortex_dtype::match_each_unsigned_integer_ptype; -/// SIMD types larger than the SIMD register size are beneficial for -/// performance as this leads to better instruction level parallelism. -pub const SIMD_WIDTH: usize = 64; +/// SIMD lane count. Benchmarking shows 16 is optimal on AVX-512 systems: +/// - 8 lanes: ~33 µs for 100k elements +/// - 16 lanes: ~31 µs (best) +/// - 32 lanes: ~34 µs +/// - 64 lanes: ~41 µs (causes stack spills) +pub const SIMD_WIDTH: usize = 16; /// Takes the specified indices into a new [`Buffer`] using portable SIMD. /// @@ -87,7 +90,7 @@ fn take_with_indices( /// # Panics /// /// Panics if any index is out of bounds for `values`. -#[multiversion(targets("x86_64+avx2", "x86_64+avx", "aarch64+neon"))] +#[multiversion(targets("x86_64+avx512f+avx512vl", "x86_64+avx2", "x86_64+avx", "aarch64+neon"))] pub fn take_portable_simd(values: &[T], indices: &[I]) -> Buffer where T: Copy + Default + simd::SimdElement, @@ -177,10 +180,10 @@ mod tests { #[test] #[should_panic(expected = "index out of bounds")] fn test_take_out_of_bounds() { - let indices = vec![2_000_000u32; 64]; + let indices = vec![2_000_000u32; 8]; let values = vec![1i32]; - drop(take_portable_simd::(&values, &indices)); + drop(take_portable_simd::(&values, &indices)); } /// Tests SIMD gather with a mix of sequential, strided, and repeated indices. This exercises @@ -206,7 +209,8 @@ mod tests { // Reverse: 255, 254, ..., 216. indices.extend((216u32..256).rev()); - let result = take_portable_simd::(&values, &indices); + // Use 4 lanes for i64 (256-bit / 64-bit = 4). + let result = take_portable_simd::(&values, &indices); let result_slice = result.as_slice(); // Verify sequential portion. @@ -244,25 +248,25 @@ mod tests { fn test_take_with_remainder() { let values: Vec = (0..1000).collect(); - // Use 64 + 37 = 101 indices to test both the SIMD loop (64 elements) and the scalar - // remainder (37 elements). - let indices: Vec = (0u8..101).collect(); + // Use 8 + 5 = 13 indices to test both the SIMD loop (8 elements) and the scalar + // remainder (5 elements). Using 8 lanes for u16 values. + let indices: Vec = (0u8..13).collect(); - let result = take_portable_simd::(&values, &indices); + let result = take_portable_simd::(&values, &indices); let result_slice = result.as_slice(); - assert_eq!(result_slice.len(), 101); + assert_eq!(result_slice.len(), 13); // Verify all elements. - for i in 0..101 { + for i in 0..13 { assert_eq!(result_slice[i], i as u16, "mismatch at index {i}"); } // Also test with exactly 1 remainder element. - let indices_one_remainder: Vec = (0u8..65).collect(); - let result_one = take_portable_simd::(&values, &indices_one_remainder); - assert_eq!(result_one.as_slice().len(), 65); - assert_eq!(result_one.as_slice()[64], 64); + let indices_one_remainder: Vec = (0u8..9).collect(); + let result_one = take_portable_simd::(&values, &indices_one_remainder); + assert_eq!(result_one.as_slice().len(), 9); + assert_eq!(result_one.as_slice()[8], 8); } /// Tests gather with large 64-bit values and various index types to ensure no truncation @@ -283,17 +287,13 @@ mod tests { ]; // Indices that access each value multiple times in different orders. + // Pad to 8 to ensure we hit the SIMD path (4 lanes for i64). let indices: Vec = vec![ - 0, 8, 1, 7, 2, 6, 3, 5, 4, // Forward-backward interleaved. - 8, 8, 8, 0, 0, 0, // Repeated extremes. - 4, 4, 4, 4, 4, 4, 4, 4, // Repeated zero. - 0, 1, 2, 3, 4, 5, 6, 7, 8, // Sequential. - 8, 7, 6, 5, 4, 3, 2, 1, 0, // Reverse. - // Pad to 64 to ensure we hit the SIMD path. - 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, + 0, 8, 1, 7, 2, 6, 3, 5, // 8 indices - exercises SIMD path ]; - let result = take_portable_simd::(&values, &indices); + // Use 4 lanes for i64 (256-bit / 64-bit = 4). + let result = take_portable_simd::(&values, &indices); let result_slice = result.as_slice(); // Verify each result matches the expected value.