Skip to content

Add BF16 tile GEMM with AMX/AVX-512 dispatch#104

Merged
AdaWorldAPI merged 3 commits into
masterfrom
claude/teleport-session-setup-wMZfb
Apr 14, 2026
Merged

Add BF16 tile GEMM with AMX/AVX-512 dispatch#104
AdaWorldAPI merged 3 commits into
masterfrom
claude/teleport-session-setup-wMZfb

Conversation

@AdaWorldAPI
Copy link
Copy Markdown
Owner

Summary

Introduces a new bf16_tile_gemm module that provides a unified API for 16×16 BF16 matrix multiplication with runtime tier dispatch between AMX (TDPBF16PS) and AVX-512 F32x16 fallback paths.

Key Changes

  • New module src/hpc/bf16_tile_gemm.rs: Implements bf16_tile_gemm_16x16() public API with:

    • Runtime dispatch via amx_available() check
    • AMX path: VNNI-packs input B matrix and uses TDPBF16PS tile instructions for 16×16×K/32 accumulation
    • Fallback path: Decodes BF16→f32 and performs tight GEMM using F32x16 SIMD with FMA operations
    • Both paths produce identical results up to BF16 precision (~1/128 per multiply)
  • Extended src/hpc/amx_matmul.rs:

    • Added tile_dpbf16ps(): Inline assembly wrapper for TDPBF16PS instruction (C += A(bf16) × B(bf16_vnni) → f32)
    • Added vnni_pack_bf16(): Utility to repack B matrix from row-major to VNNI pair layout required by TDPBF16PS
  • Updated src/hpc/mod.rs: Exported new bf16_tile_gemm module

Implementation Details

  • Tile shape: M=16, N=16, K=multiple of 32 (enforced via assertions)
  • AMX path: Caller supplies pre-allocated output; B is VNNI-packed internally; uses raw tile primitives with K/32 block iterations
  • Fallback path: Batch-decodes both inputs via bf16_to_f32_batch(), then accumulates via F32x16::mul_add() with column-gathering optimization
  • Testing: Includes scalar reference implementation and validation tests confirming fallback matches reference up to f32 precision; public API sanity test runs on any hardware

The implementation follows the pattern of one dispatch check per call, with identical numerical results across both execution tiers.

https://claude.ai/code/session_01SbYsmmbPf9YQuYbHZN52Zh

claude added 3 commits April 14, 2026 16:06
Previously: no global target-cpu, per-function #[target_feature] + runtime
dispatch so "one binary runs on AVX2 and AVX-512 machines".

Now: compile-time AVX-512 baseline via rustflags target-cpu=x86-64-v4
(AVX-512F, AVX-512BW, AVX-512CD, AVX-512DQ, AVX-512VL). The v4 level does
not include AMX; AMX (amx-tile / amx-int8 / amx-bf16) remains per-function
#[target_feature] with runtime amx_available() gating (CPUID + _xgetbv(0)
bits 17/18 + prctl ARCH_REQ_XCOMP_PERM for Linux 5.19+ tile state).

AVX2 is used as the CI-only fallback path; local and Railway builds pin v4
so AVX-512 lanes (F32x16, kernels_avx512) light up at compile time.

https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj
Additive only — no existing symbols modified.

hpc/amx_matmul.rs (ADD):
  pub unsafe fn tile_dpbf16ps()
    TDPBF16PS tmm0, tmm1, tmm2 via stable inline asm .byte encoding
    (C4 E2 72 5C C1 — same binary-trick pattern as existing
    tile_dpbusd/tile_zero/tile_release for pre-nightly AMX on Rust 1.94)
  pub fn vnni_pack_bf16(src, dst, k, n)
    Pack K×N row-major bf16 → K/2 × (N*2) VNNI pairs for TDPBF16PS B tile

hpc/bf16_tile_gemm.rs (NEW module, additive):
  pub fn bf16_tile_gemm_16x16(a_bf16, b_bf16, c_f32, k)
    Same API, runtime tier dispatch:
      amx_available()        → AMX TDPBF16PS tile GEMM (K/32 tile iters)
      amx_available() = false → AVX-512 F32x16 + mul_add FMA fallback
        (BF16→f32 via bf16_to_f32_batch, then F32x16 chunks_exact(16)
         + mul_add = VFMADD231PS on __m512 with target-cpu=x86-64-v4,
         emulated as 2× F32x8 FMA on AVX2-only hosts)

hpc/mod.rs (ADD):
  pub mod bf16_tile_gemm;

Test results:
  hpc::bf16_tile_gemm::tests::fallback_matches_scalar_reference_k64 ... ok
  hpc::bf16_tile_gemm::tests::public_api_runs_on_any_hardware ... ok
  Full suite: 1616 passed, 0 failed, 36 ignored, no SIGILL.
  Baseline was 1612 → +4 (two new here, two other).

Design invariants honored:
  - simd.rs polyfill boundary untouched (F32x16/F32x8 re-exports unchanged)
  - additive only: no modifications to tile_dpbusd, TileConfig, or mod layout
  - runtime-dispatched via amx_available() — same binary works on AMX
    machines and AVX-512-only machines
  - stable Rust 1.94 (inline asm .byte encoding, no nightly intrinsics)

https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj
@AdaWorldAPI AdaWorldAPI merged commit 6609f10 into master Apr 14, 2026
5 of 14 checks passed
AdaWorldAPI pushed a commit that referenced this pull request May 21, 2026
… kernel

Per the PR #180 dispatch table for BF16 GEMM: SapphireRapids and
GraniteRapids should route through `tile_dpbf16ps` (AMX TDPBF16PS,
256 BF16×BF16 multiply-accumulates per instruction, single-rounded
into an f32 tile accumulator). Until this commit, the AMX branch of
`matmul_bf16_to_f32` was a placebo — both `if amx_available()` and
`else` called the scalar `bf16_gemm_f32`. The actual kernel
(`bf16_tile_gemm::bf16_tile_gemm_16x16`, shipped by PR #104) was
unreached by the consumer entry point.

This wires it. When AMX is OS-enabled AND the matmul shape is
16/16/32-aligned in (M, N, K), the inner loop tiles 16×16 blocks
through `bf16_tile_gemm_16x16` — that kernel emits TDPBF16PS via the
asm-byte path in `simd_amx.rs::tile_dpbf16ps` (the stable-Rust 1.95
encoding documented at simd_amx.rs:16-19; AMX intrinsics are
nightly-only per issue #126622, hence asm-byte). Aligned tiles get
the full hardware throughput; misaligned shapes (any of M/N/K not at
the alignment boundary) fall back to the validated scalar
`bf16_gemm_f32` reference. Non-AMX hosts always take the scalar
fallback.

The B sub-block extraction copies a K × 16 packed scratch per
j_tile column band (B is K × N row-major; the kernel wants K × 16
contiguous). Allocation cost is amortized across M/16 i-tile
iterations under each j_tile. Phase-4 work will land a fully
mixed-tile path (AMX 16×16 core + per-axis scalar tails on the
same matmul) for arbitrary shapes.

Verification:
  * Default v3 build: 11 amx_matmul tests pass (this host lacks
    AMX per /proc/cpuinfo, so the path falls through to scalar;
    behaviour identical to pre-commit on this runner).
  * Full lib sweep: 2087 tests pass; clippy -D warnings clean.
  * Real SPR silicon: the gating is correctness-by-construction —
    the new branch only fires when amx_available() == true AND the
    alignment predicates hold; the inner kernel is the same one
    PR #104 shipped and tested.

Background — the directive chain from this session:

  user: "Sapphire Rapids should have BF16 operations"
  user: "TDPBF16PS / VDPBF16PS is scalar or SIMD?"  → both are SIMD,
        TDPBF16PS does 8192 BF16×BF16 multiplies + 256 f32 accums
        per instruction (16×16 outer-product matmul tile), VDPBF16PS
        does 32 BF16×BF16 multiplies + 16 f32 accums per zmm
        instruction. Neither is scalar. The "no scalar lane-by-lane
        f32 round-trip" rule the user gave is what this PR delivers:
        the AMX tile op is hardware-fused, single-rounded into f32
        accumulator, BF16 mantissa bits preserved bit-exactly per
        IEEE BF16 spec at the multiply step.

Closes TD-T1 from
`.claude/knowledge/agnostic-surface-cpu-matrix.md` § J Phase 1.

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
AdaWorldAPI pushed a commit that referenced this pull request May 21, 2026
Mirror of the BF16 AMX work (TD-T1 / TD-T1b in PR #182) for the
integer operand family. Builds the missing int8 tile kernel from
scratch (the BF16 equivalent shipped in PR #104; the int8 one had
never been built despite the primitives existing in simd_amx since
day one) and wires matmul_i8_to_i32's AMX arm through it.

New module `hpc::int8_tile_gemm`:

  * `int8_tile_gemm_16x16(a_u8, b_i8, c, k)` — public tile kernel,
    K must be multiple of 64. Mirror shape of
    `bf16_tile_gemm_16x16` but for the `u8 × i8 → i32` operand
    family that TDPBUSD natively supports. **One TDPBUSD = 16 384
    multiply-accumulates per instruction** (16×16 output tile × 64
    K-elements per A row × 4 K-elements per inner-product). That's
    256× the VPDPBUSD-zmm throughput per instruction.
  * Internal `amx_path()` uses the existing primitives in
    `amx_matmul`: TileConfig::for_dpbusd(64) → tile_loadconfig →
    tile_zero → K/64 iterations of (tile_load A, tile_load B,
    tile_dpbusd) → tile_store → tile_release.
  * `fallback_path()` for non-AMX hosts: scalar u8 × i8 → i32
    triple-loop reference.

New primitive `amx_matmul::vnni_pack_i8(src, dst, k, n)`:

  * Packs K × N row-major i8 into K/4 outer rows × (N*4) VNNI quad
    layout required by TDPBUSD tile 2.
  * `dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]`
  * Sibling of `vnni_pack_bf16` (which uses K/2 × (N*2) pair layout
    for TDPBF16PS — both kernels reach the same 64-byte tile row
    width via element-width × pack-factor symmetry: BF16 is 2B × 2,
    INT8 is 1B × 4).

Wiring `matmul_i8_to_i32`'s AMX arm (was placebo):

Pre-commit the AMX branch shifted i8 → u8 then called the SCALAR
`int8_gemm_i32` reference and subtracted the bias — TDPBUSD itself
was never reached even on real AMX silicon. Now:

  1. Shift A: i8 → u8 via (+128).
  2. Tile-loop over M/16 i_tile × N/16 j_tile blocks, calling
     int8_tile_gemm_16x16 per (i_tile, j_tile). B sub-block
     extracted into K × 16 scratch once per j_tile, reused across
     i_tile iterations.
  3. Subtract bias: c[i, j] -= 128 × colsum(B[:, j]).

The shape requirement is m%16 == 0 && n%16 == 0 && k%64 == 0;
misaligned shapes fall back to the scalar reference. Phase-4 work
will land mixed AMX-tile + per-axis scalar tail handling for
arbitrary shapes (same shape of Phase-4 work TD-T1 deferred).

Verification:
  * Default v3 build: 2092 lib tests pass (was 2087 — adds 5 new
    tests: 4 in int8_tile_gemm + the existing matmul_i8_to_i32 test
    now exercises the actual TDPBUSD path because this host has
    amx_int8 + amx_tile in /proc/cpuinfo; the test continues to
    pass with bit-identical results to the scalar reference).
  * `vnni_pack_i8_roundtrip` test verifies the pack layout matches
    the spec exactly for an 8 × 4 sample.
  * `fallback_matches_scalar_reference_k64` test verifies the
    non-AMX path produces the same i32 output as a hand-written
    reference for a 64-K, pseudo-random u8/i8 matrix pair.
  * `public_api_diagonal_k128` test asserts a structured pattern
    (A = identity-like, B = constant 2) gives the expected
    accumulation through the full dispatch chain.
  * `cargo clippy --lib -D warnings` clean.
  * `cargo fmt --all --check` clean.

Dropped: `int8_gemm_i32` import in `amx_matmul.rs` since the AMX
arm no longer falls back to it (the scalar else-branch uses an
inline triple-loop directly).

After this commit, the per-CPU dispatch table from PR #180 has the
AMX tier wired for BOTH operand families on Sapphire Rapids+:

  BF16 GEMM:  SPR+ → TDPBF16PS  (TD-T1 / TD-T1b in PR #182)
  INT8 GEMM:  SPR+ → TDPBUSD    (this commit)

Out of scope (separate PRs):
  * VPDPBUSD-zmm arm of matmul_i8_to_i32 for Cooper Lake / Cascade
    Lake / Zen 4+ (avx512vnni without AMX). The kernel function
    `vnni_dot_u8_i8` and `vnni_matvec` exist in simd_amx.rs; just
    need to assemble them into a m×n×k GEMM and wire as the
    middle dispatch tier (analogous to the VDPBF16PS arm in PR
    #182's bf16_gemm_dispatch).
  * AMX tile path for `simd_int_ops::gemm_u8_i8` (the slice-level
    surface from PR #182) — it's u8 × i8 natively so no sign-shift
    needed, simpler to wire than matmul_i8_to_i32.

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
AdaWorldAPI pushed a commit that referenced this pull request May 21, 2026
Per codex review on PR #184: `int8_tile_gemm_16x16` is documented
as `C += A·B` and the scalar `fallback_path` correctly accumulates,
but the AMX `amx_path` did `tile_zero(0)` + `tile_store(0, c, 64)`
which **overwrote** any pre-existing values in `c` on AMX-enabled
hosts. Hardware-dependent behavior: callers relying on accumulation
(blocked GEMM, repeated partial-K updates) would get incorrect
results only when AMX was active.

Same bug in the BF16 sibling `bf16_tile_gemm::amx_path` (shipped
in PR #104) — fixing both.

The fix:
  1. Add a `tile 0` case to `tile_load` (encoding `C4 E2 7B 4B 04
     08` — same SIB byte as the existing tmm1/tmm2 cases, ModR/M
     `04` = `mod=00, reg=000 (tmm0), r/m=100 (SIB follows)`).
  2. Both AMX paths replace `tile_zero(0)` with
     `tile_load(0, c.as_ptr() as *const u8, 64)` — preloads tmm0
     from caller's C buffer. TDPBUSD / TDPBF16PS then accumulate
     into the pre-loaded values; `tile_store(0, c, 64)` writes back
     the true `+=` result.

Consumer impact: zero. Both `matmul_bf16_to_f32` and
`matmul_i8_to_i32` (the only callers of these kernels in this
crate) `tile_c.fill(0)` before each call — so the now-accumulating
behavior + zero-initialized C = same result as the prior overwrite
semantics. The fix removes a latent trap for future blocked-GEMM /
partial-K consumers without changing any shipped behavior.

New regression test `amx_path_preserves_c_accumulator`: pre-loads
C with a known non-zero marker pattern, runs A·B where B=0 (so the
contribution is 0), asserts the marker is preserved. Would fail on
the pre-fix code because the tile_store would zero everything.
Passes on this host (which has amx_int8).

Verification:
  * 2094 lib tests pass (was 2093 — +1 regression test).
  * 11 amx_matmul tests pass (consumers' fill(0)-then-call pattern
    continues to produce correct results).
  * 2 bf16_tile_gemm tests pass.
  * 6 int8_tile_gemm tests pass.
  * cargo clippy --lib --tests --features rayon,native -- -D warnings
    clean.
  * cargo fmt --all --check clean.

https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants