From c636a907d6b95d5e478ea77c1fd291ba28e90c1f Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sun, 15 Feb 2026 21:47:46 -0800 Subject: [PATCH 01/10] Add a Neon backend. --- .github/workflows/ci.yml | 1 + .../src/algorithms/hadamard.rs | 18 + diskann-quantization/src/bits/distances.rs | 29 +- diskann-quantization/src/spherical/iface.rs | 41 +- diskann-vector/src/conversion.rs | 82 +- .../src/distance/distance_provider.rs | 21 +- .../src/distance/implementations.rs | 7 +- diskann-vector/src/distance/simd.rs | 1051 ++++++++++++++++- diskann-wide/compile-aarch64-on-x86.sh | 4 + diskann-wide/src/arch/aarch64/double.rs | 373 ++++++ diskann-wide/src/arch/aarch64/f16x4_.rs | 113 ++ diskann-wide/src/arch/aarch64/f16x8_.rs | 133 +++ diskann-wide/src/arch/aarch64/f32x2_.rs | 92 ++ diskann-wide/src/arch/aarch64/f32x4_.rs | 194 +++ diskann-wide/src/arch/aarch64/i16x8_.rs | 106 ++ diskann-wide/src/arch/aarch64/i32x4_.rs | 259 ++++ diskann-wide/src/arch/aarch64/i64x2_.rs | 117 ++ diskann-wide/src/arch/aarch64/i8x16_.rs | 99 ++ diskann-wide/src/arch/aarch64/i8x8_.rs | 89 ++ diskann-wide/src/arch/aarch64/macros.rs | 575 +++++++++ diskann-wide/src/arch/aarch64/masks.rs | 822 +++++++++++++ diskann-wide/src/arch/aarch64/mod.rs | 382 ++++++ diskann-wide/src/arch/aarch64/u16x8_.rs | 97 ++ diskann-wide/src/arch/aarch64/u32x4_.rs | 162 +++ diskann-wide/src/arch/aarch64/u64x2_.rs | 127 ++ diskann-wide/src/arch/aarch64/u8x16_.rs | 100 ++ diskann-wide/src/arch/aarch64/u8x8_.rs | 89 ++ diskann-wide/src/arch/mod.rs | 24 + diskann-wide/src/doubled.rs | 2 + diskann-wide/src/emulated.rs | 64 + diskann-wide/src/helpers.rs | 14 + diskann-wide/src/lib.rs | 2 - diskann-wide/src/test_utils/dot_product.rs | 134 ++- diskann-wide/src/test_utils/ops.rs | 9 +- diskann-wide/tests/dispatch.rs | 121 +- 35 files changed, 5441 insertions(+), 112 deletions(-) create mode 100755 diskann-wide/compile-aarch64-on-x86.sh create mode 100644 diskann-wide/src/arch/aarch64/double.rs create mode 100644 diskann-wide/src/arch/aarch64/f16x4_.rs create mode 100644 diskann-wide/src/arch/aarch64/f16x8_.rs create mode 100644 diskann-wide/src/arch/aarch64/f32x2_.rs create mode 100644 diskann-wide/src/arch/aarch64/f32x4_.rs create mode 100644 diskann-wide/src/arch/aarch64/i16x8_.rs create mode 100644 diskann-wide/src/arch/aarch64/i32x4_.rs create mode 100644 diskann-wide/src/arch/aarch64/i64x2_.rs create mode 100644 diskann-wide/src/arch/aarch64/i8x16_.rs create mode 100644 diskann-wide/src/arch/aarch64/i8x8_.rs create mode 100644 diskann-wide/src/arch/aarch64/macros.rs create mode 100644 diskann-wide/src/arch/aarch64/masks.rs create mode 100644 diskann-wide/src/arch/aarch64/mod.rs create mode 100644 diskann-wide/src/arch/aarch64/u16x8_.rs create mode 100644 diskann-wide/src/arch/aarch64/u32x4_.rs create mode 100644 diskann-wide/src/arch/aarch64/u64x2_.rs create mode 100644 diskann-wide/src/arch/aarch64/u8x16_.rs create mode 100644 diskann-wide/src/arch/aarch64/u8x8_.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f3e5bafb4..9b9c6f2fe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -234,6 +234,7 @@ jobs: os: - windows-latest - ubuntu-latest + - ubuntu-24.04-arm steps: - uses: actions/checkout@v4 diff --git a/diskann-quantization/src/algorithms/hadamard.rs b/diskann-quantization/src/algorithms/hadamard.rs index 8871b93cc..385a8bb05 100644 --- a/diskann-quantization/src/algorithms/hadamard.rs +++ b/diskann-quantization/src/algorithms/hadamard.rs @@ -81,6 +81,24 @@ impl } } +#[cfg(target_arch = "aarch64")] +impl + diskann_wide::arch::Target1< + diskann_wide::arch::aarch64::Neon, + Result<(), NotPowerOfTwo>, + &mut [f32], + > for HadamardTransform +{ + #[inline(never)] + fn run( + self, + arch: diskann_wide::arch::aarch64::Neon, + x: &mut [f32], + ) -> Result<(), NotPowerOfTwo> { + arch.retarget().run1(HadamardTransformOuter, x) + } +} + //////////////////// // Implementation // //////////////////// diff --git a/diskann-quantization/src/bits/distances.rs b/diskann-quantization/src/bits/distances.rs index b5d5d5766..9dd1359ed 100644 --- a/diskann-quantization/src/bits/distances.rs +++ b/diskann-quantization/src/bits/distances.rs @@ -118,7 +118,6 @@ type USlice<'a, const N: usize, Perm = Dense> = BitSlice<'a, N, Unsigned, Perm>; /// Retarget the [`diskann_wide::arch::x86_64::V3`] architecture to /// [`diskann_wide::arch::Scalar`] or [`diskann_wide::arch::x86_64::V4`] to V3 etc. -#[cfg(target_arch = "x86_64")] macro_rules! retarget { ($arch:path, $op:ty, $N:literal) => { impl Target2< @@ -788,6 +787,18 @@ retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3); #[cfg(target_arch = "x86_64")] retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2); +#[cfg(target_arch = "aarch64")] +retarget!( + diskann_wide::arch::aarch64::Neon, + SquaredL2, + 7, + 6, + 5, + 4, + 3, + 2 +); + dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8); /////////////////// @@ -1317,6 +1328,18 @@ retarget!(diskann_wide::arch::x86_64::V3, InnerProduct, 7, 6, 5, 3); #[cfg(target_arch = "x86_64")] retarget!(diskann_wide::arch::x86_64::V4, InnerProduct, 7, 6, 4, 5, 3); +#[cfg(target_arch = "aarch64")] +retarget!( + diskann_wide::arch::aarch64::Neon, + InnerProduct, + 7, + 6, + 4, + 5, + 3, + 2 +); + dispatch_pure!(InnerProduct, 1, 2, 3, 4, 5, 6, 7, 8); ////////////////// @@ -1905,7 +1928,6 @@ where } /// Implement `Target2` for higher architecture in terms of the scalar fallback. -#[cfg(target_arch = "x86_64")] macro_rules! ip_retarget { ($arch:path, $N:literal) => { impl Target2<$arch, MathematicalResult, &[f32], USlice<'_, $N>> @@ -1933,6 +1955,9 @@ ip_retarget!(diskann_wide::arch::x86_64::V3, 3, 5, 6, 7, 8); #[cfg(target_arch = "x86_64")] ip_retarget!(diskann_wide::arch::x86_64::V4, 1, 2, 3, 4, 5, 6, 7, 8); +#[cfg(target_arch = "aarch64")] +ip_retarget!(diskann_wide::arch::aarch64::Neon, 1, 2, 3, 4, 5, 6, 7, 8); + /// Delegate the implementation of `PureDistanceFunction` to `diskann_wide::arch::Target2` /// with the current architectures. macro_rules! dispatch_full_ip { diff --git a/diskann-quantization/src/spherical/iface.rs b/diskann-quantization/src/spherical/iface.rs index 6b9e9c349..f44920370 100644 --- a/diskann-quantization/src/spherical/iface.rs +++ b/diskann-quantization/src/spherical/iface.rs @@ -148,6 +148,9 @@ use thiserror::Error; #[cfg(target_arch = "x86_64")] use diskann_wide::arch::x86_64::{V3, V4}; +#[cfg(target_arch = "aarch64")] +use diskann_wide::arch::aarch64::Neon; + use super::{ CompensatedCosine, CompensatedIP, CompensatedSquaredL2, Data, DataMut, DataRef, FullQuery, FullQueryMut, FullQueryRef, Query, QueryMut, QueryRef, SphericalQuantizer, SupportedMetric, @@ -1382,6 +1385,25 @@ cfg_if::cfg_if! { dispatch_map!(2, AsQuery<2>, V4); // specialized dispatch_map!(4, AsQuery<4>, V4, downcast_to_v3); dispatch_map!(8, AsQuery<8>, V4, downcast_to_v3); + } else if #[cfg(target_arch = "aarch64")] { + fn downcast(arch: Neon) -> Scalar { + arch.retarget() + } + + dispatch_map!(1, AsFull, Neon, downcast); + dispatch_map!(2, AsFull, Neon, downcast); + dispatch_map!(4, AsFull, Neon, downcast); + dispatch_map!(8, AsFull, Neon, downcast); + + dispatch_map!(1, AsData<1>, Neon, downcast); + dispatch_map!(2, AsData<2>, Neon, downcast); + dispatch_map!(4, AsData<4>, Neon, downcast); + dispatch_map!(8, AsData<8>, Neon, downcast); + + dispatch_map!(1, AsQuery<4, bits::BitTranspose>, Neon, downcast); + dispatch_map!(2, AsQuery<2>, Neon, downcast); + dispatch_map!(4, AsQuery<4>, Neon, downcast); + dispatch_map!(8, AsQuery<8>, Neon, downcast); } } @@ -1476,14 +1498,29 @@ where { } -#[cfg(not(target_arch = "x86_64"))] +#[cfg(target_arch = "aarch64")] +trait Dispatchable: BuildComputer + BuildComputer +where + Q: FromOpaque, +{ +} + +#[cfg(target_arch = "aarch64")] +impl Dispatchable for T +where + Q: FromOpaque, + T: BuildComputer + BuildComputer, +{ +} + +#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] trait Dispatchable: BuildComputer where Q: FromOpaque, { } -#[cfg(not(target_arch = "x86_64"))] +#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] impl Dispatchable for T where Q: FromOpaque, diff --git a/diskann-vector/src/conversion.rs b/diskann-vector/src/conversion.rs index d77b7df06..3da74c8aa 100644 --- a/diskann-vector/src/conversion.rs +++ b/diskann-vector/src/conversion.rs @@ -5,9 +5,7 @@ use std::convert::{AsMut, AsRef}; -use diskann_wide::arch::Target2; -#[cfg(not(target_arch = "aarch64"))] -use diskann_wide::{Architecture, Const, Constant, SIMDCast, SIMDVector}; +use diskann_wide::{arch::Target2, Architecture, Const, Constant, SIMDCast, SIMDVector}; use half::f16; /// Perform a numeric cast on a slice of values. @@ -119,12 +117,24 @@ where } } +#[cfg(target_arch = "aarch64")] +impl Target2 for SliceCast +where + T: AsMut<[To]>, + U: AsRef<[From]>, + diskann_wide::arch::aarch64::Neon: SIMDConvert, +{ + #[inline(always)] + fn run(self, arch: diskann_wide::arch::aarch64::Neon, mut to: T, from: U) { + simd_convert(arch, to.as_mut(), from.as_ref()) + } +} + ///////////////////////////// // General SIMD Conversion // ///////////////////////////// /// A helper trait to fill in the gaps for the unrolled `simd_convert` method. -#[cfg(target_arch = "x86_64")] trait SIMDConvert: Architecture { /// A constant encoding the the SIMD width of the underlying schema. type Width: Constant; @@ -171,7 +181,6 @@ trait SIMDConvert: Architecture { #[inline(never)] #[allow(clippy::panic)] -#[cfg(target_arch = "x86_64")] fn emit_length_error(xlen: usize, ylen: usize) -> ! { panic!( "lengths must be equal, instead got: xlen = {}, ylen = {}", @@ -206,7 +215,6 @@ fn emit_length_error(xlen: usize, ylen: usize) -> ! { /// This overlapping can only happen at the very end of the slice and only if the length /// of the slice is not a multiple of the SIMD width used. #[inline(always)] -#[cfg(target_arch = "x86_64")] fn simd_convert(arch: A, to: &mut [To], from: &[From]) where A: SIMDConvert, @@ -379,6 +387,50 @@ impl SIMDConvert for diskann_wide::arch::x86_64::V3 { } } +//---------// +// Aarch64 // +//---------// + +#[cfg(target_arch = "aarch64")] +impl SIMDConvert for diskann_wide::arch::aarch64::Neon { + type Width = Const<4>; + type WideTo = ::f32x4; + type WideFrom = diskann_wide::arch::aarch64::f16x4; + + #[inline(always)] + fn simd_convert(from: Self::WideFrom) -> Self::WideTo { + from.into() + } + + // SAFETY: We only access data in the valid range for `pto` and `pfrom`. + #[inline(always)] + unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) { + for i in 0..len { + *pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i)) + } + } +} + +#[cfg(target_arch = "aarch64")] +impl SIMDConvert for diskann_wide::arch::aarch64::Neon { + type Width = Const<4>; + type WideTo = diskann_wide::arch::aarch64::f16x4; + type WideFrom = ::f32x4; + + #[inline(always)] + fn simd_convert(from: Self::WideFrom) -> Self::WideTo { + from.simd_cast() + } + + // SAFETY: We only access data in the valid range for `pto` and `pfrom`. + #[inline(always)] + unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) { + for i in 0..len { + *pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i)) + } + } +} + /////////// // Tests // /////////// @@ -480,6 +532,15 @@ mod tests { if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() { SliceCast::::new().run(arch, dst.as_mut_slice(), src.as_slice()) } + + if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() { + SliceCast::::new().run(arch, dst.as_mut_slice(), src.as_slice()) + } + } + + #[cfg(target_arch = "aarch64")] + if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() { + SliceCast::::new().run(arch, dst.as_mut_slice(), src.as_slice()) } } } @@ -508,6 +569,15 @@ mod tests { if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() { SliceCast::::new().run(arch, dst.as_mut_slice(), src.as_slice()) } + + if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() { + SliceCast::::new().run(arch, dst.as_mut_slice(), src.as_slice()) + } + } + + #[cfg(target_arch = "aarch64")] + if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() { + SliceCast::::new().run(arch, dst.as_mut_slice(), src.as_slice()) } } } diff --git a/diskann-vector/src/distance/distance_provider.rs b/diskann-vector/src/distance/distance_provider.rs index 689c37531..38d746f52 100644 --- a/diskann-vector/src/distance/distance_provider.rs +++ b/diskann-vector/src/distance/distance_provider.rs @@ -3,6 +3,8 @@ * Licensed under the MIT license. */ +#[cfg(target_arch = "aarch64")] +use diskann_wide::arch::aarch64::Neon; #[cfg(target_arch = "x86_64")] use diskann_wide::arch::x86_64::{V3, V4}; use diskann_wide::{ @@ -12,12 +14,9 @@ use diskann_wide::{ }; use half::f16; -use super::{Cosine, CosineNormalized, InnerProduct, SquaredL2}; +use super::{implementations::Specialize, Cosine, CosineNormalized, InnerProduct, SquaredL2}; use crate::distance::Metric; -#[cfg(target_arch = "x86_64")] -use super::implementations::Specialize; - /// Return a function pointer-like [`Distance`] to compute the requested metric. /// /// If `dimension` is provided, then the returned function may **only** be used on @@ -277,6 +276,18 @@ mod x86_64 { specialize!(@integer, V4, i8, i8, 128, 100); } +#[cfg(target_arch = "aarch64")] +mod aarch64 { + use super::*; + + specialize!(Neon, f32, f32, 768, 384, 128, 100); + specialize!(Neon, f32, f16, 768, 384, 128, 100); + specialize!(Neon, f16, f16, 768, 384, 128, 100); + + specialize!(@integer, Neon, u8, u8, 128); + specialize!(@integer, Neon, i8, i8, 128, 100); +} + /// Specialize a distance function `F` for the dimension `dim` if possible. Otherwise, /// return `None`. trait TrySpecialize @@ -289,10 +300,8 @@ where } /// Specialize a distance function for the requested dimensionality. -#[cfg(target_arch = "x86_64")] struct Spec; -#[cfg(target_arch = "x86_64")] impl TrySpecialize for Spec where A: Architecture, diff --git a/diskann-vector/src/distance/implementations.rs b/diskann-vector/src/distance/implementations.rs index 88bff6892..c8ac98f49 100644 --- a/diskann-vector/src/distance/implementations.rs +++ b/diskann-vector/src/distance/implementations.rs @@ -70,11 +70,9 @@ macro_rules! architecture_hook { } /// A utility for specializing distance computatiosn for fixed-length slices. -#[cfg(any(test, target_arch = "x86_64"))] #[derive(Debug, Clone, Copy)] pub(crate) struct Specialize(std::marker::PhantomData); -#[cfg(any(test, target_arch = "x86_64"))] impl diskann_wide::arch::FTarget2 for Specialize where @@ -101,7 +99,6 @@ where // Outline the panic formatting and keep the calling convention the same as // the top function. This keeps code generation extremely lightweight. -#[cfg(any(test, target_arch = "x86_64"))] #[inline(never)] #[allow(clippy::panic)] fn fail_length_check(x: &[L], y: &[R], len: usize) -> ! { @@ -116,6 +113,10 @@ fn fail_length_check(x: &[L], y: &[R], len: usize) -> ! { ); } +pub fn test_function(x: &[Half], y: &[Half]) -> f32 { + InnerProduct::evaluate(x, y) +} + /// An internal trait to transform the result of the low-level SIMD ops into a value /// expected by the rest of DiskANN. /// diff --git a/diskann-vector/src/distance/simd.rs b/diskann-vector/src/distance/simd.rs index b83b238a8..2feef2bdc 100644 --- a/diskann-vector/src/distance/simd.rs +++ b/diskann-vector/src/distance/simd.rs @@ -8,11 +8,12 @@ use std::convert::AsRef; #[cfg(target_arch = "x86_64")] use diskann_wide::arch::x86_64::{V3, V4}; -#[cfg(not(target_arch = "aarch64"))] -use diskann_wide::SIMDDotProduct; +#[cfg(target_arch = "aarch64")] +use diskann_wide::arch::aarch64::Neon; + use diskann_wide::{ - arch::Scalar, Architecture, Const, Constant, Emulated, SIMDAbs, SIMDMulAdd, SIMDSumTree, - SIMDVector, + arch::Scalar, Architecture, Const, Constant, Emulated, SIMDAbs, SIMDDotProduct, SIMDMulAdd, + SIMDSumTree, SIMDVector, }; use crate::Half; @@ -34,6 +35,12 @@ impl LossyF32Conversion for i32 { } } +impl LossyF32Conversion for u32 { + fn as_f32_lossy(self) -> f32 { + self as f32 + } +} + cfg_if::cfg_if! { if #[cfg(miri)] { fn force_eval(_x: f32) {} @@ -741,6 +748,34 @@ where schema.reduce(s0) } +//----------// +// Epilogue // +//----------// + +#[cfg(target_arch = "aarch64")] +#[inline(always)] +unsafe fn scalar_epilogue( + left: *const L, + right: *const R, + len: usize, + mut acc: Acc, + mut f: F, +) -> Acc +where + L: Copy, + R: Copy, + F: FnMut(Acc, L, R) -> Acc, +{ + for i in 0..len { + // SAFETY: The range `[x, x.add(len))` is valid for reads. + let left = unsafe { left.add(i).read_unaligned() }; + // SAFETY: The range `[y, y.add(len))` is valid for reads. + let right = unsafe { right.add(i).read_unaligned() }; + acc = f(acc, left, right); + } + acc +} + ///// ///// L2 Implementations ///// @@ -811,6 +846,59 @@ impl SIMDSchema for L2 { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for L2 { + type SIMDWidth = Const<4>; + type Accumulator = ::f32x4; + type Left = ::f32x4; + type Right = ::f32x4; + type Return = f32; + type Main = Strategy4x1; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::default(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let c = x - y; + c.mul_add_simd(c, acc) + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const f32, + y: *const f32, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let scalar = scalar_epilogue( + x, + y, + len.min(Self::SIMDWidth::value() - 1), + 0.0f32, + |acc, x, y| -> f32 { + let c = x - y; + c.mul_add(c, acc) + }, + ); + acc + Self::Accumulator::from_array(arch, [scalar, 0.0, 0.0, 0.0]) + } + + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum_tree() + } +} + impl SIMDSchema for L2 { type SIMDWidth = Const<4>; type Accumulator = Emulated; @@ -853,9 +941,9 @@ impl SIMDSchema for L2 { let mut s: f32 = 0.0; for i in 0..len { // SAFETY: The range `[x, x.add(len))` is valid for reads. - let vx = unsafe { x.add(i).read() }; + let vx = unsafe { x.add(i).read_unaligned() }; // SAFETY: The range `[y, y.add(len))` is valid for reads. - let vy = unsafe { y.add(i).read() }; + let vy = unsafe { y.add(i).read_unaligned() }; let d = vx - vy; s += d * d; } @@ -936,6 +1024,69 @@ impl SIMDSchema for L2 { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for L2 { + type SIMDWidth = Const<4>; + type Accumulator = ::f32x4; + type Left = diskann_wide::arch::aarch64::f16x4; + type Right = diskann_wide::arch::aarch64::f16x4; + type Return = f32; + type Main = Strategy4x1; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::default(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + diskann_wide::alias!(f32s = ::f32x4); + + let x: f32s = x.into(); + let y: f32s = y.into(); + + let c = x - y; + c.mul_add_simd(c, acc) + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const Half, + y: *const Half, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + diskann_wide::alias!(f32s = ::f32x4); + + let rest = scalar_epilogue( + x, + y, + len.min(Self::SIMDWidth::value() - 1), + f32s::default(arch), + |acc, x: Half, y: Half| -> f32s { + let zero = Half::default(); + let x: f32s = Self::Left::from_array(arch, [x, zero, zero, zero]).into(); + let y: f32s = Self::Right::from_array(arch, [y, zero, zero, zero]).into(); + let c: f32s = x - y; + c.mul_add_simd(c, acc) + }, + ); + acc + rest + } + + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum_tree() + } +} + impl SIMDSchema for L2 { type SIMDWidth = Const<1>; type Accumulator = Emulated; @@ -1075,6 +1226,64 @@ impl SIMDSchema for L2 { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for L2 { + type SIMDWidth = Const<8>; + type Accumulator = ::i32x4; + type Left = diskann_wide::arch::aarch64::i8x8; + type Right = diskann_wide::arch::aarch64::i8x8; + type Return = f32; + type Main = Strategy4x1; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::default(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + diskann_wide::alias!(i16s = ::i16x8); + + let x: i16s = x.into(); + let y: i16s = y.into(); + let c = x - y; + acc.dot_simd(c, c) + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const i8, + y: *const i8, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let scalar = scalar_epilogue( + x, + y, + len.min(Self::SIMDWidth::value() - 1), + 0i32, + |acc, x: i8, y: i8| -> i32 { + let c = (x as i32) - (y as i32); + acc + c * c + }, + ); + acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0]) + } + + // Perform a final reduction. + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum_tree().as_f32_lossy() + } +} + impl SIMDSchema for L2 { type SIMDWidth = Const<4>; type Accumulator = Emulated; @@ -1119,9 +1328,9 @@ impl SIMDSchema for L2 { let mut s: i32 = 0; for i in 0..len { // SAFETY: The range `[x, x.add(len))` is valid for reads. - let vx: i32 = unsafe { x.add(i).read() }.into(); + let vx: i32 = unsafe { x.add(i).read_unaligned() }.into(); // SAFETY: The range `[y, y.add(len))` is valid for reads. - let vy: i32 = unsafe { y.add(i).read() }.into(); + let vy: i32 = unsafe { y.add(i).read_unaligned() }.into(); let d = vx - vy; s += d * d; } @@ -1200,6 +1409,64 @@ impl SIMDSchema for L2 { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for L2 { + type SIMDWidth = Const<8>; + type Accumulator = ::i32x4; + type Left = diskann_wide::arch::aarch64::u8x8; + type Right = diskann_wide::arch::aarch64::u8x8; + type Return = f32; + type Main = Strategy4x1; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::default(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + diskann_wide::alias!(i16s = ::i16x8); + + let x: i16s = x.into(); + let y: i16s = y.into(); + let c = x - y; + acc.dot_simd(c, c) + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const u8, + y: *const u8, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let scalar = scalar_epilogue( + x, + y, + len.min(Self::SIMDWidth::value() - 1), + 0i32, + |acc, x: u8, y: u8| -> i32 { + let c = (x as i32) - (y as i32); + acc + c * c + }, + ); + acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0]) + } + + // Perform a final reduction. + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum_tree().as_f32_lossy() + } +} + impl SIMDSchema for L2 { type SIMDWidth = Const<4>; type Accumulator = Emulated; @@ -1244,9 +1511,9 @@ impl SIMDSchema for L2 { let mut s: i32 = 0; for i in 0..len { // SAFETY: The range `[x, x.add(len))` is valid for reads. - let vx: i32 = unsafe { x.add(i).read() }.into(); + let vx: i32 = unsafe { x.add(i).read_unaligned() }.into(); // SAFETY: The range `[y, y.add(len))` is valid for reads. - let vy: i32 = unsafe { y.add(i).read() }.into(); + let vy: i32 = unsafe { y.add(i).read_unaligned() }.into(); let d = vx - vy; s += d * d; } @@ -1360,6 +1627,55 @@ impl SIMDSchema for IP { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for IP { + type SIMDWidth = Const<4>; + type Accumulator = ::f32x4; + type Left = ::f32x4; + type Right = ::f32x4; + type Return = f32; + type Main = Strategy4x1; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::default(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + x.mul_add_simd(y, acc) + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const f32, + y: *const f32, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let scalar = scalar_epilogue( + x, + y, + len.min(Self::SIMDWidth::value() - 1), + 0.0f32, + |acc, x: f32, y: f32| -> f32 { x.mul_add(y, acc) }, + ); + acc + Self::Accumulator::from_array(arch, [scalar, 0.0, 0.0, 0.0]) + } + + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum_tree() + } +} + impl SIMDSchema for IP { type SIMDWidth = Const<4>; type Accumulator = Emulated; @@ -1401,9 +1717,9 @@ impl SIMDSchema for IP { let mut s: f32 = 0.0; for i in 0..len { // SAFETY: The range `[x, x.add(len))` is valid for reads. - let vx = unsafe { x.add(i).read() }; + let vx = unsafe { x.add(i).read_unaligned() }; // SAFETY: The range `[y, y.add(len))` is valid for reads. - let vy = unsafe { y.add(i).read() }; + let vy = unsafe { y.add(i).read_unaligned() }; s += vx * vy; } acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0]) @@ -1479,6 +1795,67 @@ impl SIMDSchema for IP { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for IP { + type SIMDWidth = Const<4>; + type Accumulator = ::f32x4; + type Left = diskann_wide::arch::aarch64::f16x4; + type Right = diskann_wide::arch::aarch64::f16x4; + type Return = f32; + type Main = Strategy4x1; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::default(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + diskann_wide::alias!(f32s = ::f32x4); + + let x: f32s = x.into(); + let y: f32s = y.into(); + + x.mul_add_simd(y, acc) + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const Half, + y: *const Half, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + diskann_wide::alias!(f32s = ::f32x4); + + let rest = scalar_epilogue( + x, + y, + len.min(Self::SIMDWidth::value() - 1), + f32s::default(arch), + |acc, x: Half, y: Half| -> f32s { + let zero = Half::default(); + let x: f32s = Self::Left::from_array(arch, [x, zero, zero, zero]).into(); + let y: f32s = Self::Right::from_array(arch, [y, zero, zero, zero]).into(); + x.mul_add_simd(y, acc) + }, + ); + acc + rest + } + + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum_tree() + } +} + impl SIMDSchema for IP { type SIMDWidth = Const<1>; type Accumulator = Emulated; @@ -1613,6 +1990,55 @@ impl SIMDSchema for IP { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for IP { + type SIMDWidth = Const<16>; + type Accumulator = ::i32x4; + type Left = ::i8x16; + type Right = ::i8x16; + type Return = f32; + type Main = Strategy4x1; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::default(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + acc.dot_simd(x, y) + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const i8, + y: *const i8, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let scalar = scalar_epilogue( + x, + y, + len.min(Self::SIMDWidth::value() - 1), + 0i32, + |acc, x: i8, y: i8| -> i32 { acc + (x as i32) * (y as i32) }, + ); + acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0]) + } + + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum_tree().as_f32_lossy() + } +} + impl SIMDSchema for IP { type SIMDWidth = Const<1>; type Accumulator = Emulated; @@ -1728,6 +2154,55 @@ impl SIMDSchema for IP { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for IP { + type SIMDWidth = Const<16>; + type Accumulator = ::u32x4; + type Left = ::u8x16; + type Right = ::u8x16; + type Return = f32; + type Main = Strategy4x2; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::default(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + acc.dot_simd(x, y) + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const u8, + y: *const u8, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let scalar = scalar_epilogue( + x, + y, + len.min(Self::SIMDWidth::value() - 1), + 0u32, + |acc, x: u8, y: u8| -> u32 { acc + (x as u32) * (y as u32) }, + ); + acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0]) + } + + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum_tree().as_f32_lossy() + } +} + impl SIMDSchema for IP { type SIMDWidth = Const<1>; type Accumulator = Emulated; @@ -1997,6 +2472,69 @@ impl SIMDSchema for CosineStateless { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for CosineStateless { + type SIMDWidth = Const<4>; + type Accumulator = FullCosineAccumulator<::f32x4>; + type Left = ::f32x4; + type Right = ::f32x4; + type Return = f32; + + // Cosine accumulators are pretty large, so only use 2 parallel accumulator with a + // hefty unroll factor. + type Main = Strategy2x4; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::new(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + acc.add_with(x, y) + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const f32, + y: *const f32, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let mut xx: f32 = 0.0; + let mut yy: f32 = 0.0; + let mut xy: f32 = 0.0; + for i in 0..len.min(Self::SIMDWidth::value() - 1) { + // SAFETY: The range `[x, x.add(len))` is valid for reads. + let vx = unsafe { x.add(i).read_unaligned() }; + // SAFETY: The range `[y, y.add(len))` is valid for reads. + let vy = unsafe { y.add(i).read_unaligned() }; + xx = vx.mul_add(vx, xx); + yy = vy.mul_add(vy, yy); + xy = vx.mul_add(vy, xy); + } + type V = ::f32x4; + acc + FullCosineAccumulator { + normx: V::from_array(arch, [xx, 0.0, 0.0, 0.0]), + normy: V::from_array(arch, [yy, 0.0, 0.0, 0.0]), + xy: V::from_array(arch, [xy, 0.0, 0.0, 0.0]), + } + } + + // Perform a final reduction. + #[inline(always)] + fn reduce(&self, acc: Self::Accumulator) -> Self::Return { + acc.sum() + } +} + impl SIMDSchema for CosineStateless { type SIMDWidth = Const<4>; type Accumulator = FullCosineAccumulator>; @@ -2096,6 +2634,67 @@ impl SIMDSchema for CosineStateless { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for CosineStateless { + type SIMDWidth = Const<4>; + type Accumulator = FullCosineAccumulator<::f32x4>; + type Left = diskann_wide::arch::aarch64::f16x4; + type Right = diskann_wide::arch::aarch64::f16x4; + type Return = f32; + + type Main = Strategy2x4; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::new(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + diskann_wide::alias!(f32s = ::f32x4); + + let x: f32s = x.into(); + let y: f32s = y.into(); + acc.add_with(x, y) + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const Half, + y: *const Half, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + type V = ::f32x4; + + let rest = scalar_epilogue( + x, + y, + len.min(Self::SIMDWidth::value() - 1), + FullCosineAccumulator::::new(arch), + |acc, x: Half, y: Half| -> FullCosineAccumulator { + let zero = Half::default(); + let x: V = Self::Left::from_array(arch, [x, zero, zero, zero]).into(); + let y: V = Self::Right::from_array(arch, [y, zero, zero, zero]).into(); + acc.add_with(x, y) + }, + ); + acc + rest + } + + #[inline(always)] + fn reduce(&self, acc: Self::Accumulator) -> Self::Return { + acc.sum() + } +} + impl SIMDSchema for CosineStateless { type SIMDWidth = Const<1>; type Accumulator = FullCosineAccumulator>; @@ -2232,7 +2831,70 @@ impl SIMDSchema for CosineStateless { } } - // Perform a final reduction. + // Perform a final reduction. + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum() + } +} + +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for CosineStateless { + type SIMDWidth = Const<16>; + type Accumulator = FullCosineAccumulator<::i32x4>; + type Left = ::i8x16; + type Right = ::i8x16; + type Return = f32; + type Main = Strategy4x1; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::new(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + FullCosineAccumulator { + normx: acc.normx.dot_simd(x, x), + normy: acc.normy.dot_simd(y, y), + xy: acc.xy.dot_simd(x, y), + } + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const i8, + y: *const i8, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let mut xx: i32 = 0; + let mut yy: i32 = 0; + let mut xy: i32 = 0; + for i in 0..len.min(Self::SIMDWidth::value() - 1) { + // SAFETY: The range `[x, x.add(len))` is valid for reads. + let vx: i32 = unsafe { x.add(i).read_unaligned() }.into(); + // SAFETY: The range `[y, y.add(len))` is valid for reads. + let vy: i32 = unsafe { y.add(i).read_unaligned() }.into(); + xx += vx * vx; + xy += vx * vy; + yy += vy * vy; + } + type V = ::i32x4; + acc + FullCosineAccumulator { + normx: V::from_array(arch, [xx, 0, 0, 0]), + normy: V::from_array(arch, [yy, 0, 0, 0]), + xy: V::from_array(arch, [xy, 0, 0, 0]), + } + } + #[inline(always)] fn reduce(&self, x: Self::Accumulator) -> Self::Return { x.sum() @@ -2285,9 +2947,9 @@ impl SIMDSchema for CosineStateless { for i in 0..len { // SAFETY: The range `[x, x.add(len))` is valid for reads. - let vx: i32 = unsafe { x.add(i).read() }.into(); + let vx: i32 = unsafe { x.add(i).read_unaligned() }.into(); // SAFETY: The range `[y, y.add(len))` is valid for reads. - let vy: i32 = unsafe { y.add(i).read() }.into(); + let vy: i32 = unsafe { y.add(i).read_unaligned() }.into(); xx += vx * vx; xy += vx * vy; @@ -2382,6 +3044,69 @@ impl SIMDSchema for CosineStateless { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for CosineStateless { + type SIMDWidth = Const<16>; + type Accumulator = FullCosineAccumulator<::u32x4>; + type Left = ::u8x16; + type Right = ::u8x16; + type Return = f32; + type Main = Strategy4x1; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::new(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + FullCosineAccumulator { + normx: acc.normx.dot_simd(x, x), + normy: acc.normy.dot_simd(y, y), + xy: acc.xy.dot_simd(x, y), + } + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const u8, + y: *const u8, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let mut xx: u32 = 0; + let mut yy: u32 = 0; + let mut xy: u32 = 0; + for i in 0..len.min(Self::SIMDWidth::value() - 1) { + // SAFETY: The range `[x, x.add(len))` is valid for reads. + let vx: u32 = unsafe { x.add(i).read_unaligned() }.into(); + // SAFETY: The range `[y, y.add(len))` is valid for reads. + let vy: u32 = unsafe { y.add(i).read_unaligned() }.into(); + xx += vx * vx; + xy += vx * vy; + yy += vy * vy; + } + type V = ::u32x4; + acc + FullCosineAccumulator { + normx: V::from_array(arch, [xx, 0, 0, 0]), + normy: V::from_array(arch, [yy, 0, 0, 0]), + xy: V::from_array(arch, [xy, 0, 0, 0]), + } + } + + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum() + } +} + impl SIMDSchema for CosineStateless { type SIMDWidth = Const<4>; type Accumulator = FullCosineAccumulator>; @@ -2428,9 +3153,9 @@ impl SIMDSchema for CosineStateless { for i in 0..len { // SAFETY: The range `[x, x.add(len))` is valid for reads. - let vx: i32 = unsafe { x.add(i).read() }.into(); + let vx: i32 = unsafe { x.add(i).read_unaligned() }.into(); // SAFETY: The range `[y, y.add(len))` is valid for reads. - let vy: i32 = unsafe { y.add(i).read() }.into(); + let vy: i32 = unsafe { y.add(i).read_unaligned() }.into(); xx += vx * vx; xy += vx * vy; @@ -2560,6 +3285,54 @@ impl SIMDSchema for L1Norm { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for L1Norm { + type SIMDWidth = Const<4>; + type Accumulator = ::f32x4; + type Left = ::f32x4; + type Right = ::f32x4; + type Return = f32; + type Main = Strategy4x1; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::default(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + _y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + x.abs_simd() + acc + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const f32, + _y: *const f32, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let mut s: f32 = 0.0; + for i in 0..len.min(Self::SIMDWidth::value() - 1) { + // SAFETY: The range `[x, x.add(len))` is valid for reads. + let vx = unsafe { x.add(i).read_unaligned() }; + s += vx.abs(); + } + acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0]) + } + + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum_tree() + } +} + impl SIMDSchema for L1Norm { type SIMDWidth = Const<4>; type Accumulator = Emulated; @@ -2601,7 +3374,7 @@ impl SIMDSchema for L1Norm { let mut s: f32 = 0.0; for i in 0..len { // SAFETY: The range `[x, x.add(len))` is valid for reads. - let vx = unsafe { x.add(i).read() }; + let vx = unsafe { x.add(i).read_unaligned() }; s += vx.abs(); } acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0]) @@ -2672,6 +3445,62 @@ impl SIMDSchema for L1Norm { } } +#[cfg(target_arch = "aarch64")] +impl SIMDSchema for L1Norm { + type SIMDWidth = Const<4>; + type Accumulator = ::f32x4; + type Left = diskann_wide::arch::aarch64::f16x4; + type Right = diskann_wide::arch::aarch64::f16x4; + type Return = f32; + type Main = Strategy2x4; + + #[inline(always)] + fn init(&self, arch: Neon) -> Self::Accumulator { + Self::Accumulator::default(arch) + } + + #[inline(always)] + fn accumulate( + &self, + x: Self::Left, + _y: Self::Right, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let x: ::f32x4 = x.into(); + x.abs_simd() + acc + } + + #[inline(always)] + unsafe fn epilogue( + &self, + arch: Neon, + x: *const Half, + _y: *const Half, + len: usize, + acc: Self::Accumulator, + ) -> Self::Accumulator { + let rest = scalar_epilogue( + x, + x, // unused, but scalar_epilogue requires a right pointer + len.min(Self::SIMDWidth::value() - 1), + Self::Accumulator::default(arch), + |acc, x: Half, _: Half| -> Self::Accumulator { + let zero = Half::default(); + let x: Self::Accumulator = + Self::Left::from_array(arch, [x, zero, zero, zero]).into(); + x.abs_simd() + acc + }, + ); + acc + rest + } + + // Perform a final reduction. + #[inline(always)] + fn reduce(&self, x: Self::Accumulator) -> Self::Return { + x.sum_tree() + } +} + impl SIMDSchema for L1Norm { type SIMDWidth = Const<1>; type Accumulator = Emulated; @@ -3138,6 +3967,19 @@ mod tests { V4::new_checked_miri() ); + #[cfg(target_arch = "aarch64")] + float_test!( + test_l2_f32_aarch64_neon, + L2, + ResumableL2, + reference::reference_squared_l2_f32_mathematical, + 1e-5, + 1e-5, + 0xf149c2bcde660128, + 256, + Neon::new_checked() + ); + //----// // IP // //----// @@ -3192,6 +4034,19 @@ mod tests { V4::new_checked_miri() ); + #[cfg(target_arch = "aarch64")] + float_test!( + test_ip_f32_aarch64_neon, + IP, + ResumableIP, + reference::reference_innerproduct_f32_mathematical, + 2e-4, + 1e-3, + 0xb4687c17a9ea9866, + 256, + Neon::new_checked() + ); + //--------// // Cosine // //--------// @@ -3246,6 +4101,19 @@ mod tests { V4::new_checked_miri() ); + #[cfg(target_arch = "aarch64")] + float_test!( + test_cosine_f32_aarch64_neon, + CosineStateless, + ResumableCosine, + reference::reference_cosine_f32_mathematical, + 1e-5, + 1e-5, + 0xe860e9dc65f38bb8, + 256, + Neon::new_checked() + ); + ///////// // f16 // ///////// @@ -3339,6 +4207,18 @@ mod tests { V4::new_checked_miri() ); + #[cfg(target_arch = "aarch64")] + half_test!( + test_l2_f16_aarch64_neon, + L2, + reference::reference_squared_l2_f16_mathematical, + 1e-5, + 1e-5, + 0x87ca6f1051667500, + 256, + Neon::new_checked() + ); + //----// // IP // //----// @@ -3389,6 +4269,18 @@ mod tests { V4::new_checked_miri() ); + #[cfg(target_arch = "aarch64")] + half_test!( + test_ip_f16_aarch64_neon, + IP, + reference::reference_innerproduct_f16_mathematical, + 2e-4, + 2e-4, + 0x5909f5f20307ccbe, + 256, + Neon::new_checked() + ); + //--------// // Cosine // //--------// @@ -3439,6 +4331,18 @@ mod tests { V4::new_checked_miri() ); + #[cfg(target_arch = "aarch64")] + half_test!( + test_cosine_f16_aarch64_neon, + CosineStateless, + reference::reference_cosine_f16_mathematical, + 1e-5, + 1e-5, + 0x41dda34655f05ef6, + 256, + Neon::new_checked() + ); + ///////////// // Integer // ///////////// @@ -3520,6 +4424,17 @@ mod tests { { V4::new_checked_miri() } ); + #[cfg(target_arch = "aarch64")] + int_test!( + test_l2_u8_aarch64_neon, + u8, + L2, + reference::reference_squared_l2_u8_mathematical, + 0x74c86334ab7a51f9, + 320, + { Neon::new_checked() } + ); + int_test!( test_ip_u8_current, u8, @@ -3562,6 +4477,17 @@ mod tests { { V4::new_checked_miri() } ); + #[cfg(target_arch = "aarch64")] + int_test!( + test_ip_u8_aarch64_neon, + u8, + IP, + reference::reference_innerproduct_u8_mathematical, + 0x888e07fc489e773f, + 320, + { Neon::new_checked() } + ); + int_test!( test_cosine_u8_current, u8, @@ -3604,6 +4530,17 @@ mod tests { { V4::new_checked_miri() } ); + #[cfg(target_arch = "aarch64")] + int_test!( + test_cosine_u8_aarch64_neon, + u8, + CosineStateless, + reference::reference_cosine_u8_mathematical, + 0xcc258c9391733211, + 320, + { Neon::new_checked() } + ); + //----// // I8 // //----// @@ -3650,6 +4587,17 @@ mod tests { { V4::new_checked_miri() } ); + #[cfg(target_arch = "aarch64")] + int_test!( + test_l2_i8_aarch64_neon, + i8, + L2, + reference::reference_squared_l2_i8_mathematical, + 0x3e8bada709e176be, + 320, + { Neon::new_checked() } + ); + int_test!( test_ip_i8_current, i8, @@ -3692,6 +4640,17 @@ mod tests { { V4::new_checked_miri() } ); + #[cfg(target_arch = "aarch64")] + int_test!( + test_ip_i8_aarch64_neon, + i8, + IP, + reference::reference_innerproduct_i8_mathematical, + 0x8a263408c7b31d85, + 320, + { Neon::new_checked() } + ); + int_test!( test_cosine_i8_current, i8, @@ -3734,6 +4693,17 @@ mod tests { { V4::new_checked_miri() } ); + #[cfg(target_arch = "aarch64")] + int_test!( + test_cosine_i8_aarch64_neon, + i8, + CosineStateless, + reference::reference_cosine_i8_mathematical, + 0x2d077bed2629b18e, + 320, + { Neon::new_checked() } + ); + ////////// // LInf // ////////// @@ -3803,6 +4773,18 @@ mod tests { V4::new_checked_miri() ); + #[cfg(target_arch = "aarch64")] + linf_test!( + test_linf_f32_neon, + f32, + reference::reference_linf_f32_mathematical, + 1e-6, + 1e-6, + 0xf149c2bcde660128, + 256, + Neon::new_checked() + ); + linf_test!( test_linf_f16_scalar, f16, @@ -3838,6 +4820,18 @@ mod tests { V4::new_checked_miri() ); + #[cfg(target_arch = "aarch64")] + linf_test!( + test_linf_f16_neon, + f16, + reference::reference_linf_f16_mathematical, + 1e-6, + 1e-6, + 0xf149c2bcde660128, + 256, + Neon::new_checked() + ); + //////////////// // Miri Tests // //////////////// @@ -3885,6 +4879,7 @@ mod tests { X86_64_V3, #[expect(non_camel_case_types)] X86_64_V4, + Aarch64Neon, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -3901,25 +4896,30 @@ mod tests { } static MIRI_BOUNDS: LazyLock> = LazyLock::new(|| { - use Arch::{Scalar, X86_64_V3, X86_64_V4}; + use Arch::{Aarch64Neon, Scalar, X86_64_V3, X86_64_V4}; use DataType::{Float16, Float32, Int8, UInt8}; [ (Key::new(Scalar, Float32, Float32), 64), (Key::new(X86_64_V3, Float32, Float32), 256), (Key::new(X86_64_V4, Float32, Float32), 256), + (Key::new(Aarch64Neon, Float32, Float32), 128), (Key::new(Scalar, Float16, Float16), 64), (Key::new(X86_64_V3, Float16, Float16), 256), (Key::new(X86_64_V4, Float16, Float16), 256), + (Key::new(Aarch64Neon, Float16, Float16), 128), (Key::new(Scalar, Float32, Float16), 64), (Key::new(X86_64_V3, Float32, Float16), 256), (Key::new(X86_64_V4, Float32, Float16), 256), + (Key::new(Aarch64Neon, Float32, Float16), 128), (Key::new(Scalar, UInt8, UInt8), 64), (Key::new(X86_64_V3, UInt8, UInt8), 256), (Key::new(X86_64_V4, UInt8, UInt8), 320), + (Key::new(Aarch64Neon, UInt8, UInt8), 128), (Key::new(Scalar, Int8, Int8), 64), (Key::new(X86_64_V3, Int8, Int8), 256), (Key::new(X86_64_V4, Int8, Int8), 320), + (Key::new(Aarch64Neon, Int8, Int8), 128), ] .into_iter() .collect() @@ -3980,6 +4980,19 @@ mod tests { simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice()); } } + + #[cfg(target_arch = "aarch64")] + if let Some(arch) = Neon::new_checked() { + let max = MIRI_BOUNDS[&Key::new(Arch::Aarch64Neon, left_type, right_type)]; + for dim in 0..max { + let left: Vec<$left> = vec![left; dim]; + let right: Vec<$right> = vec![right; dim]; + + simd_op(&L2, arch, left.as_slice(), right.as_slice()); + simd_op(&IP, arch, left.as_slice(), right.as_slice()); + simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice()); + } + } } }; } diff --git a/diskann-wide/compile-aarch64-on-x86.sh b/diskann-wide/compile-aarch64-on-x86.sh new file mode 100755 index 000000000..8d4be102f --- /dev/null +++ b/diskann-wide/compile-aarch64-on-x86.sh @@ -0,0 +1,4 @@ +RUSTFLAGS="-Ctarget-feature=+neon,+dotprod -Clinker=aarch64-linux-gnu-gcc" cargo test \ + --package "$1" \ + --target aarch64-unknown-linux-gnu \ + --profile ci diff --git a/diskann-wide/src/arch/aarch64/double.rs b/diskann-wide/src/arch/aarch64/double.rs new file mode 100644 index 000000000..5880fbc1a --- /dev/null +++ b/diskann-wide/src/arch/aarch64/double.rs @@ -0,0 +1,373 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use half::f16; + +use crate::{ + LoHi, SplitJoin, + doubled::{self, Doubled}, +}; + +use super::{ + f16x8, f32x4, i8x16, i16x8, i32x4, + masks::{mask8x16, mask16x8, mask32x4, mask64x2}, + u8x16, u32x4, u64x2, +}; + +// Double Masks +doubled::double_mask!(32, mask8x16); +doubled::double_mask!(16, mask16x8); +doubled::double_mask!(8, mask32x4); +doubled::double_mask!(4, mask64x2); + +// Double-Double Masks +doubled::double_mask!(64, Doubled); +doubled::double_mask!(32, Doubled); +doubled::double_mask!(16, Doubled); + +macro_rules! double_alias { + ($type:ident, $scalar:ty, $lanes:literal, $subtype:ty) => { + // Implement `SIMDVector` and friends for the `Double` type. + doubled::double_vector!($scalar, $lanes, $subtype); + + #[allow(non_camel_case_types)] + pub type $type = Doubled<$subtype>; + }; +} + +// Double Wide +double_alias!(f32x8, f32, 8, f32x4); + +double_alias!(u8x32, u8, 32, u8x16); +double_alias!(u32x8, u32, 8, u32x4); +double_alias!(u64x4, u64, 4, u64x2); + +double_alias!(i8x32, i8, 32, i8x16); +double_alias!(i16x16, i16, 16, i16x8); +double_alias!(i32x8, i32, 8, i32x4); + +doubled::double_scalar_shift!(Doubled); +doubled::double_scalar_shift!(Doubled); +doubled::double_scalar_shift!(Doubled); + +doubled::double_scalar_shift!(Doubled); +doubled::double_scalar_shift!(Doubled); +doubled::double_scalar_shift!(Doubled); + +// Double-Double Wide +double_alias!(f32x16, f32, 16, f32x8); +double_alias!(f16x16, f16, 16, f16x8); + +double_alias!(u8x64, u8, 64, u8x32); +double_alias!(u32x16, u32, 16, u32x8); + +double_alias!(i8x64, i8, 64, i8x32); +double_alias!(i16x32, i16, 32, i16x16); +double_alias!(i32x16, i32, 16, i32x8); + +doubled::double_scalar_shift!(Doubled>); +doubled::double_scalar_shift!(Doubled>); +doubled::double_scalar_shift!(Doubled>); + +doubled::double_scalar_shift!(Doubled>); +doubled::double_scalar_shift!(Doubled>); +doubled::double_scalar_shift!(Doubled>); + +//-------------// +// Conversions // +//-------------// + +// Lossless +impl From for f32x8 { + #[inline(always)] + fn from(value: f16x8) -> Self { + let LoHi { lo, hi } = value.split(); + Self::new(lo.into(), hi.into()) + } +} + +impl From for f32x16 { + #[inline(always)] + fn from(value: f16x16) -> Self { + Self::new(value.0.into(), value.1.into()) + } +} + +impl From for i16x16 { + #[inline(always)] + fn from(value: u8x16) -> Self { + let LoHi { lo, hi } = value.split(); + Self::new(lo.into(), hi.into()) + } +} + +impl From for i16x32 { + #[inline(always)] + fn from(value: u8x32) -> Self { + Self::new(value.0.into(), value.1.into()) + } +} + +impl From for i16x16 { + #[inline(always)] + fn from(value: i8x16) -> Self { + let LoHi { lo, hi } = value.split(); + Self::new(lo.into(), hi.into()) + } +} + +impl From for i16x32 { + #[inline(always)] + fn from(value: i8x32) -> Self { + Self::new(value.0.into(), value.1.into()) + } +} + +// (Potentially) Lossy +impl crate::SIMDCast for f16x16 { + type Cast = f32x16; + #[inline(always)] + fn simd_cast(self) -> f32x16 { + self.into() + } +} + +impl crate::SIMDCast for f32x8 { + type Cast = f16x8; + #[inline(always)] + fn simd_cast(self) -> f16x8 { + f16x8::join(LoHi::new(self.0.simd_cast(), self.1.simd_cast())) + } +} + +impl crate::SIMDCast for f32x16 { + type Cast = f16x16; + #[inline(always)] + fn simd_cast(self) -> f16x16 { + f16x16::new(self.0.simd_cast(), self.1.simd_cast()) + } +} + +impl crate::SIMDCast for i32x8 { + type Cast = f32x8; + #[inline(always)] + fn simd_cast(self) -> f32x8 { + f32x8::new(self.0.simd_cast(), self.1.simd_cast()) + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{arch::aarch64::Neon, reference::ReferenceScalarOps, test_utils}; + + // Run a standard set of: + // - Load + // - Store + // - Add, Sub, Mul, FMA + // - SIMDPartialEq, SIMDPartialCmp + macro_rules! standard_tests { + ($type:ident, $scalar:ty, $lanes:literal) => { + #[test] + fn miri_test_load() { + test_utils::test_load_simd::<$scalar, $lanes, $type>(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::<$scalar, $lanes, $type>(Neon::new_checked().unwrap()); + } + + #[test] + fn test_constructors() { + test_utils::ops::test_splat::<$scalar, $lanes, $type>(Neon::new_checked().unwrap()); + } + + test_utils::ops::test_add!($type, 0x1c08175714ae637e, Neon::new_checked()); + test_utils::ops::test_sub!($type, 0x3746ddcb006b7b4c, Neon::new_checked()); + test_utils::ops::test_mul!($type, 0xde99e62aaea3f38a, Neon::new_checked()); + test_utils::ops::test_fma!($type, 0x2e301b7e12090d5c, Neon::new_checked()); + + test_utils::ops::test_cmp!($type, 0x90a59e23ad545de1, Neon::new_checked()); + }; + } + + // f32s + mod test_f32x8 { + use super::*; + standard_tests!(f32x8, f32, 8); + test_utils::ops::test_sumtree!(f32x8, 0x90a59e23ad545de1, Neon::new_checked()); + test_utils::ops::test_splitjoin!(f32x8 => f32x4, 0x2e301b7e12090d5c, Neon::new_checked()); + } + + mod test_f32x16 { + use super::*; + standard_tests!(f32x16, f32, 16); + test_utils::ops::test_sumtree!(f32x16, 0x90a59e23ad545de1, Neon::new_checked()); + test_utils::ops::test_splitjoin!(f32x16 => f32x8, 0x2e301b7e12090d5c, Neon::new_checked()); + } + + // u8s + mod test_u8x32 { + use super::*; + standard_tests!(u8x32, u8, 32); + + // Bit ops + test_utils::ops::test_bitops!(u8x32, 0xd62d8de09f82ed4e, Neon::new_checked()); + } + + mod test_u8x64 { + use super::*; + standard_tests!(u8x64, u8, 64); + + // Bit ops + test_utils::ops::test_bitops!(u8x64, 0xd62d8de09f82ed4e, Neon::new_checked()); + } + + // u32s + mod test_u32x8 { + use super::*; + standard_tests!(u32x8, u32, 8); + + // Bit ops + test_utils::ops::test_bitops!(u32x8, 0xd62d8de09f82ed4e, Neon::new_checked()); + + // Reductions + test_utils::ops::test_sumtree!(u32x8, 0x90a59e23ad545de1, Neon::new_checked()); + } + + mod test_u32x16 { + use super::*; + standard_tests!(u32x16, u32, 16); + + // Bit ops + test_utils::ops::test_bitops!(u32x16, 0xd62d8de09f82ed4e, Neon::new_checked()); + + // Reductions + test_utils::ops::test_sumtree!(u32x16, 0x90a59e23ad545de1, Neon::new_checked()); + } + + // u64s + mod test_u64x4 { + use super::*; + standard_tests!(u64x4, u64, 4); + + // Bit ops + test_utils::ops::test_bitops!(u64x4, 0xc4491a44af4aa58e, Neon::new_checked()); + } + + // i8s + mod test_i8x32 { + use super::*; + standard_tests!(i8x32, i8, 32); + + // Bit ops + test_utils::ops::test_bitops!(i8x32, 0xd62d8de09f82ed4e, Neon::new_checked()); + } + + mod test_i8x64 { + use super::*; + standard_tests!(i8x64, i8, 64); + + // Bit ops + test_utils::ops::test_bitops!(i8x64, 0xd62d8de09f82ed4e, Neon::new_checked()); + } + + // i16s + mod test_i16x16 { + use super::*; + standard_tests!(i16x16, i16, 16); + + // Bit ops + test_utils::ops::test_bitops!(i16x16, 0x9167644fc4ad5cfa, Neon::new_checked()); + } + + mod test_i16x32 { + use super::*; + standard_tests!(i16x32, i16, 32); + + // Bit ops + test_utils::ops::test_bitops!(i16x32, 0x9167644fc4ad5cfa, Neon::new_checked()); + } + + // i32s + mod test_i32x8 { + use super::*; + standard_tests!(i32x8, i32, 8); + + // Bit ops + test_utils::ops::test_bitops!(i32x8, 0xc4491a44af4aa58e, Neon::new_checked()); + + // Dot Products + test_utils::dot_product::test_dot_product!( + (i16x16, i16x16) => i32x8, + 0x145f89b446c03ff1, + Neon::new_checked() + ); + + test_utils::dot_product::test_dot_product!( + (u8x32, i8x32) => i32x8, + 0x145f89b446c03ff1, + Neon::new_checked() + ); + + test_utils::dot_product::test_dot_product!( + (i8x32, u8x32) => i32x8, + 0x145f89b446c03ff1, + Neon::new_checked() + ); + + // Reductions + test_utils::ops::test_sumtree!(i32x8, 0x90a59e23ad545de1, Neon::new_checked()); + } + + mod test_i32x16 { + use super::*; + standard_tests!(i32x16, i32, 16); + + // Bit ops + test_utils::ops::test_bitops!(i32x16, 0xc4491a44af4aa58e, Neon::new_checked()); + + // Dot Products + test_utils::dot_product::test_dot_product!( + (i16x32, i16x32) => i32x16, + 0x145f89b446c03ff1, + Neon::new_checked() + ); + + test_utils::dot_product::test_dot_product!( + (u8x64, i8x64) => i32x16, + 0x145f89b446c03ff1, + Neon::new_checked() + ); + + test_utils::dot_product::test_dot_product!( + (i8x64, u8x64) => i32x16, + 0x145f89b446c03ff1, + Neon::new_checked() + ); + + // Reductions + test_utils::ops::test_sumtree!(i32x16, 0x90a59e23ad545de1, Neon::new_checked()); + } + + // Conversions + test_utils::ops::test_lossless_convert!(f16x8 => f32x8, 0x84c1c6f05b169a20, Neon::new_checked()); + test_utils::ops::test_lossless_convert!(f16x16 => f32x16, 0x84c1c6f05b169a20, Neon::new_checked()); + + test_utils::ops::test_lossless_convert!(u8x16 => i16x16, 0x84c1c6f05b169a20, Neon::new_checked()); + test_utils::ops::test_lossless_convert!(i8x16 => i16x16, 0x84c1c6f05b169a20, Neon::new_checked()); + + test_utils::ops::test_cast!(f16x8 => f32x8, 0xba8fe343fc9dbeff, Neon::new_checked()); + test_utils::ops::test_cast!(f16x16 => f32x16, 0xba8fe343fc9dbeff, Neon::new_checked()); + test_utils::ops::test_cast!(f32x8 => f16x8, 0xba8fe343fc9dbeff, Neon::new_checked()); + test_utils::ops::test_cast!(f32x16 => f16x16, 0xba8fe343fc9dbeff, Neon::new_checked()); + + test_utils::ops::test_cast!(i32x8 => f32x8, 0xba8fe343fc9dbeff, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/f16x4_.rs b/diskann-wide/src/arch/aarch64/f16x4_.rs new file mode 100644 index 000000000..b40814695 --- /dev/null +++ b/diskann-wide/src/arch/aarch64/f16x4_.rs @@ -0,0 +1,113 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, + arch::Scalar, + constant::Const, + traits::{SIMDMask, SIMDVector}, +}; + +use half::f16; + +// AArch64 masks +use super::{ + Neon, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask16x4, +}; + +// AArch64 intrinsics +use std::arch::aarch64::*; + +///////////////////// +// 16-bit floating // +///////////////////// + +macros::aarch64_define_register!(f16x4, uint16x4_t, mask16x4, f16, 4, Neon); + +impl AArchSplat for f16x4 { + #[inline(always)] + fn aarch_splat(_: Neon, value: f16) -> Self { + // SAFETY: Allowed by the `Neon` architecture. + Self(unsafe { vmov_n_u16(value.to_bits()) }) + } + + #[inline(always)] + fn aarch_default(arch: Neon) -> Self { + Self::aarch_splat(arch, f16::default()) + } +} + +impl AArchLoadStore for f16x4 { + #[inline(always)] + unsafe fn load_simd(_: Neon, ptr: *const f16) -> Self { + // SAFETY: Allowed by the `Neon` architecture. + Self(unsafe { vld1_u16(ptr.cast::()) }) + } + + #[inline(always)] + unsafe fn load_simd_masked_logical(arch: Neon, ptr: *const f16, mask: Self::Mask) -> Self { + // SAFETY: Pointer access safety inhereted from the caller. + let e = unsafe { + Emulated::::load_simd_masked_logical(Scalar, ptr, mask.bitmask().as_scalar()) + }; + Self::from_array(arch, e.to_array()) + } + + #[inline(always)] + unsafe fn load_simd_first(arch: Neon, ptr: *const f16, first: usize) -> Self { + // SAFETY: Pointer access safety inhereted from the caller. + let e = unsafe { Emulated::::load_simd_first(Scalar, ptr, first) }; + Self::from_array(arch, e.to_array()) + } + + #[inline(always)] + unsafe fn store_simd(self, ptr: *mut ::Scalar) { + // SAFETY: Pointer access safety inhereted from the caller. Use of the instruction + // is allowed by the `Neon` architecture. + unsafe { vst1_u16(ptr.cast::(), self.0) } + } + + #[inline(always)] + unsafe fn store_simd_masked_logical(self, ptr: *mut f16, mask: Self::Mask) { + let e = Emulated::::from_array(Scalar, self.to_array()); + // SAFETY: Pointer access safety inhereted from the caller. + unsafe { e.store_simd_masked_logical(ptr, mask.bitmask().as_scalar()) } + } + + #[inline(always)] + unsafe fn store_simd_first(self, ptr: *mut f16, first: usize) { + let e = Emulated::::from_array(Scalar, self.to_array()); + // SAFETY: Pointer access safety inhereted from the caller. + unsafe { e.store_simd_first(ptr, first) } + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } +} diff --git a/diskann-wide/src/arch/aarch64/f16x8_.rs b/diskann-wide/src/arch/aarch64/f16x8_.rs new file mode 100644 index 000000000..a688b4060 --- /dev/null +++ b/diskann-wide/src/arch/aarch64/f16x8_.rs @@ -0,0 +1,133 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, + arch::Scalar, + constant::Const, + traits::{SIMDMask, SIMDVector}, +}; + +use half::f16; + +// AArch64 masks +use super::{ + Neon, f16x4, f32x8, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask16x8, +}; + +// AArch64 intrinsics +use std::arch::aarch64::*; + +///////////////////// +// 16-bit floating // +///////////////////// + +macros::aarch64_define_register!(f16x8, uint16x8_t, mask16x8, f16, 8, Neon); +macros::aarch64_splitjoin!(f16x8, f16x4, vget_low_u16, vget_high_u16, vcombine_u16); + +impl AArchSplat for f16x8 { + #[inline(always)] + fn aarch_splat(_: Neon, value: f16) -> Self { + // SAFETY: Allowed by the `Neon` architecture. + Self(unsafe { vmovq_n_u16(value.to_bits()) }) + } + + #[inline(always)] + fn aarch_default(arch: Neon) -> Self { + Self::aarch_splat(arch, f16::default()) + } +} + +impl AArchLoadStore for f16x8 { + #[inline(always)] + unsafe fn load_simd(_: Neon, ptr: *const f16) -> Self { + // SAFETY: Pointer access safety inhereted from the caller.Allowed by the `Neon` + // architecture. + Self(unsafe { vld1q_u16(ptr.cast::()) }) + } + + #[inline(always)] + unsafe fn load_simd_masked_logical(arch: Neon, ptr: *const f16, mask: Self::Mask) -> Self { + // SAFETY: Pointer access safety inhereted from the caller. + let e = unsafe { + Emulated::::load_simd_masked_logical(Scalar, ptr, mask.bitmask().as_scalar()) + }; + Self::from_array(arch, e.to_array()) + } + + #[inline(always)] + unsafe fn load_simd_first(arch: Neon, ptr: *const f16, first: usize) -> Self { + // SAFETY: Pointer access safety inhereted from the caller. + let e = unsafe { Emulated::::load_simd_first(Scalar, ptr, first) }; + Self::from_array(arch, e.to_array()) + } + + #[inline(always)] + unsafe fn store_simd(self, ptr: *mut ::Scalar) { + // SAFETY: Pointer access safety inhereted from the caller. Use of the instruction + // is allowed by the `Neon` architecture. + unsafe { vst1q_u16(ptr.cast::(), self.0) } + } + + #[inline(always)] + unsafe fn store_simd_masked_logical(self, ptr: *mut f16, mask: Self::Mask) { + let e = Emulated::::from_array(Scalar, self.to_array()); + // SAFETY: Pointer access safety inhereted from the caller. + unsafe { e.store_simd_masked_logical(ptr, mask.bitmask().as_scalar()) } + } + + #[inline(always)] + unsafe fn store_simd_first(self, ptr: *mut f16, first: usize) { + let e = Emulated::::from_array(Scalar, self.to_array()); + // SAFETY: Pointer access safety inhereted from the caller. + unsafe { e.store_simd_first(ptr, first) } + } +} + +//------------// +// Conversion // +//------------// + +impl crate::SIMDCast for f16x8 { + type Cast = f32x8; + + #[inline(always)] + fn simd_cast(self) -> f32x8 { + self.into() + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + test_utils::ops::test_splitjoin!(f16x8 => f16x4, 0xa4d00a4d04293967, Neon::new_checked()); + + // Conversions + test_utils::ops::test_cast!(f16x8 => f32x8, 0x37314659b022466a, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/f32x2_.rs b/diskann-wide/src/arch/aarch64/f32x2_.rs new file mode 100644 index 000000000..3bf1176ed --- /dev/null +++ b/diskann-wide/src/arch/aarch64/f32x2_.rs @@ -0,0 +1,92 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, + constant::Const, + helpers, + traits::{SIMDMask, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDSumTree, SIMDVector}, +}; + +// AArch64 masks +use super::{ + Neon, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask32x2, +}; + +// AArch64 intrinsics +use std::arch::aarch64::*; + +///////////////////// +// 32-bit floating // +///////////////////// + +macros::aarch64_define_register!(f32x2, float32x2_t, mask32x2, f32, 2, Neon); +macros::aarch64_define_splat!(f32x2, vmov_n_f32); +macros::aarch64_define_loadstore!(f32x2, vld1_f32, vst1_f32, 2); + +helpers::unsafe_map_binary_op!(f32x2, std::ops::Add, add, vadd_f32, "neon"); +helpers::unsafe_map_binary_op!(f32x2, std::ops::Sub, sub, vsub_f32, "neon"); +helpers::unsafe_map_binary_op!(f32x2, std::ops::Mul, mul, vmul_f32, "neon"); +macros::aarch64_define_fma!(f32x2, vfma_f32); + +macros::aarch64_define_cmp!( + f32x2, + vceq_f32, + (vmvn_u32), + vclt_f32, + vcle_f32, + vcgt_f32, + vcge_f32 +); + +impl SIMDSumTree for f32x2 { + #[inline(always)] + fn sum_tree(self) -> f32 { + if cfg!(miri) { + self.sum_tree() + } else { + // SAFETY: This file is gated by the "neon" target feature. + unsafe { vaddv_f32(self.to_underlying()) } + } + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Ops + test_utils::ops::test_add!(f32x2, 0xcd7a8fea9a3fb727, Neon::new_checked()); + test_utils::ops::test_sub!(f32x2, 0x3f6562c94c923238, Neon::new_checked()); + test_utils::ops::test_mul!(f32x2, 0x07e48666c0fc564c, Neon::new_checked()); + test_utils::ops::test_fma!(f32x2, 0xcfde9d031302cf2c, Neon::new_checked()); + + test_utils::ops::test_cmp!(f32x2, 0xc4f468b224622326, Neon::new_checked()); + + test_utils::ops::test_sumtree!(f32x2, 0x828bd890a470dc4d, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/f32x4_.rs b/diskann-wide/src/arch/aarch64/f32x4_.rs new file mode 100644 index 000000000..aefc0dfd8 --- /dev/null +++ b/diskann-wide/src/arch/aarch64/f32x4_.rs @@ -0,0 +1,194 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use half::f16; + +use crate::{ + Emulated, SIMDAbs, SIMDMask, SIMDMinMax, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDSelect, + SIMDSumTree, SIMDVector, constant::Const, helpers, +}; + +// AArch64 masks +use super::{ + Neon, f16x4, f32x2, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask32x4, +}; + +// AArch64 intrinsics +use std::arch::{aarch64::*, asm}; + +///////////////////// +// 32-bit floating // +///////////////////// + +macros::aarch64_define_register!(f32x4, float32x4_t, mask32x4, f32, 4, Neon); +macros::aarch64_define_splat!(f32x4, vmovq_n_f32); +macros::aarch64_define_loadstore!(f32x4, vld1q_f32, vst1q_f32, 4); +macros::aarch64_splitjoin!(f32x4, f32x2, vget_low_f32, vget_high_f32, vcombine_f32); + +helpers::unsafe_map_binary_op!(f32x4, std::ops::Add, add, vaddq_f32, "neon"); +helpers::unsafe_map_binary_op!(f32x4, std::ops::Sub, sub, vsubq_f32, "neon"); +helpers::unsafe_map_binary_op!(f32x4, std::ops::Mul, mul, vmulq_f32, "neon"); +helpers::unsafe_map_unary_op!(f32x4, SIMDAbs, abs_simd, vabsq_f32, "neon"); +macros::aarch64_define_fma!(f32x4, vfmaq_f32); + +impl SIMDMinMax for f32x4 { + #[inline(always)] + fn min_simd(self, rhs: Self) -> Self { + Self(unsafe { vminnmq_f32(self.0, rhs.0) }) + } + + #[inline(always)] + fn min_simd_standard(self, rhs: Self) -> Self { + Self(unsafe { vminnmq_f32(self.0, rhs.0) }) + } + + #[inline(always)] + fn max_simd(self, rhs: Self) -> Self { + Self(unsafe { vmaxnmq_f32(self.0, rhs.0) }) + } + + #[inline(always)] + fn max_simd_standard(self, rhs: Self) -> Self { + Self(unsafe { vmaxnmq_f32(self.0, rhs.0) }) + } +} + +macros::aarch64_define_cmp!( + f32x4, + vceqq_f32, + (vmvnq_u32), + vcltq_f32, + vcleq_f32, + vcgtq_f32, + vcgeq_f32 +); + +impl SIMDSumTree for f32x4 { + #[inline(always)] + fn sum_tree(self) -> f32 { + // Miri does not support `vaddv_f32`. + if cfg!(miri) { + self.emulated().sum_tree() + } else { + // NOTE: `vaddvq` does not do a tree reduction, so we need to do a bit of work + // manually. + let x = self.to_underlying(); + // SAFETY: Allowed by the implicit `Neon` architecture. + unsafe { + let low = vget_low_f32(x); + let high = vget_high_f32(x); + vaddv_f32(vadd_f32(low, high)) + } + } + } +} + +impl SIMDSelect for mask32x4 { + #[inline(always)] + fn select(self, x: f32x4, y: f32x4) -> f32x4 { + // SAFETY: Allowed by the implicit `Neon` architecture. + f32x4(unsafe { vbslq_f32(self.0, x.0, y.0) }) + } +} + +//------------// +// Conversion // +//------------// + +// Rust does not expose any of the f16 style intrinsics, so we need to drop down straight +// into inline assembly. +impl From for f32x4 { + #[inline(always)] + fn from(value: f16x4) -> f32x4 { + if cfg!(miri) { + Self::from_array(value.arch(), value.to_array().map(crate::cast_f16_to_f32)) + } else { + let raw = value.0; + let result: float32x4_t; + // SAFETY: The instruction we are running is available with the `neon` platform, + // just not exposed by Rust's intrinsics. + unsafe { + asm!( + "fcvtl {0:v}.4s, {1:v}.4h", + out(vreg) result, + in(vreg) raw, + options(pure, nomem, nostack) + ); + } + Self(result) + } + } +} + +impl crate::SIMDCast for f32x4 { + type Cast = f16x4; + #[inline(always)] + fn simd_cast(self) -> f16x4 { + if cfg!(miri) { + f16x4::from_array(self.arch(), self.to_array().map(crate::cast_f32_to_f16)) + } else { + let raw = self.0; + let result: uint16x4_t; + // SAFETY: The instruction we are running is available with the `neon` platform, + // just not exposed by Rust's intrinsics. + unsafe { + asm!( + "fcvtn {0:v}.4h, {1:v}.4s", + out(vreg) result, + in(vreg) raw, + options(pure, nomem, nostack) + ); + } + f16x4::from_underlying(self.arch(), result) + } + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Ops + test_utils::ops::test_add!(f32x4, 0xcd7a8fea9a3fb727, Neon::new_checked()); + test_utils::ops::test_sub!(f32x4, 0x3f6562c94c923238, Neon::new_checked()); + test_utils::ops::test_mul!(f32x4, 0x07e48666c0fc564c, Neon::new_checked()); + test_utils::ops::test_fma!(f32x4, 0xcfde9d031302cf2c, Neon::new_checked()); + test_utils::ops::test_abs!(f32x4, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_minmax!(f32x4, 0x6d7fc8ed6d852187, Neon::new_checked()); + test_utils::ops::test_splitjoin!(f32x4 => f32x2, 0xa4d00a4d04293967, Neon::new_checked()); + + test_utils::ops::test_cmp!(f32x4, 0xc4f468b224622326, Neon::new_checked()); + test_utils::ops::test_select!(f32x4, 0xef24013b8578637c, Neon::new_checked()); + + test_utils::ops::test_sumtree!(f32x4, 0x828bd890a470dc4d, Neon::new_checked()); + + // Conversions + test_utils::ops::test_lossless_convert!(f16x4 => f32x4, 0xecba3008eae54ce7, Neon::new_checked()); + + test_utils::ops::test_cast!(f32x4 => f16x4, 0xba8fe343fc9dbeff, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/i16x8_.rs b/diskann-wide/src/arch/aarch64/i16x8_.rs new file mode 100644 index 000000000..bd14b633d --- /dev/null +++ b/diskann-wide/src/arch/aarch64/i16x8_.rs @@ -0,0 +1,106 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, SIMDAbs, SIMDMask, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDVector, + constant::Const, helpers, +}; + +// AArch64 masks +use super::{ + Neon, i8x8, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask16x8, + u8x8, +}; + +// AArch64 intrinsics +use std::arch::aarch64::*; + +/////////////////// +// 16-bit signed // +/////////////////// + +macros::aarch64_define_register!(i16x8, int16x8_t, mask16x8, i16, 8, Neon); +macros::aarch64_define_splat!(i16x8, vmovq_n_s16); +macros::aarch64_define_loadstore!(i16x8, vld1q_s16, vst1q_s16, 8); + +helpers::unsafe_map_binary_op!(i16x8, std::ops::Add, add, vaddq_s16, "neon"); +helpers::unsafe_map_binary_op!(i16x8, std::ops::Sub, sub, vsubq_s16, "neon"); +helpers::unsafe_map_binary_op!(i16x8, std::ops::Mul, mul, vmulq_s16, "neon"); +helpers::unsafe_map_unary_op!(i16x8, SIMDAbs, abs_simd, vabsq_s16, "neon"); +macros::aarch64_define_fma!(i16x8, vmlaq_s16); + +macros::aarch64_define_cmp!( + i16x8, + vceqq_s16, + (vmvnq_u16), + vcltq_s16, + vcleq_s16, + vcgtq_s16, + vcgeq_s16 +); +macros::aarch64_define_bitops!( + i16x8, + vmvnq_s16, + vandq_s16, + vorrq_s16, + veorq_s16, + ( + vshlq_s16, + 16, + vnegq_s16, + vminq_u16, + vreinterpretq_s16_u16, + vreinterpretq_u16_s16 + ), + (u16, i16, vmovq_n_s16), +); + +// Conversion +helpers::unsafe_map_conversion!(i8x8, i16x8, vmovl_s8, "neon"); +helpers::unsafe_map_conversion!(u8x8, i16x8, (vreinterpretq_s16_u16, vmovl_u8), "neon"); + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Ops + test_utils::ops::test_add!(i16x8, 0x3017fd73c99cc633, Neon::new_checked()); + test_utils::ops::test_sub!(i16x8, 0xfc627f10b5f8db8a, Neon::new_checked()); + test_utils::ops::test_mul!(i16x8, 0x0f4caa80eceaa523, Neon::new_checked()); + test_utils::ops::test_fma!(i16x8, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_abs!(i16x8, 0xb8f702ba85375041, Neon::new_checked()); + + test_utils::ops::test_cmp!(i16x8, 0x941757bd5cc641a1, Neon::new_checked()); + + // Bit ops + test_utils::ops::test_bitops!(i16x8, 0xd62d8de09f82ed4e, Neon::new_checked()); + + // Conversion + test_utils::ops::test_lossless_convert!(i8x8 => i16x8, 0x79458ca52356242e, Neon::new_checked()); + test_utils::ops::test_lossless_convert!(u8x8 => i16x8, 0xa9a57c5c541ce360, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/i32x4_.rs b/diskann-wide/src/arch/aarch64/i32x4_.rs new file mode 100644 index 000000000..2c7b905ea --- /dev/null +++ b/diskann-wide/src/arch/aarch64/i32x4_.rs @@ -0,0 +1,259 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, SIMDAbs, SIMDCast, SIMDDotProduct, SIMDMask, SIMDMulAdd, SIMDPartialEq, + SIMDPartialOrd, SIMDSelect, SIMDSumTree, SIMDVector, constant::Const, helpers, +}; + +// AArch64 masks +use super::{ + Neon, f32x4, i8x8, i8x16, i16x8, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask32x4, + u8x8, u8x16, +}; + +// AArch64 intrinsics +use std::arch::{aarch64::*, asm}; + +/////////////////// +// 32-bit signed // +/////////////////// + +macros::aarch64_define_register!(i32x4, int32x4_t, mask32x4, i32, 4, Neon); +macros::aarch64_define_splat!(i32x4, vmovq_n_s32); +macros::aarch64_define_loadstore!(i32x4, vld1q_s32, vst1q_s32, 4); + +helpers::unsafe_map_binary_op!(i32x4, std::ops::Add, add, vaddq_s32, "neon"); +helpers::unsafe_map_binary_op!(i32x4, std::ops::Sub, sub, vsubq_s32, "neon"); +helpers::unsafe_map_binary_op!(i32x4, std::ops::Mul, mul, vmulq_s32, "neon"); +helpers::unsafe_map_unary_op!(i32x4, SIMDAbs, abs_simd, vabsq_s32, "neon"); +macros::aarch64_define_fma!(i32x4, vmlaq_s32); + +macros::aarch64_define_cmp!( + i32x4, + vceqq_s32, + (vmvnq_u32), + vcltq_s32, + vcleq_s32, + vcgtq_s32, + vcgeq_s32 +); +macros::aarch64_define_bitops!( + i32x4, + vmvnq_s32, + vandq_s32, + vorrq_s32, + veorq_s32, + ( + vshlq_s32, + 32, + vnegq_s32, + vminq_u32, + vreinterpretq_s32_u32, + vreinterpretq_u32_s32 + ), + (u32, i32, vmovq_n_s32), +); + +impl SIMDSumTree for i32x4 { + #[inline(always)] + fn sum_tree(self) -> i32 { + if cfg!(miri) { + self.emulated().sum_tree() + } else { + // SAFETY: Allowed by the `Neon` architecture. + unsafe { vaddvq_s32(self.0) } + } + } +} + +impl SIMDSelect for mask32x4 { + #[inline(always)] + fn select(self, x: i32x4, y: i32x4) -> i32x4 { + // SAFETY: Allowed by the `Neon` architecture. + i32x4(unsafe { vbslq_s32(self.0, x.0, y.0) }) + } +} + +impl SIMDDotProduct for i32x4 { + #[inline(always)] + fn dot_simd(self, left: i16x8, right: i16x8) -> Self { + if cfg!(miri) { + use crate::AsSIMD; + self.emulated() + .dot_simd(left.emulated(), right.emulated()) + .as_simd(self.arch()) + } else { + let left = left.0; + let right = right.0; + // SAFETY: Allowed by the `Neon` architecture. + unsafe { + let left_lo = vget_low_s16(left); + let left_hi = vget_high_s16(left); + + let right_lo = vget_low_s16(right); + let right_hi = vget_high_s16(right); + + let lo: int32x4_t = vmull_s16(left_lo, right_lo); + let hi: int32x4_t = vmull_s16(left_hi, right_hi); + + Self(vaddq_s32(self.0, vpaddq_s32(lo, hi))) + } + } + } +} + +impl SIMDDotProduct for i32x4 { + #[inline(always)] + fn dot_simd(self, left: u8x16, right: i8x16) -> Self { + if cfg!(miri) { + use crate::AsSIMD; + self.emulated() + .dot_simd(left.emulated(), right.emulated()) + .as_simd(self.arch()) + } else { + use crate::SplitJoin; + + // SAFETY: The intrinsics used here are allowed by the implicit `Neon` architecture. + unsafe { + let left = left.split(); + let right = right.split(); + + let left_evens: i16x8 = u8x8(vuzp1_u8(left.lo.0, left.hi.0)).into(); + let left_odds: i16x8 = u8x8(vuzp2_u8(left.lo.0, left.hi.0)).into(); + + let right_evens: i16x8 = i8x8(vuzp1_s8(right.lo.0, right.hi.0)).into(); + let right_odds: i16x8 = i8x8(vuzp2_s8(right.lo.0, right.hi.0)).into(); + + self.dot_simd(left_evens, right_evens) + .dot_simd(left_odds, right_odds) + } + } + } +} + +impl SIMDDotProduct for i32x4 { + #[inline(always)] + fn dot_simd(self, left: i8x16, right: u8x16) -> Self { + self.dot_simd(right, left) + } +} + +impl SIMDDotProduct for i32x4 { + #[inline(always)] + fn dot_simd(self, left: i8x16, right: i8x16) -> Self { + if cfg!(miri) { + use crate::AsSIMD; + self.emulated() + .dot_simd(left.emulated(), right.emulated()) + .as_simd(self.arch()) + } else { + // SAFETY: Instantiating `Neon` implies `dotprod`. + // + // We need this wrapper to allow compilation of the underlying ASM when compiling + // without the `dotprod` feature globally enabled. + #[target_feature(enable = "dotprod")] + unsafe fn sdot(mut s: int32x4_t, x: int8x16_t, y: int8x16_t) -> int32x4_t { + // SAFETY: The `Neon` architecture implies `dotprod`, allowing us to use + // this intrinsic. + unsafe { + asm!( + "sdot {0:v}.4s, {1:v}.16b, {2:v}.16b", + inout(vreg) s, + in(vreg) x, + in(vreg) y, + options(pure, nomem, nostack) + ); + } + + s + } + + // SAFETY: The `Neon` architecture guarantees the `dotprod` feature. + Self::from_underlying(self.arch(), unsafe { sdot(self.0, left.0, right.0) }) + } + } +} + +//-------------// +// Conversions // +//-------------// + +helpers::unsafe_map_cast!( + i32x4 => (f32, f32x4), + vcvtq_f32_s32, + "neon" +); + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Ops + test_utils::ops::test_add!(i32x4, 0x3017fd73c99cc633, Neon::new_checked()); + test_utils::ops::test_sub!(i32x4, 0xfc627f10b5f8db8a, Neon::new_checked()); + test_utils::ops::test_mul!(i32x4, 0x0f4caa80eceaa523, Neon::new_checked()); + test_utils::ops::test_fma!(i32x4, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_abs!(i32x4, 0xb8f702ba85375041, Neon::new_checked()); + + test_utils::ops::test_cmp!(i32x4, 0x941757bd5cc641a1, Neon::new_checked()); + + // Bit ops + test_utils::ops::test_bitops!(i32x4, 0xd62d8de09f82ed4e, Neon::new_checked()); + + // Dot Products + test_utils::dot_product::test_dot_product!( + (i16x8, i16x8) => i32x4, + 0x145f89b446c03ff1, + Neon::new_checked() + ); + + test_utils::dot_product::test_dot_product!( + (u8x16, i8x16) => i32x4, + 0x145f89b446c03ff1, + Neon::new_checked() + ); + + test_utils::dot_product::test_dot_product!( + (i8x16, u8x16) => i32x4, + 0x145f89b446c03ff1, + Neon::new_checked() + ); + + test_utils::dot_product::test_dot_product!( + (i8x16, i8x16) => i32x4, + 0x145f89b446c03ff1, + Neon::new_checked() + ); + + // Reductions + test_utils::ops::test_sumtree!(i32x4, 0xb9ac82ab23a855da, Neon::new_checked()); + + // Conversions + test_utils::ops::test_cast!(i32x4 => f32x4, 0xba8fe343fc9dbeff, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/i64x2_.rs b/diskann-wide/src/arch/aarch64/i64x2_.rs new file mode 100644 index 000000000..bdbbd683b --- /dev/null +++ b/diskann-wide/src/arch/aarch64/i64x2_.rs @@ -0,0 +1,117 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, SIMDAbs, SIMDMask, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDVector, + arch::Scalar, constant::Const, helpers, +}; + +// AArch64 masks +use super::{ + Neon, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask64x2, + u64x2_::{emulated_vminq_u64, emulated_vmvnq_u64}, +}; + +// AArch64 intrinsics +use std::arch::aarch64::*; + +/////////////////// +// 64-bit signed // +/////////////////// + +#[inline(always)] +pub(super) unsafe fn emulated_vmvnq_s64(x: int64x2_t) -> int64x2_t { + let x: [i64; 2] = i64x2(x).to_array(); + let mapped: [i64; 2] = core::array::from_fn(|i| !x[i]); + // SAFETY: This is only called in a context where the caller guarantees `Neon` is + // available. + i64x2::from_array(unsafe { Neon::new() }, mapped).0 +} + +macros::aarch64_define_register!(i64x2, int64x2_t, mask64x2, i64, 2, Neon); +macros::aarch64_define_splat!(i64x2, vmovq_n_s64); +macros::aarch64_define_loadstore!(i64x2, vld1q_s64, vst1q_s64, 2); + +helpers::unsafe_map_binary_op!(i64x2, std::ops::Add, add, vaddq_s64, "neon"); +helpers::unsafe_map_binary_op!(i64x2, std::ops::Sub, sub, vsubq_s64, "neon"); +helpers::unsafe_map_unary_op!(i64x2, SIMDAbs, abs_simd, vabsq_s64, "neon"); + +impl std::ops::Mul for i64x2 { + type Output = Self; + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + let x = Emulated::::from_array(Scalar, self.to_array()); + let y = Emulated::::from_array(Scalar, rhs.to_array()); + Self::from_array(self.arch(), (x * y).to_array()) + } +} + +macros::aarch64_define_fma!(i64x2, integer); + +macros::aarch64_define_cmp!( + i64x2, + vceqq_s64, + (emulated_vmvnq_u64), + vcltq_s64, + vcleq_s64, + vcgtq_s64, + vcgeq_s64 +); +macros::aarch64_define_bitops!( + i64x2, + emulated_vmvnq_s64, + vandq_s64, + vorrq_s64, + veorq_s64, + ( + vshlq_s64, + 64, + vnegq_s64, + emulated_vminq_u64, + vreinterpretq_s64_u64, + vreinterpretq_u64_s64 + ), + (u64, i64, vmovq_n_s64), +); + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Binary Ops + test_utils::ops::test_add!(i64x2, 0x8d7bf28b1c6e2545, Neon::new_checked()); + test_utils::ops::test_sub!(i64x2, 0x4a1c644a1a910bed, Neon::new_checked()); + test_utils::ops::test_mul!(i64x2, 0xf42ee707a808fd10, Neon::new_checked()); + test_utils::ops::test_fma!(i64x2, 0x28540d9936a9e803, Neon::new_checked()); + test_utils::ops::test_abs!(i64x2, 0xb8f702ba85375041, Neon::new_checked()); + + test_utils::ops::test_cmp!(i64x2, 0xfae27072c6b70885, Neon::new_checked()); + + // Bit ops + test_utils::ops::test_bitops!(i64x2, 0xbe927713ea310164, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/i8x16_.rs b/diskann-wide/src/arch/aarch64/i8x16_.rs new file mode 100644 index 000000000..494fa5a56 --- /dev/null +++ b/diskann-wide/src/arch/aarch64/i8x16_.rs @@ -0,0 +1,99 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, SIMDAbs, SIMDMask, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDVector, + constant::Const, helpers, +}; + +// AArch64 masks +use super::{ + Neon, i8x8, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask8x16, +}; + +// AArch64 intrinsics +use std::arch::aarch64::*; + +////////////////// +// 8-bit signed // +////////////////// + +macros::aarch64_define_register!(i8x16, int8x16_t, mask8x16, i8, 16, Neon); +macros::aarch64_define_splat!(i8x16, vmovq_n_s8); +macros::aarch64_define_loadstore!(i8x16, vld1q_s8, vst1q_s8, 16); +macros::aarch64_splitjoin!(i8x16, i8x8, vget_low_s8, vget_high_s8, vcombine_s8); + +helpers::unsafe_map_binary_op!(i8x16, std::ops::Add, add, vaddq_s8, "neon"); +helpers::unsafe_map_binary_op!(i8x16, std::ops::Sub, sub, vsubq_s8, "neon"); +helpers::unsafe_map_binary_op!(i8x16, std::ops::Mul, mul, vmulq_s8, "neon"); +helpers::unsafe_map_unary_op!(i8x16, SIMDAbs, abs_simd, vabsq_s8, "neon"); +macros::aarch64_define_fma!(i8x16, vmlaq_s8); + +macros::aarch64_define_cmp!( + i8x16, + vceqq_s8, + (vmvnq_u8), + vcltq_s8, + vcleq_s8, + vcgtq_s8, + vcgeq_s8 +); +macros::aarch64_define_bitops!( + i8x16, + vmvnq_s8, + vandq_s8, + vorrq_s8, + veorq_s8, + ( + vshlq_s8, + 8, + vnegq_s8, + vminq_u8, + vreinterpretq_s8_u8, + vreinterpretq_u8_s8 + ), + (u8, i8, vmovq_n_s8), +); + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Ops + test_utils::ops::test_add!(i8x16, 0x3017fd73c99cc633, Neon::new_checked()); + test_utils::ops::test_sub!(i8x16, 0xfc627f10b5f8db8a, Neon::new_checked()); + test_utils::ops::test_mul!(i8x16, 0x0f4caa80eceaa523, Neon::new_checked()); + test_utils::ops::test_fma!(i8x16, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_abs!(i8x16, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_splitjoin!(i8x16 => i8x8, 0xa4d00a4d04293967, Neon::new_checked()); + + test_utils::ops::test_cmp!(i8x16, 0x941757bd5cc641a1, Neon::new_checked()); + + // Bit ops + test_utils::ops::test_bitops!(i8x16, 0xd62d8de09f82ed4e, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/i8x8_.rs b/diskann-wide/src/arch/aarch64/i8x8_.rs new file mode 100644 index 000000000..2dad5cc16 --- /dev/null +++ b/diskann-wide/src/arch/aarch64/i8x8_.rs @@ -0,0 +1,89 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, SIMDAbs, SIMDMask, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDVector, + constant::Const, helpers, +}; + +// AArch64 masks +use super::{ + Neon, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask8x8, +}; + +// AArch64 intrinsics +use std::arch::aarch64::*; + +////////////////// +// 8-bit signed // +////////////////// + +macros::aarch64_define_register!(i8x8, int8x8_t, mask8x8, i8, 8, Neon); +macros::aarch64_define_splat!(i8x8, vmov_n_s8); +macros::aarch64_define_loadstore!(i8x8, vld1_s8, vst1_s8, 8); + +helpers::unsafe_map_binary_op!(i8x8, std::ops::Add, add, vadd_s8, "neon"); +helpers::unsafe_map_binary_op!(i8x8, std::ops::Sub, sub, vsub_s8, "neon"); +helpers::unsafe_map_binary_op!(i8x8, std::ops::Mul, mul, vmul_s8, "neon"); +helpers::unsafe_map_unary_op!(i8x8, SIMDAbs, abs_simd, vabs_s8, "neon"); +macros::aarch64_define_fma!(i8x8, vmla_s8); + +macros::aarch64_define_cmp!(i8x8, vceq_s8, (vmvn_u8), vclt_s8, vcle_s8, vcgt_s8, vcge_s8); +macros::aarch64_define_bitops!( + i8x8, + vmvn_s8, + vand_s8, + vorr_s8, + veor_s8, + ( + vshl_s8, + 8, + vneg_s8, + vmin_u8, + vreinterpret_s8_u8, + vreinterpret_u8_s8 + ), + (u8, i8, vmov_n_s8), +); + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Ops + test_utils::ops::test_add!(i8x8, 0x3017fd73c99cc633, Neon::new_checked()); + test_utils::ops::test_sub!(i8x8, 0xfc627f10b5f8db8a, Neon::new_checked()); + test_utils::ops::test_mul!(i8x8, 0x0f4caa80eceaa523, Neon::new_checked()); + test_utils::ops::test_fma!(i8x8, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_abs!(i8x8, 0xb8f702ba85375041, Neon::new_checked()); + + test_utils::ops::test_cmp!(i8x8, 0x941757bd5cc641a1, Neon::new_checked()); + + // Bit ops + test_utils::ops::test_bitops!(i8x8, 0xd62d8de09f82ed4e, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/macros.rs b/diskann-wide/src/arch/aarch64/macros.rs new file mode 100644 index 000000000..4fab7de37 --- /dev/null +++ b/diskann-wide/src/arch/aarch64/macros.rs @@ -0,0 +1,575 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::SIMDVector; + +macro_rules! aarch64_define_register { + ($type:ident, $impl:ty, $mask:ty, $scalar:ty, $lanes:literal, $arch:ty) => { + #[derive(Debug, Clone, Copy)] + #[allow(non_camel_case_types)] + #[repr(transparent)] + pub struct $type(pub $impl); + + impl $type { + /// Convert `self` to its corresponding [`crate::Emulated`] type. + #[inline(always)] + pub fn emulated(self) -> $crate::Emulated<$scalar, $lanes> { + $crate::Emulated::from_array($crate::arch::Scalar, self.to_array()) + } + } + + impl $crate::AsSIMD<$type> for $crate::Emulated<$scalar, $lanes> { + #[inline(always)] + fn as_simd(self, arch: $arch) -> $type { + $type::from_array(arch, self.to_array()) + } + } + + impl SIMDVector for $type { + type Arch = $arch; + type Scalar = $scalar; + type Underlying = $impl; + + type Mask = $mask; + type ConstLanes = Const<$lanes>; + const LANES: usize = $lanes; + const EMULATED: bool = false; + + #[inline(always)] + fn arch(self) -> $arch { + // SAFETY: The existence of `self` provides a witness that it is safe to + // instantiate its architecture. + unsafe { <$arch>::new() } + } + + #[inline(always)] + fn default(arch: $arch) -> Self { + ::aarch_default(arch) + } + + fn to_underlying(self) -> Self::Underlying { + self.0 + } + + fn from_underlying(_: $arch, repr: Self::Underlying) -> Self { + Self(repr) + } + + fn to_array(self) -> [$scalar; $lanes] { + // SAFETY: Provided the scalar type is an integer or floating point, + // then all bit pattens are valid between source and destination types. + // (provided an x86 intrinsic is one of the transmuted types). + // + // The source argument is taken by value (no reference conversion) and + // as long as `T` is `[repr(C)]`, then `[T; N]` will be `[repr(C)]`. + // + // The intrinsic types are `[repr(simd)]` which amounts to `[repr(C)]` and + // change. + unsafe { std::mem::transmute::(self) } + } + + fn from_array(_: $arch, x: [$scalar; $lanes]) -> Self { + // SAFETY: Provided the scalar type is an integer or floating point, + // then all bit pattens are valid between source and destination types. + // (provided an x86 intrinsic is one of the transmuted types). + // + // The source argument is taken by value (no reference conversion) and + // as long as `T` is `[repr(C)]`, then `[T; N]` will be `[repr(C)]`. + // + // The intrinsic types are `[repr(simd)]` which amounts to `[repr(C)]` and + // change. + unsafe { std::mem::transmute::<[$scalar; $lanes], Self>(x) } + } + + #[inline(always)] + fn splat(arch: $arch, value: Self::Scalar) -> Self { + ::aarch_splat(arch, value) + } + + #[inline(always)] + unsafe fn load_simd(arch: $arch, ptr: *const $scalar) -> Self { + // SAFETY: Inherited from caller. + unsafe { ::load_simd(arch, ptr) } + } + + #[inline(always)] + unsafe fn load_simd_masked_logical( + arch: $arch, + ptr: *const $scalar, + mask: $mask, + ) -> Self { + // SAFETY: Inherited from caller. + unsafe { ::load_simd_masked_logical(arch, ptr, mask) } + } + + #[inline(always)] + unsafe fn load_simd_first(arch: $arch, ptr: *const $scalar, first: usize) -> Self { + // SAFETY: Inherited from caller. + unsafe { ::load_simd_first(arch, ptr, first) } + } + + #[inline(always)] + unsafe fn store_simd(self, ptr: *mut $scalar) { + // SAFETY: Inherited from caller. + unsafe { ::store_simd(self, ptr) } + } + + #[inline(always)] + unsafe fn store_simd_masked_logical(self, ptr: *mut $scalar, mask: $mask) { + // SAFETY: Inherited from caller. + unsafe { ::store_simd_masked_logical(self, ptr, mask) } + } + + #[inline(always)] + unsafe fn store_simd_first(self, ptr: *mut $scalar, first: usize) { + // SAFETY: Inherited from caller. + unsafe { ::store_simd_first(self, ptr, first) } + } + } + }; +} + +pub(super) trait AArchSplat: SIMDVector { + fn aarch_splat(arch: ::Arch, value: ::Scalar) -> Self; + + fn aarch_default(arch: ::Arch) -> Self; +} + +pub(super) trait AArchLoadStore: SIMDVector { + unsafe fn load_simd( + arch: ::Arch, + ptr: *const ::Scalar, + ) -> Self; + unsafe fn load_simd_masked_logical( + arch: ::Arch, + ptr: *const ::Scalar, + mask: Self::Mask, + ) -> Self; + unsafe fn load_simd_first( + arch: ::Arch, + ptr: *const ::Scalar, + first: usize, + ) -> Self; + + unsafe fn store_simd(self, ptr: *mut ::Scalar); + unsafe fn store_simd_masked_logical( + self, + ptr: *mut ::Scalar, + mask: Self::Mask, + ); + unsafe fn store_simd_first(self, ptr: *mut ::Scalar, first: usize); +} + +/// Utility macro for defining `AArchSplat`. +/// +/// SAFETY: It is the invoker's responsibility to ensure that the intrinsic is safe to call. +macro_rules! aarch64_define_splat { + ($type:ty, $intrinsic:expr) => { + impl AArchSplat for $type { + #[inline(always)] + fn aarch_splat( + _arch: ::Arch, + value: ::Scalar, + ) -> Self { + // SAFETY: Instantiator asserts that `$intrinsic` is allowed by `Arch`. + Self(unsafe { $intrinsic(value) }) + } + + #[inline(always)] + fn aarch_default(arch: ::Arch) -> Self { + Self::aarch_splat(arch, ::Scalar::default()) + } + } + }; +} + +macro_rules! aarch64_define_loadstore { + ($type:ty, $load:expr, $store:expr, $lanes:literal) => { + impl AArchLoadStore for $type { + #[inline(always)] + unsafe fn load_simd( + _arch: ::Arch, + ptr: *const ::Scalar, + ) -> Self { + // SAFETY: Instantiator asserts that `$load` is allowed by `Arch`. + Self(unsafe { $load(ptr) }) + } + + #[inline(always)] + unsafe fn load_simd_masked_logical( + arch: ::Arch, + ptr: *const ::Scalar, + mask: Self::Mask, + ) -> Self { + // SAFETY: Inherited from caller. + let e = unsafe { + Emulated::<_, $lanes>::load_simd_masked_logical( + $crate::arch::Scalar, + ptr, + mask.bitmask().as_scalar(), + ) + }; + + Self::from_array(arch, e.to_array()) + } + + #[inline(always)] + unsafe fn load_simd_first( + arch: ::Arch, + ptr: *const ::Scalar, + first: usize, + ) -> Self { + // SAFETY: Inherited from caller. + let e = unsafe { + Emulated::<_, $lanes>::load_simd_first($crate::arch::Scalar, ptr, first) + }; + + Self::from_array(arch, e.to_array()) + } + + #[inline(always)] + unsafe fn store_simd(self, ptr: *mut ::Scalar) { + // SAFETY: Instantiator asserts that `$store` is allowed by `Arch`. + unsafe { $store(ptr, self.0) } + } + + unsafe fn store_simd_masked_logical( + self, + ptr: *mut ::Scalar, + mask: Self::Mask, + ) { + let e = Emulated::<_, $lanes>::from_array($crate::arch::Scalar, self.to_array()); + + // SAFETY: Inherited from caller. + unsafe { e.store_simd_masked_logical(ptr, mask.bitmask().as_scalar()) } + } + + #[inline(always)] + unsafe fn store_simd_first(self, ptr: *mut ::Scalar, first: usize) { + let e = Emulated::<_, $lanes>::from_array($crate::arch::Scalar, self.to_array()); + + // SAFETY: Inherited from caller. + unsafe { e.store_simd_first(ptr, first) } + } + } + }; +} + +macro_rules! aarch64_define_cmp { + ($type:ty, $eq:ident, ($not:expr), $lt:ident, $le:ident, $gt:ident, $ge:ident) => { + impl SIMDPartialEq for $type { + #[inline(always)] + fn eq_simd(self, other: Self) -> Self::Mask { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self::Mask::from_underlying(self.arch(), unsafe { $eq(self.0, other.0) }) + } + + #[inline(always)] + fn ne_simd(self, other: Self) -> Self::Mask { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self::Mask::from_underlying(self.arch(), unsafe { $not($eq(self.0, other.0)) }) + } + } + + impl SIMDPartialOrd for $type { + #[inline(always)] + fn lt_simd(self, other: Self) -> Self::Mask { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self::Mask::from_underlying(self.arch(), unsafe { $lt(self.0, other.0) }) + } + + #[inline(always)] + fn le_simd(self, other: Self) -> Self::Mask { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self::Mask::from_underlying(self.arch(), unsafe { $le(self.0, other.0) }) + } + + #[inline(always)] + fn gt_simd(self, other: Self) -> Self::Mask { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self::Mask::from_underlying(self.arch(), unsafe { $gt(self.0, other.0) }) + } + + #[inline(always)] + fn ge_simd(self, other: Self) -> Self::Mask { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self::Mask::from_underlying(self.arch(), unsafe { $ge(self.0, other.0) }) + } + } + }; +} + +/// Utility macro for defining simple operations that lower to a single intrinsic. +/// +/// SAFETY: It is the invoker's responsibility to ensure that the intrinsic is safe to call +macro_rules! aarch64_define_fma { + ($type:ty, integer) => { + impl SIMDMulAdd for $type { + #[inline(always)] + fn mul_add_simd(self, rhs: Self, accumulator: Self) -> $type { + self * rhs + accumulator + } + } + }; + // This variant maps the implementation to an intrinsic. + ($type:ty, $intrinsic:expr) => { + impl SIMDMulAdd for $type { + #[inline(always)] + fn mul_add_simd(self, rhs: Self, accumulator: Self) -> $type { + // SAFETY: The invoker of this macro must pass the `target_feature` + // requirement of the intrinsic. + // + // That way, if the intrinsic is not available, we get a compile-time error. + Self(unsafe { $intrinsic(accumulator.0, self.0, rhs.0) }) + } + } + }; +} + +/// # Notes on vector shifts. +/// +/// Neon only has the `vector`x`vector` left shift function. However, it takes signed +/// arguments for the shift amount. Right shifts are achieved by using negative left-shifts. +/// +/// To maintain consistency in `Wide`, we only allow positive left shifts and positive right +/// shifts. +/// +/// * Left shifts: We need to clamp the shift amount between 0 and the maximum shift +/// (inclusive). This is done by first reinrepreting the shift vector as unsigned +/// (`cvtpre`), taking the unsigned `min` with the maximal shift, and then reinterpret +/// to signed (`cvtpost`). +/// +/// If the shift vector is already unsigned, then `cvtpre` can be the identity. +/// +/// * Right shifts: Right shifts follow the same logic as left shifts, just with a final +/// negative before invoking the left-shift intrinsic. +/// +/// # Shifts by a Scalar +/// +/// LLVM is not smart enough to constant propagate properly though a `splat` followed by +/// a vector shift if we do the range-limitation after the `splat`. So, the scalar shift +/// operations perform the "cast-to-positive + min + cast-to-signed" in the scalar space +/// before splatting. LLVM optimizes this correctly. +macro_rules! aarch64_define_bitops { + ($type:ty, + $not:ident, + $and:ident, + $or:ident, + $xor:ident, + ($shlv:ident, $mask:literal, $neg:ident, $min:ident, $cvtpost:path, $cvtpre:path), + ($unsigned:ty, $signed:ty, $broadcast_signed:ident), + ) => { + impl std::ops::Not for $type { + type Output = Self; + #[inline(always)] + fn not(self) -> Self { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self(unsafe { $not(self.0) }) + } + } + + impl std::ops::BitAnd for $type { + type Output = Self; + #[inline(always)] + fn bitand(self, rhs: Self) -> Self { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self(unsafe { $and(self.0, rhs.0) }) + } + } + + impl std::ops::BitOr for $type { + type Output = Self; + #[inline(always)] + fn bitor(self, rhs: Self) -> Self { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self(unsafe { $or(self.0, rhs.0) }) + } + } + + impl std::ops::BitXor for $type { + type Output = Self; + #[inline(always)] + fn bitxor(self, rhs: Self) -> Self { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self(unsafe { $xor(self.0, rhs.0) }) + } + } + + /////////////////// + // vector shifts // + /////////////////// + + impl std::ops::Shr for $type { + type Output = Self; + #[inline(always)] + fn shr(self, rhs: Self) -> Self { + use $crate::AsSIMD; + if cfg!(miri) { + self.emulated().shr(rhs.emulated()).as_simd(self.arch()) + } else { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self(unsafe { + $shlv( + self.0, + $neg($cvtpost($min( + $cvtpre(rhs.0), + $cvtpre(<$type as SIMDVector>::splat(self.arch(), $mask).0), + ))), + ) + }) + } + } + } + + impl std::ops::Shl for $type { + type Output = Self; + #[inline(always)] + fn shl(self, rhs: Self) -> Self { + use $crate::AsSIMD; + if cfg!(miri) { + self.emulated().shl(rhs.emulated()).as_simd(self.arch()) + } else { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self(unsafe { + $shlv( + self.0, + $cvtpost($min( + $cvtpre(rhs.0), + $cvtpre(<$type as SIMDVector>::splat(self.arch(), $mask).0), + )), + ) + }) + } + } + } + + /////////////////// + // scalar shifts // + /////////////////// + + impl std::ops::Shr<<$type as SIMDVector>::Scalar> for $type { + type Output = Self; + #[inline(always)] + fn shr(self, rhs: <$type as SIMDVector>::Scalar) -> Self { + use $crate::AsSIMD; + if cfg!(miri) { + self.emulated().shr(rhs).as_simd(self.arch()) + } else { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self(unsafe { + $shlv( + self.0, + $broadcast_signed(-((rhs as $unsigned).min($mask) as $signed)), + ) + }) + } + } + } + + impl std::ops::Shl<<$type as SIMDVector>::Scalar> for $type { + type Output = Self; + #[inline(always)] + fn shl(self, rhs: <$type as SIMDVector>::Scalar) -> Self { + use $crate::AsSIMD; + if cfg!(miri) { + self.emulated().shl(rhs).as_simd(self.arch()) + } else { + // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // + // It is the caller's responsibility to instantiate the macro with an + // intrinsics also gated by "neon". + Self(unsafe { + $shlv( + self.0, + $broadcast_signed((rhs as $unsigned).min($mask) as $signed), + ) + }) + } + } + } + }; +} + +/// SAFETY: It is the invoker's responsibility to ensure that the provided intrinsics are +/// safe to call. T hat is - any intrinsics invoked must be compatible with `$type`'s +/// associated architecture. +macro_rules! aarch64_splitjoin { + ($type:path, $half:path, $getlo:ident, $gethi:ident, $join:ident) => { + impl $crate::SplitJoin for $type { + type Halved = $half; + + #[inline(always)] + fn split(self) -> $crate::LoHi { + // SAFETY: This should only be instantiated for types where the associated + // architecture provides a license to use it. + unsafe { + $crate::LoHi::new( + Self::Halved::from_underlying(self.arch(), $getlo(self.to_underlying())), + Self::Halved::from_underlying(self.arch(), $gethi(self.to_underlying())), + ) + } + } + + #[inline(always)] + fn join(lohi: $crate::LoHi) -> Self { + // SAFETY: This should only be instantiated for types where the associated + // architecture provides a license to use it. + unsafe { + Self::from_underlying( + lohi.lo.arch(), + $join(lohi.lo.to_underlying(), lohi.hi.to_underlying()), + ) + } + } + } + }; +} + +pub(crate) use aarch64_define_bitops; +pub(crate) use aarch64_define_cmp; +pub(crate) use aarch64_define_fma; +pub(crate) use aarch64_define_loadstore; +pub(crate) use aarch64_define_register; +pub(crate) use aarch64_define_splat; +pub(crate) use aarch64_splitjoin; diff --git a/diskann-wide/src/arch/aarch64/masks.rs b/diskann-wide/src/arch/aarch64/masks.rs new file mode 100644 index 000000000..f88c5f57a --- /dev/null +++ b/diskann-wide/src/arch/aarch64/masks.rs @@ -0,0 +1,822 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +//! # Masks +//! +//! Neon masks are bit-width specific, but type agnostic (meaning the mask representation +//! for `u32x4` and `f32x4` are the same). +//! +//! The representation for a type with bitwidth `B` and `L` lanes is a SIMD register +//! containing `L` lanes of `B`-bit unsigned integers. +//! +//! Within each lane, a mask is "set" is all `B` bits of the corresponding integer are 1, +//! and unset if all `B` bits are 0. These masks are automatically generated in this form +//! by the various compare intrinsics. +//! +//! Setting all bits is important because the `select` operations in Neon are bit-wise +//! selects, unlike AVX2 where only the most-significant bit is important. +//! +//! The conversion implementation in this file still refer to the uppermost bit when +//! implementing `move_mask`-like functionality. + +use crate::{BitMask, FromInt, SIMDMask}; + +use super::Neon; + +use std::arch::aarch64::*; + +macro_rules! define_mask { + ($mask:ident, $repr:ident, $lanes:literal, $arch:ty) => { + #[derive(Debug, Clone, Copy)] + #[allow(non_camel_case_types)] + #[repr(transparent)] + pub struct $mask(pub(crate) $repr); + + impl SIMDMask for $mask { + type Arch = $arch; + type Underlying = $repr; + type BitMask = BitMask<$lanes, $arch>; + const ISBITS: bool = false; + const LANES: usize = $lanes; + + #[inline(always)] + fn arch(self) -> Self::Arch { + // SAFETY: Since `self` cannot be safely constructed without its `Arch`, + // it's safe to construct the arch. + unsafe { <$arch>::new() } + } + + #[inline(always)] + fn to_underlying(self) -> Self::Underlying { + self.0 + } + + #[inline(always)] + fn from_underlying(_arch: $arch, value: Self::Underlying) -> Self { + Self(value) + } + + #[inline(always)] + fn keep_first(arch: $arch, lanes: usize) -> Self { + Self(<$repr as MaskOps>::keep_first(arch, lanes)) + } + + #[inline(always)] + fn get_unchecked(&self, i: usize) -> bool { + <$repr as MaskOps>::move_mask(self.0, self.arch()).get_unchecked(i) + } + } + + impl From> for $mask { + #[inline(always)] + fn from(mask: BitMask<$lanes, $arch>) -> Self { + Self(<$repr as MaskOps>::from_mask(mask)) + } + } + + impl From<$mask> for BitMask<$lanes, $arch> { + #[inline(always)] + fn from(mask: $mask) -> BitMask<$lanes, $arch> { + <$repr as MaskOps>::move_mask(mask.0, mask.arch()) + } + } + }; +} + +define_mask!(mask8x8, uint8x8_t, 8, Neon); +define_mask!(mask8x16, uint8x16_t, 16, Neon); +define_mask!(mask16x4, uint16x4_t, 4, Neon); +define_mask!(mask16x8, uint16x8_t, 8, Neon); +define_mask!(mask32x2, uint32x2_t, 2, Neon); +define_mask!(mask32x4, uint32x4_t, 4, Neon); +define_mask!(mask64x1, uint64x1_t, 1, Neon); +define_mask!(mask64x2, uint64x2_t, 2, Neon); + +///////////// +// MaskOps // +///////////// + +trait MaskOps: Sized { + type BitMask: SIMDMask; + type Array; + + /// Convert `self` into a BitMask. + fn move_mask(self, arch: Neon) -> Self::BitMask; + + /// Construct `Self` from a BitMask. + fn from_mask(mask: Self::BitMask) -> Self; + + /// Convert `self` to an array. + #[cfg(any(test, miri))] + fn to_array(self) -> Self::Array; + + /// Construct `Self` by only keeping up to the first `lanes` lanes. + #[inline(always)] + fn keep_first(arch: Neon, lanes: usize) -> Self { + Self::from_mask(Self::BitMask::keep_first(arch, lanes)) + } +} + +// Two approaches are used for `move_mask` depending on lane count: +// +// * For types with few lanes (≤4), we use a shift-right-accumulate (USRA) chain: +// normalize each lane to 0/1 via a right-shift, then progressively fold adjacent bits +// together by reinterpreting at wider lane widths and using USRA. This requires zero +// constants and no horizontal reductions. +// +// * For types with many lanes (8+), we use the MSB-isolate + variable-shift + horizontal-add +// approach: mask the upper-most bit of each lane, perform a variable shift to place each +// retained bit in a unique position, then finish with a horizontal sum to concatenate +// all bits together. + +impl MaskOps for uint8x8_t { + type BitMask = BitMask<8, Neon>; + type Array = [u8; 8]; + + #[inline(always)] + fn move_mask(self, arch: Neon) -> Self::BitMask { + cfg_if::cfg_if! { + if #[cfg(miri)] { + let array = self.to_array(); + BitMask::from_fn(arch, |i| array[i] == u8::MAX) + } else { + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + let value = unsafe { + let mask = vmov_n_u8(0x80); + // Effectively creates [-7, -6, -5, -4, -3, -2, -1, 0] + let shifts = vcreate_s8(0x00FF_FEFD_FCFB_FAF9); + vaddlv_u8(vshl_u8(vand_u8(self, mask), shifts)) + }; + BitMask::from_int(arch, value as u8) + } + } + } + + #[inline(always)] + fn from_mask(mask: Self::BitMask) -> Self { + const BIT_SELECTOR: u64 = 0x8040201008040201; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { vtst_u8(vmov_n_u8(mask.0), vcreate_u8(BIT_SELECTOR)) } + } + + #[inline(always)] + fn keep_first(_arch: Neon, lanes: usize) -> Self { + const INDICES: u64 = 0x0706050403020100; + let n = lanes.min(8) as u8; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { vclt_u8(vcreate_u8(INDICES), vmov_n_u8(n)) } + } + + #[cfg(any(test, miri))] + fn to_array(self) -> Self::Array { + // SAFETY: Both the source and destination types are trivially destructible and are + // valid for all possible bit-patterns. + unsafe { std::mem::transmute::(self) } + } +} + +impl MaskOps for uint8x16_t { + type BitMask = BitMask<16, Neon>; + type Array = [u8; 16]; + + #[inline(always)] + fn move_mask(self, arch: Neon) -> Self::BitMask { + cfg_if::cfg_if! { + if #[cfg(miri)] { + let array = self.to_array(); + BitMask::from_fn(arch, |i| array[i] == u8::MAX) + } else { + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + let value = unsafe { + let mask = vmovq_n_u8(0x80); + let masked = vandq_u8(self, mask); + // Effectively creates [-7, -6, -5, -4, -3, -2, -1, 0] + let shifts = vcreate_s8(0x00FF_FEFD_FCFB_FAF9); + + let low = vaddlv_u8(vshl_u8(vget_low_u8(masked), shifts)); + let high = vaddlv_u8(vshl_u8(vget_high_u8(masked), shifts)); + + (low as u16) | ((high as u16) << 8) + }; + BitMask::from_int(arch, value) + } + } + } + + #[inline(always)] + fn from_mask(mask: Self::BitMask) -> Self { + let mask: u16 = mask.0; + const BIT_SELECTOR: u64 = 0x8040201008040201; + + let low = mask as u8; + let high = (mask >> 8) as u8; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { + vtstq_u8( + vcombine_u8(vmov_n_u8(low), vmov_n_u8(high)), + vcombine_u8(vcreate_u8(BIT_SELECTOR), vcreate_u8(BIT_SELECTOR)), + ) + } + } + + #[inline(always)] + fn keep_first(_arch: Neon, lanes: usize) -> Self { + const LO: u64 = 0x0706050403020100; + const HI: u64 = 0x0F0E0D0C0B0A0908; + let n = lanes.min(16) as u8; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { vcltq_u8(vcombine_u8(vcreate_u8(LO), vcreate_u8(HI)), vmovq_n_u8(n)) } + } + + #[cfg(any(test, miri))] + fn to_array(self) -> Self::Array { + // SAFETY: Both the source and destination types are trivially destructible and are + // valid for all possible bit-patterns. + unsafe { std::mem::transmute::(self) } + } +} + +impl MaskOps for uint16x4_t { + type BitMask = BitMask<4, Neon>; + type Array = [u16; 4]; + + #[inline(always)] + fn move_mask(self, arch: Neon) -> Self::BitMask { + cfg_if::cfg_if! { + if #[cfg(miri)] { + let array = self.to_array(); + BitMask::from_fn(arch, |i| array[i] == u16::MAX) + } else { + // Step 1: Isolate single bits in each lane and compact to bytes: + // + // | Lane 0 | Lane 1 | Lane 2 | Lane 3 | + // | 0b0000'000a | 0b0000'000b | 0b0000'000c | 0b0000'000d | + // + // Step 2: Shift the even lanes and then add with the odd lanes: + // + // | Lane 0 | Lane 1 | + // | 0b0000'0000'0000'00ab | 0b0000'0000'0000'00cd | + // + // Step 3: Shift the even lane and add with the odd lane. + // + // | 0b0000'0000'0000'0000'b0000'0000'0000'abcd | + // + // Thus, everything gets compressed down to 4-bits. + // + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + let value = unsafe { + let bits = vshr_n_u16(self, 15); + let paired = vsra_n_u32( + vreinterpret_u32_u16(bits), + vreinterpret_u32_u16(bits), + 15, + ); + let packed = vsra_n_u64( + vreinterpret_u64_u32(paired), + vreinterpret_u64_u32(paired), + 30, + ); + vget_lane_u8(vreinterpret_u8_u64(packed), 0) + }; + BitMask::from_int(arch, value) + } + } + } + + #[inline(always)] + fn from_mask(mask: Self::BitMask) -> Self { + const BIT_SELECTOR: u64 = 0x0008_0004_0002_0001; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { vtst_u16(vmov_n_u16(mask.0 as u16), vcreate_u16(BIT_SELECTOR)) } + } + + #[inline(always)] + fn keep_first(_arch: Neon, lanes: usize) -> Self { + const INDICES: u64 = 0x0003_0002_0001_0000; + let n = lanes.min(4) as u16; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { vclt_u16(vcreate_u16(INDICES), vmov_n_u16(n)) } + } + + #[cfg(any(test, miri))] + fn to_array(self) -> Self::Array { + // SAFETY: Both the source and destination types are trivially destructible and are + // valid for all possible bit-patterns. + unsafe { std::mem::transmute::(self) } + } +} + +impl MaskOps for uint16x8_t { + type BitMask = BitMask<8, Neon>; + type Array = [u16; 8]; + + #[inline(always)] + fn move_mask(self, arch: Neon) -> Self::BitMask { + cfg_if::cfg_if! { + if #[cfg(miri)] { + let array = self.to_array(); + BitMask::from_fn(arch, |i| array[i] == u16::MAX) + } else { + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + let value = unsafe { + // Effectively creates [-15, -14, -13, -12, -11, -10, -9, -8] + let shifts = vcombine_s16( + vcreate_s16(0xFFF4_FFF3_FFF2_FFF1), + vcreate_s16(0xFFF8_FFF7_FFF6_FFF5), + ); + let mask = vmovq_n_u16(0x8000); + vaddlvq_u16(vshlq_u16(vandq_u16(self, mask), shifts)) + }; + BitMask::from_int(arch, value as u8) + } + } + } + + #[inline(always)] + fn from_mask(mask: Self::BitMask) -> Self { + const BIT_SELECTOR_LOW: u64 = 0x0008_0004_0002_0001; + const BIT_SELECTOR_HIGH: u64 = 0x0080_0040_0020_0010; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { + vtstq_u16( + vmovq_n_u16(mask.0 as u16), + vcombine_u16( + vcreate_u16(BIT_SELECTOR_LOW), + vcreate_u16(BIT_SELECTOR_HIGH), + ), + ) + } + } + + #[inline(always)] + fn keep_first(_arch: Neon, lanes: usize) -> Self { + const LO: u64 = 0x0003_0002_0001_0000; + const HI: u64 = 0x0007_0006_0005_0004; + let n = lanes.min(8) as u16; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { + vcltq_u16( + vcombine_u16(vcreate_u16(LO), vcreate_u16(HI)), + vmovq_n_u16(n), + ) + } + } + + #[cfg(any(test, miri))] + fn to_array(self) -> Self::Array { + // SAFETY: Both the source and destination types are trivially destructible and are + // valid for all possible bit-patterns. + unsafe { std::mem::transmute::(self) } + } +} + +impl MaskOps for uint32x2_t { + type BitMask = BitMask<2, Neon>; + type Array = [u32; 2]; + + #[inline(always)] + fn move_mask(self, arch: Neon) -> Self::BitMask { + cfg_if::cfg_if! { + if #[cfg(miri)] { + let array = self.to_array(); + BitMask::from_fn(arch, |i| array[i] == u32::MAX) + } else { + // Normalize each lane to 0 or 1, then use shift-right-accumulate to pack + // bits into position. + // + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + let value = unsafe { + let bits = vshr_n_u32(self, 31); + let packed = vsra_n_u64( + vreinterpret_u64_u32(bits), + vreinterpret_u64_u32(bits), + 31, + ); + vget_lane_u8(vreinterpret_u8_u64(packed), 0) + }; + BitMask::from_int(arch, value) + } + } + } + + #[inline(always)] + fn from_mask(mask: Self::BitMask) -> Self { + const BIT_SELECTOR: u64 = 0x0000_0002_0000_0001; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { vtst_u32(vmov_n_u32(mask.0 as u32), vcreate_u32(BIT_SELECTOR)) } + } + + #[inline(always)] + fn keep_first(_arch: Neon, lanes: usize) -> Self { + const INDICES: u64 = 0x0000_0001_0000_0000; + let n = lanes.min(2) as u32; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { vclt_u32(vcreate_u32(INDICES), vmov_n_u32(n)) } + } + + #[cfg(any(test, miri))] + fn to_array(self) -> Self::Array { + // SAFETY: Both the source and destination types are trivially destructible and are + // valid for all possible bit-patterns. + unsafe { std::mem::transmute::(self) } + } +} + +impl MaskOps for uint32x4_t { + type BitMask = BitMask<4, Neon>; + type Array = [u32; 4]; + + #[inline(always)] + fn move_mask(self, arch: Neon) -> Self::BitMask { + cfg_if::cfg_if! { + if #[cfg(miri)] { + let array = self.to_array(); + BitMask::from_fn(arch, |i| array[i] == u32::MAX) + } else { + // Refer to the implementation for `uint16x4_t`. The approach here is + // identical, just twice as wide. + // + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + let value = unsafe { + let bits = vshrq_n_u32(self, 31); + let paired = vsraq_n_u64( + vreinterpretq_u64_u32(bits), + vreinterpretq_u64_u32(bits), + 31, + ); + // Narrow the two u64 lanes to two u32 lanes in a 64-bit register. + let narrowed = vmovn_u64(paired); + let packed = vsra_n_u64( + vreinterpret_u64_u32(narrowed), + vreinterpret_u64_u32(narrowed), + 30, + ); + vget_lane_u8(vreinterpret_u8_u64(packed), 0) + }; + BitMask::from_int(arch, value) + } + } + } + + #[inline(always)] + fn from_mask(mask: Self::BitMask) -> Self { + const BIT_SELECTOR_LOW: u64 = 0x0000_0002_0000_0001; + const BIT_SELECTOR_HIGH: u64 = 0x0000_0008_0000_0004; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { + vtstq_u32( + vmovq_n_u32(mask.0 as u32), + vcombine_u32( + vcreate_u32(BIT_SELECTOR_LOW), + vcreate_u32(BIT_SELECTOR_HIGH), + ), + ) + } + } + + #[inline(always)] + fn keep_first(_arch: Neon, lanes: usize) -> Self { + const LO: u64 = 0x0000_0001_0000_0000; + const HI: u64 = 0x0000_0003_0000_0002; + let n = lanes.min(4) as u32; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { + vcltq_u32( + vcombine_u32(vcreate_u32(LO), vcreate_u32(HI)), + vmovq_n_u32(n), + ) + } + } + + #[cfg(any(test, miri))] + fn to_array(self) -> Self::Array { + // SAFETY: Both the source and destination types are trivially destructible and are + // valid for all possible bit-patterns. + unsafe { std::mem::transmute::(self) } + } +} + +impl MaskOps for uint64x1_t { + type BitMask = BitMask<1, Neon>; + type Array = [u64; 1]; + + #[inline(always)] + fn move_mask(self, arch: Neon) -> Self::BitMask { + cfg_if::cfg_if! { + if #[cfg(miri)] { + let array = self.to_array(); + BitMask::from_fn(arch, |i| array[i] == u64::MAX) + } else { + // Single lane: just shift the MSB down to bit 0 and extract. + // + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + let value = unsafe { + vget_lane_u8(vreinterpret_u8_u64(vshr_n_u64(self, 63)), 0) + }; + BitMask::from_int(arch, value) + } + } + } + + #[inline(always)] + fn from_mask(mask: Self::BitMask) -> Self { + // Single lane: negation maps 0→0 and 1→0xFFFF_FFFF_FFFF_FFFF. + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { vcreate_u64((mask.0 as u64).wrapping_neg()) } + } + + #[inline(always)] + fn keep_first(_arch: Neon, lanes: usize) -> Self { + // Single lane: negation maps 0→0 and 1→0xFFFF_FFFF_FFFF_FFFF. + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { vcreate_u64((lanes.min(1) as u64).wrapping_neg()) } + } + + #[cfg(any(test, miri))] + fn to_array(self) -> Self::Array { + // SAFETY: Both the source and destination types are trivially destructible and are + // valid for all possible bit-patterns. + unsafe { std::mem::transmute::(self) } + } +} + +impl MaskOps for uint64x2_t { + type BitMask = BitMask<2, Neon>; + type Array = [u64; 2]; + + #[inline(always)] + fn move_mask(self, arch: Neon) -> Self::BitMask { + cfg_if::cfg_if! { + if #[cfg(miri)] { + let array = self.to_array(); + BitMask::from_fn(arch, |i| array[i] == u64::MAX) + } else { + // Normalize each lane to 0 or 1, then narrow to a 64-bit register and + // use shift-right-accumulate to combine the two bits. + // + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + let value = unsafe { + let bits = vshrq_n_u64(self, 63); + let narrowed = vmovn_u64(bits); + let packed = vsra_n_u64( + vreinterpret_u64_u32(narrowed), + vreinterpret_u64_u32(narrowed), + 31, + ); + vget_lane_u8(vreinterpret_u8_u64(packed), 0) + }; + BitMask::from_int(arch, value) + } + } + } + + #[inline(always)] + fn from_mask(mask: Self::BitMask) -> Self { + const BIT_SELECTOR_LOW: u64 = 0x0000_0000_0000_0001; + const BIT_SELECTOR_HIGH: u64 = 0x0000_0000_0000_0002; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { + vtstq_u64( + vmovq_n_u64(mask.0 as u64), + vcombine_u64( + vcreate_u64(BIT_SELECTOR_LOW), + vcreate_u64(BIT_SELECTOR_HIGH), + ), + ) + } + } + + #[inline(always)] + fn keep_first(_arch: Neon, lanes: usize) -> Self { + const LO: u64 = 0; + const HI: u64 = 1; + let n = lanes.min(2) as u64; + // SAFETY: Inclusion of this function is dependent on the "neon" target + // feature. This function does not access memory directly. + unsafe { + vcltq_u64( + vcombine_u64(vcreate_u64(LO), vcreate_u64(HI)), + vmovq_n_u64(n), + ) + } + } + + #[cfg(any(test, miri))] + fn to_array(self) -> Self::Array { + // SAFETY: Both the source and destination types are trivially destructible and are + // valid for all possible bit-patterns. + unsafe { std::mem::transmute::(self) } + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Const, SupportedLaneCount}; + + trait MaskTraits: std::fmt::Debug { + const SET: Self; + const UNSET: Self; + } + + impl MaskTraits for u8 { + const SET: u8 = u8::MAX; + const UNSET: u8 = 0; + } + + impl MaskTraits for u16 { + const SET: u16 = u16::MAX; + const UNSET: u16 = 0; + } + + impl MaskTraits for u32 { + const SET: u32 = u32::MAX; + const UNSET: u32 = 0; + } + + impl MaskTraits for u64 { + const SET: u64 = u64::MAX; + const UNSET: u64 = 0; + } + + trait AllValues: SIMDMask { + fn all_values() -> impl Iterator::Underlying>; + } + + impl AllValues for BitMask<1, Neon> { + fn all_values() -> impl Iterator::Underlying> { + 0..2 + } + } + + impl AllValues for BitMask<2, Neon> { + fn all_values() -> impl Iterator::Underlying> { + 0..4 + } + } + + impl AllValues for BitMask<4, Neon> { + fn all_values() -> impl Iterator::Underlying> { + 0..16 + } + } + + impl AllValues for BitMask<8, Neon> { + fn all_values() -> impl Iterator::Underlying> { + 0..=u8::MAX + } + } + + impl AllValues for BitMask<16, Neon> { + fn all_values() -> impl Iterator::Underlying> { + 0..=u16::MAX + } + } + + fn test_mask() + where + Const: SupportedLaneCount, + BitMask: SIMDMask + AllValues + From, + T: MaskTraits + PartialEq + Copy, + M: SIMDMask> + From>, + ::Underlying: MaskOps, Array = [T; N]>, + { + let arch = Neon::new_checked().unwrap(); + + // Test keep-first. + for i in 0..N + 5 { + let m = M::keep_first(arch, i); + + // Inspect the underlying mask. + let a = m.to_underlying().to_array(); + assert_eq!(a.len(), N); + for (j, v) in a.into_iter().enumerate() { + if j < i { + assert_eq!( + v, + T::SET, + "expected lane {} of keep_first({}) to be {:?}. Instead, it is {:?}", + j, + i, + T::SET, + v + ); + } else { + assert_eq!( + v, + T::UNSET, + "expected lane {} of keep_first({}) to be {:?}. Instead, it is {:?}", + j, + i, + T::UNSET, + v + ); + } + } + + // Inspect the bitmask. + assert_eq!(m.bitmask(), BitMask::::keep_first(arch, i)); + } + + // Test all bitmask precursors. + for v in BitMask::::all_values() { + let bitmask = BitMask::::from_underlying(arch, v); + let mask = >>::from(bitmask); + + assert_eq!(BitMask::::from(mask), bitmask); + let a = mask.to_underlying().to_array(); + assert_eq!(a.len(), N); + for (j, v) in a.into_iter().enumerate() { + if bitmask.get_unchecked(j) { + assert_eq!( + v, + T::SET, + "expected lane {} to be {:?}. Instead, it is {:?}", + j, + T::SET, + v + ); + } else { + assert_eq!( + v, + T::UNSET, + "expected lane {} to be {:?}. Instead, it is {:?}", + j, + T::UNSET, + v + ); + } + } + } + } + + #[test] + fn test_mask8x8() { + test_mask::(); + } + + #[cfg(not(miri))] + #[test] + fn test_mask8x16() { + test_mask::(); + } + + #[test] + fn test_mask16x4() { + test_mask::(); + } + + #[test] + fn test_mask16x8() { + test_mask::(); + } + + #[test] + fn test_mask32x2() { + test_mask::(); + } + + #[test] + fn test_mask32x4() { + test_mask::(); + } + + #[test] + fn test_mask64x1() { + test_mask::(); + } + + #[test] + fn test_mask64x2() { + test_mask::(); + } +} diff --git a/diskann-wide/src/arch/aarch64/mod.rs b/diskann-wide/src/arch/aarch64/mod.rs new file mode 100644 index 000000000..74df594f9 --- /dev/null +++ b/diskann-wide/src/arch/aarch64/mod.rs @@ -0,0 +1,382 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Architecture, SIMDVector, + arch::{ + self, AddLifetime, Dispatched1, Dispatched2, Dispatched3, FTarget1, FTarget2, FTarget3, + Hidden, Scalar, Target, Target1, Target2, Target3, + }, +}; + +pub mod f16x4_; +pub use f16x4_::f16x4; + +pub mod f16x8_; +pub use f16x8_::f16x8; + +pub mod f32x2_; +pub use f32x2_::f32x2; + +pub mod f32x4_; +pub use f32x4_::f32x4; + +// Unsigned +pub mod u8x8_; +pub use u8x8_::u8x8; + +pub mod u8x16_; +pub use u8x16_::u8x16; + +pub mod u16x8_; +pub use u16x8_::u16x8; + +pub mod u32x4_; +pub use u32x4_::u32x4; + +pub mod u64x2_; +pub use u64x2_::u64x2; + +// Signed +pub mod i8x8_; +pub use i8x8_::i8x8; + +pub mod i8x16_; +pub use i8x16_::i8x16; + +pub mod i16x8_; +pub use i16x8_::i16x8; + +pub mod i32x4_; +pub use i32x4_::i32x4; + +pub mod i64x2_; +pub use i64x2_::i64x2; + +// Extra wide types. +pub mod double; + +pub use double::f16x16; + +pub use double::f32x8; +pub use double::f32x16; + +pub use double::i8x32; +pub use double::i8x64; + +pub use double::i16x16; +pub use double::i16x32; + +pub use double::i32x8; +pub use double::i32x16; + +pub use double::u8x32; +pub use double::u8x64; + +pub use double::u32x8; +pub use double::u32x16; + +pub use double::u64x4; + +// Internal helpers +mod macros; +mod masks; + +// The ordering is `Scalar < V3 < V4`. +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +pub(super) enum LevelInner { + Scalar, + Neon, +} + +///////////// +// Current // +///////////// + +cfg_if::cfg_if! { + if #[cfg(all(target_feature = "neon", target_feature = "dotprod"))] { + pub type Current = Neon; + + pub const fn current() -> Current { + // SAFETY: Requirements are checked at compile time. + unsafe { Neon::new() } + } + } else { + pub type Current = Scalar; + + pub const fn current() -> Current { + Scalar::new() + } + } +} + +///////////////// +// Dispatching // +///////////////// + +pub fn dispatch(f: T) -> R +where + T: Target + Target, +{ + if let Some(arch) = Neon::new_checked() { + arch.run(f) + } else { + Scalar::new().run(f) + } +} + +pub fn dispatch_no_features(f: T) -> R +where + T: Target + Target, +{ + dispatch(f) +} + +pub fn dispatch1(f: T, x0: T0) -> R +where + T: Target1 + Target1, +{ + if let Some(arch) = Neon::new_checked() { + arch.run1(f, x0) + } else { + Scalar::new().run1(f, x0) + } +} + +pub fn dispatch1_no_features(f: T, x0: T0) -> R +where + T: Target1 + Target1, +{ + dispatch1(f, x0) +} + +pub fn dispatch2(f: T, x0: T0, x1: T1) -> R +where + T: Target2 + Target2, +{ + if let Some(arch) = Neon::new_checked() { + arch.run2(f, x0, x1) + } else { + Scalar::new().run2(f, x0, x1) + } +} + +pub fn dispatch2_no_features(f: T, x0: T0, x1: T1) -> R +where + T: Target2 + Target2, +{ + dispatch2(f, x0, x1) +} + +pub fn dispatch3(f: T, x0: T0, x1: T1, x2: T2) -> R +where + T: Target3 + Target3, +{ + if let Some(arch) = Neon::new_checked() { + arch.run3(f, x0, x1, x2) + } else { + Scalar::new().run3(f, x0, x1, x2) + } +} + +pub fn dispatch3_no_features(f: T, x0: T0, x1: T1, x2: T2) -> R +where + T: Target3 + Target3, +{ + dispatch3(f, x0, x1, x2) +} + +////////////////// +// Architecture // +////////////////// + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct Neon(Hidden); + +impl arch::sealed::Sealed for Neon {} + +impl Neon { + /// Construct a new `Neon` architecture struct. + /// + /// # Safety + /// + /// To avoid undefined behavior, this function must only be called on a machine that + /// supports following features + /// + /// * `neon` + /// * `dotprod` + pub const unsafe fn new() -> Self { + Self(Hidden) + } + + /// Construct a new `Neon` architecture if it is safe to do so on the current hardware. + pub fn new_checked() -> Option { + // This check here ensure that if we ever switch to dynamically dispatching to + // `Neon` that we do not forget to update `new_checked`. + if cfg!(all(target_feature = "neon", target_feature = "dotprod")) { + // SAFETY: The compile erorr check above ensure we do not accidentally return an + // unsafe instance of `Self`. + Some(unsafe { Self::new() }) + } else { + None + } + } + + /// Retarget the [`Scalar`] architecture. + pub const fn retarget(self) -> Scalar { + Scalar::new() + } + + fn run_function_with_1(self, x0: T0::Of<'_>) -> R + where + T0: AddLifetime, + F: for<'a> FTarget1>, + { + F::run(self, x0) + } + + fn run_function_with_2(self, x0: T0::Of<'_>, x1: T1::Of<'_>) -> R + where + T0: AddLifetime, + T1: AddLifetime, + F: for<'a, 'b> FTarget2, T1::Of<'b>>, + { + F::run(self, x0, x1) + } + + fn run_function_with_3( + self, + x0: T0::Of<'_>, + x1: T1::Of<'_>, + x2: T2::Of<'_>, + ) -> R + where + T0: AddLifetime, + T1: AddLifetime, + T2: AddLifetime, + F: for<'a, 'b, 'c> FTarget3, T1::Of<'b>, T2::Of<'c>>, + { + F::run(self, x0, x1, x2) + } +} + +impl From for Scalar { + fn from(neon: Neon) -> Self { + neon.retarget() + } +} + +impl arch::Architecture for Neon { + arch::maskdef!(); + arch::typedef!(); + + fn level() -> arch::Level { + arch::Level::neon() + } + + fn run(self, f: F) -> R + where + F: Target, + { + f.run(self) + } + + fn run1(self, f: F, x0: T0) -> R + where + F: Target1, + { + f.run(self, x0) + } + + fn run2(self, f: F, x0: T0, x1: T1) -> R + where + F: Target2, + { + f.run(self, x0, x1) + } + + fn run3(self, f: F, x0: T0, x1: T1, x2: T2) -> R + where + F: Target3, + { + f.run(self, x0, x1, x2) + } + + #[inline(always)] + fn run_inline(self, f: F) -> R + where + F: Target, + { + f.run(self) + } + + #[inline(always)] + fn run1_inline(self, f: F, x0: T0) -> R + where + F: Target1, + { + f.run(self, x0) + } + + #[inline(always)] + fn run2_inline(self, f: F, x0: T0, x1: T1) -> R + where + F: Target2, + { + f.run(self, x0, x1) + } + + #[inline(always)] + fn run3_inline(self, f: F, x0: T0, x1: T1, x2: T2) -> R + where + F: Target3, + { + f.run(self, x0, x1, x2) + } + + fn dispatch1(self) -> Dispatched1 + where + T0: AddLifetime, + F: for<'a> FTarget1>, + { + let f: unsafe fn(Self, T0::Of<'_>) -> R = Self::run_function_with_1::; + + // SAFETY: The present of `self` as an argument attests that it is safe to construct + // a `Neon` architecture. Additionally, since `V3` is a `Copy` zero-sized type, + // it is safe to wink into existence and is ABI compattible with `Hidden`. + unsafe { arch::hide1(f) } + } + + fn dispatch2(self) -> Dispatched2 + where + T0: AddLifetime, + T1: AddLifetime, + F: for<'a, 'b> FTarget2, T1::Of<'b>>, + { + let f: unsafe fn(Self, T0::Of<'_>, T1::Of<'_>) -> R = + Self::run_function_with_2::; + + // SAFETY: The present of `self` as an argument attests that it is safe to construct + // a `Neon` architecture. Additionally, since `V3` is a `Copy` zero-sized type, + // it is safe to wink into existence and is ABI compattible with `Hidden`. + unsafe { arch::hide2(f) } + } + + fn dispatch3(self) -> Dispatched3 + where + T0: AddLifetime, + T1: AddLifetime, + T2: AddLifetime, + F: for<'a, 'b, 'c> FTarget3, T1::Of<'b>, T2::Of<'c>>, + { + let f: unsafe fn(Self, T0::Of<'_>, T1::Of<'_>, T2::Of<'_>) -> R = + Self::run_function_with_3::; + + // SAFETY: The present of `self` as an argument attests that it is safe to construct + // A `Neon` architecture. Additionally, since `V3` is a `Copy` zero-sized type, + // it is safe to wink into existence and is ABI compattible with `Hidden`. + unsafe { arch::hide3(f) } + } +} diff --git a/diskann-wide/src/arch/aarch64/u16x8_.rs b/diskann-wide/src/arch/aarch64/u16x8_.rs new file mode 100644 index 000000000..bc5e941b1 --- /dev/null +++ b/diskann-wide/src/arch/aarch64/u16x8_.rs @@ -0,0 +1,97 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, + constant::Const, + helpers, + traits::{SIMDMask, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDVector}, +}; + +// AArch64 masks +use super::{ + Neon, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask16x8, +}; + +// AArch64 intrinsics +use std::arch::aarch64::*; + +///////////////////// +// 16-bit unsigned // +///////////////////// + +macros::aarch64_define_register!(u16x8, uint16x8_t, mask16x8, u16, 8, Neon); +macros::aarch64_define_splat!(u16x8, vmovq_n_u16); +macros::aarch64_define_loadstore!(u16x8, vld1q_u16, vst1q_u16, 8); + +helpers::unsafe_map_binary_op!(u16x8, std::ops::Add, add, vaddq_u16, "neon"); +helpers::unsafe_map_binary_op!(u16x8, std::ops::Sub, sub, vsubq_u16, "neon"); +helpers::unsafe_map_binary_op!(u16x8, std::ops::Mul, mul, vmulq_u16, "neon"); +macros::aarch64_define_fma!(u16x8, vmlaq_u16); + +macros::aarch64_define_cmp!( + u16x8, + vceqq_u16, + (vmvnq_u16), + vcltq_u16, + vcleq_u16, + vcgtq_u16, + vcgeq_u16 +); +macros::aarch64_define_bitops!( + u16x8, + vmvnq_u16, + vandq_u16, + vorrq_u16, + veorq_u16, + ( + vshlq_u16, + 16, + vnegq_s16, + vminq_u16, + vreinterpretq_s16_u16, + std::convert::identity + ), + (u16, i16, vmovq_n_s16), +); + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Ops + test_utils::ops::test_add!(u16x8, 0x3017fd73c99cc633, Neon::new_checked()); + test_utils::ops::test_sub!(u16x8, 0xfc627f10b5f8db8a, Neon::new_checked()); + test_utils::ops::test_mul!(u16x8, 0x0f4caa80eceaa523, Neon::new_checked()); + test_utils::ops::test_fma!(u16x8, 0xb8f702ba85375041, Neon::new_checked()); + + test_utils::ops::test_cmp!(u16x8, 0x941757bd5cc641a1, Neon::new_checked()); + + // Bit ops + test_utils::ops::test_bitops!(u16x8, 0xd62d8de09f82ed4e, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/u32x4_.rs b/diskann-wide/src/arch/aarch64/u32x4_.rs new file mode 100644 index 000000000..58d001390 --- /dev/null +++ b/diskann-wide/src/arch/aarch64/u32x4_.rs @@ -0,0 +1,162 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, SIMDDotProduct, SIMDMask, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDSelect, + SIMDSumTree, SIMDVector, constant::Const, helpers, +}; + +// AArch64 masks +use super::{ + Neon, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask32x4, + u8x16, +}; + +// AArch64 intrinsics +use std::arch::{aarch64::*, asm}; + +///////////////////// +// 32-bit unsigned // +///////////////////// + +macros::aarch64_define_register!(u32x4, uint32x4_t, mask32x4, u32, 4, Neon); +macros::aarch64_define_splat!(u32x4, vmovq_n_u32); +macros::aarch64_define_loadstore!(u32x4, vld1q_u32, vst1q_u32, 4); + +helpers::unsafe_map_binary_op!(u32x4, std::ops::Add, add, vaddq_u32, "neon"); +helpers::unsafe_map_binary_op!(u32x4, std::ops::Sub, sub, vsubq_u32, "neon"); +helpers::unsafe_map_binary_op!(u32x4, std::ops::Mul, mul, vmulq_u32, "neon"); +macros::aarch64_define_fma!(u32x4, vmlaq_u32); + +macros::aarch64_define_cmp!( + u32x4, + vceqq_u32, + (vmvnq_u32), + vcltq_u32, + vcleq_u32, + vcgtq_u32, + vcgeq_u32 +); +macros::aarch64_define_bitops!( + u32x4, + vmvnq_u32, + vandq_u32, + vorrq_u32, + veorq_u32, + ( + vshlq_u32, + 32, + vnegq_s32, + vminq_u32, + vreinterpretq_s32_u32, + std::convert::identity + ), + (u32, i32, vmovq_n_s32), +); + +impl SIMDSumTree for u32x4 { + #[inline(always)] + fn sum_tree(self) -> u32 { + if cfg!(miri) { + self.emulated().sum_tree() + } else { + // SAFETY: Allowed by the `Neon` architecture. + unsafe { vaddvq_u32(self.0) } + } + } +} + +impl SIMDSelect for mask32x4 { + #[inline(always)] + fn select(self, x: u32x4, y: u32x4) -> u32x4 { + // SAFETY: Allowed by the `Neon` architecture. + u32x4(unsafe { vbslq_u32(self.0, x.0, y.0) }) + } +} + +impl SIMDDotProduct for u32x4 { + #[inline(always)] + fn dot_simd(self, left: u8x16, right: u8x16) -> Self { + if cfg!(miri) { + use crate::AsSIMD; + self.emulated() + .dot_simd(left.emulated(), right.emulated()) + .as_simd(self.arch()) + } else { + // SAFETY: Instantiating `Neon` implies `dotprod`. + // + // We need this wrapper to allow compilation of the underlying ASM when compiling + // without the `dotprod` feature globally enabled. + #[target_feature(enable = "dotprod")] + unsafe fn udot(mut s: uint32x4_t, x: uint8x16_t, y: uint8x16_t) -> uint32x4_t { + // SAFETY: The `Neon` architecture implies `dotprod`, allowing us to use + // this intrinsic. + unsafe { + asm!( + "udot {0:v}.4s, {1:v}.16b, {2:v}.16b", + inout(vreg) s, + in(vreg) x, + in(vreg) y, + options(pure, nomem, nostack) + ); + } + + s + } + + // SAFETY: The `Neon` architecture guarantees the `dotprod` feature. + Self::from_underlying(self.arch(), unsafe { udot(self.0, left.0, right.0) }) + } + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Ops + test_utils::ops::test_add!(u32x4, 0x3017fd73c99cc633, Neon::new_checked()); + test_utils::ops::test_sub!(u32x4, 0xfc627f10b5f8db8a, Neon::new_checked()); + test_utils::ops::test_mul!(u32x4, 0x0f4caa80eceaa523, Neon::new_checked()); + test_utils::ops::test_fma!(u32x4, 0xb8f702ba85375041, Neon::new_checked()); + + test_utils::ops::test_cmp!(u32x4, 0x941757bd5cc641a1, Neon::new_checked()); + + // Dot Product + test_utils::dot_product::test_dot_product!( + (u8x16, u8x16) => u32x4, + 0x145f89b446c03ff1, + Neon::new_checked() + ); + + // Bit ops + test_utils::ops::test_bitops!(u32x4, 0xd62d8de09f82ed4e, Neon::new_checked()); + + // Reductions + test_utils::ops::test_sumtree!(u32x4, 0xb9ac82ab23a855da, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/u64x2_.rs b/diskann-wide/src/arch/aarch64/u64x2_.rs new file mode 100644 index 000000000..45e1fbaab --- /dev/null +++ b/diskann-wide/src/arch/aarch64/u64x2_.rs @@ -0,0 +1,127 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, + arch::Scalar, + constant::Const, + helpers, + traits::{SIMDMask, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDVector}, +}; + +// AArch64 masks +use super::{ + Neon, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask64x2, +}; + +// AArch64 intrinsics +use std::arch::aarch64::*; + +///////////////////// +// 64-bit unsigned // +///////////////////// + +#[inline(always)] +pub(super) unsafe fn emulated_vmvnq_u64(x: uint64x2_t) -> uint64x2_t { + let x: [u64; 2] = u64x2(x).to_array(); + let mapped: [u64; 2] = core::array::from_fn(|i| !x[i]); + // SAFETY: This is only called in a context where the caller guarantees `Neon` is + // available. + u64x2::from_array(unsafe { Neon::new() }, mapped).0 +} + +#[inline(always)] +pub(super) unsafe fn emulated_vminq_u64(x: uint64x2_t, y: uint64x2_t) -> uint64x2_t { + let x = u64x2(x).to_array(); + let y = u64x2(y).to_array(); + let mapped: [u64; 2] = core::array::from_fn(|i| x[i].min(y[i])); + // SAFETY: This is only called in a context where the caller guarantees `Neon` is + // available. + u64x2::from_array(unsafe { Neon::new() }, mapped).0 +} + +macros::aarch64_define_register!(u64x2, uint64x2_t, mask64x2, u64, 2, Neon); +macros::aarch64_define_splat!(u64x2, vmovq_n_u64); +macros::aarch64_define_loadstore!(u64x2, vld1q_u64, vst1q_u64, 2); + +helpers::unsafe_map_binary_op!(u64x2, std::ops::Add, add, vaddq_u64, "neon"); +helpers::unsafe_map_binary_op!(u64x2, std::ops::Sub, sub, vsubq_u64, "neon"); + +impl std::ops::Mul for u64x2 { + type Output = Self; + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + let x = Emulated::::from_array(Scalar, self.to_array()); + let y = Emulated::::from_array(Scalar, rhs.to_array()); + Self::from_array(self.arch(), (x * y).to_array()) + } +} + +macros::aarch64_define_fma!(u64x2, integer); + +macros::aarch64_define_cmp!( + u64x2, + vceqq_u64, + (emulated_vmvnq_u64), + vcltq_u64, + vcleq_u64, + vcgtq_u64, + vcgeq_u64 +); +macros::aarch64_define_bitops!( + u64x2, + emulated_vmvnq_u64, + vandq_u64, + vorrq_u64, + veorq_u64, + ( + vshlq_u64, + 64, + vnegq_s64, + emulated_vminq_u64, + vreinterpretq_s64_u64, + std::convert::identity + ), + (u64, i64, vmovq_n_s64), +); + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Binary Ops + test_utils::ops::test_add!(u64x2, 0x8d7bf28b1c6e2545, Neon::new_checked()); + test_utils::ops::test_sub!(u64x2, 0x4a1c644a1a910bed, Neon::new_checked()); + test_utils::ops::test_mul!(u64x2, 0xf42ee707a808fd10, Neon::new_checked()); + test_utils::ops::test_fma!(u64x2, 0x28540d9936a9e803, Neon::new_checked()); + + test_utils::ops::test_cmp!(u64x2, 0xfae27072c6b70885, Neon::new_checked()); + + // Bit ops + test_utils::ops::test_bitops!(u64x2, 0xbe927713ea310164, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/u8x16_.rs b/diskann-wide/src/arch/aarch64/u8x16_.rs new file mode 100644 index 000000000..7933f9437 --- /dev/null +++ b/diskann-wide/src/arch/aarch64/u8x16_.rs @@ -0,0 +1,100 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, + constant::Const, + helpers, + traits::{SIMDMask, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDVector}, +}; + +// AArch64 masks +use super::{ + Neon, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask8x16, + u8x8, +}; + +// AArch64 intrinsics +use std::arch::aarch64::*; + +//////////////////// +// 8-bit unsigned // +//////////////////// + +macros::aarch64_define_register!(u8x16, uint8x16_t, mask8x16, u8, 16, Neon); +macros::aarch64_define_splat!(u8x16, vmovq_n_u8); +macros::aarch64_define_loadstore!(u8x16, vld1q_u8, vst1q_u8, 16); +macros::aarch64_splitjoin!(u8x16, u8x8, vget_low_u8, vget_high_u8, vcombine_u8); + +helpers::unsafe_map_binary_op!(u8x16, std::ops::Add, add, vaddq_u8, "neon"); +helpers::unsafe_map_binary_op!(u8x16, std::ops::Sub, sub, vsubq_u8, "neon"); +helpers::unsafe_map_binary_op!(u8x16, std::ops::Mul, mul, vmulq_u8, "neon"); +macros::aarch64_define_fma!(u8x16, vmlaq_u8); + +macros::aarch64_define_cmp!( + u8x16, + vceqq_u8, + (vmvnq_u8), + vcltq_u8, + vcleq_u8, + vcgtq_u8, + vcgeq_u8 +); +macros::aarch64_define_bitops!( + u8x16, + vmvnq_u8, + vandq_u8, + vorrq_u8, + veorq_u8, + ( + vshlq_u8, + 8, + vnegq_s8, + vminq_u8, + vreinterpretq_s8_u8, + std::convert::identity + ), + (u8, i8, vmovq_n_s8), +); + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Ops + test_utils::ops::test_add!(u8x16, 0x3017fd73c99cc633, Neon::new_checked()); + test_utils::ops::test_sub!(u8x16, 0xfc627f10b5f8db8a, Neon::new_checked()); + test_utils::ops::test_mul!(u8x16, 0x0f4caa80eceaa523, Neon::new_checked()); + test_utils::ops::test_fma!(u8x16, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_splitjoin!(u8x16 => u8x8, 0xa4d00a4d04293967, Neon::new_checked()); + + test_utils::ops::test_cmp!(u8x16, 0x941757bd5cc641a1, Neon::new_checked()); + + // Bit ops + test_utils::ops::test_bitops!(u8x16, 0xd62d8de09f82ed4e, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/aarch64/u8x8_.rs b/diskann-wide/src/arch/aarch64/u8x8_.rs new file mode 100644 index 000000000..914000037 --- /dev/null +++ b/diskann-wide/src/arch/aarch64/u8x8_.rs @@ -0,0 +1,89 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use crate::{ + Emulated, + constant::Const, + helpers, + traits::{SIMDMask, SIMDMulAdd, SIMDPartialEq, SIMDPartialOrd, SIMDVector}, +}; + +// AArch64 masks +use super::{ + Neon, + macros::{self, AArchLoadStore, AArchSplat}, + masks::mask8x8, +}; + +// AArch64 intrinsics +use std::arch::aarch64::*; + +//////////////////// +// 8-bit unsigned // +//////////////////// + +macros::aarch64_define_register!(u8x8, uint8x8_t, mask8x8, u8, 8, Neon); +macros::aarch64_define_splat!(u8x8, vmov_n_u8); +macros::aarch64_define_loadstore!(u8x8, vld1_u8, vst1_u8, 8); + +helpers::unsafe_map_binary_op!(u8x8, std::ops::Add, add, vadd_u8, "neon"); +helpers::unsafe_map_binary_op!(u8x8, std::ops::Sub, sub, vsub_u8, "neon"); +helpers::unsafe_map_binary_op!(u8x8, std::ops::Mul, mul, vmul_u8, "neon"); +macros::aarch64_define_fma!(u8x8, vmla_u8); + +macros::aarch64_define_cmp!(u8x8, vceq_u8, (vmvn_u8), vclt_u8, vcle_u8, vcgt_u8, vcge_u8); +macros::aarch64_define_bitops!( + u8x8, + vmvn_u8, + vand_u8, + vorr_u8, + veor_u8, + ( + vshl_u8, + 8, + vneg_s8, + vmin_u8, + vreinterpret_s8_u8, + std::convert::identity + ), + (u8, i8, vmov_n_s8), +); + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + use crate::{reference::ReferenceScalarOps, test_utils}; + + #[test] + fn miri_test_load() { + test_utils::test_load_simd::(Neon::new_checked().unwrap()); + } + + #[test] + fn miri_test_store() { + test_utils::test_store_simd::(Neon::new_checked().unwrap()); + } + + // constructors + #[test] + fn test_constructors() { + test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + } + + // Ops + test_utils::ops::test_add!(u8x8, 0x3017fd73c99cc633, Neon::new_checked()); + test_utils::ops::test_sub!(u8x8, 0xfc627f10b5f8db8a, Neon::new_checked()); + test_utils::ops::test_mul!(u8x8, 0x0f4caa80eceaa523, Neon::new_checked()); + test_utils::ops::test_fma!(u8x8, 0xb8f702ba85375041, Neon::new_checked()); + + test_utils::ops::test_cmp!(u8x8, 0x941757bd5cc641a1, Neon::new_checked()); + + // Bit ops + test_utils::ops::test_bitops!(u8x8, 0xd62d8de09f82ed4e, Neon::new_checked()); +} diff --git a/diskann-wide/src/arch/mod.rs b/diskann-wide/src/arch/mod.rs index 4b030c9a3..9ec5a3573 100644 --- a/diskann-wide/src/arch/mod.rs +++ b/diskann-wide/src/arch/mod.rs @@ -409,6 +409,30 @@ cfg_if::cfg_if! { Self(LevelInner::V4) } } + } else if #[cfg(target_arch = "aarch64")] { + // Delegate to the architecture selection within the `aarch64` module,. + pub mod aarch64; + + use aarch64::LevelInner; + + pub use aarch64::current; + pub use aarch64::Current; + + pub use aarch64::dispatch; + pub use aarch64::dispatch1; + pub use aarch64::dispatch2; + pub use aarch64::dispatch3; + + pub use aarch64::dispatch_no_features; + pub use aarch64::dispatch1_no_features; + pub use aarch64::dispatch2_no_features; + pub use aarch64::dispatch3_no_features; + + impl Level { + const fn neon() -> Self { + Self(LevelInner::Neon) + } + } } else { pub type Current = Scalar; diff --git a/diskann-wide/src/doubled.rs b/diskann-wide/src/doubled.rs index 74e529c86..66e69be65 100644 --- a/diskann-wide/src/doubled.rs +++ b/diskann-wide/src/doubled.rs @@ -377,6 +377,7 @@ macro_rules! double_mask { ) } + #[inline(always)] fn get_unchecked(&self, i: usize) -> bool { if i < { $N / 2 } { self.0.get_unchecked(i) @@ -385,6 +386,7 @@ macro_rules! double_mask { } } + #[inline(always)] fn keep_first(arch: Self::Arch, i: usize) -> Self { let lo = <$repr>::keep_first(arch, i); let hi = <$repr>::keep_first(arch, i.saturating_sub({ $N / 2 })); diff --git a/diskann-wide/src/emulated.rs b/diskann-wide/src/emulated.rs index f9a454698..2554527a3 100644 --- a/diskann-wide/src/emulated.rs +++ b/diskann-wide/src/emulated.rs @@ -437,12 +437,62 @@ macro_rules! impl_simd_dot_product_iu8_to_i32 { self.dot_simd(right, left) } } + + impl SIMDDotProduct, Emulated> + for Emulated + where + A: arch::Sealed, + { + fn dot_simd(self, left: Emulated, right: Emulated) -> Self { + self + Self::from_arch_fn(self.1, |i| { + let l0: u32 = left.0[4 * i].into(); + let l1: u32 = left.0[4 * i + 1].into(); + let l2: u32 = left.0[4 * i + 2].into(); + let l3: u32 = left.0[4 * i + 3].into(); + + let r0: u32 = right.0[4 * i].into(); + let r1: u32 = right.0[4 * i + 1].into(); + let r2: u32 = right.0[4 * i + 2].into(); + let r3: u32 = right.0[4 * i + 3].into(); + + let a = l0.expected_fma_(r0, l1.expected_mul_(r1)); + let b = l2.expected_fma_(r2, l3.expected_mul_(r3)); + a + b + }) + } + } + + impl SIMDDotProduct, Emulated> + for Emulated + where + A: arch::Sealed, + { + fn dot_simd(self, left: Emulated, right: Emulated) -> Self { + self + Self::from_arch_fn(self.1, |i| { + let l0: i32 = left.0[4 * i].into(); + let l1: i32 = left.0[4 * i + 1].into(); + let l2: i32 = left.0[4 * i + 2].into(); + let l3: i32 = left.0[4 * i + 3].into(); + + let r0: i32 = right.0[4 * i].into(); + let r1: i32 = right.0[4 * i + 1].into(); + let r2: i32 = right.0[4 * i + 2].into(); + let r3: i32 = right.0[4 * i + 3].into(); + + let a = l0.expected_fma_(r0, l1.expected_mul_(r1)); + let b = l2.expected_fma_(r2, l3.expected_mul_(r3)); + a + b + }) + } + } }; } +impl_simd_dot_product_i16_to_i32!(4, 8); impl_simd_dot_product_i16_to_i32!(8, 16); impl_simd_dot_product_i16_to_i32!(16, 32); +impl_simd_dot_product_iu8_to_i32!(4, 16); impl_simd_dot_product_iu8_to_i32!(8, 32); impl_simd_dot_product_iu8_to_i32!(16, 64); @@ -814,12 +864,26 @@ mod test_emulated { test_utils::dot_product::test_dot_product!( (Emulated, Emulated) => Emulated, 0x3001f05604e96289, SC ); + test_utils::dot_product::test_dot_product!( + (Emulated, Emulated) => Emulated, 0x3001f05604e96289, SC + ); + test_utils::dot_product::test_dot_product!( (Emulated, Emulated) => Emulated, 0x3001f05604e96289, SC ); test_utils::dot_product::test_dot_product!( (Emulated, Emulated) => Emulated, 0x3001f05604e96289, SC ); + test_utils::dot_product::test_dot_product!( + (Emulated, Emulated) => Emulated, 0x3001f05604e96289, SC + ); + + test_utils::dot_product::test_dot_product!( + (Emulated, Emulated) => Emulated, 0x3001f05604e96289, SC + ); + test_utils::dot_product::test_dot_product!( + (Emulated, Emulated) => Emulated, 0x3001f05604e96289, SC + ); // reductions test_utils::ops::test_sumtree!(Emulated, 0x410bad8207a8ccfc, SC); diff --git a/diskann-wide/src/helpers.rs b/diskann-wide/src/helpers.rs index da1ef7ba4..2d204afda 100644 --- a/diskann-wide/src/helpers.rs +++ b/diskann-wide/src/helpers.rs @@ -42,6 +42,18 @@ macro_rules! unsafe_map_unary_op { /// /// SAFETY: It is the invoker's responsibility to ensure that the intrinsic is safe to call. macro_rules! unsafe_map_conversion { + ($from:ty, $to:ty, ($i1:ident, $i0:ident), $requires:literal) => { + impl From<$from> for $to { + #[inline(always)] + fn from(value: $from) -> $to { + // SAFETY: The invoker of this macro must pass the `target_feature` + // requirement of the intrinsic. + // + // That way, if the intrinsic is not available, we get a compile-time error. + Self(unsafe { $i1($i0(value.0)) }) + } + } + }; ($from:ty, $to:ty, $intrinsic:expr, $requires:literal) => { impl From<$from> for $to { #[inline(always)] @@ -79,6 +91,7 @@ macro_rules! unsafe_map_cast { } /// Implement shifting by calling Splat. +#[cfg(target_arch = "x86_64")] macro_rules! scalar_shift_by_splat { ($T:ty, $scalar:ty) => { impl std::ops::Shr<$scalar> for $T { @@ -100,6 +113,7 @@ macro_rules! scalar_shift_by_splat { } // Allow modules in this crate to use these macros. +#[cfg(target_arch = "x86_64")] pub(crate) use scalar_shift_by_splat; pub(crate) use unsafe_map_binary_op; pub(crate) use unsafe_map_cast; diff --git a/diskann-wide/src/lib.rs b/diskann-wide/src/lib.rs index 5e382775e..287284553 100644 --- a/diskann-wide/src/lib.rs +++ b/diskann-wide/src/lib.rs @@ -154,7 +154,6 @@ pub use splitjoin::{LoHi, SplitJoin}; mod bitmask; pub use bitmask::{BitMask, FromInt}; -#[cfg(target_arch = "x86_64")] pub(crate) mod doubled; mod emulated; @@ -232,7 +231,6 @@ fn get_test_arch() -> Option { } } -#[cfg(not(target_arch = "aarch64"))] pub(crate) mod helpers; #[cfg(test)] diff --git a/diskann-wide/src/test_utils/dot_product.rs b/diskann-wide/src/test_utils/dot_product.rs index 960c00d08..0b97df6d9 100644 --- a/diskann-wide/src/test_utils/dot_product.rs +++ b/diskann-wide/src/test_utils/dot_product.rs @@ -70,6 +70,48 @@ impl ExpectedDot for DotSchema { } } +impl ExpectedDot for DotSchema { + fn expected_dot_impl(accumulator: i32, left: &[i8; 4], right: &[i8; 4]) -> i32 { + let l0: i32 = left[0].into(); + let l1: i32 = left[1].into(); + let l2: i32 = left[2].into(); + let l3: i32 = left[3].into(); + + let r0: i32 = right[0].into(); + let r1: i32 = right[1].into(); + let r2: i32 = right[2].into(); + let r3: i32 = right[3].into(); + + accumulator.expected_add_( + l0.expected_mul_(r0) + .expected_add_(l1.expected_mul_(r1)) + .expected_add_(l2.expected_mul_(r2)) + .expected_add_(l3.expected_mul_(r3)), + ) + } +} + +impl ExpectedDot for DotSchema { + fn expected_dot_impl(accumulator: u32, left: &[u8; 4], right: &[u8; 4]) -> u32 { + let l0: u32 = left[0].into(); + let l1: u32 = left[1].into(); + let l2: u32 = left[2].into(); + let l3: u32 = left[3].into(); + + let r0: u32 = right[0].into(); + let r1: u32 = right[1].into(); + let r2: u32 = right[2].into(); + let r3: u32 = right[3].into(); + + accumulator.expected_add_( + l0.expected_mul_(r0) + .expected_add_(l1.expected_mul_(r1)) + .expected_add_(l2.expected_mul_(r2)) + .expected_add_(l3.expected_mul_(r3)), + ) + } +} + //////////////// // Test Macro // //////////////// @@ -202,9 +244,8 @@ mod tests { ); } - #[test] - fn test_u8_i8_to_i32() { - let a: &[[u8; 4]] = &[ + fn u8_range() -> &'static [[u8; 4]] { + &[ [u8::MIN, u8::MIN, u8::MIN, u8::MIN], [u8::MIN, u8::MIN, u8::MIN, u8::MAX], [u8::MIN, u8::MIN, u8::MAX, u8::MIN], @@ -221,9 +262,11 @@ mod tests { [u8::MAX, u8::MAX, u8::MIN, u8::MAX], [u8::MAX, u8::MAX, u8::MAX, u8::MIN], [u8::MAX, u8::MAX, u8::MAX, u8::MAX], - ]; + ] + } - let b: &[[i8; 4]] = &[ + fn i8_range() -> &'static [[i8; 4]] { + &[ [i8::MIN, i8::MIN, i8::MIN, i8::MIN], [i8::MIN, i8::MIN, i8::MIN, i8::MAX], [i8::MIN, i8::MIN, i8::MAX, i8::MIN], @@ -240,7 +283,13 @@ mod tests { [i8::MAX, i8::MAX, i8::MIN, i8::MAX], [i8::MAX, i8::MAX, i8::MAX, i8::MIN], [i8::MAX, i8::MAX, i8::MAX, i8::MAX], - ]; + ] + } + + #[test] + fn test_u8_i8_to_i32() { + let a = u8_range(); + let b = i8_range(); let bases = [0, 1, -1, i16::MAX as i32, i16::MIN as i32]; @@ -274,4 +323,77 @@ mod tests { } } } + + #[test] + fn test_i8_i8_to_i32() { + let a = i8_range(); + let bases = [0, 1, -1, i16::MAX as i32, i16::MIN as i32]; + + for left in a { + for right in a { + let dot: i32 = (*left) + .into_iter() + .zip((*right).into_iter()) + .map(|(l, r)| (l as i32) * (r as i32)) + .sum(); + for b in bases { + let expected = dot + b; + assert_eq!( + expected, + DotSchema::expected_dot(b, left, right), + "failed for: base = {}, left = {:?}, right = {:?}", + b, + left, + right, + ); + + assert_eq!( + expected, + DotSchema::expected_dot(b, right, left), + "failed for: base = {}, left = {:?}, right = {:?}", + b, + right, + left, + ); + } + } + } + } + + #[test] + fn test_u8_u8_to_u32() { + let a = u8_range(); + + let bases = [0, 1, i16::MAX as u32, u16::MAX as u32]; + + for left in a { + for right in a { + let dot: u32 = (*left) + .into_iter() + .zip((*right).into_iter()) + .map(|(l, r)| (l as u32) * (r as u32)) + .sum(); + for b in bases { + let expected = dot + b; + assert_eq!( + expected, + DotSchema::expected_dot(b, left, right), + "failed for: base = {}, left = {:?}, right = {:?}", + b, + left, + right, + ); + + assert_eq!( + expected, + DotSchema::expected_dot(b, right, left), + "failed for: base = {}, left = {:?}, right = {:?}", + b, + right, + left, + ); + } + } + } + } } diff --git a/diskann-wide/src/test_utils/ops.rs b/diskann-wide/src/test_utils/ops.rs index 7a38b09e4..6acfcea1c 100644 --- a/diskann-wide/src/test_utils/ops.rs +++ b/diskann-wide/src/test_utils/ops.rs @@ -7,15 +7,12 @@ use std::fmt::Debug; // Common test traits. use super::common::{self, ScalarTraits}; -#[cfg(target_arch = "x86_64")] -use crate::SplitJoin; use crate::{ BitMask, Const, SIMDMask, SIMDMinMax, SIMDPartialEq, SIMDPartialOrd, SIMDSumTree, SIMDVector, - SupportedLaneCount, arch, + SplitJoin, SupportedLaneCount, arch, reference::{ReferenceScalarOps, ReferenceShifts, TreeReduce}, }; -#[cfg(target_arch = "x86_64")] fn identity(x: T) -> T { x } @@ -948,7 +945,7 @@ macro_rules! test_sumtree { /////////////// // SplitJoin // /////////////// -#[cfg(target_arch = "x86_64")] + pub fn test_splitjoin_impl(arch: V::Arch, a: &[T]) where T: Copy + Debug + ScalarTraits, @@ -976,7 +973,6 @@ where test_unary_op(&joined.to_array(), a, &identity, "join"); } -#[cfg(target_arch = "x86_64")] macro_rules! test_splitjoin { ($wide:ident $(< $($ps:tt),+ >)? => $half:ident $(< $($hs:tt),+ >)?, $seed:literal, $arch:expr) => { paste::paste! { @@ -1021,7 +1017,6 @@ pub(crate) use test_lossless_convert; pub(crate) use test_minmax; pub(crate) use test_mul; pub(crate) use test_select; -#[cfg(target_arch = "x86_64")] pub(crate) use test_splitjoin; pub(crate) use test_sub; pub(crate) use test_sumtree; diff --git a/diskann-wide/tests/dispatch.rs b/diskann-wide/tests/dispatch.rs index 0507372b0..d935dc1fe 100644 --- a/diskann-wide/tests/dispatch.rs +++ b/diskann-wide/tests/dispatch.rs @@ -3,8 +3,7 @@ * Licensed under the MIT license. */ -#[cfg(target_arch = "x86_64")] -use diskann_wide::{SIMDMulAdd, SIMDSumTree, SIMDVector}; +use diskann_wide::{SIMDFloat, SIMDSumTree, SIMDVector}; #[cfg(target_arch = "x86_64")] type V3 = diskann_wide::arch::x86_64::V3; @@ -12,6 +11,9 @@ type V3 = diskann_wide::arch::x86_64::V3; #[cfg(target_arch = "x86_64")] type V4 = diskann_wide::arch::x86_64::V4; +#[cfg(target_arch = "aarch64")] +use diskann_wide::arch::aarch64::Neon; + struct InnerProduct; impl diskann_wide::arch::Target2 for InnerProduct { @@ -21,40 +23,49 @@ impl diskann_wide::arch::Target2(arch: F::Arch, a: &[f32], b: &[f32]) -> f32 +where + F: SIMDVector + SIMDFloat + SIMDSumTree, +{ + assert_eq!(a.len(), b.len()); + + let lanes: usize = F::LANES; + + let len = a.len(); + let a = a.as_ptr(); + let b = b.as_ptr(); + + let mut sum = F::default(arch); + let trips = len / lanes; + let remainder = len % lanes; + for i in 0..trips { + // SAFETY: By loop construction, `a.add(lanes * (i + 1) - 1)` is always in-bounds. + let wa = unsafe { F::load_simd(arch, a.add(lanes * i)) }; + // SAFETY: By loop construction, `b.add(lanes * (i + 1) - 1)` is always in-bounds. + let wb = unsafe { F::load_simd(arch, b.add(lanes * i)) }; + sum = wa.mul_add_simd(wb, sum); + } + + // Handle and remaining using predicated loads. + if remainder != 0 { + // SAFETY: By loop construction, `a.add(lanes * trips)` is always in-bounds. + let wa = unsafe { F::load_simd_first(arch, a.add(trips * lanes), remainder) }; + // SAFETY: By loop construction, `b.add(lanes * trips)` is always in-bounds. + let wb = unsafe { F::load_simd_first(arch, b.add(trips * lanes), remainder) }; + sum = wa.mul_add_simd(wb, sum); + } + + sum.sum_tree() +} + #[cfg(target_arch = "x86_64")] impl diskann_wide::arch::Target2 for InnerProduct { #[inline(always)] fn run(self, arch: V3, a: &[f32], b: &[f32]) -> f32 { diskann_wide::alias!(f32s = ::f32x8); - // The number of lanes to use in a SIMD register. - const LANES: usize = f32s::LANES; - - let len = a.len(); - let a = a.as_ptr(); - let b = b.as_ptr(); - - let mut sum = f32s::default(arch); - let trips = len / LANES; - let remainder = len % LANES; - for i in 0..trips { - // SAFETY: By loop construction, `a.add(LANES * (i + 1) - 1)` is always in-bounds. - let wa = unsafe { f32s::load_simd(arch, a.add(LANES * i)) }; - // SAFETY: By loop construction, `b.add(LANES * (i + 1) - 1)` is always in-bounds. - let wb = unsafe { f32s::load_simd(arch, b.add(LANES * i)) }; - sum = wa.mul_add_simd(wb, sum); - } - - // Handle and remaining using predicated loads. - if remainder != 0 { - // SAFETY: By loop construction, `a.add(LANES * trips)` is always in-bounds. - let wa = unsafe { f32s::load_simd_first(arch, a.add(trips * LANES), remainder) }; - // SAFETY: By loop construction, `b.add(LANES * trips)` is always in-bounds. - let wb = unsafe { f32s::load_simd_first(arch, b.add(trips * LANES), remainder) }; - sum = wa.mul_add_simd(wb, sum); - } - - sum.sum_tree() + inner_product::(arch, a, b) } } @@ -64,38 +75,21 @@ impl diskann_wide::arch::Target2 for InnerProduct { fn run(self, arch: V4, a: &[f32], b: &[f32]) -> f32 { diskann_wide::alias!(f32s = ::f32x16); - // The number of lanes to use in a SIMD register. - const LANES: usize = f32s::LANES; - - let len = a.len(); - let a = a.as_ptr(); - let b = b.as_ptr(); - - let mut sum = f32s::default(arch); - let trips = len / LANES; - let remainder = len % LANES; - for i in 0..trips { - // SAFETY: By loop construction, `a.add(LANES * (i + 1) - 1)` is always in-bounds. - let wa = unsafe { f32s::load_simd(arch, a.add(LANES * i)) }; - // SAFETY: By loop construction, `b.add(LANES * (i + 1) - 1)` is always in-bounds. - let wb = unsafe { f32s::load_simd(arch, b.add(LANES * i)) }; - sum = wa.mul_add_simd(wb, sum); - } - - // Handle and remaining using predicated loads. - if remainder != 0 { - // SAFETY: By loop construction, `a.add(LANES * trips)` is always in-bounds. - let wa = unsafe { f32s::load_simd_first(arch, a.add(trips * LANES), remainder) }; - // SAFETY: By loop construction, `b.add(LANES * trips)` is always in-bounds. - let wb = unsafe { f32s::load_simd_first(arch, b.add(trips * LANES), remainder) }; - sum = wa.mul_add_simd(wb, sum); - } - - sum.sum_tree() + inner_product::(arch, a, b) } } -// This example +#[cfg(target_arch = "aarch64")] +impl diskann_wide::arch::Target2 for InnerProduct { + #[inline(always)] + fn run(self, arch: Neon, a: &[f32], b: &[f32]) -> f32 { + diskann_wide::alias!(f32s = ::f32x4); + + inner_product::(arch, a, b) + } +} + +// This example shows how `Architectures` can be used in the context of auto-vectorization. struct SquaredL2; impl diskann_wide::arch::Target2 for SquaredL2 { @@ -128,6 +122,15 @@ impl diskann_wide::arch::Target2 for SquaredL2 { } } +#[cfg(target_arch = "aarch64")] +impl diskann_wide::arch::Target2 for SquaredL2 { + #[inline(always)] + fn run(self, arch: Neon, x: &[f32], y: &[f32]) -> f32 { + let arch: diskann_wide::arch::Scalar = arch.into(); + self.run(arch, x, y) + } +} + #[inline(never)] pub fn test_inner_product(a: &[f32], b: &[f32]) -> f32 { diskann_wide::arch::dispatch2(InnerProduct, a, b) From e4cd3352f454bf828675f77d00a29e838a027403 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sun, 15 Feb 2026 21:52:09 -0800 Subject: [PATCH 02/10] Address Clippy. --- diskann-wide/src/arch/aarch64/f32x4_.rs | 4 ++++ diskann-wide/src/arch/aarch64/masks.rs | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/diskann-wide/src/arch/aarch64/f32x4_.rs b/diskann-wide/src/arch/aarch64/f32x4_.rs index aefc0dfd8..f857673c3 100644 --- a/diskann-wide/src/arch/aarch64/f32x4_.rs +++ b/diskann-wide/src/arch/aarch64/f32x4_.rs @@ -38,21 +38,25 @@ macros::aarch64_define_fma!(f32x4, vfmaq_f32); impl SIMDMinMax for f32x4 { #[inline(always)] fn min_simd(self, rhs: Self) -> Self { + // SAFETY: `vminnmq_f32` requires "neon", implied by the `Neon` architecture. Self(unsafe { vminnmq_f32(self.0, rhs.0) }) } #[inline(always)] fn min_simd_standard(self, rhs: Self) -> Self { + // SAFETY: `vminnmq_f32` requires "neon", implied by the `Neon` architecture. Self(unsafe { vminnmq_f32(self.0, rhs.0) }) } #[inline(always)] fn max_simd(self, rhs: Self) -> Self { + // SAFETY: `vminnmq_f32` requires "neon", implied by the `Neon` architecture. Self(unsafe { vmaxnmq_f32(self.0, rhs.0) }) } #[inline(always)] fn max_simd_standard(self, rhs: Self) -> Self { + // SAFETY: `vminnmq_f32` requires "neon", implied by the `Neon` architecture. Self(unsafe { vmaxnmq_f32(self.0, rhs.0) }) } } diff --git a/diskann-wide/src/arch/aarch64/masks.rs b/diskann-wide/src/arch/aarch64/masks.rs index f88c5f57a..8658c4b94 100644 --- a/diskann-wide/src/arch/aarch64/masks.rs +++ b/diskann-wide/src/arch/aarch64/masks.rs @@ -202,7 +202,7 @@ impl MaskOps for uint8x16_t { let low = vaddlv_u8(vshl_u8(vget_low_u8(masked), shifts)); let high = vaddlv_u8(vshl_u8(vget_high_u8(masked), shifts)); - (low as u16) | ((high as u16) << 8) + low | (high << 8) }; BitMask::from_int(arch, value) } From 5e8b87a9b48ec0002afe42094ad39064e26bf72d Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sun, 15 Feb 2026 22:03:45 -0800 Subject: [PATCH 03/10] Enable `neon` and `dotprod` --- .cargo/config.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.cargo/config.toml b/.cargo/config.toml index c6ab5b16d..34bee66f8 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -6,3 +6,6 @@ rustflags = ["-C", "control-flow-guard"] # by setting RUSTFLAGS in the enviornment: RUSTFLAGS="-C target-cpu=x86-64" [target.'cfg(target_arch="x86_64")'] rustflags = ["-C", "target-cpu=x86-64-v3"] + +[target.'cfg(target_arch="aarch64")'] +rustflags = ["-C", "target-feature=+neon,+dotprod"] From ff0c5a479e4209a1aaab1c354131b4bc1c103ef5 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sun, 15 Feb 2026 22:23:20 -0800 Subject: [PATCH 04/10] Fix `benchmark-simd` --- diskann-benchmark-simd/src/lib.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 9f4611020..94509e0e4 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -801,6 +801,11 @@ stamp!("x86_64", diskann_wide::arch::x86_64::V3, f16, f16); stamp!("x86_64", diskann_wide::arch::x86_64::V3, u8, u8); stamp!("x86_64", diskann_wide::arch::x86_64::V3, i8, i8); +stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f32, f32); +stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f16, f16); +stamp!("aarch64", diskann_wide::arch::aarch64::Neon, u8, u8); +stamp!("aarch64", diskann_wide::arch::aarch64::Neon, i8, i8); + stamp!(diskann_wide::arch::Scalar, f32, f32); stamp!(diskann_wide::arch::Scalar, f16, f16); stamp!(diskann_wide::arch::Scalar, u8, u8); @@ -875,6 +880,13 @@ mod reference { } } + #[cfg(target_arch = "aarch64")] + impl MaybeFMA for diskann_wide::arch::aarch64::Neon { + fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 { + a.mul_add(b, c) + } + } + //------------// // Squared L2 // //------------// From 6a26bd620706f5f008b6c94d1044c1b09eb9e45e Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 16 Feb 2026 10:36:12 -0800 Subject: [PATCH 05/10] Checkpoint. --- .../src/arch/aarch64/algorithms/load_first.rs | 404 ++++++++++++++++++ .../src/arch/aarch64/algorithms/mod.rs | 6 + diskann-wide/src/arch/aarch64/f32x2_.rs | 6 +- diskann-wide/src/arch/aarch64/f32x4_.rs | 10 +- diskann-wide/src/arch/aarch64/i16x8_.rs | 10 +- diskann-wide/src/arch/aarch64/i32x4_.rs | 10 +- diskann-wide/src/arch/aarch64/i64x2_.rs | 10 +- diskann-wide/src/arch/aarch64/i8x16_.rs | 4 +- diskann-wide/src/arch/aarch64/i8x8_.rs | 4 +- diskann-wide/src/arch/aarch64/macros.rs | 36 +- diskann-wide/src/arch/aarch64/masks.rs | 72 ++-- diskann-wide/src/arch/aarch64/mod.rs | 2 + diskann-wide/src/arch/aarch64/u16x8_.rs | 10 +- diskann-wide/src/arch/aarch64/u32x4_.rs | 10 +- diskann-wide/src/arch/aarch64/u64x2_.rs | 10 +- diskann-wide/src/arch/aarch64/u8x16_.rs | 4 +- diskann-wide/src/arch/aarch64/u8x8_.rs | 4 +- diskann-wide/src/doubled.rs | 34 ++ 18 files changed, 553 insertions(+), 93 deletions(-) create mode 100644 diskann-wide/src/arch/aarch64/algorithms/load_first.rs create mode 100644 diskann-wide/src/arch/aarch64/algorithms/mod.rs diff --git a/diskann-wide/src/arch/aarch64/algorithms/load_first.rs b/diskann-wide/src/arch/aarch64/algorithms/load_first.rs new file mode 100644 index 000000000..b679e28cb --- /dev/null +++ b/diskann-wide/src/arch/aarch64/algorithms/load_first.rs @@ -0,0 +1,404 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::arch::aarch64::*; + +use crate::arch::aarch64::Neon; + +//////////////// +// Load First // +//////////////// + +//-------------// +// 64-bit wide // +//-------------// + +/// Load the first `first` elements from `ptr` into a `uint8x8_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn u8x8(_: Neon, ptr: *const u8, first: usize) -> uint8x8_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { vcreate_u8(load_first_of_8_bytes(ptr, first)) } +} + +/// Load the first `first` elements from `ptr` into an `int8x8_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn i8x8(_: Neon, ptr: *const i8, first: usize) -> int8x8_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { vreinterpret_s8_u8(vcreate_u8(load_first_of_8_bytes(ptr.cast::(), first))) } +} + +/// Load the first `first` elements from `ptr` into a `float32x2_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn f32x2( + arch: Neon, + ptr: *const f32, + first: usize, +) -> float32x2_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { vreinterpret_f32_u32(load_first_32x2(arch, ptr.cast::(), first)) } +} + +//--------------// +// 128-bit wide // +//--------------// + +/// Load the first `first` elements from `ptr` into a `uint8x16_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn u8x16( + arch: Neon, + ptr: *const u8, + first: usize, +) -> uint8x16_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { load_first_of_16_bytes(arch, ptr, first) } +} + +/// Load the first `first` elements from `ptr` into an `int8x16_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn i8x16( + arch: Neon, + ptr: *const i8, + first: usize, +) -> int8x16_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { vreinterpretq_s8_u8(u8x16(arch, ptr.cast::(), first)) } +} + +/// Load the first `first` elements from `ptr` into a `uint16x8_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn u16x8( + arch: Neon, + ptr: *const u16, + first: usize, +) -> uint16x8_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { vreinterpretq_u16_u8(load_first_of_16_bytes(arch, ptr.cast::(), 2 * first)) } +} + +/// Load the first `first` elements from `ptr` into an `int16x8_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn i16x8( + arch: Neon, + ptr: *const i16, + first: usize, +) -> int16x8_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { vreinterpretq_s16_u16(u16x8(arch, ptr.cast::(), first)) } +} + +/// Load the first `first` elements from `ptr` into a `uint32x4_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn u32x4( + arch: Neon, + ptr: *const u32, + first: usize, +) -> uint32x4_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { load_first_32x4(arch, ptr, first) } +} + +/// Load the first `first` elements from `ptr` into an `int32x4_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn i32x4( + arch: Neon, + ptr: *const i32, + first: usize, +) -> int32x4_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { vreinterpretq_s32_u32(u32x4(arch, ptr.cast::(), first)) } +} + +/// Load the first `first` elements from `ptr` into a `float32x4_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn f32x4( + arch: Neon, + ptr: *const f32, + first: usize, +) -> float32x4_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { vreinterpretq_f32_u32(u32x4(arch, ptr.cast::(), first)) } +} + +/// Load the first `first` elements from `ptr` into a `uint64x2_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn u64x2( + arch: Neon, + ptr: *const u64, + first: usize, +) -> uint64x2_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { load_first_64x2(arch, ptr, first) } +} + +/// Load the first `first` elements from `ptr` into an `int64x2_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn i64x2( + arch: Neon, + ptr: *const i64, + first: usize, +) -> int64x2_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { vreinterpretq_s64_u64(u64x2(arch, ptr.cast::(), first)) } +} + +//////////////////// +// Implementation // +//////////////////// + +/// Efficiently load the first `8 < bytes < 16` bytes from `ptr` without accessing memory +/// outside of `[ptr, ptr + bytes)`. +/// +/// Uses two overlapping 8-byte loads combined with `TBL` to shift the high portion into +/// position, mirroring the x86 `PSHUFB` technique. +/// +/// # Safety +/// +/// * `bytes` must be in the range `(8, 16)`. +/// * The memory in `[ptr, ptr + bytes)` must be readable and valid. +#[inline(always)] +unsafe fn load_8_to_16_bytes(arch: Neon, ptr: *const u8, bytes: usize) -> uint8x16_t { + debug_assert!(bytes > 8 && bytes < 16); + + // Two overlapping 8-byte loads: [ptr, ptr+8) and [ptr+bytes-8, ptr+bytes). + // + // `lo` occupies the lower 64 bits of a 128-bit register (upper half is zero). + // We need to shift `hi` right by `bytes - 8` positions so that its first valid + // byte aligns with byte `bytes - 8` of `lo`, then OR the two together. + // + // We achieve this with `vqtbl1q_u8`: build an identity index vector and subtract + // `bytes - 8`. Lanes that underflow get their high bit set, which `TBL` maps to + // zero — exactly what we want for the overlapping region. + // + // SAFETY: Both reads are within `[ptr, ptr + bytes)`. The intrinsics require NEON. + unsafe { + let base = vcombine_u8( + vcreate_u8(0x0706050403020100), + vcreate_u8(0x0F0E0D0C0B0A0908), + ); + + let lo = vcombine_u8(vld1_u8(ptr), vcreate_u8(0)); + let hi = vcombine_u8(vld1_u8(ptr.add(bytes - 8)), vcreate_u8(0)); + let shift = vmovq_n_u8((bytes - 8) as u8); + let mask = vsubq_u8(base, shift); + + // Miri does not support the `vqtbl1q_u8` instruction. + // + // Because we want `Miri` to see the loads, we just emulate the shift portion of + // the operation. + let combined = if cfg!(miri) { + use crate::{SIMDVector, arch::aarch64::u8x16}; + let lo = u8x16::from_underlying(arch, lo).to_array(); + let hi = u8x16::from_underlying(arch, hi).to_array(); + + let combined: [u8; 16] = core::array::from_fn(|i| { + if i < 8 { + lo[i] + } else if i < bytes { + hi[i - (bytes - 8)] + } else { + 0 + } + }); + + u8x16::from_array(arch, combined).to_underlying() + } else { + vqtbl1q_u8(hi, mask) + }; + + vorrq_u8(lo, combined) + } +} + +/// Load the first `bytes` bytes from `ptr` into a `u64`. +/// +/// Bytes beyond `bytes` are zero. This is efficient for small loads (≤8 bytes) because +/// it stays entirely in general-purpose registers with no SIMD involvement. +/// +/// # Safety +/// +/// * The memory in `[ptr, ptr + bytes.min(8))` must be readable and valid. +#[inline(always)] +unsafe fn load_first_of_8_bytes(ptr: *const u8, bytes: usize) -> u64 { + // SAFETY: All reads are within `[ptr, ptr + bytes)`, which the caller asserts is valid. + unsafe { + if bytes >= 8 { + std::ptr::read_unaligned(ptr as *const u64) + } else if bytes >= 4 { + let lo = std::ptr::read_unaligned(ptr as *const u32) as u64; + let hi = std::ptr::read_unaligned(ptr.add(bytes - 4) as *const u32) as u64; + lo | (hi << ((bytes - 4) * 8)) + } else if bytes >= 2 { + let lo = std::ptr::read_unaligned(ptr as *const u16) as u64; + let hi = std::ptr::read_unaligned(ptr.add(bytes - 2) as *const u16) as u64; + lo | (hi << ((bytes - 2) * 8)) + } else if bytes == 1 { + std::ptr::read(ptr) as u64 + } else { + 0 + } + } +} + +/// Load the first `bytes` bytes from `ptr` into a 128-bit Neon register. +/// +/// For full loads (≥16), uses a single `vld1q_u8`. For 8 < bytes < 16, uses the two-load +/// shuffle technique. For ≤8 bytes, uses GPR-based overlapping reads. +/// +/// # Safety +/// +/// * The memory in `[ptr, ptr + bytes)` must be readable and valid. +/// * Memory at and above `ptr + bytes` will not be accessed. +#[inline(always)] +unsafe fn load_first_of_16_bytes(arch: Neon, ptr: *const u8, bytes: usize) -> uint8x16_t { + if bytes >= 16 { + // SAFETY: Full load is valid since `bytes >= 16`. + return unsafe { vld1q_u8(ptr) }; + } + + if bytes > 8 { + // SAFETY: `bytes` is in `(8, 16)` and `[ptr, ptr + bytes)` is valid. + return unsafe { load_8_to_16_bytes(arch, ptr, bytes) }; + } + + // SAFETY: `bytes` is in `[0, 8]` and `[ptr, ptr + bytes)` is valid. + // + // The presence of `Neon` enables the use of "neon" intrinsics. + unsafe { + let v = load_first_of_8_bytes(ptr, bytes); + vcombine_u8(vcreate_u8(v), vcreate_u8(0)) + } +} + +/// Load the first `first` elements of a 32-bit type from `ptr` into a 128-bit register. +/// +/// # Safety +/// +/// * The memory in `[ptr, ptr + first)` must be readable and valid (element-wise). +/// * Memory at and above `ptr + first` will not be accessed. +#[inline(always)] +unsafe fn load_first_32x4(_: Neon, ptr: *const u32, first: usize) -> uint32x4_t { + // SAFETY: All reads are within `[ptr, ptr + first)`. + // + // The presence of `Neon` enables the use of "neon" intrinsics. + unsafe { + if first >= 4 { + vld1q_u32(ptr) + } else if first == 3 { + let lo = vld1_u32(ptr); + let hi = vld1_lane_u32(ptr.add(2), vcreate_u32(0), 0); + vcombine_u32(lo, hi) + } else if first == 2 { + vcombine_u32(vld1_u32(ptr), vcreate_u32(0)) + } else if first == 1 { + vcombine_u32(vcreate_u32(ptr.read_unaligned() as u64), vcreate_u32(0)) + } else { + vmovq_n_u32(0) + } + } +} + +/// Load the first `first` elements of a 32-bit type from `ptr` into a 64-bit register. +/// +/// # Safety +/// +/// * The memory in `[ptr, ptr + first)` must be readable and valid (element-wise). +/// * Memory at and above `ptr + first` will not be accessed. +#[inline(always)] +unsafe fn load_first_32x2(_: Neon, ptr: *const u32, first: usize) -> uint32x2_t { + // SAFETY: All reads are within `[ptr, ptr + first)`. + // + // The presence of `Neon` enables the use of "neon" intrinsics. + unsafe { + if first >= 2 { + vld1_u32(ptr) + } else if first == 1 { + vcreate_u32(ptr.read_unaligned() as u64) + } else { + vmov_n_u32(0) + } + } +} + +/// Load the first `first` elements of a 64-bit type from `ptr` into a 128-bit register. +/// +/// # Safety +/// +/// * The memory in `[ptr, ptr + first)` must be readable and valid (element-wise). +/// * Memory at and above `ptr + first` will not be accessed. +#[inline(always)] +unsafe fn load_first_64x2(_: Neon, ptr: *const u64, first: usize) -> uint64x2_t { + // SAFETY: All reads are within `[ptr, ptr + first)`. + // + // The presence of `Neon` enables the use of "neon" intrinsics. + unsafe { + if first >= 2 { + vld1q_u64(ptr) + } else if first == 1 { + vcombine_u64(vld1_u64(ptr), vcreate_u64(0)) + } else { + vmovq_n_u64(0) + } + } +} diff --git a/diskann-wide/src/arch/aarch64/algorithms/mod.rs b/diskann-wide/src/arch/aarch64/algorithms/mod.rs new file mode 100644 index 000000000..09a7ff1fa --- /dev/null +++ b/diskann-wide/src/arch/aarch64/algorithms/mod.rs @@ -0,0 +1,6 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub(super) mod load_first; diff --git a/diskann-wide/src/arch/aarch64/f32x2_.rs b/diskann-wide/src/arch/aarch64/f32x2_.rs index 3bf1176ed..2f9abbc95 100644 --- a/diskann-wide/src/arch/aarch64/f32x2_.rs +++ b/diskann-wide/src/arch/aarch64/f32x2_.rs @@ -12,7 +12,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, + Neon, algorithms, macros::{self, AArchLoadStore, AArchSplat}, masks::mask32x2, }; @@ -26,7 +26,7 @@ use std::arch::aarch64::*; macros::aarch64_define_register!(f32x2, float32x2_t, mask32x2, f32, 2, Neon); macros::aarch64_define_splat!(f32x2, vmov_n_f32); -macros::aarch64_define_loadstore!(f32x2, vld1_f32, vst1_f32, 2); +macros::aarch64_define_loadstore!(f32x2, vld1_f32, algorithms::load_first::f32x2, vst1_f32, 2); helpers::unsafe_map_binary_op!(f32x2, std::ops::Add, add, vadd_f32, "neon"); helpers::unsafe_map_binary_op!(f32x2, std::ops::Sub, sub, vsub_f32, "neon"); @@ -49,7 +49,7 @@ impl SIMDSumTree for f32x2 { if cfg!(miri) { self.sum_tree() } else { - // SAFETY: This file is gated by the "neon" target feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vaddv_f32(self.to_underlying()) } } } diff --git a/diskann-wide/src/arch/aarch64/f32x4_.rs b/diskann-wide/src/arch/aarch64/f32x4_.rs index f857673c3..43bd680b8 100644 --- a/diskann-wide/src/arch/aarch64/f32x4_.rs +++ b/diskann-wide/src/arch/aarch64/f32x4_.rs @@ -12,7 +12,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, f16x4, f32x2, + Neon, algorithms, f16x4, f32x2, macros::{self, AArchLoadStore, AArchSplat}, masks::mask32x4, }; @@ -26,7 +26,13 @@ use std::arch::{aarch64::*, asm}; macros::aarch64_define_register!(f32x4, float32x4_t, mask32x4, f32, 4, Neon); macros::aarch64_define_splat!(f32x4, vmovq_n_f32); -macros::aarch64_define_loadstore!(f32x4, vld1q_f32, vst1q_f32, 4); +macros::aarch64_define_loadstore!( + f32x4, + vld1q_f32, + algorithms::load_first::f32x4, + vst1q_f32, + 4 +); macros::aarch64_splitjoin!(f32x4, f32x2, vget_low_f32, vget_high_f32, vcombine_f32); helpers::unsafe_map_binary_op!(f32x4, std::ops::Add, add, vaddq_f32, "neon"); diff --git a/diskann-wide/src/arch/aarch64/i16x8_.rs b/diskann-wide/src/arch/aarch64/i16x8_.rs index bd14b633d..10b44ef4d 100644 --- a/diskann-wide/src/arch/aarch64/i16x8_.rs +++ b/diskann-wide/src/arch/aarch64/i16x8_.rs @@ -10,7 +10,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, i8x8, + Neon, algorithms, i8x8, macros::{self, AArchLoadStore, AArchSplat}, masks::mask16x8, u8x8, @@ -25,7 +25,13 @@ use std::arch::aarch64::*; macros::aarch64_define_register!(i16x8, int16x8_t, mask16x8, i16, 8, Neon); macros::aarch64_define_splat!(i16x8, vmovq_n_s16); -macros::aarch64_define_loadstore!(i16x8, vld1q_s16, vst1q_s16, 8); +macros::aarch64_define_loadstore!( + i16x8, + vld1q_s16, + algorithms::load_first::i16x8, + vst1q_s16, + 8 +); helpers::unsafe_map_binary_op!(i16x8, std::ops::Add, add, vaddq_s16, "neon"); helpers::unsafe_map_binary_op!(i16x8, std::ops::Sub, sub, vsubq_s16, "neon"); diff --git a/diskann-wide/src/arch/aarch64/i32x4_.rs b/diskann-wide/src/arch/aarch64/i32x4_.rs index 2c7b905ea..d7f6f09c4 100644 --- a/diskann-wide/src/arch/aarch64/i32x4_.rs +++ b/diskann-wide/src/arch/aarch64/i32x4_.rs @@ -10,7 +10,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, f32x4, i8x8, i8x16, i16x8, + Neon, algorithms, f32x4, i8x8, i8x16, i16x8, macros::{self, AArchLoadStore, AArchSplat}, masks::mask32x4, u8x8, u8x16, @@ -25,7 +25,13 @@ use std::arch::{aarch64::*, asm}; macros::aarch64_define_register!(i32x4, int32x4_t, mask32x4, i32, 4, Neon); macros::aarch64_define_splat!(i32x4, vmovq_n_s32); -macros::aarch64_define_loadstore!(i32x4, vld1q_s32, vst1q_s32, 4); +macros::aarch64_define_loadstore!( + i32x4, + vld1q_s32, + algorithms::load_first::i32x4, + vst1q_s32, + 4 +); helpers::unsafe_map_binary_op!(i32x4, std::ops::Add, add, vaddq_s32, "neon"); helpers::unsafe_map_binary_op!(i32x4, std::ops::Sub, sub, vsubq_s32, "neon"); diff --git a/diskann-wide/src/arch/aarch64/i64x2_.rs b/diskann-wide/src/arch/aarch64/i64x2_.rs index bdbbd683b..b37a24f1a 100644 --- a/diskann-wide/src/arch/aarch64/i64x2_.rs +++ b/diskann-wide/src/arch/aarch64/i64x2_.rs @@ -10,7 +10,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, + Neon, algorithms, macros::{self, AArchLoadStore, AArchSplat}, masks::mask64x2, u64x2_::{emulated_vminq_u64, emulated_vmvnq_u64}, @@ -34,7 +34,13 @@ pub(super) unsafe fn emulated_vmvnq_s64(x: int64x2_t) -> int64x2_t { macros::aarch64_define_register!(i64x2, int64x2_t, mask64x2, i64, 2, Neon); macros::aarch64_define_splat!(i64x2, vmovq_n_s64); -macros::aarch64_define_loadstore!(i64x2, vld1q_s64, vst1q_s64, 2); +macros::aarch64_define_loadstore!( + i64x2, + vld1q_s64, + algorithms::load_first::i64x2, + vst1q_s64, + 2 +); helpers::unsafe_map_binary_op!(i64x2, std::ops::Add, add, vaddq_s64, "neon"); helpers::unsafe_map_binary_op!(i64x2, std::ops::Sub, sub, vsubq_s64, "neon"); diff --git a/diskann-wide/src/arch/aarch64/i8x16_.rs b/diskann-wide/src/arch/aarch64/i8x16_.rs index 494fa5a56..017d9f80a 100644 --- a/diskann-wide/src/arch/aarch64/i8x16_.rs +++ b/diskann-wide/src/arch/aarch64/i8x16_.rs @@ -10,7 +10,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, i8x8, + Neon, algorithms, i8x8, macros::{self, AArchLoadStore, AArchSplat}, masks::mask8x16, }; @@ -24,7 +24,7 @@ use std::arch::aarch64::*; macros::aarch64_define_register!(i8x16, int8x16_t, mask8x16, i8, 16, Neon); macros::aarch64_define_splat!(i8x16, vmovq_n_s8); -macros::aarch64_define_loadstore!(i8x16, vld1q_s8, vst1q_s8, 16); +macros::aarch64_define_loadstore!(i8x16, vld1q_s8, algorithms::load_first::i8x16, vst1q_s8, 16); macros::aarch64_splitjoin!(i8x16, i8x8, vget_low_s8, vget_high_s8, vcombine_s8); helpers::unsafe_map_binary_op!(i8x16, std::ops::Add, add, vaddq_s8, "neon"); diff --git a/diskann-wide/src/arch/aarch64/i8x8_.rs b/diskann-wide/src/arch/aarch64/i8x8_.rs index 2dad5cc16..8cd501f93 100644 --- a/diskann-wide/src/arch/aarch64/i8x8_.rs +++ b/diskann-wide/src/arch/aarch64/i8x8_.rs @@ -10,7 +10,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, + Neon, algorithms, macros::{self, AArchLoadStore, AArchSplat}, masks::mask8x8, }; @@ -24,7 +24,7 @@ use std::arch::aarch64::*; macros::aarch64_define_register!(i8x8, int8x8_t, mask8x8, i8, 8, Neon); macros::aarch64_define_splat!(i8x8, vmov_n_s8); -macros::aarch64_define_loadstore!(i8x8, vld1_s8, vst1_s8, 8); +macros::aarch64_define_loadstore!(i8x8, vld1_s8, algorithms::load_first::i8x8, vst1_s8, 8); helpers::unsafe_map_binary_op!(i8x8, std::ops::Add, add, vadd_s8, "neon"); helpers::unsafe_map_binary_op!(i8x8, std::ops::Sub, sub, vsub_s8, "neon"); diff --git a/diskann-wide/src/arch/aarch64/macros.rs b/diskann-wide/src/arch/aarch64/macros.rs index 4fab7de37..972ae32ea 100644 --- a/diskann-wide/src/arch/aarch64/macros.rs +++ b/diskann-wide/src/arch/aarch64/macros.rs @@ -186,7 +186,7 @@ macro_rules! aarch64_define_splat { } macro_rules! aarch64_define_loadstore { - ($type:ty, $load:expr, $store:expr, $lanes:literal) => { + ($type:ty, $load:expr, $load_first:expr, $store:expr, $lanes:literal) => { impl AArchLoadStore for $type { #[inline(always)] unsafe fn load_simd( @@ -222,11 +222,7 @@ macro_rules! aarch64_define_loadstore { first: usize, ) -> Self { // SAFETY: Inherited from caller. - let e = unsafe { - Emulated::<_, $lanes>::load_simd_first($crate::arch::Scalar, ptr, first) - }; - - Self::from_array(arch, e.to_array()) + Self(unsafe { ($load_first)(arch, ptr, first) }) } #[inline(always)] @@ -262,7 +258,7 @@ macro_rules! aarch64_define_cmp { impl SIMDPartialEq for $type { #[inline(always)] fn eq_simd(self, other: Self) -> Self::Mask { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -271,7 +267,7 @@ macro_rules! aarch64_define_cmp { #[inline(always)] fn ne_simd(self, other: Self) -> Self::Mask { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -282,7 +278,7 @@ macro_rules! aarch64_define_cmp { impl SIMDPartialOrd for $type { #[inline(always)] fn lt_simd(self, other: Self) -> Self::Mask { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -291,7 +287,7 @@ macro_rules! aarch64_define_cmp { #[inline(always)] fn le_simd(self, other: Self) -> Self::Mask { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -300,7 +296,7 @@ macro_rules! aarch64_define_cmp { #[inline(always)] fn gt_simd(self, other: Self) -> Self::Mask { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -309,7 +305,7 @@ macro_rules! aarch64_define_cmp { #[inline(always)] fn ge_simd(self, other: Self) -> Self::Mask { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -383,7 +379,7 @@ macro_rules! aarch64_define_bitops { type Output = Self; #[inline(always)] fn not(self) -> Self { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -395,7 +391,7 @@ macro_rules! aarch64_define_bitops { type Output = Self; #[inline(always)] fn bitand(self, rhs: Self) -> Self { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -407,7 +403,7 @@ macro_rules! aarch64_define_bitops { type Output = Self; #[inline(always)] fn bitor(self, rhs: Self) -> Self { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -419,7 +415,7 @@ macro_rules! aarch64_define_bitops { type Output = Self; #[inline(always)] fn bitxor(self, rhs: Self) -> Self { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -439,7 +435,7 @@ macro_rules! aarch64_define_bitops { if cfg!(miri) { self.emulated().shr(rhs.emulated()).as_simd(self.arch()) } else { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -464,7 +460,7 @@ macro_rules! aarch64_define_bitops { if cfg!(miri) { self.emulated().shl(rhs.emulated()).as_simd(self.arch()) } else { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -493,7 +489,7 @@ macro_rules! aarch64_define_bitops { if cfg!(miri) { self.emulated().shr(rhs).as_simd(self.arch()) } else { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". @@ -515,7 +511,7 @@ macro_rules! aarch64_define_bitops { if cfg!(miri) { self.emulated().shl(rhs).as_simd(self.arch()) } else { - // SAFETY: Inclusion of this macro is gated by the "neon" feature. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. // // It is the caller's responsibility to instantiate the macro with an // intrinsics also gated by "neon". diff --git a/diskann-wide/src/arch/aarch64/masks.rs b/diskann-wide/src/arch/aarch64/masks.rs index 8658c4b94..d20044eb7 100644 --- a/diskann-wide/src/arch/aarch64/masks.rs +++ b/diskann-wide/src/arch/aarch64/masks.rs @@ -142,8 +142,7 @@ impl MaskOps for uint8x8_t { let array = self.to_array(); BitMask::from_fn(arch, |i| array[i] == u8::MAX) } else { - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. let value = unsafe { let mask = vmov_n_u8(0x80); // Effectively creates [-7, -6, -5, -4, -3, -2, -1, 0] @@ -158,8 +157,7 @@ impl MaskOps for uint8x8_t { #[inline(always)] fn from_mask(mask: Self::BitMask) -> Self { const BIT_SELECTOR: u64 = 0x8040201008040201; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vtst_u8(vmov_n_u8(mask.0), vcreate_u8(BIT_SELECTOR)) } } @@ -167,8 +165,7 @@ impl MaskOps for uint8x8_t { fn keep_first(_arch: Neon, lanes: usize) -> Self { const INDICES: u64 = 0x0706050403020100; let n = lanes.min(8) as u8; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vclt_u8(vcreate_u8(INDICES), vmov_n_u8(n)) } } @@ -191,8 +188,7 @@ impl MaskOps for uint8x16_t { let array = self.to_array(); BitMask::from_fn(arch, |i| array[i] == u8::MAX) } else { - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. let value = unsafe { let mask = vmovq_n_u8(0x80); let masked = vandq_u8(self, mask); @@ -216,8 +212,7 @@ impl MaskOps for uint8x16_t { let low = mask as u8; let high = (mask >> 8) as u8; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vtstq_u8( vcombine_u8(vmov_n_u8(low), vmov_n_u8(high)), @@ -231,8 +226,7 @@ impl MaskOps for uint8x16_t { const LO: u64 = 0x0706050403020100; const HI: u64 = 0x0F0E0D0C0B0A0908; let n = lanes.min(16) as u8; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vcltq_u8(vcombine_u8(vcreate_u8(LO), vcreate_u8(HI)), vmovq_n_u8(n)) } } @@ -271,8 +265,7 @@ impl MaskOps for uint16x4_t { // // Thus, everything gets compressed down to 4-bits. // - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. let value = unsafe { let bits = vshr_n_u16(self, 15); let paired = vsra_n_u32( @@ -295,8 +288,7 @@ impl MaskOps for uint16x4_t { #[inline(always)] fn from_mask(mask: Self::BitMask) -> Self { const BIT_SELECTOR: u64 = 0x0008_0004_0002_0001; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vtst_u16(vmov_n_u16(mask.0 as u16), vcreate_u16(BIT_SELECTOR)) } } @@ -304,8 +296,7 @@ impl MaskOps for uint16x4_t { fn keep_first(_arch: Neon, lanes: usize) -> Self { const INDICES: u64 = 0x0003_0002_0001_0000; let n = lanes.min(4) as u16; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vclt_u16(vcreate_u16(INDICES), vmov_n_u16(n)) } } @@ -328,8 +319,7 @@ impl MaskOps for uint16x8_t { let array = self.to_array(); BitMask::from_fn(arch, |i| array[i] == u16::MAX) } else { - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. let value = unsafe { // Effectively creates [-15, -14, -13, -12, -11, -10, -9, -8] let shifts = vcombine_s16( @@ -348,8 +338,7 @@ impl MaskOps for uint16x8_t { fn from_mask(mask: Self::BitMask) -> Self { const BIT_SELECTOR_LOW: u64 = 0x0008_0004_0002_0001; const BIT_SELECTOR_HIGH: u64 = 0x0080_0040_0020_0010; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vtstq_u16( vmovq_n_u16(mask.0 as u16), @@ -366,8 +355,7 @@ impl MaskOps for uint16x8_t { const LO: u64 = 0x0003_0002_0001_0000; const HI: u64 = 0x0007_0006_0005_0004; let n = lanes.min(8) as u16; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vcltq_u16( vcombine_u16(vcreate_u16(LO), vcreate_u16(HI)), @@ -398,8 +386,7 @@ impl MaskOps for uint32x2_t { // Normalize each lane to 0 or 1, then use shift-right-accumulate to pack // bits into position. // - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. let value = unsafe { let bits = vshr_n_u32(self, 31); let packed = vsra_n_u64( @@ -417,8 +404,7 @@ impl MaskOps for uint32x2_t { #[inline(always)] fn from_mask(mask: Self::BitMask) -> Self { const BIT_SELECTOR: u64 = 0x0000_0002_0000_0001; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vtst_u32(vmov_n_u32(mask.0 as u32), vcreate_u32(BIT_SELECTOR)) } } @@ -426,8 +412,7 @@ impl MaskOps for uint32x2_t { fn keep_first(_arch: Neon, lanes: usize) -> Self { const INDICES: u64 = 0x0000_0001_0000_0000; let n = lanes.min(2) as u32; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vclt_u32(vcreate_u32(INDICES), vmov_n_u32(n)) } } @@ -453,8 +438,7 @@ impl MaskOps for uint32x4_t { // Refer to the implementation for `uint16x4_t`. The approach here is // identical, just twice as wide. // - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. let value = unsafe { let bits = vshrq_n_u32(self, 31); let paired = vsraq_n_u64( @@ -480,8 +464,7 @@ impl MaskOps for uint32x4_t { fn from_mask(mask: Self::BitMask) -> Self { const BIT_SELECTOR_LOW: u64 = 0x0000_0002_0000_0001; const BIT_SELECTOR_HIGH: u64 = 0x0000_0008_0000_0004; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vtstq_u32( vmovq_n_u32(mask.0 as u32), @@ -498,8 +481,7 @@ impl MaskOps for uint32x4_t { const LO: u64 = 0x0000_0001_0000_0000; const HI: u64 = 0x0000_0003_0000_0002; let n = lanes.min(4) as u32; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vcltq_u32( vcombine_u32(vcreate_u32(LO), vcreate_u32(HI)), @@ -529,8 +511,7 @@ impl MaskOps for uint64x1_t { } else { // Single lane: just shift the MSB down to bit 0 and extract. // - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. let value = unsafe { vget_lane_u8(vreinterpret_u8_u64(vshr_n_u64(self, 63)), 0) }; @@ -542,16 +523,14 @@ impl MaskOps for uint64x1_t { #[inline(always)] fn from_mask(mask: Self::BitMask) -> Self { // Single lane: negation maps 0→0 and 1→0xFFFF_FFFF_FFFF_FFFF. - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vcreate_u64((mask.0 as u64).wrapping_neg()) } } #[inline(always)] fn keep_first(_arch: Neon, lanes: usize) -> Self { // Single lane: negation maps 0→0 and 1→0xFFFF_FFFF_FFFF_FFFF. - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vcreate_u64((lanes.min(1) as u64).wrapping_neg()) } } @@ -577,8 +556,7 @@ impl MaskOps for uint64x2_t { // Normalize each lane to 0 or 1, then narrow to a 64-bit register and // use shift-right-accumulate to combine the two bits. // - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. let value = unsafe { let bits = vshrq_n_u64(self, 63); let narrowed = vmovn_u64(bits); @@ -598,8 +576,7 @@ impl MaskOps for uint64x2_t { fn from_mask(mask: Self::BitMask) -> Self { const BIT_SELECTOR_LOW: u64 = 0x0000_0000_0000_0001; const BIT_SELECTOR_HIGH: u64 = 0x0000_0000_0000_0002; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vtstq_u64( vmovq_n_u64(mask.0 as u64), @@ -616,8 +593,7 @@ impl MaskOps for uint64x2_t { const LO: u64 = 0; const HI: u64 = 1; let n = lanes.min(2) as u64; - // SAFETY: Inclusion of this function is dependent on the "neon" target - // feature. This function does not access memory directly. + // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vcltq_u64( vcombine_u64(vcreate_u64(LO), vcreate_u64(HI)), diff --git a/diskann-wide/src/arch/aarch64/mod.rs b/diskann-wide/src/arch/aarch64/mod.rs index 74df594f9..87c6b4e39 100644 --- a/diskann-wide/src/arch/aarch64/mod.rs +++ b/diskann-wide/src/arch/aarch64/mod.rs @@ -11,6 +11,8 @@ use crate::{ }, }; +mod algorithms; + pub mod f16x4_; pub use f16x4_::f16x4; diff --git a/diskann-wide/src/arch/aarch64/u16x8_.rs b/diskann-wide/src/arch/aarch64/u16x8_.rs index bc5e941b1..6e1b1d51f 100644 --- a/diskann-wide/src/arch/aarch64/u16x8_.rs +++ b/diskann-wide/src/arch/aarch64/u16x8_.rs @@ -12,7 +12,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, + Neon, algorithms, macros::{self, AArchLoadStore, AArchSplat}, masks::mask16x8, }; @@ -26,7 +26,13 @@ use std::arch::aarch64::*; macros::aarch64_define_register!(u16x8, uint16x8_t, mask16x8, u16, 8, Neon); macros::aarch64_define_splat!(u16x8, vmovq_n_u16); -macros::aarch64_define_loadstore!(u16x8, vld1q_u16, vst1q_u16, 8); +macros::aarch64_define_loadstore!( + u16x8, + vld1q_u16, + algorithms::load_first::u16x8, + vst1q_u16, + 8 +); helpers::unsafe_map_binary_op!(u16x8, std::ops::Add, add, vaddq_u16, "neon"); helpers::unsafe_map_binary_op!(u16x8, std::ops::Sub, sub, vsubq_u16, "neon"); diff --git a/diskann-wide/src/arch/aarch64/u32x4_.rs b/diskann-wide/src/arch/aarch64/u32x4_.rs index 58d001390..e25844288 100644 --- a/diskann-wide/src/arch/aarch64/u32x4_.rs +++ b/diskann-wide/src/arch/aarch64/u32x4_.rs @@ -10,7 +10,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, + Neon, algorithms, macros::{self, AArchLoadStore, AArchSplat}, masks::mask32x4, u8x16, @@ -25,7 +25,13 @@ use std::arch::{aarch64::*, asm}; macros::aarch64_define_register!(u32x4, uint32x4_t, mask32x4, u32, 4, Neon); macros::aarch64_define_splat!(u32x4, vmovq_n_u32); -macros::aarch64_define_loadstore!(u32x4, vld1q_u32, vst1q_u32, 4); +macros::aarch64_define_loadstore!( + u32x4, + vld1q_u32, + algorithms::load_first::u32x4, + vst1q_u32, + 4 +); helpers::unsafe_map_binary_op!(u32x4, std::ops::Add, add, vaddq_u32, "neon"); helpers::unsafe_map_binary_op!(u32x4, std::ops::Sub, sub, vsubq_u32, "neon"); diff --git a/diskann-wide/src/arch/aarch64/u64x2_.rs b/diskann-wide/src/arch/aarch64/u64x2_.rs index 45e1fbaab..7d0e45bde 100644 --- a/diskann-wide/src/arch/aarch64/u64x2_.rs +++ b/diskann-wide/src/arch/aarch64/u64x2_.rs @@ -13,7 +13,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, + Neon, algorithms, macros::{self, AArchLoadStore, AArchSplat}, masks::mask64x2, }; @@ -46,7 +46,13 @@ pub(super) unsafe fn emulated_vminq_u64(x: uint64x2_t, y: uint64x2_t) -> uint64x macros::aarch64_define_register!(u64x2, uint64x2_t, mask64x2, u64, 2, Neon); macros::aarch64_define_splat!(u64x2, vmovq_n_u64); -macros::aarch64_define_loadstore!(u64x2, vld1q_u64, vst1q_u64, 2); +macros::aarch64_define_loadstore!( + u64x2, + vld1q_u64, + algorithms::load_first::u64x2, + vst1q_u64, + 2 +); helpers::unsafe_map_binary_op!(u64x2, std::ops::Add, add, vaddq_u64, "neon"); helpers::unsafe_map_binary_op!(u64x2, std::ops::Sub, sub, vsubq_u64, "neon"); diff --git a/diskann-wide/src/arch/aarch64/u8x16_.rs b/diskann-wide/src/arch/aarch64/u8x16_.rs index 7933f9437..be1a1f8f5 100644 --- a/diskann-wide/src/arch/aarch64/u8x16_.rs +++ b/diskann-wide/src/arch/aarch64/u8x16_.rs @@ -12,7 +12,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, + Neon, algorithms, macros::{self, AArchLoadStore, AArchSplat}, masks::mask8x16, u8x8, @@ -27,7 +27,7 @@ use std::arch::aarch64::*; macros::aarch64_define_register!(u8x16, uint8x16_t, mask8x16, u8, 16, Neon); macros::aarch64_define_splat!(u8x16, vmovq_n_u8); -macros::aarch64_define_loadstore!(u8x16, vld1q_u8, vst1q_u8, 16); +macros::aarch64_define_loadstore!(u8x16, vld1q_u8, algorithms::load_first::u8x16, vst1q_u8, 16); macros::aarch64_splitjoin!(u8x16, u8x8, vget_low_u8, vget_high_u8, vcombine_u8); helpers::unsafe_map_binary_op!(u8x16, std::ops::Add, add, vaddq_u8, "neon"); diff --git a/diskann-wide/src/arch/aarch64/u8x8_.rs b/diskann-wide/src/arch/aarch64/u8x8_.rs index 914000037..44812de81 100644 --- a/diskann-wide/src/arch/aarch64/u8x8_.rs +++ b/diskann-wide/src/arch/aarch64/u8x8_.rs @@ -12,7 +12,7 @@ use crate::{ // AArch64 masks use super::{ - Neon, + Neon, algorithms, macros::{self, AArchLoadStore, AArchSplat}, masks::mask8x8, }; @@ -26,7 +26,7 @@ use std::arch::aarch64::*; macros::aarch64_define_register!(u8x8, uint8x8_t, mask8x8, u8, 8, Neon); macros::aarch64_define_splat!(u8x8, vmov_n_u8); -macros::aarch64_define_loadstore!(u8x8, vld1_u8, vst1_u8, 8); +macros::aarch64_define_loadstore!(u8x8, vld1_u8, algorithms::load_first::u8x8, vst1_u8, 8); helpers::unsafe_map_binary_op!(u8x8, std::ops::Add, add, vadd_u8, "neon"); helpers::unsafe_map_binary_op!(u8x8, std::ops::Sub, sub, vsub_u8, "neon"); diff --git a/diskann-wide/src/doubled.rs b/diskann-wide/src/doubled.rs index 66e69be65..4f0ecd681 100644 --- a/diskann-wide/src/doubled.rs +++ b/diskann-wide/src/doubled.rs @@ -121,6 +121,27 @@ macro_rules! double_vector { ) } + #[inline(always)] + unsafe fn load_simd_first( + arch: Self::Arch, + ptr: *const Self::Scalar, + first: usize, + ) -> Self { + const HALF: usize = { $N / 2 }; + Self( + // SAFETY: Inherited from caller. + unsafe { <$repr as $crate::SIMDVector>::load_simd_first(arch, ptr, first) }, + // SAFETY: Inherited from caller. + unsafe { + <$repr as $crate::SIMDVector>::load_simd_first( + arch, + ptr.wrapping_add(HALF), + first.saturating_sub(HALF), + ) + }, + ) + } + #[inline(always)] unsafe fn store_simd(self, ptr: *mut Self::Scalar) { // SAFETY: The caller asserts this pointer access is safe. @@ -143,6 +164,19 @@ macro_rules! double_vector { .store_simd_masked_logical(ptr.wrapping_add({ $N / 2 }), mask.1) }; } + + #[inline(always)] + unsafe fn store_simd_first(self, ptr: *mut Self::Scalar, first: usize) { + const HALF: usize = { $N / 2 }; + + // SAFETY: Inherited from caller. + unsafe { self.0.store_simd_first(ptr, first) }; + // SAFETY: Inherited from caller. + unsafe { + self.1 + .store_simd_first(ptr.wrapping_add(HALF), first.saturating_sub(HALF)) + }; + } } }; } From 8f3aedcdc613d17fccf82ef5e6abc8d92ba254ae Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 16 Feb 2026 12:16:29 -0800 Subject: [PATCH 06/10] Wrapping up!. --- .cargo/config.toml | 4 + .../examples/test-aarch64.json | 1390 +++++++++++++++++ diskann-benchmark-simd/src/bin.rs | 10 +- diskann-benchmark-simd/src/lib.rs | 139 +- .../src/model/pq/distance/dynamic.rs | 2 +- .../src/algorithms/hadamard.rs | 13 +- diskann-quantization/src/bits/distances.rs | 49 +- .../src/arch/aarch64/algorithms/load_first.rs | 12 + diskann-wide/src/arch/aarch64/double.rs | 92 +- diskann-wide/src/arch/aarch64/f16x4_.rs | 22 +- diskann-wide/src/arch/aarch64/f16x8_.rs | 27 +- diskann-wide/src/arch/aarch64/f32x2_.rs | 28 +- diskann-wide/src/arch/aarch64/f32x4_.rs | 38 +- diskann-wide/src/arch/aarch64/i16x8_.rs | 32 +- diskann-wide/src/arch/aarch64/i32x4_.rs | 40 +- diskann-wide/src/arch/aarch64/i64x2_.rs | 28 +- diskann-wide/src/arch/aarch64/i8x16_.rs | 30 +- diskann-wide/src/arch/aarch64/i8x8_.rs | 28 +- diskann-wide/src/arch/aarch64/masks.rs | 18 +- diskann-wide/src/arch/aarch64/mod.rs | 34 + diskann-wide/src/arch/aarch64/u16x8_.rs | 26 +- diskann-wide/src/arch/aarch64/u32x4_.rs | 30 +- diskann-wide/src/arch/aarch64/u64x2_.rs | 26 +- diskann-wide/src/arch/aarch64/u8x16_.rs | 28 +- diskann-wide/src/arch/aarch64/u8x8_.rs | 26 +- diskann-wide/src/lib.rs | 4 +- 26 files changed, 1894 insertions(+), 282 deletions(-) create mode 100644 diskann-benchmark-simd/examples/test-aarch64.json diff --git a/.cargo/config.toml b/.cargo/config.toml index 34bee66f8..906b9fc85 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -7,5 +7,9 @@ rustflags = ["-C", "control-flow-guard"] [target.'cfg(target_arch="x86_64")'] rustflags = ["-C", "target-cpu=x86-64-v3"] +# If running on an Aarch64 CPU, this enables the `neon` and `dotprod`. +# +# Generally speaking, CPUs like Apple M-series and Graviton 2+ will support `dotprod`, so +# enabling this seems acceptable. [target.'cfg(target_arch="aarch64")'] rustflags = ["-C", "target-feature=+neon,+dotprod"] diff --git a/diskann-benchmark-simd/examples/test-aarch64.json b/diskann-benchmark-simd/examples/test-aarch64.json new file mode 100644 index 000000000..15afa0a31 --- /dev/null +++ b/diskann-benchmark-simd/examples/test-aarch64.json @@ -0,0 +1,1390 @@ +{ + "search_directories": [], + "jobs": [ + { + "type": "simd-op", + "content": { + "query_type": "float32", + "data_type": "float32", + "arch": "neon", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + }, + { + "type": "simd-op", + "content": { + "query_type": "float32", + "data_type": "float32", + "arch": "scalar", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + }, + { + "type": "simd-op", + "content": { + "query_type": "float32", + "data_type": "float32", + "arch": "reference", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + }, + + { + "type": "simd-op", + "content": { + "query_type": "float16", + "data_type": "float16", + "arch": "neon", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + }, + { + "type": "simd-op", + "content": { + "query_type": "float16", + "data_type": "float16", + "arch": "scalar", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + }, + { + "type": "simd-op", + "content": { + "query_type": "float16", + "data_type": "float16", + "arch": "reference", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + }, + + + { + "type": "simd-op", + "content": { + "query_type": "uint8", + "data_type": "uint8", + "arch": "neon", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + }, + { + "type": "simd-op", + "content": { + "query_type": "uint8", + "data_type": "uint8", + "arch": "scalar", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + }, + { + "type": "simd-op", + "content": { + "query_type": "uint8", + "data_type": "uint8", + "arch": "reference", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + }, + + + { + "type": "simd-op", + "content": { + "query_type": "int8", + "data_type": "int8", + "arch": "neon", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + }, + { + "type": "simd-op", + "content": { + "query_type": "int8", + "data_type": "int8", + "arch": "scalar", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + }, + { + "type": "simd-op", + "content": { + "query_type": "int8", + "data_type": "int8", + "arch": "reference", + "runs": [ + { + "distance": "squared_l2", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "squared_l2", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "inner_product", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 100, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 128, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 160, + "num_points": 50, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 384, + "num_points": 24, + "loops_per_measurement": 5, + "num_measurements": 1 + }, + { + "distance": "cosine", + "dim": 768, + "num_points": 12, + "loops_per_measurement": 5, + "num_measurements": 1 + } + ] + } + } + ] +} diff --git a/diskann-benchmark-simd/src/bin.rs b/diskann-benchmark-simd/src/bin.rs index 74d11cd15..c570b0111 100644 --- a/diskann-benchmark-simd/src/bin.rs +++ b/diskann-benchmark-simd/src/bin.rs @@ -57,9 +57,17 @@ mod tests { #[test] fn integration_test() { + let input = if cfg!(target_arch = "x86_64") { + "test.json" + } else if cfg!(target_arch = "aarch64") { + "test-aarch64.json" + } else { + panic!("Please add a dedicated test input for the compiled architecture"); + }; + let input_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("examples") - .join("test.json"); + .join(input); let tempdir = tempfile::tempdir().unwrap(); let output_path = tempdir.path().join("output.json"); diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 94509e0e4..36c8b2adc 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -84,6 +84,7 @@ pub(crate) enum Arch { #[serde(rename = "x86-64-v3")] #[allow(non_camel_case_types)] X86_64_V3, + Neon, Scalar, Reference, } @@ -93,6 +94,7 @@ impl std::fmt::Display for Arch { let st = match self { Self::X86_64_V4 => "x86-64-v4", Self::X86_64_V3 => "x86-64-v3", + Self::Neon => "neon", Self::Scalar => "scalar", Self::Reference => "reference", }; @@ -277,6 +279,32 @@ fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry: Kernel<'static, diskann_wide::arch::x86_64::V3, i8, i8> ); + // x86-64-v3 + register!( + "aarch64", + dispatcher, + "simd-op-f32xf32-aarch64_neon", + Kernel<'static, diskann_wide::arch::aarch64::Neon, f32, f32> + ); + register!( + "aarch64", + dispatcher, + "simd-op-f16xf16-aarch64_neon", + Kernel<'static, diskann_wide::arch::aarch64::Neon, f16, f16> + ); + register!( + "aarch64", + dispatcher, + "simd-op-u8xu8-aarch64_neon", + Kernel<'static, diskann_wide::arch::aarch64::Neon, u8, u8> + ); + register!( + "aarch64", + dispatcher, + "simd-op-i8xi8-aarch64_neon", + Kernel<'static, diskann_wide::arch::aarch64::Neon, i8, i8> + ); + // scalar register!( dispatcher, @@ -377,7 +405,7 @@ impl DispatchRule for Identity { if *from == Arch::Reference { Ok(MatchScore(0)) } else { - Err(FailureScore(0)) + Err(FailureScore(10)) } } @@ -407,7 +435,7 @@ impl DispatchRule for Identity { if *from == Arch::Scalar { Ok(MatchScore(0)) } else { - Err(FailureScore(0)) + Err(FailureScore(1)) } } @@ -430,72 +458,56 @@ impl DispatchRule for Identity { } } -#[cfg(target_arch = "x86_64")] -impl DispatchRule for Identity { - type Error = ArchNotSupported; - - fn try_match(from: &Arch) -> Result { - if *from == Arch::X86_64_V4 { - Ok(MatchScore(0)) - } else { - Err(FailureScore(0)) - } - } - - fn convert(from: Arch) -> Result { - assert_eq!(from, Arch::X86_64_V4); - diskann_wide::arch::x86_64::V4::new_checked() - .ok_or(ArchNotSupported(from)) - .map(Identity) - } - - fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result { - match from { - None => write!(f, "x86-64-v4"), - Some(arch) => { - if Self::try_match(arch).is_ok() { - write!(f, "matched {}", arch) +macro_rules! match_arch { + ($target_arch:literal, $arch:path, $enum:ident) => { + #[cfg(target_arch = $target_arch)] + impl DispatchRule for Identity<$arch> { + type Error = ArchNotSupported; + + fn try_match(from: &Arch) -> Result { + let available = <$arch>::new_checked().is_some(); + if available && *from == Arch::$enum { + Ok(MatchScore(0)) + } else if !available && *from == Arch::$enum { + Err(FailureScore(0)) } else { - write!(f, "expected {}, got {}", Arch::X86_64_V4, arch) + Err(FailureScore(1)) } } - } - } -} - -#[cfg(target_arch = "x86_64")] -impl DispatchRule for Identity { - type Error = ArchNotSupported; - - fn try_match(from: &Arch) -> Result { - if *from == Arch::X86_64_V3 { - Ok(MatchScore(0)) - } else { - Err(FailureScore(0)) - } - } - fn convert(from: Arch) -> Result { - assert_eq!(from, Arch::X86_64_V3); - diskann_wide::arch::x86_64::V3::new_checked() - .ok_or(ArchNotSupported(from)) - .map(Identity) - } + fn convert(from: Arch) -> Result { + assert_eq!(from, Arch::$enum); + <$arch>::new_checked() + .ok_or(ArchNotSupported(from)) + .map(Identity) + } - fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result { - match from { - None => write!(f, "x86-64-v3"), - Some(arch) => { - if Self::try_match(arch).is_ok() { - write!(f, "matched {}", arch) - } else { - write!(f, "expected {}, got {}", Arch::X86_64_V3, arch) + fn description( + f: &mut std::fmt::Formatter<'_>, + from: Option<&Arch>, + ) -> std::fmt::Result { + let available = <$arch>::new_checked().is_some(); + match from { + None => write!(f, "{}", Arch::$enum), + Some(arch) => { + if Self::try_match(arch).is_ok() { + write!(f, "matched {}", arch) + } else if !available && *arch == Arch::$enum { + write!(f, "matched {} but unsupported by this CPU", Arch::$enum) + } else { + write!(f, "expected {}, got {}", Arch::$enum, arch) + } + } } } } - } + }; } +match_arch!("x86_64", diskann_wide::arch::x86_64::V4, X86_64_V4); +match_arch!("x86_64", diskann_wide::arch::x86_64::V3, X86_64_V3); +match_arch!("aarch64", diskann_wide::arch::aarch64::Neon, Neon); + impl<'a, A, Q, D> DispatchRule<&'a SimdOp> for Kernel<'a, A, Q, D> where datatype::Type: DispatchRule, @@ -513,9 +525,10 @@ where if datatype::Type::::try_match(&from.data_type).is_err() { *failscore.get_or_insert(0) += 10; } - if Identity::::try_match(&from.arch).is_err() { - *failscore.get_or_insert(0) += 2; + if let Err(FailureScore(score)) = Identity::::try_match(&from.arch) { + *failscore.get_or_insert(0) += 2 + score; } + match failscore { None => Ok(MatchScore(0)), Some(score) => Err(FailureScore(score)), @@ -549,13 +562,13 @@ where } Some(input) => { if let Err(err) = datatype::Type::::try_match_verbose(&input.query_type) { - describeln!(f, "- Mismatched query type: {}", err)?; + describeln!(f, "\n - Mismatched query type: {}", err)?; } if let Err(err) = datatype::Type::::try_match_verbose(&input.data_type) { - describeln!(f, "- Mismatched data type: {}", err)?; + describeln!(f, "\n - Mismatched data type: {}", err)?; } if let Err(err) = Identity::::try_match_verbose(&input.arch) { - describeln!(f, "- Mismatched architecture: {}", err)?; + describeln!(f, "\n - Mismatched architecture: {}", err)?; } } } diff --git a/diskann-providers/src/model/pq/distance/dynamic.rs b/diskann-providers/src/model/pq/distance/dynamic.rs index 108fe0b30..dc0c6a73c 100644 --- a/diskann-providers/src/model/pq/distance/dynamic.rs +++ b/diskann-providers/src/model/pq/distance/dynamic.rs @@ -369,7 +369,7 @@ mod tests { let num_trials = 100; let errors = test_utils::RelativeAndAbsolute { - relative: 6e-7, + relative: 6.3e-7, absolute: 0.0, }; diff --git a/diskann-quantization/src/algorithms/hadamard.rs b/diskann-quantization/src/algorithms/hadamard.rs index 385a8bb05..5be88d473 100644 --- a/diskann-quantization/src/algorithms/hadamard.rs +++ b/diskann-quantization/src/algorithms/hadamard.rs @@ -491,7 +491,10 @@ mod tests { // Queue up a list of implementations. type Implementation = Box; - #[cfg_attr(not(target_arch = "x86_64"), expect(unused_mut))] + #[cfg_attr( + not(any(target_arch = "x86_64", target_arch = "aarch64")), + expect(unused_mut) + )] let mut impls: Vec<(Implementation, &'static str)> = vec![ ( Box::new(|x| hadamard_transform(x).unwrap()), @@ -515,6 +518,14 @@ mod tests { )); } + #[cfg(target_arch = "aarch64")] + if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() { + impls.push(( + Box::new(move |x| arch.run1(HadamardTransform, x).unwrap()), + "neon", + )); + } + for (f, kernel) in impls.into_iter() { let mut src_clone = src.clone(); f(src_clone.as_mut_slice()); diff --git a/diskann-quantization/src/bits/distances.rs b/diskann-quantization/src/bits/distances.rs index 9dd1359ed..f2b18f2d1 100644 --- a/diskann-quantization/src/bits/distances.rs +++ b/diskann-quantization/src/bits/distances.rs @@ -2135,34 +2135,42 @@ mod tests { // However, some SIMD kernels (especially for the lower bit widths), require higher bounds // to trigger all possible corner cases. static BITSLICE_TEST_BOUNDS: LazyLock> = LazyLock::new(|| { - use ArchKey::{Scalar, X86_64_V3, X86_64_V4}; + use ArchKey::{Neon, Scalar, X86_64_V3, X86_64_V4}; [ (Key::new(1, Scalar), Bounds::new(64, 64)), (Key::new(1, X86_64_V3), Bounds::new(256, 256)), (Key::new(1, X86_64_V4), Bounds::new(256, 256)), + (Key::new(1, Neon), Bounds::new(64, 64)), (Key::new(2, Scalar), Bounds::new(64, 64)), // Need a higher miri-amount due to the larget block size (Key::new(2, X86_64_V3), Bounds::new(512, 300)), (Key::new(2, X86_64_V4), Bounds::new(768, 600)), // main loop processes 256 items + (Key::new(2, Neon), Bounds::new(64, 64)), (Key::new(3, Scalar), Bounds::new(64, 64)), (Key::new(3, X86_64_V3), Bounds::new(256, 96)), (Key::new(3, X86_64_V4), Bounds::new(256, 96)), + (Key::new(3, Neon), Bounds::new(64, 64)), (Key::new(4, Scalar), Bounds::new(64, 64)), // Need a higher miri-amount due to the larget block size (Key::new(4, X86_64_V3), Bounds::new(256, 150)), (Key::new(4, X86_64_V4), Bounds::new(256, 150)), + (Key::new(4, Neon), Bounds::new(64, 64)), (Key::new(5, Scalar), Bounds::new(64, 64)), (Key::new(5, X86_64_V3), Bounds::new(256, 96)), (Key::new(5, X86_64_V4), Bounds::new(256, 96)), + (Key::new(5, Neon), Bounds::new(64, 64)), (Key::new(6, Scalar), Bounds::new(64, 64)), (Key::new(6, X86_64_V3), Bounds::new(256, 96)), (Key::new(6, X86_64_V4), Bounds::new(256, 96)), + (Key::new(6, Neon), Bounds::new(64, 64)), (Key::new(7, Scalar), Bounds::new(64, 64)), (Key::new(7, X86_64_V3), Bounds::new(256, 96)), (Key::new(7, X86_64_V4), Bounds::new(256, 96)), + (Key::new(7, Neon), Bounds::new(64, 64)), (Key::new(8, Scalar), Bounds::new(64, 64)), (Key::new(8, X86_64_V3), Bounds::new(256, 96)), (Key::new(8, X86_64_V4), Bounds::new(256, 96)), + (Key::new(8, Neon), Bounds::new(64, 64)), ] .into_iter() .collect() @@ -2175,6 +2183,7 @@ mod tests { X86_64_V3, #[expect(non_camel_case_types)] X86_64_V4, + Neon, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -2257,6 +2266,19 @@ mod tests { &mut rng, ); } + + #[cfg(target_arch = "aarch64")] + if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() { + let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Neon)].get(); + test_bitslice_distances::<$nbits, _>( + max_dim, + TRIALS_PER_DIM, + &|x, y| arch.run2(SquaredL2, x, y), + &|x, y| arch.run2(InnerProduct, x, y), + "neon", + &mut rng, + ); + } } }; } @@ -2453,6 +2475,18 @@ mod tests { &mut rng, ); } + + // Architecture Specific. + #[cfg(target_arch = "aarch64")] + if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() { + test_bit_transpose_distances( + MAX_DIM, + TRIALS_PER_DIM, + &|x, y| arch.run2(InnerProduct, x, y), + "neon", + &mut rng, + ); + } } ////////// @@ -2571,7 +2605,7 @@ mod tests { } #[cfg(target_arch = "x86_64")] - if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked() { + if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() { test_full_distances::<$nbits>( MAX_DIM, TRIALS_PER_DIM, @@ -2580,6 +2614,17 @@ mod tests { &mut rng, ); } + + #[cfg(target_arch = "aarch64")] + if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() { + test_full_distances::<$nbits>( + MAX_DIM, + TRIALS_PER_DIM, + &|x, y| arch.run2(InnerProduct, x, y), + "neon", + &mut rng, + ); + } } }; } diff --git a/diskann-wide/src/arch/aarch64/algorithms/load_first.rs b/diskann-wide/src/arch/aarch64/algorithms/load_first.rs index b679e28cb..ebe41c350 100644 --- a/diskann-wide/src/arch/aarch64/algorithms/load_first.rs +++ b/diskann-wide/src/arch/aarch64/algorithms/load_first.rs @@ -39,6 +39,18 @@ pub(in crate::arch::aarch64) unsafe fn i8x8(_: Neon, ptr: *const i8, first: usiz unsafe { vreinterpret_s8_u8(vcreate_u8(load_first_of_8_bytes(ptr.cast::(), first))) } } +/// Load the first `first` elements from `ptr` into a `uint16x4_t` register. +/// +/// # Safety +/// +/// The caller must ensure `[ptr, ptr + first)` is readable. The presence of `Neon` +/// enables the use of "neon" intrinsics. +#[inline(always)] +pub(in crate::arch::aarch64) unsafe fn u16x4(_: Neon, ptr: *const u16, first: usize) -> uint16x4_t { + // SAFETY: Pointer access inherited from caller. `Neon` enables "neon" intrinsics. + unsafe { vcreate_u16(load_first_of_8_bytes(ptr.cast::(), 2 * first)) } +} + /// Load the first `first` elements from `ptr` into a `float32x2_t` register. /// /// # Safety diff --git a/diskann-wide/src/arch/aarch64/double.rs b/diskann-wide/src/arch/aarch64/double.rs index 5880fbc1a..38186669d 100644 --- a/diskann-wide/src/arch/aarch64/double.rs +++ b/diskann-wide/src/arch/aarch64/double.rs @@ -165,7 +165,7 @@ impl crate::SIMDCast for i32x8 { #[cfg(test)] mod tests { use super::*; - use crate::{arch::aarch64::Neon, reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; // Run a standard set of: // - Load @@ -176,25 +176,31 @@ mod tests { ($type:ident, $scalar:ty, $lanes:literal) => { #[test] fn miri_test_load() { - test_utils::test_load_simd::<$scalar, $lanes, $type>(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::<$scalar, $lanes, $type>(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::<$scalar, $lanes, $type>(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::<$scalar, $lanes, $type>(arch); + } } #[test] fn test_constructors() { - test_utils::ops::test_splat::<$scalar, $lanes, $type>(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::<$scalar, $lanes, $type>(arch); + } } - test_utils::ops::test_add!($type, 0x1c08175714ae637e, Neon::new_checked()); - test_utils::ops::test_sub!($type, 0x3746ddcb006b7b4c, Neon::new_checked()); - test_utils::ops::test_mul!($type, 0xde99e62aaea3f38a, Neon::new_checked()); - test_utils::ops::test_fma!($type, 0x2e301b7e12090d5c, Neon::new_checked()); + test_utils::ops::test_add!($type, 0x1c08175714ae637e, test_neon()); + test_utils::ops::test_sub!($type, 0x3746ddcb006b7b4c, test_neon()); + test_utils::ops::test_mul!($type, 0xde99e62aaea3f38a, test_neon()); + test_utils::ops::test_fma!($type, 0x2e301b7e12090d5c, test_neon()); - test_utils::ops::test_cmp!($type, 0x90a59e23ad545de1, Neon::new_checked()); + test_utils::ops::test_cmp!($type, 0x90a59e23ad545de1, test_neon()); }; } @@ -202,15 +208,15 @@ mod tests { mod test_f32x8 { use super::*; standard_tests!(f32x8, f32, 8); - test_utils::ops::test_sumtree!(f32x8, 0x90a59e23ad545de1, Neon::new_checked()); - test_utils::ops::test_splitjoin!(f32x8 => f32x4, 0x2e301b7e12090d5c, Neon::new_checked()); + test_utils::ops::test_sumtree!(f32x8, 0x90a59e23ad545de1, test_neon()); + test_utils::ops::test_splitjoin!(f32x8 => f32x4, 0x2e301b7e12090d5c, test_neon()); } mod test_f32x16 { use super::*; standard_tests!(f32x16, f32, 16); - test_utils::ops::test_sumtree!(f32x16, 0x90a59e23ad545de1, Neon::new_checked()); - test_utils::ops::test_splitjoin!(f32x16 => f32x8, 0x2e301b7e12090d5c, Neon::new_checked()); + test_utils::ops::test_sumtree!(f32x16, 0x90a59e23ad545de1, test_neon()); + test_utils::ops::test_splitjoin!(f32x16 => f32x8, 0x2e301b7e12090d5c, test_neon()); } // u8s @@ -219,7 +225,7 @@ mod tests { standard_tests!(u8x32, u8, 32); // Bit ops - test_utils::ops::test_bitops!(u8x32, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(u8x32, 0xd62d8de09f82ed4e, test_neon()); } mod test_u8x64 { @@ -227,7 +233,7 @@ mod tests { standard_tests!(u8x64, u8, 64); // Bit ops - test_utils::ops::test_bitops!(u8x64, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(u8x64, 0xd62d8de09f82ed4e, test_neon()); } // u32s @@ -236,10 +242,10 @@ mod tests { standard_tests!(u32x8, u32, 8); // Bit ops - test_utils::ops::test_bitops!(u32x8, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(u32x8, 0xd62d8de09f82ed4e, test_neon()); // Reductions - test_utils::ops::test_sumtree!(u32x8, 0x90a59e23ad545de1, Neon::new_checked()); + test_utils::ops::test_sumtree!(u32x8, 0x90a59e23ad545de1, test_neon()); } mod test_u32x16 { @@ -247,10 +253,10 @@ mod tests { standard_tests!(u32x16, u32, 16); // Bit ops - test_utils::ops::test_bitops!(u32x16, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(u32x16, 0xd62d8de09f82ed4e, test_neon()); // Reductions - test_utils::ops::test_sumtree!(u32x16, 0x90a59e23ad545de1, Neon::new_checked()); + test_utils::ops::test_sumtree!(u32x16, 0x90a59e23ad545de1, test_neon()); } // u64s @@ -259,7 +265,7 @@ mod tests { standard_tests!(u64x4, u64, 4); // Bit ops - test_utils::ops::test_bitops!(u64x4, 0xc4491a44af4aa58e, Neon::new_checked()); + test_utils::ops::test_bitops!(u64x4, 0xc4491a44af4aa58e, test_neon()); } // i8s @@ -268,7 +274,7 @@ mod tests { standard_tests!(i8x32, i8, 32); // Bit ops - test_utils::ops::test_bitops!(i8x32, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(i8x32, 0xd62d8de09f82ed4e, test_neon()); } mod test_i8x64 { @@ -276,7 +282,7 @@ mod tests { standard_tests!(i8x64, i8, 64); // Bit ops - test_utils::ops::test_bitops!(i8x64, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(i8x64, 0xd62d8de09f82ed4e, test_neon()); } // i16s @@ -285,7 +291,7 @@ mod tests { standard_tests!(i16x16, i16, 16); // Bit ops - test_utils::ops::test_bitops!(i16x16, 0x9167644fc4ad5cfa, Neon::new_checked()); + test_utils::ops::test_bitops!(i16x16, 0x9167644fc4ad5cfa, test_neon()); } mod test_i16x32 { @@ -293,7 +299,7 @@ mod tests { standard_tests!(i16x32, i16, 32); // Bit ops - test_utils::ops::test_bitops!(i16x32, 0x9167644fc4ad5cfa, Neon::new_checked()); + test_utils::ops::test_bitops!(i16x32, 0x9167644fc4ad5cfa, test_neon()); } // i32s @@ -302,29 +308,29 @@ mod tests { standard_tests!(i32x8, i32, 8); // Bit ops - test_utils::ops::test_bitops!(i32x8, 0xc4491a44af4aa58e, Neon::new_checked()); + test_utils::ops::test_bitops!(i32x8, 0xc4491a44af4aa58e, test_neon()); // Dot Products test_utils::dot_product::test_dot_product!( (i16x16, i16x16) => i32x8, 0x145f89b446c03ff1, - Neon::new_checked() + test_neon() ); test_utils::dot_product::test_dot_product!( (u8x32, i8x32) => i32x8, 0x145f89b446c03ff1, - Neon::new_checked() + test_neon() ); test_utils::dot_product::test_dot_product!( (i8x32, u8x32) => i32x8, 0x145f89b446c03ff1, - Neon::new_checked() + test_neon() ); // Reductions - test_utils::ops::test_sumtree!(i32x8, 0x90a59e23ad545de1, Neon::new_checked()); + test_utils::ops::test_sumtree!(i32x8, 0x90a59e23ad545de1, test_neon()); } mod test_i32x16 { @@ -332,42 +338,42 @@ mod tests { standard_tests!(i32x16, i32, 16); // Bit ops - test_utils::ops::test_bitops!(i32x16, 0xc4491a44af4aa58e, Neon::new_checked()); + test_utils::ops::test_bitops!(i32x16, 0xc4491a44af4aa58e, test_neon()); // Dot Products test_utils::dot_product::test_dot_product!( (i16x32, i16x32) => i32x16, 0x145f89b446c03ff1, - Neon::new_checked() + test_neon() ); test_utils::dot_product::test_dot_product!( (u8x64, i8x64) => i32x16, 0x145f89b446c03ff1, - Neon::new_checked() + test_neon() ); test_utils::dot_product::test_dot_product!( (i8x64, u8x64) => i32x16, 0x145f89b446c03ff1, - Neon::new_checked() + test_neon() ); // Reductions - test_utils::ops::test_sumtree!(i32x16, 0x90a59e23ad545de1, Neon::new_checked()); + test_utils::ops::test_sumtree!(i32x16, 0x90a59e23ad545de1, test_neon()); } // Conversions - test_utils::ops::test_lossless_convert!(f16x8 => f32x8, 0x84c1c6f05b169a20, Neon::new_checked()); - test_utils::ops::test_lossless_convert!(f16x16 => f32x16, 0x84c1c6f05b169a20, Neon::new_checked()); + test_utils::ops::test_lossless_convert!(f16x8 => f32x8, 0x84c1c6f05b169a20, test_neon()); + test_utils::ops::test_lossless_convert!(f16x16 => f32x16, 0x84c1c6f05b169a20, test_neon()); - test_utils::ops::test_lossless_convert!(u8x16 => i16x16, 0x84c1c6f05b169a20, Neon::new_checked()); - test_utils::ops::test_lossless_convert!(i8x16 => i16x16, 0x84c1c6f05b169a20, Neon::new_checked()); + test_utils::ops::test_lossless_convert!(u8x16 => i16x16, 0x84c1c6f05b169a20, test_neon()); + test_utils::ops::test_lossless_convert!(i8x16 => i16x16, 0x84c1c6f05b169a20, test_neon()); - test_utils::ops::test_cast!(f16x8 => f32x8, 0xba8fe343fc9dbeff, Neon::new_checked()); - test_utils::ops::test_cast!(f16x16 => f32x16, 0xba8fe343fc9dbeff, Neon::new_checked()); - test_utils::ops::test_cast!(f32x8 => f16x8, 0xba8fe343fc9dbeff, Neon::new_checked()); - test_utils::ops::test_cast!(f32x16 => f16x16, 0xba8fe343fc9dbeff, Neon::new_checked()); + test_utils::ops::test_cast!(f16x8 => f32x8, 0xba8fe343fc9dbeff, test_neon()); + test_utils::ops::test_cast!(f16x16 => f32x16, 0xba8fe343fc9dbeff, test_neon()); + test_utils::ops::test_cast!(f32x8 => f16x8, 0xba8fe343fc9dbeff, test_neon()); + test_utils::ops::test_cast!(f32x16 => f16x16, 0xba8fe343fc9dbeff, test_neon()); - test_utils::ops::test_cast!(i32x8 => f32x8, 0xba8fe343fc9dbeff, Neon::new_checked()); + test_utils::ops::test_cast!(i32x8 => f32x8, 0xba8fe343fc9dbeff, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/f16x4_.rs b/diskann-wide/src/arch/aarch64/f16x4_.rs index b40814695..f9673a98b 100644 --- a/diskann-wide/src/arch/aarch64/f16x4_.rs +++ b/diskann-wide/src/arch/aarch64/f16x4_.rs @@ -14,7 +14,7 @@ use half::f16; // AArch64 masks use super::{ - Neon, + Neon, algorithms, macros::{self, AArchLoadStore, AArchSplat}, masks::mask16x4, }; @@ -59,9 +59,9 @@ impl AArchLoadStore for f16x4 { #[inline(always)] unsafe fn load_simd_first(arch: Neon, ptr: *const f16, first: usize) -> Self { - // SAFETY: Pointer access safety inhereted from the caller. - let e = unsafe { Emulated::::load_simd_first(Scalar, ptr, first) }; - Self::from_array(arch, e.to_array()) + // SAFETY: f16 and u16 share the same 2-byte representation. Pointer access + // inherited from caller. + Self(unsafe { algorithms::load_first::u16x4(arch, ptr.cast::(), first) }) } #[inline(always)] @@ -93,21 +93,27 @@ impl AArchLoadStore for f16x4 { #[cfg(test)] mod tests { use super::*; - use crate::test_utils; + use crate::{arch::aarch64::test_neon, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } } diff --git a/diskann-wide/src/arch/aarch64/f16x8_.rs b/diskann-wide/src/arch/aarch64/f16x8_.rs index a688b4060..40e119824 100644 --- a/diskann-wide/src/arch/aarch64/f16x8_.rs +++ b/diskann-wide/src/arch/aarch64/f16x8_.rs @@ -17,6 +17,7 @@ use super::{ Neon, f16x4, f32x8, macros::{self, AArchLoadStore, AArchSplat}, masks::mask16x8, + u16x8, }; // AArch64 intrinsics @@ -61,9 +62,11 @@ impl AArchLoadStore for f16x8 { #[inline(always)] unsafe fn load_simd_first(arch: Neon, ptr: *const f16, first: usize) -> Self { - // SAFETY: Pointer access safety inhereted from the caller. - let e = unsafe { Emulated::::load_simd_first(Scalar, ptr, first) }; - Self::from_array(arch, e.to_array()) + // SAFETY: f16 and u16 share the same 2-byte representation. Pointer access + // inherited from caller. + Self(unsafe { + ::load_simd_first(arch, ptr.cast::(), first).0 + }) } #[inline(always)] @@ -108,26 +111,32 @@ impl crate::SIMDCast for f16x8 { #[cfg(test)] mod tests { use super::*; - use crate::test_utils; + use crate::{arch::aarch64::test_neon, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } - test_utils::ops::test_splitjoin!(f16x8 => f16x4, 0xa4d00a4d04293967, Neon::new_checked()); + test_utils::ops::test_splitjoin!(f16x8 => f16x4, 0xa4d00a4d04293967, test_neon()); // Conversions - test_utils::ops::test_cast!(f16x8 => f32x8, 0x37314659b022466a, Neon::new_checked()); + test_utils::ops::test_cast!(f16x8 => f32x8, 0x37314659b022466a, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/f32x2_.rs b/diskann-wide/src/arch/aarch64/f32x2_.rs index 2f9abbc95..4ddd27f3e 100644 --- a/diskann-wide/src/arch/aarch64/f32x2_.rs +++ b/diskann-wide/src/arch/aarch64/f32x2_.rs @@ -47,7 +47,7 @@ impl SIMDSumTree for f32x2 { #[inline(always)] fn sum_tree(self) -> f32 { if cfg!(miri) { - self.sum_tree() + self.emulated().sum_tree() } else { // SAFETY: The presence of `Neon` enables the use of "neon" intrinsics. unsafe { vaddv_f32(self.to_underlying()) } @@ -62,31 +62,37 @@ impl SIMDSumTree for f32x2 { #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Ops - test_utils::ops::test_add!(f32x2, 0xcd7a8fea9a3fb727, Neon::new_checked()); - test_utils::ops::test_sub!(f32x2, 0x3f6562c94c923238, Neon::new_checked()); - test_utils::ops::test_mul!(f32x2, 0x07e48666c0fc564c, Neon::new_checked()); - test_utils::ops::test_fma!(f32x2, 0xcfde9d031302cf2c, Neon::new_checked()); + test_utils::ops::test_add!(f32x2, 0xcd7a8fea9a3fb727, test_neon()); + test_utils::ops::test_sub!(f32x2, 0x3f6562c94c923238, test_neon()); + test_utils::ops::test_mul!(f32x2, 0x07e48666c0fc564c, test_neon()); + test_utils::ops::test_fma!(f32x2, 0xcfde9d031302cf2c, test_neon()); - test_utils::ops::test_cmp!(f32x2, 0xc4f468b224622326, Neon::new_checked()); + test_utils::ops::test_cmp!(f32x2, 0xc4f468b224622326, test_neon()); - test_utils::ops::test_sumtree!(f32x2, 0x828bd890a470dc4d, Neon::new_checked()); + test_utils::ops::test_sumtree!(f32x2, 0x828bd890a470dc4d, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/f32x4_.rs b/diskann-wide/src/arch/aarch64/f32x4_.rs index 43bd680b8..0082eed1f 100644 --- a/diskann-wide/src/arch/aarch64/f32x4_.rs +++ b/diskann-wide/src/arch/aarch64/f32x4_.rs @@ -165,40 +165,46 @@ impl crate::SIMDCast for f32x4 { #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Ops - test_utils::ops::test_add!(f32x4, 0xcd7a8fea9a3fb727, Neon::new_checked()); - test_utils::ops::test_sub!(f32x4, 0x3f6562c94c923238, Neon::new_checked()); - test_utils::ops::test_mul!(f32x4, 0x07e48666c0fc564c, Neon::new_checked()); - test_utils::ops::test_fma!(f32x4, 0xcfde9d031302cf2c, Neon::new_checked()); - test_utils::ops::test_abs!(f32x4, 0xb8f702ba85375041, Neon::new_checked()); - test_utils::ops::test_minmax!(f32x4, 0x6d7fc8ed6d852187, Neon::new_checked()); - test_utils::ops::test_splitjoin!(f32x4 => f32x2, 0xa4d00a4d04293967, Neon::new_checked()); + test_utils::ops::test_add!(f32x4, 0xcd7a8fea9a3fb727, test_neon()); + test_utils::ops::test_sub!(f32x4, 0x3f6562c94c923238, test_neon()); + test_utils::ops::test_mul!(f32x4, 0x07e48666c0fc564c, test_neon()); + test_utils::ops::test_fma!(f32x4, 0xcfde9d031302cf2c, test_neon()); + test_utils::ops::test_abs!(f32x4, 0xb8f702ba85375041, test_neon()); + test_utils::ops::test_minmax!(f32x4, 0x6d7fc8ed6d852187, test_neon()); + test_utils::ops::test_splitjoin!(f32x4 => f32x2, 0xa4d00a4d04293967, test_neon()); - test_utils::ops::test_cmp!(f32x4, 0xc4f468b224622326, Neon::new_checked()); - test_utils::ops::test_select!(f32x4, 0xef24013b8578637c, Neon::new_checked()); + test_utils::ops::test_cmp!(f32x4, 0xc4f468b224622326, test_neon()); + test_utils::ops::test_select!(f32x4, 0xef24013b8578637c, test_neon()); - test_utils::ops::test_sumtree!(f32x4, 0x828bd890a470dc4d, Neon::new_checked()); + test_utils::ops::test_sumtree!(f32x4, 0x828bd890a470dc4d, test_neon()); // Conversions - test_utils::ops::test_lossless_convert!(f16x4 => f32x4, 0xecba3008eae54ce7, Neon::new_checked()); + test_utils::ops::test_lossless_convert!(f16x4 => f32x4, 0xecba3008eae54ce7, test_neon()); - test_utils::ops::test_cast!(f32x4 => f16x4, 0xba8fe343fc9dbeff, Neon::new_checked()); + test_utils::ops::test_cast!(f32x4 => f16x4, 0xba8fe343fc9dbeff, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/i16x8_.rs b/diskann-wide/src/arch/aarch64/i16x8_.rs index 10b44ef4d..0321ead00 100644 --- a/diskann-wide/src/arch/aarch64/i16x8_.rs +++ b/diskann-wide/src/arch/aarch64/i16x8_.rs @@ -76,37 +76,43 @@ helpers::unsafe_map_conversion!(u8x8, i16x8, (vreinterpretq_s16_u16, vmovl_u8), #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Ops - test_utils::ops::test_add!(i16x8, 0x3017fd73c99cc633, Neon::new_checked()); - test_utils::ops::test_sub!(i16x8, 0xfc627f10b5f8db8a, Neon::new_checked()); - test_utils::ops::test_mul!(i16x8, 0x0f4caa80eceaa523, Neon::new_checked()); - test_utils::ops::test_fma!(i16x8, 0xb8f702ba85375041, Neon::new_checked()); - test_utils::ops::test_abs!(i16x8, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_add!(i16x8, 0x3017fd73c99cc633, test_neon()); + test_utils::ops::test_sub!(i16x8, 0xfc627f10b5f8db8a, test_neon()); + test_utils::ops::test_mul!(i16x8, 0x0f4caa80eceaa523, test_neon()); + test_utils::ops::test_fma!(i16x8, 0xb8f702ba85375041, test_neon()); + test_utils::ops::test_abs!(i16x8, 0xb8f702ba85375041, test_neon()); - test_utils::ops::test_cmp!(i16x8, 0x941757bd5cc641a1, Neon::new_checked()); + test_utils::ops::test_cmp!(i16x8, 0x941757bd5cc641a1, test_neon()); // Bit ops - test_utils::ops::test_bitops!(i16x8, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(i16x8, 0xd62d8de09f82ed4e, test_neon()); // Conversion - test_utils::ops::test_lossless_convert!(i8x8 => i16x8, 0x79458ca52356242e, Neon::new_checked()); - test_utils::ops::test_lossless_convert!(u8x8 => i16x8, 0xa9a57c5c541ce360, Neon::new_checked()); + test_utils::ops::test_lossless_convert!(i8x8 => i16x8, 0x79458ca52356242e, test_neon()); + test_utils::ops::test_lossless_convert!(u8x8 => i16x8, 0xa9a57c5c541ce360, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/i32x4_.rs b/diskann-wide/src/arch/aarch64/i32x4_.rs index d7f6f09c4..dd0cc83b6 100644 --- a/diskann-wide/src/arch/aarch64/i32x4_.rs +++ b/diskann-wide/src/arch/aarch64/i32x4_.rs @@ -202,64 +202,70 @@ helpers::unsafe_map_cast!( #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Ops - test_utils::ops::test_add!(i32x4, 0x3017fd73c99cc633, Neon::new_checked()); - test_utils::ops::test_sub!(i32x4, 0xfc627f10b5f8db8a, Neon::new_checked()); - test_utils::ops::test_mul!(i32x4, 0x0f4caa80eceaa523, Neon::new_checked()); - test_utils::ops::test_fma!(i32x4, 0xb8f702ba85375041, Neon::new_checked()); - test_utils::ops::test_abs!(i32x4, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_add!(i32x4, 0x3017fd73c99cc633, test_neon()); + test_utils::ops::test_sub!(i32x4, 0xfc627f10b5f8db8a, test_neon()); + test_utils::ops::test_mul!(i32x4, 0x0f4caa80eceaa523, test_neon()); + test_utils::ops::test_fma!(i32x4, 0xb8f702ba85375041, test_neon()); + test_utils::ops::test_abs!(i32x4, 0xb8f702ba85375041, test_neon()); - test_utils::ops::test_cmp!(i32x4, 0x941757bd5cc641a1, Neon::new_checked()); + test_utils::ops::test_cmp!(i32x4, 0x941757bd5cc641a1, test_neon()); // Bit ops - test_utils::ops::test_bitops!(i32x4, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(i32x4, 0xd62d8de09f82ed4e, test_neon()); // Dot Products test_utils::dot_product::test_dot_product!( (i16x8, i16x8) => i32x4, 0x145f89b446c03ff1, - Neon::new_checked() + test_neon() ); test_utils::dot_product::test_dot_product!( (u8x16, i8x16) => i32x4, 0x145f89b446c03ff1, - Neon::new_checked() + test_neon() ); test_utils::dot_product::test_dot_product!( (i8x16, u8x16) => i32x4, 0x145f89b446c03ff1, - Neon::new_checked() + test_neon() ); test_utils::dot_product::test_dot_product!( (i8x16, i8x16) => i32x4, 0x145f89b446c03ff1, - Neon::new_checked() + test_neon() ); // Reductions - test_utils::ops::test_sumtree!(i32x4, 0xb9ac82ab23a855da, Neon::new_checked()); + test_utils::ops::test_sumtree!(i32x4, 0xb9ac82ab23a855da, test_neon()); // Conversions - test_utils::ops::test_cast!(i32x4 => f32x4, 0xba8fe343fc9dbeff, Neon::new_checked()); + test_utils::ops::test_cast!(i32x4 => f32x4, 0xba8fe343fc9dbeff, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/i64x2_.rs b/diskann-wide/src/arch/aarch64/i64x2_.rs index b37a24f1a..bb5bf3e4c 100644 --- a/diskann-wide/src/arch/aarch64/i64x2_.rs +++ b/diskann-wide/src/arch/aarch64/i64x2_.rs @@ -91,33 +91,39 @@ macros::aarch64_define_bitops!( #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Binary Ops - test_utils::ops::test_add!(i64x2, 0x8d7bf28b1c6e2545, Neon::new_checked()); - test_utils::ops::test_sub!(i64x2, 0x4a1c644a1a910bed, Neon::new_checked()); - test_utils::ops::test_mul!(i64x2, 0xf42ee707a808fd10, Neon::new_checked()); - test_utils::ops::test_fma!(i64x2, 0x28540d9936a9e803, Neon::new_checked()); - test_utils::ops::test_abs!(i64x2, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_add!(i64x2, 0x8d7bf28b1c6e2545, test_neon()); + test_utils::ops::test_sub!(i64x2, 0x4a1c644a1a910bed, test_neon()); + test_utils::ops::test_mul!(i64x2, 0xf42ee707a808fd10, test_neon()); + test_utils::ops::test_fma!(i64x2, 0x28540d9936a9e803, test_neon()); + test_utils::ops::test_abs!(i64x2, 0xb8f702ba85375041, test_neon()); - test_utils::ops::test_cmp!(i64x2, 0xfae27072c6b70885, Neon::new_checked()); + test_utils::ops::test_cmp!(i64x2, 0xfae27072c6b70885, test_neon()); // Bit ops - test_utils::ops::test_bitops!(i64x2, 0xbe927713ea310164, Neon::new_checked()); + test_utils::ops::test_bitops!(i64x2, 0xbe927713ea310164, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/i8x16_.rs b/diskann-wide/src/arch/aarch64/i8x16_.rs index 017d9f80a..51f401c28 100644 --- a/diskann-wide/src/arch/aarch64/i8x16_.rs +++ b/diskann-wide/src/arch/aarch64/i8x16_.rs @@ -66,34 +66,40 @@ macros::aarch64_define_bitops!( #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Ops - test_utils::ops::test_add!(i8x16, 0x3017fd73c99cc633, Neon::new_checked()); - test_utils::ops::test_sub!(i8x16, 0xfc627f10b5f8db8a, Neon::new_checked()); - test_utils::ops::test_mul!(i8x16, 0x0f4caa80eceaa523, Neon::new_checked()); - test_utils::ops::test_fma!(i8x16, 0xb8f702ba85375041, Neon::new_checked()); - test_utils::ops::test_abs!(i8x16, 0xb8f702ba85375041, Neon::new_checked()); - test_utils::ops::test_splitjoin!(i8x16 => i8x8, 0xa4d00a4d04293967, Neon::new_checked()); + test_utils::ops::test_add!(i8x16, 0x3017fd73c99cc633, test_neon()); + test_utils::ops::test_sub!(i8x16, 0xfc627f10b5f8db8a, test_neon()); + test_utils::ops::test_mul!(i8x16, 0x0f4caa80eceaa523, test_neon()); + test_utils::ops::test_fma!(i8x16, 0xb8f702ba85375041, test_neon()); + test_utils::ops::test_abs!(i8x16, 0xb8f702ba85375041, test_neon()); + test_utils::ops::test_splitjoin!(i8x16 => i8x8, 0xa4d00a4d04293967, test_neon()); - test_utils::ops::test_cmp!(i8x16, 0x941757bd5cc641a1, Neon::new_checked()); + test_utils::ops::test_cmp!(i8x16, 0x941757bd5cc641a1, test_neon()); // Bit ops - test_utils::ops::test_bitops!(i8x16, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(i8x16, 0xd62d8de09f82ed4e, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/i8x8_.rs b/diskann-wide/src/arch/aarch64/i8x8_.rs index 8cd501f93..a2409afd4 100644 --- a/diskann-wide/src/arch/aarch64/i8x8_.rs +++ b/diskann-wide/src/arch/aarch64/i8x8_.rs @@ -57,33 +57,39 @@ macros::aarch64_define_bitops!( #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Ops - test_utils::ops::test_add!(i8x8, 0x3017fd73c99cc633, Neon::new_checked()); - test_utils::ops::test_sub!(i8x8, 0xfc627f10b5f8db8a, Neon::new_checked()); - test_utils::ops::test_mul!(i8x8, 0x0f4caa80eceaa523, Neon::new_checked()); - test_utils::ops::test_fma!(i8x8, 0xb8f702ba85375041, Neon::new_checked()); - test_utils::ops::test_abs!(i8x8, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_add!(i8x8, 0x3017fd73c99cc633, test_neon()); + test_utils::ops::test_sub!(i8x8, 0xfc627f10b5f8db8a, test_neon()); + test_utils::ops::test_mul!(i8x8, 0x0f4caa80eceaa523, test_neon()); + test_utils::ops::test_fma!(i8x8, 0xb8f702ba85375041, test_neon()); + test_utils::ops::test_abs!(i8x8, 0xb8f702ba85375041, test_neon()); - test_utils::ops::test_cmp!(i8x8, 0x941757bd5cc641a1, Neon::new_checked()); + test_utils::ops::test_cmp!(i8x8, 0x941757bd5cc641a1, test_neon()); // Bit ops - test_utils::ops::test_bitops!(i8x8, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(i8x8, 0xd62d8de09f82ed4e, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/masks.rs b/diskann-wide/src/arch/aarch64/masks.rs index d20044eb7..8a3d8b959 100644 --- a/diskann-wide/src/arch/aarch64/masks.rs +++ b/diskann-wide/src/arch/aarch64/masks.rs @@ -17,15 +17,15 @@ //! //! Setting all bits is important because the `select` operations in Neon are bit-wise //! selects, unlike AVX2 where only the most-significant bit is important. -//! -//! The conversion implementation in this file still refer to the uppermost bit when -//! implementing `move_mask`-like functionality. -use crate::{BitMask, FromInt, SIMDMask}; +use std::arch::aarch64::*; -use super::Neon; +use crate::{BitMask, SIMDMask}; -use std::arch::aarch64::*; +#[cfg(not(miri))] +use crate::FromInt; + +use super::Neon; macro_rules! define_mask { ($mask:ident, $repr:ident, $lanes:literal, $arch:ty) => { @@ -617,7 +617,7 @@ impl MaskOps for uint64x2_t { #[cfg(test)] mod tests { use super::*; - use crate::{Const, SupportedLaneCount}; + use crate::{Const, SupportedLaneCount, arch::aarch64::test_neon}; trait MaskTraits: std::fmt::Debug { const SET: Self; @@ -686,7 +686,9 @@ mod tests { M: SIMDMask> + From>, ::Underlying: MaskOps, Array = [T; N]>, { - let arch = Neon::new_checked().unwrap(); + let Some(arch) = test_neon() else { + return; + }; // Test keep-first. for i in 0..N + 5 { diff --git a/diskann-wide/src/arch/aarch64/mod.rs b/diskann-wide/src/arch/aarch64/mod.rs index 87c6b4e39..01190cbe3 100644 --- a/diskann-wide/src/arch/aarch64/mod.rs +++ b/diskann-wide/src/arch/aarch64/mod.rs @@ -382,3 +382,37 @@ impl arch::Architecture for Neon { unsafe { arch::hide3(f) } } } + +/////////// +// Tests // +/////////// + +/// Return `Some(Neon)` if the Neon architecture should be used for testing. +/// +/// If the environment variable `WIDE_TEST_MIN_ARCH` is set, this uses the configured +/// architecture with the following mapping: +/// +/// * `all` or `neon`: Run the Neon backend +/// * `scalar`: Skip the Neon backend (returns `None`) +/// +/// If the variable is not set, this defaults to [`Neon::new_checked()`]. +#[cfg(test)] +pub(super) fn test_neon() -> Option { + match crate::get_test_arch() { + Some(arch) => { + if arch == "all" || arch == "neon" { + match Neon::new_checked() { + Some(v) => Some(v), + None => panic!( + "Neon architecture was requested but is not available on the current target" + ), + } + } else if arch == "scalar" { + None + } else { + panic!("Unrecognized test architecture: \"{arch}\""); + } + } + None => Neon::new_checked(), + } +} diff --git a/diskann-wide/src/arch/aarch64/u16x8_.rs b/diskann-wide/src/arch/aarch64/u16x8_.rs index 6e1b1d51f..7a345895a 100644 --- a/diskann-wide/src/arch/aarch64/u16x8_.rs +++ b/diskann-wide/src/arch/aarch64/u16x8_.rs @@ -72,32 +72,38 @@ macros::aarch64_define_bitops!( #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Ops - test_utils::ops::test_add!(u16x8, 0x3017fd73c99cc633, Neon::new_checked()); - test_utils::ops::test_sub!(u16x8, 0xfc627f10b5f8db8a, Neon::new_checked()); - test_utils::ops::test_mul!(u16x8, 0x0f4caa80eceaa523, Neon::new_checked()); - test_utils::ops::test_fma!(u16x8, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_add!(u16x8, 0x3017fd73c99cc633, test_neon()); + test_utils::ops::test_sub!(u16x8, 0xfc627f10b5f8db8a, test_neon()); + test_utils::ops::test_mul!(u16x8, 0x0f4caa80eceaa523, test_neon()); + test_utils::ops::test_fma!(u16x8, 0xb8f702ba85375041, test_neon()); - test_utils::ops::test_cmp!(u16x8, 0x941757bd5cc641a1, Neon::new_checked()); + test_utils::ops::test_cmp!(u16x8, 0x941757bd5cc641a1, test_neon()); // Bit ops - test_utils::ops::test_bitops!(u16x8, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(u16x8, 0xd62d8de09f82ed4e, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/u32x4_.rs b/diskann-wide/src/arch/aarch64/u32x4_.rs index e25844288..f980e4fb8 100644 --- a/diskann-wide/src/arch/aarch64/u32x4_.rs +++ b/diskann-wide/src/arch/aarch64/u32x4_.rs @@ -127,42 +127,48 @@ impl SIMDDotProduct for u32x4 { #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Ops - test_utils::ops::test_add!(u32x4, 0x3017fd73c99cc633, Neon::new_checked()); - test_utils::ops::test_sub!(u32x4, 0xfc627f10b5f8db8a, Neon::new_checked()); - test_utils::ops::test_mul!(u32x4, 0x0f4caa80eceaa523, Neon::new_checked()); - test_utils::ops::test_fma!(u32x4, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_add!(u32x4, 0x3017fd73c99cc633, test_neon()); + test_utils::ops::test_sub!(u32x4, 0xfc627f10b5f8db8a, test_neon()); + test_utils::ops::test_mul!(u32x4, 0x0f4caa80eceaa523, test_neon()); + test_utils::ops::test_fma!(u32x4, 0xb8f702ba85375041, test_neon()); - test_utils::ops::test_cmp!(u32x4, 0x941757bd5cc641a1, Neon::new_checked()); + test_utils::ops::test_cmp!(u32x4, 0x941757bd5cc641a1, test_neon()); // Dot Product test_utils::dot_product::test_dot_product!( (u8x16, u8x16) => u32x4, 0x145f89b446c03ff1, - Neon::new_checked() + test_neon() ); // Bit ops - test_utils::ops::test_bitops!(u32x4, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(u32x4, 0xd62d8de09f82ed4e, test_neon()); // Reductions - test_utils::ops::test_sumtree!(u32x4, 0xb9ac82ab23a855da, Neon::new_checked()); + test_utils::ops::test_sumtree!(u32x4, 0xb9ac82ab23a855da, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/u64x2_.rs b/diskann-wide/src/arch/aarch64/u64x2_.rs index 7d0e45bde..2ac54ac07 100644 --- a/diskann-wide/src/arch/aarch64/u64x2_.rs +++ b/diskann-wide/src/arch/aarch64/u64x2_.rs @@ -102,32 +102,38 @@ macros::aarch64_define_bitops!( #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Binary Ops - test_utils::ops::test_add!(u64x2, 0x8d7bf28b1c6e2545, Neon::new_checked()); - test_utils::ops::test_sub!(u64x2, 0x4a1c644a1a910bed, Neon::new_checked()); - test_utils::ops::test_mul!(u64x2, 0xf42ee707a808fd10, Neon::new_checked()); - test_utils::ops::test_fma!(u64x2, 0x28540d9936a9e803, Neon::new_checked()); + test_utils::ops::test_add!(u64x2, 0x8d7bf28b1c6e2545, test_neon()); + test_utils::ops::test_sub!(u64x2, 0x4a1c644a1a910bed, test_neon()); + test_utils::ops::test_mul!(u64x2, 0xf42ee707a808fd10, test_neon()); + test_utils::ops::test_fma!(u64x2, 0x28540d9936a9e803, test_neon()); - test_utils::ops::test_cmp!(u64x2, 0xfae27072c6b70885, Neon::new_checked()); + test_utils::ops::test_cmp!(u64x2, 0xfae27072c6b70885, test_neon()); // Bit ops - test_utils::ops::test_bitops!(u64x2, 0xbe927713ea310164, Neon::new_checked()); + test_utils::ops::test_bitops!(u64x2, 0xbe927713ea310164, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/u8x16_.rs b/diskann-wide/src/arch/aarch64/u8x16_.rs index be1a1f8f5..819bb9f6b 100644 --- a/diskann-wide/src/arch/aarch64/u8x16_.rs +++ b/diskann-wide/src/arch/aarch64/u8x16_.rs @@ -68,33 +68,39 @@ macros::aarch64_define_bitops!( #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Ops - test_utils::ops::test_add!(u8x16, 0x3017fd73c99cc633, Neon::new_checked()); - test_utils::ops::test_sub!(u8x16, 0xfc627f10b5f8db8a, Neon::new_checked()); - test_utils::ops::test_mul!(u8x16, 0x0f4caa80eceaa523, Neon::new_checked()); - test_utils::ops::test_fma!(u8x16, 0xb8f702ba85375041, Neon::new_checked()); - test_utils::ops::test_splitjoin!(u8x16 => u8x8, 0xa4d00a4d04293967, Neon::new_checked()); + test_utils::ops::test_add!(u8x16, 0x3017fd73c99cc633, test_neon()); + test_utils::ops::test_sub!(u8x16, 0xfc627f10b5f8db8a, test_neon()); + test_utils::ops::test_mul!(u8x16, 0x0f4caa80eceaa523, test_neon()); + test_utils::ops::test_fma!(u8x16, 0xb8f702ba85375041, test_neon()); + test_utils::ops::test_splitjoin!(u8x16 => u8x8, 0xa4d00a4d04293967, test_neon()); - test_utils::ops::test_cmp!(u8x16, 0x941757bd5cc641a1, Neon::new_checked()); + test_utils::ops::test_cmp!(u8x16, 0x941757bd5cc641a1, test_neon()); // Bit ops - test_utils::ops::test_bitops!(u8x16, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(u8x16, 0xd62d8de09f82ed4e, test_neon()); } diff --git a/diskann-wide/src/arch/aarch64/u8x8_.rs b/diskann-wide/src/arch/aarch64/u8x8_.rs index 44812de81..28602e7d6 100644 --- a/diskann-wide/src/arch/aarch64/u8x8_.rs +++ b/diskann-wide/src/arch/aarch64/u8x8_.rs @@ -58,32 +58,38 @@ macros::aarch64_define_bitops!( #[cfg(test)] mod tests { use super::*; - use crate::{reference::ReferenceScalarOps, test_utils}; + use crate::{arch::aarch64::test_neon, reference::ReferenceScalarOps, test_utils}; #[test] fn miri_test_load() { - test_utils::test_load_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_load_simd::(arch); + } } #[test] fn miri_test_store() { - test_utils::test_store_simd::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::test_store_simd::(arch); + } } // constructors #[test] fn test_constructors() { - test_utils::ops::test_splat::(Neon::new_checked().unwrap()); + if let Some(arch) = test_neon() { + test_utils::ops::test_splat::(arch); + } } // Ops - test_utils::ops::test_add!(u8x8, 0x3017fd73c99cc633, Neon::new_checked()); - test_utils::ops::test_sub!(u8x8, 0xfc627f10b5f8db8a, Neon::new_checked()); - test_utils::ops::test_mul!(u8x8, 0x0f4caa80eceaa523, Neon::new_checked()); - test_utils::ops::test_fma!(u8x8, 0xb8f702ba85375041, Neon::new_checked()); + test_utils::ops::test_add!(u8x8, 0x3017fd73c99cc633, test_neon()); + test_utils::ops::test_sub!(u8x8, 0xfc627f10b5f8db8a, test_neon()); + test_utils::ops::test_mul!(u8x8, 0x0f4caa80eceaa523, test_neon()); + test_utils::ops::test_fma!(u8x8, 0xb8f702ba85375041, test_neon()); - test_utils::ops::test_cmp!(u8x8, 0x941757bd5cc641a1, Neon::new_checked()); + test_utils::ops::test_cmp!(u8x8, 0x941757bd5cc641a1, test_neon()); // Bit ops - test_utils::ops::test_bitops!(u8x8, 0xd62d8de09f82ed4e, Neon::new_checked()); + test_utils::ops::test_bitops!(u8x8, 0xd62d8de09f82ed4e, test_neon()); } diff --git a/diskann-wide/src/lib.rs b/diskann-wide/src/lib.rs index 287284553..65d9f6bc2 100644 --- a/diskann-wide/src/lib.rs +++ b/diskann-wide/src/lib.rs @@ -217,10 +217,10 @@ macro_rules! alias { // Internal // ////////////// -#[cfg(all(test, target_arch = "x86_64"))] +#[cfg(all(test, any(target_arch = "x86_64", target_arch = "aarch64")))] const TEST_MIN_ARCH: &str = "WIDE_TEST_MIN_ARCH"; -#[cfg(all(test, target_arch = "x86_64"))] +#[cfg(all(test, any(target_arch = "x86_64", target_arch = "aarch64")))] fn get_test_arch() -> Option { match std::env::var(TEST_MIN_ARCH) { Ok(v) => Some(v), From 076a50119e7e482810fbc437b2dd7ed0073c0e16 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 16 Feb 2026 12:17:22 -0800 Subject: [PATCH 07/10] Fix typo. --- diskann-wide/src/arch/aarch64/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/diskann-wide/src/arch/aarch64/mod.rs b/diskann-wide/src/arch/aarch64/mod.rs index 01190cbe3..9fd6b77cc 100644 --- a/diskann-wide/src/arch/aarch64/mod.rs +++ b/diskann-wide/src/arch/aarch64/mod.rs @@ -86,7 +86,7 @@ pub use double::u64x4; mod macros; mod masks; -// The ordering is `Scalar < V3 < V4`. +// The ordering is `Scalar < Neon`. #[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] pub(super) enum LevelInner { Scalar, @@ -346,7 +346,7 @@ impl arch::Architecture for Neon { let f: unsafe fn(Self, T0::Of<'_>) -> R = Self::run_function_with_1::; // SAFETY: The present of `self` as an argument attests that it is safe to construct - // a `Neon` architecture. Additionally, since `V3` is a `Copy` zero-sized type, + // a `Neon` architecture. Additionally, since `Neon` is a `Copy` zero-sized type, // it is safe to wink into existence and is ABI compattible with `Hidden`. unsafe { arch::hide1(f) } } @@ -361,7 +361,7 @@ impl arch::Architecture for Neon { Self::run_function_with_2::; // SAFETY: The present of `self` as an argument attests that it is safe to construct - // a `Neon` architecture. Additionally, since `V3` is a `Copy` zero-sized type, + // a `Neon` architecture. Additionally, since `Neon` is a `Copy` zero-sized type, // it is safe to wink into existence and is ABI compattible with `Hidden`. unsafe { arch::hide2(f) } } @@ -377,7 +377,7 @@ impl arch::Architecture for Neon { Self::run_function_with_3::; // SAFETY: The present of `self` as an argument attests that it is safe to construct - // A `Neon` architecture. Additionally, since `V3` is a `Copy` zero-sized type, + // A `Neon` architecture. Additionally, since `Neon` is a `Copy` zero-sized type, // it is safe to wink into existence and is ABI compattible with `Hidden`. unsafe { arch::hide3(f) } } From b762f79f33555027a91fd5d05499e8af90aec9c0 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 16 Feb 2026 12:35:42 -0800 Subject: [PATCH 08/10] Here we gooo! --- diskann-benchmark-simd/src/lib.rs | 2 +- diskann-wide/src/arch/aarch64/f32x4_.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 36c8b2adc..4701b04d7 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -279,7 +279,7 @@ fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry: Kernel<'static, diskann_wide::arch::x86_64::V3, i8, i8> ); - // x86-64-v3 + // aarch64-neon register!( "aarch64", dispatcher, diff --git a/diskann-wide/src/arch/aarch64/f32x4_.rs b/diskann-wide/src/arch/aarch64/f32x4_.rs index 0082eed1f..a8f42566d 100644 --- a/diskann-wide/src/arch/aarch64/f32x4_.rs +++ b/diskann-wide/src/arch/aarch64/f32x4_.rs @@ -56,13 +56,13 @@ impl SIMDMinMax for f32x4 { #[inline(always)] fn max_simd(self, rhs: Self) -> Self { - // SAFETY: `vminnmq_f32` requires "neon", implied by the `Neon` architecture. + // SAFETY: `vmaxnmq_f32` requires "neon", implied by the `Neon` architecture. Self(unsafe { vmaxnmq_f32(self.0, rhs.0) }) } #[inline(always)] fn max_simd_standard(self, rhs: Self) -> Self { - // SAFETY: `vminnmq_f32` requires "neon", implied by the `Neon` architecture. + // SAFETY: `vmaxnmq_f32` requires "neon", implied by the `Neon` architecture. Self(unsafe { vmaxnmq_f32(self.0, rhs.0) }) } } From c18da522f44e7d6dd93bff508b3219ede6b8f761 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 16 Feb 2026 12:47:56 -0800 Subject: [PATCH 09/10] Disable inclusion of `x86_64` module when building rust-doc. --- diskann-wide/src/arch/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann-wide/src/arch/mod.rs b/diskann-wide/src/arch/mod.rs index 9ec5a3573..acc2f954c 100644 --- a/diskann-wide/src/arch/mod.rs +++ b/diskann-wide/src/arch/mod.rs @@ -381,7 +381,7 @@ impl Level { } cfg_if::cfg_if! { - if #[cfg(any(target_arch = "x86_64", doc))] { + if #[cfg(target_arch = "x86_64")] { // Delegate to the architecture selection within the `x86_64` module. pub mod x86_64; From 9ce00a7f13f4d9de6c14790125028a74234e392a Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 16 Feb 2026 15:13:28 -0800 Subject: [PATCH 10/10] Fix an oops. --- diskann-benchmark-simd/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 4701b04d7..530be5803 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -405,7 +405,7 @@ impl DispatchRule for Identity { if *from == Arch::Reference { Ok(MatchScore(0)) } else { - Err(FailureScore(10)) + Err(FailureScore(1)) } }