PERF: Speed up permutation cluster tests via Numba union-find + compact graph#13731
Draft
sharifhsn wants to merge 1 commit intomne-tools:mainfrom
Draft
PERF: Speed up permutation cluster tests via Numba union-find + compact graph#13731sharifhsn wants to merge 1 commit intomne-tools:mainfrom
sharifhsn wants to merge 1 commit intomne-tools:mainfrom
Conversation
Member
|
this is a pretty large diff. Before we invest time in a review, a few questions:
tip: next time, if you name the changelog |
Author
|
My bad, I meant to publish it as draft. Yes it's still in progress. Tests are all passing. AI was used to generate the code which was checked over manually to catch bugs. I'm going to try to reduce the amount of code as much as possible while still maintaining the main speedups. |
The permutation loop in spatio_temporal_cluster_1samp_test is the main computational bottleneck for source-space cluster analyses. On fsaverage ico-5 (~20K vertices, 15 timepoints, 307K tests), 2048 permutations previously took ~14 seconds; this patch brings it to under 1 second on a 16-core EPYC. The key changes, roughly in order of impact: - Fused Numba union-find (_st_fused_ccl) replaces the Python BFS in _get_clusters_st. Handles both spatial neighbors (CSR adjacency) and temporal self-connections in a single compiled pass, avoiding the overhead of the old _get_clusters_spatial + _reassign loop. - Parallel permutation processing (_perm_batch_fast) fuses threshold scan + CCL + weighted bincount + argmax into one @jit(parallel=True) function with prange across permutations. Each perm gets its own pre-allocated work buffers to avoid data races. - Batched fused t-test (_batched_fused_ttest) reads X_T once from DRAM and computes t-statistics for 32 permutations, amortizing memory traffic ~32x. Uses prange across variables. - Single-perm fused t-test (_fused_ttest) replaces the numpy sequence (dot product + 8 elementwise ops) with one prange loop. - Compact-graph CCL in _get_components: builds a subgraph of only the supra-threshold vertices before calling connected_components, so CCL operates on ~1K vertices instead of ~20K. - Vectorized _pval_from_histogram with np.searchsorted (O(n log n) instead of O(n * n_perms)). - Vectorized _get_1samp_orders via bit-shifting instead of per-element np.fromiter(np.binary_repr(...)). - Pre-computed CSR arrays from _setup_adjacency, threaded through to _find_clusters and the permutation loop to avoid redundant neighbor-list-to-CSR conversion. All fast paths are gated behind has_numba checks and identity checks on the stat function (only activates for the built-in ttest_1samp_no_p and f_oneway). Non-Numba fallback paths are unchanged. Tested on EPYC 9R14 (16 vCPU), fsaverage ico-5, 15 subjects, 2048 permutations: 0.37 ms/perm wall time vs ~5.7 ms/perm baseline.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Reference issue
Related: #5439, #7784, #8095, #12609
What does this implement/fix?
Speeds up
spatio_temporal_cluster_1samp_test(and the otherpermutation_cluster_*functions) by replacing hot-path Python/SciPy calls with tighter NumPy and a Numba JIT kernel.Compact-graph reindexing (
_get_components): Before callingconnected_components, build a subgraph containing only the supra-threshold vertices. This shrinks the CCL input from ~20K to ~200–1000 vertices — the single biggest win.Numba union-find (
_st_fused_ccl): For spatio-temporal neighbor-list adjacency, replaces the Python BFS in_get_clusters_stwith a single compiled union-find pass over spatial + temporal edges. Avoids constructing ascipy.sparsematrix every permutation.Skip cluster-list construction during permutations (
_sums_only): The permutation loop only needs max cluster sums, not the cluster index arrays. A new_sums_onlyflag lets_find_clustersreturn sums directly vianp.bincounton component labels, skippingargsort/split/concatenate.Precomputed sum-of-squares for sign-flip t-tests: For the default
ttest_1samp_no_p, the sign-flip identity s²=1 meanssum(X²)is constant across permutations. Each permutation computestwith a single matrix-vector multiply (signs @ X) plus a few elementwise ops, replacingnp.dot+np.var+ multiple temporaries.Smaller overhead reductions:
np.add.reduceat(replaces per-cluster loop)_pval_from_histogramvianp.searchsortedinstead of O(n·k) loopbuffer_sizeverification skipped for built-in stat functionsAll optimizations are gated behind
has_numbachecks and fall back to the original code paths when Numba is not installed. No public API changes.Benchmarks
Measured on AWS EC2 (AMD EPYC 9R14, 16 vCPU). Setup:
spatio_temporal_cluster_1samp_testwith 15 subjects, fsaverage ico-5 source space (20,484 vertices × 15 timepoints = 307,260 tests),tail=1,threshold=1.67.End-to-end (before → after):
Per-permutation cost drops from ~6.3 ms to ~1.5 ms (single-threaded). Fixed overhead drops from ~880 ms to ~140 ms.
Bitwise parity verified (t_obs, H0, p-values, cluster counts) across all three tail modes.
Additional information
threshold=dict(...)) correctly falls back to the original code path — the Numba fast paths are bypassed when the threshold is a dict.ttest_1samp_no_p) also fall back to the original code path, still benefiting from the CCL and overhead optimizations but not the fused t-test.