diff --git a/doc/api/preprocessing.rst b/doc/api/preprocessing.rst index f1d60cfdc3b..2e7047eb4b9 100644 --- a/doc/api/preprocessing.rst +++ b/doc/api/preprocessing.rst @@ -137,6 +137,7 @@ Projections: short_channels scalp_coupling_index temporal_derivative_distribution_repair + motion_correct_wavelet :py:mod:`mne.preprocessing.ieeg`: diff --git a/doc/changes/dev/13692.newfeature.rst b/doc/changes/dev/13692.newfeature.rst new file mode 100644 index 00000000000..df228454446 --- /dev/null +++ b/doc/changes/dev/13692.newfeature.rst @@ -0,0 +1 @@ +Add :func:`mne.preprocessing.nirs.motion_correct_wavelet` (alias ``mne.preprocessing.nirs.wavelet``) for wavelet-based motion correction of fNIRS data (spike removal via SWT and IQR thresholding), based on Homer3 ``hmrR_MotionCorrectWavelet`` :footcite:t:`MolaviDumont2012`, by :newcontrib:`Leonardo Zaggia`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 09c5818b6af..d4ff059b795 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -182,6 +182,7 @@ .. _Laurent Le Mentec: https://github.com/LaurentLM .. _Leonardo Barbosa: https://github.com/noreun .. _Leonardo Rochael Almeida: https://github.com/leorochael +.. _Leonardo Zaggia: https://github.com/leonardozaggia .. _Liberty Hamilton: https://github.com/libertyh .. _Lorenzo Desantis: https://github.com/lorenzo-desantis/ .. _Lukas Breuer: https://www.researchgate.net/profile/Lukas-Breuer-2 diff --git a/doc/references.bib b/doc/references.bib index 12d63458bd1..cad83673772 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -866,6 +866,17 @@ @article{HouckClaus2020 year = {2020} } +@article{HuppertEtAl2009, + author = {Huppert, Theodore J. and Diamond, Solomon G. and Franceschini, Maria A. and Boas, David A.}, + doi = {10.1364/AO.48.00D280}, + journal = {Applied Optics}, + number = {10}, + pages = {D280-D298}, + title = {{{HomER}}: a review of time-series analysis methods for near-infrared spectroscopy of the brain}, + volume = {48}, + year = {2009} +} + @article{Hyvarinen1999, author = {Hyvärinen, Aapo}, doi = {10.1109/72.761722}, @@ -1254,6 +1265,17 @@ @misc{Mills2016 year = {2016} } +@article{MolaviDumont2012, + author = {Molavi, Behnam and Dumont, Guy A.}, + doi = {10.1088/0967-3334/33/2/259}, + journal = {Physiological Measurement}, + number = {2}, + pages = {259-270}, + title = {Wavelet-based motion artifact removal for functional near-infrared spectroscopy}, + volume = {33}, + year = {2012} +} + @article{MolinsEtAl2008, author = {Molins A, and Stufflebeam S. M., and Brown E. N., and Hämäläinen M. S.}, doi = {10.1016/j.neuroimage.2008.05.064}, diff --git a/mne/preprocessing/nirs/__init__.py b/mne/preprocessing/nirs/__init__.py index 6e35b59dd15..925d51d5f6a 100644 --- a/mne/preprocessing/nirs/__init__.py +++ b/mne/preprocessing/nirs/__init__.py @@ -20,3 +20,4 @@ from ._beer_lambert_law import beer_lambert_law from ._scalp_coupling_index import scalp_coupling_index from ._tddr import temporal_derivative_distribution_repair, tddr +from ._wavelet import motion_correct_wavelet, wavelet diff --git a/mne/preprocessing/nirs/_wavelet.py b/mne/preprocessing/nirs/_wavelet.py new file mode 100644 index 00000000000..5d50232f20a --- /dev/null +++ b/mne/preprocessing/nirs/_wavelet.py @@ -0,0 +1,248 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +# The core logic for this implementation was adapted from the Cedalion project +# (https://github.com/ibs-lab/cedalion), which is originally based on Homer3 +# (https://github.com/BUNPC/Homer3). + +import numpy as np + +from ...io import BaseRaw +from ...utils import _validate_type, verbose +from ..nirs import _validate_nirs_info + + +def _pad_to_power_2(signal): + """Pad a 1-D signal to the next power of 2. + + Parameters + ---------- + signal : array-like, shape (n,) + Input signal. + + Returns + ------- + padded : ndarray, shape (2**k,) + Zero-padded signal. + original_length : int + Length of the original signal before padding. + """ + original_length = len(signal) + if original_length <= 1: + n = 1 + else: + n = int(np.ceil(np.log2(original_length))) + padded_length = 2**n + padded = np.zeros(padded_length) + padded[:original_length] = signal + return padded, original_length + + +def _mad(x): + """Compute Median Absolute Deviation.""" + median = np.median(x) + return np.median(np.abs(x - median)) + + +def _normalize_signal(signal, wavelet_name, pywt_module): + """Normalize signal by its noise level using MAD of detail coefficients. + + Implements Homer3's ``NormalizationNoise`` function + :footcite:`HuppertEtAl2009`. + + Parameters + ---------- + signal : ndarray, shape (n,) + Input signal (should already be padded to a power of 2). + wavelet_name : str + Wavelet to use (e.g. ``'db2'``). + pywt_module : module + The PyWavelets module. + + Returns + ------- + normalized_signal : ndarray, shape (n,) + Normalized version of the input signal. + norm_coef : float + Multiply ``normalized_signal`` by ``1 / norm_coef`` to recover the + original scale. + """ + wvlt = pywt_module.Wavelet(wavelet_name) + # Homer3 uses qmf(db2, 4) which produces the HIGH-pass decomposition + # filter from the scaling (low-pass) filter. In PyWavelets this is dec_hi. + qmf = np.array(wvlt.dec_hi) + + # Circular convolution matching MATLAB's cconv(signal, qmf, len(signal)) + n = len(signal) + c = np.real(np.fft.ifft(np.fft.fft(signal, n) * np.fft.fft(qmf, n))) + + # Downsample by 2 — first-level detail coefficients for noise estimation + y_ds = c[::2] + + median_abs_dev = _mad(y_ds) + + if median_abs_dev != 0: + norm_coef = 1.0 / (1.4826 * median_abs_dev) + normalized_signal = signal * norm_coef + else: + norm_coef = 1.0 + normalized_signal = signal.copy() + + return normalized_signal, norm_coef + + +def _process_wavelet_coefficients(coeffs, iqr_factor, signal_length): + """Zero out outlier wavelet coefficients using IQR thresholding. + + Parameters + ---------- + coeffs : ndarray, shape (n_padded, n_levels + 1) + Stacked wavelet coefficient array (first column = approx, rest = + detail per level). + iqr_factor : float + Interquartile-range multiplier for the outlier threshold. + signal_length : int + Original (unpadded) signal length used to compute per-block valid + lengths. + + Returns + ------- + coeffs : ndarray + Coefficient array with outliers zeroed out. + """ + n = coeffs.shape[0] + n_levels = coeffs.shape[1] - 1 + + for j in range(n_levels): + curr_length = signal_length // (2**j) if j > 0 else signal_length + n_blocks = 2**j + block_length = n // n_blocks + + for b in range(n_blocks): + start_idx = b * block_length + end_idx = start_idx + block_length + coeff_block = coeffs[start_idx:end_idx, j + 1] + + valid_coeffs = coeff_block[:curr_length] + q25, q75 = np.percentile(valid_coeffs, [25, 75]) + iqr_val = q75 - q25 + + upper = q75 + iqr_factor * iqr_val + lower = q25 - iqr_factor * iqr_val + + coeffs[start_idx:end_idx, j + 1] = np.where( + (coeff_block > upper) | (coeff_block < lower), + 0, + coeff_block, + ) + + return coeffs + + +@verbose +def motion_correct_wavelet(raw, iqr=1.5, wavelet="db2", level=4, *, verbose=None): + """Apply wavelet-based motion correction to fNIRS data. + + Decomposes each channel with a stationary wavelet transform (SWT), zeroes + out detail coefficients that are statistical outliers (IQR-based), and + reconstructs the corrected signal. Specialises in spike removal. + + Based on Homer3 v1.80.2 ``hmrR_MotionCorrectWavelet.m`` + :footcite:`HuppertEtAl2009` and the approach described in + :footcite:`MolaviDumont2012`. + + Parameters + ---------- + raw : instance of Raw + The raw fNIRS data (optical density or hemoglobin). + iqr : float + Interquartile-range multiplier used as the outlier threshold for + wavelet coefficients. Larger values remove fewer coefficients. Set + to ``-1`` to disable thresholding entirely. Default is ``1.5``. + wavelet : str + Mother wavelet name recognised by PyWavelets (e.g. ``'db2'``). + Default is ``'db2'``. + level : int + Number of decomposition levels for the SWT. Default is ``4``. + %(verbose)s + + Returns + ------- + raw : instance of Raw + Data with wavelet motion correction applied (copy). + + Notes + ----- + Requires the ``PyWavelets`` package (``pip install PyWavelets``). + + There is a shorter alias ``mne.preprocessing.nirs.wavelet`` + that can be used instead of this function. + + References + ---------- + .. footbibliography:: + """ + try: + import pywt + except ImportError as exc: + raise ImportError( + "PyWavelets is required for wavelet motion correction. " + "Install it with: pip install PyWavelets" + ) from exc + + _validate_type(raw, BaseRaw, "raw") + raw = raw.copy().load_data() + picks = _validate_nirs_info(raw.info) + + if not len(picks): + raise RuntimeError( + "Wavelet motion correction should be run on optical density " + "or hemoglobin data." + ) + + if iqr < 0: + return raw + + for pick in picks: + signal = raw._data[pick].copy() + + # Pad to power of 2 (required by SWT) + padded_signal, original_length = _pad_to_power_2(signal) + + # Remove DC component + dc_val = np.mean(padded_signal) + padded_signal -= dc_val + + # Normalise by estimated noise level + normalized_signal, norm_coef = _normalize_signal(padded_signal, wavelet, pywt) + + # Stationary wavelet transform + n_log2 = int(np.log2(len(normalized_signal))) + actual_level = min(level, n_log2 - 1) + coeffs = pywt.swt(normalized_signal, wavelet, level=actual_level) + + # Stack into a 2-D array: col 0 = approx, cols 1..L = detail levels + coeffs_array = np.column_stack([coeffs[0][0]] + [c[1] for c in coeffs]) + + # Threshold outlier coefficients + coeffs_array = _process_wavelet_coefficients(coeffs_array, iqr, original_length) + + # Rebuild list of (approx, detail) tuples for iswt + coeffs_list = [ + (coeffs_array[:, 0], coeffs_array[:, i]) + for i in range(1, coeffs_array.shape[1]) + ] + + # Reconstruct, denormalise, restore DC and trim to original length + corrected = pywt.iswt(coeffs_list, wavelet) + corrected = corrected / norm_coef + corrected = corrected[:original_length] + dc_val + + raw._data[pick] = corrected + + return raw + + +# provide a short alias +wavelet = motion_correct_wavelet diff --git a/mne/preprocessing/nirs/tests/test_wavelet.py b/mne/preprocessing/nirs/tests/test_wavelet.py new file mode 100644 index 00000000000..2a8d52147de --- /dev/null +++ b/mne/preprocessing/nirs/tests/test_wavelet.py @@ -0,0 +1,112 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from mne.datasets import testing +from mne.datasets.testing import data_path +from mne.io import read_raw_nirx +from mne.preprocessing.nirs import ( + _validate_nirs_info, + beer_lambert_law, + motion_correct_wavelet, + optical_density, + wavelet, +) + +fname_nirx_15_2 = ( + data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording" +) + + +@testing.requires_testing_data +@pytest.mark.parametrize("fname", ([fname_nirx_15_2])) +def test_motion_correct_wavelet_reduces_spikes_od(fname): + """Test wavelet correction attenuates spike artefacts in OD data.""" + pytest.importorskip("pywt") + + raw = read_raw_nirx(fname) + raw_od = optical_density(raw) + picks = _validate_nirs_info(raw_od.info) + + sig = raw_od._data[picks[0]] + spike_amp = 20 * np.std(np.diff(sig)) + + # Inject isolated single-sample spikes + n_times = raw_od._data.shape[1] + raw_od._data[picks[0], n_times // 4] += spike_amp + raw_od._data[picks[0], n_times // 2] -= spike_amp + + spike_before = np.max(np.abs(np.diff(raw_od._data[picks[0]]))) + + raw_od_corr = motion_correct_wavelet(raw_od, iqr=1.5) + + spike_after = np.max(np.abs(np.diff(raw_od_corr._data[picks[0]]))) + assert spike_after < spike_before + + +@testing.requires_testing_data +@pytest.mark.parametrize("fname", ([fname_nirx_15_2])) +def test_motion_correct_wavelet_reduces_spikes_hb(fname): + """Test wavelet correction works on haemoglobin concentration data.""" + pytest.importorskip("pywt") + + raw = read_raw_nirx(fname) + raw_od = optical_density(raw) + raw_hb = beer_lambert_law(raw_od) + picks = _validate_nirs_info(raw_hb.info) + + spike_amp = 20 * np.std(np.diff(raw_hb._data[picks[0]])) + n_times = raw_hb._data.shape[1] + raw_hb._data[picks[0], n_times // 4] += spike_amp + raw_hb._data[picks[0], n_times // 2] -= spike_amp + + spike_before = np.max(np.abs(np.diff(raw_hb._data[picks[0]]))) + raw_hb_corr = motion_correct_wavelet(raw_hb, iqr=1.5) + spike_after = np.max(np.abs(np.diff(raw_hb_corr._data[picks[0]]))) + assert spike_after < spike_before + + +@testing.requires_testing_data +@pytest.mark.parametrize("fname", ([fname_nirx_15_2])) +def test_motion_correct_wavelet_negative_iqr_passthrough(fname): + """Test iqr < 0 returns data unchanged.""" + pytest.importorskip("pywt") + + raw = read_raw_nirx(fname) + raw_od = optical_density(raw) + picks = _validate_nirs_info(raw_od.info) + original = raw_od._data[picks[0]].copy() + + raw_od_corr = motion_correct_wavelet(raw_od, iqr=-1) + assert_allclose(raw_od_corr._data[picks[0]], original) + + +@testing.requires_testing_data +@pytest.mark.parametrize("fname", ([fname_nirx_15_2])) +def test_motion_correct_wavelet_returns_copy(fname): + """Test wavelet correction does not modify the input Raw in place.""" + pytest.importorskip("pywt") + + raw = read_raw_nirx(fname) + raw_od = optical_density(raw) + picks = _validate_nirs_info(raw_od.info) + original = raw_od._data[picks[0]].copy() + + _ = motion_correct_wavelet(raw_od) + assert_allclose(raw_od._data[picks[0]], original) + + +def test_wavelet_alias(): + """Test wavelet is an alias for motion_correct_wavelet.""" + assert wavelet is motion_correct_wavelet + + +def test_motion_correct_wavelet_wrong_type(): + """Test passing a non-Raw object raises TypeError.""" + pytest.importorskip("pywt") + with pytest.raises(TypeError): + motion_correct_wavelet(np.zeros((10, 100)))