@@ -351,18 +351,31 @@ pub fn cast_bf16_to_f32_batch(src: &[BF16], dst: &mut [f32]) {
351351
352352/// Batch convert F16 → f32.
353353///
354- /// Uses F16x16 for chunks of 16, scalar tail for remainder.
354+ /// On x86_64 with F16C (every CPU from Ivy Bridge 2013 / Piledriver 2012
355+ /// onward), dispatches to `_mm256_cvtph_ps` — one hardware instruction
356+ /// converts 8 F16 lanes to 8 F32 lanes, IEEE-754 exact. The scalar
357+ /// fallback uses the bit-fiddle [`F16::to_f32`] which is also IEEE-754
358+ /// exact, just slower.
355359pub fn cast_f16_to_f32_batch ( src : & [ F16 ] , dst : & mut [ f32 ] ) {
356360 let n = src. len ( ) . min ( dst. len ( ) ) ;
357- let chunks = n / 16 ;
358- for c in 0 ..chunks {
359- let off = c * 16 ;
360- let v = F16x16 :: from_slice ( & src[ off..] ) ;
361- let f = v. to_f32x16 ( ) ;
362- dst[ off..off + 16 ] . copy_from_slice ( & f) ;
361+
362+ #[ cfg( target_arch = "x86_64" ) ]
363+ {
364+ if std:: is_x86_feature_detected!( "f16c" ) && std:: is_x86_feature_detected!( "avx" ) {
365+ // SAFETY: `F16` is `#[repr(transparent)] struct F16(pub u16)`
366+ // (per `hpc::quantized::F16`). Slice reinterpretation is
367+ // bit-pattern preserving. Runtime feature detection above
368+ // confirms F16C + AVX before calling the target-feature fn.
369+ let src_u16: & [ u16 ] = unsafe { core:: slice:: from_raw_parts ( src. as_ptr ( ) as * const u16 , src. len ( ) ) } ;
370+ unsafe {
371+ cast_f16_to_f32_batch_f16c ( & src_u16[ ..n] , & mut dst[ ..n] ) ;
372+ }
373+ return ;
374+ }
363375 }
364- // Scalar tail
365- for i in ( chunks * 16 ) ..n {
376+
377+ // Scalar fallback (non-x86_64 or pre-F16C silicon).
378+ for i in 0 ..n {
366379 dst[ i] = src[ i] . to_f32 ( ) ;
367380 }
368381}
@@ -376,13 +389,134 @@ pub fn cast_f32_to_bf16_batch(src: &[f32], dst: &mut [BF16]) {
376389}
377390
378391/// Batch convert f32 → F16 (round-to-nearest-even).
392+ ///
393+ /// On x86_64 with F16C, dispatches to `_mm256_cvtps_ph::<8>` (RNE,
394+ /// no exceptions) — one hardware instruction converts 8 F32 lanes to
395+ /// 8 F16 lanes with IEEE 754 round-to-nearest-even. Scalar fallback
396+ /// uses [`F16::from_f32_rounded`] which matches the IEEE 754 RNE rule
397+ /// bit-for-bit on every input (including subnormal / NaN / Inf).
379398pub fn cast_f32_to_f16_batch ( src : & [ f32 ] , dst : & mut [ F16 ] ) {
380399 let n = src. len ( ) . min ( dst. len ( ) ) ;
400+
401+ #[ cfg( target_arch = "x86_64" ) ]
402+ {
403+ if std:: is_x86_feature_detected!( "f16c" ) && std:: is_x86_feature_detected!( "avx" ) {
404+ // SAFETY: same as cast_f16_to_f32_batch — `F16` is
405+ // repr(transparent) over u16; runtime feature gate ensures
406+ // F16C is present.
407+ let dst_u16: & mut [ u16 ] =
408+ unsafe { core:: slice:: from_raw_parts_mut ( dst. as_mut_ptr ( ) as * mut u16 , dst. len ( ) ) } ;
409+ unsafe {
410+ cast_f32_to_f16_batch_f16c ( & src[ ..n] , & mut dst_u16[ ..n] ) ;
411+ }
412+ return ;
413+ }
414+ }
415+
381416 for i in 0 ..n {
382417 dst[ i] = F16 :: from_f32_rounded ( src[ i] ) ;
383418 }
384419}
385420
421+ /// F16C-vectorized F16 → f32 batch.
422+ ///
423+ /// 8 F16 lanes per `_mm256_cvtph_ps` instruction (one xmm load + one
424+ /// ymm store). Scalar tail handles the remaining `n % 8` lanes via the
425+ /// bit-fiddle reference. **F16C result is bit-identical to the scalar
426+ /// reference per IEEE 754 binary16 → binary32 spec** (lossless widening,
427+ /// no rounding possible).
428+ ///
429+ /// # MXCSR preservation
430+ /// `_mm256_cvtph_ps` may raise `#I` (Invalid: SNaN input) or `#D`
431+ /// (Denormal) — setting bits in MXCSR that the scalar bit-fiddle
432+ /// reference [`F16::to_f32`] does not touch. To preserve the scalar
433+ /// path's contract of "no observable FP control/status side effects,"
434+ /// the MXCSR is saved before the SIMD region and restored after. Net
435+ /// effect: callers see no MXCSR change vs. the scalar path. (See
436+ /// codex review on PR #183.)
437+ ///
438+ /// # Safety
439+ /// Caller must have feature-detected `f16c` + `avx` at runtime.
440+ #[ cfg( target_arch = "x86_64" ) ]
441+ #[ target_feature( enable = "f16c,avx" ) ]
442+ unsafe fn cast_f16_to_f32_batch_f16c ( src : & [ u16 ] , dst : & mut [ f32 ] ) {
443+ use core:: arch:: asm;
444+ use core:: arch:: x86_64:: { __m128i, _mm256_cvtph_ps, _mm256_storeu_ps, _mm_loadu_si128} ;
445+ let mut saved_mxcsr: u32 = 0 ;
446+ // SAFETY: STMXCSR writes the 32-bit MXCSR control/status register
447+ // to the provided memory location; available on any SSE host
448+ // (baseline x86_64).
449+ asm ! ( "stmxcsr [{ptr}]" , ptr = in( reg) & mut saved_mxcsr, options( nostack) ) ;
450+ let n = src. len ( ) . min ( dst. len ( ) ) ;
451+ let chunks = n / 8 ;
452+ for c in 0 ..chunks {
453+ let off = c * 8 ;
454+ let h = _mm_loadu_si128 ( src. as_ptr ( ) . add ( off) as * const __m128i ) ;
455+ let f = _mm256_cvtph_ps ( h) ;
456+ _mm256_storeu_ps ( dst. as_mut_ptr ( ) . add ( off) , f) ;
457+ }
458+ // Scalar tail (0..7 remaining lanes).
459+ for i in ( chunks * 8 ) ..n {
460+ dst[ i] = F16 ( src[ i] ) . to_f32 ( ) ;
461+ }
462+ // SAFETY: LDMXCSR reads the value we saved at the top — preserves
463+ // every bit of the original MXCSR (rounding mode, exception masks,
464+ // flush-to-zero etc.), clearing any exception flags the SIMD path
465+ // may have set.
466+ asm ! ( "ldmxcsr [{ptr}]" , ptr = in( reg) & saved_mxcsr, options( nostack, readonly) ) ;
467+ }
468+
469+ /// F16C-vectorized f32 → F16 batch with IEEE 754 RNE rounding.
470+ ///
471+ /// 8 F32 lanes per `_mm256_cvtps_ph::<0>` instruction (one ymm load +
472+ /// one xmm store). The const `IMM8 = 0` selects
473+ /// `_MM_FROUND_TO_NEAREST_INT` — round-to-nearest-even, matches the
474+ /// scalar reference [`F16::from_f32_rounded`] bit-for-bit on every
475+ /// input.
476+ ///
477+ /// # IMM8 encoding limit
478+ /// `_mm256_cvtps_ph`'s `IMM8` is 3 bits wide (`static_assert_uimm_bits!
479+ /// (IMM8, 3)` in the Rust stdarch wrapper). Valid values are `0..=3`
480+ /// (the four rounding modes — RNE, down, up, truncate). Bits 2-3 of
481+ /// the underlying VCVTPS2PH IMM8 encoding are "reserved" and "select
482+ /// MXCSR.RM" per Intel SDM — NOT `_MM_FROUND_NO_EXC`, which is an
483+ /// AVX-512 convention (`_mm512_cvtps_ph` accepts `NO_EXC`, F16C does
484+ /// not). Exception suppression is handled at the MXCSR level (below).
485+ ///
486+ /// # MXCSR preservation
487+ /// `_mm256_cvtps_ph` may raise `#O` (Overflow), `#U` (Underflow),
488+ /// `#P` (Precision), `#I` (Invalid for SNaN), `#D` (Denormal). The
489+ /// scalar reference [`F16::from_f32_rounded`] is pure bit
490+ /// manipulation and never touches MXCSR. We save/restore MXCSR around
491+ /// the SIMD region so callers see no observable control/status side
492+ /// effects regardless of input data. (See codex review on PR #183.)
493+ ///
494+ /// # Safety
495+ /// Caller must have feature-detected `f16c` + `avx` at runtime.
496+ #[ cfg( target_arch = "x86_64" ) ]
497+ #[ target_feature( enable = "f16c,avx" ) ]
498+ unsafe fn cast_f32_to_f16_batch_f16c ( src : & [ f32 ] , dst : & mut [ u16 ] ) {
499+ use core:: arch:: asm;
500+ use core:: arch:: x86_64:: { __m128i, _mm256_cvtps_ph, _mm256_loadu_ps, _mm_storeu_si128} ;
501+ let mut saved_mxcsr: u32 = 0 ;
502+ // SAFETY: STMXCSR writes the 32-bit MXCSR; baseline SSE op.
503+ asm ! ( "stmxcsr [{ptr}]" , ptr = in( reg) & mut saved_mxcsr, options( nostack) ) ;
504+ let n = src. len ( ) . min ( dst. len ( ) ) ;
505+ let chunks = n / 8 ;
506+ for c in 0 ..chunks {
507+ let off = c * 8 ;
508+ let f = _mm256_loadu_ps ( src. as_ptr ( ) . add ( off) ) ;
509+ let h = _mm256_cvtps_ph :: < 0 > ( f) ;
510+ _mm_storeu_si128 ( dst. as_mut_ptr ( ) . add ( off) as * mut __m128i , h) ;
511+ }
512+ // Scalar tail.
513+ for i in ( chunks * 8 ) ..n {
514+ dst[ i] = F16 :: from_f32_rounded ( src[ i] ) . 0 ;
515+ }
516+ // SAFETY: LDMXCSR restores the saved value bit-for-bit.
517+ asm ! ( "ldmxcsr [{ptr}]" , ptr = in( reg) & saved_mxcsr, options( nostack, readonly) ) ;
518+ }
519+
386520// ============================================================================
387521// Tests
388522// ============================================================================
@@ -759,4 +893,81 @@ mod tests {
759893 assert_eq ! ( dst[ i] , expected[ i] , "mul_f16_inplace mismatch at {}" , i) ;
760894 }
761895 }
896+
897+ /// Codex PR #183 P2: F16C `_mm256_cvtps_ph` may raise FP exceptions
898+ /// (#O on overflow, #U on underflow, #P on precision loss, #I on
899+ /// SNaN, #D on denormal input) which set bits in MXCSR. The scalar
900+ /// path is pure bit manipulation and never touches MXCSR. The fix:
901+ /// `cast_f32_to_f16_batch_f16c` saves MXCSR via STMXCSR before the
902+ /// SIMD region and restores it via LDMXCSR after. This test feeds
903+ /// inputs that should trigger every exception bit and asserts
904+ /// MXCSR is byte-identical before vs. after the call.
905+ #[ cfg( target_arch = "x86_64" ) ]
906+ #[ test]
907+ fn f16c_cast_preserves_mxcsr ( ) {
908+ if !std:: is_x86_feature_detected!( "f16c" ) {
909+ eprintln ! ( "f16c not detected; skipping" ) ;
910+ return ;
911+ }
912+ use core:: arch:: asm;
913+
914+ // Inputs designed to trigger #O / #U / #P / #I / #D in F16C
915+ // downcast:
916+ // - 1e30, -1e30 : overflow (out of F16 range ±65504) → #O
917+ // - 1e-30 : underflow / denormal → #U, #D, #P
918+ // - 1.0/3.0 : precision loss → #P
919+ // - f32::NAN : invalid (if it's an sNaN representation) → #I
920+ let inputs: Vec < f32 > = vec ! [
921+ 1e30 ,
922+ -1e30 ,
923+ 1e-30 ,
924+ 1.0 / 3.0 ,
925+ f32 :: NAN ,
926+ f32 :: INFINITY ,
927+ 0.0 ,
928+ 1.0 ,
929+ // Pad to 8 lanes so the SIMD chunk loop fires once with no tail.
930+ ] ;
931+ assert_eq ! ( inputs. len( ) , 8 ) ;
932+ let mut out = vec ! [ F16 :: ZERO ; 8 ] ;
933+
934+ // Snapshot MXCSR before.
935+ let mut mxcsr_before: u32 = 0 ;
936+ unsafe {
937+ asm ! ( "stmxcsr [{ptr}]" , ptr = in( reg) & mut mxcsr_before, options( nostack) ) ;
938+ }
939+
940+ cast_f32_to_f16_batch ( & inputs, & mut out) ;
941+
942+ // Snapshot MXCSR after.
943+ let mut mxcsr_after: u32 = 0 ;
944+ unsafe {
945+ asm ! ( "stmxcsr [{ptr}]" , ptr = in( reg) & mut mxcsr_after, options( nostack) ) ;
946+ }
947+
948+ assert_eq ! (
949+ mxcsr_before, mxcsr_after,
950+ "cast_f32_to_f16_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)" ,
951+ mxcsr_before, mxcsr_after
952+ ) ;
953+
954+ // Same check for the upcast direction (`_mm256_cvtph_ps` can raise
955+ // #I/#D on SNaN/denormal F16 input).
956+ let f16_inputs: Vec < F16 > = ( 0 ..8 ) . map ( |i| F16 ( 0x7C01 + i as u16 ) ) . collect ( ) ; // SNaN-ish
957+ let mut f32_out = vec ! [ 0.0f32 ; 8 ] ;
958+
959+ unsafe {
960+ asm ! ( "stmxcsr [{ptr}]" , ptr = in( reg) & mut mxcsr_before, options( nostack) ) ;
961+ }
962+ cast_f16_to_f32_batch ( & f16_inputs, & mut f32_out) ;
963+ unsafe {
964+ asm ! ( "stmxcsr [{ptr}]" , ptr = in( reg) & mut mxcsr_after, options( nostack) ) ;
965+ }
966+
967+ assert_eq ! (
968+ mxcsr_before, mxcsr_after,
969+ "cast_f16_to_f32_batch must not modify MXCSR (got 0x{:08x} before, 0x{:08x} after)" ,
970+ mxcsr_before, mxcsr_after
971+ ) ;
972+ }
762973}
0 commit comments