Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13698.other.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add optional low-variance ("hat") regularization to :func:`mne.stats.f_oneway` via new ``sigma`` and ``method`` parameters, by `Aniket Singh Yadav`_.
21 changes: 20 additions & 1 deletion mne/stats/parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def ttest_ind_no_p(a, b, equal_var=True, sigma=0.0):
return t


def f_oneway(*args):
def f_oneway(*args, sigma=0.0, method="relative"):
"""Perform a 1-way ANOVA.

The one-way ANOVA tests the null hypothesis that 2 or more groups have
Expand All @@ -125,6 +125,16 @@ def f_oneway(*args):
----------
*args : array_like
The sample measurements should be given as arguments.
sigma : float
Regularization parameter (>= 0) added to the within-group mean
square to mitigate F-statistic inflation under low-variance
conditions. ``0`` (default) disables correction.
method : str
How *sigma* is interpreted when ``sigma > 0``. Can be
``'relative'`` (default) or ``'absolute'``.
``'relative'`` multiplies *sigma* by the median within-group
mean square (scale-invariant, recommended).
``'absolute'`` uses *sigma* directly as a raw sigma squared.

Returns
-------
Expand All @@ -151,6 +161,9 @@ def f_oneway(*args):
----------
.. footbibliography::
"""
_check_option("method", method, ["absolute", "relative"])
if sigma < 0:
raise ValueError(f"sigma must be >= 0, got {sigma}")
n_classes = len(args)
n_samples_per_class = np.array([len(a) for a in args])
n_samples = np.sum(n_samples_per_class)
Expand All @@ -168,6 +181,12 @@ def f_oneway(*args):
dfwn = n_samples - n_classes
msb = ssbn / float(dfbn)
msw = sswn / float(dfwn)
if sigma > 0.0:
if method == "relative":
sigma_sq = sigma * np.median(msw)
else:
sigma_sq = float(sigma)
msw = (sswn + sigma_sq) / float(dfwn)
f = msb / msw
return f

Expand Down
56 changes: 55 additions & 1 deletion mne/stats/tests/test_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_less

import mne
from mne.stats.parametric import _map_effects, f_mway_rm, f_threshold_mway_rm
from mne.stats.parametric import _map_effects, f_mway_rm, f_oneway, f_threshold_mway_rm

# hardcoded external test results, manually transferred
test_external = {
Expand Down Expand Up @@ -175,3 +175,57 @@ def theirs(*a, **kw):
# something to the divisor (var)
assert_allclose(got, want, rtol=2e-1, atol=1e-2)
assert_array_less(np.abs(got), np.abs(want))


@pytest.mark.parametrize("sigma", [0.0, 1e-3])
@pytest.mark.parametrize("method", ["absolute", "relative"])
@pytest.mark.parametrize("seed", [0, 42, 1337])
def test_f_oneway_hat(sigma, method, seed):
"""Test f_oneway hat (low-variance) regularization."""
rng = np.random.RandomState(seed)
X1 = rng.randn(10, 50)
X2 = rng.randn(10, 50)

f_ours = f_oneway(X1, X2, sigma=0.0, method=method)
f_scipy = scipy.stats.f_oneway(X1, X2)[0]
assert_allclose(f_ours, f_scipy, rtol=1e-7, atol=1e-6)

if sigma > 0:
f_reg = f_oneway(X1, X2, sigma=sigma, method=method)
f_unreg = f_oneway(X1, X2, sigma=0.0)
pos = f_unreg > 0
assert_array_less(f_reg[pos], f_unreg[pos])


def test_f_oneway_hat_small_variance():
"""Test that f_oneway hat stabilizes F-values for near-zero variance."""
rng = np.random.RandomState(0)
X1 = rng.normal(0, 1e-6, (10, 100))
X2 = rng.normal(1, 1e-6, (10, 100))

f_unreg = f_oneway(X1, X2, sigma=0.0)
f_abs = f_oneway(X1, X2, sigma=1e-3, method="absolute")
f_rel = f_oneway(X1, X2, sigma=1e-3, method="relative")

assert np.median(f_unreg) > 1e6
assert np.median(f_abs) < np.median(f_unreg)
assert np.median(f_rel) < np.median(f_unreg)
assert np.all(np.isfinite(f_abs))
assert np.all(np.isfinite(f_rel))


def test_f_oneway_hat_input_validation():
"""Test f_oneway input validation for sigma and method."""
rng = np.random.RandomState(0)
X1 = rng.randn(5, 10)
X2 = rng.randn(5, 10)

f_plain = f_oneway(X1, X2, sigma=0.0)
f_scipy = scipy.stats.f_oneway(X1, X2)[0]
assert_allclose(f_plain, f_scipy, rtol=1e-7)

with pytest.raises(ValueError, match="sigma must be >= 0"):
f_oneway(X1, X2, sigma=-0.1)

with pytest.raises(ValueError, match="method"):
f_oneway(X1, X2, sigma=1e-3, method="invalid")
Loading