Skip to content

Commit 098c5aa

Browse files
authored
Merge pull request #183 from AdaWorldAPI/claude/continue-ndarray-x0Oaw
simd_half: TD-SIMD-8 — F16C-vectorized F16↔f32 batch conversion
2 parents ae5efaa + 1a73c37 commit 098c5aa

2 files changed

Lines changed: 315 additions & 9 deletions

File tree

src/simd_avx512.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,6 +2405,22 @@ pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) {
24052405
}
24062406
return;
24072407
}
2408+
// Middle tier: pure AVX-512F bit-shift (Skylake-X, Cascade Lake,
2409+
// Ice Lake-SP — all AVX-512F CPUs without the bf16 extension).
2410+
// BF16 → f32 is lossless: BF16 IS the upper 16 bits of f32, so
2411+
// `(bf16_u16 as u32) << 16` reinterpreted as f32 IS the exact
2412+
// value. Vectorized: one _mm512_cvtepu16_epi32 zero-extends 16
2413+
// u16 → 16 u32, one _mm512_slli_epi32::<16> shifts each lane left
2414+
// by 16, _mm512_castsi512_ps reinterprets the i32 bit pattern as
2415+
// f32. Three AVX-512F instructions per 16-lane chunk vs 16
2416+
// scalar shifts in the fallback below.
2417+
if is_x86_feature_detected!("avx512f") {
2418+
// SAFETY: feature detection confirmed avx512f.
2419+
unsafe {
2420+
convert_bf16_to_f32_avx512f(input, output);
2421+
}
2422+
return;
2423+
}
24082424
}
24092425

24102426
// Scalar fallback (all platforms, all CPUs)
@@ -2413,6 +2429,36 @@ pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) {
24132429
}
24142430
}
24152431

2432+
/// Pure-AVX-512F BF16 → f32 conversion. Bit-exact against
2433+
/// `bf16_to_f32_scalar` on every input — BF16 is `f32_bits >> 16`, so
2434+
/// the inverse `(bf16 as u32) << 16` reconstructed as f32 is exact.
2435+
///
2436+
/// 16-lane main loop via `_mm512_cvtepu16_epi32` (zero-extend) +
2437+
/// `_mm512_slli_epi32::<16>` (shift up) + `_mm512_castsi512_ps`
2438+
/// (bit-cast). Scalar tail for the last `n % 16` lanes.
2439+
#[cfg(target_arch = "x86_64")]
2440+
#[target_feature(enable = "avx512f")]
2441+
unsafe fn convert_bf16_to_f32_avx512f(input: &[u16], output: &mut [f32]) {
2442+
let n = input.len();
2443+
let mut i = 0usize;
2444+
2445+
// Main 16-wide loop.
2446+
while i + 16 <= n {
2447+
let raw256 = _mm256_loadu_si256(input.as_ptr().add(i) as *const __m256i);
2448+
let extended = _mm512_cvtepu16_epi32(raw256);
2449+
let shifted = _mm512_slli_epi32::<16>(extended);
2450+
let as_f32 = _mm512_castsi512_ps(shifted);
2451+
_mm512_storeu_ps(output.as_mut_ptr().add(i), as_f32);
2452+
i += 16;
2453+
}
2454+
2455+
// Scalar tail (0..15 remaining lanes).
2456+
while i < n {
2457+
*output.get_unchecked_mut(i) = bf16_to_f32_scalar(*input.get_unchecked(i));
2458+
i += 1;
2459+
}
2460+
}
2461+
24162462
/// Batch f32 → BF16 conversion: same pattern.
24172463
pub fn f32_to_bf16_batch(input: &[f32], output: &mut [u16]) {
24182464
assert!(output.len() >= input.len(), "output must be >= input length");
@@ -2707,6 +2753,55 @@ mod bf16_tests {
27072753
}
27082754
}
27092755

2756+
/// Direct test for the AVX-512F bit-shift BF16 → f32 arm, exercising
2757+
/// the path the dispatcher would skip when avx512bf16 is available.
2758+
/// Verifies bit-exact parity against the scalar reference across a
2759+
/// pathological corpus (subnormal, NaN, Inf, sign ±0, every exponent
2760+
/// boundary) and a 16-aligned-plus-tail length.
2761+
#[cfg(target_arch = "x86_64")]
2762+
#[test]
2763+
fn batch_bf16_to_f32_avx512f_matches_scalar() {
2764+
if !is_x86_feature_detected!("avx512f") {
2765+
eprintln!("avx512f not detected on this host; skipping");
2766+
return;
2767+
}
2768+
// Build a corpus: every bf16 value of interest. The dispatcher's
2769+
// 16-wide loop is what matters most; pick a non-aligned total so
2770+
// we also exercise the scalar tail.
2771+
let mut input: Vec<u16> = Vec::new();
2772+
// Sign × exponent × representative mantissa sweep
2773+
for sign in [0u16, 0x8000] {
2774+
for exp in 0..256u16 {
2775+
for &mant in &[0u16, 1, 0x40, 0x7F] {
2776+
input.push(sign | (exp << 7) | mant);
2777+
}
2778+
}
2779+
}
2780+
// Add 5 bytes of tail to land on a non-16-aligned length.
2781+
input.extend_from_slice(&[0x3F80, 0xBF80, 0x4000, 0xC000, 0x7F80]);
2782+
2783+
let mut output = vec![0.0f32; input.len()];
2784+
// SAFETY: avx512f confirmed above.
2785+
unsafe { convert_bf16_to_f32_avx512f(&input, &mut output) };
2786+
2787+
for (i, &bf16) in input.iter().enumerate() {
2788+
let expected = bf16_to_f32_scalar(bf16);
2789+
// BF16 → f32 is lossless: bits must be byte-equal (incl. NaN
2790+
// payloads).
2791+
assert_eq!(
2792+
output[i].to_bits(),
2793+
expected.to_bits(),
2794+
"mismatch at index {} (bf16=0x{:04x}): got {} (0x{:08x}) vs {} (0x{:08x})",
2795+
i,
2796+
bf16,
2797+
output[i],
2798+
output[i].to_bits(),
2799+
expected,
2800+
expected.to_bits()
2801+
);
2802+
}
2803+
}
2804+
27102805
// ─────────────────────────────────────────────────────────────────────
27112806
// RNE certification tests — byte-equality with `_mm512_cvtneps_pbh`.
27122807
// ─────────────────────────────────────────────────────────────────────

src/simd_half.rs

Lines changed: 220 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -351,18 +351,31 @@ pub fn cast_bf16_to_f32_batch(src: &[BF16], dst: &mut [f32]) {
351351

352352
/// Batch convert F16 → f32.
353353
///
354-
/// Uses F16x16 for chunks of 16, scalar tail for remainder.
354+
/// On x86_64 with F16C (every CPU from Ivy Bridge 2013 / Piledriver 2012
355+
/// onward), dispatches to `_mm256_cvtph_ps` — one hardware instruction
356+
/// converts 8 F16 lanes to 8 F32 lanes, IEEE-754 exact. The scalar
357+
/// fallback uses the bit-fiddle [`F16::to_f32`] which is also IEEE-754
358+
/// exact, just slower.
355359
pub fn cast_f16_to_f32_batch(src: &[F16], dst: &mut [f32]) {
356360
let n = src.len().min(dst.len());
357-
let chunks = n / 16;
358-
for c in 0..chunks {
359-
let off = c * 16;
360-
let v = F16x16::from_slice(&src[off..]);
361-
let f = v.to_f32x16();
362-
dst[off..off + 16].copy_from_slice(&f);
361+
362+
#[cfg(target_arch = "x86_64")]
363+
{
364+
if std::is_x86_feature_detected!("f16c") && std::is_x86_feature_detected!("avx") {
365+
// SAFETY: `F16` is `#[repr(transparent)] struct F16(pub u16)`
366+
// (per `hpc::quantized::F16`). Slice reinterpretation is
367+
// bit-pattern preserving. Runtime feature detection above
368+
// confirms F16C + AVX before calling the target-feature fn.
369+
let src_u16: &[u16] = unsafe { core::slice::from_raw_parts(src.as_ptr() as *const u16, src.len()) };
370+
unsafe {
371+
cast_f16_to_f32_batch_f16c(&src_u16[..n], &mut dst[..n]);
372+
}
373+
return;
374+
}
363375
}
364-
// Scalar tail
365-
for i in (chunks * 16)..n {
376+
377+
// Scalar fallback (non-x86_64 or pre-F16C silicon).
378+
for i in 0..n {
366379
dst[i] = src[i].to_f32();
367380
}
368381
}
@@ -376,13 +389,134 @@ pub fn cast_f32_to_bf16_batch(src: &[f32], dst: &mut [BF16]) {
376389
}
377390

378391
/// Batch convert f32 → F16 (round-to-nearest-even).
392+
///
393+
/// On x86_64 with F16C, dispatches to `_mm256_cvtps_ph::<8>` (RNE,
394+
/// no exceptions) — one hardware instruction converts 8 F32 lanes to
395+
/// 8 F16 lanes with IEEE 754 round-to-nearest-even. Scalar fallback
396+
/// uses [`F16::from_f32_rounded`] which matches the IEEE 754 RNE rule
397+
/// bit-for-bit on every input (including subnormal / NaN / Inf).
379398
pub fn cast_f32_to_f16_batch(src: &[f32], dst: &mut [F16]) {
380399
let n = src.len().min(dst.len());
400+
401+
#[cfg(target_arch = "x86_64")]
402+
{
403+
if std::is_x86_feature_detected!("f16c") && std::is_x86_feature_detected!("avx") {
404+
// SAFETY: same as cast_f16_to_f32_batch — `F16` is
405+
// repr(transparent) over u16; runtime feature gate ensures
406+
// F16C is present.
407+
let dst_u16: &mut [u16] =
408+
unsafe { core::slice::from_raw_parts_mut(dst.as_mut_ptr() as *mut u16, dst.len()) };
409+
unsafe {
410+
cast_f32_to_f16_batch_f16c(&src[..n], &mut dst_u16[..n]);
411+
}
412+
return;
413+
}
414+
}
415+
381416
for i in 0..n {
382417
dst[i] = F16::from_f32_rounded(src[i]);
383418
}
384419
}
385420

421+
/// F16C-vectorized F16 → f32 batch.
422+
///
423+
/// 8 F16 lanes per `_mm256_cvtph_ps` instruction (one xmm load + one
424+
/// ymm store). Scalar tail handles the remaining `n % 8` lanes via the
425+
/// bit-fiddle reference. **F16C result is bit-identical to the scalar
426+
/// reference per IEEE 754 binary16 → binary32 spec** (lossless widening,
427+
/// no rounding possible).
428+
///
429+
/// # MXCSR preservation
430+
/// `_mm256_cvtph_ps` may raise `#I` (Invalid: SNaN input) or `#D`
431+
/// (Denormal) — setting bits in MXCSR that the scalar bit-fiddle
432+
/// reference [`F16::to_f32`] does not touch. To preserve the scalar
433+
/// path's contract of "no observable FP control/status side effects,"
434+
/// the MXCSR is saved before the SIMD region and restored after. Net
435+
/// effect: callers see no MXCSR change vs. the scalar path. (See
436+
/// codex review on PR #183.)
437+
///
438+
/// # Safety
439+
/// Caller must have feature-detected `f16c` + `avx` at runtime.
440+
#[cfg(target_arch = "x86_64")]
441+
#[target_feature(enable = "f16c,avx")]
442+
unsafe fn cast_f16_to_f32_batch_f16c(src: &[u16], dst: &mut [f32]) {
443+
use core::arch::asm;
444+
use core::arch::x86_64::{__m128i, _mm256_cvtph_ps, _mm256_storeu_ps, _mm_loadu_si128};
445+
let mut saved_mxcsr: u32 = 0;
446+
// SAFETY: STMXCSR writes the 32-bit MXCSR control/status register
447+
// to the provided memory location; available on any SSE host
448+
// (baseline x86_64).
449+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut saved_mxcsr, options(nostack));
450+
let n = src.len().min(dst.len());
451+
let chunks = n / 8;
452+
for c in 0..chunks {
453+
let off = c * 8;
454+
let h = _mm_loadu_si128(src.as_ptr().add(off) as *const __m128i);
455+
let f = _mm256_cvtph_ps(h);
456+
_mm256_storeu_ps(dst.as_mut_ptr().add(off), f);
457+
}
458+
// Scalar tail (0..7 remaining lanes).
459+
for i in (chunks * 8)..n {
460+
dst[i] = F16(src[i]).to_f32();
461+
}
462+
// SAFETY: LDMXCSR reads the value we saved at the top — preserves
463+
// every bit of the original MXCSR (rounding mode, exception masks,
464+
// flush-to-zero etc.), clearing any exception flags the SIMD path
465+
// may have set.
466+
asm!("ldmxcsr [{ptr}]", ptr = in(reg) &saved_mxcsr, options(nostack, readonly));
467+
}
468+
469+
/// F16C-vectorized f32 → F16 batch with IEEE 754 RNE rounding.
470+
///
471+
/// 8 F32 lanes per `_mm256_cvtps_ph::<0>` instruction (one ymm load +
472+
/// one xmm store). The const `IMM8 = 0` selects
473+
/// `_MM_FROUND_TO_NEAREST_INT` — round-to-nearest-even, matches the
474+
/// scalar reference [`F16::from_f32_rounded`] bit-for-bit on every
475+
/// input.
476+
///
477+
/// # IMM8 encoding limit
478+
/// `_mm256_cvtps_ph`'s `IMM8` is 3 bits wide (`static_assert_uimm_bits!
479+
/// (IMM8, 3)` in the Rust stdarch wrapper). Valid values are `0..=3`
480+
/// (the four rounding modes — RNE, down, up, truncate). Bits 2-3 of
481+
/// the underlying VCVTPS2PH IMM8 encoding are "reserved" and "select
482+
/// MXCSR.RM" per Intel SDM — NOT `_MM_FROUND_NO_EXC`, which is an
483+
/// AVX-512 convention (`_mm512_cvtps_ph` accepts `NO_EXC`, F16C does
484+
/// not). Exception suppression is handled at the MXCSR level (below).
485+
///
486+
/// # MXCSR preservation
487+
/// `_mm256_cvtps_ph` may raise `#O` (Overflow), `#U` (Underflow),
488+
/// `#P` (Precision), `#I` (Invalid for SNaN), `#D` (Denormal). The
489+
/// scalar reference [`F16::from_f32_rounded`] is pure bit
490+
/// manipulation and never touches MXCSR. We save/restore MXCSR around
491+
/// the SIMD region so callers see no observable control/status side
492+
/// effects regardless of input data. (See codex review on PR #183.)
493+
///
494+
/// # Safety
495+
/// Caller must have feature-detected `f16c` + `avx` at runtime.
496+
#[cfg(target_arch = "x86_64")]
497+
#[target_feature(enable = "f16c,avx")]
498+
unsafe fn cast_f32_to_f16_batch_f16c(src: &[f32], dst: &mut [u16]) {
499+
use core::arch::asm;
500+
use core::arch::x86_64::{__m128i, _mm256_cvtps_ph, _mm256_loadu_ps, _mm_storeu_si128};
501+
let mut saved_mxcsr: u32 = 0;
502+
// SAFETY: STMXCSR writes the 32-bit MXCSR; baseline SSE op.
503+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut saved_mxcsr, options(nostack));
504+
let n = src.len().min(dst.len());
505+
let chunks = n / 8;
506+
for c in 0..chunks {
507+
let off = c * 8;
508+
let f = _mm256_loadu_ps(src.as_ptr().add(off));
509+
let h = _mm256_cvtps_ph::<0>(f);
510+
_mm_storeu_si128(dst.as_mut_ptr().add(off) as *mut __m128i, h);
511+
}
512+
// Scalar tail.
513+
for i in (chunks * 8)..n {
514+
dst[i] = F16::from_f32_rounded(src[i]).0;
515+
}
516+
// SAFETY: LDMXCSR restores the saved value bit-for-bit.
517+
asm!("ldmxcsr [{ptr}]", ptr = in(reg) &saved_mxcsr, options(nostack, readonly));
518+
}
519+
386520
// ============================================================================
387521
// Tests
388522
// ============================================================================
@@ -759,4 +893,81 @@ mod tests {
759893
assert_eq!(dst[i], expected[i], "mul_f16_inplace mismatch at {}", i);
760894
}
761895
}
896+
897+
/// Codex PR #183 P2: F16C `_mm256_cvtps_ph` may raise FP exceptions
898+
/// (#O on overflow, #U on underflow, #P on precision loss, #I on
899+
/// SNaN, #D on denormal input) which set bits in MXCSR. The scalar
900+
/// path is pure bit manipulation and never touches MXCSR. The fix:
901+
/// `cast_f32_to_f16_batch_f16c` saves MXCSR via STMXCSR before the
902+
/// SIMD region and restores it via LDMXCSR after. This test feeds
903+
/// inputs that should trigger every exception bit and asserts
904+
/// MXCSR is byte-identical before vs. after the call.
905+
#[cfg(target_arch = "x86_64")]
906+
#[test]
907+
fn f16c_cast_preserves_mxcsr() {
908+
if !std::is_x86_feature_detected!("f16c") {
909+
eprintln!("f16c not detected; skipping");
910+
return;
911+
}
912+
use core::arch::asm;
913+
914+
// Inputs designed to trigger #O / #U / #P / #I / #D in F16C
915+
// downcast:
916+
// - 1e30, -1e30 : overflow (out of F16 range ±65504) → #O
917+
// - 1e-30 : underflow / denormal → #U, #D, #P
918+
// - 1.0/3.0 : precision loss → #P
919+
// - f32::NAN : invalid (if it's an sNaN representation) → #I
920+
let inputs: Vec<f32> = vec![
921+
1e30,
922+
-1e30,
923+
1e-30,
924+
1.0 / 3.0,
925+
f32::NAN,
926+
f32::INFINITY,
927+
0.0,
928+
1.0,
929+
// Pad to 8 lanes so the SIMD chunk loop fires once with no tail.
930+
];
931+
assert_eq!(inputs.len(), 8);
932+
let mut out = vec![F16::ZERO; 8];
933+
934+
// Snapshot MXCSR before.
935+
let mut mxcsr_before: u32 = 0;
936+
unsafe {
937+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_before, options(nostack));
938+
}
939+
940+
cast_f32_to_f16_batch(&inputs, &mut out);
941+
942+
// Snapshot MXCSR after.
943+
let mut mxcsr_after: u32 = 0;
944+
unsafe {
945+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_after, options(nostack));
946+
}
947+
948+
assert_eq!(
949+
mxcsr_before, mxcsr_after,
950+
"cast_f32_to_f16_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)",
951+
mxcsr_before, mxcsr_after
952+
);
953+
954+
// Same check for the upcast direction (`_mm256_cvtph_ps` can raise
955+
// #I/#D on SNaN/denormal F16 input).
956+
let f16_inputs: Vec<F16> = (0..8).map(|i| F16(0x7C01 + i as u16)).collect(); // SNaN-ish
957+
let mut f32_out = vec![0.0f32; 8];
958+
959+
unsafe {
960+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_before, options(nostack));
961+
}
962+
cast_f16_to_f32_batch(&f16_inputs, &mut f32_out);
963+
unsafe {
964+
asm!("stmxcsr [{ptr}]", ptr = in(reg) &mut mxcsr_after, options(nostack));
965+
}
966+
967+
assert_eq!(
968+
mxcsr_before, mxcsr_after,
969+
"cast_f16_to_f32_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)",
970+
mxcsr_before, mxcsr_after
971+
);
972+
}
762973
}

0 commit comments

Comments
 (0)