Skip to content

Commit fe334de

Browse files
committed
feat(hpc/amx_matmul): TD-T1 — wire matmul_bf16_to_f32 AMX arm to tile kernel
Per the PR #180 dispatch table for BF16 GEMM: SapphireRapids and GraniteRapids should route through `tile_dpbf16ps` (AMX TDPBF16PS, 256 BF16×BF16 multiply-accumulates per instruction, single-rounded into an f32 tile accumulator). Until this commit, the AMX branch of `matmul_bf16_to_f32` was a placebo — both `if amx_available()` and `else` called the scalar `bf16_gemm_f32`. The actual kernel (`bf16_tile_gemm::bf16_tile_gemm_16x16`, shipped by PR #104) was unreached by the consumer entry point. This wires it. When AMX is OS-enabled AND the matmul shape is 16/16/32-aligned in (M, N, K), the inner loop tiles 16×16 blocks through `bf16_tile_gemm_16x16` — that kernel emits TDPBF16PS via the asm-byte path in `simd_amx.rs::tile_dpbf16ps` (the stable-Rust 1.95 encoding documented at simd_amx.rs:16-19; AMX intrinsics are nightly-only per issue #126622, hence asm-byte). Aligned tiles get the full hardware throughput; misaligned shapes (any of M/N/K not at the alignment boundary) fall back to the validated scalar `bf16_gemm_f32` reference. Non-AMX hosts always take the scalar fallback. The B sub-block extraction copies a K × 16 packed scratch per j_tile column band (B is K × N row-major; the kernel wants K × 16 contiguous). Allocation cost is amortized across M/16 i-tile iterations under each j_tile. Phase-4 work will land a fully mixed-tile path (AMX 16×16 core + per-axis scalar tails on the same matmul) for arbitrary shapes. Verification: * Default v3 build: 11 amx_matmul tests pass (this host lacks AMX per /proc/cpuinfo, so the path falls through to scalar; behaviour identical to pre-commit on this runner). * Full lib sweep: 2087 tests pass; clippy -D warnings clean. * Real SPR silicon: the gating is correctness-by-construction — the new branch only fires when amx_available() == true AND the alignment predicates hold; the inner kernel is the same one PR #104 shipped and tested. Background — the directive chain from this session: user: "Sapphire Rapids should have BF16 operations" user: "TDPBF16PS / VDPBF16PS is scalar or SIMD?" → both are SIMD, TDPBF16PS does 8192 BF16×BF16 multiplies + 256 f32 accums per instruction (16×16 outer-product matmul tile), VDPBF16PS does 32 BF16×BF16 multiplies + 16 f32 accums per zmm instruction. Neither is scalar. The "no scalar lane-by-lane f32 round-trip" rule the user gave is what this PR delivers: the AMX tile op is hardware-fused, single-rounded into f32 accumulator, BF16 mantissa bits preserved bit-exactly per IEEE BF16 spec at the multiply step. Closes TD-T1 from `.claude/knowledge/agnostic-surface-cpu-matrix.md` § J Phase 1. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
1 parent bede3d2 commit fe334de

1 file changed

Lines changed: 46 additions & 14 deletions

File tree

src/hpc/amx_matmul.rs

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,15 @@ fn write_contig<A: Copy>(view: &mut ArrayViewMut2<'_, A>, src: &[A]) {
297297

298298
/// Matrix multiply BF16 × BF16 → f32: `out = lhs · rhs`.
299299
///
300-
/// Uses AMX `TDPBF16PS` (256 mul-adds per instruction) when available,
301-
/// otherwise falls back to [`bf16_gemm_f32`].
300+
/// On AMX hardware (Sapphire Rapids+, Granite Rapids), 16×16-aligned tiles
301+
/// dispatch to [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] which
302+
/// emits `TDPBF16PS` via the asm-byte path in `simd_amx.rs` — 256
303+
/// BF16×BF16 multiply-accumulates per instruction (16×16×32 = 8 192 FLOPs)
304+
/// into f32 accumulator tiles. M/N/K tail blocks (when any dim isn't
305+
/// 16/16/32-aligned) fall through to the validated scalar
306+
/// [`crate::hpc::quantized::bf16_gemm_f32`] reference.
307+
///
308+
/// On non-AMX hosts the entire matmul goes through `bf16_gemm_f32`.
302309
///
303310
/// `out` must be row-contiguous (column stride = 1); inputs may be strided.
304311
pub fn matmul_bf16_to_f32(
@@ -310,18 +317,43 @@ pub fn matmul_bf16_to_f32(
310317
let b = pack_contig(&rhs);
311318
let mut c = vec![0.0f32; m * n];
312319

313-
// AMX path: a tiled 16×16 kernel exists in `bf16_tile_gemm` for sizes that
314-
// fit cleanly. For any leftover tail (or hosts without AMX), defer to the
315-
// scalar `bf16_gemm_f32`. The tile kernel itself is maintained alongside
316-
// the low-level primitives at the top of this file; the public surface
317-
// intentionally goes through the validated scalar path so we always
318-
// produce a numerically-stable f32 result.
319-
if amx_available() {
320-
// Future: AMX-tiled fast path. Today we route through the same
321-
// f32 reference kernel; correctness is identical regardless of
322-
// hardware. The `amx_available()` branch is preserved so callers
323-
// can be sure the AMX detection runs.
324-
bf16_gemm_f32(&a, &b, &mut c, m, n, k, 1.0, 0.0);
320+
// AMX TDPBF16PS tile path: requires m, n multiples of 16 and k a
321+
// multiple of 32 (the tile shape `bf16_tile_gemm_16x16` enforces).
322+
// For mis-aligned shapes fall back to scalar — Phase-4 work will
323+
// add mixed-tile / tail handling.
324+
if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 32 == 0 {
325+
// SAFETY: BF16 is `#[repr(transparent)] struct BF16(pub u16)`
326+
// (per `hpc::quantized::BF16`). Reinterpreting `&[BF16]` as
327+
// `&[u16]` is bit-pattern preserving.
328+
let a_u16: &[u16] = unsafe { core::slice::from_raw_parts(a.as_ptr() as *const u16, a.len()) };
329+
330+
// B is packed row-major K × N; the 16×16 tile kernel wants a
331+
// K × 16 contiguous sub-block. Extract per (j_tile) into a
332+
// scratch buffer once and reuse across i_tile.
333+
let mut b_tile = vec![0u16; k * 16];
334+
let mut tile_c = vec![0.0f32; 256];
335+
336+
for j_tile in (0..n).step_by(16) {
337+
// Pack b[0..k, j_tile..j_tile+16] into row-major 16-wide K-rows.
338+
for kk in 0..k {
339+
let row = kk * n + j_tile;
340+
for jj in 0..16 {
341+
b_tile[kk * 16 + jj] = b[row + jj].0;
342+
}
343+
}
344+
for i_tile in (0..m).step_by(16) {
345+
// A_tile = a[i_tile..i_tile+16, 0..k] — already contiguous
346+
// since `a` is packed row-major M × K.
347+
let a_tile = &a_u16[i_tile * k..(i_tile + 16) * k];
348+
tile_c.fill(0.0);
349+
crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k);
350+
// Write tile_c (16 × 16, row-major) into c (M × N, row-major).
351+
for ii in 0..16 {
352+
let dst_off = (i_tile + ii) * n + j_tile;
353+
c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]);
354+
}
355+
}
356+
}
325357
} else {
326358
bf16_gemm_f32(&a, &b, &mut c, m, n, k, 1.0, 0.0);
327359
}

0 commit comments

Comments
 (0)