|
| 1 | +//! BF16 tile GEMM polyfill — AMX (TDPBF16PS) with AVX-512 F32x16 fallback. |
| 2 | +//! |
| 3 | +//! Same API, runtime tier dispatch via `amx_available()`. The AMX path uses |
| 4 | +//! the raw primitives in `hpc::amx_matmul`. The fallback decodes BF16→f32 |
| 5 | +//! and uses `crate::simd::F32x16` + `mul_add` (VFMADD231PS on AVX-512, |
| 6 | +//! emulated as 2× F32x8 FMA on AVX2). |
| 7 | +//! |
| 8 | +//! Pattern: one dispatch check per call; caller supplies preallocated |
| 9 | +//! output and (for AMX) VNNI-packed B. |
| 10 | +//! |
| 11 | +//! Tile shape: M=16, N=16, K = multiple of 32. |
| 12 | +//! |
| 13 | +//! Usage: |
| 14 | +//! ```ignore |
| 15 | +//! use ndarray::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16; |
| 16 | +//! let mut c = vec![0.0f32; 16*16]; |
| 17 | +//! bf16_tile_gemm_16x16(&a_bf16, &b_bf16_row_major, &mut c, k); |
| 18 | +//! ``` |
| 19 | +
|
| 20 | +use crate::hpc::amx_matmul::{ |
| 21 | + amx_available, TileConfig, tile_loadconfig, tile_zero, |
| 22 | + tile_load, tile_store, tile_release, tile_dpbf16ps, vnni_pack_bf16, |
| 23 | +}; |
| 24 | +use crate::simd::{F32x16, bf16_to_f32_batch}; |
| 25 | + |
| 26 | +// ═════════════════════════════════════════════════════════════════════ |
| 27 | +// Public API — safe dispatching wrapper |
| 28 | +// ═════════════════════════════════════════════════════════════════════ |
| 29 | + |
| 30 | +/// Compute C[16, 16] += A[16, K] × B[K, 16] where A, B are BF16 row-major |
| 31 | +/// and C is f32 row-major. K must be a multiple of 32. |
| 32 | +/// |
| 33 | +/// Tier dispatch (runtime): |
| 34 | +/// AMX available → TDPBF16PS tile GEMM (16×16 × K/32 tile iterations) |
| 35 | +/// AMX unavailable → AVX-512 F32x16 FMA fallback (decode BF16→f32, gemm) |
| 36 | +/// |
| 37 | +/// Both paths produce identical results up to BF16 precision (~1/128 per |
| 38 | +/// multiply, O(sqrt(K)) accumulated). |
| 39 | +pub fn bf16_tile_gemm_16x16(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) { |
| 40 | + assert_eq!(k % 32, 0, "K must be multiple of 32"); |
| 41 | + assert_eq!(a_bf16.len(), 16 * k); |
| 42 | + assert_eq!(b_bf16.len(), k * 16); |
| 43 | + assert_eq!(c.len(), 16 * 16); |
| 44 | + |
| 45 | + if amx_available() { |
| 46 | + // AMX path: pack B into VNNI, call tile GEMM |
| 47 | + let mut b_vnni = vec![0u16; k * 16]; |
| 48 | + vnni_pack_bf16(b_bf16, &mut b_vnni, k, 16); |
| 49 | + // SAFETY: amx_available() just confirmed CPUID + XCR0 + prctl. |
| 50 | + unsafe { amx_path(a_bf16, &b_vnni, c, k); } |
| 51 | + } else { |
| 52 | + fallback_path(a_bf16, b_bf16, c, k); |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +// ═════════════════════════════════════════════════════════════════════ |
| 57 | +// AMX path (TDPBF16PS) |
| 58 | +// ═════════════════════════════════════════════════════════════════════ |
| 59 | + |
| 60 | +/// AMX tile GEMM. B must be pre-VNNI-packed (see `vnni_pack_bf16`). |
| 61 | +/// # Safety |
| 62 | +/// Caller must have verified `amx_available() == true`. |
| 63 | +#[inline] |
| 64 | +unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) { |
| 65 | + // Tile config: shapes at K_bytes=64 match BF16 K=32 case |
| 66 | + let cfg = TileConfig::for_dpbusd(64); |
| 67 | + tile_loadconfig(&cfg); |
| 68 | + tile_zero(0); |
| 69 | + |
| 70 | + // Accumulate over K/32 tile blocks |
| 71 | + let k_blocks = k / 32; |
| 72 | + let a_stride = (k * 2) as usize; // full A row stride in bytes (bf16 = 2B) |
| 73 | + let b_stride = 64usize; // VNNI row stride in bytes |
| 74 | + |
| 75 | + for kb in 0..k_blocks { |
| 76 | + let a_ptr = a_bf16.as_ptr().add(kb * 32) as *const u8; |
| 77 | + let b_ptr = b_vnni.as_ptr().add(kb * 16 * 32) as *const u8; |
| 78 | + tile_load(1, a_ptr, a_stride); |
| 79 | + tile_load(2, b_ptr, b_stride); |
| 80 | + tile_dpbf16ps(); |
| 81 | + } |
| 82 | + |
| 83 | + tile_store(0, c.as_mut_ptr() as *mut u8, 64); |
| 84 | + tile_release(); |
| 85 | +} |
| 86 | + |
| 87 | +// ═════════════════════════════════════════════════════════════════════ |
| 88 | +// AVX-512 fallback (F32x16 + mul_add FMA) |
| 89 | +// ═════════════════════════════════════════════════════════════════════ |
| 90 | + |
| 91 | +/// Fallback: decode BF16→f32 and run a tight F32x16 GEMM with mul_add FMA. |
| 92 | +/// When AVX-512 is the compile-time baseline, this uses native __m512 FMA; |
| 93 | +/// on AVX2 it uses the emulated F32x16 = (F32x8, F32x8) pair — same logic. |
| 94 | +fn fallback_path(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) { |
| 95 | + // Decode BF16 → f32 (batch via SIMD when avx512bf16 / avx2 available) |
| 96 | + let mut a_f32 = vec![0.0f32; a_bf16.len()]; |
| 97 | + let mut b_f32 = vec![0.0f32; b_bf16.len()]; |
| 98 | + bf16_to_f32_batch(a_bf16, &mut a_f32); |
| 99 | + bf16_to_f32_batch(b_bf16, &mut b_f32); |
| 100 | + |
| 101 | + // Tight GEMM: for each output (i,j), dot row-of-A with col-of-B via F32x16+FMA. |
| 102 | + // B is row-major [K, 16]; j-th column is b_f32[kk*16 + j] over kk=0..K. |
| 103 | + // We gather the column into a stack-sized buffer once per (i,j) pair to hit |
| 104 | + // the chunks_exact(16) + mul_add fast path on contiguous memory. |
| 105 | + for i in 0..16 { |
| 106 | + let a_row = &a_f32[i * k .. i * k + k]; |
| 107 | + for j in 0..16 { |
| 108 | + // Stream the column into a contiguous buffer |
| 109 | + let mut col = vec![0.0f32; k]; |
| 110 | + for kk in 0..k { col[kk] = b_f32[kk * 16 + j]; } |
| 111 | + |
| 112 | + // Accumulate via F32x16::mul_add (FMA) |
| 113 | + let mut acc = F32x16::splat(0.0); |
| 114 | + for (ra, rb) in a_row.chunks_exact(16).zip(col.chunks_exact(16)) { |
| 115 | + let av = F32x16::from_slice(ra); |
| 116 | + let bv = F32x16::from_slice(rb); |
| 117 | + acc = av.mul_add(bv, acc); |
| 118 | + } |
| 119 | + c[i * 16 + j] += acc.reduce_sum(); |
| 120 | + } |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +// ═════════════════════════════════════════════════════════════════════ |
| 125 | +// Tests |
| 126 | +// ═════════════════════════════════════════════════════════════════════ |
| 127 | + |
| 128 | +#[cfg(test)] |
| 129 | +mod tests { |
| 130 | + use super::*; |
| 131 | + use crate::simd::{f32_to_bf16_batch, bf16_to_f32_batch}; |
| 132 | + |
| 133 | + /// Scalar BF16 reference (f32-accumulated) — ground truth. |
| 134 | + fn ref_gemm(a: &[f32], b: &[f32], c: &mut [f32], k: usize) { |
| 135 | + for i in 0..16 { |
| 136 | + for j in 0..16 { |
| 137 | + let mut s = 0.0f32; |
| 138 | + for kk in 0..k { |
| 139 | + s += a[i * k + kk] * b[kk * 16 + j]; |
| 140 | + } |
| 141 | + c[i * 16 + j] = s; |
| 142 | + } |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + #[test] |
| 147 | + fn fallback_matches_scalar_reference_k64() { |
| 148 | + let k = 64; |
| 149 | + // Deterministic pseudo-random inputs |
| 150 | + let mut a_f32 = vec![0.0f32; 16 * k]; |
| 151 | + let mut b_f32 = vec![0.0f32; k * 16]; |
| 152 | + for i in 0..a_f32.len() { |
| 153 | + a_f32[i] = (((i as i32).wrapping_mul(1103515245).wrapping_add(12345) >> 8) as f32 |
| 154 | + / 2147483648.0).clamp(-1.0, 1.0); |
| 155 | + } |
| 156 | + for i in 0..b_f32.len() { |
| 157 | + b_f32[i] = (((i as i32).wrapping_mul(69069).wrapping_add(1) >> 8) as f32 |
| 158 | + / 2147483648.0).clamp(-1.0, 1.0); |
| 159 | + } |
| 160 | + let mut a_bf16 = vec![0u16; a_f32.len()]; |
| 161 | + let mut b_bf16 = vec![0u16; b_f32.len()]; |
| 162 | + f32_to_bf16_batch(&a_f32, &mut a_bf16); |
| 163 | + f32_to_bf16_batch(&b_f32, &mut b_bf16); |
| 164 | + |
| 165 | + // Reference uses bf16-truncated inputs (matches what the GEMM sees) |
| 166 | + let mut a_back = vec![0.0f32; a_f32.len()]; |
| 167 | + let mut b_back = vec![0.0f32; b_f32.len()]; |
| 168 | + bf16_to_f32_batch(&a_bf16, &mut a_back); |
| 169 | + bf16_to_f32_batch(&b_bf16, &mut b_back); |
| 170 | + let mut c_ref = vec![0.0f32; 16 * 16]; |
| 171 | + ref_gemm(&a_back, &b_back, &mut c_ref, k); |
| 172 | + |
| 173 | + // Fallback GEMM |
| 174 | + let mut c_fb = vec![0.0f32; 16 * 16]; |
| 175 | + fallback_path(&a_bf16, &b_bf16, &mut c_fb, k); |
| 176 | + |
| 177 | + // Compare — should match exactly (same arithmetic, f32 precision) |
| 178 | + let mut max_err = 0.0f32; |
| 179 | + for i in 0..(16 * 16) { |
| 180 | + let e = (c_fb[i] - c_ref[i]).abs(); |
| 181 | + if e > max_err { max_err = e; } |
| 182 | + } |
| 183 | + assert!(max_err < 1e-3, "fallback vs scalar ref max_err = {}", max_err); |
| 184 | + } |
| 185 | + |
| 186 | + #[test] |
| 187 | + fn public_api_runs_on_any_hardware() { |
| 188 | + // Just sanity: calling the public API doesn't panic regardless of AMX. |
| 189 | + // On AMX hardware it takes the tile path; on this test host likely fallback. |
| 190 | + let k = 32; |
| 191 | + let a = vec![0u16; 16 * k]; |
| 192 | + let b = vec![0u16; k * 16]; |
| 193 | + let mut c = vec![0.0f32; 16 * 16]; |
| 194 | + bf16_tile_gemm_16x16(&a, &b, &mut c, k); |
| 195 | + // All zeros × all zeros = 0 |
| 196 | + for v in c.iter() { assert_eq!(*v, 0.0); } |
| 197 | + } |
| 198 | +} |
0 commit comments