@@ -94,13 +94,32 @@ pub unsafe fn tile_release() {
9494
9595/// Load tile from memory.
9696///
97+ /// Encoding: `TILELOADD tmmN, [rcx + rax]` is VEX `C4 E2 7B 4B /r` with
98+ /// a SIB byte selecting `[rcx + rax]`. The ModR/M `/r` field encodes the
99+ /// destination tile via `reg = N` (3-bit tile index). Per-tile bytes:
100+ ///
101+ /// tmm0: C4 E2 7B 4B **04** 08
102+ /// tmm1: C4 E2 7B 4B **0C** 08
103+ /// tmm2: C4 E2 7B 4B **14** 08
104+ ///
105+ /// `04 | (N << 3)` gives the ModR/M byte; the `08` SIB is the same
106+ /// across tiles. tmm0 was added when codex flagged the accumulator-
107+ /// preservation bug on PR #184 (`tile_zero(0)` + `tile_store(0, c)`
108+ /// discarded any pre-existing C values — the fix is `tile_load(0, c)`
109+ /// instead of `tile_zero(0)` so TDPBUSD/TDPBF16PS truly accumulate as
110+ /// the documented `C += A·B` semantics promise).
111+ ///
97112/// # Safety
98113/// Pointer must be valid, stride must match tile config.
99114#[ inline]
100115pub unsafe fn tile_load ( tile : u8 , ptr : * const u8 , stride : usize ) {
101116 match tile {
102- // TILELOADD tmm0, [ptr + stride*row]
103- // Encoding: VEX.128.F2.0F38.W0 4B /r with memory operand
117+ 0 => asm ! (
118+ ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x04, 0x08" ,
119+ in( "rcx" ) ptr,
120+ in( "rax" ) stride,
121+ options( nostack) ,
122+ ) ,
104123 1 => asm ! (
105124 ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x0c, 0x08" ,
106125 in( "rcx" ) ptr,
@@ -193,6 +212,32 @@ pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) {
193212 }
194213}
195214
215+ /// Pack B[K, N] i8 row-major into K/4 × (N*4) VNNI quads for `TDPBUSD`.
216+ ///
217+ /// Output layout required by `TDPBUSD` tile 2 (16 rows × 64 bytes):
218+ /// dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]
219+ ///
220+ /// For N=16 (AMX tile width), each output "row" holds 16 i8 quads = 64
221+ /// bytes (matches the 64-byte tile row width). K must be a multiple of
222+ /// 4. The same layout is used for `u8` operands (just bit-cast through
223+ /// — VNNI doesn't care about sign at the packing layer; sign
224+ /// interpretation happens inside TDPBUSD which treats A as u8 and B
225+ /// as i8 for the multiply).
226+ #[ inline]
227+ pub fn vnni_pack_i8 ( src : & [ i8 ] , dst : & mut [ i8 ] , k : usize , n : usize ) {
228+ debug_assert_eq ! ( src. len( ) , k * n) ;
229+ debug_assert_eq ! ( dst. len( ) , k * n) ;
230+ debug_assert_eq ! ( k % 4 , 0 , "K must be multiple of 4 for VNNI INT8 quads" ) ;
231+ for kb in 0 ..( k / 4 ) {
232+ let dst_row = kb * n * 4 ;
233+ for j in 0 ..n {
234+ for p in 0 ..4 {
235+ dst[ dst_row + j * 4 + p] = src[ ( 4 * kb + p) * n + j] ;
236+ }
237+ }
238+ }
239+ }
240+
196241// ═══════════════════════════════════════════════════════════════════════════
197242// Public ndarray-typed matmul API (sprint A4 / Burn parity item 6)
198243// ═══════════════════════════════════════════════════════════════════════════
@@ -207,7 +252,7 @@ pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) {
207252// strided (e.g. `view.slice(s![.., ..;2])`). Strided inputs are repacked
208253// into contiguous staging buffers before the kernel runs.
209254
210- use crate :: hpc:: quantized:: { bf16_gemm_f32, int8_gemm_i32 , BF16 } ;
255+ use crate :: hpc:: quantized:: { bf16_gemm_f32, BF16 } ;
211256use crate :: { ArrayView2 , ArrayViewMut2 } ;
212257
213258/// Errors returned by the public AMX matmul API.
@@ -537,14 +582,17 @@ pub fn matmul_f32(
537582
538583/// Matrix multiply i8 × i8 → i32: `out = lhs · rhs`.
539584///
540- /// On AMX hosts uses `TDPBUSD` (256 MACs/instr); otherwise falls back to
541- /// the scalar `int8_gemm_i32`.
585+ /// On AMX hosts with 16/16/64-aligned shapes uses `TDPBUSD` via the
586+ /// 16×16 tile kernel in [`crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16`]
587+ /// — 16 384 MACs per instruction. Mis-aligned shapes (or non-AMX hosts)
588+ /// fall back to the scalar i8×i8 → i32 reference.
542589///
543- /// Note: `TDPBUSD` natively expects unsigned-by-signed (u8 × i8). For the
544- /// signed-by-signed surface required here, the LHS is shifted into the
545- /// unsigned domain and the bias subtracted from the accumulator (only on
546- /// the AMX path; the scalar path operates directly in i8). The public
547- /// result is identical.
590+ /// Note: `TDPBUSD` natively expects unsigned-by-signed (u8 × i8). For
591+ /// the signed-by-signed surface required here, the LHS is shifted into
592+ /// the unsigned domain (i8 + 128 → u8) and the bias `128 · sum(B[:, j]
593+ /// over k)` is subtracted from the accumulator. The public result is
594+ /// bit-identical to the scalar reference because all arithmetic stays
595+ /// in i32 (no float rounding).
548596///
549597/// `out` must be row-contiguous; inputs may be strided.
550598pub fn matmul_i8_to_i32 (
@@ -556,26 +604,45 @@ pub fn matmul_i8_to_i32(
556604 let b_i8 = pack_contig ( & rhs) ;
557605 let mut c = vec ! [ 0i32 ; m * n] ;
558606
559- if amx_available ( ) {
560- // AMX TDPBUSD path: shift LHS i8 → u8 via (+128) and subtract the
561- // bias 128·sum(B[:, j] over k) afterwards. This keeps numerics exact .
607+ if amx_available ( ) && m % 16 == 0 && n % 16 == 0 && k % 64 == 0 {
608+ // Tier 1 — AMX TDPBUSD tile path: shift LHS i8 → u8 (+128),
609+ // tile-GEMM via int8_tile_gemm_16x16, subtract bias .
562610 let a_u8: Vec < u8 > = a_i8. iter ( ) . map ( |& v| ( v as i32 + 128 ) as u8 ) . collect ( ) ;
563611
564- // Compute C' = A_u8 · B_i8 in i32, then subtract 128 · colsum(B).
565- int8_gemm_i32 ( & a_u8, & b_i8, & mut c, m, n, k) ;
566- let mut colsum = vec ! [ 0i32 ; n] ;
567- for p in 0 ..k {
568- for j in 0 ..n {
569- colsum[ j] += b_i8[ p * n + j] as i32 ;
612+ let mut b_tile = vec ! [ 0i8 ; k * 16 ] ;
613+ let mut tile_c = vec ! [ 0i32 ; 256 ] ;
614+
615+ for j_tile in ( 0 ..n) . step_by ( 16 ) {
616+ for kk in 0 ..k {
617+ let row = kk * n + j_tile;
618+ b_tile[ kk * 16 ..( kk + 1 ) * 16 ]
619+ . copy_from_slice ( unsafe { core:: slice:: from_raw_parts ( b_i8. as_ptr ( ) . add ( row) , 16 ) } ) ;
570620 }
571- }
572- for i in 0 ..m {
573- for j in 0 ..n {
574- c[ i * n + j] -= 128 * colsum[ j] ;
621+ for i_tile in ( 0 ..m) . step_by ( 16 ) {
622+ let a_tile = & a_u8[ i_tile * k..( i_tile + 16 ) * k] ;
623+ tile_c. fill ( 0 ) ;
624+ crate :: hpc:: int8_tile_gemm:: int8_tile_gemm_16x16 ( a_tile, & b_tile, & mut tile_c, k) ;
625+ for ii in 0 ..16 {
626+ let dst_off = ( i_tile + ii) * n + j_tile;
627+ c[ dst_off..dst_off + 16 ] . copy_from_slice ( & tile_c[ ii * 16 ..( ii + 1 ) * 16 ] ) ;
628+ }
575629 }
576630 }
631+ subtract_i8_to_u8_bias ( & mut c, & b_i8, m, n, k) ;
632+ } else if cfg ! ( target_arch = "x86_64" ) && std:: is_x86_feature_detected!( "avx512vnni" ) {
633+ // Tier 2 — AVX-512 VPDPBUSD zmm: 64 MACs per instruction, no
634+ // shape-alignment requirement (M/N/K all handled via per-block
635+ // trim and scalar K-tail). Same sign-shift bias trick as AMX.
636+ let a_u8: Vec < u8 > = a_i8. iter ( ) . map ( |& v| ( v as i32 + 128 ) as u8 ) . collect ( ) ;
637+ // SAFETY: runtime feature-detected avx512vnni above.
638+ unsafe {
639+ crate :: hpc:: int8_tile_gemm:: int8_gemm_vpdpbusd_zmm ( & a_u8, & b_i8, & mut c, m, n, k) ;
640+ }
641+ subtract_i8_to_u8_bias ( & mut c, & b_i8, m, n, k) ;
577642 } else {
578- // Scalar i8×i8 → i32 reference.
643+ // Tier 3 — Scalar i8×i8 → i32 reference for non-x86 hosts,
644+ // pre-AVX-512 silicon, or shapes that don't satisfy either of
645+ // the SIMD tiers' alignment requirements.
579646 for i in 0 ..m {
580647 for p in 0 ..k {
581648 let av = a_i8[ i * k + p] as i32 ;
@@ -597,6 +664,27 @@ pub fn matmul_i8_to_i32(
597664 Ok ( ( ) )
598665}
599666
667+ /// Subtract `128 · colsum(B[:, j])` from each `c[i, j]` lane.
668+ ///
669+ /// Used by both the AMX and AVX-512-VNNI arms of `matmul_i8_to_i32`
670+ /// to undo the LHS sign-shift bias (A_i8 → A_u8 via +128 means
671+ /// `A_u8 · B = (A_i8 + 128) · B = A_i8 · B + 128 · sum_k B[k, j]`).
672+ /// Pure integer arithmetic, no rounding — the public result is
673+ /// bit-identical to the scalar i8 × i8 → i32 reference.
674+ fn subtract_i8_to_u8_bias ( c : & mut [ i32 ] , b_i8 : & [ i8 ] , m : usize , n : usize , k : usize ) {
675+ let mut colsum = vec ! [ 0i32 ; n] ;
676+ for p in 0 ..k {
677+ for j in 0 ..n {
678+ colsum[ j] += b_i8[ p * n + j] as i32 ;
679+ }
680+ }
681+ for i in 0 ..m {
682+ for j in 0 ..n {
683+ c[ i * n + j] -= 128 * colsum[ j] ;
684+ }
685+ }
686+ }
687+
600688#[ cfg( test) ]
601689mod tests {
602690 use super :: * ;
0 commit comments