From def03a3aae05baf35fd77da3decb263416139aca Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 20:36:39 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=20Thunderbolt:=20softmax=5Fv6=20?= =?UTF-8?q?=E2=80=94=20AVX-512=20Vectorized=20Softmax?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: bugparty <1510776+bugparty@users.noreply.github.com> --- ml_kernels/include/ml_kernels/softmax.h | 160 ++++++++++++++++++++++++ ml_kernels/src/kernel_bench.cpp | 16 +++ ml_kernels/src/test_naive_ops.cpp | 29 +++++ 3 files changed, 205 insertions(+) diff --git a/ml_kernels/include/ml_kernels/softmax.h b/ml_kernels/include/ml_kernels/softmax.h index 4c6ed7a..050238c 100644 --- a/ml_kernels/include/ml_kernels/softmax.h +++ b/ml_kernels/include/ml_kernels/softmax.h @@ -501,4 +501,164 @@ inline void softmax_v5(const float *input, float *output, std::size_t n) { } } + +#ifdef __AVX512F__ +// ⚡ Thunderbolt: AVX-512 Vectorized Softmax +// Target: AVX-512 (Skylake-X+) +// Reason: AVX-512 processes 16 floats per vector (vs 8 in AVX2), doubling the theoretical throughput. In addition, AVX-512 provides masking and improved exponential math sequences. +// Expected gain: ~1.5-2.0x over AVX2 implementations (softmax_v5/v6) on capable hardware. +inline void softmax_v6(const float *input, float *output, std::size_t n) { + if (n == 0) return; + + // 1. Find max + std::size_t i = 0; + __m512 max0 = _mm512_set1_ps(std::numeric_limits::lowest()); + __m512 max1 = max0, max2 = max0, max3 = max0; + + for (; i + 63 < n; i += 64) { + max0 = _mm512_max_ps(max0, _mm512_loadu_ps(input + i)); + max1 = _mm512_max_ps(max1, _mm512_loadu_ps(input + i + 16)); + max2 = _mm512_max_ps(max2, _mm512_loadu_ps(input + i + 32)); + max3 = _mm512_max_ps(max3, _mm512_loadu_ps(input + i + 48)); + } + max0 = _mm512_max_ps(max0, max1); + max2 = _mm512_max_ps(max2, max3); + max0 = _mm512_max_ps(max0, max2); + + for (; i + 15 < n; i += 16) { + max0 = _mm512_max_ps(max0, _mm512_loadu_ps(input + i)); + } + + float max_val = _mm512_reduce_max_ps(max0); + + if (i < n) { + __mmask16 mask = (1 << (n - i)) - 1; + __m512 rem = _mm512_maskz_loadu_ps(mask, input + i); + // fill inactive lanes with lowest + __m512 lowest = _mm512_set1_ps(std::numeric_limits::lowest()); + rem = _mm512_mask_blend_ps(mask, lowest, rem); + max_val = std::max(max_val, _mm512_reduce_max_ps(rem)); + } + + __m512 max_vec = _mm512_set1_ps(max_val); + + // 2. Compute exp and sum + i = 0; + __m512 sum0 = _mm512_setzero_ps(); + __m512 sum1 = _mm512_setzero_ps(); + __m512 sum2 = _mm512_setzero_ps(); + __m512 sum3 = _mm512_setzero_ps(); + + for (; i + 63 < n; i += 64) { + __m512 x0 = _mm512_sub_ps(_mm512_loadu_ps(input + i), max_vec); + __m512 x1 = _mm512_sub_ps(_mm512_loadu_ps(input + i + 16), max_vec); + __m512 x2 = _mm512_sub_ps(_mm512_loadu_ps(input + i + 32), max_vec); + __m512 x3 = _mm512_sub_ps(_mm512_loadu_ps(input + i + 48), max_vec); + + auto exp512 = [](const __m512& x) { + __m512 max_clamped = _mm512_max_ps(x, _mm512_set1_ps(-87.3f)); + __m512 x_log2e = _mm512_mul_ps(max_clamped, _mm512_set1_ps(1.4426950408889634f)); + __m512i n_int = _mm512_cvt_roundps_epi32(x_log2e, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + __m512 n_flt = _mm512_cvtepi32_ps(n_int); + + __m512 r = _mm512_fnmadd_ps(n_flt, _mm512_set1_ps(0.693145751953125f), max_clamped); + r = _mm512_fnmadd_ps(n_flt, _mm512_set1_ps(1.428606765330187e-06f), r); + + __m512 p = _mm512_fmadd_ps(_mm512_set1_ps(1.0f / 120.0f), r, _mm512_set1_ps(1.0f / 24.0f)); + p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f / 6.0f)); + p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f / 2.0f)); + p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f)); + p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f)); + + __m512i exp_shifted = _mm512_slli_epi32(_mm512_add_epi32(n_int, _mm512_set1_epi32(127)), 23); + return _mm512_mul_ps(p, _mm512_castsi512_ps(exp_shifted)); + }; + + __m512 e0 = exp512(x0); + __m512 e1 = exp512(x1); + __m512 e2 = exp512(x2); + __m512 e3 = exp512(x3); + + _mm512_storeu_ps(output + i, e0); + _mm512_storeu_ps(output + i + 16, e1); + _mm512_storeu_ps(output + i + 32, e2); + _mm512_storeu_ps(output + i + 48, e3); + + sum0 = _mm512_add_ps(sum0, e0); + sum1 = _mm512_add_ps(sum1, e1); + sum2 = _mm512_add_ps(sum2, e2); + sum3 = _mm512_add_ps(sum3, e3); + } + + sum0 = _mm512_add_ps(sum0, sum1); + sum2 = _mm512_add_ps(sum2, sum3); + sum0 = _mm512_add_ps(sum0, sum2); + + auto exp512_single = [](const __m512& x) { + __m512 max_clamped = _mm512_max_ps(x, _mm512_set1_ps(-87.3f)); + __m512 x_log2e = _mm512_mul_ps(max_clamped, _mm512_set1_ps(1.4426950408889634f)); + __m512i n_int = _mm512_cvt_roundps_epi32(x_log2e, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + __m512 n_flt = _mm512_cvtepi32_ps(n_int); + __m512 r = _mm512_fnmadd_ps(n_flt, _mm512_set1_ps(0.693145751953125f), max_clamped); + r = _mm512_fnmadd_ps(n_flt, _mm512_set1_ps(1.428606765330187e-06f), r); + __m512 p = _mm512_fmadd_ps(_mm512_set1_ps(1.0f / 120.0f), r, _mm512_set1_ps(1.0f / 24.0f)); + p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f / 6.0f)); + p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f / 2.0f)); + p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f)); + p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f)); + __m512i exp_shifted = _mm512_slli_epi32(_mm512_add_epi32(n_int, _mm512_set1_epi32(127)), 23); + return _mm512_mul_ps(p, _mm512_castsi512_ps(exp_shifted)); + }; + + for (; i + 15 < n; i += 16) { + __m512 x = _mm512_loadu_ps(input + i); + __m512 e = exp512_single(_mm512_sub_ps(x, max_vec)); + _mm512_storeu_ps(output + i, e); + sum0 = _mm512_add_ps(sum0, e); + } + + float sum_val = _mm512_reduce_add_ps(sum0); + + if (i < n) { + __mmask16 mask = (1 << (n - i)) - 1; + __m512 x = _mm512_maskz_loadu_ps(mask, input + i); + __m512 lowest = _mm512_set1_ps(std::numeric_limits::lowest()); + x = _mm512_mask_blend_ps(mask, lowest, x); + __m512 e = exp512_single(_mm512_sub_ps(x, max_vec)); + _mm512_mask_storeu_ps(output + i, mask, e); + // mask out inactive elements for sum reduction + __m512 e_masked = _mm512_maskz_mov_ps(mask, e); + sum_val += _mm512_reduce_add_ps(e_masked); + } + + if (sum_val == 0.0f) return; + + // 3. Normalize + float inv_sum = 1.0f / sum_val; + __m512 inv_sum_v = _mm512_set1_ps(inv_sum); + i = 0; + + for (; i + 63 < n; i += 64) { + _mm512_storeu_ps(output + i, _mm512_mul_ps(_mm512_loadu_ps(output + i), inv_sum_v)); + _mm512_storeu_ps(output + i + 16, _mm512_mul_ps(_mm512_loadu_ps(output + i + 16), inv_sum_v)); + _mm512_storeu_ps(output + i + 32, _mm512_mul_ps(_mm512_loadu_ps(output + i + 32), inv_sum_v)); + _mm512_storeu_ps(output + i + 48, _mm512_mul_ps(_mm512_loadu_ps(output + i + 48), inv_sum_v)); + } + + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(output + i, _mm512_mul_ps(_mm512_loadu_ps(output + i), inv_sum_v)); + } + + if (i < n) { + __mmask16 mask = (1 << (n - i)) - 1; + __m512 o = _mm512_maskz_loadu_ps(mask, output + i); + _mm512_mask_storeu_ps(output + i, mask, _mm512_mul_ps(o, inv_sum_v)); + } +} +#else +// Fallback if AVX-512 is not compiled +inline void softmax_v6(const float *input, float *output, std::size_t n) { + softmax_v5(input, output, n); +} +#endif } // namespace ml_kernels diff --git a/ml_kernels/src/kernel_bench.cpp b/ml_kernels/src/kernel_bench.cpp index d22dc06..3239e7b 100644 --- a/ml_kernels/src/kernel_bench.cpp +++ b/ml_kernels/src/kernel_bench.cpp @@ -321,6 +321,21 @@ class SoftmaxV4Benchmark : public SoftmaxBenchmark { }; REGISTER_BENCHMARK(SoftmaxV4Benchmark); + + + + + +class SoftmaxV6Benchmark : public SoftmaxBenchmark { +public: + const char *name() const override { return "softmax_v6"; } + + void run() override { + ml_kernels::softmax_v6(inputs_[current_idx_].data(), outputs_[current_idx_].data(), inputs_[0].size()); + current_idx_ = (current_idx_ + 1) % pool_size_; + } +}; + class SoftmaxV5Benchmark : public SoftmaxBenchmark { public: const char *name() const override { return "softmax_v5"; } @@ -330,6 +345,7 @@ class SoftmaxV5Benchmark : public SoftmaxBenchmark { current_idx_ = (current_idx_ + 1) % pool_size_; } }; +REGISTER_BENCHMARK(SoftmaxV6Benchmark); REGISTER_BENCHMARK(SoftmaxV5Benchmark); } // namespace diff --git a/ml_kernels/src/test_naive_ops.cpp b/ml_kernels/src/test_naive_ops.cpp index b0f27a6..bf7b26b 100644 --- a/ml_kernels/src/test_naive_ops.cpp +++ b/ml_kernels/src/test_naive_ops.cpp @@ -181,7 +181,36 @@ void test_softmax_v5() { std::cout << "test_softmax_v5 passed!" << std::endl; } + + + +#include + +void test_softmax_v6() { + std::cout << "Running test_softmax_v6..." << std::endl; + for (std::size_t n : {1, 3, 4, 8, 15, 16, 31, 32, 33, 63, 64, 65, 100, 128, 256, 1024, 1025, 4096}) { + std::vector input(n); + std::vector output_naive(n); + std::vector output_v6(n); + + std::mt19937 rng(n); + std::uniform_real_distribution dist(-10.0f, 10.0f); + for (std::size_t i = 0; i < n; ++i) { + input[i] = dist(rng); + } + + ml_kernels::softmax_naive(input.data(), output_naive.data(), n); + ml_kernels::softmax_v6(input.data(), output_v6.data(), n); + + for (std::size_t i = 0; i < n; ++i) { + assert(std::fabs(output_naive[i] - output_v6[i]) < 1e-4f); + } + } + std::cout << "test_softmax_v6 passed!" << std::endl; +} + int main() { + test_softmax_v6(); test_relu_naive(); test_max_naive(); test_softmax_v3();