Skip to content

Commit b91828b

Browse files
committed
feat(backend/mkl): public ndarray-shaped sgemm/dgemm/bf16/int8 (sprint A6)
Adds Burn-style public GEMM wrappers to ndarray::backend::mkl: pub fn sgemm(a, b, c, alpha, beta) -> Result<(), MklError> pub fn dgemm(a, b, c, alpha, beta) -> Result<(), MklError> pub fn sgemm_bf16(a, b, c, alpha, beta) -> Result<(), MklError> pub fn sgemm_int8(a, b, c) -> Result<(), MklError> Wrappers accept ArrayView2 / ArrayViewMut2 inputs, detect row- vs column-major layout from ndarray strides, and forward to the CBLAS FFI already declared for sgemm/dgemm. New extern decls cover cblas_gemm_bf16bf16f32 and cblas_gemm_s8s8s32 (real bindings, not stubs); they require recent MKL builds (>= 2018 for s8s8s32, >= 2020 for bf16bf16f32) and link via the existing -lmkl_rt path. Also flips mod mkl to pub mod mkl (gated on intel-mkl) so external crates can address the new entry points as ndarray::backend::mkl::sgemm. A new MklError enum reports shape mismatches, non-CBLAS-compatible strides, and unsupported feature paths. Acceptance: - cargo check (default features): clean - cargo check --features intel-mkl: clean (compile-only; link requires MKL) - cargo test --lib backend: 13/13 pass Note: commit unsigned because the signing server returned persistent "missing source" errors during this sprint; please re-sign on rebase or merge if signing policy requires it.
1 parent 44c0845 commit b91828b

2 files changed

Lines changed: 268 additions & 1 deletion

File tree

src/backend/mkl.rs

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,18 @@
77
88
#![allow(non_snake_case)]
99

10+
use crate::{ArrayView2, ArrayViewMut2};
1011
use std::os::raw::{c_double, c_float, c_int, c_long, c_void};
1112

1213
const CBLAS_ROW_MAJOR: c_int = 101;
1314
const CBLAS_NO_TRANS: c_int = 111;
1415

16+
// `cblas_gemm_s8u8s32` / `cblas_gemm_s8s8s32` use CBLAS_OFFSET enums for the
17+
// final argument (offset mode). `RowOffset = 171`, `ColOffset = 172`,
18+
// `FixOffset = 173` — we always use `FixOffset` with a zero offset, which
19+
// matches Burn / rustyblas behaviour.
20+
const CBLAS_OFFSET_FIX: c_int = 173;
21+
1522
// ═══════════════════════════════════════════════════════════════
1623
// CBLAS (shared API surface with OpenBLAS)
1724
// ═══════════════════════════════════════════════════════════════
@@ -56,6 +63,32 @@ extern "C" {
5663
x: *const c_double, incx: c_int,
5764
beta: c_double, y: *mut c_double, incy: c_int,
5865
);
66+
67+
// Mixed-precision GEMM: BF16 inputs, F32 accumulator.
68+
// MKL takes `*const u16` for BF16 operands (no native bf16 type in C ABI).
69+
// Reference: oneAPI MKL Developer Reference, "cblas_gemm_bf16bf16f32".
70+
fn cblas_gemm_bf16bf16f32(
71+
layout: c_int, transa: c_int, transb: c_int,
72+
m: c_int, n: c_int, k: c_int,
73+
alpha: c_float, a: *const u16, lda: c_int,
74+
b: *const u16, ldb: c_int,
75+
beta: c_float, c: *mut c_float, ldc: c_int,
76+
);
77+
78+
// Integer GEMM: i8 × i8 → i32.
79+
// The trailing offset arguments take CBLAS_OFFSET (= FixOffset) plus a
80+
// pointer to the offset value. Passing zero offsets matches a plain
81+
// matmul without zero-point correction.
82+
// Reference: oneAPI MKL Developer Reference, "cblas_gemm_s8s8s32".
83+
fn cblas_gemm_s8s8s32(
84+
layout: c_int, transa: c_int, transb: c_int, offsetc: c_int,
85+
m: c_int, n: c_int, k: c_int,
86+
alpha: c_float,
87+
a: *const i8, lda: c_int, oa: i8,
88+
b: *const i8, ldb: c_int, ob: i8,
89+
beta: c_float, c: *mut i32, ldc: c_int,
90+
co: *const i32,
91+
);
5992
}
6093

6194
// ═══════════════════════════════════════════════════════════════
@@ -235,3 +268,237 @@ pub const fn sgemm_nr() -> usize { 16 }
235268
pub const fn sgemm_mr() -> usize { 6 }
236269
pub const fn dgemm_nr() -> usize { 8 }
237270
pub const fn dgemm_mr() -> usize { 6 }
271+
272+
// ═══════════════════════════════════════════════════════════════
273+
// Public ndarray-shaped GEMM API (Burn integration surface)
274+
// ═══════════════════════════════════════════════════════════════
275+
//
276+
// These wrappers accept `ArrayView2` / `ArrayViewMut2` and forward to the
277+
// CBLAS FFI declared above. They handle row-major / column-major layout
278+
// detection from ndarray strides and return a structured error if the input
279+
// is non-contiguous along its leading dimension (which CBLAS cannot express).
280+
281+
/// Errors returned from the MKL ndarray-shaped GEMM wrappers.
282+
#[derive(Debug, Clone, PartialEq, Eq)]
283+
pub enum MklError {
284+
/// Inner dimensions of A and B don't match (`A.cols != B.rows`).
285+
ShapeMismatch {
286+
a_shape: (usize, usize),
287+
b_shape: (usize, usize),
288+
},
289+
/// Output `C` dimensions don't match `(A.rows, B.cols)`.
290+
OutputShapeMismatch {
291+
expected: (usize, usize),
292+
got: (usize, usize),
293+
},
294+
/// One of the arrays is not stride-compatible with CBLAS.
295+
///
296+
/// CBLAS requires that one of the two strides is `1` (the contiguous
297+
/// dimension) and the other is `>= the contiguous extent`. Arbitrary
298+
/// striding (e.g. from a non-contiguous slice) is not supported — copy
299+
/// to a contiguous buffer first.
300+
NonContiguous { which: &'static str },
301+
/// The bf16 / int8 routines are not available in this MKL build, or are
302+
/// stubbed out (e.g. older MKL versions predate `cblas_gemm_*`).
303+
Unsupported(&'static str),
304+
}
305+
306+
impl core::fmt::Display for MklError {
307+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
308+
match self {
309+
MklError::ShapeMismatch { a_shape, b_shape } => write!(
310+
f,
311+
"MKL GEMM shape mismatch: A is {:?}, B is {:?}",
312+
a_shape, b_shape
313+
),
314+
MklError::OutputShapeMismatch { expected, got } => write!(
315+
f,
316+
"MKL GEMM output shape mismatch: expected {:?}, got {:?}",
317+
expected, got
318+
),
319+
MklError::NonContiguous { which } => {
320+
write!(f, "MKL GEMM operand `{}` is not stride-compatible with CBLAS", which)
321+
}
322+
MklError::Unsupported(msg) => write!(f, "MKL GEMM unsupported: {}", msg),
323+
}
324+
}
325+
}
326+
327+
impl std::error::Error for MklError {}
328+
329+
/// CBLAS layout descriptor extracted from an ndarray view.
330+
struct BlasLayout {
331+
layout: c_int,
332+
trans: c_int,
333+
ld: c_int,
334+
}
335+
336+
/// Inspect strides and produce a CBLAS layout/transpose/leading-dimension
337+
/// triple for a 2D ndarray view. Returns `None` if neither dimension has
338+
/// stride 1 (i.e. the matrix is non-contiguous in both directions).
339+
fn blas_layout<S: crate::RawData>(view: &crate::ArrayBase<S, crate::Ix2>) -> Option<BlasLayout> {
340+
let (rows, cols) = view.dim();
341+
let strides = view.strides();
342+
let rs = strides[0];
343+
let cs = strides[1];
344+
// Row-major: stride between rows is the leading dim, columns are stride 1.
345+
if cs == 1 && (rs >= cols as isize || rows <= 1) {
346+
return Some(BlasLayout { layout: CBLAS_ROW_MAJOR, trans: CBLAS_NO_TRANS, ld: rs.max(1) as c_int });
347+
}
348+
// Column-major: stride between cols is the leading dim, rows are stride 1.
349+
// We expose this to CBLAS as a *row-major transposed* matrix so we keep a
350+
// single `layout` argument across all three operands.
351+
if rs == 1 && (cs >= rows as isize || cols <= 1) {
352+
return Some(BlasLayout { layout: CBLAS_ROW_MAJOR, trans: 112 /* CblasTrans */, ld: cs.max(1) as c_int });
353+
}
354+
None
355+
}
356+
357+
/// `C := alpha * A * B + beta * C` for `f32` matrices via MKL `cblas_sgemm`.
358+
pub fn sgemm(
359+
a: ArrayView2<f32>,
360+
b: ArrayView2<f32>,
361+
mut c: ArrayViewMut2<f32>,
362+
alpha: f32,
363+
beta: f32,
364+
) -> Result<(), MklError> {
365+
let (m, k) = a.dim();
366+
let (kb, n) = b.dim();
367+
if k != kb {
368+
return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() });
369+
}
370+
if c.dim() != (m, n) {
371+
return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() });
372+
}
373+
let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?;
374+
let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?;
375+
let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?;
376+
if lc.trans != CBLAS_NO_TRANS {
377+
return Err(MklError::NonContiguous { which: "c" });
378+
}
379+
unsafe {
380+
cblas_sgemm(
381+
lc.layout, la.trans, lb.trans,
382+
m as c_int, n as c_int, k as c_int,
383+
alpha, a.as_ptr(), la.ld,
384+
b.as_ptr(), lb.ld,
385+
beta, c.as_mut_ptr(), lc.ld,
386+
);
387+
}
388+
Ok(())
389+
}
390+
391+
/// `C := alpha * A * B + beta * C` for `f64` matrices via MKL `cblas_dgemm`.
392+
pub fn dgemm(
393+
a: ArrayView2<f64>,
394+
b: ArrayView2<f64>,
395+
mut c: ArrayViewMut2<f64>,
396+
alpha: f64,
397+
beta: f64,
398+
) -> Result<(), MklError> {
399+
let (m, k) = a.dim();
400+
let (kb, n) = b.dim();
401+
if k != kb {
402+
return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() });
403+
}
404+
if c.dim() != (m, n) {
405+
return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() });
406+
}
407+
let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?;
408+
let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?;
409+
let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?;
410+
if lc.trans != CBLAS_NO_TRANS {
411+
return Err(MklError::NonContiguous { which: "c" });
412+
}
413+
unsafe {
414+
cblas_dgemm(
415+
lc.layout, la.trans, lb.trans,
416+
m as c_int, n as c_int, k as c_int,
417+
alpha, a.as_ptr(), la.ld,
418+
b.as_ptr(), lb.ld,
419+
beta, c.as_mut_ptr(), lc.ld,
420+
);
421+
}
422+
Ok(())
423+
}
424+
425+
/// `C := alpha * A * B + beta * C` with BF16 inputs and `f32` accumulator,
426+
/// via MKL `cblas_gemm_bf16bf16f32`.
427+
///
428+
/// This requires Intel MKL >= 2020 (for the bf16 GEMM kernel). On older MKL
429+
/// builds the symbol is missing and linking will fail at runtime — there is
430+
/// no compile-time fallback.
431+
pub fn sgemm_bf16(
432+
a: ArrayView2<crate::hpc::quantized::BF16>,
433+
b: ArrayView2<crate::hpc::quantized::BF16>,
434+
mut c: ArrayViewMut2<f32>,
435+
alpha: f32,
436+
beta: f32,
437+
) -> Result<(), MklError> {
438+
let (m, k) = a.dim();
439+
let (kb, n) = b.dim();
440+
if k != kb {
441+
return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() });
442+
}
443+
if c.dim() != (m, n) {
444+
return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() });
445+
}
446+
let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?;
447+
let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?;
448+
let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?;
449+
if lc.trans != CBLAS_NO_TRANS {
450+
return Err(MklError::NonContiguous { which: "c" });
451+
}
452+
// BF16 is `#[repr(transparent)] (pub u16)`, so the pointer cast is sound.
453+
unsafe {
454+
cblas_gemm_bf16bf16f32(
455+
lc.layout, la.trans, lb.trans,
456+
m as c_int, n as c_int, k as c_int,
457+
alpha,
458+
a.as_ptr() as *const u16, la.ld,
459+
b.as_ptr() as *const u16, lb.ld,
460+
beta, c.as_mut_ptr(), lc.ld,
461+
);
462+
}
463+
Ok(())
464+
}
465+
466+
/// `C := A * B` with `i8` inputs and `i32` accumulator, via MKL
467+
/// `cblas_gemm_s8s8s32` with zero offsets (no zero-point correction).
468+
///
469+
/// Note: alpha/beta are fixed at `1.0` / `0.0` for the simple `Burn`-style
470+
/// signature. If you need scaling, call the FFI directly. This requires
471+
/// Intel MKL >= 2018 (when integer GEMM was introduced).
472+
pub fn sgemm_int8(
473+
a: ArrayView2<i8>,
474+
b: ArrayView2<i8>,
475+
mut c: ArrayViewMut2<i32>,
476+
) -> Result<(), MklError> {
477+
let (m, k) = a.dim();
478+
let (kb, n) = b.dim();
479+
if k != kb {
480+
return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() });
481+
}
482+
if c.dim() != (m, n) {
483+
return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() });
484+
}
485+
let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?;
486+
let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?;
487+
let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?;
488+
if lc.trans != CBLAS_NO_TRANS {
489+
return Err(MklError::NonContiguous { which: "c" });
490+
}
491+
let co: i32 = 0;
492+
unsafe {
493+
cblas_gemm_s8s8s32(
494+
lc.layout, la.trans, lb.trans, CBLAS_OFFSET_FIX,
495+
m as c_int, n as c_int, k as c_int,
496+
1.0_f32,
497+
a.as_ptr(), la.ld, 0_i8,
498+
b.as_ptr(), lb.ld, 0_i8,
499+
0.0_f32, c.as_mut_ptr(), lc.ld,
500+
&co as *const i32,
501+
);
502+
}
503+
Ok(())
504+
}

src/backend/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub(crate) mod kernels_avx512;
1616

1717

1818
#[cfg(feature = "intel-mkl")]
19-
mod mkl;
19+
pub mod mkl;
2020
#[cfg(feature = "openblas")]
2121
mod openblas;
2222

0 commit comments

Comments
 (0)