Skip to content

PERF: Speed up permutation cluster tests via Numba union-find + compact graph#13731

Draft
sharifhsn wants to merge 1 commit intomne-tools:mainfrom
sharifhsn:perf-opt
Draft

PERF: Speed up permutation cluster tests via Numba union-find + compact graph#13731
sharifhsn wants to merge 1 commit intomne-tools:mainfrom
sharifhsn:perf-opt

Conversation

@sharifhsn
Copy link

@sharifhsn sharifhsn commented Mar 9, 2026

Reference issue

Related: #5439, #7784, #8095, #12609

What does this implement/fix?

Speeds up spatio_temporal_cluster_1samp_test (and the other permutation_cluster_* functions) by replacing hot-path Python/SciPy calls with tighter NumPy and a Numba JIT kernel.

  1. Compact-graph reindexing (_get_components): Before calling connected_components, build a subgraph containing only the supra-threshold vertices. This shrinks the CCL input from ~20K to ~200–1000 vertices — the single biggest win.

  2. Numba union-find (_st_fused_ccl): For spatio-temporal neighbor-list adjacency, replaces the Python BFS in _get_clusters_st with a single compiled union-find pass over spatial + temporal edges. Avoids constructing a scipy.sparse matrix every permutation.

  3. Skip cluster-list construction during permutations (_sums_only): The permutation loop only needs max cluster sums, not the cluster index arrays. A new _sums_only flag lets _find_clusters return sums directly via np.bincount on component labels, skipping argsort/split/concatenate.

  4. Precomputed sum-of-squares for sign-flip t-tests: For the default ttest_1samp_no_p, the sign-flip identity s²=1 means sum(X²) is constant across permutations. Each permutation computes t with a single matrix-vector multiply (signs @ X) plus a few elementwise ops, replacing np.dot + np.var + multiple temporaries.

  5. Smaller overhead reductions:

    • Vectorized cluster sums via np.add.reduceat (replaces per-cluster loop)
    • _pval_from_histogram via np.searchsorted instead of O(n·k) loop
    • Pre-computed CSR adjacency arrays (avoids redundant list→CSR rebuilds)
    • buffer_size verification skipped for built-in stat functions
    • Vectorized sign-order generation via bit-shifting

All optimizations are gated behind has_numba checks and fall back to the original code paths when Numba is not installed. No public API changes.

Benchmarks

Measured on AWS EC2 (AMD EPYC 9R14, 16 vCPU). Setup: spatio_temporal_cluster_1samp_test with 15 subjects, fsaverage ico-5 source space (20,484 vertices × 15 timepoints = 307,260 tests), tail=1, threshold=1.67.

End-to-end (before → after):

Permutations Before After Speedup
256 3.09 s 0.55 s 5.6×
1024 7.49 s 1.72 s 4.4×
2048 13.8 s 3.29 s 4.2×

Per-permutation cost drops from ~6.3 ms to ~1.5 ms (single-threaded). Fixed overhead drops from ~880 ms to ~140 ms.

Bitwise parity verified (t_obs, H0, p-values, cluster counts) across all three tail modes.

Additional information

  • The Numba JIT warmup happens once on the first call. Subsequent calls pay no warmup cost.
  • TFCE (threshold=dict(...)) correctly falls back to the original code path — the Numba fast paths are bypassed when the threshold is a dict.
  • Custom stat functions (anything other than ttest_1samp_no_p) also fall back to the original code path, still benefiting from the CCL and overhead optimizations but not the fused t-test.

@drammock
Copy link
Member

drammock commented Mar 9, 2026

this is a pretty large diff. Before we invest time in a review, a few questions:

  1. is this ready for review? if not, please mark as draft until it is ready
  2. have you run relevant tests and doc build locally? are they all passing?
  3. please disclose the way(s) in which you used AI to assist in this contribution (if any)

tip: next time, if you name the changelog newfeature.rst instead of XXXXX.newfeature.rst then one of our CIs will automatically rename it to include the PR number for you.

@sharifhsn sharifhsn marked this pull request as draft March 9, 2026 19:21
@sharifhsn
Copy link
Author

My bad, I meant to publish it as draft. Yes it's still in progress.

Tests are all passing.

AI was used to generate the code which was checked over manually to catch bugs.

I'm going to try to reduce the amount of code as much as possible while still maintaining the main speedups.

The permutation loop in spatio_temporal_cluster_1samp_test is the main
computational bottleneck for source-space cluster analyses. On fsaverage
ico-5 (~20K vertices, 15 timepoints, 307K tests), 2048 permutations
previously took ~14 seconds; this patch brings it to under 1 second on
a 16-core EPYC.

The key changes, roughly in order of impact:

- Fused Numba union-find (_st_fused_ccl) replaces the Python BFS in
  _get_clusters_st. Handles both spatial neighbors (CSR adjacency) and
  temporal self-connections in a single compiled pass, avoiding the
  overhead of the old _get_clusters_spatial + _reassign loop.

- Parallel permutation processing (_perm_batch_fast) fuses threshold
  scan + CCL + weighted bincount + argmax into one @jit(parallel=True)
  function with prange across permutations. Each perm gets its own
  pre-allocated work buffers to avoid data races.

- Batched fused t-test (_batched_fused_ttest) reads X_T once from DRAM
  and computes t-statistics for 32 permutations, amortizing memory
  traffic ~32x. Uses prange across variables.

- Single-perm fused t-test (_fused_ttest) replaces the numpy sequence
  (dot product + 8 elementwise ops) with one prange loop.

- Compact-graph CCL in _get_components: builds a subgraph of only the
  supra-threshold vertices before calling connected_components, so CCL
  operates on ~1K vertices instead of ~20K.

- Vectorized _pval_from_histogram with np.searchsorted (O(n log n)
  instead of O(n * n_perms)).

- Vectorized _get_1samp_orders via bit-shifting instead of per-element
  np.fromiter(np.binary_repr(...)).

- Pre-computed CSR arrays from _setup_adjacency, threaded through to
  _find_clusters and the permutation loop to avoid redundant
  neighbor-list-to-CSR conversion.

All fast paths are gated behind has_numba checks and identity checks on
the stat function (only activates for the built-in ttest_1samp_no_p and
f_oneway). Non-Numba fallback paths are unchanged.

Tested on EPYC 9R14 (16 vCPU), fsaverage ico-5, 15 subjects, 2048
permutations: 0.37 ms/perm wall time vs ~5.7 ms/perm baseline.
@sharifhsn sharifhsn changed the title PERF: Speed up permutation cluster tests ~15× via Numba JIT kernels PERF: Speed up permutation cluster tests via Numba union-find + compact graph Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants