|
| 1 | +//! NEON + BF16 tier — ARMv8.6-A `+bf16` (or ARMv8.2-A + optional `+bf16`). |
| 2 | +//! |
| 3 | +//! Builds on `simd_neon_dotprod.rs`. Adds the BF16 instruction family: |
| 4 | +//! BFDOT, BFMMLA, BFMLALB, BFMLALT, BFCVT. These are the bf16 cousins |
| 5 | +//! of dotprod — same 4× int8 throughput shape, but for the half-the- |
| 6 | +//! width bfloat16 type that LLM inference standardized on. |
| 7 | +//! |
| 8 | +//! # Silicon |
| 9 | +//! |
| 10 | +//! - **Apple M2 / M3 / M4** (Avalanche/Blizzard, Everest/Sawtooth, |
| 11 | +//! Tupai/Donan) — ARMv8.6-A+. BF16 always on. `sysctl |
| 12 | +//! hw.optional.arm.FEAT_BF16` returns 1. M1 does NOT have BF16 — it's |
| 13 | +//! ARMv8.5-A. |
| 14 | +//! - **Snapdragon X Elite / X Plus** (Cortex-X4/X3 cores, Oryon |
| 15 | +//! prime) — ARMv8.7-A. BF16 always on. |
| 16 | +//! - **Cortex-A510 / A520 / A710 / A720 / X2 / X3 / X4 / X925** — |
| 17 | +//! ARMv9.0-A+. BF16 always on. |
| 18 | +//! - **NVIDIA Grace** (Neoverse V2) — ARMv9-A. BF16 on. |
| 19 | +//! - **AWS Graviton 3 / 3E / 4** (Neoverse V1/V2) — V1 added BF16 as |
| 20 | +//! optional ARMv8.4-A extension; V2 makes it mandatory. |
| 21 | +//! - **Ampere One (M-series)** — ARMv8.6-A+. BF16 on. |
| 22 | +//! |
| 23 | +//! # NOT in this tier |
| 24 | +//! |
| 25 | +//! - Apple M1 (ARMv8.5-A, no BF16) — falls back to `simd_neon_dotprod.rs` |
| 26 | +//! - Raspberry Pi 5 (Cortex-A76, ARMv8.2-A, no BF16) — `simd_neon_dotprod.rs` |
| 27 | +//! - Any Pi 3/4 / Cortex-A53/A72 — `simd_neon_baseline.rs` |
| 28 | +//! |
| 29 | +//! # How to detect at runtime |
| 30 | +//! |
| 31 | +//! - **Linux**: `/proc/cpuinfo` Features line should show `bf16`. |
| 32 | +//! `getauxval(AT_HWCAP2) & HWCAP2_BF16` (bit 14). |
| 33 | +//! `std::arch::is_aarch64_feature_detected!("bf16")` — recommended. |
| 34 | +//! - **macOS**: `sysctl hw.optional.arm.FEAT_BF16` → `1` means yes. |
| 35 | +//! On M2+ it's always 1; on M1 it's 0. |
| 36 | +//! - **Windows ARM64**: `IsProcessorFeaturePresent(PF_ARM_V83_BF16)` |
| 37 | +//! (constant added in Win11 24H2 SDK). |
| 38 | +//! |
| 39 | +//! # How to detect at compile time |
| 40 | +//! |
| 41 | +//! Cargo config flags: |
| 42 | +//! - `-Ctarget-feature=+bf16` — enables BF16 intrinsics + cfg gate. |
| 43 | +//! - `-Ctarget-cpu=apple-m2` — implies bf16 + everything else. |
| 44 | +//! - `-Ctarget-cpu=neoverse-v2` — Graviton 4 baseline. |
| 45 | +//! - `-Ctarget-cpu=cortex-x4` — Snapdragon X Elite / Cortex-X4 cores. |
| 46 | +//! |
| 47 | +//! Inside Rust: |
| 48 | +//! |
| 49 | +//! ```ignore |
| 50 | +//! #[cfg(all(target_arch = "aarch64", target_feature = "bf16"))] |
| 51 | +//! pub use crate::simd_neon_bf16::{BF16x8, BF16x16, bfdot, bfmmla}; |
| 52 | +//! ``` |
| 53 | +//! |
| 54 | +//! # What you get |
| 55 | +//! |
| 56 | +//! ## BF16 dot-product / matrix-multiply |
| 57 | +//! |
| 58 | +//! - `vbfdotq_f32(acc, a, b)` — 2×(2×bf16·2×bf16) → 2×f32, accumulated |
| 59 | +//! into 4×f32 register. The bf16 analogue of `vdotq_s32`. |
| 60 | +//! - `vbfmmlaq_f32(acc, a, b)` — 2×2 outer product BFMMLA. The crown |
| 61 | +//! jewel for transformer GEMM — accumulates a full 2×2 f32 tile per |
| 62 | +//! instruction. 8 bf16 mults + 4 f32 adds per cycle on M2. |
| 63 | +//! - `vbfmlalbq_f32` / `vbfmlaltq_f32` — bottom / top half multiply- |
| 64 | +//! accumulate, lane-by-lane variant of BFDOT. |
| 65 | +//! - `vbfmlalbq_laned_f32` — broadcast one lane across all bf16 |
| 66 | +//! multiplications. Useful for matvec. |
| 67 | +//! |
| 68 | +//! ## BF16 conversion |
| 69 | +//! |
| 70 | +//! - `vcvt_bf16_f32` / `vcvtq_low_bf16_f32` / `vcvtq_high_bf16_f32` — |
| 71 | +//! pack 4×f32 → 4×bf16. Hardware rounding (no manual RNE needed |
| 72 | +//! like the AVX-512BF16 `_mm512_cvtne2ps_pbh` path in |
| 73 | +//! `simd_avx512.rs`). |
| 74 | +//! - Scalar f32 ↔ bf16: trivial high-16-bit slice (the scalar paths in |
| 75 | +//! `src/simd.rs:1604-1626` work everywhere, including this tier). |
| 76 | +//! |
| 77 | +//! # Composed wrapper shapes |
| 78 | +//! |
| 79 | +//! - `BF16x8` = `bfloat16x8_t` — native 128-bit register, 8 bf16 lanes. |
| 80 | +//! Matches AVX-512BF16 `BF16x8 = __m128bh` in shape. |
| 81 | +//! - `BF16x16` = `[bfloat16x8_t; 2]` — two 128-bit registers, 16 bf16 |
| 82 | +//! lanes. Matches AVX-512BF16 `BF16x16 = __m256bh` in shape. |
| 83 | +//! |
| 84 | +//! # Cargo configs |
| 85 | +//! |
| 86 | +//! ```toml |
| 87 | +//! # .cargo/config-apple-m2.toml — Apple M2/M3/M4 |
| 88 | +//! [build] |
| 89 | +//! target = "aarch64-apple-darwin" |
| 90 | +//! [target.aarch64-apple-darwin] |
| 91 | +//! rustflags = ["-Ctarget-cpu=apple-m2", "-Ctarget-feature=+bf16,+dotprod,+fp16"] |
| 92 | +//! ``` |
| 93 | +//! |
| 94 | +//! ```toml |
| 95 | +//! # .cargo/config-graviton.toml — AWS Graviton 3/4 |
| 96 | +//! [build] |
| 97 | +//! target = "aarch64-unknown-linux-gnu" |
| 98 | +//! [target.aarch64-unknown-linux-gnu] |
| 99 | +//! rustflags = ["-Ctarget-cpu=neoverse-v2", "-Ctarget-feature=+bf16"] |
| 100 | +//! ``` |
| 101 | +//! |
| 102 | +//! ```toml |
| 103 | +//! # .cargo/config-snapdragon-x.toml — Snapdragon X Elite (Win/Linux) |
| 104 | +//! [build] |
| 105 | +//! target = "aarch64-pc-windows-msvc" # or aarch64-unknown-linux-gnu |
| 106 | +//! rustflags = ["-Ctarget-cpu=cortex-x4", "-Ctarget-feature=+bf16,+i8mm"] |
| 107 | +//! ``` |
| 108 | +//! |
| 109 | +//! # Stable-Rust constraint |
| 110 | +//! |
| 111 | +//! Same as the FP16 tier: `bfloat16x8_t` exists in `core::arch::aarch64` |
| 112 | +//! on stable, but the intrinsics (`vbfdotq_f32`, `vbfmmlaq_f32`, ...) |
| 113 | +//! are nightly-only (issue #117222). Two paths on stable 1.95: |
| 114 | +//! |
| 115 | +//! 1. **asm! byte encoding** — same pattern as `src/simd_amx.rs` |
| 116 | +//! uses for AMX. Example: |
| 117 | +//! ```ignore |
| 118 | +//! // BFDOT v0.4s, v1.8h, v2.8h |
| 119 | +//! asm!(".inst 0x4e41ec00", inout("v0") acc, in("v1") a, in("v2") b); |
| 120 | +//! // BFMMLA v0.4s, v1.8h, v2.8h |
| 121 | +//! asm!(".inst 0x6e42ec01", inout("v0") acc, in("v1") a, in("v2") b); |
| 122 | +//! ``` |
| 123 | +//! Verify the encoding with `aarch64-linux-gnu-objdump --disassemble` |
| 124 | +//! on a reference compile. |
| 125 | +//! 2. **Round-trip through f32** — convert bf16 → f32 (scalar bit- |
| 126 | +//! shift), use the existing `vfmaq_f32` from baseline NEON, convert |
| 127 | +//! back. Loses the 4× throughput; only as a correctness anchor for |
| 128 | +//! the asm path. |
| 129 | +//! |
| 130 | +//! Path (1) is the only one worth shipping. The asm-byte fallback IS |
| 131 | +//! how `simd_amx.rs` ships AMX on stable Rust today — same pattern. |
| 132 | +
|
| 133 | +#![cfg(all(target_arch = "aarch64", feature = "std"))] |
| 134 | + |
| 135 | +// ─── BF16 stubs ────────────────────────────────────────────────────── |
| 136 | + |
| 137 | +/// Placeholder for the BF16 8-lane native wrapper. |
| 138 | +/// |
| 139 | +/// Real implementation: `pub struct BF16x8(pub bfloat16x8_t)`. API |
| 140 | +/// surface mirrors `simd_avx512::BF16x8`: |
| 141 | +/// - `splat(bits: u16) -> Self` (broadcast bf16 bit pattern across 8 lanes) |
| 142 | +/// - `from_slice(s: &[u16]) -> Self` (load 8 raw bf16 bits as u16s) |
| 143 | +/// - `to_array(self) -> [u16; 8]` |
| 144 | +/// - `dot_f32(self, other: Self, acc: F32x4) -> F32x4` — wraps BFDOT |
| 145 | +/// - `cvt_to_f32_lo(self) -> F32x4`, `cvt_to_f32_hi(self) -> F32x4` |
| 146 | +/// |
| 147 | +/// Without `target_feature = "bf16"`, this falls back to round-trip |
| 148 | +/// through f32 (slow). With the feature on, it uses asm-byte BFDOT. |
| 149 | +pub struct BF16x8Stub; |
| 150 | + |
| 151 | +/// Placeholder for the BF16 16-lane composed wrapper. |
| 152 | +/// |
| 153 | +/// Real implementation: `pub struct BF16x16(pub [bfloat16x8_t; 2])`. |
| 154 | +/// API mirror of `simd_avx512::BF16x16`. The 16-lane variant is the |
| 155 | +/// natural width for matmul tile rows in transformer attention. |
| 156 | +pub struct BF16x16Stub; |
| 157 | + |
| 158 | +impl BF16x8Stub { |
| 159 | + pub fn unimplemented() -> ! { |
| 160 | + unimplemented!( |
| 161 | + "BF16x8 NEON bf16-tier implementation TODO. See \ |
| 162 | + src/simd_neon_bf16.rs module docs for the BFDOT / BFMMLA \ |
| 163 | + asm-byte encoding (stable Rust 1.95 can't reach the \ |
| 164 | + nightly-only vbfdotq_f32 intrinsic). Reference: \ |
| 165 | + src/simd_amx.rs's `.byte` pattern." |
| 166 | + ) |
| 167 | + } |
| 168 | +} |
| 169 | + |
| 170 | +impl BF16x16Stub { |
| 171 | + pub fn unimplemented() -> ! { |
| 172 | + unimplemented!( |
| 173 | + "BF16x16 NEON bf16-tier implementation TODO. Two-half \ |
| 174 | + composed wrapper [bfloat16x8_t; 2] — see module docs." |
| 175 | + ) |
| 176 | + } |
| 177 | +} |
| 178 | + |
| 179 | +// ─── BFMMLA: the prize intrinsic ───────────────────────────────────── |
| 180 | +// |
| 181 | +// BFMMLA is the most important instruction this tier unlocks. It |
| 182 | +// computes a 2×2 outer-product matrix multiply of bf16 inputs, |
| 183 | +// accumulating into a 2×2 f32 tile. One instruction = 8 bf16 mults + |
| 184 | +// 4 f32 adds. On Apple M2 the throughput is ~32 GFLOP/s per core in |
| 185 | +// bf16-matmul-bound kernels. |
| 186 | +// |
| 187 | +// Encoding for `BFMMLA Vd.4s, Vn.8h, Vm.8h`: 0x6e40_ec00 | (Vm << 16) |
| 188 | +// | (Vn << 5) | Vd. Use a `bfmmla!` macro to emit the asm-byte for any |
| 189 | +// (acc, a, b) v-register triple. |
| 190 | +// |
| 191 | +// TODO(Phase-3): implement `bfmmla(acc: F32x4, a: BF16x8, b: BF16x8) |
| 192 | +// -> F32x4` as the primary export. The rest of the BF16 API builds on |
| 193 | +// it (BFDOT is BFMMLA's diagonal, BFMLALB/T are its half-slices). |
| 194 | + |
| 195 | +// ─── BFDOT: same shape as DotProd, but bf16 ────────────────────────── |
| 196 | +// |
| 197 | +// Where `vdotq_s32(acc, a, b)` does 4×(4×i8·4×i8) → 4×i32, BFDOT does |
| 198 | +// 2×(2×bf16·2×bf16) → 2×f32 accumulated into 4×f32. The bf16 analogue |
| 199 | +// is HALF the lane count per output (2 vs 4) because bf16 is twice as |
| 200 | +// wide as i8. |
| 201 | +// |
| 202 | +// TODO(Phase-3): implement `bfdot(acc: F32x4, a: BF16x8, b: BF16x8) |
| 203 | +// -> F32x4`. Asm-byte for `BFDOT Vd.4s, Vn.8h, Vm.8h`: |
| 204 | +// 0x4e40_ec00 | (Vm << 16) | (Vn << 5) | Vd |
0 commit comments