diff --git a/doc/changes/dev/13698.other.rst b/doc/changes/dev/13698.other.rst new file mode 100644 index 00000000000..fab48e560d1 --- /dev/null +++ b/doc/changes/dev/13698.other.rst @@ -0,0 +1 @@ +Add optional low-variance ("hat") regularization to :func:`mne.stats.f_oneway` via new ``sigma`` and ``method`` parameters, by `Aniket Singh Yadav`_. \ No newline at end of file diff --git a/mne/stats/parametric.py b/mne/stats/parametric.py index 2cc0bff2ea1..7cdbcb367ea 100644 --- a/mne/stats/parametric.py +++ b/mne/stats/parametric.py @@ -111,7 +111,7 @@ def ttest_ind_no_p(a, b, equal_var=True, sigma=0.0): return t -def f_oneway(*args): +def f_oneway(*args, sigma=0.0, method="relative"): """Perform a 1-way ANOVA. The one-way ANOVA tests the null hypothesis that 2 or more groups have @@ -125,6 +125,16 @@ def f_oneway(*args): ---------- *args : array_like The sample measurements should be given as arguments. + sigma : float + Regularization parameter (>= 0) added to the within-group mean + square to mitigate F-statistic inflation under low-variance + conditions. ``0`` (default) disables correction. + method : str + How *sigma* is interpreted when ``sigma > 0``. Can be + ``'relative'`` (default) or ``'absolute'``. + ``'relative'`` multiplies *sigma* by the median within-group + mean square (scale-invariant, recommended). + ``'absolute'`` uses *sigma* directly as a raw sigma squared. Returns ------- @@ -151,6 +161,9 @@ def f_oneway(*args): ---------- .. footbibliography:: """ + _check_option("method", method, ["absolute", "relative"]) + if sigma < 0: + raise ValueError(f"sigma must be >= 0, got {sigma}") n_classes = len(args) n_samples_per_class = np.array([len(a) for a in args]) n_samples = np.sum(n_samples_per_class) @@ -168,6 +181,12 @@ def f_oneway(*args): dfwn = n_samples - n_classes msb = ssbn / float(dfbn) msw = sswn / float(dfwn) + if sigma > 0.0: + if method == "relative": + sigma_sq = sigma * np.median(msw) + else: + sigma_sq = float(sigma) + msw = (sswn + sigma_sq) / float(dfwn) f = msb / msw return f diff --git a/mne/stats/tests/test_parametric.py b/mne/stats/tests/test_parametric.py index 61ecbc43af3..79d274ca4ba 100644 --- a/mne/stats/tests/test_parametric.py +++ b/mne/stats/tests/test_parametric.py @@ -11,7 +11,7 @@ from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_less import mne -from mne.stats.parametric import _map_effects, f_mway_rm, f_threshold_mway_rm +from mne.stats.parametric import _map_effects, f_mway_rm, f_oneway, f_threshold_mway_rm # hardcoded external test results, manually transferred test_external = { @@ -175,3 +175,57 @@ def theirs(*a, **kw): # something to the divisor (var) assert_allclose(got, want, rtol=2e-1, atol=1e-2) assert_array_less(np.abs(got), np.abs(want)) + + +@pytest.mark.parametrize("sigma", [0.0, 1e-3]) +@pytest.mark.parametrize("method", ["absolute", "relative"]) +@pytest.mark.parametrize("seed", [0, 42, 1337]) +def test_f_oneway_hat(sigma, method, seed): + """Test f_oneway hat (low-variance) regularization.""" + rng = np.random.RandomState(seed) + X1 = rng.randn(10, 50) + X2 = rng.randn(10, 50) + + f_ours = f_oneway(X1, X2, sigma=0.0, method=method) + f_scipy = scipy.stats.f_oneway(X1, X2)[0] + assert_allclose(f_ours, f_scipy, rtol=1e-7, atol=1e-6) + + if sigma > 0: + f_reg = f_oneway(X1, X2, sigma=sigma, method=method) + f_unreg = f_oneway(X1, X2, sigma=0.0) + pos = f_unreg > 0 + assert_array_less(f_reg[pos], f_unreg[pos]) + + +def test_f_oneway_hat_small_variance(): + """Test that f_oneway hat stabilizes F-values for near-zero variance.""" + rng = np.random.RandomState(0) + X1 = rng.normal(0, 1e-6, (10, 100)) + X2 = rng.normal(1, 1e-6, (10, 100)) + + f_unreg = f_oneway(X1, X2, sigma=0.0) + f_abs = f_oneway(X1, X2, sigma=1e-3, method="absolute") + f_rel = f_oneway(X1, X2, sigma=1e-3, method="relative") + + assert np.median(f_unreg) > 1e6 + assert np.median(f_abs) < np.median(f_unreg) + assert np.median(f_rel) < np.median(f_unreg) + assert np.all(np.isfinite(f_abs)) + assert np.all(np.isfinite(f_rel)) + + +def test_f_oneway_hat_input_validation(): + """Test f_oneway input validation for sigma and method.""" + rng = np.random.RandomState(0) + X1 = rng.randn(5, 10) + X2 = rng.randn(5, 10) + + f_plain = f_oneway(X1, X2, sigma=0.0) + f_scipy = scipy.stats.f_oneway(X1, X2)[0] + assert_allclose(f_plain, f_scipy, rtol=1e-7) + + with pytest.raises(ValueError, match="sigma must be >= 0"): + f_oneway(X1, X2, sigma=-0.1) + + with pytest.raises(ValueError, match="method"): + f_oneway(X1, X2, sigma=1e-3, method="invalid")