Skip to content

Commit 90da43f

Browse files
committed
feat(amx): public ndarray-typed matmul API (sprint A4)
Adds three public entry points and a `MatmulError` enum on top of the existing AMX primitives in `hpc::amx_matmul`: matmul_f32(lhs, rhs, out) f32 x f32 -> f32 matmul_bf16_to_f32(lhs, rhs, out) BF16 x BF16 -> f32 matmul_i8_to_i32(lhs, rhs, out) i8 x i8 -> i32 All three accept `ArrayView2` / `ArrayViewMut2`. Strided inputs are repacked into contiguous staging buffers before the kernel runs; the output must be row-stride-1 (returns `MatmulError::NonContiguousOutput` otherwise). On AMX-enabled hosts the routines drive `TDPBF16PS` / `TDPBUSD` via the existing inline-asm primitives; on hosts without AMX they fall through to `bf16_gemm_f32` / `int8_gemm_i32`. Burn parity item 6. Tests cover 16x16, 17x16 row-tail, 16x65 K-tail, strided LHS via `slice(s![.., ..;2])`, shape-mismatch / non-contiguous-output rejection, and the AMX-unavailable fallback path. 11/11 pass. https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj
1 parent 44c0845 commit 90da43f

1 file changed

Lines changed: 449 additions & 7 deletions

File tree

0 commit comments

Comments
 (0)