Skip to content

Commit fb95cb3

Browse files
authored
Merge pull request #168 from AdaWorldAPI/claude/pr-x2-generic-soa
PR-X2 Worker A: generalize aos_to_soa / soa_to_aos to <T, U, N, F>
2 parents 25874e7 + 8a859a3 commit fb95cb3

2 files changed

Lines changed: 183 additions & 42 deletions

File tree

src/hpc/bulk.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
//! .map(|i| Item { a: i as f32, b: (i * 2) as f32, c: (i * 3) as f32 })
3131
//! .collect();
3232
//! bulk_apply(&mut items, 16, |chunk, _start| {
33-
//! let soa = aos_to_soa::<_, 3, _>(chunk, |it| [it.a, it.b, it.c]);
33+
//! let soa = aos_to_soa::<_, _, 3, _>(chunk, |it| [it.a, it.b, it.c]);
3434
//! // ... per-field SIMD-style loops over soa.field(0), soa.field(1), ...
3535
//! let _ = soa;
3636
//! });
@@ -315,7 +315,7 @@ mod tests {
315315

316316
let mut chunk_count = 0;
317317
bulk_apply(&mut items, 16, |chunk, start_idx| {
318-
let soa = aos_to_soa::<_, 3, _>(chunk, |it| [it.a, it.b, it.c]);
318+
let soa = aos_to_soa::<_, _, 3, _>(chunk, |it| [it.a, it.b, it.c]);
319319
assert_eq!(soa.len(), chunk.len());
320320
// First row of the chunk corresponds to absolute index start_idx.
321321
assert_eq!(soa.field(0)[0], start_idx as f32);

src/hpc/soa.rs

Lines changed: 181 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,30 @@
1717
//! Both shapes are SIMD-friendly storage layouts: each field is a
1818
//! contiguous `Vec<T>`, so per-field SIMD loops iterate one `Vec`.
1919
//!
20-
//! # Element-type scope (this PR)
20+
//! # Element-type scope (PR-X2)
2121
//!
22-
//! The macro and `SoaVec` are generic over `T`. The closure-based
23-
//! conversion helpers ([`aos_to_soa`], [`soa_to_aos`]) are currently
24-
//! **hardwired to `f32` output** (`SoaVec<f32, N>`). Downstream consumers
25-
//! with `i8` / `u8` / `u16` / `bf16` SoA fields (palette indices,
26-
//! quantized embeddings, BF16 mantissa bytes) must write their own
27-
//! extract loop today; the public surface for generic-T conversion is
28-
//! a follow-up. The macro itself supports any field type.
22+
//! `SoaVec`, the `soa_struct!` macro, and the closure-based conversion
23+
//! helpers [`aos_to_soa`] / [`soa_to_aos`] are **fully generic over the
24+
//! element type `U`** (was f32-hardwired through W3-W6; PR-X2 lifted the
25+
//! constraint). Common element types now flow through directly:
26+
//!
27+
//! - `f32` — Gaussian batch means, covariances (original W3-W6 case)
28+
//! - `u64` — `CausalEdge64` mantissa cells, NARS evidence packs
29+
//! - `u16` — BF16 carrier values, packed depth fields
30+
//! - `u8` — palette indices, quantized embeddings
31+
//! - `i8` — quantized weights with signed range
32+
//!
33+
//! Callers passing turbofish should now use four type params:
34+
//! `aos_to_soa::<_, U, N, _>(...)` instead of the pre-PR-X2 form
35+
//! `aos_to_soa::<_, N, _>(...)`. Callers using return-type inference are
36+
//! unaffected by the generalisation.
2937
//!
3038
//! # Layering — why `hpc::soa` and not `simd_ops`
3139
//!
3240
//! `crate::simd_ops` is the SIMD-dispatch glue layer (every fn there
3341
//! dispatches through `F32x16` / `F64x8`). Per the W1a consumer contract
3442
//! at `.claude/knowledge/vertical-simd-consumer-contract.md`, free-function
35-
//! shapes like `fn aos_to_soa(&[T], extract) -> SoaVec<f32, N>` belong
43+
//! shapes like `fn aos_to_soa(&[T], extract) -> SoaVec<U, N>` belong
3644
//! at the `crate::hpc` level, co-located with the data types they
3745
//! convert between. Putting pure-scalar helpers in `simd_ops` would
3846
//! contradict that module's charter and the W1a litmus that rejects
@@ -370,30 +378,34 @@ macro_rules! soa_struct {
370378
};
371379
}
372380

373-
/// Deinterleave an AoS slice into a [`SoaVec`] by extracting `N` field
374-
/// values per item via the user-supplied `extract` closure.
381+
/// Deinterleave an AoS slice into a [`SoaVec<U, N>`] by extracting `N`
382+
/// field values per item via the user-supplied `extract` closure.
383+
///
384+
/// `U` is the element type of the resulting `SoaVec` — generic over all
385+
/// `Copy` types. Common values:
386+
/// - `f32` — Gaussian batch means, covariances (original W3-W6 use case)
387+
/// - `u64` — `CausalEdge64` mantissa cells, NARS evidence packs
388+
/// - `u16` — BF16 carrier values, packed depth fields
389+
/// - `u8` — palette indices, quantized embeddings
375390
///
376391
/// Scalar implementation. A future bench-justified wave may add per-arch
377-
/// SIMD gather (VPGATHERDD on AVX-512, LD3/LD4 on NEON) for stride-known
378-
/// dense layouts; the public API is forward-compatible — the dispatcher
379-
/// will grow internal per-arch arms without changing this signature.
392+
/// SIMD gather (VPGATHERDD on AVX-512, LD3/LD4 on NEON). The public
393+
/// signature is forward-compatible — the dispatcher will grow internal
394+
/// per-arch arms without changing this signature.
380395
///
381-
/// `T` need not be `Copy`; only the extracted `[f32; N]` row is
382-
/// materialized.
396+
/// `T` need not be `Copy`; only the extracted `[U; N]` row is materialised.
383397
///
384398
/// # Inference
385399
///
386-
/// If the const-generic `N` fails to infer from the closure return type,
387-
/// annotate either with a turbofish or a closure return-type ascription:
400+
/// If `N` fails to infer from the closure return type, annotate via
401+
/// turbofish (note: 4 type params now, was 3 in the f32-only era):
388402
///
389403
/// ```ignore
390-
/// aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]);
391-
/// aos_to_soa(&aos, |it| -> [f32; 3] { [it.a, it.b, it.c] });
404+
/// aos_to_soa::<_, u64, 3, _>(&aos, |it| [it.a, it.b, it.c]);
405+
/// aos_to_soa(&aos, |it| -> [u64; 3] { [it.a, it.b, it.c] });
392406
/// ```
393407
///
394-
/// (Verified on Rust 1.94.)
395-
///
396-
/// # Example
408+
/// # Example — f32 (backwards-compatible)
397409
///
398410
/// ```
399411
/// use ndarray::hpc::soa::aos_to_soa;
@@ -402,32 +414,58 @@ macro_rules! soa_struct {
402414
/// Item { a: 1.0, b: 2.0, c: 3.0 },
403415
/// Item { a: 4.0, b: 5.0, c: 6.0 },
404416
/// ];
405-
/// let soa = aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]);
417+
/// let soa = aos_to_soa::<_, f32, 3, _>(&aos, |it| [it.a, it.b, it.c]);
406418
/// assert_eq!(soa.field(0), &[1.0, 4.0]);
407419
/// assert_eq!(soa.field(1), &[2.0, 5.0]);
408420
/// assert_eq!(soa.field(2), &[3.0, 6.0]);
409421
/// ```
410-
pub fn aos_to_soa<T, const N: usize, F>(aos: &[T], extract: F) -> SoaVec<f32, N>
422+
///
423+
/// # Example — u64 (CausalEdge64-style)
424+
///
425+
/// ```
426+
/// use ndarray::hpc::soa::aos_to_soa;
427+
/// struct Edge { src: u64, dst: u64, weight: u64 }
428+
/// let aos = vec![
429+
/// Edge { src: 1, dst: 2, weight: 10 },
430+
/// Edge { src: 3, dst: 4, weight: 20 },
431+
/// ];
432+
/// let soa = aos_to_soa::<_, u64, 3, _>(&aos, |e| [e.src, e.dst, e.weight]);
433+
/// assert_eq!(soa.field(0), &[1u64, 3]);
434+
/// assert_eq!(soa.field(2), &[10u64, 20]);
435+
/// ```
436+
///
437+
/// # Example — u8 (palette indices)
438+
///
439+
/// ```
440+
/// use ndarray::hpc::soa::aos_to_soa;
441+
/// struct Cell { palette: u8, alpha: u8 }
442+
/// let aos = vec![Cell { palette: 7, alpha: 255 }, Cell { palette: 3, alpha: 128 }];
443+
/// let soa = aos_to_soa::<_, u8, 2, _>(&aos, |c| [c.palette, c.alpha]);
444+
/// assert_eq!(soa.field(0), &[7u8, 3]);
445+
/// assert_eq!(soa.field(1), &[255u8, 128]);
446+
/// ```
447+
pub fn aos_to_soa<T, U, const N: usize, F>(aos: &[T], extract: F) -> SoaVec<U, N>
411448
where
412-
F: Fn(&T) -> [f32; N],
449+
F: Fn(&T) -> [U; N],
413450
{
414-
let mut soa = SoaVec::<f32, N>::with_capacity(aos.len());
451+
let mut soa = SoaVec::<U, N>::with_capacity(aos.len());
415452
for item in aos {
416453
soa.push(extract(item));
417454
}
418455
soa
419456
}
420457

421-
/// Interleave a [`SoaVec`] into an AoS `Vec<T>` by building each item
458+
/// Interleave a [`SoaVec<U, N>`] into an AoS `Vec<T>` by building each item
422459
/// from the per-field values via the user-supplied `build` closure.
423460
///
424-
/// Scalar implementation. See [`aos_to_soa`] for the forward-compatible
425-
/// note on future SIMD acceleration.
461+
/// `U` is the element type of the input `SoaVec` (must be `Copy` so a
462+
/// per-row `[U; N]` can be materialised by indexing). Scalar implementation;
463+
/// the public signature is forward-compatible per [`aos_to_soa`].
426464
///
427465
/// Complexity: O(N·len) where N is the field count and len is the row
428466
/// count.
429467
///
430-
/// # Example
468+
/// # Example — f32 (backwards-compatible)
431469
///
432470
/// ```
433471
/// use ndarray::hpc::soa::{aos_to_soa, soa_to_aos};
@@ -436,20 +474,34 @@ where
436474
/// Item { a: 1.0, b: 2.0, c: 3.0 },
437475
/// Item { a: 4.0, b: 5.0, c: 6.0 },
438476
/// ];
439-
/// let soa = aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]);
477+
/// let soa = aos_to_soa::<_, f32, 3, _>(&aos, |it| [it.a, it.b, it.c]);
440478
/// let back: Vec<Item> = soa_to_aos(&soa, |[a, b, c]| Item { a, b, c });
441479
/// assert_eq!(back[0].a, 1.0);
442480
/// assert_eq!(back[1].c, 6.0);
443481
/// ```
444-
pub fn soa_to_aos<T, const N: usize, F>(soa: &SoaVec<f32, N>, build: F) -> Vec<T>
482+
///
483+
/// # Example — u16 (BF16 carrier)
484+
///
485+
/// ```
486+
/// use ndarray::hpc::soa::{aos_to_soa, soa_to_aos};
487+
/// #[derive(Debug, PartialEq)]
488+
/// struct Pair { lo: u16, hi: u16 }
489+
/// let aos = vec![Pair { lo: 0x1234, hi: 0xABCD }, Pair { lo: 0x5678, hi: 0xEF01 }];
490+
/// let soa = aos_to_soa::<_, u16, 2, _>(&aos, |p| [p.lo, p.hi]);
491+
/// let back: Vec<Pair> = soa_to_aos(&soa, |[lo, hi]| Pair { lo, hi });
492+
/// assert_eq!(back[0], Pair { lo: 0x1234, hi: 0xABCD });
493+
/// assert_eq!(back[1], Pair { lo: 0x5678, hi: 0xEF01 });
494+
/// ```
495+
pub fn soa_to_aos<T, U, const N: usize, F>(soa: &SoaVec<U, N>, build: F) -> Vec<T>
445496
where
446-
F: Fn([f32; N]) -> T,
497+
F: Fn([U; N]) -> T,
498+
U: Copy,
447499
{
448500
let n = soa.len();
449501
let fields = soa.all_fields();
450502
let mut out = Vec::with_capacity(n);
451503
for i in 0..n {
452-
let row: [f32; N] = core::array::from_fn(|k| fields[k][i]);
504+
let row: [U; N] = core::array::from_fn(|k| fields[k][i]);
453505
out.push(build(row));
454506
}
455507
out
@@ -787,7 +839,7 @@ mod tests {
787839
#[test]
788840
fn aos_to_soa_n2_roundtrip() {
789841
let aos = vec![ItemN2 { a: 1.0, b: 2.0 }, ItemN2 { a: 3.0, b: 4.0 }, ItemN2 { a: 5.0, b: 6.0 }];
790-
let soa = aos_to_soa::<_, 2, _>(&aos, |it| [it.a, it.b]);
842+
let soa = aos_to_soa::<_, _, 2, _>(&aos, |it| [it.a, it.b]);
791843
assert_eq!(soa.len(), 3);
792844
assert_eq!(soa.field(0), &[1.0, 3.0, 5.0]);
793845
assert_eq!(soa.field(1), &[2.0, 4.0, 6.0]);
@@ -798,7 +850,7 @@ mod tests {
798850
#[test]
799851
fn aos_to_soa_n3_roundtrip() {
800852
let aos = vec![ItemN3 { a: 1.0, b: 2.0, c: 3.0 }, ItemN3 { a: 4.0, b: 5.0, c: 6.0 }];
801-
let soa = aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]);
853+
let soa = aos_to_soa::<_, _, 3, _>(&aos, |it| [it.a, it.b, it.c]);
802854
assert_eq!(soa.field(0), &[1.0, 4.0]);
803855
assert_eq!(soa.field(1), &[2.0, 5.0]);
804856
assert_eq!(soa.field(2), &[3.0, 6.0]);
@@ -828,7 +880,7 @@ mod tests {
828880
d: 12.0,
829881
},
830882
];
831-
let soa = aos_to_soa::<_, 4, _>(&aos, |it| [it.a, it.b, it.c, it.d]);
883+
let soa = aos_to_soa::<_, _, 4, _>(&aos, |it| [it.a, it.b, it.c, it.d]);
832884
assert_eq!(soa.field(0), &[1.0, 5.0, 9.0]);
833885
assert_eq!(soa.field(1), &[2.0, 6.0, 10.0]);
834886
assert_eq!(soa.field(2), &[3.0, 7.0, 11.0]);
@@ -840,7 +892,7 @@ mod tests {
840892
#[test]
841893
fn aos_to_soa_empty_input() {
842894
let aos: Vec<ItemN3> = Vec::new();
843-
let soa = aos_to_soa::<_, 3, _>(&aos, |it| [it.a, it.b, it.c]);
895+
let soa = aos_to_soa::<_, _, 3, _>(&aos, |it| [it.a, it.b, it.c]);
844896
assert!(soa.is_empty());
845897
assert_eq!(soa.field(0), &[] as &[f32]);
846898
assert_eq!(soa.field(1), &[] as &[f32]);
@@ -856,7 +908,7 @@ mod tests {
856908
// applied per row.
857909
let scale: f32 = 10.0;
858910
let aos = vec![ItemN2 { a: 1.0, b: 2.0 }, ItemN2 { a: 3.0, b: 4.0 }];
859-
let soa = aos_to_soa::<_, 2, _>(&aos, |it| [it.a * scale, it.b * scale]);
911+
let soa = aos_to_soa::<_, _, 2, _>(&aos, |it| [it.a * scale, it.b * scale]);
860912
assert_eq!(soa.field(0), &[10.0, 30.0]);
861913
assert_eq!(soa.field(1), &[20.0, 40.0]);
862914
}
@@ -867,4 +919,93 @@ mod tests {
867919
let back: Vec<ItemN2> = soa_to_aos(&soa, |[a, b]| ItemN2 { a, b });
868920
assert!(back.is_empty());
869921
}
922+
923+
// ------------------------------------------------------------------
924+
// PR-X2 — generic-U coverage (was f32-hardwired through W3-W6)
925+
// ------------------------------------------------------------------
926+
927+
/// `aos_to_soa` over `u64` (CausalEdge64-style fields).
928+
#[test]
929+
fn aos_to_soa_u64_round_trip() {
930+
struct Edge {
931+
src: u64,
932+
dst: u64,
933+
weight: u64,
934+
}
935+
let aos = [
936+
Edge {
937+
src: 1,
938+
dst: 2,
939+
weight: 10,
940+
},
941+
Edge {
942+
src: 3,
943+
dst: 4,
944+
weight: 20,
945+
},
946+
Edge {
947+
src: 0xDEAD_BEEF_CAFE_BABE,
948+
dst: 0,
949+
weight: u64::MAX,
950+
},
951+
];
952+
let soa = aos_to_soa::<_, u64, 3, _>(&aos, |e| [e.src, e.dst, e.weight]);
953+
assert_eq!(soa.len(), 3);
954+
assert_eq!(soa.field(0), &[1u64, 3, 0xDEAD_BEEF_CAFE_BABE]);
955+
assert_eq!(soa.field(1), &[2u64, 4, 0]);
956+
assert_eq!(soa.field(2), &[10u64, 20, u64::MAX]);
957+
}
958+
959+
/// `aos_to_soa` over `u8` (palette indices) plus `soa_to_aos` round-trip.
960+
#[test]
961+
fn aos_to_soa_u8_round_trip() {
962+
#[derive(Debug, PartialEq, Eq)]
963+
struct Cell {
964+
palette: u8,
965+
alpha: u8,
966+
}
967+
let aos = vec![Cell { palette: 7, alpha: 255 }, Cell { palette: 3, alpha: 128 }, Cell { palette: 0, alpha: 0 }];
968+
let soa = aos_to_soa::<_, u8, 2, _>(&aos, |c| [c.palette, c.alpha]);
969+
assert_eq!(soa.field(0), &[7u8, 3, 0]);
970+
assert_eq!(soa.field(1), &[255u8, 128, 0]);
971+
972+
let back: Vec<Cell> = soa_to_aos(&soa, |[palette, alpha]| Cell { palette, alpha });
973+
assert_eq!(back, aos);
974+
}
975+
976+
/// `aos_to_soa` over `u16` (BF16 carrier bytes).
977+
#[test]
978+
fn aos_to_soa_u16_round_trip() {
979+
#[derive(Debug, PartialEq, Eq)]
980+
struct Bf16Pair {
981+
lo: u16,
982+
hi: u16,
983+
}
984+
let aos = vec![
985+
Bf16Pair { lo: 0x1234, hi: 0xABCD },
986+
Bf16Pair { lo: 0x5678, hi: 0xEF01 },
987+
Bf16Pair { lo: 0xFFFF, hi: 0x0000 },
988+
];
989+
let soa = aos_to_soa::<_, u16, 2, _>(&aos, |p| [p.lo, p.hi]);
990+
assert_eq!(soa.field(0), &[0x1234u16, 0x5678, 0xFFFF]);
991+
assert_eq!(soa.field(1), &[0xABCDu16, 0xEF01, 0x0000]);
992+
993+
let back: Vec<Bf16Pair> = soa_to_aos(&soa, |[lo, hi]| Bf16Pair { lo, hi });
994+
assert_eq!(back, aos);
995+
}
996+
997+
/// Inference-only entry: caller relies on closure return-type ascription,
998+
/// no turbofish at all.
999+
#[test]
1000+
fn aos_to_soa_inference_only() {
1001+
struct Triple {
1002+
a: i8,
1003+
b: i8,
1004+
c: i8,
1005+
}
1006+
let aos = [Triple { a: 1, b: 2, c: 3 }, Triple { a: -1, b: -2, c: -3 }];
1007+
let soa = aos_to_soa(&aos, |t| -> [i8; 3] { [t.a, t.b, t.c] });
1008+
assert_eq!(soa.field(0), &[1i8, -1]);
1009+
assert_eq!(soa.field(2), &[3i8, -3]);
1010+
}
8701011
}

0 commit comments

Comments
 (0)