@@ -43,13 +43,10 @@ type L1Fn = unsafe fn(&[i16; 17], &[i16; 17]) -> u32;
4343#[ cfg( target_arch = "x86_64" ) ]
4444#[ target_feature( enable = "avx512f" ) ]
4545unsafe fn l1_avx512 ( a : & [ i16 ; 17 ] , b : & [ i16 ; 17 ] ) -> u32 {
46- use std:: arch:: x86_64:: * ;
47- // Load 16 i16 → 16 i32 via sign-extension
48- let va = _mm512_cvtepi16_epi32 ( _mm256_loadu_si256 ( a. as_ptr ( ) as * const __m256i ) ) ;
49- let vb = _mm512_cvtepi16_epi32 ( _mm256_loadu_si256 ( b. as_ptr ( ) as * const __m256i ) ) ;
50- let diff = _mm512_sub_epi32 ( va, vb) ;
51- let abs_diff = _mm512_abs_epi32 ( diff) ;
52- let sum16 = _mm512_reduce_add_epi32 ( abs_diff) as u32 ;
46+ let va = crate :: simd:: I32x16 :: from_i16_slice ( a) ;
47+ let vb = crate :: simd:: I32x16 :: from_i16_slice ( b) ;
48+ let abs_diff = ( va - vb) . abs ( ) ;
49+ let sum16 = abs_diff. reduce_sum ( ) as u32 ;
5350 // 17th dim scalar
5451 let d16 = ( a[ 16 ] as i32 - b[ 16 ] as i32 ) . unsigned_abs ( ) ;
5552 sum16 + d16
@@ -58,26 +55,10 @@ unsafe fn l1_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 {
5855#[ cfg( target_arch = "x86_64" ) ]
5956#[ target_feature( enable = "avx2" ) ]
6057unsafe fn l1_avx2 ( a : & [ i16 ; 17 ] , b : & [ i16 ; 17 ] ) -> u32 {
61- use std:: arch:: x86_64:: * ;
62- // Process 8 dims at a time (2 passes of 8 = 16, + 1 scalar)
63- let va0 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( a. as_ptr ( ) as * const __m128i ) ) ;
64- let vb0 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( b. as_ptr ( ) as * const __m128i ) ) ;
65- let diff0 = _mm256_sub_epi32 ( va0, vb0) ;
66- let abs0 = _mm256_abs_epi32 ( diff0) ;
67-
68- let va1 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( a[ 8 ..] . as_ptr ( ) as * const __m128i ) ) ;
69- let vb1 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( b[ 8 ..] . as_ptr ( ) as * const __m128i ) ) ;
70- let diff1 = _mm256_sub_epi32 ( va1, vb1) ;
71- let abs1 = _mm256_abs_epi32 ( diff1) ;
72-
73- let sum = _mm256_add_epi32 ( abs0, abs1) ;
74- // Horizontal sum of 8 i32
75- let hi128 = _mm256_extracti128_si256 ( sum, 1 ) ;
76- let lo128 = _mm256_castsi256_si128 ( sum) ;
77- let sum128 = _mm_add_epi32 ( lo128, hi128) ;
78- let sum64 = _mm_add_epi32 ( sum128, _mm_srli_si128 ( sum128, 8 ) ) ;
79- let sum32 = _mm_add_epi32 ( sum64, _mm_srli_si128 ( sum64, 4 ) ) ;
80- let sum16 = _mm_extract_epi32 ( sum32, 0 ) as u32 ;
58+ let va = crate :: simd:: I32x16 :: from_i16_slice ( a) ;
59+ let vb = crate :: simd:: I32x16 :: from_i16_slice ( b) ;
60+ let abs_diff = ( va - vb) . abs ( ) ;
61+ let sum16 = abs_diff. reduce_sum ( ) as u32 ;
8162 // 17th dim scalar
8263 let d16 = ( a[ 16 ] as i32 - b[ 16 ] as i32 ) . unsigned_abs ( ) ;
8364 sum16 + d16
@@ -115,14 +96,12 @@ const WEIGHT_VEC: [i32; 16] = [20, 3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
11596#[ cfg( target_arch = "x86_64" ) ]
11697#[ target_feature( enable = "avx512f" ) ]
11798unsafe fn l1_weighted_avx512 ( a : & [ i16 ; 17 ] , b : & [ i16 ; 17 ] ) -> u32 {
118- use std:: arch:: x86_64:: * ;
119- let va = _mm512_cvtepi16_epi32 ( _mm256_loadu_si256 ( a. as_ptr ( ) as * const __m256i ) ) ;
120- let vb = _mm512_cvtepi16_epi32 ( _mm256_loadu_si256 ( b. as_ptr ( ) as * const __m256i ) ) ;
121- let diff = _mm512_sub_epi32 ( va, vb) ;
122- let abs_diff = _mm512_abs_epi32 ( diff) ;
123- let vw = _mm512_loadu_si512 ( WEIGHT_VEC . as_ptr ( ) as * const __m512i ) ;
124- let weighted = _mm512_mullo_epi32 ( abs_diff, vw) ;
125- let sum16 = _mm512_reduce_add_epi32 ( weighted) as u32 ;
99+ let va = crate :: simd:: I32x16 :: from_i16_slice ( a) ;
100+ let vb = crate :: simd:: I32x16 :: from_i16_slice ( b) ;
101+ let abs_diff = ( va - vb) . abs ( ) ;
102+ let vw = crate :: simd:: I32x16 :: from_array ( WEIGHT_VEC ) ;
103+ let weighted = abs_diff * vw;
104+ let sum16 = weighted. reduce_sum ( ) as u32 ;
126105 // 17th dim: weight = 1
127106 let d16 = ( a[ 16 ] as i32 - b[ 16 ] as i32 ) . unsigned_abs ( ) ;
128107 sum16 + d16
@@ -131,34 +110,14 @@ unsafe fn l1_weighted_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 {
131110#[ cfg( target_arch = "x86_64" ) ]
132111#[ target_feature( enable = "avx2" ) ]
133112unsafe fn l1_weighted_avx2 ( a : & [ i16 ; 17 ] , b : & [ i16 ; 17 ] ) -> u32 {
134- use std:: arch:: x86_64:: * ;
135- // First 8 dims
136- let va0 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( a. as_ptr ( ) as * const __m128i ) ) ;
137- let vb0 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( b. as_ptr ( ) as * const __m128i ) ) ;
138- let diff0 = _mm256_sub_epi32 ( va0, vb0) ;
139- let abs0 = _mm256_abs_epi32 ( diff0) ;
140- let vw0 = _mm256_loadu_si256 ( WEIGHT_VEC . as_ptr ( ) as * const __m256i ) ;
141- let w0 = _mm256_mullo_epi32 ( abs0, vw0) ;
142-
143- // Dims 8..16
144- let va1 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( a[ 8 ..] . as_ptr ( ) as * const __m128i ) ) ;
145- let vb1 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( b[ 8 ..] . as_ptr ( ) as * const __m128i ) ) ;
146- let diff1 = _mm256_sub_epi32 ( va1, vb1) ;
147- let abs1 = _mm256_abs_epi32 ( diff1) ;
148- let vw1 = _mm256_loadu_si256 ( WEIGHT_VEC [ 8 ..] . as_ptr ( ) as * const __m256i ) ;
149- let w1 = _mm256_mullo_epi32 ( abs1, vw1) ;
150-
151- let sum = _mm256_add_epi32 ( w0, w1) ;
152- // Horizontal sum
153- let hi128 = _mm256_extracti128_si256 ( sum, 1 ) ;
154- let lo128 = _mm256_castsi256_si128 ( sum) ;
155- let sum128 = _mm_add_epi32 ( lo128, hi128) ;
156- let sum64 = _mm_add_epi32 ( sum128, _mm_srli_si128 ( sum128, 8 ) ) ;
157- let sum32 = _mm_add_epi32 ( sum64, _mm_srli_si128 ( sum64, 4 ) ) ;
158- let s = _mm_extract_epi32 ( sum32, 0 ) as u32 ;
159- // 17th dim: weight = 1
113+ let va = crate :: simd:: I32x16 :: from_i16_slice ( a) ;
114+ let vb = crate :: simd:: I32x16 :: from_i16_slice ( b) ;
115+ let abs_diff = ( va - vb) . abs ( ) ;
116+ let vw = crate :: simd:: I32x16 :: from_array ( WEIGHT_VEC ) ;
117+ let weighted = abs_diff * vw;
118+ let sum16 = weighted. reduce_sum ( ) as u32 ;
160119 let d16 = ( a[ 16 ] as i32 - b[ 16 ] as i32 ) . unsigned_abs ( ) ;
161- s + d16
120+ sum16 + d16
162121}
163122
164123fn l1_weighted_scalar ( a : & [ i16 ; 17 ] , b : & [ i16 ; 17 ] ) -> u32 {
@@ -193,14 +152,10 @@ type SignAgreementFn = unsafe fn(&[i16; 17], &[i16; 17]) -> u32;
193152#[ cfg( target_arch = "x86_64" ) ]
194153#[ target_feature( enable = "avx512f" ) ]
195154unsafe fn sign_agreement_avx512 ( a : & [ i16 ; 17 ] , b : & [ i16 ; 17 ] ) -> u32 {
196- use std:: arch:: x86_64:: * ;
197- let va = _mm512_cvtepi16_epi32 ( _mm256_loadu_si256 ( a. as_ptr ( ) as * const __m256i ) ) ;
198- let vb = _mm512_cvtepi16_epi32 ( _mm256_loadu_si256 ( b. as_ptr ( ) as * const __m256i ) ) ;
199- // XOR: same sign → non-negative, different sign → negative
200- let xor = _mm512_xor_si512 ( va, vb) ;
201- // Compare >= 0: mask bit set where same sign
202- let zero = _mm512_setzero_si512 ( ) ;
203- let mask = _mm512_cmpge_epi32_mask ( xor, zero) ;
155+ let va = crate :: simd:: I32x16 :: from_i16_slice ( a) ;
156+ let vb = crate :: simd:: I32x16 :: from_i16_slice ( b) ;
157+ let xor = va ^ vb;
158+ let mask = xor. cmpge_zero_mask ( ) ;
204159 let count16 = mask. count_ones ( ) ;
205160 // 17th dim
206161 let same17 = if ( a[ 16 ] >= 0 ) == ( b[ 16 ] >= 0 ) { 1u32 } else { 0u32 } ;
@@ -210,28 +165,14 @@ unsafe fn sign_agreement_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 {
210165#[ cfg( target_arch = "x86_64" ) ]
211166#[ target_feature( enable = "avx2" ) ]
212167unsafe fn sign_agreement_avx2 ( a : & [ i16 ; 17 ] , b : & [ i16 ; 17 ] ) -> u32 {
213- use std:: arch:: x86_64:: * ;
214- // First 8 dims
215- let va0 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( a. as_ptr ( ) as * const __m128i ) ) ;
216- let vb0 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( b. as_ptr ( ) as * const __m128i ) ) ;
217- let xor0 = _mm256_xor_si256 ( va0, vb0) ;
218- let zero = _mm256_setzero_si256 ( ) ;
219- let neg0 = _mm256_cmpgt_epi32 ( zero, xor0) ; // -1 where xor < 0
220- // movemask_ps on the reinterpreted float gives 8 bits, one per 32-bit lane
221- let mask0 = _mm256_movemask_ps ( _mm256_castsi256_ps ( neg0) ) as u32 ;
222- let same0 = 8 - mask0. count_ones ( ) ;
223-
224- // Dims 8..16
225- let va1 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( a[ 8 ..] . as_ptr ( ) as * const __m128i ) ) ;
226- let vb1 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( b[ 8 ..] . as_ptr ( ) as * const __m128i ) ) ;
227- let xor1 = _mm256_xor_si256 ( va1, vb1) ;
228- let neg1 = _mm256_cmpgt_epi32 ( zero, xor1) ;
229- let mask1 = _mm256_movemask_ps ( _mm256_castsi256_ps ( neg1) ) as u32 ;
230- let same1 = 8 - mask1. count_ones ( ) ;
231-
168+ let va = crate :: simd:: I32x16 :: from_i16_slice ( a) ;
169+ let vb = crate :: simd:: I32x16 :: from_i16_slice ( b) ;
170+ let xor = va ^ vb;
171+ let mask = xor. cmpge_zero_mask ( ) ;
172+ let count16 = mask. count_ones ( ) ;
232173 // 17th dim
233174 let same17 = if ( a[ 16 ] >= 0 ) == ( b[ 16 ] >= 0 ) { 1u32 } else { 0u32 } ;
234- same0 + same1 + same17
175+ count16 + same17
235176}
236177
237178fn sign_agreement_scalar ( a : & [ i16 ; 17 ] , b : & [ i16 ; 17 ] ) -> u32 {
@@ -267,47 +208,25 @@ type XorBindFn = unsafe fn(&[i16; 17], &[i16; 17]) -> [i16; 17];
267208#[ cfg( target_arch = "x86_64" ) ]
268209#[ target_feature( enable = "avx512f" ) ]
269210unsafe fn xor_bind_avx512 ( a : & [ i16 ; 17 ] , b : & [ i16 ; 17 ] ) -> [ i16 ; 17 ] {
270- use std:: arch:: x86_64:: * ;
271- // Load 16 i16 as i32, XOR, store back as i16
272- let va = _mm512_cvtepi16_epi32 ( _mm256_loadu_si256 ( a. as_ptr ( ) as * const __m256i ) ) ;
273- let vb = _mm512_cvtepi16_epi32 ( _mm256_loadu_si256 ( b. as_ptr ( ) as * const __m256i ) ) ;
274- let xored = _mm512_xor_si512 ( va, vb) ;
275- // Convert back to i16: truncate i32 -> i16 via pmovdw
276- let packed = _mm512_cvtepi32_epi16 ( xored) ;
211+ let va = crate :: simd:: I32x16 :: from_i16_slice ( a) ;
212+ let vb = crate :: simd:: I32x16 :: from_i16_slice ( b) ;
213+ let xored = va ^ vb; // BitXor trait
214+ let narrow = xored. to_i16_array ( ) ; // narrow i32→i16
277215 let mut dims = [ 0i16 ; 17 ] ;
278- _mm256_storeu_si256 ( dims. as_mut_ptr ( ) as * mut __m256i , packed ) ;
216+ dims[ .. 16 ] . copy_from_slice ( & narrow ) ;
279217 dims[ 16 ] = ( a[ 16 ] as u16 ^ b[ 16 ] as u16 ) as i16 ;
280218 dims
281219}
282220
283221#[ cfg( target_arch = "x86_64" ) ]
284222#[ target_feature( enable = "avx2" ) ]
285223unsafe fn xor_bind_avx2 ( a : & [ i16 ; 17 ] , b : & [ i16 ; 17 ] ) -> [ i16 ; 17 ] {
286- use std:: arch:: x86_64:: * ;
287- // First 8 dims: load as i32, XOR, narrow back
288- let va0 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( a. as_ptr ( ) as * const __m128i ) ) ;
289- let vb0 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( b. as_ptr ( ) as * const __m128i ) ) ;
290- let xor0 = _mm256_xor_si256 ( va0, vb0) ;
291-
292- // Dims 8..16
293- let va1 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( a[ 8 ..] . as_ptr ( ) as * const __m128i ) ) ;
294- let vb1 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( b[ 8 ..] . as_ptr ( ) as * const __m128i ) ) ;
295- let xor1 = _mm256_xor_si256 ( va1, vb1) ;
296-
297- // Extract results back to i16
224+ let va = crate :: simd:: I32x16 :: from_i16_slice ( a) ;
225+ let vb = crate :: simd:: I32x16 :: from_i16_slice ( b) ;
226+ let xored = va ^ vb; // BitXor trait
227+ let narrow = xored. to_i16_array ( ) ; // narrow i32→i16
298228 let mut dims = [ 0i16 ; 17 ] ;
299- // Pack i32 -> i16 via shuffle + truncation
300- // We need the low 16 bits of each i32 lane.
301- // Use _mm256_packs_epi32 which saturates — but XOR of two i16 fits in i16,
302- // so we use manual extraction instead to avoid saturation issues.
303- let arr0: [ i32 ; 8 ] = core:: mem:: transmute ( xor0) ;
304- let arr1: [ i32 ; 8 ] = core:: mem:: transmute ( xor1) ;
305- for i in 0 ..8 {
306- dims[ i] = arr0[ i] as i16 ;
307- }
308- for i in 0 ..8 {
309- dims[ 8 + i] = arr1[ i] as i16 ;
310- }
229+ dims[ ..16 ] . copy_from_slice ( & narrow) ;
311230 dims[ 16 ] = ( a[ 16 ] as u16 ^ b[ 16 ] as u16 ) as i16 ;
312231 dims
313232}
@@ -356,25 +275,23 @@ fn noise_from_state(state: u64, scale: i16) -> i16 {
356275#[ cfg( target_arch = "x86_64" ) ]
357276#[ target_feature( enable = "avx512f" ) ]
358277unsafe fn inject_noise_avx512 ( dims : & [ i16 ; 17 ] , scale : i16 , seed : u64 ) -> [ i16 ; 17 ] {
359- use std:: arch:: x86_64:: * ;
360278 // Generate 16 noise values via PRNG
361279 let mut state = seed;
362280 let mut noise_vals = [ 0i32 ; 16 ] ;
363281 for i in 0 ..16 {
364282 prng_step ( & mut state) ;
365283 noise_vals[ i] = noise_from_state ( state, scale) as i32 ;
366284 }
367- // Load dims as i32
368- let vd = _mm512_cvtepi16_epi32 ( _mm256_loadu_si256 ( dims. as_ptr ( ) as * const __m256i ) ) ;
369- let vn = _mm512_loadu_si512 ( noise_vals. as_ptr ( ) as * const __m512i ) ;
370- // Saturating add: add then clamp to i16 range
371- let sum = _mm512_add_epi32 ( vd, vn) ;
372- let lo = _mm512_set1_epi32 ( -32768 ) ;
373- let hi = _mm512_set1_epi32 ( 32767 ) ;
374- let clamped = _mm512_max_epi32 ( _mm512_min_epi32 ( sum, hi) , lo) ;
375- let packed = _mm512_cvtepi32_epi16 ( clamped) ;
285+ // Load dims as i32, add noise, clamp to i16 range
286+ let vd = crate :: simd:: I32x16 :: from_i16_slice ( dims) ;
287+ let vn = crate :: simd:: I32x16 :: from_array ( noise_vals) ;
288+ let sum = vd + vn;
289+ let lo = crate :: simd:: I32x16 :: splat ( -32768 ) ;
290+ let hi = crate :: simd:: I32x16 :: splat ( 32767 ) ;
291+ let clamped = sum. simd_min ( hi) . simd_max ( lo) ;
292+ let narrow = clamped. to_i16_array ( ) ;
376293 let mut result = [ 0i16 ; 17 ] ;
377- _mm256_storeu_si256 ( result. as_mut_ptr ( ) as * mut __m256i , packed ) ;
294+ result[ .. 16 ] . copy_from_slice ( & narrow ) ;
378295 // 17th dim
379296 prng_step ( & mut state) ;
380297 let n16 = noise_from_state ( state, scale) ;
@@ -385,43 +302,23 @@ unsafe fn inject_noise_avx512(dims: &[i16; 17], scale: i16, seed: u64) -> [i16;
385302#[ cfg( target_arch = "x86_64" ) ]
386303#[ target_feature( enable = "avx2" ) ]
387304unsafe fn inject_noise_avx2 ( dims : & [ i16 ; 17 ] , scale : i16 , seed : u64 ) -> [ i16 ; 17 ] {
388- use std :: arch :: x86_64 :: * ;
305+ // Generate 16 noise values via PRNG
389306 let mut state = seed;
390- // First 8 dims
391- let mut noise0 = [ 0i32 ; 8 ] ;
392- for i in 0 ..8 {
393- prng_step ( & mut state) ;
394- noise0[ i] = noise_from_state ( state, scale) as i32 ;
395- }
396- let vd0 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( dims. as_ptr ( ) as * const __m128i ) ) ;
397- let vn0 = _mm256_loadu_si256 ( noise0. as_ptr ( ) as * const __m256i ) ;
398- let sum0 = _mm256_add_epi32 ( vd0, vn0) ;
399-
400- // Dims 8..16
401- let mut noise1 = [ 0i32 ; 8 ] ;
402- for i in 0 ..8 {
307+ let mut noise_vals = [ 0i32 ; 16 ] ;
308+ for i in 0 ..16 {
403309 prng_step ( & mut state) ;
404- noise1 [ i] = noise_from_state ( state, scale) as i32 ;
310+ noise_vals [ i] = noise_from_state ( state, scale) as i32 ;
405311 }
406- let vd1 = _mm256_cvtepi16_epi32 ( _mm_loadu_si128 ( dims[ 8 ..] . as_ptr ( ) as * const __m128i ) ) ;
407- let vn1 = _mm256_loadu_si256 ( noise1. as_ptr ( ) as * const __m256i ) ;
408- let sum1 = _mm256_add_epi32 ( vd1, vn1) ;
409-
410- // Clamp and extract
411- let lo = _mm256_set1_epi32 ( -32768 ) ;
412- let hi = _mm256_set1_epi32 ( 32767 ) ;
413- let c0 = _mm256_max_epi32 ( _mm256_min_epi32 ( sum0, hi) , lo) ;
414- let c1 = _mm256_max_epi32 ( _mm256_min_epi32 ( sum1, hi) , lo) ;
415-
416- let arr0: [ i32 ; 8 ] = core:: mem:: transmute ( c0) ;
417- let arr1: [ i32 ; 8 ] = core:: mem:: transmute ( c1) ;
312+ // Load dims as i32, add noise, clamp to i16 range
313+ let vd = crate :: simd:: I32x16 :: from_i16_slice ( dims) ;
314+ let vn = crate :: simd:: I32x16 :: from_array ( noise_vals) ;
315+ let sum = vd + vn;
316+ let lo = crate :: simd:: I32x16 :: splat ( -32768 ) ;
317+ let hi = crate :: simd:: I32x16 :: splat ( 32767 ) ;
318+ let clamped = sum. simd_min ( hi) . simd_max ( lo) ;
319+ let narrow = clamped. to_i16_array ( ) ;
418320 let mut result = [ 0i16 ; 17 ] ;
419- for i in 0 ..8 {
420- result[ i] = arr0[ i] as i16 ;
421- }
422- for i in 0 ..8 {
423- result[ 8 + i] = arr1[ i] as i16 ;
424- }
321+ result[ ..16 ] . copy_from_slice ( & narrow) ;
425322 // 17th dim
426323 prng_step ( & mut state) ;
427324 let n16 = noise_from_state ( state, scale) ;
0 commit comments