Skip to content

Commit 6609f10

Browse files
authored
Merge pull request #104 from AdaWorldAPI/claude/teleport-session-setup-wMZfb
Add BF16 tile GEMM with AMX/AVX-512 dispatch
2 parents 00a3c16 + 77ac069 commit 6609f10

3 files changed

Lines changed: 243 additions & 0 deletions

File tree

src/hpc/amx_matmul.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,50 @@ pub unsafe fn tile_dpbusd() {
149149
asm!(".byte 0xc4, 0xe2, 0x73, 0x5e, 0xc1", options(nostack, nomem));
150150
}
151151

152+
/// TDPBF16PS: C += A(bf16) × B(bf16_vnni) → f32.
153+
/// tmm0 += tmm1 × tmm2.
154+
///
155+
/// 16×16 output accumulator (f32), 32 bf16 values per A row × 32 bf16 values
156+
/// per B row in VNNI layout = 512 mul-adds in one instruction.
157+
///
158+
/// Encoding (analogous to TDPBUSD, pp field flips F2→F3, opcode 5E→5C):
159+
/// TDPBUSD tmm0, tmm1, tmm2 → C4 E2 73 5E C1
160+
/// TDPBF16PS tmm0, tmm1, tmm2 → C4 E2 72 5C C1
161+
///
162+
/// Tile shapes at K=32, M=N=16 (identical to TDPBUSD max at K_bytes=64):
163+
/// tmm0 (C): 16×16 f32 (16 rows × 64 bytes)
164+
/// tmm1 (A): 16×32 bf16 (16 rows × 64 bytes, plain row-major)
165+
/// tmm2 (B): 16×16 bf16 pairs (K/2=16 rows × 64 bytes, VNNI pairs)
166+
///
167+
/// # Safety
168+
/// Tiles 0/1/2 must be configured via `tile_loadconfig(&TileConfig::for_dpbusd(64))`
169+
/// and loaded with valid data; AMX must be OS-enabled (check `amx_available()`).
170+
#[inline]
171+
pub unsafe fn tile_dpbf16ps() {
172+
asm!(".byte 0xc4, 0xe2, 0x72, 0x5c, 0xc1", options(nostack, nomem));
173+
}
174+
175+
/// Pack B[K, N] bf16 row-major into K/2 × (N*2) VNNI pairs (in-place target).
176+
/// Output layout required by TDPBF16PS tile 2:
177+
/// dst[i, 2j] = src[2i, j]
178+
/// dst[i, 2j+1] = src[2i+1, j]
179+
///
180+
/// For N=16 (AMX tile width), each output "row" holds 16 bf16 pairs = 64 bytes.
181+
/// K must be even.
182+
#[inline]
183+
pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) {
184+
debug_assert_eq!(src.len(), k * n);
185+
debug_assert_eq!(dst.len(), k * n);
186+
debug_assert_eq!(k % 2, 0, "K must be even for VNNI BF16 pairs");
187+
for i in 0..(k / 2) {
188+
let dst_row = i * n * 2;
189+
for j in 0..n {
190+
dst[dst_row + 2 * j] = src[(2 * i) * n + j];
191+
dst[dst_row + 2 * j + 1] = src[(2 * i + 1) * n + j];
192+
}
193+
}
194+
}
195+
152196
#[cfg(test)]
153197
mod tests {
154198
use super::*;

src/hpc/bf16_tile_gemm.rs

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
//! BF16 tile GEMM polyfill — AMX (TDPBF16PS) with AVX-512 F32x16 fallback.
2+
//!
3+
//! Same API, runtime tier dispatch via `amx_available()`. The AMX path uses
4+
//! the raw primitives in `hpc::amx_matmul`. The fallback decodes BF16→f32
5+
//! and uses `crate::simd::F32x16` + `mul_add` (VFMADD231PS on AVX-512,
6+
//! emulated as 2× F32x8 FMA on AVX2).
7+
//!
8+
//! Pattern: one dispatch check per call; caller supplies preallocated
9+
//! output and (for AMX) VNNI-packed B.
10+
//!
11+
//! Tile shape: M=16, N=16, K = multiple of 32.
12+
//!
13+
//! Usage:
14+
//! ```ignore
15+
//! use ndarray::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16;
16+
//! let mut c = vec![0.0f32; 16*16];
17+
//! bf16_tile_gemm_16x16(&a_bf16, &b_bf16_row_major, &mut c, k);
18+
//! ```
19+
20+
use crate::hpc::amx_matmul::{
21+
amx_available, TileConfig, tile_loadconfig, tile_zero,
22+
tile_load, tile_store, tile_release, tile_dpbf16ps, vnni_pack_bf16,
23+
};
24+
use crate::simd::{F32x16, bf16_to_f32_batch};
25+
26+
// ═════════════════════════════════════════════════════════════════════
27+
// Public API — safe dispatching wrapper
28+
// ═════════════════════════════════════════════════════════════════════
29+
30+
/// Compute C[16, 16] += A[16, K] × B[K, 16] where A, B are BF16 row-major
31+
/// and C is f32 row-major. K must be a multiple of 32.
32+
///
33+
/// Tier dispatch (runtime):
34+
/// AMX available → TDPBF16PS tile GEMM (16×16 × K/32 tile iterations)
35+
/// AMX unavailable → AVX-512 F32x16 FMA fallback (decode BF16→f32, gemm)
36+
///
37+
/// Both paths produce identical results up to BF16 precision (~1/128 per
38+
/// multiply, O(sqrt(K)) accumulated).
39+
pub fn bf16_tile_gemm_16x16(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) {
40+
assert_eq!(k % 32, 0, "K must be multiple of 32");
41+
assert_eq!(a_bf16.len(), 16 * k);
42+
assert_eq!(b_bf16.len(), k * 16);
43+
assert_eq!(c.len(), 16 * 16);
44+
45+
if amx_available() {
46+
// AMX path: pack B into VNNI, call tile GEMM
47+
let mut b_vnni = vec![0u16; k * 16];
48+
vnni_pack_bf16(b_bf16, &mut b_vnni, k, 16);
49+
// SAFETY: amx_available() just confirmed CPUID + XCR0 + prctl.
50+
unsafe { amx_path(a_bf16, &b_vnni, c, k); }
51+
} else {
52+
fallback_path(a_bf16, b_bf16, c, k);
53+
}
54+
}
55+
56+
// ═════════════════════════════════════════════════════════════════════
57+
// AMX path (TDPBF16PS)
58+
// ═════════════════════════════════════════════════════════════════════
59+
60+
/// AMX tile GEMM. B must be pre-VNNI-packed (see `vnni_pack_bf16`).
61+
/// # Safety
62+
/// Caller must have verified `amx_available() == true`.
63+
#[inline]
64+
unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) {
65+
// Tile config: shapes at K_bytes=64 match BF16 K=32 case
66+
let cfg = TileConfig::for_dpbusd(64);
67+
tile_loadconfig(&cfg);
68+
tile_zero(0);
69+
70+
// Accumulate over K/32 tile blocks
71+
let k_blocks = k / 32;
72+
let a_stride = (k * 2) as usize; // full A row stride in bytes (bf16 = 2B)
73+
let b_stride = 64usize; // VNNI row stride in bytes
74+
75+
for kb in 0..k_blocks {
76+
let a_ptr = a_bf16.as_ptr().add(kb * 32) as *const u8;
77+
let b_ptr = b_vnni.as_ptr().add(kb * 16 * 32) as *const u8;
78+
tile_load(1, a_ptr, a_stride);
79+
tile_load(2, b_ptr, b_stride);
80+
tile_dpbf16ps();
81+
}
82+
83+
tile_store(0, c.as_mut_ptr() as *mut u8, 64);
84+
tile_release();
85+
}
86+
87+
// ═════════════════════════════════════════════════════════════════════
88+
// AVX-512 fallback (F32x16 + mul_add FMA)
89+
// ═════════════════════════════════════════════════════════════════════
90+
91+
/// Fallback: decode BF16→f32 and run a tight F32x16 GEMM with mul_add FMA.
92+
/// When AVX-512 is the compile-time baseline, this uses native __m512 FMA;
93+
/// on AVX2 it uses the emulated F32x16 = (F32x8, F32x8) pair — same logic.
94+
fn fallback_path(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) {
95+
// Decode BF16 → f32 (batch via SIMD when avx512bf16 / avx2 available)
96+
let mut a_f32 = vec![0.0f32; a_bf16.len()];
97+
let mut b_f32 = vec![0.0f32; b_bf16.len()];
98+
bf16_to_f32_batch(a_bf16, &mut a_f32);
99+
bf16_to_f32_batch(b_bf16, &mut b_f32);
100+
101+
// Tight GEMM: for each output (i,j), dot row-of-A with col-of-B via F32x16+FMA.
102+
// B is row-major [K, 16]; j-th column is b_f32[kk*16 + j] over kk=0..K.
103+
// We gather the column into a stack-sized buffer once per (i,j) pair to hit
104+
// the chunks_exact(16) + mul_add fast path on contiguous memory.
105+
for i in 0..16 {
106+
let a_row = &a_f32[i * k .. i * k + k];
107+
for j in 0..16 {
108+
// Stream the column into a contiguous buffer
109+
let mut col = vec![0.0f32; k];
110+
for kk in 0..k { col[kk] = b_f32[kk * 16 + j]; }
111+
112+
// Accumulate via F32x16::mul_add (FMA)
113+
let mut acc = F32x16::splat(0.0);
114+
for (ra, rb) in a_row.chunks_exact(16).zip(col.chunks_exact(16)) {
115+
let av = F32x16::from_slice(ra);
116+
let bv = F32x16::from_slice(rb);
117+
acc = av.mul_add(bv, acc);
118+
}
119+
c[i * 16 + j] += acc.reduce_sum();
120+
}
121+
}
122+
}
123+
124+
// ═════════════════════════════════════════════════════════════════════
125+
// Tests
126+
// ═════════════════════════════════════════════════════════════════════
127+
128+
#[cfg(test)]
129+
mod tests {
130+
use super::*;
131+
use crate::simd::{f32_to_bf16_batch, bf16_to_f32_batch};
132+
133+
/// Scalar BF16 reference (f32-accumulated) — ground truth.
134+
fn ref_gemm(a: &[f32], b: &[f32], c: &mut [f32], k: usize) {
135+
for i in 0..16 {
136+
for j in 0..16 {
137+
let mut s = 0.0f32;
138+
for kk in 0..k {
139+
s += a[i * k + kk] * b[kk * 16 + j];
140+
}
141+
c[i * 16 + j] = s;
142+
}
143+
}
144+
}
145+
146+
#[test]
147+
fn fallback_matches_scalar_reference_k64() {
148+
let k = 64;
149+
// Deterministic pseudo-random inputs
150+
let mut a_f32 = vec![0.0f32; 16 * k];
151+
let mut b_f32 = vec![0.0f32; k * 16];
152+
for i in 0..a_f32.len() {
153+
a_f32[i] = (((i as i32).wrapping_mul(1103515245).wrapping_add(12345) >> 8) as f32
154+
/ 2147483648.0).clamp(-1.0, 1.0);
155+
}
156+
for i in 0..b_f32.len() {
157+
b_f32[i] = (((i as i32).wrapping_mul(69069).wrapping_add(1) >> 8) as f32
158+
/ 2147483648.0).clamp(-1.0, 1.0);
159+
}
160+
let mut a_bf16 = vec![0u16; a_f32.len()];
161+
let mut b_bf16 = vec![0u16; b_f32.len()];
162+
f32_to_bf16_batch(&a_f32, &mut a_bf16);
163+
f32_to_bf16_batch(&b_f32, &mut b_bf16);
164+
165+
// Reference uses bf16-truncated inputs (matches what the GEMM sees)
166+
let mut a_back = vec![0.0f32; a_f32.len()];
167+
let mut b_back = vec![0.0f32; b_f32.len()];
168+
bf16_to_f32_batch(&a_bf16, &mut a_back);
169+
bf16_to_f32_batch(&b_bf16, &mut b_back);
170+
let mut c_ref = vec![0.0f32; 16 * 16];
171+
ref_gemm(&a_back, &b_back, &mut c_ref, k);
172+
173+
// Fallback GEMM
174+
let mut c_fb = vec![0.0f32; 16 * 16];
175+
fallback_path(&a_bf16, &b_bf16, &mut c_fb, k);
176+
177+
// Compare — should match exactly (same arithmetic, f32 precision)
178+
let mut max_err = 0.0f32;
179+
for i in 0..(16 * 16) {
180+
let e = (c_fb[i] - c_ref[i]).abs();
181+
if e > max_err { max_err = e; }
182+
}
183+
assert!(max_err < 1e-3, "fallback vs scalar ref max_err = {}", max_err);
184+
}
185+
186+
#[test]
187+
fn public_api_runs_on_any_hardware() {
188+
// Just sanity: calling the public API doesn't panic regardless of AMX.
189+
// On AMX hardware it takes the tile path; on this test host likely fallback.
190+
let k = 32;
191+
let a = vec![0u16; 16 * k];
192+
let b = vec![0u16; k * 16];
193+
let mut c = vec![0.0f32; 16 * 16];
194+
bf16_tile_gemm_16x16(&a, &b, &mut c, k);
195+
// All zeros × all zeros = 0
196+
for v in c.iter() { assert_eq!(*v, 0.0); }
197+
}
198+
}

src/hpc/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ pub mod cascade;
5555
pub mod heel_f64x8;
5656
#[allow(missing_docs)]
5757
pub mod amx_matmul;
58+
pub mod bf16_tile_gemm;
5859
#[allow(missing_docs)]
5960
pub mod bf16_truth;
6061
#[allow(missing_docs)]

0 commit comments

Comments
 (0)