Commit fe334de
committed
feat(hpc/amx_matmul): TD-T1 — wire matmul_bf16_to_f32 AMX arm to tile kernel
Per the PR #180 dispatch table for BF16 GEMM: SapphireRapids and
GraniteRapids should route through `tile_dpbf16ps` (AMX TDPBF16PS,
256 BF16×BF16 multiply-accumulates per instruction, single-rounded
into an f32 tile accumulator). Until this commit, the AMX branch of
`matmul_bf16_to_f32` was a placebo — both `if amx_available()` and
`else` called the scalar `bf16_gemm_f32`. The actual kernel
(`bf16_tile_gemm::bf16_tile_gemm_16x16`, shipped by PR #104) was
unreached by the consumer entry point.
This wires it. When AMX is OS-enabled AND the matmul shape is
16/16/32-aligned in (M, N, K), the inner loop tiles 16×16 blocks
through `bf16_tile_gemm_16x16` — that kernel emits TDPBF16PS via the
asm-byte path in `simd_amx.rs::tile_dpbf16ps` (the stable-Rust 1.95
encoding documented at simd_amx.rs:16-19; AMX intrinsics are
nightly-only per issue #126622, hence asm-byte). Aligned tiles get
the full hardware throughput; misaligned shapes (any of M/N/K not at
the alignment boundary) fall back to the validated scalar
`bf16_gemm_f32` reference. Non-AMX hosts always take the scalar
fallback.
The B sub-block extraction copies a K × 16 packed scratch per
j_tile column band (B is K × N row-major; the kernel wants K × 16
contiguous). Allocation cost is amortized across M/16 i-tile
iterations under each j_tile. Phase-4 work will land a fully
mixed-tile path (AMX 16×16 core + per-axis scalar tails on the
same matmul) for arbitrary shapes.
Verification:
* Default v3 build: 11 amx_matmul tests pass (this host lacks
AMX per /proc/cpuinfo, so the path falls through to scalar;
behaviour identical to pre-commit on this runner).
* Full lib sweep: 2087 tests pass; clippy -D warnings clean.
* Real SPR silicon: the gating is correctness-by-construction —
the new branch only fires when amx_available() == true AND the
alignment predicates hold; the inner kernel is the same one
PR #104 shipped and tested.
Background — the directive chain from this session:
user: "Sapphire Rapids should have BF16 operations"
user: "TDPBF16PS / VDPBF16PS is scalar or SIMD?" → both are SIMD,
TDPBF16PS does 8192 BF16×BF16 multiplies + 256 f32 accums
per instruction (16×16 outer-product matmul tile), VDPBF16PS
does 32 BF16×BF16 multiplies + 16 f32 accums per zmm
instruction. Neither is scalar. The "no scalar lane-by-lane
f32 round-trip" rule the user gave is what this PR delivers:
the AMX tile op is hardware-fused, single-rounded into f32
accumulator, BF16 mantissa bits preserved bit-exactly per
IEEE BF16 spec at the multiply step.
Closes TD-T1 from
`.claude/knowledge/agnostic-surface-cpu-matrix.md` § J Phase 1.
https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u1 parent bede3d2 commit fe334de
1 file changed
Lines changed: 46 additions & 14 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
297 | 297 | | |
298 | 298 | | |
299 | 299 | | |
300 | | - | |
301 | | - | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
302 | 309 | | |
303 | 310 | | |
304 | 311 | | |
| |||
310 | 317 | | |
311 | 318 | | |
312 | 319 | | |
313 | | - | |
314 | | - | |
315 | | - | |
316 | | - | |
317 | | - | |
318 | | - | |
319 | | - | |
320 | | - | |
321 | | - | |
322 | | - | |
323 | | - | |
324 | | - | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
325 | 357 | | |
326 | 358 | | |
327 | 359 | | |
| |||
0 commit comments