Skip to content

Commit ddf0905

Browse files
authored
Merge pull request #184 from AdaWorldAPI/claude/continue-ndarray-x0Oaw
hpc: TD-T2 — AMX TDPBUSD tile kernel + matmul_i8_to_i32 wiring
2 parents 098c5aa + f987937 commit ddf0905

5 files changed

Lines changed: 519 additions & 26 deletions

File tree

src/hpc/amx_matmul.rs

Lines changed: 112 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,32 @@ pub unsafe fn tile_release() {
9494

9595
/// Load tile from memory.
9696
///
97+
/// Encoding: `TILELOADD tmmN, [rcx + rax]` is VEX `C4 E2 7B 4B /r` with
98+
/// a SIB byte selecting `[rcx + rax]`. The ModR/M `/r` field encodes the
99+
/// destination tile via `reg = N` (3-bit tile index). Per-tile bytes:
100+
///
101+
/// tmm0: C4 E2 7B 4B **04** 08
102+
/// tmm1: C4 E2 7B 4B **0C** 08
103+
/// tmm2: C4 E2 7B 4B **14** 08
104+
///
105+
/// `04 | (N << 3)` gives the ModR/M byte; the `08` SIB is the same
106+
/// across tiles. tmm0 was added when codex flagged the accumulator-
107+
/// preservation bug on PR #184 (`tile_zero(0)` + `tile_store(0, c)`
108+
/// discarded any pre-existing C values — the fix is `tile_load(0, c)`
109+
/// instead of `tile_zero(0)` so TDPBUSD/TDPBF16PS truly accumulate as
110+
/// the documented `C += A·B` semantics promise).
111+
///
97112
/// # Safety
98113
/// Pointer must be valid, stride must match tile config.
99114
#[inline]
100115
pub unsafe fn tile_load(tile: u8, ptr: *const u8, stride: usize) {
101116
match tile {
102-
// TILELOADD tmm0, [ptr + stride*row]
103-
// Encoding: VEX.128.F2.0F38.W0 4B /r with memory operand
117+
0 => asm!(
118+
".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x04, 0x08",
119+
in("rcx") ptr,
120+
in("rax") stride,
121+
options(nostack),
122+
),
104123
1 => asm!(
105124
".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x0c, 0x08",
106125
in("rcx") ptr,
@@ -193,6 +212,32 @@ pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) {
193212
}
194213
}
195214

215+
/// Pack B[K, N] i8 row-major into K/4 × (N*4) VNNI quads for `TDPBUSD`.
216+
///
217+
/// Output layout required by `TDPBUSD` tile 2 (16 rows × 64 bytes):
218+
/// dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]
219+
///
220+
/// For N=16 (AMX tile width), each output "row" holds 16 i8 quads = 64
221+
/// bytes (matches the 64-byte tile row width). K must be a multiple of
222+
/// 4. The same layout is used for `u8` operands (just bit-cast through
223+
/// — VNNI doesn't care about sign at the packing layer; sign
224+
/// interpretation happens inside TDPBUSD which treats A as u8 and B
225+
/// as i8 for the multiply).
226+
#[inline]
227+
pub fn vnni_pack_i8(src: &[i8], dst: &mut [i8], k: usize, n: usize) {
228+
debug_assert_eq!(src.len(), k * n);
229+
debug_assert_eq!(dst.len(), k * n);
230+
debug_assert_eq!(k % 4, 0, "K must be multiple of 4 for VNNI INT8 quads");
231+
for kb in 0..(k / 4) {
232+
let dst_row = kb * n * 4;
233+
for j in 0..n {
234+
for p in 0..4 {
235+
dst[dst_row + j * 4 + p] = src[(4 * kb + p) * n + j];
236+
}
237+
}
238+
}
239+
}
240+
196241
// ═══════════════════════════════════════════════════════════════════════════
197242
// Public ndarray-typed matmul API (sprint A4 / Burn parity item 6)
198243
// ═══════════════════════════════════════════════════════════════════════════
@@ -207,7 +252,7 @@ pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) {
207252
// strided (e.g. `view.slice(s![.., ..;2])`). Strided inputs are repacked
208253
// into contiguous staging buffers before the kernel runs.
209254

210-
use crate::hpc::quantized::{bf16_gemm_f32, int8_gemm_i32, BF16};
255+
use crate::hpc::quantized::{bf16_gemm_f32, BF16};
211256
use crate::{ArrayView2, ArrayViewMut2};
212257

213258
/// Errors returned by the public AMX matmul API.
@@ -537,14 +582,17 @@ pub fn matmul_f32(
537582

538583
/// Matrix multiply i8 × i8 → i32: `out = lhs · rhs`.
539584
///
540-
/// On AMX hosts uses `TDPBUSD` (256 MACs/instr); otherwise falls back to
541-
/// the scalar `int8_gemm_i32`.
585+
/// On AMX hosts with 16/16/64-aligned shapes uses `TDPBUSD` via the
586+
/// 16×16 tile kernel in [`crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16`]
587+
/// — 16 384 MACs per instruction. Mis-aligned shapes (or non-AMX hosts)
588+
/// fall back to the scalar i8×i8 → i32 reference.
542589
///
543-
/// Note: `TDPBUSD` natively expects unsigned-by-signed (u8 × i8). For the
544-
/// signed-by-signed surface required here, the LHS is shifted into the
545-
/// unsigned domain and the bias subtracted from the accumulator (only on
546-
/// the AMX path; the scalar path operates directly in i8). The public
547-
/// result is identical.
590+
/// Note: `TDPBUSD` natively expects unsigned-by-signed (u8 × i8). For
591+
/// the signed-by-signed surface required here, the LHS is shifted into
592+
/// the unsigned domain (i8 + 128 → u8) and the bias `128 · sum(B[:, j]
593+
/// over k)` is subtracted from the accumulator. The public result is
594+
/// bit-identical to the scalar reference because all arithmetic stays
595+
/// in i32 (no float rounding).
548596
///
549597
/// `out` must be row-contiguous; inputs may be strided.
550598
pub fn matmul_i8_to_i32(
@@ -556,26 +604,45 @@ pub fn matmul_i8_to_i32(
556604
let b_i8 = pack_contig(&rhs);
557605
let mut c = vec![0i32; m * n];
558606

559-
if amx_available() {
560-
// AMX TDPBUSD path: shift LHS i8 → u8 via (+128) and subtract the
561-
// bias 128·sum(B[:, j] over k) afterwards. This keeps numerics exact.
607+
if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 64 == 0 {
608+
// Tier 1 — AMX TDPBUSD tile path: shift LHS i8 → u8 (+128),
609+
// tile-GEMM via int8_tile_gemm_16x16, subtract bias.
562610
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();
563611

564-
// Compute C' = A_u8 · B_i8 in i32, then subtract 128 · colsum(B).
565-
int8_gemm_i32(&a_u8, &b_i8, &mut c, m, n, k);
566-
let mut colsum = vec![0i32; n];
567-
for p in 0..k {
568-
for j in 0..n {
569-
colsum[j] += b_i8[p * n + j] as i32;
612+
let mut b_tile = vec![0i8; k * 16];
613+
let mut tile_c = vec![0i32; 256];
614+
615+
for j_tile in (0..n).step_by(16) {
616+
for kk in 0..k {
617+
let row = kk * n + j_tile;
618+
b_tile[kk * 16..(kk + 1) * 16]
619+
.copy_from_slice(unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) });
570620
}
571-
}
572-
for i in 0..m {
573-
for j in 0..n {
574-
c[i * n + j] -= 128 * colsum[j];
621+
for i_tile in (0..m).step_by(16) {
622+
let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k];
623+
tile_c.fill(0);
624+
crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k);
625+
for ii in 0..16 {
626+
let dst_off = (i_tile + ii) * n + j_tile;
627+
c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]);
628+
}
575629
}
576630
}
631+
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
632+
} else if cfg!(target_arch = "x86_64") && std::is_x86_feature_detected!("avx512vnni") {
633+
// Tier 2 — AVX-512 VPDPBUSD zmm: 64 MACs per instruction, no
634+
// shape-alignment requirement (M/N/K all handled via per-block
635+
// trim and scalar K-tail). Same sign-shift bias trick as AMX.
636+
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();
637+
// SAFETY: runtime feature-detected avx512vnni above.
638+
unsafe {
639+
crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm(&a_u8, &b_i8, &mut c, m, n, k);
640+
}
641+
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
577642
} else {
578-
// Scalar i8×i8 → i32 reference.
643+
// Tier 3 — Scalar i8×i8 → i32 reference for non-x86 hosts,
644+
// pre-AVX-512 silicon, or shapes that don't satisfy either of
645+
// the SIMD tiers' alignment requirements.
579646
for i in 0..m {
580647
for p in 0..k {
581648
let av = a_i8[i * k + p] as i32;
@@ -597,6 +664,27 @@ pub fn matmul_i8_to_i32(
597664
Ok(())
598665
}
599666

667+
/// Subtract `128 · colsum(B[:, j])` from each `c[i, j]` lane.
668+
///
669+
/// Used by both the AMX and AVX-512-VNNI arms of `matmul_i8_to_i32`
670+
/// to undo the LHS sign-shift bias (A_i8 → A_u8 via +128 means
671+
/// `A_u8 · B = (A_i8 + 128) · B = A_i8 · B + 128 · sum_k B[k, j]`).
672+
/// Pure integer arithmetic, no rounding — the public result is
673+
/// bit-identical to the scalar i8 × i8 → i32 reference.
674+
fn subtract_i8_to_u8_bias(c: &mut [i32], b_i8: &[i8], m: usize, n: usize, k: usize) {
675+
let mut colsum = vec![0i32; n];
676+
for p in 0..k {
677+
for j in 0..n {
678+
colsum[j] += b_i8[p * n + j] as i32;
679+
}
680+
}
681+
for i in 0..m {
682+
for j in 0..n {
683+
c[i * n + j] -= 128 * colsum[j];
684+
}
685+
}
686+
}
687+
600688
#[cfg(test)]
601689
mod tests {
602690
use super::*;

src/hpc/bf16_tile_gemm.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,23 @@ pub fn bf16_tile_gemm_16x16(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: us
6060
// ═════════════════════════════════════════════════════════════════════
6161

6262
/// AMX tile GEMM. B must be pre-VNNI-packed (see `vnni_pack_bf16`).
63+
/// **Accumulates** into the caller's `c` buffer — matches the
64+
/// documented `C += A·B` semantics. The C tile (tmm0) is preloaded
65+
/// from `c` before the TDPBF16PS loop so any pre-existing values are
66+
/// preserved. (Same accumulator-preservation fix the int8 sibling
67+
/// got after codex P1 on PR #184: prior `tile_zero(0)` discarded
68+
/// pre-existing C values even though docs promised accumulation.)
69+
///
6370
/// # Safety
6471
/// Caller must have verified `amx_available() == true`.
6572
#[inline]
6673
unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) {
6774
// Tile config: shapes at K_bytes=64 match BF16 K=32 case
6875
let cfg = TileConfig::for_dpbusd(64);
6976
tile_loadconfig(&cfg);
70-
tile_zero(0);
77+
// Preload C accumulator from caller's buffer (was tile_zero(0)
78+
// pre-fix — see method-level note above).
79+
tile_load(0, c.as_ptr() as *const u8, 64);
7180

7281
// Accumulate over K/32 tile blocks
7382
let k_blocks = k / 32;

0 commit comments

Comments
 (0)