|
31 | 31 | // ═══════════════════════════════════════════════════════════════════════════ |
32 | 32 |
|
33 | 33 | /// Check if AMX hardware is present AND OS-enabled. |
| 34 | +/// |
| 35 | +/// Two checks required: |
| 36 | +/// 1. CPUID.07H.0H:EDX bits 24 (AMX-TILE) + 25 (AMX-INT8) = CPU supports it |
| 37 | +/// 2. XCR0 bits 17 (TILECFG) + 18 (TILEDATA) = OS has enabled tile state |
| 38 | +/// |
| 39 | +/// The XCR0 check is critical: even if CPUID reports AMX, the hypervisor |
| 40 | +/// may not have enabled the XSTATE for tiles. Without OS enablement, |
| 41 | +/// LDTILECFG will SIGILL. |
| 42 | +/// |
| 43 | +/// Previous bug: used CPUID leaf 0xD (reports what CPU supports for XSAVE) |
| 44 | +/// instead of _xgetbv(0) (reports what OS actually enabled). The old check |
| 45 | +/// could return true on a hypervisor that advertises AMX in CPUID but |
| 46 | +/// hasn't set XCR0 bits 17+18. |
34 | 47 | #[cfg(target_arch = "x86_64")] |
35 | 48 | pub fn amx_available() -> bool { |
| 49 | + // Step 1: CPU supports AMX-TILE + AMX-INT8? |
36 | 50 | let cpuid = core::arch::x86_64::__cpuid_count(7, 0); |
37 | 51 | let amx_tile = (cpuid.edx >> 24) & 1; |
38 | 52 | let amx_int8 = (cpuid.edx >> 25) & 1; |
39 | 53 | if amx_tile == 0 || amx_int8 == 0 { return false; } |
40 | | - // Check OS enabled via XCR0 bits 17+18 |
41 | | - let xcr0 = core::arch::x86_64::__cpuid_count(0xD, 0); |
42 | | - let tilecfg = (xcr0.eax >> 17) & 1; |
43 | | - let tiledata = (xcr0.eax >> 18) & 1; |
44 | | - tilecfg == 1 && tiledata == 1 |
| 54 | + |
| 55 | + // Step 2: OS enabled XSAVE? (CPUID.01H:ECX bit 27 = OSXSAVE) |
| 56 | + let cpuid_01 = core::arch::x86_64::__cpuid(1); |
| 57 | + let osxsave = (cpuid_01.ecx >> 27) & 1; |
| 58 | + if osxsave == 0 { return false; } |
| 59 | + |
| 60 | + // Step 3: OS actually enabled tile state in XCR0? |
| 61 | + // _xgetbv(0) reads the ACTUAL XCR0 register (what the OS set), |
| 62 | + // not the CPUID-reported capability. |
| 63 | + // Bit 17 = TILECFG, Bit 18 = TILEDATA. Both must be set. |
| 64 | + let xcr0: u64 = unsafe { core::arch::x86_64::_xgetbv(0) }; |
| 65 | + let tilecfg = (xcr0 >> 17) & 1; |
| 66 | + let tiledata = (xcr0 >> 18) & 1; |
| 67 | + if tilecfg == 0 || tiledata == 0 { return false; } |
| 68 | + |
| 69 | + // Step 4: Request XCOMP_PERM for TILEDATA. |
| 70 | + // Linux kernel 5.19+: processes must call prctl(ARCH_REQ_XCOMP_PERM, 18) |
| 71 | + // to request permission for TILEDATA (XFEATURE 18) before using AMX. |
| 72 | + // Without this, LDTILECFG will SIGILL even if XCR0 bits are set. |
| 73 | + // The prctl either succeeds (0) or fails (-1) — idempotent, safe to call |
| 74 | + // multiple times. |
| 75 | + #[cfg(target_os = "linux")] |
| 76 | + { |
| 77 | + const SYS_PRCTL: i64 = 157; // x86_64 syscall number for prctl |
| 78 | + const ARCH_REQ_XCOMP_PERM: i64 = 0x1023; |
| 79 | + const XFEATURE_XTILEDATA: i64 = 18; |
| 80 | + // SAFETY: syscall(prctl, ARCH_REQ_XCOMP_PERM, 18) is a simple permission |
| 81 | + // request. It either grants tile permission (returns 0) or fails (returns |
| 82 | + // -errno). No side effects on failure. Idempotent. |
| 83 | + let ret: i64; |
| 84 | + unsafe { |
| 85 | + core::arch::asm!( |
| 86 | + "syscall", |
| 87 | + inlateout("rax") SYS_PRCTL => ret, |
| 88 | + in("rdi") ARCH_REQ_XCOMP_PERM, |
| 89 | + in("rsi") XFEATURE_XTILEDATA, |
| 90 | + in("rdx") 0i64, |
| 91 | + in("r10") 0i64, |
| 92 | + in("r8") 0i64, |
| 93 | + lateout("rcx") _, |
| 94 | + lateout("r11") _, |
| 95 | + options(nostack), |
| 96 | + ); |
| 97 | + } |
| 98 | + if ret != 0 { return false; } |
| 99 | + } |
| 100 | + |
| 101 | + true |
45 | 102 | } |
46 | 103 |
|
47 | 104 | #[cfg(not(target_arch = "x86_64"))] |
@@ -203,17 +260,25 @@ pub fn vnni_matvec_scalar( |
203 | 260 |
|
204 | 261 | /// Runtime-dispatched VNNI MatVec: avx512vnni → avxvnniint8 → scalar i32. |
205 | 262 | /// |
206 | | -/// Three tiers, mutually exclusive by hardware generation: |
| 263 | +/// Three tiers, checked in order (first match wins): |
207 | 264 | /// avx512vnni — 64 MACs/instr (zmm, Cascade Lake+, Zen 4+) |
208 | 265 | /// avxvnniint8 — 32 MACs/instr (ymm, Arrow Lake, NUC 14 i9-185H) |
209 | | -/// scalar i32 — only for non-x86 or testing (caller should prefer F32x16 FMA) |
| 266 | +/// scalar i32 — only for non-x86 or testing |
| 267 | +/// |
| 268 | +/// IMPORTANT: avxvnniint8 (VNNI2, 256-bit) is NEVER reached when |
| 269 | +/// avx512vnni (VNNI512) is present. This is correct: |
| 270 | +/// - CPUs with avx512vnni always have 512-bit VPDPBUSD (faster) |
| 271 | +/// - avxvnniint8 exists ONLY for CPUs that dropped AVX-512 |
| 272 | +/// but added 256-bit VNNI (Arrow Lake, Meteor Lake U-series) |
| 273 | +/// - The two instructions have DIFFERENT encodings: |
| 274 | +/// avx512vnni: EVEX-encoded VPDPBUSD zmm (512-bit) |
| 275 | +/// avxvnniint8: VEX-encoded VPDPBUSD ymm (256-bit) |
| 276 | +/// - Running EVEX VPDPBUSD on a VEX-only CPU = SIGILL |
| 277 | +/// - Running VEX VPDPBUSD on an EVEX CPU = works but wastes half the width |
210 | 278 | /// |
211 | | -/// NOTE: The scalar path here does i32 multiply-accumulate, NOT f32. |
212 | | -/// For the thinking engine, F32x16 FMA (16 MACs/instr) is the true floor. |
213 | | -/// This scalar path exists only for correctness on non-x86 targets. |
214 | 279 | /// The thinking engine's cycle_auto() dispatches: |
215 | 280 | /// VNNI detected → cycle_vnni() → this function |
216 | | -/// No VNNI → cycle() → F32x16 (never reaches here) |
| 281 | +/// No VNNI → cycle() → F32x16 FMA (never reaches here) |
217 | 282 | pub fn matvec_dispatch( |
218 | 283 | table: &[u8], |
219 | 284 | energy_i8: &[i8], |
|
0 commit comments