|
7 | 7 |
|
8 | 8 | #![allow(non_snake_case)] |
9 | 9 |
|
| 10 | +use crate::{ArrayView2, ArrayViewMut2}; |
10 | 11 | use std::os::raw::{c_double, c_float, c_int, c_long, c_void}; |
11 | 12 |
|
12 | 13 | const CBLAS_ROW_MAJOR: c_int = 101; |
13 | 14 | const CBLAS_NO_TRANS: c_int = 111; |
14 | 15 |
|
| 16 | +// `cblas_gemm_s8u8s32` / `cblas_gemm_s8s8s32` use CBLAS_OFFSET enums for the |
| 17 | +// final argument (offset mode). `RowOffset = 171`, `ColOffset = 172`, |
| 18 | +// `FixOffset = 173` — we always use `FixOffset` with a zero offset, which |
| 19 | +// matches Burn / rustyblas behaviour. |
| 20 | +const CBLAS_OFFSET_FIX: c_int = 173; |
| 21 | + |
15 | 22 | // ═══════════════════════════════════════════════════════════════ |
16 | 23 | // CBLAS (shared API surface with OpenBLAS) |
17 | 24 | // ═══════════════════════════════════════════════════════════════ |
@@ -56,6 +63,32 @@ extern "C" { |
56 | 63 | x: *const c_double, incx: c_int, |
57 | 64 | beta: c_double, y: *mut c_double, incy: c_int, |
58 | 65 | ); |
| 66 | + |
| 67 | + // Mixed-precision GEMM: BF16 inputs, F32 accumulator. |
| 68 | + // MKL takes `*const u16` for BF16 operands (no native bf16 type in C ABI). |
| 69 | + // Reference: oneAPI MKL Developer Reference, "cblas_gemm_bf16bf16f32". |
| 70 | + fn cblas_gemm_bf16bf16f32( |
| 71 | + layout: c_int, transa: c_int, transb: c_int, |
| 72 | + m: c_int, n: c_int, k: c_int, |
| 73 | + alpha: c_float, a: *const u16, lda: c_int, |
| 74 | + b: *const u16, ldb: c_int, |
| 75 | + beta: c_float, c: *mut c_float, ldc: c_int, |
| 76 | + ); |
| 77 | + |
| 78 | + // Integer GEMM: i8 × i8 → i32. |
| 79 | + // The trailing offset arguments take CBLAS_OFFSET (= FixOffset) plus a |
| 80 | + // pointer to the offset value. Passing zero offsets matches a plain |
| 81 | + // matmul without zero-point correction. |
| 82 | + // Reference: oneAPI MKL Developer Reference, "cblas_gemm_s8s8s32". |
| 83 | + fn cblas_gemm_s8s8s32( |
| 84 | + layout: c_int, transa: c_int, transb: c_int, offsetc: c_int, |
| 85 | + m: c_int, n: c_int, k: c_int, |
| 86 | + alpha: c_float, |
| 87 | + a: *const i8, lda: c_int, oa: i8, |
| 88 | + b: *const i8, ldb: c_int, ob: i8, |
| 89 | + beta: c_float, c: *mut i32, ldc: c_int, |
| 90 | + co: *const i32, |
| 91 | + ); |
59 | 92 | } |
60 | 93 |
|
61 | 94 | // ═══════════════════════════════════════════════════════════════ |
@@ -235,3 +268,237 @@ pub const fn sgemm_nr() -> usize { 16 } |
235 | 268 | pub const fn sgemm_mr() -> usize { 6 } |
236 | 269 | pub const fn dgemm_nr() -> usize { 8 } |
237 | 270 | pub const fn dgemm_mr() -> usize { 6 } |
| 271 | + |
| 272 | +// ═══════════════════════════════════════════════════════════════ |
| 273 | +// Public ndarray-shaped GEMM API (Burn integration surface) |
| 274 | +// ═══════════════════════════════════════════════════════════════ |
| 275 | +// |
| 276 | +// These wrappers accept `ArrayView2` / `ArrayViewMut2` and forward to the |
| 277 | +// CBLAS FFI declared above. They handle row-major / column-major layout |
| 278 | +// detection from ndarray strides and return a structured error if the input |
| 279 | +// is non-contiguous along its leading dimension (which CBLAS cannot express). |
| 280 | + |
| 281 | +/// Errors returned from the MKL ndarray-shaped GEMM wrappers. |
| 282 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 283 | +pub enum MklError { |
| 284 | + /// Inner dimensions of A and B don't match (`A.cols != B.rows`). |
| 285 | + ShapeMismatch { |
| 286 | + a_shape: (usize, usize), |
| 287 | + b_shape: (usize, usize), |
| 288 | + }, |
| 289 | + /// Output `C` dimensions don't match `(A.rows, B.cols)`. |
| 290 | + OutputShapeMismatch { |
| 291 | + expected: (usize, usize), |
| 292 | + got: (usize, usize), |
| 293 | + }, |
| 294 | + /// One of the arrays is not stride-compatible with CBLAS. |
| 295 | + /// |
| 296 | + /// CBLAS requires that one of the two strides is `1` (the contiguous |
| 297 | + /// dimension) and the other is `>= the contiguous extent`. Arbitrary |
| 298 | + /// striding (e.g. from a non-contiguous slice) is not supported — copy |
| 299 | + /// to a contiguous buffer first. |
| 300 | + NonContiguous { which: &'static str }, |
| 301 | + /// The bf16 / int8 routines are not available in this MKL build, or are |
| 302 | + /// stubbed out (e.g. older MKL versions predate `cblas_gemm_*`). |
| 303 | + Unsupported(&'static str), |
| 304 | +} |
| 305 | + |
| 306 | +impl core::fmt::Display for MklError { |
| 307 | + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
| 308 | + match self { |
| 309 | + MklError::ShapeMismatch { a_shape, b_shape } => write!( |
| 310 | + f, |
| 311 | + "MKL GEMM shape mismatch: A is {:?}, B is {:?}", |
| 312 | + a_shape, b_shape |
| 313 | + ), |
| 314 | + MklError::OutputShapeMismatch { expected, got } => write!( |
| 315 | + f, |
| 316 | + "MKL GEMM output shape mismatch: expected {:?}, got {:?}", |
| 317 | + expected, got |
| 318 | + ), |
| 319 | + MklError::NonContiguous { which } => { |
| 320 | + write!(f, "MKL GEMM operand `{}` is not stride-compatible with CBLAS", which) |
| 321 | + } |
| 322 | + MklError::Unsupported(msg) => write!(f, "MKL GEMM unsupported: {}", msg), |
| 323 | + } |
| 324 | + } |
| 325 | +} |
| 326 | + |
| 327 | +impl std::error::Error for MklError {} |
| 328 | + |
| 329 | +/// CBLAS layout descriptor extracted from an ndarray view. |
| 330 | +struct BlasLayout { |
| 331 | + layout: c_int, |
| 332 | + trans: c_int, |
| 333 | + ld: c_int, |
| 334 | +} |
| 335 | + |
| 336 | +/// Inspect strides and produce a CBLAS layout/transpose/leading-dimension |
| 337 | +/// triple for a 2D ndarray view. Returns `None` if neither dimension has |
| 338 | +/// stride 1 (i.e. the matrix is non-contiguous in both directions). |
| 339 | +fn blas_layout<S: crate::RawData>(view: &crate::ArrayBase<S, crate::Ix2>) -> Option<BlasLayout> { |
| 340 | + let (rows, cols) = view.dim(); |
| 341 | + let strides = view.strides(); |
| 342 | + let rs = strides[0]; |
| 343 | + let cs = strides[1]; |
| 344 | + // Row-major: stride between rows is the leading dim, columns are stride 1. |
| 345 | + if cs == 1 && (rs >= cols as isize || rows <= 1) { |
| 346 | + return Some(BlasLayout { layout: CBLAS_ROW_MAJOR, trans: CBLAS_NO_TRANS, ld: rs.max(1) as c_int }); |
| 347 | + } |
| 348 | + // Column-major: stride between cols is the leading dim, rows are stride 1. |
| 349 | + // We expose this to CBLAS as a *row-major transposed* matrix so we keep a |
| 350 | + // single `layout` argument across all three operands. |
| 351 | + if rs == 1 && (cs >= rows as isize || cols <= 1) { |
| 352 | + return Some(BlasLayout { layout: CBLAS_ROW_MAJOR, trans: 112 /* CblasTrans */, ld: cs.max(1) as c_int }); |
| 353 | + } |
| 354 | + None |
| 355 | +} |
| 356 | + |
| 357 | +/// `C := alpha * A * B + beta * C` for `f32` matrices via MKL `cblas_sgemm`. |
| 358 | +pub fn sgemm( |
| 359 | + a: ArrayView2<f32>, |
| 360 | + b: ArrayView2<f32>, |
| 361 | + mut c: ArrayViewMut2<f32>, |
| 362 | + alpha: f32, |
| 363 | + beta: f32, |
| 364 | +) -> Result<(), MklError> { |
| 365 | + let (m, k) = a.dim(); |
| 366 | + let (kb, n) = b.dim(); |
| 367 | + if k != kb { |
| 368 | + return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); |
| 369 | + } |
| 370 | + if c.dim() != (m, n) { |
| 371 | + return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); |
| 372 | + } |
| 373 | + let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; |
| 374 | + let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; |
| 375 | + let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?; |
| 376 | + if lc.trans != CBLAS_NO_TRANS { |
| 377 | + return Err(MklError::NonContiguous { which: "c" }); |
| 378 | + } |
| 379 | + unsafe { |
| 380 | + cblas_sgemm( |
| 381 | + lc.layout, la.trans, lb.trans, |
| 382 | + m as c_int, n as c_int, k as c_int, |
| 383 | + alpha, a.as_ptr(), la.ld, |
| 384 | + b.as_ptr(), lb.ld, |
| 385 | + beta, c.as_mut_ptr(), lc.ld, |
| 386 | + ); |
| 387 | + } |
| 388 | + Ok(()) |
| 389 | +} |
| 390 | + |
| 391 | +/// `C := alpha * A * B + beta * C` for `f64` matrices via MKL `cblas_dgemm`. |
| 392 | +pub fn dgemm( |
| 393 | + a: ArrayView2<f64>, |
| 394 | + b: ArrayView2<f64>, |
| 395 | + mut c: ArrayViewMut2<f64>, |
| 396 | + alpha: f64, |
| 397 | + beta: f64, |
| 398 | +) -> Result<(), MklError> { |
| 399 | + let (m, k) = a.dim(); |
| 400 | + let (kb, n) = b.dim(); |
| 401 | + if k != kb { |
| 402 | + return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); |
| 403 | + } |
| 404 | + if c.dim() != (m, n) { |
| 405 | + return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); |
| 406 | + } |
| 407 | + let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; |
| 408 | + let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; |
| 409 | + let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?; |
| 410 | + if lc.trans != CBLAS_NO_TRANS { |
| 411 | + return Err(MklError::NonContiguous { which: "c" }); |
| 412 | + } |
| 413 | + unsafe { |
| 414 | + cblas_dgemm( |
| 415 | + lc.layout, la.trans, lb.trans, |
| 416 | + m as c_int, n as c_int, k as c_int, |
| 417 | + alpha, a.as_ptr(), la.ld, |
| 418 | + b.as_ptr(), lb.ld, |
| 419 | + beta, c.as_mut_ptr(), lc.ld, |
| 420 | + ); |
| 421 | + } |
| 422 | + Ok(()) |
| 423 | +} |
| 424 | + |
| 425 | +/// `C := alpha * A * B + beta * C` with BF16 inputs and `f32` accumulator, |
| 426 | +/// via MKL `cblas_gemm_bf16bf16f32`. |
| 427 | +/// |
| 428 | +/// This requires Intel MKL >= 2020 (for the bf16 GEMM kernel). On older MKL |
| 429 | +/// builds the symbol is missing and linking will fail at runtime — there is |
| 430 | +/// no compile-time fallback. |
| 431 | +pub fn sgemm_bf16( |
| 432 | + a: ArrayView2<crate::hpc::quantized::BF16>, |
| 433 | + b: ArrayView2<crate::hpc::quantized::BF16>, |
| 434 | + mut c: ArrayViewMut2<f32>, |
| 435 | + alpha: f32, |
| 436 | + beta: f32, |
| 437 | +) -> Result<(), MklError> { |
| 438 | + let (m, k) = a.dim(); |
| 439 | + let (kb, n) = b.dim(); |
| 440 | + if k != kb { |
| 441 | + return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); |
| 442 | + } |
| 443 | + if c.dim() != (m, n) { |
| 444 | + return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); |
| 445 | + } |
| 446 | + let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; |
| 447 | + let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; |
| 448 | + let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?; |
| 449 | + if lc.trans != CBLAS_NO_TRANS { |
| 450 | + return Err(MklError::NonContiguous { which: "c" }); |
| 451 | + } |
| 452 | + // BF16 is `#[repr(transparent)] (pub u16)`, so the pointer cast is sound. |
| 453 | + unsafe { |
| 454 | + cblas_gemm_bf16bf16f32( |
| 455 | + lc.layout, la.trans, lb.trans, |
| 456 | + m as c_int, n as c_int, k as c_int, |
| 457 | + alpha, |
| 458 | + a.as_ptr() as *const u16, la.ld, |
| 459 | + b.as_ptr() as *const u16, lb.ld, |
| 460 | + beta, c.as_mut_ptr(), lc.ld, |
| 461 | + ); |
| 462 | + } |
| 463 | + Ok(()) |
| 464 | +} |
| 465 | + |
| 466 | +/// `C := A * B` with `i8` inputs and `i32` accumulator, via MKL |
| 467 | +/// `cblas_gemm_s8s8s32` with zero offsets (no zero-point correction). |
| 468 | +/// |
| 469 | +/// Note: alpha/beta are fixed at `1.0` / `0.0` for the simple `Burn`-style |
| 470 | +/// signature. If you need scaling, call the FFI directly. This requires |
| 471 | +/// Intel MKL >= 2018 (when integer GEMM was introduced). |
| 472 | +pub fn sgemm_int8( |
| 473 | + a: ArrayView2<i8>, |
| 474 | + b: ArrayView2<i8>, |
| 475 | + mut c: ArrayViewMut2<i32>, |
| 476 | +) -> Result<(), MklError> { |
| 477 | + let (m, k) = a.dim(); |
| 478 | + let (kb, n) = b.dim(); |
| 479 | + if k != kb { |
| 480 | + return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); |
| 481 | + } |
| 482 | + if c.dim() != (m, n) { |
| 483 | + return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); |
| 484 | + } |
| 485 | + let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; |
| 486 | + let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; |
| 487 | + let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?; |
| 488 | + if lc.trans != CBLAS_NO_TRANS { |
| 489 | + return Err(MklError::NonContiguous { which: "c" }); |
| 490 | + } |
| 491 | + let co: i32 = 0; |
| 492 | + unsafe { |
| 493 | + cblas_gemm_s8s8s32( |
| 494 | + lc.layout, la.trans, lb.trans, CBLAS_OFFSET_FIX, |
| 495 | + m as c_int, n as c_int, k as c_int, |
| 496 | + 1.0_f32, |
| 497 | + a.as_ptr(), la.ld, 0_i8, |
| 498 | + b.as_ptr(), lb.ld, 0_i8, |
| 499 | + 0.0_f32, c.as_mut_ptr(), lc.ld, |
| 500 | + &co as *const i32, |
| 501 | + ); |
| 502 | + } |
| 503 | + Ok(()) |
| 504 | +} |
0 commit comments