Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .jules/thunderbolt.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@
**Evidence:** Microbenchmarking showed a 2x speedup (99ms -> 49ms) for max_v3 over max_v2 on L1-hot arrays. End-to-end framework benchmarks showed an 8% throughput increase (4.03 -> 4.36 GFLOP/s) on large fixed-memory allocations (N=6553600).

**Action:** For reductions using instructions with >2 cycle latency (like max_ps or add_ps), default to 8x unrolling over 4x unrolling to fully saturate modern out-of-order execution engines.

## $(date +%Y-%m-%d) - Explicit Instruction Interleaving in AVX2 Softmax
**Learning:** In AVX2, when unrolling complex math sequences like `exp` (which heavily relies on FMAs via Horner's scheme or Estrin's), simply calling the vectorized math function consecutively inside a 4x unrolled loop leaves performance on the table. FMA latency (typically 4 cycles) creates a dependency chain within each `exp` call. By manually inlining and interleaving the independent FMA instructions across the 4 unrolled accumulators (e.g., executing all 4 Horner `p0 = fmadd(c5, r0, c4)` instructions before the next polynomial degree), the Out-of-Order execution engine can fully saturate the execution ports, completely hiding the FMA latency.
**Evidence:** The explicitly interleaved `softmax_v6` achieved ~4.25 GFLOP/s compared to `softmax_v5`'s 4.02 GFLOP/s (Fixed Memory Mode, N=1M), a consistent ~5-10% throughput improvement.
**Action:** When unrolling loops containing long FMA latency chains (like polynomial approximations for transcendental functions), consider manual instruction-level interleaving across the independent accumulators rather than sequential calls to the vector function.
Comment on lines +31 to +34
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Replace the literal shell placeholder with a real date.

$(date +%Y-%m-%d) will be stored verbatim in this markdown file, so the entry will not sort or search like the rest of the dated journal entries. Please replace it with the actual ISO date before merge.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @.jules/thunderbolt.md around lines 31 - 34, The markdown entry uses the
literal shell placeholder "$(date +%Y-%m-%d)" instead of a real ISO date;
replace that placeholder with the actual date string (e.g., "2026-05-07") in the
header line so the entry sorts and indexes like other dated journal lines—edit
the header beginning "## $(date +%Y-%m-%d) - Explicit Instruction Interleaving
in AVX2 Softmax" to use the concrete YYYY-MM-DD date.

186 changes: 186 additions & 0 deletions ml_kernels/include/ml_kernels/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,4 +501,190 @@ inline void softmax_v5(const float *input, float *output, std::size_t n) {
}
}


// ⚡ Thunderbolt: Explicitly Interleaved AVX2 Softmax
// Target: AVX2 (Haswell+)
// Reason: Manual instruction interleaving of 4x unrolled exp256 breaks FMA latency chains
// Expected gain: ~10% throughput over softmax_v5
inline void softmax_v6(const float *input, float *output, std::size_t n) {
if (n == 0) return;

// 1. Find max
std::size_t i = 0;
__m256 max_v = _mm256_set1_ps(std::numeric_limits<float>::lowest());
__m256 max0 = max_v, max1 = max_v, max2 = max_v, max3 = max_v;

for (; i + 31 < n; i += 32) {
max0 = _mm256_max_ps(max0, _mm256_loadu_ps(input + i));
max1 = _mm256_max_ps(max1, _mm256_loadu_ps(input + i + 8));
max2 = _mm256_max_ps(max2, _mm256_loadu_ps(input + i + 16));
max3 = _mm256_max_ps(max3, _mm256_loadu_ps(input + i + 24));
}
max0 = _mm256_max_ps(max0, max1);
max2 = _mm256_max_ps(max2, max3);
max0 = _mm256_max_ps(max0, max2);
for (; i + 7 < n; i += 8) {
max0 = _mm256_max_ps(max0, _mm256_loadu_ps(input + i));
}
float max_val = reduce_max(max0);
for (; i < n; ++i) max_val = std::max(max_val, input[i]);

__m256 max_vec = _mm256_set1_ps(max_val);

// 2. Compute exp and sum
i = 0;
__m256 sum0 = _mm256_setzero_ps();
__m256 sum1 = _mm256_setzero_ps();
__m256 sum2 = _mm256_setzero_ps();
__m256 sum3 = _mm256_setzero_ps();

__m256 c1 = _mm256_set1_ps(1.0f);
__m256 c2 = _mm256_set1_ps(1.0f / 2.0f);
__m256 c3 = _mm256_set1_ps(1.0f / 6.0f);
__m256 c4 = _mm256_set1_ps(1.0f / 24.0f);
__m256 c5 = _mm256_set1_ps(1.0f / 120.0f);
__m256 log2e = _mm256_set1_ps(1.4426950408889634f);
__m256 ln2_hi = _mm256_set1_ps(0.693145751953125f);
__m256 ln2_lo = _mm256_set1_ps(1.428606765330187e-06f);
__m256 min_val = _mm256_set1_ps(-87.3f);
__m256i shift127 = _mm256_set1_epi32(127);

for (; i + 31 < n; i += 32) {
__m256 x0 = _mm256_sub_ps(_mm256_loadu_ps(input + i), max_vec);
__m256 x1 = _mm256_sub_ps(_mm256_loadu_ps(input + i + 8), max_vec);
__m256 x2 = _mm256_sub_ps(_mm256_loadu_ps(input + i + 16), max_vec);
__m256 x3 = _mm256_sub_ps(_mm256_loadu_ps(input + i + 24), max_vec);

x0 = _mm256_max_ps(x0, min_val);
x1 = _mm256_max_ps(x1, min_val);
x2 = _mm256_max_ps(x2, min_val);
x3 = _mm256_max_ps(x3, min_val);

__m256 x0_log2e = _mm256_mul_ps(x0, log2e);
__m256 x1_log2e = _mm256_mul_ps(x1, log2e);
__m256 x2_log2e = _mm256_mul_ps(x2, log2e);
__m256 x3_log2e = _mm256_mul_ps(x3, log2e);

__m256i n0_int = _mm256_cvtps_epi32(x0_log2e);
__m256i n1_int = _mm256_cvtps_epi32(x1_log2e);
__m256i n2_int = _mm256_cvtps_epi32(x2_log2e);
__m256i n3_int = _mm256_cvtps_epi32(x3_log2e);

__m256 n0 = _mm256_cvtepi32_ps(n0_int);
__m256 n1 = _mm256_cvtepi32_ps(n1_int);
__m256 n2 = _mm256_cvtepi32_ps(n2_int);
__m256 n3 = _mm256_cvtepi32_ps(n3_int);

__m256 r0 = _mm256_fnmadd_ps(n0, ln2_hi, x0);
__m256 r1 = _mm256_fnmadd_ps(n1, ln2_hi, x1);
__m256 r2 = _mm256_fnmadd_ps(n2, ln2_hi, x2);
__m256 r3 = _mm256_fnmadd_ps(n3, ln2_hi, x3);

r0 = _mm256_fnmadd_ps(n0, ln2_lo, r0);
r1 = _mm256_fnmadd_ps(n1, ln2_lo, r1);
r2 = _mm256_fnmadd_ps(n2, ln2_lo, r2);
r3 = _mm256_fnmadd_ps(n3, ln2_lo, r3);

__m256 p0 = _mm256_fmadd_ps(c5, r0, c4);
__m256 p1 = _mm256_fmadd_ps(c5, r1, c4);
__m256 p2 = _mm256_fmadd_ps(c5, r2, c4);
__m256 p3 = _mm256_fmadd_ps(c5, r3, c4);

p0 = _mm256_fmadd_ps(p0, r0, c3);
p1 = _mm256_fmadd_ps(p1, r1, c3);
p2 = _mm256_fmadd_ps(p2, r2, c3);
p3 = _mm256_fmadd_ps(p3, r3, c3);

p0 = _mm256_fmadd_ps(p0, r0, c2);
p1 = _mm256_fmadd_ps(p1, r1, c2);
p2 = _mm256_fmadd_ps(p2, r2, c2);
p3 = _mm256_fmadd_ps(p3, r3, c2);

p0 = _mm256_fmadd_ps(p0, r0, c1);
p1 = _mm256_fmadd_ps(p1, r1, c1);
p2 = _mm256_fmadd_ps(p2, r2, c1);
p3 = _mm256_fmadd_ps(p3, r3, c1);

p0 = _mm256_fmadd_ps(p0, r0, c1);
p1 = _mm256_fmadd_ps(p1, r1, c1);
p2 = _mm256_fmadd_ps(p2, r2, c1);
p3 = _mm256_fmadd_ps(p3, r3, c1);

__m256i exp_shift0 = _mm256_add_epi32(n0_int, shift127);
__m256i exp_shift1 = _mm256_add_epi32(n1_int, shift127);
__m256i exp_shift2 = _mm256_add_epi32(n2_int, shift127);
__m256i exp_shift3 = _mm256_add_epi32(n3_int, shift127);

__m256i exp_shifted0 = _mm256_slli_epi32(exp_shift0, 23);
__m256i exp_shifted1 = _mm256_slli_epi32(exp_shift1, 23);
__m256i exp_shifted2 = _mm256_slli_epi32(exp_shift2, 23);
__m256i exp_shifted3 = _mm256_slli_epi32(exp_shift3, 23);

__m256 exp2n0 = _mm256_castsi256_ps(exp_shifted0);
__m256 exp2n1 = _mm256_castsi256_ps(exp_shifted1);
__m256 exp2n2 = _mm256_castsi256_ps(exp_shifted2);
__m256 exp2n3 = _mm256_castsi256_ps(exp_shifted3);

__m256 e0 = _mm256_mul_ps(p0, exp2n0);
__m256 e1 = _mm256_mul_ps(p1, exp2n1);
__m256 e2 = _mm256_mul_ps(p2, exp2n2);
__m256 e3 = _mm256_mul_ps(p3, exp2n3);

_mm256_storeu_ps(output + i, e0);
_mm256_storeu_ps(output + i + 8, e1);
_mm256_storeu_ps(output + i + 16, e2);
_mm256_storeu_ps(output + i + 24, e3);

sum0 = _mm256_add_ps(sum0, e0);
sum1 = _mm256_add_ps(sum1, e1);
sum2 = _mm256_add_ps(sum2, e2);
sum3 = _mm256_add_ps(sum3, e3);
}
sum0 = _mm256_add_ps(sum0, sum1);
sum2 = _mm256_add_ps(sum2, sum3);
sum0 = _mm256_add_ps(sum0, sum2);

for (; i + 7 < n; i += 8) {
__m256 x = _mm256_loadu_ps(input + i);
__m256 e = exp256_ps_v2(_mm256_sub_ps(x, max_vec));
_mm256_storeu_ps(output + i, e);
sum0 = _mm256_add_ps(sum0, e);
}

float sum_val = reduce_sum(sum0);
for (; i < n; ++i) {
float e = std::exp(input[i] - max_val);
output[i] = e;
sum_val += e;
}

if (sum_val == 0.0f) return;

// 3. Normalize
float inv_sum = 1.0f / sum_val;
__m256 inv_sum_v = _mm256_set1_ps(inv_sum);
i = 0;
for (; i + 31 < n; i += 32) {
__m256 o0 = _mm256_loadu_ps(output + i);
__m256 o1 = _mm256_loadu_ps(output + i + 8);
__m256 o2 = _mm256_loadu_ps(output + i + 16);
__m256 o3 = _mm256_loadu_ps(output + i + 24);

__m256 m0 = _mm256_mul_ps(o0, inv_sum_v);
__m256 m1 = _mm256_mul_ps(o1, inv_sum_v);
__m256 m2 = _mm256_mul_ps(o2, inv_sum_v);
__m256 m3 = _mm256_mul_ps(o3, inv_sum_v);

_mm256_storeu_ps(output + i, m0);
_mm256_storeu_ps(output + i + 8, m1);
_mm256_storeu_ps(output + i + 16, m2);
_mm256_storeu_ps(output + i + 24, m3);
}
for (; i + 7 < n; i += 8) {
_mm256_storeu_ps(output + i, _mm256_mul_ps(_mm256_loadu_ps(output + i), inv_sum_v));
}
for (; i < n; ++i) {
output[i] *= inv_sum;
}
}
} // namespace ml_kernels
11 changes: 11 additions & 0 deletions ml_kernels/src/kernel_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,14 @@ class MaxV3Benchmark : public MaxBenchmarkBase {
std::size_t current_idx_ = 0;
};
REGISTER_BENCHMARK(MaxV3Benchmark);

class SoftmaxV6Benchmark : public SoftmaxBenchmark {
public:
const char *name() const override { return "softmax_v6"; }

void run() override {
ml_kernels::softmax_v6(inputs_[current_idx_].data(), outputs_[current_idx_].data(), inputs_[0].size());
current_idx_ = (current_idx_ + 1) % pool_size_;
}
};
REGISTER_BENCHMARK(SoftmaxV6Benchmark);
32 changes: 31 additions & 1 deletion ml_kernels/src/test_naive_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <random>
#include <cassert>
#include <iostream>
#include <vector>
Expand All @@ -7,6 +8,7 @@
#include "ml_kernels/naive_ops.h"
#include "ml_kernels/softmax.h"

void test_softmax_v6();
void test_max_naive() {
// Happy path
{
Expand Down Expand Up @@ -187,5 +189,33 @@ int main() {
test_softmax_v3();
test_softmax_v4();
test_softmax_v5();
test_softmax_v6();
std::cout << "All tests passed successfully!" << std::endl;
}
}
void test_softmax_v6() {
std::cout << "Testing softmax_v6..." << std::endl;
for (std::size_t n : {1, 3, 8, 15, 16, 31, 32, 33, 100, 1024, 1024 * 1024 + 7}) {
std::vector<float> in(n);
std::vector<float> expected(n);
std::vector<float> actual(n);

std::mt19937 gen(42);
std::uniform_real_distribution<float> dist(-10.0f, 10.0f);
for (std::size_t i = 0; i < n; ++i) {
in[i] = dist(gen);
}

ml_kernels::softmax_naive(in.data(), expected.data(), n);
ml_kernels::softmax_v6(in.data(), actual.data(), n);

for (std::size_t i = 0; i < n; ++i) {
float err = std::abs(expected[i] - actual[i]);
if (err > 1e-4f) {
std::cerr << "Mismatch at n=" << n << " i=" << i
<< " expected=" << expected[i]
<< " actual=" << actual[i] << std::endl;
std::exit(1);
}
}
}
}
Loading