Skip to content

Commit 49cd860

Browse files
AdaWorldAPIclaude
andauthored
feat(simd): BF16x16 + F16x16 SIMD vectors + slice ops (#126, sprint W3-A)
Closes parity items 2 + 3. Scalar dispatch (upcast f32 -> op -> downcast). SIMD-accelerated paths (AVX2 emulation, AVX-512-BF16 native, NEON +fp16) are a follow-up. The scalar implementation is correct and portable, and unblocks burn's NdArrayElement bound for half types. - src/simd_half.rs: 691 LOC new module - src/lib.rs: pub mod simd_half declaration - src/simd.rs: re-exports 21 new tests, all passing. Total lib tests: 1817+ pass. https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj Co-authored-by: Claude <noreply@anthropic.com>
1 parent 0d22e44 commit 49cd860

3 files changed

Lines changed: 713 additions & 8 deletions

File tree

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ pub mod simd_int_ops;
260260
/// Half-precision SIMD vectors (`BF16x16`, `F16x16`) + slice-level ops.
261261
#[cfg(feature = "std")]
262262
#[allow(clippy::all, missing_docs, dead_code, unused_variables, unused_imports)]
263-
// pub mod simd_half; // TODO: BF16x16/F16x16 SIMD vectors (A2 WIP)
263+
pub mod simd_half;
264264

265265
/// Pluggable linear algebra backends (native SIMD, MKL, OpenBLAS).
266266
#[cfg(feature = "std")]

src/simd.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,14 +1210,28 @@ pub use crate::hpc::quantized::{
12101210
QuantParams,
12111211
};
12121212

1213-
// Half-precision SIMD vectors (BF16x16, F16x16) — runtime-dispatched, always
1213+
// Half-precision SIMD vectors (BF16x16, F16x16) — portable scalar impl, always
12141214
// available. Note: when `target_feature = "avx512bf16"` is active a separate
1215-
// hardware-only `BF16x16` is also exported above from `simd_avx512`. The
1216-
// hardware-native one ships unsafe `from_u16_slice` / `to_f32x16` and is
1217-
// distinct from the portable runtime-dispatched `simd_half::BF16x16`.
1218-
// TODO: BF16x16/F16x16 SIMD vector types + slice ops (A2 WIP — simd_half module)
1219-
// F16 type itself is available in hpc::quantized::F16.
1220-
// SIMD vectors land in Wave 3 after the A2 module is completed.
1215+
// hardware-native `BF16x16` is also exported above from `simd_avx512`; in that
1216+
// case we only re-export F16x16 + slice ops to avoid name collisions.
1217+
//
1218+
// On all other targets (including avx512f-without-bf16, NEON, scalar) the
1219+
// portable `simd_half::BF16x16` is the canonical 16-lane BF16 vector.
1220+
1221+
// Always re-export F16x16 + all slice-level ops (no naming conflict).
1222+
#[cfg(feature = "std")]
1223+
pub use crate::simd_half::{
1224+
F16x16,
1225+
add_bf16_inplace, mul_bf16_inplace,
1226+
add_f16_inplace, mul_f16_inplace,
1227+
cast_bf16_to_f32_batch, cast_f16_to_f32_batch,
1228+
cast_f32_to_bf16_batch, cast_f32_to_f16_batch,
1229+
};
1230+
1231+
// Re-export portable BF16x16 only when the hardware-native avx512bf16 variant
1232+
// is NOT active (otherwise `simd_avx512::BF16x16` already occupies the name).
1233+
#[cfg(all(feature = "std", not(all(target_arch = "x86_64", target_feature = "avx512bf16"))))]
1234+
pub use crate::simd_half::BF16x16 as BF16x16;
12211235

12221236
// K-means + L2 distance
12231237
pub use crate::hpc::cam_pq::{kmeans, squared_l2};

0 commit comments

Comments
 (0)