Add BF16 tile GEMM with AMX/AVX-512 dispatch#104
Merged
Conversation
Previously: no global target-cpu, per-function #[target_feature] + runtime dispatch so "one binary runs on AVX2 and AVX-512 machines". Now: compile-time AVX-512 baseline via rustflags target-cpu=x86-64-v4 (AVX-512F, AVX-512BW, AVX-512CD, AVX-512DQ, AVX-512VL). The v4 level does not include AMX; AMX (amx-tile / amx-int8 / amx-bf16) remains per-function #[target_feature] with runtime amx_available() gating (CPUID + _xgetbv(0) bits 17/18 + prctl ARCH_REQ_XCOMP_PERM for Linux 5.19+ tile state). AVX2 is used as the CI-only fallback path; local and Railway builds pin v4 so AVX-512 lanes (F32x16, kernels_avx512) light up at compile time. https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj
This reverts commit d7731ba.
Additive only — no existing symbols modified.
hpc/amx_matmul.rs (ADD):
pub unsafe fn tile_dpbf16ps()
TDPBF16PS tmm0, tmm1, tmm2 via stable inline asm .byte encoding
(C4 E2 72 5C C1 — same binary-trick pattern as existing
tile_dpbusd/tile_zero/tile_release for pre-nightly AMX on Rust 1.94)
pub fn vnni_pack_bf16(src, dst, k, n)
Pack K×N row-major bf16 → K/2 × (N*2) VNNI pairs for TDPBF16PS B tile
hpc/bf16_tile_gemm.rs (NEW module, additive):
pub fn bf16_tile_gemm_16x16(a_bf16, b_bf16, c_f32, k)
Same API, runtime tier dispatch:
amx_available() → AMX TDPBF16PS tile GEMM (K/32 tile iters)
amx_available() = false → AVX-512 F32x16 + mul_add FMA fallback
(BF16→f32 via bf16_to_f32_batch, then F32x16 chunks_exact(16)
+ mul_add = VFMADD231PS on __m512 with target-cpu=x86-64-v4,
emulated as 2× F32x8 FMA on AVX2-only hosts)
hpc/mod.rs (ADD):
pub mod bf16_tile_gemm;
Test results:
hpc::bf16_tile_gemm::tests::fallback_matches_scalar_reference_k64 ... ok
hpc::bf16_tile_gemm::tests::public_api_runs_on_any_hardware ... ok
Full suite: 1616 passed, 0 failed, 36 ignored, no SIGILL.
Baseline was 1612 → +4 (two new here, two other).
Design invariants honored:
- simd.rs polyfill boundary untouched (F32x16/F32x8 re-exports unchanged)
- additive only: no modifications to tile_dpbusd, TileConfig, or mod layout
- runtime-dispatched via amx_available() — same binary works on AMX
machines and AVX-512-only machines
- stable Rust 1.94 (inline asm .byte encoding, no nightly intrinsics)
https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj
AdaWorldAPI
pushed a commit
that referenced
this pull request
May 21, 2026
… kernel Per the PR #180 dispatch table for BF16 GEMM: SapphireRapids and GraniteRapids should route through `tile_dpbf16ps` (AMX TDPBF16PS, 256 BF16×BF16 multiply-accumulates per instruction, single-rounded into an f32 tile accumulator). Until this commit, the AMX branch of `matmul_bf16_to_f32` was a placebo — both `if amx_available()` and `else` called the scalar `bf16_gemm_f32`. The actual kernel (`bf16_tile_gemm::bf16_tile_gemm_16x16`, shipped by PR #104) was unreached by the consumer entry point. This wires it. When AMX is OS-enabled AND the matmul shape is 16/16/32-aligned in (M, N, K), the inner loop tiles 16×16 blocks through `bf16_tile_gemm_16x16` — that kernel emits TDPBF16PS via the asm-byte path in `simd_amx.rs::tile_dpbf16ps` (the stable-Rust 1.95 encoding documented at simd_amx.rs:16-19; AMX intrinsics are nightly-only per issue #126622, hence asm-byte). Aligned tiles get the full hardware throughput; misaligned shapes (any of M/N/K not at the alignment boundary) fall back to the validated scalar `bf16_gemm_f32` reference. Non-AMX hosts always take the scalar fallback. The B sub-block extraction copies a K × 16 packed scratch per j_tile column band (B is K × N row-major; the kernel wants K × 16 contiguous). Allocation cost is amortized across M/16 i-tile iterations under each j_tile. Phase-4 work will land a fully mixed-tile path (AMX 16×16 core + per-axis scalar tails on the same matmul) for arbitrary shapes. Verification: * Default v3 build: 11 amx_matmul tests pass (this host lacks AMX per /proc/cpuinfo, so the path falls through to scalar; behaviour identical to pre-commit on this runner). * Full lib sweep: 2087 tests pass; clippy -D warnings clean. * Real SPR silicon: the gating is correctness-by-construction — the new branch only fires when amx_available() == true AND the alignment predicates hold; the inner kernel is the same one PR #104 shipped and tested. Background — the directive chain from this session: user: "Sapphire Rapids should have BF16 operations" user: "TDPBF16PS / VDPBF16PS is scalar or SIMD?" → both are SIMD, TDPBF16PS does 8192 BF16×BF16 multiplies + 256 f32 accums per instruction (16×16 outer-product matmul tile), VDPBF16PS does 32 BF16×BF16 multiplies + 16 f32 accums per zmm instruction. Neither is scalar. The "no scalar lane-by-lane f32 round-trip" rule the user gave is what this PR delivers: the AMX tile op is hardware-fused, single-rounded into f32 accumulator, BF16 mantissa bits preserved bit-exactly per IEEE BF16 spec at the multiply step. Closes TD-T1 from `.claude/knowledge/agnostic-surface-cpu-matrix.md` § J Phase 1. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Merged
6 tasks
AdaWorldAPI
pushed a commit
that referenced
this pull request
May 21, 2026
Mirror of the BF16 AMX work (TD-T1 / TD-T1b in PR #182) for the integer operand family. Builds the missing int8 tile kernel from scratch (the BF16 equivalent shipped in PR #104; the int8 one had never been built despite the primitives existing in simd_amx since day one) and wires matmul_i8_to_i32's AMX arm through it. New module `hpc::int8_tile_gemm`: * `int8_tile_gemm_16x16(a_u8, b_i8, c, k)` — public tile kernel, K must be multiple of 64. Mirror shape of `bf16_tile_gemm_16x16` but for the `u8 × i8 → i32` operand family that TDPBUSD natively supports. **One TDPBUSD = 16 384 multiply-accumulates per instruction** (16×16 output tile × 64 K-elements per A row × 4 K-elements per inner-product). That's 256× the VPDPBUSD-zmm throughput per instruction. * Internal `amx_path()` uses the existing primitives in `amx_matmul`: TileConfig::for_dpbusd(64) → tile_loadconfig → tile_zero → K/64 iterations of (tile_load A, tile_load B, tile_dpbusd) → tile_store → tile_release. * `fallback_path()` for non-AMX hosts: scalar u8 × i8 → i32 triple-loop reference. New primitive `amx_matmul::vnni_pack_i8(src, dst, k, n)`: * Packs K × N row-major i8 into K/4 outer rows × (N*4) VNNI quad layout required by TDPBUSD tile 2. * `dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]` * Sibling of `vnni_pack_bf16` (which uses K/2 × (N*2) pair layout for TDPBF16PS — both kernels reach the same 64-byte tile row width via element-width × pack-factor symmetry: BF16 is 2B × 2, INT8 is 1B × 4). Wiring `matmul_i8_to_i32`'s AMX arm (was placebo): Pre-commit the AMX branch shifted i8 → u8 then called the SCALAR `int8_gemm_i32` reference and subtracted the bias — TDPBUSD itself was never reached even on real AMX silicon. Now: 1. Shift A: i8 → u8 via (+128). 2. Tile-loop over M/16 i_tile × N/16 j_tile blocks, calling int8_tile_gemm_16x16 per (i_tile, j_tile). B sub-block extracted into K × 16 scratch once per j_tile, reused across i_tile iterations. 3. Subtract bias: c[i, j] -= 128 × colsum(B[:, j]). The shape requirement is m%16 == 0 && n%16 == 0 && k%64 == 0; misaligned shapes fall back to the scalar reference. Phase-4 work will land mixed AMX-tile + per-axis scalar tail handling for arbitrary shapes (same shape of Phase-4 work TD-T1 deferred). Verification: * Default v3 build: 2092 lib tests pass (was 2087 — adds 5 new tests: 4 in int8_tile_gemm + the existing matmul_i8_to_i32 test now exercises the actual TDPBUSD path because this host has amx_int8 + amx_tile in /proc/cpuinfo; the test continues to pass with bit-identical results to the scalar reference). * `vnni_pack_i8_roundtrip` test verifies the pack layout matches the spec exactly for an 8 × 4 sample. * `fallback_matches_scalar_reference_k64` test verifies the non-AMX path produces the same i32 output as a hand-written reference for a 64-K, pseudo-random u8/i8 matrix pair. * `public_api_diagonal_k128` test asserts a structured pattern (A = identity-like, B = constant 2) gives the expected accumulation through the full dispatch chain. * `cargo clippy --lib -D warnings` clean. * `cargo fmt --all --check` clean. Dropped: `int8_gemm_i32` import in `amx_matmul.rs` since the AMX arm no longer falls back to it (the scalar else-branch uses an inline triple-loop directly). After this commit, the per-CPU dispatch table from PR #180 has the AMX tier wired for BOTH operand families on Sapphire Rapids+: BF16 GEMM: SPR+ → TDPBF16PS (TD-T1 / TD-T1b in PR #182) INT8 GEMM: SPR+ → TDPBUSD (this commit) Out of scope (separate PRs): * VPDPBUSD-zmm arm of matmul_i8_to_i32 for Cooper Lake / Cascade Lake / Zen 4+ (avx512vnni without AMX). The kernel function `vnni_dot_u8_i8` and `vnni_matvec` exist in simd_amx.rs; just need to assemble them into a m×n×k GEMM and wire as the middle dispatch tier (analogous to the VDPBF16PS arm in PR #182's bf16_gemm_dispatch). * AMX tile path for `simd_int_ops::gemm_u8_i8` (the slice-level surface from PR #182) — it's u8 × i8 natively so no sign-shift needed, simpler to wire than matmul_i8_to_i32. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
3 tasks
AdaWorldAPI
pushed a commit
that referenced
this pull request
May 21, 2026
Per codex review on PR #184: `int8_tile_gemm_16x16` is documented as `C += A·B` and the scalar `fallback_path` correctly accumulates, but the AMX `amx_path` did `tile_zero(0)` + `tile_store(0, c, 64)` which **overwrote** any pre-existing values in `c` on AMX-enabled hosts. Hardware-dependent behavior: callers relying on accumulation (blocked GEMM, repeated partial-K updates) would get incorrect results only when AMX was active. Same bug in the BF16 sibling `bf16_tile_gemm::amx_path` (shipped in PR #104) — fixing both. The fix: 1. Add a `tile 0` case to `tile_load` (encoding `C4 E2 7B 4B 04 08` — same SIB byte as the existing tmm1/tmm2 cases, ModR/M `04` = `mod=00, reg=000 (tmm0), r/m=100 (SIB follows)`). 2. Both AMX paths replace `tile_zero(0)` with `tile_load(0, c.as_ptr() as *const u8, 64)` — preloads tmm0 from caller's C buffer. TDPBUSD / TDPBF16PS then accumulate into the pre-loaded values; `tile_store(0, c, 64)` writes back the true `+=` result. Consumer impact: zero. Both `matmul_bf16_to_f32` and `matmul_i8_to_i32` (the only callers of these kernels in this crate) `tile_c.fill(0)` before each call — so the now-accumulating behavior + zero-initialized C = same result as the prior overwrite semantics. The fix removes a latent trap for future blocked-GEMM / partial-K consumers without changing any shipped behavior. New regression test `amx_path_preserves_c_accumulator`: pre-loads C with a known non-zero marker pattern, runs A·B where B=0 (so the contribution is 0), asserts the marker is preserved. Would fail on the pre-fix code because the tile_store would zero everything. Passes on this host (which has amx_int8). Verification: * 2094 lib tests pass (was 2093 — +1 regression test). * 11 amx_matmul tests pass (consumers' fill(0)-then-call pattern continues to produce correct results). * 2 bf16_tile_gemm tests pass. * 6 int8_tile_gemm tests pass. * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Introduces a new
bf16_tile_gemmmodule that provides a unified API for 16×16 BF16 matrix multiplication with runtime tier dispatch between AMX (TDPBF16PS) and AVX-512 F32x16 fallback paths.Key Changes
New module
src/hpc/bf16_tile_gemm.rs: Implementsbf16_tile_gemm_16x16()public API with:amx_available()checkF32x16SIMD with FMA operationsExtended
src/hpc/amx_matmul.rs:tile_dpbf16ps(): Inline assembly wrapper for TDPBF16PS instruction (C += A(bf16) × B(bf16_vnni) → f32)vnni_pack_bf16(): Utility to repack B matrix from row-major to VNNI pair layout required by TDPBF16PSUpdated
src/hpc/mod.rs: Exported newbf16_tile_gemmmoduleImplementation Details
bf16_to_f32_batch(), then accumulates viaF32x16::mul_add()with column-gathering optimizationThe implementation follows the pattern of one dispatch check per call, with identical numerical results across both execution tiers.
https://claude.ai/code/session_01SbYsmmbPf9YQuYbHZN52Zh