Skip to content

Commit 0d352fa

Browse files
committed
refactor: bgz17_bridge.rs fully rewired to crate::simd::I32x16
Zero raw _mm512_/_mm256_/_mm_ intrinsics remaining. All 5 kernels rewired (92 intrinsics → 0): L1 distance: from_i16_slice → sub → abs → reduce_sum L1 weighted: same + from_array(WEIGHT_VEC) → mul Sign agreement: from_i16_slice → xor → cmpge_zero_mask XOR bind: from_i16_slice → xor → to_i16_array Inject noise: from_i16_slice → add → simd_min/max → to_i16_array AVX2 2-pass patterns collapsed: polyfill I32x16 absorbs the split internally (array-backed [i32; 16] on AVX2, native __m512i on AVX-512). LazyLock runtime dispatch preserved. #[target_feature] preserved. Scalar fallbacks untouched. 19/19 bgz17_bridge tests pass. 1514/1515 full suite pass (1 pre-existing timing flake in vml.rs). https://claude.ai/code/session_01ChLvBfpJS8dQhHxRD4pYNp
1 parent ab1aa6b commit 0d352fa

1 file changed

Lines changed: 63 additions & 166 deletions

File tree

src/hpc/bgz17_bridge.rs

Lines changed: 63 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,10 @@ type L1Fn = unsafe fn(&[i16; 17], &[i16; 17]) -> u32;
4343
#[cfg(target_arch = "x86_64")]
4444
#[target_feature(enable = "avx512f")]
4545
unsafe fn l1_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 {
46-
use std::arch::x86_64::*;
47-
// Load 16 i16 → 16 i32 via sign-extension
48-
let va = _mm512_cvtepi16_epi32(_mm256_loadu_si256(a.as_ptr() as *const __m256i));
49-
let vb = _mm512_cvtepi16_epi32(_mm256_loadu_si256(b.as_ptr() as *const __m256i));
50-
let diff = _mm512_sub_epi32(va, vb);
51-
let abs_diff = _mm512_abs_epi32(diff);
52-
let sum16 = _mm512_reduce_add_epi32(abs_diff) as u32;
46+
let va = crate::simd::I32x16::from_i16_slice(a);
47+
let vb = crate::simd::I32x16::from_i16_slice(b);
48+
let abs_diff = (va - vb).abs();
49+
let sum16 = abs_diff.reduce_sum() as u32;
5350
// 17th dim scalar
5451
let d16 = (a[16] as i32 - b[16] as i32).unsigned_abs();
5552
sum16 + d16
@@ -58,26 +55,10 @@ unsafe fn l1_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 {
5855
#[cfg(target_arch = "x86_64")]
5956
#[target_feature(enable = "avx2")]
6057
unsafe fn l1_avx2(a: &[i16; 17], b: &[i16; 17]) -> u32 {
61-
use std::arch::x86_64::*;
62-
// Process 8 dims at a time (2 passes of 8 = 16, + 1 scalar)
63-
let va0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a.as_ptr() as *const __m128i));
64-
let vb0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b.as_ptr() as *const __m128i));
65-
let diff0 = _mm256_sub_epi32(va0, vb0);
66-
let abs0 = _mm256_abs_epi32(diff0);
67-
68-
let va1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a[8..].as_ptr() as *const __m128i));
69-
let vb1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b[8..].as_ptr() as *const __m128i));
70-
let diff1 = _mm256_sub_epi32(va1, vb1);
71-
let abs1 = _mm256_abs_epi32(diff1);
72-
73-
let sum = _mm256_add_epi32(abs0, abs1);
74-
// Horizontal sum of 8 i32
75-
let hi128 = _mm256_extracti128_si256(sum, 1);
76-
let lo128 = _mm256_castsi256_si128(sum);
77-
let sum128 = _mm_add_epi32(lo128, hi128);
78-
let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
79-
let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
80-
let sum16 = _mm_extract_epi32(sum32, 0) as u32;
58+
let va = crate::simd::I32x16::from_i16_slice(a);
59+
let vb = crate::simd::I32x16::from_i16_slice(b);
60+
let abs_diff = (va - vb).abs();
61+
let sum16 = abs_diff.reduce_sum() as u32;
8162
// 17th dim scalar
8263
let d16 = (a[16] as i32 - b[16] as i32).unsigned_abs();
8364
sum16 + d16
@@ -115,14 +96,12 @@ const WEIGHT_VEC: [i32; 16] = [20, 3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
11596
#[cfg(target_arch = "x86_64")]
11697
#[target_feature(enable = "avx512f")]
11798
unsafe fn l1_weighted_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 {
118-
use std::arch::x86_64::*;
119-
let va = _mm512_cvtepi16_epi32(_mm256_loadu_si256(a.as_ptr() as *const __m256i));
120-
let vb = _mm512_cvtepi16_epi32(_mm256_loadu_si256(b.as_ptr() as *const __m256i));
121-
let diff = _mm512_sub_epi32(va, vb);
122-
let abs_diff = _mm512_abs_epi32(diff);
123-
let vw = _mm512_loadu_si512(WEIGHT_VEC.as_ptr() as *const __m512i);
124-
let weighted = _mm512_mullo_epi32(abs_diff, vw);
125-
let sum16 = _mm512_reduce_add_epi32(weighted) as u32;
99+
let va = crate::simd::I32x16::from_i16_slice(a);
100+
let vb = crate::simd::I32x16::from_i16_slice(b);
101+
let abs_diff = (va - vb).abs();
102+
let vw = crate::simd::I32x16::from_array(WEIGHT_VEC);
103+
let weighted = abs_diff * vw;
104+
let sum16 = weighted.reduce_sum() as u32;
126105
// 17th dim: weight = 1
127106
let d16 = (a[16] as i32 - b[16] as i32).unsigned_abs();
128107
sum16 + d16
@@ -131,34 +110,14 @@ unsafe fn l1_weighted_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 {
131110
#[cfg(target_arch = "x86_64")]
132111
#[target_feature(enable = "avx2")]
133112
unsafe fn l1_weighted_avx2(a: &[i16; 17], b: &[i16; 17]) -> u32 {
134-
use std::arch::x86_64::*;
135-
// First 8 dims
136-
let va0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a.as_ptr() as *const __m128i));
137-
let vb0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b.as_ptr() as *const __m128i));
138-
let diff0 = _mm256_sub_epi32(va0, vb0);
139-
let abs0 = _mm256_abs_epi32(diff0);
140-
let vw0 = _mm256_loadu_si256(WEIGHT_VEC.as_ptr() as *const __m256i);
141-
let w0 = _mm256_mullo_epi32(abs0, vw0);
142-
143-
// Dims 8..16
144-
let va1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a[8..].as_ptr() as *const __m128i));
145-
let vb1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b[8..].as_ptr() as *const __m128i));
146-
let diff1 = _mm256_sub_epi32(va1, vb1);
147-
let abs1 = _mm256_abs_epi32(diff1);
148-
let vw1 = _mm256_loadu_si256(WEIGHT_VEC[8..].as_ptr() as *const __m256i);
149-
let w1 = _mm256_mullo_epi32(abs1, vw1);
150-
151-
let sum = _mm256_add_epi32(w0, w1);
152-
// Horizontal sum
153-
let hi128 = _mm256_extracti128_si256(sum, 1);
154-
let lo128 = _mm256_castsi256_si128(sum);
155-
let sum128 = _mm_add_epi32(lo128, hi128);
156-
let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
157-
let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
158-
let s = _mm_extract_epi32(sum32, 0) as u32;
159-
// 17th dim: weight = 1
113+
let va = crate::simd::I32x16::from_i16_slice(a);
114+
let vb = crate::simd::I32x16::from_i16_slice(b);
115+
let abs_diff = (va - vb).abs();
116+
let vw = crate::simd::I32x16::from_array(WEIGHT_VEC);
117+
let weighted = abs_diff * vw;
118+
let sum16 = weighted.reduce_sum() as u32;
160119
let d16 = (a[16] as i32 - b[16] as i32).unsigned_abs();
161-
s + d16
120+
sum16 + d16
162121
}
163122

164123
fn l1_weighted_scalar(a: &[i16; 17], b: &[i16; 17]) -> u32 {
@@ -193,14 +152,10 @@ type SignAgreementFn = unsafe fn(&[i16; 17], &[i16; 17]) -> u32;
193152
#[cfg(target_arch = "x86_64")]
194153
#[target_feature(enable = "avx512f")]
195154
unsafe fn sign_agreement_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 {
196-
use std::arch::x86_64::*;
197-
let va = _mm512_cvtepi16_epi32(_mm256_loadu_si256(a.as_ptr() as *const __m256i));
198-
let vb = _mm512_cvtepi16_epi32(_mm256_loadu_si256(b.as_ptr() as *const __m256i));
199-
// XOR: same sign → non-negative, different sign → negative
200-
let xor = _mm512_xor_si512(va, vb);
201-
// Compare >= 0: mask bit set where same sign
202-
let zero = _mm512_setzero_si512();
203-
let mask = _mm512_cmpge_epi32_mask(xor, zero);
155+
let va = crate::simd::I32x16::from_i16_slice(a);
156+
let vb = crate::simd::I32x16::from_i16_slice(b);
157+
let xor = va ^ vb;
158+
let mask = xor.cmpge_zero_mask();
204159
let count16 = mask.count_ones();
205160
// 17th dim
206161
let same17 = if (a[16] >= 0) == (b[16] >= 0) { 1u32 } else { 0u32 };
@@ -210,28 +165,14 @@ unsafe fn sign_agreement_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 {
210165
#[cfg(target_arch = "x86_64")]
211166
#[target_feature(enable = "avx2")]
212167
unsafe fn sign_agreement_avx2(a: &[i16; 17], b: &[i16; 17]) -> u32 {
213-
use std::arch::x86_64::*;
214-
// First 8 dims
215-
let va0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a.as_ptr() as *const __m128i));
216-
let vb0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b.as_ptr() as *const __m128i));
217-
let xor0 = _mm256_xor_si256(va0, vb0);
218-
let zero = _mm256_setzero_si256();
219-
let neg0 = _mm256_cmpgt_epi32(zero, xor0); // -1 where xor < 0
220-
// movemask_ps on the reinterpreted float gives 8 bits, one per 32-bit lane
221-
let mask0 = _mm256_movemask_ps(_mm256_castsi256_ps(neg0)) as u32;
222-
let same0 = 8 - mask0.count_ones();
223-
224-
// Dims 8..16
225-
let va1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a[8..].as_ptr() as *const __m128i));
226-
let vb1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b[8..].as_ptr() as *const __m128i));
227-
let xor1 = _mm256_xor_si256(va1, vb1);
228-
let neg1 = _mm256_cmpgt_epi32(zero, xor1);
229-
let mask1 = _mm256_movemask_ps(_mm256_castsi256_ps(neg1)) as u32;
230-
let same1 = 8 - mask1.count_ones();
231-
168+
let va = crate::simd::I32x16::from_i16_slice(a);
169+
let vb = crate::simd::I32x16::from_i16_slice(b);
170+
let xor = va ^ vb;
171+
let mask = xor.cmpge_zero_mask();
172+
let count16 = mask.count_ones();
232173
// 17th dim
233174
let same17 = if (a[16] >= 0) == (b[16] >= 0) { 1u32 } else { 0u32 };
234-
same0 + same1 + same17
175+
count16 + same17
235176
}
236177

237178
fn sign_agreement_scalar(a: &[i16; 17], b: &[i16; 17]) -> u32 {
@@ -267,47 +208,25 @@ type XorBindFn = unsafe fn(&[i16; 17], &[i16; 17]) -> [i16; 17];
267208
#[cfg(target_arch = "x86_64")]
268209
#[target_feature(enable = "avx512f")]
269210
unsafe fn xor_bind_avx512(a: &[i16; 17], b: &[i16; 17]) -> [i16; 17] {
270-
use std::arch::x86_64::*;
271-
// Load 16 i16 as i32, XOR, store back as i16
272-
let va = _mm512_cvtepi16_epi32(_mm256_loadu_si256(a.as_ptr() as *const __m256i));
273-
let vb = _mm512_cvtepi16_epi32(_mm256_loadu_si256(b.as_ptr() as *const __m256i));
274-
let xored = _mm512_xor_si512(va, vb);
275-
// Convert back to i16: truncate i32 -> i16 via pmovdw
276-
let packed = _mm512_cvtepi32_epi16(xored);
211+
let va = crate::simd::I32x16::from_i16_slice(a);
212+
let vb = crate::simd::I32x16::from_i16_slice(b);
213+
let xored = va ^ vb; // BitXor trait
214+
let narrow = xored.to_i16_array(); // narrow i32→i16
277215
let mut dims = [0i16; 17];
278-
_mm256_storeu_si256(dims.as_mut_ptr() as *mut __m256i, packed);
216+
dims[..16].copy_from_slice(&narrow);
279217
dims[16] = (a[16] as u16 ^ b[16] as u16) as i16;
280218
dims
281219
}
282220

283221
#[cfg(target_arch = "x86_64")]
284222
#[target_feature(enable = "avx2")]
285223
unsafe fn xor_bind_avx2(a: &[i16; 17], b: &[i16; 17]) -> [i16; 17] {
286-
use std::arch::x86_64::*;
287-
// First 8 dims: load as i32, XOR, narrow back
288-
let va0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a.as_ptr() as *const __m128i));
289-
let vb0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b.as_ptr() as *const __m128i));
290-
let xor0 = _mm256_xor_si256(va0, vb0);
291-
292-
// Dims 8..16
293-
let va1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a[8..].as_ptr() as *const __m128i));
294-
let vb1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b[8..].as_ptr() as *const __m128i));
295-
let xor1 = _mm256_xor_si256(va1, vb1);
296-
297-
// Extract results back to i16
224+
let va = crate::simd::I32x16::from_i16_slice(a);
225+
let vb = crate::simd::I32x16::from_i16_slice(b);
226+
let xored = va ^ vb; // BitXor trait
227+
let narrow = xored.to_i16_array(); // narrow i32→i16
298228
let mut dims = [0i16; 17];
299-
// Pack i32 -> i16 via shuffle + truncation
300-
// We need the low 16 bits of each i32 lane.
301-
// Use _mm256_packs_epi32 which saturates — but XOR of two i16 fits in i16,
302-
// so we use manual extraction instead to avoid saturation issues.
303-
let arr0: [i32; 8] = core::mem::transmute(xor0);
304-
let arr1: [i32; 8] = core::mem::transmute(xor1);
305-
for i in 0..8 {
306-
dims[i] = arr0[i] as i16;
307-
}
308-
for i in 0..8 {
309-
dims[8 + i] = arr1[i] as i16;
310-
}
229+
dims[..16].copy_from_slice(&narrow);
311230
dims[16] = (a[16] as u16 ^ b[16] as u16) as i16;
312231
dims
313232
}
@@ -356,25 +275,23 @@ fn noise_from_state(state: u64, scale: i16) -> i16 {
356275
#[cfg(target_arch = "x86_64")]
357276
#[target_feature(enable = "avx512f")]
358277
unsafe fn inject_noise_avx512(dims: &[i16; 17], scale: i16, seed: u64) -> [i16; 17] {
359-
use std::arch::x86_64::*;
360278
// Generate 16 noise values via PRNG
361279
let mut state = seed;
362280
let mut noise_vals = [0i32; 16];
363281
for i in 0..16 {
364282
prng_step(&mut state);
365283
noise_vals[i] = noise_from_state(state, scale) as i32;
366284
}
367-
// Load dims as i32
368-
let vd = _mm512_cvtepi16_epi32(_mm256_loadu_si256(dims.as_ptr() as *const __m256i));
369-
let vn = _mm512_loadu_si512(noise_vals.as_ptr() as *const __m512i);
370-
// Saturating add: add then clamp to i16 range
371-
let sum = _mm512_add_epi32(vd, vn);
372-
let lo = _mm512_set1_epi32(-32768);
373-
let hi = _mm512_set1_epi32(32767);
374-
let clamped = _mm512_max_epi32(_mm512_min_epi32(sum, hi), lo);
375-
let packed = _mm512_cvtepi32_epi16(clamped);
285+
// Load dims as i32, add noise, clamp to i16 range
286+
let vd = crate::simd::I32x16::from_i16_slice(dims);
287+
let vn = crate::simd::I32x16::from_array(noise_vals);
288+
let sum = vd + vn;
289+
let lo = crate::simd::I32x16::splat(-32768);
290+
let hi = crate::simd::I32x16::splat(32767);
291+
let clamped = sum.simd_min(hi).simd_max(lo);
292+
let narrow = clamped.to_i16_array();
376293
let mut result = [0i16; 17];
377-
_mm256_storeu_si256(result.as_mut_ptr() as *mut __m256i, packed);
294+
result[..16].copy_from_slice(&narrow);
378295
// 17th dim
379296
prng_step(&mut state);
380297
let n16 = noise_from_state(state, scale);
@@ -385,43 +302,23 @@ unsafe fn inject_noise_avx512(dims: &[i16; 17], scale: i16, seed: u64) -> [i16;
385302
#[cfg(target_arch = "x86_64")]
386303
#[target_feature(enable = "avx2")]
387304
unsafe fn inject_noise_avx2(dims: &[i16; 17], scale: i16, seed: u64) -> [i16; 17] {
388-
use std::arch::x86_64::*;
305+
// Generate 16 noise values via PRNG
389306
let mut state = seed;
390-
// First 8 dims
391-
let mut noise0 = [0i32; 8];
392-
for i in 0..8 {
393-
prng_step(&mut state);
394-
noise0[i] = noise_from_state(state, scale) as i32;
395-
}
396-
let vd0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(dims.as_ptr() as *const __m128i));
397-
let vn0 = _mm256_loadu_si256(noise0.as_ptr() as *const __m256i);
398-
let sum0 = _mm256_add_epi32(vd0, vn0);
399-
400-
// Dims 8..16
401-
let mut noise1 = [0i32; 8];
402-
for i in 0..8 {
307+
let mut noise_vals = [0i32; 16];
308+
for i in 0..16 {
403309
prng_step(&mut state);
404-
noise1[i] = noise_from_state(state, scale) as i32;
310+
noise_vals[i] = noise_from_state(state, scale) as i32;
405311
}
406-
let vd1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(dims[8..].as_ptr() as *const __m128i));
407-
let vn1 = _mm256_loadu_si256(noise1.as_ptr() as *const __m256i);
408-
let sum1 = _mm256_add_epi32(vd1, vn1);
409-
410-
// Clamp and extract
411-
let lo = _mm256_set1_epi32(-32768);
412-
let hi = _mm256_set1_epi32(32767);
413-
let c0 = _mm256_max_epi32(_mm256_min_epi32(sum0, hi), lo);
414-
let c1 = _mm256_max_epi32(_mm256_min_epi32(sum1, hi), lo);
415-
416-
let arr0: [i32; 8] = core::mem::transmute(c0);
417-
let arr1: [i32; 8] = core::mem::transmute(c1);
312+
// Load dims as i32, add noise, clamp to i16 range
313+
let vd = crate::simd::I32x16::from_i16_slice(dims);
314+
let vn = crate::simd::I32x16::from_array(noise_vals);
315+
let sum = vd + vn;
316+
let lo = crate::simd::I32x16::splat(-32768);
317+
let hi = crate::simd::I32x16::splat(32767);
318+
let clamped = sum.simd_min(hi).simd_max(lo);
319+
let narrow = clamped.to_i16_array();
418320
let mut result = [0i16; 17];
419-
for i in 0..8 {
420-
result[i] = arr0[i] as i16;
421-
}
422-
for i in 0..8 {
423-
result[8 + i] = arr1[i] as i16;
424-
}
321+
result[..16].copy_from_slice(&narrow);
425322
// 17th dim
426323
prng_step(&mut state);
427324
let n16 = noise_from_state(state, scale);

0 commit comments

Comments
 (0)