Skip to content

Commit ae5efaa

Browse files
authored
Merge pull request #182 from AdaWorldAPI/claude/continue-ndarray-x0Oaw
simd: agnostic gemm_u8_i8 surface, integer-slice-op lift, per-CPU matrix, BF16 AMX wiring
2 parents 3568003 + 333e96a commit ae5efaa

7 files changed

Lines changed: 1588 additions & 46 deletions

File tree

.cargo/config-avx512.toml

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,50 @@
11
[build]
2-
# Explicit AVX-512 config — `x86-64-v4`. Use with:
2+
# Explicit AVX-512 config — Sapphire Rapids baseline. Use with:
33
# cargo --config .cargo/config-avx512.toml build
44
# cargo --config .cargo/config-avx512.toml test
55
#
6-
# Compiles `target_feature = "avx512f"` on, so `src/simd.rs` selects the
7-
# `simd_avx512` backend with native `__m512` / `__m512d` / `__m512i`
8-
# storage. Required for the Sapphire Rapids / Granite Rapids hot paths
9-
# (`f32_to_bf16_batch_rne`, the AVX-512BF16 BF16 lanes, the AMX tiles).
6+
# `-Ctarget-cpu=sapphirerapids` enables, in addition to the
7+
# `x86-64-v4` AVX-512 baseline (F + BW + CD + DQ + VL):
108
#
11-
# Binary produced here will SIGILL on AVX2-only silicon — only use on
12-
# hosts that report `avx512f` in `/proc/cpuinfo`. For shipping a single
13-
# release artifact that adapts at process start, see the LazyLock runtime
14-
# dispatch path in § 7.1 of the architecture doc instead.
9+
# - AVX-512 VNNI (VPDPBUSD u8×i8 → i32)
10+
# - AVX-512 BF16 (VDPBF16PS, VCVTNE2PS2BF16)
11+
# - AVX-512 FP16 (16-wide native FP16 arithmetic)
12+
# - AVX-512 VBMI / VBMI2 (byte permute)
13+
# - AVX-512 IFMA, BITALG, VPOPCNTDQ, GFNI, VAES, VPCLMUL
14+
# - AVX-VNNI (ymm VPDPBUSD on Alder/Sapphire client)
15+
# - AMX-TILE + AMX-INT8 + AMX-BF16 (16×16×k tile kernels)
16+
#
17+
# Effect on the agnostic surfaces in `src/simd_*ops.rs`:
18+
#
19+
# - `simd_int_ops::gemm_u8_i8` resolves to the AVX-512 VNNI `VPDPBUSD`
20+
# zmm kernel (`hpc::vnni_gemm::int8_gemm_vnni_avx512`). When the
21+
# planned `amx-int8` arm lands, it will preempt this one and route
22+
# to `TDPBUSD` instead — same source, no caller changes.
23+
# - BF16 / FP16 lane ops in `src/simd_avx512.rs` light up.
24+
# - `simd_amx::*` tile primitives are usable without further gating.
25+
#
26+
# Pure `x86-64-v4` is NOT used here — Skylake-X is the only AVX-512 CPU
27+
# without VNNI and the project's design pins VNNI as the lowest common
28+
# denominator above the scalar reference. SKX users either build with
29+
# `-Ctarget-cpu=x86-64-v4` explicitly (and accept the scalar arm for
30+
# `gemm_u8_i8`) or run a runtime-LazyLock dispatch binary.
31+
#
32+
# Binary produced here will SIGILL on CPUs that lack any of the
33+
# enabled feature sets — i.e. anything pre-Sapphire-Rapids on x86_64:
34+
#
35+
# - Cooper Lake / Cascade Lake / Ice Lake-SP (no BF16+FP16+AMX)
36+
# - Skylake-X / Skylake-SP / Skylake-W (no VNNI either)
37+
# - Zen 4 / Zen 5 (no AMX)
38+
# - Alder Lake / Arrow Lake (no AVX-512 at all)
39+
# - Haswell ⇢ Coffee Lake (AVX2 only)
40+
#
41+
# Only deploy artifacts built with this config to hosts that report
42+
# `amx_int8 amx_bf16 avx512_bf16 avx512_fp16 avx512_vnni` in
43+
# `/proc/cpuinfo`. For Cascade Lake → Ice Lake-SP → Zen 4 silicon
44+
# (AVX-512 + VNNI but no AMX/BF16/FP16), build with
45+
# `-Ctarget-cpu=cascadelake` or `-Ctarget-cpu=znver4` instead. For
46+
# shipping a single release artifact that adapts at process start,
47+
# see the LazyLock runtime dispatch path in § 7.1 of the architecture
48+
# doc instead.
1549
[target.'cfg(target_arch = "x86_64")']
16-
rustflags = ["-Ctarget-cpu=x86-64-v4"]
50+
rustflags = ["-Ctarget-cpu=sapphirerapids"]

.claude/knowledge/agnostic-surface-cpu-matrix.md

Lines changed: 548 additions & 0 deletions
Large diffs are not rendered by default.

src/hpc/amx_matmul.rs

Lines changed: 183 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,15 @@ fn write_contig<A: Copy>(view: &mut ArrayViewMut2<'_, A>, src: &[A]) {
297297

298298
/// Matrix multiply BF16 × BF16 → f32: `out = lhs · rhs`.
299299
///
300-
/// Uses AMX `TDPBF16PS` (256 mul-adds per instruction) when available,
301-
/// otherwise falls back to [`bf16_gemm_f32`].
300+
/// On AMX hardware (Sapphire Rapids+, Granite Rapids), 16×16-aligned tiles
301+
/// dispatch to [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] which
302+
/// emits `TDPBF16PS` via the asm-byte path in `simd_amx.rs` — 256
303+
/// BF16×BF16 multiply-accumulates per instruction (16×16×32 = 8 192 FLOPs)
304+
/// into f32 accumulator tiles. M/N/K tail blocks (when any dim isn't
305+
/// 16/16/32-aligned) fall through to the validated scalar
306+
/// [`crate::hpc::quantized::bf16_gemm_f32`] reference.
307+
///
308+
/// On non-AMX hosts the entire matmul goes through `bf16_gemm_f32`.
302309
///
303310
/// `out` must be row-contiguous (column stride = 1); inputs may be strided.
304311
pub fn matmul_bf16_to_f32(
@@ -310,26 +317,180 @@ pub fn matmul_bf16_to_f32(
310317
let b = pack_contig(&rhs);
311318
let mut c = vec![0.0f32; m * n];
312319

313-
// AMX path: a tiled 16×16 kernel exists in `bf16_tile_gemm` for sizes that
314-
// fit cleanly. For any leftover tail (or hosts without AMX), defer to the
315-
// scalar `bf16_gemm_f32`. The tile kernel itself is maintained alongside
316-
// the low-level primitives at the top of this file; the public surface
317-
// intentionally goes through the validated scalar path so we always
318-
// produce a numerically-stable f32 result.
319-
if amx_available() {
320-
// Future: AMX-tiled fast path. Today we route through the same
321-
// f32 reference kernel; correctness is identical regardless of
322-
// hardware. The `amx_available()` branch is preserved so callers
323-
// can be sure the AMX detection runs.
324-
bf16_gemm_f32(&a, &b, &mut c, m, n, k, 1.0, 0.0);
325-
} else {
326-
bf16_gemm_f32(&a, &b, &mut c, m, n, k, 1.0, 0.0);
327-
}
320+
bf16_gemm_dispatch(&a, &b, &mut c, m, n, k);
328321

329322
write_contig(&mut out, &c);
330323
Ok(())
331324
}
332325

326+
/// BF16 × BF16 → f32 GEMM with three-tier dispatch (AMX → VDPBF16PS → scalar).
327+
///
328+
/// Inputs are packed row-major (`a` is M × K, `b` is K × N). Output `c`
329+
/// is M × N row-major and is overwritten (not accumulated).
330+
///
331+
/// Tier selection:
332+
///
333+
/// 1. **AMX `TDPBF16PS`** (Sapphire Rapids+, Granite Rapids) when
334+
/// `amx_available()` is true AND shapes are 16/16/32-aligned.
335+
/// Dispatches through
336+
/// [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] →
337+
/// `simd_amx::tile_dpbf16ps` via asm-byte (`TDPBF16PS` intrinsic is
338+
/// nightly-only on Rust 1.95). 8 192 BF16×BF16 multiplies + 256 f32
339+
/// accumulates per instruction.
340+
/// 2. **`VDPBF16PS`** (Cooper Lake, Cascade Lake AVX-512BF16, Zen 4+)
341+
/// when `is_x86_feature_detected!("avx512bf16")` is true. The
342+
/// intrinsic `_mm512_dpbf16_ps` is stable on Rust 1.95 (no asm-byte
343+
/// needed). Per instruction: 32 BF16×BF16 multiplies + 16 f32
344+
/// accumulates, single-rounded. Handles arbitrary shapes — M / N
345+
/// tails fall through the per-iteration j-block trimming; K-tail
346+
/// (odd K) is handled with a final scalar pair.
347+
/// 3. **Scalar reference** [`bf16_gemm_f32`] for hosts without either
348+
/// extension or for shapes the AMX arm rejects.
349+
///
350+
/// The per-tier dispatch table comes from PR #180's BF16 GEMM column.
351+
fn bf16_gemm_dispatch(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize, k: usize) {
352+
if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 32 == 0 {
353+
// SAFETY: BF16 is `#[repr(transparent)] struct BF16(pub u16)`
354+
// (per `hpc::quantized::BF16`). Reinterpreting `&[BF16]` as
355+
// `&[u16]` is bit-pattern preserving.
356+
let a_u16: &[u16] = unsafe { core::slice::from_raw_parts(a.as_ptr() as *const u16, a.len()) };
357+
358+
// B is packed row-major K × N; the 16×16 tile kernel wants a
359+
// K × 16 contiguous sub-block. Extract per (j_tile) into a
360+
// scratch buffer once and reuse across i_tile.
361+
let mut b_tile = vec![0u16; k * 16];
362+
let mut tile_c = vec![0.0f32; 256];
363+
364+
for j_tile in (0..n).step_by(16) {
365+
// Pack b[0..k, j_tile..j_tile+16] into row-major 16-wide K-rows.
366+
for kk in 0..k {
367+
let row = kk * n + j_tile;
368+
for jj in 0..16 {
369+
b_tile[kk * 16 + jj] = b[row + jj].0;
370+
}
371+
}
372+
for i_tile in (0..m).step_by(16) {
373+
// A_tile = a[i_tile..i_tile+16, 0..k] — already contiguous
374+
// since `a` is packed row-major M × K.
375+
let a_tile = &a_u16[i_tile * k..(i_tile + 16) * k];
376+
tile_c.fill(0.0);
377+
crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k);
378+
// Write tile_c (16 × 16, row-major) into c (M × N, row-major).
379+
for ii in 0..16 {
380+
let dst_off = (i_tile + ii) * n + j_tile;
381+
c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]);
382+
}
383+
}
384+
}
385+
return;
386+
}
387+
388+
#[cfg(target_arch = "x86_64")]
389+
{
390+
if std::is_x86_feature_detected!("avx512bf16") {
391+
// SAFETY: feature-detected at runtime; the kernel is
392+
// `#[target_feature(enable = "avx512bf16,avx512f")]`.
393+
unsafe {
394+
bf16_gemm_vdpbf16ps(a, b, c, m, n, k);
395+
}
396+
return;
397+
}
398+
}
399+
400+
bf16_gemm_f32(a, b, c, m, n, k, 1.0, 0.0);
401+
}
402+
403+
/// AVX-512BF16 BF16 GEMM using `_mm512_dpbf16_ps` (`VDPBF16PS`).
404+
///
405+
/// One VDPBF16PS instruction: 16 f32 accumulator lanes each receive
406+
/// `acc[j] += a.bf16[2j] * b.bf16[2j] + a.bf16[2j+1] * b.bf16[2j+1]`,
407+
/// single-rounded. The kernel maps the 16 output lanes to a row of 16
408+
/// j-columns of C[i, ·], with one i row processed at a time and a K-pair
409+
/// inner loop accumulating into the same 16 f32 lanes across iterations.
410+
///
411+
/// B-column packing: VDPBF16PS wants the 32 B BF16s per call laid out
412+
/// as 16 lane-pairs (lane j contains `B[2k_pair, j_base+j]` followed by
413+
/// `B[2k_pair+1, j_base+j]`, packed into one u32). We pre-pack B for
414+
/// the current j-block into `b_col_pairs[k_pair * 16 + j] = u32` once
415+
/// per j_block and reuse across all i — amortizes the gather cost.
416+
///
417+
/// K-tail (when K is odd) is handled with a final scalar BF16 multiply
418+
/// per output cell; N-tail (when the j-block has < 16 valid columns)
419+
/// is handled by trimming the store after the VDPBF16PS chain.
420+
///
421+
/// # Safety
422+
/// Caller must have feature-detected `avx512bf16` at runtime.
423+
#[cfg(target_arch = "x86_64")]
424+
#[target_feature(enable = "avx512bf16,avx512f")]
425+
unsafe fn bf16_gemm_vdpbf16ps(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize, k: usize) {
426+
use core::arch::x86_64::{
427+
__m512bh, __m512i, _mm512_dpbf16_ps, _mm512_loadu_si512, _mm512_set1_epi32, _mm512_setzero_ps, _mm512_storeu_ps,
428+
};
429+
430+
let k_pairs = k / 2;
431+
let k_tail = k % 2;
432+
433+
// SAFETY: BF16 is repr(transparent) over u16.
434+
let a_u16: &[u16] = core::slice::from_raw_parts(a.as_ptr() as *const u16, a.len());
435+
let b_u16: &[u16] = core::slice::from_raw_parts(b.as_ptr() as *const u16, b.len());
436+
437+
// Pre-pack scratch: 16 u32 lanes per k_pair, holding (b_lo | b_hi << 16).
438+
let mut b_col_pairs = vec![0u32; k_pairs.max(1) * 16];
439+
// Scratch for the 16-wide store + N-tail trim.
440+
let mut out_buf = [0.0f32; 16];
441+
442+
for j_base in (0..n).step_by(16) {
443+
let j_count = 16.min(n - j_base);
444+
445+
// Pack B columns [j_base..j_base+j_count] in pair-interleaved layout.
446+
// For lanes j >= j_count (the N-tail of this j_block), pad with 0 —
447+
// they're not stored back, but the VDPBF16PS still touches them.
448+
for k_pair in 0..k_pairs {
449+
let row_lo = 2 * k_pair * n;
450+
let row_hi = (2 * k_pair + 1) * n;
451+
for jj in 0..j_count {
452+
let b_lo = b_u16[row_lo + j_base + jj] as u32;
453+
let b_hi = b_u16[row_hi + j_base + jj] as u32;
454+
b_col_pairs[k_pair * 16 + jj] = (b_hi << 16) | b_lo;
455+
}
456+
for jj in j_count..16 {
457+
b_col_pairs[k_pair * 16 + jj] = 0;
458+
}
459+
}
460+
461+
for i in 0..m {
462+
let mut acc = _mm512_setzero_ps();
463+
let a_row_off = i * k;
464+
for k_pair in 0..k_pairs {
465+
// Broadcast A[i, 2k_pair..2k_pair+2] as the (BF16 lo, BF16 hi)
466+
// pair across all 16 lanes.
467+
let a_lo = a_u16[a_row_off + 2 * k_pair] as u32;
468+
let a_hi = a_u16[a_row_off + 2 * k_pair + 1] as u32;
469+
let pair = (a_hi << 16) | a_lo;
470+
let a_bh: __m512bh = core::mem::transmute(_mm512_set1_epi32(pair as i32));
471+
let b_bh: __m512bh =
472+
core::mem::transmute(_mm512_loadu_si512(b_col_pairs.as_ptr().add(k_pair * 16) as *const __m512i));
473+
acc = _mm512_dpbf16_ps(acc, a_bh, b_bh);
474+
}
475+
_mm512_storeu_ps(out_buf.as_mut_ptr(), acc);
476+
477+
// K-tail: one extra scalar BF16 multiply for k = k_pairs*2.
478+
if k_tail == 1 {
479+
let a_last_f32 = BF16(a_u16[a_row_off + k - 1]).to_f32();
480+
let tail_row = (k - 1) * n;
481+
for jj in 0..j_count {
482+
let b_last_f32 = BF16(b_u16[tail_row + j_base + jj]).to_f32();
483+
out_buf[jj] += a_last_f32 * b_last_f32;
484+
}
485+
}
486+
487+
// Store the j_count valid lanes (drops N-tail padding lanes).
488+
let dst_off = i * n + j_base;
489+
c[dst_off..dst_off + j_count].copy_from_slice(&out_buf[..j_count]);
490+
}
491+
}
492+
}
493+
333494
// ── f32 → f32 (BF16 compute on AMX) ────────────────────────────────────────
334495

335496
/// Matrix multiply f32 × f32 → f32: `out = lhs · rhs`.
@@ -349,10 +510,13 @@ pub fn matmul_f32(
349510
let mut c = vec![0.0f32; m * n];
350511

351512
if amx_available() {
352-
// AMX path: down-cast to BF16, run BF16 GEMM, accumulate in f32.
513+
// AMX path: down-cast to BF16 (RNE, ~1 ULP at BF16 mantissa
514+
// precision), then dispatch through the shared BF16 helper
515+
// which picks `TDPBF16PS` tile kernel for 16/16/32-aligned
516+
// shapes and the scalar `bf16_gemm_f32` reference otherwise.
353517
let a_bf16: Vec<BF16> = a_f32.iter().map(|&v| BF16::from_f32_rounded(v)).collect();
354518
let b_bf16: Vec<BF16> = b_f32.iter().map(|&v| BF16::from_f32_rounded(v)).collect();
355-
bf16_gemm_f32(&a_bf16, &b_bf16, &mut c, m, n, k, 1.0, 0.0);
519+
bf16_gemm_dispatch(&a_bf16, &b_bf16, &mut c, m, n, k);
356520
} else {
357521
// Pure f32 reference path.
358522
for i in 0..m {

0 commit comments

Comments
 (0)