@@ -168,16 +168,16 @@ fn aabb_intersect_batch_scalar(query: &Aabb, candidates: &[Aabb]) -> Vec<bool> {
168168
169169/// AVX-512 batch AABB intersection: tests 16 candidates per axis comparison.
170170///
171- /// Broadcasts query min/max per axis, gathers candidate coords into __m512 ,
172- /// compares all 16 at once using `_mm512_cmp_ps_mask `, ANDs the 6 comparison
171+ /// Broadcasts query min/max per axis, gathers candidate coords into F32x16 ,
172+ /// compares all 16 at once using `simd_le` / `simd_ge `, ANDs the 6 comparison
173173/// masks.
174174///
175175/// # Safety
176176/// Caller must ensure AVX-512F is available.
177177#[ cfg( target_arch = "x86_64" ) ]
178178#[ target_feature( enable = "avx512f" ) ]
179179unsafe fn aabb_intersect_batch_avx512 ( query : & Aabb , candidates : & [ Aabb ] ) -> Vec < bool > {
180- use core :: arch :: x86_64 :: * ;
180+ use crate :: simd :: { F32x16 , F32Mask16 } ;
181181
182182 let mut result = Vec :: with_capacity ( candidates. len ( ) ) ;
183183
@@ -203,32 +203,30 @@ unsafe fn aabb_intersect_batch_avx512(query: &Aabb, candidates: &[Aabb]) -> Vec<
203203 c_max_z[ i] = cand. max [ 2 ] ;
204204 }
205205
206- // SAFETY: arrays are 16-element, avx512f checked by caller.
207- let v_c_min_x = _mm512_loadu_ps ( c_min_x. as_ptr ( ) ) ;
208- let v_c_max_x = _mm512_loadu_ps ( c_max_x. as_ptr ( ) ) ;
209- let v_c_min_y = _mm512_loadu_ps ( c_min_y. as_ptr ( ) ) ;
210- let v_c_max_y = _mm512_loadu_ps ( c_max_y. as_ptr ( ) ) ;
211- let v_c_min_z = _mm512_loadu_ps ( c_min_z. as_ptr ( ) ) ;
212- let v_c_max_z = _mm512_loadu_ps ( c_max_z. as_ptr ( ) ) ;
206+ let v_c_min_x = F32x16 :: from_array ( c_min_x) ;
207+ let v_c_max_x = F32x16 :: from_array ( c_max_x) ;
208+ let v_c_min_y = F32x16 :: from_array ( c_min_y) ;
209+ let v_c_max_y = F32x16 :: from_array ( c_max_y) ;
210+ let v_c_min_z = F32x16 :: from_array ( c_min_z) ;
211+ let v_c_max_z = F32x16 :: from_array ( c_max_z) ;
213212
214213 // Broadcast query bounds
215- let q_min_x = _mm512_set1_ps ( query. min [ 0 ] ) ;
216- let q_max_x = _mm512_set1_ps ( query. max [ 0 ] ) ;
217- let q_min_y = _mm512_set1_ps ( query. min [ 1 ] ) ;
218- let q_max_y = _mm512_set1_ps ( query. max [ 1 ] ) ;
219- let q_min_z = _mm512_set1_ps ( query. min [ 2 ] ) ;
220- let q_max_z = _mm512_set1_ps ( query. max [ 2 ] ) ;
214+ let q_min_x = F32x16 :: splat ( query. min [ 0 ] ) ;
215+ let q_max_x = F32x16 :: splat ( query. max [ 0 ] ) ;
216+ let q_min_y = F32x16 :: splat ( query. min [ 1 ] ) ;
217+ let q_max_y = F32x16 :: splat ( query. max [ 1 ] ) ;
218+ let q_min_z = F32x16 :: splat ( query. min [ 2 ] ) ;
219+ let q_max_z = F32x16 :: splat ( query. max [ 2 ] ) ;
221220
222221 // 6 intersection conditions: q.min[i] <= c.max[i] && q.max[i] >= c.min[i]
223- // _CMP_LE_OQ = 18, _CMP_GE_OQ = 29 (ordered, quiet)
224- let m1 = _mm512_cmp_ps_mask :: < { _CMP_LE_OQ } > ( q_min_x, v_c_max_x) ;
225- let m2 = _mm512_cmp_ps_mask :: < { _CMP_GE_OQ } > ( q_max_x, v_c_min_x) ;
226- let m3 = _mm512_cmp_ps_mask :: < { _CMP_LE_OQ } > ( q_min_y, v_c_max_y) ;
227- let m4 = _mm512_cmp_ps_mask :: < { _CMP_GE_OQ } > ( q_max_y, v_c_min_y) ;
228- let m5 = _mm512_cmp_ps_mask :: < { _CMP_LE_OQ } > ( q_min_z, v_c_max_z) ;
229- let m6 = _mm512_cmp_ps_mask :: < { _CMP_GE_OQ } > ( q_max_z, v_c_min_z) ;
222+ let m1 = q_min_x. simd_le ( v_c_max_x) ;
223+ let m2 = q_max_x. simd_ge ( v_c_min_x) ;
224+ let m3 = q_min_y. simd_le ( v_c_max_y) ;
225+ let m4 = q_max_y. simd_ge ( v_c_min_y) ;
226+ let m5 = q_min_z. simd_le ( v_c_max_z) ;
227+ let m6 = q_max_z. simd_ge ( v_c_min_z) ;
230228
231- let all = m1 & m2 & m3 & m4 & m5 & m6;
229+ let all = m1. 0 & m2. 0 & m3. 0 & m4. 0 & m5. 0 & m6. 0 ;
232230
233231 for i in 0 ..16 {
234232 result. push ( ( all >> i) & 1 != 0 ) ;
@@ -246,24 +244,16 @@ unsafe fn aabb_intersect_batch_avx512(query: &Aabb, candidates: &[Aabb]) -> Vec<
246244#[ cfg( target_arch = "x86_64" ) ]
247245#[ target_feature( enable = "sse4.1" ) ]
248246unsafe fn aabb_intersect_batch_sse41 ( query : & Aabb , candidates : & [ Aabb ] ) -> Vec < bool > {
249- use core:: arch:: x86_64:: * ;
250-
251- // Load query min/max into SSE registers (only need xyz, ignore w).
252- let q_min = _mm_set_ps ( 0.0 , query. min [ 2 ] , query. min [ 1 ] , query. min [ 0 ] ) ;
253- let q_max = _mm_set_ps ( f32:: MAX , query. max [ 2 ] , query. max [ 1 ] , query. max [ 0 ] ) ;
254-
247+ // Scalar per-candidate test — LLVM auto-vectorizes with target-cpu=x86-64-v4
255248 let mut result = Vec :: with_capacity ( candidates. len ( ) ) ;
256249 for c in candidates {
257- let c_min = _mm_set_ps ( 0.0 , c. min [ 2 ] , c. min [ 1 ] , c. min [ 0 ] ) ;
258- let c_max = _mm_set_ps ( f32:: MAX , c. max [ 2 ] , c. max [ 1 ] , c. max [ 0 ] ) ;
259-
260- // q.min <= c.max AND q.max >= c.min (per component)
261- let le = _mm_cmple_ps ( q_min, c_max) ; // q_min[i] <= c_max[i]
262- let ge = _mm_cmpge_ps ( q_max, c_min) ; // q_max[i] >= c_min[i]
263- let both = _mm_and_ps ( le, ge) ;
264- // All 4 lanes must be true (lane 3 is always true due to sentinel values).
265- let mask = _mm_movemask_ps ( both) ;
266- result. push ( mask == 0xF ) ;
250+ let hit = query. min [ 0 ] <= c. max [ 0 ]
251+ && query. max [ 0 ] >= c. min [ 0 ]
252+ && query. min [ 1 ] <= c. max [ 1 ]
253+ && query. max [ 1 ] >= c. min [ 1 ]
254+ && query. min [ 2 ] <= c. max [ 2 ]
255+ && query. max [ 2 ] >= c. min [ 2 ] ;
256+ result. push ( hit) ;
267257 }
268258 result
269259}
@@ -333,27 +323,27 @@ fn ray_aabb_slab_test_scalar(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>)
333323/// AVX-512 batch ray-AABB slab test: processes 16 AABBs per iteration.
334324///
335325/// Broadcasts ray origin and inv_dir per axis, gathers candidate min/max
336- /// coords into SoA arrays, computes slab intervals with `_mm512_min_ps ` /
337- /// `_mm512_max_ps `, and combines masks with `_mm512_cmp_ps_mask `.
326+ /// coords into SoA arrays, computes slab intervals with `simd_min ` /
327+ /// `simd_max `, and combines masks with `simd_le` / `simd_ge `.
338328///
339329/// # Safety
340330/// Caller must ensure AVX-512F is available.
341331#[ cfg( target_arch = "x86_64" ) ]
342332#[ target_feature( enable = "avx512f" ) ]
343333unsafe fn ray_aabb_slab_test_avx512 ( ray : & Ray , aabbs : & [ Aabb ] ) -> ( Vec < bool > , Vec < f32 > ) {
344- use core :: arch :: x86_64 :: * ;
334+ use crate :: simd :: F32x16 ;
345335
346336 let mut hits = Vec :: with_capacity ( aabbs. len ( ) ) ;
347337 let mut t_values = Vec :: with_capacity ( aabbs. len ( ) ) ;
348338
349339 // Broadcast ray origin and inv_dir per axis
350- let orig_x = _mm512_set1_ps ( ray. origin [ 0 ] ) ;
351- let orig_y = _mm512_set1_ps ( ray. origin [ 1 ] ) ;
352- let orig_z = _mm512_set1_ps ( ray. origin [ 2 ] ) ;
353- let inv_x = _mm512_set1_ps ( ray. inv_dir [ 0 ] ) ;
354- let inv_y = _mm512_set1_ps ( ray. inv_dir [ 1 ] ) ;
355- let inv_z = _mm512_set1_ps ( ray. inv_dir [ 2 ] ) ;
356- let zero = _mm512_set1_ps ( 0.0 ) ;
340+ let orig_x = F32x16 :: splat ( ray. origin [ 0 ] ) ;
341+ let orig_y = F32x16 :: splat ( ray. origin [ 1 ] ) ;
342+ let orig_z = F32x16 :: splat ( ray. origin [ 2 ] ) ;
343+ let inv_x = F32x16 :: splat ( ray. inv_dir [ 0 ] ) ;
344+ let inv_y = F32x16 :: splat ( ray. inv_dir [ 1 ] ) ;
345+ let inv_z = F32x16 :: splat ( ray. inv_dir [ 2 ] ) ;
346+ let zero = F32x16 :: splat ( 0.0 ) ;
357347
358348 // Process 16 AABBs at a time
359349 let chunks = aabbs. len ( ) / 16 ;
@@ -378,49 +368,44 @@ unsafe fn ray_aabb_slab_test_avx512(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Ve
378368 a_max_z[ i] = aabb. max [ 2 ] ;
379369 }
380370
381- // SAFETY: arrays are 16-element, avx512f checked by caller.
382- let v_min_x = _mm512_loadu_ps ( a_min_x. as_ptr ( ) ) ;
383- let v_max_x = _mm512_loadu_ps ( a_max_x. as_ptr ( ) ) ;
384- let v_min_y = _mm512_loadu_ps ( a_min_y. as_ptr ( ) ) ;
385- let v_max_y = _mm512_loadu_ps ( a_max_y. as_ptr ( ) ) ;
386- let v_min_z = _mm512_loadu_ps ( a_min_z. as_ptr ( ) ) ;
387- let v_max_z = _mm512_loadu_ps ( a_max_z. as_ptr ( ) ) ;
371+ let v_min_x = F32x16 :: from_array ( a_min_x) ;
372+ let v_max_x = F32x16 :: from_array ( a_max_x) ;
373+ let v_min_y = F32x16 :: from_array ( a_min_y) ;
374+ let v_max_y = F32x16 :: from_array ( a_max_y) ;
375+ let v_min_z = F32x16 :: from_array ( a_min_z) ;
376+ let v_max_z = F32x16 :: from_array ( a_max_z) ;
388377
389378 // X axis: t1 = (min - origin) * inv_dir, t2 = (max - origin) * inv_dir
390- let t1_x = _mm512_mul_ps ( _mm512_sub_ps ( v_min_x, orig_x) , inv_x) ;
391- let t2_x = _mm512_mul_ps ( _mm512_sub_ps ( v_max_x, orig_x) , inv_x) ;
392- let t_near_x = _mm512_min_ps ( t1_x, t2_x) ;
393- let t_far_x = _mm512_max_ps ( t1_x, t2_x) ;
379+ let t1_x = ( v_min_x - orig_x) * inv_x;
380+ let t2_x = ( v_max_x - orig_x) * inv_x;
381+ let t_near_x = t1_x. simd_min ( t2_x) ;
382+ let t_far_x = t1_x. simd_max ( t2_x) ;
394383
395384 // Y axis
396- let t1_y = _mm512_mul_ps ( _mm512_sub_ps ( v_min_y, orig_y) , inv_y) ;
397- let t2_y = _mm512_mul_ps ( _mm512_sub_ps ( v_max_y, orig_y) , inv_y) ;
398- let t_near_y = _mm512_min_ps ( t1_y, t2_y) ;
399- let t_far_y = _mm512_max_ps ( t1_y, t2_y) ;
385+ let t1_y = ( v_min_y - orig_y) * inv_y;
386+ let t2_y = ( v_max_y - orig_y) * inv_y;
387+ let t_near_y = t1_y. simd_min ( t2_y) ;
388+ let t_far_y = t1_y. simd_max ( t2_y) ;
400389
401390 // Z axis
402- let t1_z = _mm512_mul_ps ( _mm512_sub_ps ( v_min_z, orig_z) , inv_z) ;
403- let t2_z = _mm512_mul_ps ( _mm512_sub_ps ( v_max_z, orig_z) , inv_z) ;
404- let t_near_z = _mm512_min_ps ( t1_z, t2_z) ;
405- let t_far_z = _mm512_max_ps ( t1_z, t2_z) ;
391+ let t1_z = ( v_min_z - orig_z) * inv_z;
392+ let t2_z = ( v_max_z - orig_z) * inv_z;
393+ let t_near_z = t1_z. simd_min ( t2_z) ;
394+ let t_far_z = t1_z. simd_max ( t2_z) ;
406395
407396 // t_enter = max(t_near_x, t_near_y, t_near_z)
408- let t_enter = _mm512_max_ps ( _mm512_max_ps ( t_near_x, t_near_y) , t_near_z) ;
397+ let t_enter = t_near_x. simd_max ( t_near_y) . simd_max ( t_near_z) ;
409398 // t_exit = min(t_far_x, t_far_y, t_far_z)
410- let t_exit = _mm512_min_ps ( _mm512_min_ps ( t_far_x, t_far_y) , t_far_z) ;
399+ let t_exit = t_far_x. simd_min ( t_far_y) . simd_min ( t_far_z) ;
411400
412401 // hit = t_enter <= t_exit AND t_exit >= 0
413- // _CMP_LE_OQ = 18, _CMP_GE_OQ = 29 (ordered, quiet)
414- let m_le = _mm512_cmp_ps_mask :: < { _CMP_LE_OQ } > ( t_enter, t_exit) ;
415- let m_ge = _mm512_cmp_ps_mask :: < { _CMP_GE_OQ } > ( t_exit, zero) ;
416- let hit_mask = m_le & m_ge;
402+ let m_le = t_enter. simd_le ( t_exit) ;
403+ let m_ge = t_exit. simd_ge ( zero) ;
404+ let hit_mask = m_le. 0 & m_ge. 0 ;
417405
418406 // Clamp t_enter to 0 for origins inside box
419- let t_enter_clamped = _mm512_max_ps ( t_enter, zero) ;
420-
421- // SAFETY: 16-element array matches __m512 lane count.
422- let mut t_arr = [ 0.0f32 ; 16 ] ;
423- _mm512_storeu_ps ( t_arr. as_mut_ptr ( ) , t_enter_clamped) ;
407+ let t_enter_clamped = t_enter. simd_max ( zero) ;
408+ let t_arr = t_enter_clamped. to_array ( ) ;
424409
425410 for i in 0 ..16 {
426411 let hit = ( hit_mask >> i) & 1 != 0 ;
@@ -482,27 +467,14 @@ fn aabb_expand_batch_scalar(aabbs: &mut [Aabb], dx: f32, dy: f32, dz: f32) {
482467#[ cfg( target_arch = "x86_64" ) ]
483468#[ target_feature( enable = "sse2" ) ]
484469unsafe fn aabb_expand_batch_sse2 ( aabbs : & mut [ Aabb ] , dx : f32 , dy : f32 , dz : f32 ) {
485- use core:: arch:: x86_64:: * ;
486-
487- let delta_min = _mm_set_ps ( 0.0 , dz, dy, dx) ;
488- let delta_max = _mm_set_ps ( 0.0 , dz, dy, dx) ;
489-
470+ // Scalar per-AABB expand — LLVM auto-vectorizes with target-cpu=x86-64-v4
490471 for a in aabbs. iter_mut ( ) {
491- let min_v = _mm_set_ps ( 0.0 , a. min [ 2 ] , a. min [ 1 ] , a. min [ 0 ] ) ;
492- let max_v = _mm_set_ps ( 0.0 , a. max [ 2 ] , a. max [ 1 ] , a. max [ 0 ] ) ;
493-
494- let new_min = _mm_sub_ps ( min_v, delta_min) ;
495- let new_max = _mm_add_ps ( max_v, delta_max) ;
496-
497- // Store back. We cannot use _mm_storeu_ps directly into [f32;3],
498- // so extract components.
499- let mut min_arr = [ 0.0f32 ; 4 ] ;
500- let mut max_arr = [ 0.0f32 ; 4 ] ;
501- _mm_storeu_ps ( min_arr. as_mut_ptr ( ) , new_min) ;
502- _mm_storeu_ps ( max_arr. as_mut_ptr ( ) , new_max) ;
503-
504- a. min = [ min_arr[ 0 ] , min_arr[ 1 ] , min_arr[ 2 ] ] ;
505- a. max = [ max_arr[ 0 ] , max_arr[ 1 ] , max_arr[ 2 ] ] ;
472+ a. min [ 0 ] -= dx;
473+ a. min [ 1 ] -= dy;
474+ a. min [ 2 ] -= dz;
475+ a. max [ 0 ] += dx;
476+ a. max [ 1 ] += dy;
477+ a. max [ 2 ] += dz;
506478 }
507479}
508480
0 commit comments