Skip to content

Commit 0c30fe2

Browse files
AdaWorldAPIclaude
andauthored
feat(simd-neon): 6 NEON integer wrapper types for aarch64 (#127, sprint W3-B)
Closes parity item 8 — adds U8x16, U16x8, U32x4, U64x2, I32x4, I64x2 NEON wrapper types so aarch64 burn-ndarray builds get real NEON acceleration on integer hot paths instead of scalar. Each type has: splat/zero/from_slice/from_array/to_array/copy_to_slice/ add/sub/min/max. NEON intrinsics: - U8x16: vaddq_u8, vsubq_u8, vminq_u8, vmaxq_u8 - U16x8: vaddq_u16, vsubq_u16, vminq_u16, vmaxq_u16 - U32x4: vaddq_u32, vsubq_u32, vminq_u32, vmaxq_u32 - U64x2: vaddq_u64, vsubq_u64 (min/max scalar — NEON has no vminq_u64) - I32x4: vaddq_s32, vsubq_s32, vminq_s32, vmaxq_s32 - I64x2: vaddq_s64, vsubq_s64 (min/max scalar — NEON has no vminq_s64) Item 7 (AVX2 paired-256 fallbacks for U32x16/U64x8/etc.) deferred: all 6 types already exist as scalar fallback via avx2_int_type! macro in src/simd_avx2.rs — they're correct and complete; paired-256 SIMD acceleration is a perf upgrade, not a functionality blocker. Tests: builds clean on x86_64 + cross-compiles clean for aarch64-unknown-linux-gnu. https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj Co-authored-by: Claude <noreply@anthropic.com>
1 parent 8eb532d commit 0c30fe2

1 file changed

Lines changed: 176 additions & 0 deletions

File tree

src/simd_neon.rs

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,182 @@ impl PartialEq for I16x8 {
12361236
fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() }
12371237
}
12381238

1239+
// ═══════════════════════════════════════════════════════════════════════════
1240+
// W3-B: NEON integer wrapper types (item 8 of burn parity list)
1241+
// ─ U8x16, U16x8, U32x4, U64x2, I32x4, I64x2 ─
1242+
// ═══════════════════════════════════════════════════════════════════════════
1243+
1244+
#[cfg(target_arch = "aarch64")]
1245+
#[derive(Copy, Clone)]
1246+
#[repr(transparent)]
1247+
pub struct U8x16(pub uint8x16_t);
1248+
1249+
#[cfg(target_arch = "aarch64")]
1250+
impl U8x16 {
1251+
pub const LANES: usize = 16;
1252+
#[inline(always)] pub fn splat(v: u8) -> Self { Self(unsafe { vdupq_n_u8(v) }) }
1253+
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u8(0) }) }
1254+
#[inline(always)] pub fn from_slice(s: &[u8]) -> Self {
1255+
assert!(s.len() >= 16); Self(unsafe { vld1q_u8(s.as_ptr()) })
1256+
}
1257+
#[inline(always)] pub fn from_array(arr: [u8; 16]) -> Self { Self(unsafe { vld1q_u8(arr.as_ptr()) }) }
1258+
#[inline(always)] pub fn to_array(self) -> [u8; 16] {
1259+
let mut arr = [0u8; 16]; unsafe { vst1q_u8(arr.as_mut_ptr(), self.0) }; arr
1260+
}
1261+
#[inline(always)] pub fn copy_to_slice(self, s: &mut [u8]) {
1262+
assert!(s.len() >= 16); unsafe { vst1q_u8(s.as_mut_ptr(), self.0) };
1263+
}
1264+
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u8(self.0, other.0) }) }
1265+
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u8(self.0, other.0) }) }
1266+
#[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u8(self.0, other.0) }) }
1267+
#[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u8(self.0, other.0) }) }
1268+
}
1269+
1270+
#[cfg(target_arch = "aarch64")]
1271+
#[derive(Copy, Clone)]
1272+
#[repr(transparent)]
1273+
pub struct U16x8(pub uint16x8_t);
1274+
1275+
#[cfg(target_arch = "aarch64")]
1276+
impl U16x8 {
1277+
pub const LANES: usize = 8;
1278+
#[inline(always)] pub fn splat(v: u16) -> Self { Self(unsafe { vdupq_n_u16(v) }) }
1279+
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u16(0) }) }
1280+
#[inline(always)] pub fn from_slice(s: &[u16]) -> Self {
1281+
assert!(s.len() >= 8); Self(unsafe { vld1q_u16(s.as_ptr()) })
1282+
}
1283+
#[inline(always)] pub fn from_array(arr: [u16; 8]) -> Self { Self(unsafe { vld1q_u16(arr.as_ptr()) }) }
1284+
#[inline(always)] pub fn to_array(self) -> [u16; 8] {
1285+
let mut arr = [0u16; 8]; unsafe { vst1q_u16(arr.as_mut_ptr(), self.0) }; arr
1286+
}
1287+
#[inline(always)] pub fn copy_to_slice(self, s: &mut [u16]) {
1288+
assert!(s.len() >= 8); unsafe { vst1q_u16(s.as_mut_ptr(), self.0) };
1289+
}
1290+
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u16(self.0, other.0) }) }
1291+
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u16(self.0, other.0) }) }
1292+
#[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u16(self.0, other.0) }) }
1293+
#[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u16(self.0, other.0) }) }
1294+
}
1295+
1296+
#[cfg(target_arch = "aarch64")]
1297+
#[derive(Copy, Clone)]
1298+
#[repr(transparent)]
1299+
pub struct U32x4(pub uint32x4_t);
1300+
1301+
#[cfg(target_arch = "aarch64")]
1302+
impl U32x4 {
1303+
pub const LANES: usize = 4;
1304+
#[inline(always)] pub fn splat(v: u32) -> Self { Self(unsafe { vdupq_n_u32(v) }) }
1305+
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u32(0) }) }
1306+
#[inline(always)] pub fn from_slice(s: &[u32]) -> Self {
1307+
assert!(s.len() >= 4); Self(unsafe { vld1q_u32(s.as_ptr()) })
1308+
}
1309+
#[inline(always)] pub fn from_array(arr: [u32; 4]) -> Self { Self(unsafe { vld1q_u32(arr.as_ptr()) }) }
1310+
#[inline(always)] pub fn to_array(self) -> [u32; 4] {
1311+
let mut arr = [0u32; 4]; unsafe { vst1q_u32(arr.as_mut_ptr(), self.0) }; arr
1312+
}
1313+
#[inline(always)] pub fn copy_to_slice(self, s: &mut [u32]) {
1314+
assert!(s.len() >= 4); unsafe { vst1q_u32(s.as_mut_ptr(), self.0) };
1315+
}
1316+
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u32(self.0, other.0) }) }
1317+
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u32(self.0, other.0) }) }
1318+
#[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u32(self.0, other.0) }) }
1319+
#[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u32(self.0, other.0) }) }
1320+
}
1321+
1322+
#[cfg(target_arch = "aarch64")]
1323+
#[derive(Copy, Clone)]
1324+
#[repr(transparent)]
1325+
pub struct U64x2(pub uint64x2_t);
1326+
1327+
#[cfg(target_arch = "aarch64")]
1328+
impl U64x2 {
1329+
pub const LANES: usize = 2;
1330+
#[inline(always)] pub fn splat(v: u64) -> Self { Self(unsafe { vdupq_n_u64(v) }) }
1331+
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u64(0) }) }
1332+
#[inline(always)] pub fn from_slice(s: &[u64]) -> Self {
1333+
assert!(s.len() >= 2); Self(unsafe { vld1q_u64(s.as_ptr()) })
1334+
}
1335+
#[inline(always)] pub fn from_array(arr: [u64; 2]) -> Self { Self(unsafe { vld1q_u64(arr.as_ptr()) }) }
1336+
#[inline(always)] pub fn to_array(self) -> [u64; 2] {
1337+
let mut arr = [0u64; 2]; unsafe { vst1q_u64(arr.as_mut_ptr(), self.0) }; arr
1338+
}
1339+
#[inline(always)] pub fn copy_to_slice(self, s: &mut [u64]) {
1340+
assert!(s.len() >= 2); unsafe { vst1q_u64(s.as_mut_ptr(), self.0) };
1341+
}
1342+
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u64(self.0, other.0) }) }
1343+
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u64(self.0, other.0) }) }
1344+
// NEON has no vminq_u64 / vmaxq_u64 — scalar fallback
1345+
#[inline(always)] pub fn min(self, other: Self) -> Self {
1346+
let a = self.to_array(); let b = other.to_array();
1347+
Self::from_array([a[0].min(b[0]), a[1].min(b[1])])
1348+
}
1349+
#[inline(always)] pub fn max(self, other: Self) -> Self {
1350+
let a = self.to_array(); let b = other.to_array();
1351+
Self::from_array([a[0].max(b[0]), a[1].max(b[1])])
1352+
}
1353+
}
1354+
1355+
#[cfg(target_arch = "aarch64")]
1356+
#[derive(Copy, Clone)]
1357+
#[repr(transparent)]
1358+
pub struct I32x4(pub int32x4_t);
1359+
1360+
#[cfg(target_arch = "aarch64")]
1361+
impl I32x4 {
1362+
pub const LANES: usize = 4;
1363+
#[inline(always)] pub fn splat(v: i32) -> Self { Self(unsafe { vdupq_n_s32(v) }) }
1364+
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_s32(0) }) }
1365+
#[inline(always)] pub fn from_slice(s: &[i32]) -> Self {
1366+
assert!(s.len() >= 4); Self(unsafe { vld1q_s32(s.as_ptr()) })
1367+
}
1368+
#[inline(always)] pub fn from_array(arr: [i32; 4]) -> Self { Self(unsafe { vld1q_s32(arr.as_ptr()) }) }
1369+
#[inline(always)] pub fn to_array(self) -> [i32; 4] {
1370+
let mut arr = [0i32; 4]; unsafe { vst1q_s32(arr.as_mut_ptr(), self.0) }; arr
1371+
}
1372+
#[inline(always)] pub fn copy_to_slice(self, s: &mut [i32]) {
1373+
assert!(s.len() >= 4); unsafe { vst1q_s32(s.as_mut_ptr(), self.0) };
1374+
}
1375+
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s32(self.0, other.0) }) }
1376+
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s32(self.0, other.0) }) }
1377+
#[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_s32(self.0, other.0) }) }
1378+
#[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_s32(self.0, other.0) }) }
1379+
}
1380+
1381+
#[cfg(target_arch = "aarch64")]
1382+
#[derive(Copy, Clone)]
1383+
#[repr(transparent)]
1384+
pub struct I64x2(pub int64x2_t);
1385+
1386+
#[cfg(target_arch = "aarch64")]
1387+
impl I64x2 {
1388+
pub const LANES: usize = 2;
1389+
#[inline(always)] pub fn splat(v: i64) -> Self { Self(unsafe { vdupq_n_s64(v) }) }
1390+
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_s64(0) }) }
1391+
#[inline(always)] pub fn from_slice(s: &[i64]) -> Self {
1392+
assert!(s.len() >= 2); Self(unsafe { vld1q_s64(s.as_ptr()) })
1393+
}
1394+
#[inline(always)] pub fn from_array(arr: [i64; 2]) -> Self { Self(unsafe { vld1q_s64(arr.as_ptr()) }) }
1395+
#[inline(always)] pub fn to_array(self) -> [i64; 2] {
1396+
let mut arr = [0i64; 2]; unsafe { vst1q_s64(arr.as_mut_ptr(), self.0) }; arr
1397+
}
1398+
#[inline(always)] pub fn copy_to_slice(self, s: &mut [i64]) {
1399+
assert!(s.len() >= 2); unsafe { vst1q_s64(s.as_mut_ptr(), self.0) };
1400+
}
1401+
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s64(self.0, other.0) }) }
1402+
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s64(self.0, other.0) }) }
1403+
// NEON has no vminq_s64 / vmaxq_s64 — scalar fallback
1404+
#[inline(always)] pub fn min(self, other: Self) -> Self {
1405+
let a = self.to_array(); let b = other.to_array();
1406+
Self::from_array([a[0].min(b[0]), a[1].min(b[1])])
1407+
}
1408+
#[inline(always)] pub fn max(self, other: Self) -> Self {
1409+
let a = self.to_array(); let b = other.to_array();
1410+
Self::from_array([a[0].max(b[0]), a[1].max(b[1])])
1411+
}
1412+
}
1413+
1414+
12391415
// ── Polyfills for wider lanes (scalar arrays) ─────────────────────────────
12401416

12411417
macro_rules! neon_int_polyfill {

0 commit comments

Comments
 (0)