@@ -297,8 +297,15 @@ fn write_contig<A: Copy>(view: &mut ArrayViewMut2<'_, A>, src: &[A]) {
297297
298298/// Matrix multiply BF16 × BF16 → f32: `out = lhs · rhs`.
299299///
300- /// Uses AMX `TDPBF16PS` (256 mul-adds per instruction) when available,
301- /// otherwise falls back to [`bf16_gemm_f32`].
300+ /// On AMX hardware (Sapphire Rapids+, Granite Rapids), 16×16-aligned tiles
301+ /// dispatch to [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] which
302+ /// emits `TDPBF16PS` via the asm-byte path in `simd_amx.rs` — 256
303+ /// BF16×BF16 multiply-accumulates per instruction (16×16×32 = 8 192 FLOPs)
304+ /// into f32 accumulator tiles. M/N/K tail blocks (when any dim isn't
305+ /// 16/16/32-aligned) fall through to the validated scalar
306+ /// [`crate::hpc::quantized::bf16_gemm_f32`] reference.
307+ ///
308+ /// On non-AMX hosts the entire matmul goes through `bf16_gemm_f32`.
302309///
303310/// `out` must be row-contiguous (column stride = 1); inputs may be strided.
304311pub fn matmul_bf16_to_f32 (
@@ -310,26 +317,180 @@ pub fn matmul_bf16_to_f32(
310317 let b = pack_contig ( & rhs) ;
311318 let mut c = vec ! [ 0.0f32 ; m * n] ;
312319
313- // AMX path: a tiled 16×16 kernel exists in `bf16_tile_gemm` for sizes that
314- // fit cleanly. For any leftover tail (or hosts without AMX), defer to the
315- // scalar `bf16_gemm_f32`. The tile kernel itself is maintained alongside
316- // the low-level primitives at the top of this file; the public surface
317- // intentionally goes through the validated scalar path so we always
318- // produce a numerically-stable f32 result.
319- if amx_available ( ) {
320- // Future: AMX-tiled fast path. Today we route through the same
321- // f32 reference kernel; correctness is identical regardless of
322- // hardware. The `amx_available()` branch is preserved so callers
323- // can be sure the AMX detection runs.
324- bf16_gemm_f32 ( & a, & b, & mut c, m, n, k, 1.0 , 0.0 ) ;
325- } else {
326- bf16_gemm_f32 ( & a, & b, & mut c, m, n, k, 1.0 , 0.0 ) ;
327- }
320+ bf16_gemm_dispatch ( & a, & b, & mut c, m, n, k) ;
328321
329322 write_contig ( & mut out, & c) ;
330323 Ok ( ( ) )
331324}
332325
326+ /// BF16 × BF16 → f32 GEMM with three-tier dispatch (AMX → VDPBF16PS → scalar).
327+ ///
328+ /// Inputs are packed row-major (`a` is M × K, `b` is K × N). Output `c`
329+ /// is M × N row-major and is overwritten (not accumulated).
330+ ///
331+ /// Tier selection:
332+ ///
333+ /// 1. **AMX `TDPBF16PS`** (Sapphire Rapids+, Granite Rapids) when
334+ /// `amx_available()` is true AND shapes are 16/16/32-aligned.
335+ /// Dispatches through
336+ /// [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] →
337+ /// `simd_amx::tile_dpbf16ps` via asm-byte (`TDPBF16PS` intrinsic is
338+ /// nightly-only on Rust 1.95). 8 192 BF16×BF16 multiplies + 256 f32
339+ /// accumulates per instruction.
340+ /// 2. **`VDPBF16PS`** (Cooper Lake, Cascade Lake AVX-512BF16, Zen 4+)
341+ /// when `is_x86_feature_detected!("avx512bf16")` is true. The
342+ /// intrinsic `_mm512_dpbf16_ps` is stable on Rust 1.95 (no asm-byte
343+ /// needed). Per instruction: 32 BF16×BF16 multiplies + 16 f32
344+ /// accumulates, single-rounded. Handles arbitrary shapes — M / N
345+ /// tails fall through the per-iteration j-block trimming; K-tail
346+ /// (odd K) is handled with a final scalar pair.
347+ /// 3. **Scalar reference** [`bf16_gemm_f32`] for hosts without either
348+ /// extension or for shapes the AMX arm rejects.
349+ ///
350+ /// The per-tier dispatch table comes from PR #180's BF16 GEMM column.
351+ fn bf16_gemm_dispatch ( a : & [ BF16 ] , b : & [ BF16 ] , c : & mut [ f32 ] , m : usize , n : usize , k : usize ) {
352+ if amx_available ( ) && m % 16 == 0 && n % 16 == 0 && k % 32 == 0 {
353+ // SAFETY: BF16 is `#[repr(transparent)] struct BF16(pub u16)`
354+ // (per `hpc::quantized::BF16`). Reinterpreting `&[BF16]` as
355+ // `&[u16]` is bit-pattern preserving.
356+ let a_u16: & [ u16 ] = unsafe { core:: slice:: from_raw_parts ( a. as_ptr ( ) as * const u16 , a. len ( ) ) } ;
357+
358+ // B is packed row-major K × N; the 16×16 tile kernel wants a
359+ // K × 16 contiguous sub-block. Extract per (j_tile) into a
360+ // scratch buffer once and reuse across i_tile.
361+ let mut b_tile = vec ! [ 0u16 ; k * 16 ] ;
362+ let mut tile_c = vec ! [ 0.0f32 ; 256 ] ;
363+
364+ for j_tile in ( 0 ..n) . step_by ( 16 ) {
365+ // Pack b[0..k, j_tile..j_tile+16] into row-major 16-wide K-rows.
366+ for kk in 0 ..k {
367+ let row = kk * n + j_tile;
368+ for jj in 0 ..16 {
369+ b_tile[ kk * 16 + jj] = b[ row + jj] . 0 ;
370+ }
371+ }
372+ for i_tile in ( 0 ..m) . step_by ( 16 ) {
373+ // A_tile = a[i_tile..i_tile+16, 0..k] — already contiguous
374+ // since `a` is packed row-major M × K.
375+ let a_tile = & a_u16[ i_tile * k..( i_tile + 16 ) * k] ;
376+ tile_c. fill ( 0.0 ) ;
377+ crate :: hpc:: bf16_tile_gemm:: bf16_tile_gemm_16x16 ( a_tile, & b_tile, & mut tile_c, k) ;
378+ // Write tile_c (16 × 16, row-major) into c (M × N, row-major).
379+ for ii in 0 ..16 {
380+ let dst_off = ( i_tile + ii) * n + j_tile;
381+ c[ dst_off..dst_off + 16 ] . copy_from_slice ( & tile_c[ ii * 16 ..( ii + 1 ) * 16 ] ) ;
382+ }
383+ }
384+ }
385+ return ;
386+ }
387+
388+ #[ cfg( target_arch = "x86_64" ) ]
389+ {
390+ if std:: is_x86_feature_detected!( "avx512bf16" ) {
391+ // SAFETY: feature-detected at runtime; the kernel is
392+ // `#[target_feature(enable = "avx512bf16,avx512f")]`.
393+ unsafe {
394+ bf16_gemm_vdpbf16ps ( a, b, c, m, n, k) ;
395+ }
396+ return ;
397+ }
398+ }
399+
400+ bf16_gemm_f32 ( a, b, c, m, n, k, 1.0 , 0.0 ) ;
401+ }
402+
403+ /// AVX-512BF16 BF16 GEMM using `_mm512_dpbf16_ps` (`VDPBF16PS`).
404+ ///
405+ /// One VDPBF16PS instruction: 16 f32 accumulator lanes each receive
406+ /// `acc[j] += a.bf16[2j] * b.bf16[2j] + a.bf16[2j+1] * b.bf16[2j+1]`,
407+ /// single-rounded. The kernel maps the 16 output lanes to a row of 16
408+ /// j-columns of C[i, ·], with one i row processed at a time and a K-pair
409+ /// inner loop accumulating into the same 16 f32 lanes across iterations.
410+ ///
411+ /// B-column packing: VDPBF16PS wants the 32 B BF16s per call laid out
412+ /// as 16 lane-pairs (lane j contains `B[2k_pair, j_base+j]` followed by
413+ /// `B[2k_pair+1, j_base+j]`, packed into one u32). We pre-pack B for
414+ /// the current j-block into `b_col_pairs[k_pair * 16 + j] = u32` once
415+ /// per j_block and reuse across all i — amortizes the gather cost.
416+ ///
417+ /// K-tail (when K is odd) is handled with a final scalar BF16 multiply
418+ /// per output cell; N-tail (when the j-block has < 16 valid columns)
419+ /// is handled by trimming the store after the VDPBF16PS chain.
420+ ///
421+ /// # Safety
422+ /// Caller must have feature-detected `avx512bf16` at runtime.
423+ #[ cfg( target_arch = "x86_64" ) ]
424+ #[ target_feature( enable = "avx512bf16,avx512f" ) ]
425+ unsafe fn bf16_gemm_vdpbf16ps ( a : & [ BF16 ] , b : & [ BF16 ] , c : & mut [ f32 ] , m : usize , n : usize , k : usize ) {
426+ use core:: arch:: x86_64:: {
427+ __m512bh, __m512i, _mm512_dpbf16_ps, _mm512_loadu_si512, _mm512_set1_epi32, _mm512_setzero_ps, _mm512_storeu_ps,
428+ } ;
429+
430+ let k_pairs = k / 2 ;
431+ let k_tail = k % 2 ;
432+
433+ // SAFETY: BF16 is repr(transparent) over u16.
434+ let a_u16: & [ u16 ] = core:: slice:: from_raw_parts ( a. as_ptr ( ) as * const u16 , a. len ( ) ) ;
435+ let b_u16: & [ u16 ] = core:: slice:: from_raw_parts ( b. as_ptr ( ) as * const u16 , b. len ( ) ) ;
436+
437+ // Pre-pack scratch: 16 u32 lanes per k_pair, holding (b_lo | b_hi << 16).
438+ let mut b_col_pairs = vec ! [ 0u32 ; k_pairs. max( 1 ) * 16 ] ;
439+ // Scratch for the 16-wide store + N-tail trim.
440+ let mut out_buf = [ 0.0f32 ; 16 ] ;
441+
442+ for j_base in ( 0 ..n) . step_by ( 16 ) {
443+ let j_count = 16 . min ( n - j_base) ;
444+
445+ // Pack B columns [j_base..j_base+j_count] in pair-interleaved layout.
446+ // For lanes j >= j_count (the N-tail of this j_block), pad with 0 —
447+ // they're not stored back, but the VDPBF16PS still touches them.
448+ for k_pair in 0 ..k_pairs {
449+ let row_lo = 2 * k_pair * n;
450+ let row_hi = ( 2 * k_pair + 1 ) * n;
451+ for jj in 0 ..j_count {
452+ let b_lo = b_u16[ row_lo + j_base + jj] as u32 ;
453+ let b_hi = b_u16[ row_hi + j_base + jj] as u32 ;
454+ b_col_pairs[ k_pair * 16 + jj] = ( b_hi << 16 ) | b_lo;
455+ }
456+ for jj in j_count..16 {
457+ b_col_pairs[ k_pair * 16 + jj] = 0 ;
458+ }
459+ }
460+
461+ for i in 0 ..m {
462+ let mut acc = _mm512_setzero_ps ( ) ;
463+ let a_row_off = i * k;
464+ for k_pair in 0 ..k_pairs {
465+ // Broadcast A[i, 2k_pair..2k_pair+2] as the (BF16 lo, BF16 hi)
466+ // pair across all 16 lanes.
467+ let a_lo = a_u16[ a_row_off + 2 * k_pair] as u32 ;
468+ let a_hi = a_u16[ a_row_off + 2 * k_pair + 1 ] as u32 ;
469+ let pair = ( a_hi << 16 ) | a_lo;
470+ let a_bh: __m512bh = core:: mem:: transmute ( _mm512_set1_epi32 ( pair as i32 ) ) ;
471+ let b_bh: __m512bh =
472+ core:: mem:: transmute ( _mm512_loadu_si512 ( b_col_pairs. as_ptr ( ) . add ( k_pair * 16 ) as * const __m512i ) ) ;
473+ acc = _mm512_dpbf16_ps ( acc, a_bh, b_bh) ;
474+ }
475+ _mm512_storeu_ps ( out_buf. as_mut_ptr ( ) , acc) ;
476+
477+ // K-tail: one extra scalar BF16 multiply for k = k_pairs*2.
478+ if k_tail == 1 {
479+ let a_last_f32 = BF16 ( a_u16[ a_row_off + k - 1 ] ) . to_f32 ( ) ;
480+ let tail_row = ( k - 1 ) * n;
481+ for jj in 0 ..j_count {
482+ let b_last_f32 = BF16 ( b_u16[ tail_row + j_base + jj] ) . to_f32 ( ) ;
483+ out_buf[ jj] += a_last_f32 * b_last_f32;
484+ }
485+ }
486+
487+ // Store the j_count valid lanes (drops N-tail padding lanes).
488+ let dst_off = i * n + j_base;
489+ c[ dst_off..dst_off + j_count] . copy_from_slice ( & out_buf[ ..j_count] ) ;
490+ }
491+ }
492+ }
493+
333494// ── f32 → f32 (BF16 compute on AMX) ────────────────────────────────────────
334495
335496/// Matrix multiply f32 × f32 → f32: `out = lhs · rhs`.
@@ -349,10 +510,13 @@ pub fn matmul_f32(
349510 let mut c = vec ! [ 0.0f32 ; m * n] ;
350511
351512 if amx_available ( ) {
352- // AMX path: down-cast to BF16, run BF16 GEMM, accumulate in f32.
513+ // AMX path: down-cast to BF16 (RNE, ~1 ULP at BF16 mantissa
514+ // precision), then dispatch through the shared BF16 helper
515+ // which picks `TDPBF16PS` tile kernel for 16/16/32-aligned
516+ // shapes and the scalar `bf16_gemm_f32` reference otherwise.
353517 let a_bf16: Vec < BF16 > = a_f32. iter ( ) . map ( |& v| BF16 :: from_f32_rounded ( v) ) . collect ( ) ;
354518 let b_bf16: Vec < BF16 > = b_f32. iter ( ) . map ( |& v| BF16 :: from_f32_rounded ( v) ) . collect ( ) ;
355- bf16_gemm_f32 ( & a_bf16, & b_bf16, & mut c, m, n, k, 1.0 , 0.0 ) ;
519+ bf16_gemm_dispatch ( & a_bf16, & b_bf16, & mut c, m, n, k) ;
356520 } else {
357521 // Pure f32 reference path.
358522 for i in 0 ..m {
0 commit comments