1717//! Both shapes are SIMD-friendly storage layouts: each field is a
1818//! contiguous `Vec<T>`, so per-field SIMD loops iterate one `Vec`.
1919//!
20- //! # Element-type scope (this PR )
20+ //! # Element-type scope (PR-X2 )
2121//!
22- //! The macro and `SoaVec` are generic over `T`. The closure-based
23- //! conversion helpers ([`aos_to_soa`], [`soa_to_aos`]) are currently
24- //! **hardwired to `f32` output** (`SoaVec<f32, N>`). Downstream consumers
25- //! with `i8` / `u8` / `u16` / `bf16` SoA fields (palette indices,
26- //! quantized embeddings, BF16 mantissa bytes) must write their own
27- //! extract loop today; the public surface for generic-T conversion is
28- //! a follow-up. The macro itself supports any field type.
22+ //! `SoaVec`, the `soa_struct!` macro, and the closure-based conversion
23+ //! helpers [`aos_to_soa`] / [`soa_to_aos`] are **fully generic over the
24+ //! element type `U`** (was f32-hardwired through W3-W6; PR-X2 lifted the
25+ //! constraint). Common element types now flow through directly:
26+ //!
27+ //! - `f32` — Gaussian batch means, covariances (original W3-W6 case)
28+ //! - `u64` — `CausalEdge64` mantissa cells, NARS evidence packs
29+ //! - `u16` — BF16 carrier values, packed depth fields
30+ //! - `u8` — palette indices, quantized embeddings
31+ //! - `i8` — quantized weights with signed range
32+ //!
33+ //! Callers passing turbofish should now use four type params:
34+ //! `aos_to_soa::<_, U, N, _>(...)` instead of the pre-PR-X2 form
35+ //! `aos_to_soa::<_, N, _>(...)`. Callers using return-type inference are
36+ //! unaffected by the generalisation.
2937//!
3038//! # Layering — why `hpc::soa` and not `simd_ops`
3139//!
3240//! `crate::simd_ops` is the SIMD-dispatch glue layer (every fn there
3341//! dispatches through `F32x16` / `F64x8`). Per the W1a consumer contract
3442//! at `.claude/knowledge/vertical-simd-consumer-contract.md`, free-function
35- //! shapes like `fn aos_to_soa(&[T], extract) -> SoaVec<f32 , N>` belong
43+ //! shapes like `fn aos_to_soa(&[T], extract) -> SoaVec<U , N>` belong
3644//! at the `crate::hpc` level, co-located with the data types they
3745//! convert between. Putting pure-scalar helpers in `simd_ops` would
3846//! contradict that module's charter and the W1a litmus that rejects
@@ -370,30 +378,34 @@ macro_rules! soa_struct {
370378 } ;
371379}
372380
373- /// Deinterleave an AoS slice into a [`SoaVec`] by extracting `N` field
374- /// values per item via the user-supplied `extract` closure.
381+ /// Deinterleave an AoS slice into a [`SoaVec<U, N>`] by extracting `N`
382+ /// field values per item via the user-supplied `extract` closure.
383+ ///
384+ /// `U` is the element type of the resulting `SoaVec` — generic over all
385+ /// `Copy` types. Common values:
386+ /// - `f32` — Gaussian batch means, covariances (original W3-W6 use case)
387+ /// - `u64` — `CausalEdge64` mantissa cells, NARS evidence packs
388+ /// - `u16` — BF16 carrier values, packed depth fields
389+ /// - `u8` — palette indices, quantized embeddings
375390///
376391/// Scalar implementation. A future bench-justified wave may add per-arch
377- /// SIMD gather (VPGATHERDD on AVX-512, LD3/LD4 on NEON) for stride-known
378- /// dense layouts; the public API is forward-compatible — the dispatcher
379- /// will grow internal per-arch arms without changing this signature.
392+ /// SIMD gather (VPGATHERDD on AVX-512, LD3/LD4 on NEON). The public
393+ /// signature is forward-compatible — the dispatcher will grow internal
394+ /// per-arch arms without changing this signature.
380395///
381- /// `T` need not be `Copy`; only the extracted `[f32; N]` row is
382- /// materialized.
396+ /// `T` need not be `Copy`; only the extracted `[U; N]` row is materialised.
383397///
384398/// # Inference
385399///
386- /// If the const-generic `N` fails to infer from the closure return type,
387- /// annotate either with a turbofish or a closure return-type ascription :
400+ /// If `N` fails to infer from the closure return type, annotate via
401+ /// turbofish (note: 4 type params now, was 3 in the f32-only era) :
388402///
389403/// ```ignore
390- /// aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]);
391- /// aos_to_soa(&aos, |it| -> [f32 ; 3] { [it.a, it.b, it.c] });
404+ /// aos_to_soa::<_, u64, 3, _>(&aos, |it| [it.a, it.b, it.c]);
405+ /// aos_to_soa(&aos, |it| -> [u64 ; 3] { [it.a, it.b, it.c] });
392406/// ```
393407///
394- /// (Verified on Rust 1.94.)
395- ///
396- /// # Example
408+ /// # Example — f32 (backwards-compatible)
397409///
398410/// ```
399411/// use ndarray::hpc::soa::aos_to_soa;
@@ -402,32 +414,58 @@ macro_rules! soa_struct {
402414/// Item { a: 1.0, b: 2.0, c: 3.0 },
403415/// Item { a: 4.0, b: 5.0, c: 6.0 },
404416/// ];
405- /// let soa = aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]);
417+ /// let soa = aos_to_soa::<_, f32, 3, _>(&aos, |it| [it.a, it.b, it.c]);
406418/// assert_eq!(soa.field(0), &[1.0, 4.0]);
407419/// assert_eq!(soa.field(1), &[2.0, 5.0]);
408420/// assert_eq!(soa.field(2), &[3.0, 6.0]);
409421/// ```
410- pub fn aos_to_soa < T , const N : usize , F > ( aos : & [ T ] , extract : F ) -> SoaVec < f32 , N >
422+ ///
423+ /// # Example — u64 (CausalEdge64-style)
424+ ///
425+ /// ```
426+ /// use ndarray::hpc::soa::aos_to_soa;
427+ /// struct Edge { src: u64, dst: u64, weight: u64 }
428+ /// let aos = vec![
429+ /// Edge { src: 1, dst: 2, weight: 10 },
430+ /// Edge { src: 3, dst: 4, weight: 20 },
431+ /// ];
432+ /// let soa = aos_to_soa::<_, u64, 3, _>(&aos, |e| [e.src, e.dst, e.weight]);
433+ /// assert_eq!(soa.field(0), &[1u64, 3]);
434+ /// assert_eq!(soa.field(2), &[10u64, 20]);
435+ /// ```
436+ ///
437+ /// # Example — u8 (palette indices)
438+ ///
439+ /// ```
440+ /// use ndarray::hpc::soa::aos_to_soa;
441+ /// struct Cell { palette: u8, alpha: u8 }
442+ /// let aos = vec![Cell { palette: 7, alpha: 255 }, Cell { palette: 3, alpha: 128 }];
443+ /// let soa = aos_to_soa::<_, u8, 2, _>(&aos, |c| [c.palette, c.alpha]);
444+ /// assert_eq!(soa.field(0), &[7u8, 3]);
445+ /// assert_eq!(soa.field(1), &[255u8, 128]);
446+ /// ```
447+ pub fn aos_to_soa < T , U , const N : usize , F > ( aos : & [ T ] , extract : F ) -> SoaVec < U , N >
411448where
412- F : Fn ( & T ) -> [ f32 ; N ] ,
449+ F : Fn ( & T ) -> [ U ; N ] ,
413450{
414- let mut soa = SoaVec :: < f32 , N > :: with_capacity ( aos. len ( ) ) ;
451+ let mut soa = SoaVec :: < U , N > :: with_capacity ( aos. len ( ) ) ;
415452 for item in aos {
416453 soa. push ( extract ( item) ) ;
417454 }
418455 soa
419456}
420457
421- /// Interleave a [`SoaVec`] into an AoS `Vec<T>` by building each item
458+ /// Interleave a [`SoaVec<U, N> `] into an AoS `Vec<T>` by building each item
422459/// from the per-field values via the user-supplied `build` closure.
423460///
424- /// Scalar implementation. See [`aos_to_soa`] for the forward-compatible
425- /// note on future SIMD acceleration.
461+ /// `U` is the element type of the input `SoaVec` (must be `Copy` so a
462+ /// per-row `[U; N]` can be materialised by indexing). Scalar implementation;
463+ /// the public signature is forward-compatible per [`aos_to_soa`].
426464///
427465/// Complexity: O(N·len) where N is the field count and len is the row
428466/// count.
429467///
430- /// # Example
468+ /// # Example — f32 (backwards-compatible)
431469///
432470/// ```
433471/// use ndarray::hpc::soa::{aos_to_soa, soa_to_aos};
@@ -436,20 +474,34 @@ where
436474/// Item { a: 1.0, b: 2.0, c: 3.0 },
437475/// Item { a: 4.0, b: 5.0, c: 6.0 },
438476/// ];
439- /// let soa = aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]);
477+ /// let soa = aos_to_soa::<_, f32, 3, _>(&aos, |it| [it.a, it.b, it.c]);
440478/// let back: Vec<Item> = soa_to_aos(&soa, |[a, b, c]| Item { a, b, c });
441479/// assert_eq!(back[0].a, 1.0);
442480/// assert_eq!(back[1].c, 6.0);
443481/// ```
444- pub fn soa_to_aos < T , const N : usize , F > ( soa : & SoaVec < f32 , N > , build : F ) -> Vec < T >
482+ ///
483+ /// # Example — u16 (BF16 carrier)
484+ ///
485+ /// ```
486+ /// use ndarray::hpc::soa::{aos_to_soa, soa_to_aos};
487+ /// #[derive(Debug, PartialEq)]
488+ /// struct Pair { lo: u16, hi: u16 }
489+ /// let aos = vec![Pair { lo: 0x1234, hi: 0xABCD }, Pair { lo: 0x5678, hi: 0xEF01 }];
490+ /// let soa = aos_to_soa::<_, u16, 2, _>(&aos, |p| [p.lo, p.hi]);
491+ /// let back: Vec<Pair> = soa_to_aos(&soa, |[lo, hi]| Pair { lo, hi });
492+ /// assert_eq!(back[0], Pair { lo: 0x1234, hi: 0xABCD });
493+ /// assert_eq!(back[1], Pair { lo: 0x5678, hi: 0xEF01 });
494+ /// ```
495+ pub fn soa_to_aos < T , U , const N : usize , F > ( soa : & SoaVec < U , N > , build : F ) -> Vec < T >
445496where
446- F : Fn ( [ f32 ; N ] ) -> T ,
497+ F : Fn ( [ U ; N ] ) -> T ,
498+ U : Copy ,
447499{
448500 let n = soa. len ( ) ;
449501 let fields = soa. all_fields ( ) ;
450502 let mut out = Vec :: with_capacity ( n) ;
451503 for i in 0 ..n {
452- let row: [ f32 ; N ] = core:: array:: from_fn ( |k| fields[ k] [ i] ) ;
504+ let row: [ U ; N ] = core:: array:: from_fn ( |k| fields[ k] [ i] ) ;
453505 out. push ( build ( row) ) ;
454506 }
455507 out
@@ -787,7 +839,7 @@ mod tests {
787839 #[ test]
788840 fn aos_to_soa_n2_roundtrip ( ) {
789841 let aos = vec ! [ ItemN2 { a: 1.0 , b: 2.0 } , ItemN2 { a: 3.0 , b: 4.0 } , ItemN2 { a: 5.0 , b: 6.0 } ] ;
790- let soa = aos_to_soa :: < _ , 2 , _ > ( & aos, |it| [ it. a , it. b ] ) ;
842+ let soa = aos_to_soa :: < _ , _ , 2 , _ > ( & aos, |it| [ it. a , it. b ] ) ;
791843 assert_eq ! ( soa. len( ) , 3 ) ;
792844 assert_eq ! ( soa. field( 0 ) , & [ 1.0 , 3.0 , 5.0 ] ) ;
793845 assert_eq ! ( soa. field( 1 ) , & [ 2.0 , 4.0 , 6.0 ] ) ;
@@ -798,7 +850,7 @@ mod tests {
798850 #[ test]
799851 fn aos_to_soa_n3_roundtrip ( ) {
800852 let aos = vec ! [ ItemN3 { a: 1.0 , b: 2.0 , c: 3.0 } , ItemN3 { a: 4.0 , b: 5.0 , c: 6.0 } ] ;
801- let soa = aos_to_soa :: < _ , 3 , _ > ( & aos, |it| [ it. a , it. b , it. c ] ) ;
853+ let soa = aos_to_soa :: < _ , _ , 3 , _ > ( & aos, |it| [ it. a , it. b , it. c ] ) ;
802854 assert_eq ! ( soa. field( 0 ) , & [ 1.0 , 4.0 ] ) ;
803855 assert_eq ! ( soa. field( 1 ) , & [ 2.0 , 5.0 ] ) ;
804856 assert_eq ! ( soa. field( 2 ) , & [ 3.0 , 6.0 ] ) ;
@@ -828,7 +880,7 @@ mod tests {
828880 d: 12.0 ,
829881 } ,
830882 ] ;
831- let soa = aos_to_soa :: < _ , 4 , _ > ( & aos, |it| [ it. a , it. b , it. c , it. d ] ) ;
883+ let soa = aos_to_soa :: < _ , _ , 4 , _ > ( & aos, |it| [ it. a , it. b , it. c , it. d ] ) ;
832884 assert_eq ! ( soa. field( 0 ) , & [ 1.0 , 5.0 , 9.0 ] ) ;
833885 assert_eq ! ( soa. field( 1 ) , & [ 2.0 , 6.0 , 10.0 ] ) ;
834886 assert_eq ! ( soa. field( 2 ) , & [ 3.0 , 7.0 , 11.0 ] ) ;
@@ -840,7 +892,7 @@ mod tests {
840892 #[ test]
841893 fn aos_to_soa_empty_input ( ) {
842894 let aos: Vec < ItemN3 > = Vec :: new ( ) ;
843- let soa = aos_to_soa :: < _ , 3 , _ > ( & aos, |it| [ it. a , it. b , it. c ] ) ;
895+ let soa = aos_to_soa :: < _ , _ , 3 , _ > ( & aos, |it| [ it. a , it. b , it. c ] ) ;
844896 assert ! ( soa. is_empty( ) ) ;
845897 assert_eq ! ( soa. field( 0 ) , & [ ] as & [ f32 ] ) ;
846898 assert_eq ! ( soa. field( 1 ) , & [ ] as & [ f32 ] ) ;
@@ -856,7 +908,7 @@ mod tests {
856908 // applied per row.
857909 let scale: f32 = 10.0 ;
858910 let aos = vec ! [ ItemN2 { a: 1.0 , b: 2.0 } , ItemN2 { a: 3.0 , b: 4.0 } ] ;
859- let soa = aos_to_soa :: < _ , 2 , _ > ( & aos, |it| [ it. a * scale, it. b * scale] ) ;
911+ let soa = aos_to_soa :: < _ , _ , 2 , _ > ( & aos, |it| [ it. a * scale, it. b * scale] ) ;
860912 assert_eq ! ( soa. field( 0 ) , & [ 10.0 , 30.0 ] ) ;
861913 assert_eq ! ( soa. field( 1 ) , & [ 20.0 , 40.0 ] ) ;
862914 }
@@ -867,4 +919,93 @@ mod tests {
867919 let back: Vec < ItemN2 > = soa_to_aos ( & soa, |[ a, b] | ItemN2 { a, b } ) ;
868920 assert ! ( back. is_empty( ) ) ;
869921 }
922+
923+ // ------------------------------------------------------------------
924+ // PR-X2 — generic-U coverage (was f32-hardwired through W3-W6)
925+ // ------------------------------------------------------------------
926+
927+ /// `aos_to_soa` over `u64` (CausalEdge64-style fields).
928+ #[ test]
929+ fn aos_to_soa_u64_round_trip ( ) {
930+ struct Edge {
931+ src : u64 ,
932+ dst : u64 ,
933+ weight : u64 ,
934+ }
935+ let aos = [
936+ Edge {
937+ src : 1 ,
938+ dst : 2 ,
939+ weight : 10 ,
940+ } ,
941+ Edge {
942+ src : 3 ,
943+ dst : 4 ,
944+ weight : 20 ,
945+ } ,
946+ Edge {
947+ src : 0xDEAD_BEEF_CAFE_BABE ,
948+ dst : 0 ,
949+ weight : u64:: MAX ,
950+ } ,
951+ ] ;
952+ let soa = aos_to_soa :: < _ , u64 , 3 , _ > ( & aos, |e| [ e. src , e. dst , e. weight ] ) ;
953+ assert_eq ! ( soa. len( ) , 3 ) ;
954+ assert_eq ! ( soa. field( 0 ) , & [ 1u64 , 3 , 0xDEAD_BEEF_CAFE_BABE ] ) ;
955+ assert_eq ! ( soa. field( 1 ) , & [ 2u64 , 4 , 0 ] ) ;
956+ assert_eq ! ( soa. field( 2 ) , & [ 10u64 , 20 , u64 :: MAX ] ) ;
957+ }
958+
959+ /// `aos_to_soa` over `u8` (palette indices) plus `soa_to_aos` round-trip.
960+ #[ test]
961+ fn aos_to_soa_u8_round_trip ( ) {
962+ #[ derive( Debug , PartialEq , Eq ) ]
963+ struct Cell {
964+ palette : u8 ,
965+ alpha : u8 ,
966+ }
967+ let aos = vec ! [ Cell { palette: 7 , alpha: 255 } , Cell { palette: 3 , alpha: 128 } , Cell { palette: 0 , alpha: 0 } ] ;
968+ let soa = aos_to_soa :: < _ , u8 , 2 , _ > ( & aos, |c| [ c. palette , c. alpha ] ) ;
969+ assert_eq ! ( soa. field( 0 ) , & [ 7u8 , 3 , 0 ] ) ;
970+ assert_eq ! ( soa. field( 1 ) , & [ 255u8 , 128 , 0 ] ) ;
971+
972+ let back: Vec < Cell > = soa_to_aos ( & soa, |[ palette, alpha] | Cell { palette, alpha } ) ;
973+ assert_eq ! ( back, aos) ;
974+ }
975+
976+ /// `aos_to_soa` over `u16` (BF16 carrier bytes).
977+ #[ test]
978+ fn aos_to_soa_u16_round_trip ( ) {
979+ #[ derive( Debug , PartialEq , Eq ) ]
980+ struct Bf16Pair {
981+ lo : u16 ,
982+ hi : u16 ,
983+ }
984+ let aos = vec ! [
985+ Bf16Pair { lo: 0x1234 , hi: 0xABCD } ,
986+ Bf16Pair { lo: 0x5678 , hi: 0xEF01 } ,
987+ Bf16Pair { lo: 0xFFFF , hi: 0x0000 } ,
988+ ] ;
989+ let soa = aos_to_soa :: < _ , u16 , 2 , _ > ( & aos, |p| [ p. lo , p. hi ] ) ;
990+ assert_eq ! ( soa. field( 0 ) , & [ 0x1234u16 , 0x5678 , 0xFFFF ] ) ;
991+ assert_eq ! ( soa. field( 1 ) , & [ 0xABCDu16 , 0xEF01 , 0x0000 ] ) ;
992+
993+ let back: Vec < Bf16Pair > = soa_to_aos ( & soa, |[ lo, hi] | Bf16Pair { lo, hi } ) ;
994+ assert_eq ! ( back, aos) ;
995+ }
996+
997+ /// Inference-only entry: caller relies on closure return-type ascription,
998+ /// no turbofish at all.
999+ #[ test]
1000+ fn aos_to_soa_inference_only ( ) {
1001+ struct Triple {
1002+ a : i8 ,
1003+ b : i8 ,
1004+ c : i8 ,
1005+ }
1006+ let aos = [ Triple { a : 1 , b : 2 , c : 3 } , Triple { a : -1 , b : -2 , c : -3 } ] ;
1007+ let soa = aos_to_soa ( & aos, |t| -> [ i8 ; 3 ] { [ t. a , t. b , t. c ] } ) ;
1008+ assert_eq ! ( soa. field( 0 ) , & [ 1i8 , -1 ] ) ;
1009+ assert_eq ! ( soa. field( 2 ) , & [ 3i8 , -3 ] ) ;
1010+ }
8701011}
0 commit comments