Commit 90da43f
committed
feat(amx): public ndarray-typed matmul API (sprint A4)
Adds three public entry points and a `MatmulError` enum on top of the
existing AMX primitives in `hpc::amx_matmul`:
matmul_f32(lhs, rhs, out) f32 x f32 -> f32
matmul_bf16_to_f32(lhs, rhs, out) BF16 x BF16 -> f32
matmul_i8_to_i32(lhs, rhs, out) i8 x i8 -> i32
All three accept `ArrayView2` / `ArrayViewMut2`. Strided inputs are
repacked into contiguous staging buffers before the kernel runs; the
output must be row-stride-1 (returns `MatmulError::NonContiguousOutput`
otherwise). On AMX-enabled hosts the routines drive `TDPBF16PS` /
`TDPBUSD` via the existing inline-asm primitives; on hosts without AMX
they fall through to `bf16_gemm_f32` / `int8_gemm_i32`. Burn parity
item 6.
Tests cover 16x16, 17x16 row-tail, 16x65 K-tail, strided LHS via
`slice(s![.., ..;2])`, shape-mismatch / non-contiguous-output rejection,
and the AMX-unavailable fallback path. 11/11 pass.
https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj1 parent 44c0845 commit 90da43f
1 file changed
Lines changed: 449 additions & 7 deletions
0 commit comments