Skip to content

Commit e08f9eb

Browse files
ElektrikAkarclaude
andcommitted
feat(simd): wire multi-pair DTW into fillDistanceMatrix
Batches pairs into groups of 4 for SIMD dispatch. OpenMP parallelizes over batches. Scalar fallback preserved when Highway is disabled. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bc15994 commit e08f9eb

File tree

1 file changed

+55
-6
lines changed

1 file changed

+55
-6
lines changed

dtwc/Problem.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
#include "types/Range.hpp" // for Range
2323
#include "initialisation.hpp" // For initialisation functions
2424

25+
#ifdef DTWC_HAS_HIGHWAY
26+
#include "simd/multi_pair_dtw.hpp"
27+
#endif
2528

2629
#include <algorithm> // for max_element, min, min_element, sample
2730
#include <cmath> // for sqrt, floor
@@ -187,14 +190,60 @@ void Problem::fillDistanceMatrix()
187190
if (isDistanceMatrixFilled()) return;
188191

189192
const size_t N = data.size();
193+
194+
std::cout << "Distance matrix is being filled!" << std::endl;
195+
196+
#ifdef DTWC_HAS_HIGHWAY
197+
// Multi-pair SIMD path: batch 4 pairs per SIMD call.
198+
const size_t total_pairs = N * (N + 1) / 2;
199+
200+
// Pre-generate all (i,j) pairs for batching.
201+
struct Pair { size_t i, j; };
202+
std::vector<Pair> pairs;
203+
pairs.reserve(total_pairs);
204+
for (size_t i = 0; i < N; ++i)
205+
for (size_t j = i; j < N; ++j)
206+
pairs.push_back({i, j});
207+
208+
// Process pairs in batches of 4, parallelized over batches.
209+
const size_t n_batches = (pairs.size() + 3) / 4;
210+
auto batchTask = [&](size_t batch_idx) {
211+
const size_t start = batch_idx * 4;
212+
const size_t end = std::min(start + 4, pairs.size());
213+
const size_t n_in_batch = end - start;
214+
215+
const double* x_ptrs[4] = {};
216+
const double* y_ptrs[4] = {};
217+
std::size_t x_lens[4] = {};
218+
std::size_t y_lens[4] = {};
219+
220+
for (size_t k = 0; k < n_in_batch; ++k) {
221+
const auto& [pi, pj] = pairs[start + k];
222+
x_ptrs[k] = p_vec(pi).data();
223+
y_ptrs[k] = p_vec(pj).data();
224+
x_lens[k] = p_vec(pi).size();
225+
y_lens[k] = p_vec(pj).size();
226+
}
227+
228+
auto result = simd::dtw_multi_pair(x_ptrs, y_ptrs, x_lens, y_lens, n_in_batch);
229+
230+
for (size_t k = 0; k < n_in_batch; ++k) {
231+
const auto& [pi, pj] = pairs[start + k];
232+
if (pi == pj) {
233+
distMat.set(pi, pj, 0.0);
234+
} else {
235+
distMat.set(pi, pj, result.distances[k]);
236+
}
237+
}
238+
};
239+
run(batchTask, n_batches);
240+
241+
#else
242+
// Scalar fallback: one pair at a time (original implementation).
190243
auto oneTask = [this, N](size_t k) {
191-
// Decode linear triangular index k into (i, j) where 0 <= i <= j < N.
192-
// Row i starts at cumulative offset i*N - i*(i-1)/2, so we solve for i
193-
// using the quadratic formula (O(1) instead of O(N) loop).
194244
const double Nd = static_cast<double>(N);
195245
const double kd = static_cast<double>(k);
196246
size_t i = static_cast<size_t>(std::floor(Nd + 0.5 - std::sqrt((Nd + 0.5) * (Nd + 0.5) - 2.0 * kd)));
197-
// Guard against floating-point rounding: if row_start for i+1 still fits, advance.
198247
size_t row_start = i * N - i * (i - 1) / 2;
199248
if (row_start + (N - i) <= k) {
200249
row_start += (N - i);
@@ -203,9 +252,9 @@ void Problem::fillDistanceMatrix()
203252
size_t j = i + (k - row_start);
204253
distByInd(static_cast<int>(i), static_cast<int>(j));
205254
};
206-
207-
std::cout << "Distance matrix is being filled!" << std::endl;
208255
run(oneTask, N * (N + 1) / 2);
256+
#endif
257+
209258
is_distMat_filled = true;
210259
std::cout << "Distance matrix has been filled!" << std::endl;
211260
}

0 commit comments

Comments
 (0)