Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions ml_kernels/include/ml_kernels/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,4 +501,164 @@ inline void softmax_v5(const float *input, float *output, std::size_t n) {
}
}


#ifdef __AVX512F__
// ⚡ Thunderbolt: AVX-512 Vectorized Softmax
// Target: AVX-512 (Skylake-X+)
// Reason: AVX-512 processes 16 floats per vector (vs 8 in AVX2), doubling the theoretical throughput. In addition, AVX-512 provides masking and improved exponential math sequences.
// Expected gain: ~1.5-2.0x over AVX2 implementations (softmax_v5/v6) on capable hardware.
inline void softmax_v6(const float *input, float *output, std::size_t n) {
if (n == 0) return;

// 1. Find max
std::size_t i = 0;
__m512 max0 = _mm512_set1_ps(std::numeric_limits<float>::lowest());
__m512 max1 = max0, max2 = max0, max3 = max0;

for (; i + 63 < n; i += 64) {
max0 = _mm512_max_ps(max0, _mm512_loadu_ps(input + i));
max1 = _mm512_max_ps(max1, _mm512_loadu_ps(input + i + 16));
max2 = _mm512_max_ps(max2, _mm512_loadu_ps(input + i + 32));
max3 = _mm512_max_ps(max3, _mm512_loadu_ps(input + i + 48));
}
max0 = _mm512_max_ps(max0, max1);
max2 = _mm512_max_ps(max2, max3);
max0 = _mm512_max_ps(max0, max2);

for (; i + 15 < n; i += 16) {
max0 = _mm512_max_ps(max0, _mm512_loadu_ps(input + i));
}

float max_val = _mm512_reduce_max_ps(max0);

if (i < n) {
__mmask16 mask = (1 << (n - i)) - 1;
__m512 rem = _mm512_maskz_loadu_ps(mask, input + i);
// fill inactive lanes with lowest
__m512 lowest = _mm512_set1_ps(std::numeric_limits<float>::lowest());
rem = _mm512_mask_blend_ps(mask, lowest, rem);
max_val = std::max(max_val, _mm512_reduce_max_ps(rem));
}

__m512 max_vec = _mm512_set1_ps(max_val);

// 2. Compute exp and sum
i = 0;
__m512 sum0 = _mm512_setzero_ps();
__m512 sum1 = _mm512_setzero_ps();
__m512 sum2 = _mm512_setzero_ps();
__m512 sum3 = _mm512_setzero_ps();

for (; i + 63 < n; i += 64) {
__m512 x0 = _mm512_sub_ps(_mm512_loadu_ps(input + i), max_vec);
__m512 x1 = _mm512_sub_ps(_mm512_loadu_ps(input + i + 16), max_vec);
__m512 x2 = _mm512_sub_ps(_mm512_loadu_ps(input + i + 32), max_vec);
__m512 x3 = _mm512_sub_ps(_mm512_loadu_ps(input + i + 48), max_vec);

auto exp512 = [](const __m512& x) {
__m512 max_clamped = _mm512_max_ps(x, _mm512_set1_ps(-87.3f));
__m512 x_log2e = _mm512_mul_ps(max_clamped, _mm512_set1_ps(1.4426950408889634f));
__m512i n_int = _mm512_cvt_roundps_epi32(x_log2e, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
__m512 n_flt = _mm512_cvtepi32_ps(n_int);

__m512 r = _mm512_fnmadd_ps(n_flt, _mm512_set1_ps(0.693145751953125f), max_clamped);
r = _mm512_fnmadd_ps(n_flt, _mm512_set1_ps(1.428606765330187e-06f), r);

__m512 p = _mm512_fmadd_ps(_mm512_set1_ps(1.0f / 120.0f), r, _mm512_set1_ps(1.0f / 24.0f));
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f / 6.0f));
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f / 2.0f));
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));

__m512i exp_shifted = _mm512_slli_epi32(_mm512_add_epi32(n_int, _mm512_set1_epi32(127)), 23);
return _mm512_mul_ps(p, _mm512_castsi512_ps(exp_shifted));
};

__m512 e0 = exp512(x0);
__m512 e1 = exp512(x1);
__m512 e2 = exp512(x2);
__m512 e3 = exp512(x3);

_mm512_storeu_ps(output + i, e0);
_mm512_storeu_ps(output + i + 16, e1);
_mm512_storeu_ps(output + i + 32, e2);
_mm512_storeu_ps(output + i + 48, e3);

sum0 = _mm512_add_ps(sum0, e0);
sum1 = _mm512_add_ps(sum1, e1);
sum2 = _mm512_add_ps(sum2, e2);
sum3 = _mm512_add_ps(sum3, e3);
}

sum0 = _mm512_add_ps(sum0, sum1);
sum2 = _mm512_add_ps(sum2, sum3);
sum0 = _mm512_add_ps(sum0, sum2);

auto exp512_single = [](const __m512& x) {
__m512 max_clamped = _mm512_max_ps(x, _mm512_set1_ps(-87.3f));
__m512 x_log2e = _mm512_mul_ps(max_clamped, _mm512_set1_ps(1.4426950408889634f));
__m512i n_int = _mm512_cvt_roundps_epi32(x_log2e, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
__m512 n_flt = _mm512_cvtepi32_ps(n_int);
__m512 r = _mm512_fnmadd_ps(n_flt, _mm512_set1_ps(0.693145751953125f), max_clamped);
r = _mm512_fnmadd_ps(n_flt, _mm512_set1_ps(1.428606765330187e-06f), r);
__m512 p = _mm512_fmadd_ps(_mm512_set1_ps(1.0f / 120.0f), r, _mm512_set1_ps(1.0f / 24.0f));
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f / 6.0f));
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f / 2.0f));
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));
p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));
__m512i exp_shifted = _mm512_slli_epi32(_mm512_add_epi32(n_int, _mm512_set1_epi32(127)), 23);
return _mm512_mul_ps(p, _mm512_castsi512_ps(exp_shifted));
};
Comment on lines +558 to +611
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Eliminate the duplicated exp512 lambdas and hoist out of the loop.

exp512 (Lines 558–575) and exp512_single (Lines 597–611) are byte-for-byte identical Horner expansions, and exp512 is currently declared inside the for (; i + 63 < n; i += 64) loop body, so it is logically re-declared every iteration. Two issues here:

  1. DRY violation — any future tweak (constants, range clamp, rounding mode) has to be applied in both copies and will silently diverge, just like exp256_ps vs exp256_ps_v2 already did.
  2. Scope — the lambda has no captures and depends on nothing loop-local; it belongs at namespace scope alongside exp256_ps_v2, both for readability and to mirror the AVX2 versions in this file.

Suggested refactor: extract to a namespace-scope helper (e.g. exp512_ps_v2) and call it from both the unrolled and the 16-wide / masked-tail paths.

♻️ Proposed extraction
+#ifdef __AVX512F__
+inline __m512 exp512_ps_v2(__m512 x) {
+    x = _mm512_max_ps(x, _mm512_set1_ps(-87.3f));
+    __m512 x_log2e = _mm512_mul_ps(x, _mm512_set1_ps(1.4426950408889634f));
+    __m512i n_int  = _mm512_cvt_roundps_epi32(
+        x_log2e, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
+    __m512 n_flt   = _mm512_cvtepi32_ps(n_int);
+
+    __m512 r = _mm512_fnmadd_ps(n_flt, _mm512_set1_ps(0.693145751953125f), x);
+    r        = _mm512_fnmadd_ps(n_flt, _mm512_set1_ps(1.428606765330187e-06f), r);
+
+    __m512 p = _mm512_fmadd_ps(_mm512_set1_ps(1.0f / 120.0f), r, _mm512_set1_ps(1.0f / 24.0f));
+    p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f / 6.0f));
+    p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f / 2.0f));
+    p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));
+    p = _mm512_fmadd_ps(p, r, _mm512_set1_ps(1.0f));
+
+    __m512i exp_shifted =
+        _mm512_slli_epi32(_mm512_add_epi32(n_int, _mm512_set1_epi32(127)), 23);
+    return _mm512_mul_ps(p, _mm512_castsi512_ps(exp_shifted));
+}
+#endif

Then in softmax_v6, drop both lambdas and call exp512_ps_v2(...) directly in the unrolled loop, the 16-wide tail, and the masked tail.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@ml_kernels/include/ml_kernels/softmax.h` around lines 558 - 611, The two
identical lambdas exp512 and exp512_single should be consolidated into a single
namespace-scope helper (e.g. exp512_ps_v2) and the in-loop declaration removed;
implement exp512_ps_v2 with the existing body (no captures) and replace calls to
exp512/exp512_single in softmax_v6 (the unrolled for-loop, the 16-wide tail
path, and the masked-tail path) to call exp512_ps_v2(...), deleting the
duplicate lambda definitions so all code reuses the single helper.


for (; i + 15 < n; i += 16) {
__m512 x = _mm512_loadu_ps(input + i);
__m512 e = exp512_single(_mm512_sub_ps(x, max_vec));
_mm512_storeu_ps(output + i, e);
sum0 = _mm512_add_ps(sum0, e);
}

float sum_val = _mm512_reduce_add_ps(sum0);

if (i < n) {
__mmask16 mask = (1 << (n - i)) - 1;
__m512 x = _mm512_maskz_loadu_ps(mask, input + i);
__m512 lowest = _mm512_set1_ps(std::numeric_limits<float>::lowest());
x = _mm512_mask_blend_ps(mask, lowest, x);
__m512 e = exp512_single(_mm512_sub_ps(x, max_vec));
_mm512_mask_storeu_ps(output + i, mask, e);
// mask out inactive elements for sum reduction
__m512 e_masked = _mm512_maskz_mov_ps(mask, e);
sum_val += _mm512_reduce_add_ps(e_masked);
}

if (sum_val == 0.0f) return;

// 3. Normalize
float inv_sum = 1.0f / sum_val;
__m512 inv_sum_v = _mm512_set1_ps(inv_sum);
i = 0;

for (; i + 63 < n; i += 64) {
_mm512_storeu_ps(output + i, _mm512_mul_ps(_mm512_loadu_ps(output + i), inv_sum_v));
_mm512_storeu_ps(output + i + 16, _mm512_mul_ps(_mm512_loadu_ps(output + i + 16), inv_sum_v));
_mm512_storeu_ps(output + i + 32, _mm512_mul_ps(_mm512_loadu_ps(output + i + 32), inv_sum_v));
_mm512_storeu_ps(output + i + 48, _mm512_mul_ps(_mm512_loadu_ps(output + i + 48), inv_sum_v));
}

for (; i + 15 < n; i += 16) {
_mm512_storeu_ps(output + i, _mm512_mul_ps(_mm512_loadu_ps(output + i), inv_sum_v));
}

if (i < n) {
__mmask16 mask = (1 << (n - i)) - 1;
__m512 o = _mm512_maskz_loadu_ps(mask, output + i);
_mm512_mask_storeu_ps(output + i, mask, _mm512_mul_ps(o, inv_sum_v));
}
}
#else
// Fallback if AVX-512 is not compiled
inline void softmax_v6(const float *input, float *output, std::size_t n) {
softmax_v5(input, output, n);
}
#endif
} // namespace ml_kernels
16 changes: 16 additions & 0 deletions ml_kernels/src/kernel_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,21 @@ class SoftmaxV4Benchmark : public SoftmaxBenchmark {
};
REGISTER_BENCHMARK(SoftmaxV4Benchmark);






class SoftmaxV6Benchmark : public SoftmaxBenchmark {
public:
const char *name() const override { return "softmax_v6"; }

void run() override {
ml_kernels::softmax_v6(inputs_[current_idx_].data(), outputs_[current_idx_].data(), inputs_[0].size());
current_idx_ = (current_idx_ + 1) % pool_size_;
}
};

class SoftmaxV5Benchmark : public SoftmaxBenchmark {
public:
const char *name() const override { return "softmax_v5"; }
Expand All @@ -330,6 +345,7 @@ class SoftmaxV5Benchmark : public SoftmaxBenchmark {
current_idx_ = (current_idx_ + 1) % pool_size_;
}
};
REGISTER_BENCHMARK(SoftmaxV6Benchmark);
REGISTER_BENCHMARK(SoftmaxV5Benchmark);

} // namespace
Expand Down
29 changes: 29 additions & 0 deletions ml_kernels/src/test_naive_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,36 @@ void test_softmax_v5() {
std::cout << "test_softmax_v5 passed!" << std::endl;
}




#include <random>

void test_softmax_v6() {
std::cout << "Running test_softmax_v6..." << std::endl;
for (std::size_t n : {1, 3, 4, 8, 15, 16, 31, 32, 33, 63, 64, 65, 100, 128, 256, 1024, 1025, 4096}) {
std::vector<float> input(n);
std::vector<float> output_naive(n);
std::vector<float> output_v6(n);

std::mt19937 rng(n);
std::uniform_real_distribution<float> dist(-10.0f, 10.0f);
for (std::size_t i = 0; i < n; ++i) {
input[i] = dist(rng);
}

ml_kernels::softmax_naive(input.data(), output_naive.data(), n);
ml_kernels::softmax_v6(input.data(), output_v6.data(), n);

for (std::size_t i = 0; i < n; ++i) {
assert(std::fabs(output_naive[i] - output_v6[i]) < 1e-4f);
}
}
std::cout << "test_softmax_v6 passed!" << std::endl;
}
Comment on lines +189 to +210
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add n == 0 and a sum-to-1 invariant check.

softmax_v6 has an explicit if (n == 0) return; early-out (softmax.h Line 511) which is currently uncovered by this test set — the smallest size exercised is 1. A zero-length probe is essentially free and locks in that the AVX-512 path doesn't, e.g., dereference input + 0 with a non-empty mask if someone refactors the prologue later.

Also, the v3/v4/v5 tests assert |sum - 1.0f| < 1e-4f after the loop, but this test only checks elementwise agreement with softmax_naive. Adding the same probability-distribution invariant catches normalization-stage bugs (the masked tail at Lines 652–656 of softmax.h is a likely place for them) even if they happen to mirror an error in the naive reference.

🧪 Suggested additions
-    for (std::size_t n : {1, 3, 4, 8, 15, 16, 31, 32, 33, 63, 64, 65, 100, 128, 256, 1024, 1025, 4096}) {
+    for (std::size_t n : {0, 1, 3, 4, 8, 15, 16, 31, 32, 33, 63, 64, 65, 100, 128, 256, 1024, 1025, 4096}) {
         std::vector<float> input(n);
         std::vector<float> output_naive(n);
         std::vector<float> output_v6(n);
@@
         ml_kernels::softmax_naive(input.data(), output_naive.data(), n);
         ml_kernels::softmax_v6(input.data(), output_v6.data(), n);
 
-        for (std::size_t i = 0; i < n; ++i) {
-            assert(std::fabs(output_naive[i] - output_v6[i]) < 1e-4f);
-        }
+        float sum = 0.0f;
+        for (std::size_t i = 0; i < n; ++i) {
+            assert(std::fabs(output_naive[i] - output_v6[i]) < 1e-4f);
+            sum += output_v6[i];
+        }
+        if (n > 0) {
+            assert(std::fabs(sum - 1.0f) < 1e-4f);
+        }
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void test_softmax_v6() {
std::cout << "Running test_softmax_v6..." << std::endl;
for (std::size_t n : {1, 3, 4, 8, 15, 16, 31, 32, 33, 63, 64, 65, 100, 128, 256, 1024, 1025, 4096}) {
std::vector<float> input(n);
std::vector<float> output_naive(n);
std::vector<float> output_v6(n);
std::mt19937 rng(n);
std::uniform_real_distribution<float> dist(-10.0f, 10.0f);
for (std::size_t i = 0; i < n; ++i) {
input[i] = dist(rng);
}
ml_kernels::softmax_naive(input.data(), output_naive.data(), n);
ml_kernels::softmax_v6(input.data(), output_v6.data(), n);
for (std::size_t i = 0; i < n; ++i) {
assert(std::fabs(output_naive[i] - output_v6[i]) < 1e-4f);
}
}
std::cout << "test_softmax_v6 passed!" << std::endl;
}
void test_softmax_v6() {
std::cout << "Running test_softmax_v6..." << std::endl;
for (std::size_t n : {0, 1, 3, 4, 8, 15, 16, 31, 32, 33, 63, 64, 65, 100, 128, 256, 1024, 1025, 4096}) {
std::vector<float> input(n);
std::vector<float> output_naive(n);
std::vector<float> output_v6(n);
std::mt19937 rng(n);
std::uniform_real_distribution<float> dist(-10.0f, 10.0f);
for (std::size_t i = 0; i < n; ++i) {
input[i] = dist(rng);
}
ml_kernels::softmax_naive(input.data(), output_naive.data(), n);
ml_kernels::softmax_v6(input.data(), output_v6.data(), n);
float sum = 0.0f;
for (std::size_t i = 0; i < n; ++i) {
assert(std::fabs(output_naive[i] - output_v6[i]) < 1e-4f);
sum += output_v6[i];
}
if (n > 0) {
assert(std::fabs(sum - 1.0f) < 1e-4f);
}
}
std::cout << "test_softmax_v6 passed!" << std::endl;
}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@ml_kernels/src/test_naive_ops.cpp` around lines 189 - 210, The
test_softmax_v6 test misses exercising the n == 0 early-out in softmax_v6 and
lacks a sum-to-1 invariant check; add a zero-length case (n = 0) where you
create empty input/output vectors and call ml_kernels::softmax_naive and
ml_kernels::softmax_v6 to ensure they return without crashing, and for every n
(including n==0) compute the sum of output_v6 (and/or output_naive) and assert
fabs(sum - 1.0f) < 1e-4f in addition to the existing elementwise comparisons to
catch normalization/tail-mask bugs (refer to functions test_softmax_v6,
softmax_v6, softmax_naive).


int main() {
test_softmax_v6();
test_relu_naive();
test_max_naive();
test_softmax_v3();
Expand Down
Loading