Skip to content

Commit 38d4800

Browse files
committed
feat(hpc): VPDPBUSD-ymm AVX-VNNI tier for matmul_i8_to_i32
Completes the per-CPU dispatch chain for `matmul_i8_to_i32` by adding the AVX-VNNI ymm tier — Arrow Lake, Meteor Lake U, Alder Lake silicon that has AVX-VNNI but dropped AVX-512. Mirrors the shape of the avx512vnni-zmm arm shipped in PR #184 with the narrower 8-wide kernel. New kernel `hpc::int8_tile_gemm::int8_gemm_vpdpbusd_ymm`: * One `_mm256_dpbusd_avx_epi32` instruction: 8 i32 accumulator lanes, each receiving 4 u8×i8 products = 32 MACs per instruction. Half the throughput-per-instruction of the `_mm512_dpbusd_epi32` zmm version. * Same B-pre-pack scheme (quad-interleaved per 8-wide j-block), same K-tail / N-tail handling. Just narrower. * Stable intrinsic under `target_feature = "avxvnni,avx2"` — no asm-byte needed. Wiring `matmul_i8_to_i32`'s dispatch as Tier 3: 1. amx_available() + 16/16/64-aligned → AMX TDPBUSD (PR #184: int8_gemm_amx_tiled, 16 384 MACs/instr) 2. is_x86_feature_detected!("avx512vnni") → VPDPBUSD-zmm (PR #184: int8_gemm_vpdpbusd_zmm, 64 MACs/instr) 3. is_x86_feature_detected!("avxvnni") → VPDPBUSD-ymm (THIS COMMIT: int8_gemm_vpdpbusd_ymm, 32 MACs/instr) 4. scalar i8×i8 → i32 reference (was Tier 3) All three SIMD tiers share the sign-shift bias trick: shift LHS i8 → u8 (+128), run the kernel, subtract 128·colsum(B). Same `subtract_i8_to_u8_bias` helper (factored in PR #184). New direct test `vpdpbusd_ymm_matches_scalar` mirrors the zmm version's test: sweeps shapes spanning 8-aligned, K-tail (k % 4), N-tail (n % 8), and small shapes, asserts byte-equal output vs scalar reference. Verification: * Default v3 (this host has avx512vnni so the new arm doesn't fire from matmul_i8_to_i32 — Tier 2 catches first): 2096 lib tests pass (was 2095 — +1 new direct test). * Direct test exercises int8_gemm_vpdpbusd_ymm on this host since avxvnni is present alongside avx512vnni. * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. Per-CPU dispatch state after this commit (final on the int8 side): matmul_i8_to_i32: SPR+ AMX | CPL/Zen4 zmm | ARL ymm | scalar (PR #184) | (PR #184) | (THIS) | (always) The matmul_i8_to_i32 column of PR #180's dispatch table is now fully filled. The gemm_u8_i8 slice surface (in PR #185) already has AVX-VNNI ymm via its existing compile-time cascade — both i8-related public surfaces now cover every x86_64 tier with a hardware-accelerated arm. Out of scope (separate PRs): * NEON BFMMLA / SDOT on aarch64 via asm-byte — Phase 3b, needs aarch64 CI runner verification. * TD-T6: real _mm256_* for AVX2 BLAS-1 (scal/nrm2/asum). https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
1 parent f8e9453 commit 38d4800

2 files changed

Lines changed: 142 additions & 2 deletions

File tree

src/hpc/amx_matmul.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -621,9 +621,19 @@ pub fn matmul_i8_to_i32(
621621
crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm(&a_u8, &b_i8, &mut c, m, n, k);
622622
}
623623
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
624+
} else if cfg!(target_arch = "x86_64") && std::is_x86_feature_detected!("avxvnni") {
625+
// Tier 3 — AVX-VNNI ymm VPDPBUSD: 32 MACs per instruction.
626+
// Arrow Lake, Meteor Lake U, Alder Lake silicon that has
627+
// AVX-VNNI but dropped AVX-512. Same sign-shift bias trick.
628+
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();
629+
// SAFETY: runtime feature-detected avxvnni above.
630+
unsafe {
631+
crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_ymm(&a_u8, &b_i8, &mut c, m, n, k);
632+
}
633+
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
624634
} else {
625-
// Tier 3 — Scalar i8×i8 → i32 reference for non-x86 hosts,
626-
// pre-AVX-512 silicon, or shapes that don't satisfy either of
635+
// Tier 4 — Scalar i8×i8 → i32 reference for non-x86 hosts,
636+
// pre-AVX-VNNI silicon, or shapes that don't satisfy any of
627637
// the SIMD tiers' alignment requirements.
628638
for i in 0..m {
629639
for p in 0..k {

src/hpc/int8_tile_gemm.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,97 @@ pub unsafe fn int8_gemm_vpdpbusd_zmm(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m:
215215
}
216216
}
217217

218+
// ═════════════════════════════════════════════════════════════════════
219+
// VPDPBUSD-ymm AVX-VNNI tier (Arrow Lake / Meteor Lake U / Alder Lake)
220+
// ═════════════════════════════════════════════════════════════════════
221+
222+
/// AVX-VNNI ymm `u8 × i8 → i32` GEMM kernel for arbitrary M × N × K.
223+
///
224+
/// One `_mm256_dpbusd_avx_epi32` instruction: 8 i32 accumulator lanes,
225+
/// each receiving the sum of 4 `u8 × i8` products = **32 MACs per
226+
/// instruction**. Half the throughput-per-instruction of the
227+
/// `_mm512_dpbusd_epi32` zmm version (which does 64 MACs); fires on
228+
/// Arrow Lake / Meteor Lake U / Alder Lake silicon that has AVX-VNNI
229+
/// but NOT AVX-512.
230+
///
231+
/// Same B pre-packing scheme as the zmm version (quad-interleaved per
232+
/// 8-wide j-block), same K-tail and N-tail handling, just narrower.
233+
/// Mirrors the `vnni2_dot_u8_i8` shape in `simd_amx.rs` but as a
234+
/// matrix-product instead of single-row dot.
235+
///
236+
/// Output behavior: overwrites `c` (does NOT accumulate). Caller's
237+
/// responsibility to zero `c` first if needed.
238+
///
239+
/// # Safety
240+
/// Caller must have feature-detected `avxvnni + avx2` at runtime.
241+
#[cfg(target_arch = "x86_64")]
242+
#[target_feature(enable = "avxvnni,avx2")]
243+
pub unsafe fn int8_gemm_vpdpbusd_ymm(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) {
244+
use core::arch::x86_64::{
245+
__m256i, _mm256_dpbusd_avx_epi32, _mm256_loadu_si256, _mm256_set1_epi32, _mm256_setzero_si256,
246+
_mm256_storeu_si256,
247+
};
248+
249+
let k_quads = k / 4;
250+
let k_tail = k % 4;
251+
252+
// Pre-pack scratch: 8 i32 lanes per k_quad (vs 16 in the zmm
253+
// version). Same per-lane layout: each i32 holds 4 consecutive
254+
// B K-bytes for output column j+lane.
255+
let mut b_col_quads = vec![0i32; k_quads.max(1) * 8];
256+
let mut out_buf = [0i32; 8];
257+
258+
for j_base in (0..n).step_by(8) {
259+
let j_count = 8.min(n - j_base);
260+
261+
for k_quad in 0..k_quads {
262+
let row0 = 4 * k_quad * n;
263+
let row1 = (4 * k_quad + 1) * n;
264+
let row2 = (4 * k_quad + 2) * n;
265+
let row3 = (4 * k_quad + 3) * n;
266+
for jj in 0..j_count {
267+
let b0 = b_i8[row0 + j_base + jj] as u8 as u32;
268+
let b1 = b_i8[row1 + j_base + jj] as u8 as u32;
269+
let b2 = b_i8[row2 + j_base + jj] as u8 as u32;
270+
let b3 = b_i8[row3 + j_base + jj] as u8 as u32;
271+
b_col_quads[k_quad * 8 + jj] = (b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)) as i32;
272+
}
273+
for jj in j_count..8 {
274+
b_col_quads[k_quad * 8 + jj] = 0;
275+
}
276+
}
277+
278+
for i in 0..m {
279+
let mut acc = _mm256_setzero_si256();
280+
let a_row_off = i * k;
281+
for k_quad in 0..k_quads {
282+
let a0 = a_u8[a_row_off + 4 * k_quad] as u32;
283+
let a1 = a_u8[a_row_off + 4 * k_quad + 1] as u32;
284+
let a2 = a_u8[a_row_off + 4 * k_quad + 2] as u32;
285+
let a3 = a_u8[a_row_off + 4 * k_quad + 3] as u32;
286+
let packed_a = a0 | (a1 << 8) | (a2 << 16) | (a3 << 24);
287+
let a_v = _mm256_set1_epi32(packed_a as i32);
288+
let b_v = _mm256_loadu_si256(b_col_quads.as_ptr().add(k_quad * 8) as *const __m256i);
289+
acc = _mm256_dpbusd_avx_epi32(acc, a_v, b_v);
290+
}
291+
_mm256_storeu_si256(out_buf.as_mut_ptr() as *mut __m256i, acc);
292+
293+
if k_tail > 0 {
294+
for kk in (k_quads * 4)..k {
295+
let a_val = a_u8[a_row_off + kk] as i32;
296+
let tail_row = kk * n;
297+
for jj in 0..j_count {
298+
out_buf[jj] += a_val * b_i8[tail_row + j_base + jj] as i32;
299+
}
300+
}
301+
}
302+
303+
let dst_off = i * n + j_base;
304+
c[dst_off..dst_off + j_count].copy_from_slice(&out_buf[..j_count]);
305+
}
306+
}
307+
}
308+
218309
// ═════════════════════════════════════════════════════════════════════
219310
// Scalar fallback (i32 reference)
220311
// ═════════════════════════════════════════════════════════════════════
@@ -422,6 +513,45 @@ mod tests {
422513
}
423514
}
424515

516+
/// Direct test for the VPDPBUSD-ymm arm (AVX-VNNI tier of
517+
/// `matmul_i8_to_i32`). Same shape / bit-exactness contract as
518+
/// the zmm version's test, just on the narrower 8-wide kernel.
519+
#[cfg(target_arch = "x86_64")]
520+
#[test]
521+
fn vpdpbusd_ymm_matches_scalar() {
522+
if !std::is_x86_feature_detected!("avxvnni") {
523+
eprintln!("avxvnni not detected; skipping");
524+
return;
525+
}
526+
527+
fn ref_gemm(a: &[u8], b: &[i8], m: usize, n: usize, k: usize) -> Vec<i32> {
528+
let mut c = vec![0i32; m * n];
529+
for i in 0..m {
530+
for kk in 0..k {
531+
let av = a[i * k + kk] as i32;
532+
for j in 0..n {
533+
c[i * n + j] += av * b[kk * n + j] as i32;
534+
}
535+
}
536+
}
537+
c
538+
}
539+
540+
// Sweep shapes spanning 8-aligned, K-tail (k % 4), N-tail
541+
// (n % 8), and small shapes to exercise every code path.
542+
for (m, n, k) in [(16, 8, 64), (3, 5, 7), (17, 33, 100), (1, 17, 12), (8, 8, 4)] {
543+
let a: Vec<u8> = (0..m * k).map(|i| ((i * 31 + 7) % 256) as u8).collect();
544+
let b: Vec<i8> = (0..k * n)
545+
.map(|i| ((i * 17 + 3) % 256) as u8 as i8)
546+
.collect();
547+
let expected = ref_gemm(&a, &b, m, n, k);
548+
let mut got = vec![0i32; m * n];
549+
// SAFETY: avxvnni confirmed at the top of the test.
550+
unsafe { int8_gemm_vpdpbusd_ymm(&a, &b, &mut got, m, n, k) };
551+
assert_eq!(got, expected, "VPDPBUSD-ymm mismatch at (M={}, N={}, K={})", m, n, k);
552+
}
553+
}
554+
425555
#[test]
426556
fn vnni_pack_i8_roundtrip() {
427557
// Pack then verify the VNNI layout matches the spec:

0 commit comments

Comments
 (0)