Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ Projections:
short_channels
scalp_coupling_index
temporal_derivative_distribution_repair
motion_correct_pca

:py:mod:`mne.preprocessing.ieeg`:

Expand Down
1 change: 1 addition & 0 deletions doc/changes/dev/13691.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add :func:`mne.preprocessing.nirs.motion_correct_pca` (alias ``mne.preprocessing.nirs.pca``) for PCA-based motion correction of fNIRS data, based on Homer3 ``hmrR_MotionCorrectPCA``, by :newcontrib:`Leonardo Zaggia`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
.. _Laurent Le Mentec: https://github.com/LaurentLM
.. _Leonardo Barbosa: https://github.com/noreun
.. _Leonardo Rochael Almeida: https://github.com/leorochael
.. _Leonardo Zaggia: https://github.com/leonardozaggia
.. _Liberty Hamilton: https://github.com/libertyh
.. _Lorenzo Desantis: https://github.com/lorenzo-desantis/
.. _Lukas Breuer: https://www.researchgate.net/profile/Lukas-Breuer-2
Expand Down
22 changes: 22 additions & 0 deletions doc/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,17 @@ @article{HouckClaus2020
year = {2020}
}

@article{HuppertEtAl2009,
author = {Huppert, Theodore J. and Diamond, Solomon G. and Franceschini, Maria A. and Boas, David A.},
doi = {10.1364/AO.48.00D280},
journal = {Applied Optics},
number = {10},
pages = {D280-D298},
title = {{{HomER}}: a review of time-series analysis methods for near-infrared spectroscopy of the brain},
volume = {48},
year = {2009}
}

@article{Hyvarinen1999,
author = {Hyvärinen, Aapo},
doi = {10.1109/72.761722},
Expand Down Expand Up @@ -1254,6 +1265,17 @@ @misc{Mills2016
year = {2016}
}

@article{MolaviDumont2012,
author = {Molavi, Behnam and Dumont, Guy A.},
doi = {10.1088/0967-3334/33/2/259},
journal = {Physiological Measurement},
number = {2},
pages = {259-270},
title = {Wavelet-based motion artifact removal for functional near-infrared spectroscopy},
volume = {33},
year = {2012}
}

@article{MolinsEtAl2008,
author = {Molins A, and Stufflebeam S. M., and Brown E. N., and Hämäläinen M. S.},
doi = {10.1016/j.neuroimage.2008.05.064},
Expand Down
1 change: 1 addition & 0 deletions mne/preprocessing/nirs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from ._beer_lambert_law import beer_lambert_law
from ._scalp_coupling_index import scalp_coupling_index
from ._tddr import temporal_derivative_distribution_repair, tddr
from ._pca import motion_correct_pca, pca
183 changes: 183 additions & 0 deletions mne/preprocessing/nirs/_pca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

# The core logic for this implementation was adapted from the Cedalion project
# (https://github.com/ibs-lab/cedalion), which is originally based on Homer3
# (https://github.com/BUNPC/Homer3).

import numpy as np
from scipy.linalg import svd

from ...io import BaseRaw
from ...utils import _validate_type, verbose
from ..nirs import _validate_nirs_info


@verbose
def motion_correct_pca(raw, tInc, nSV=0.97, *, verbose=None):
"""Apply PCA-based motion correction to fNIRS data.

Extracts the motion-artifact portions of the signal, performs a singular
value decomposition across all fNIRS channels, removes the dominant
principal components, and reinserts the cleaned segments.

Based on Homer3 v1.80.2 ``hmrR_MotionCorrectPCA.m``
:footcite:`HuppertEtAl2009`.

Parameters
----------
raw : instance of Raw
The raw fNIRS data (optical density or hemoglobin).
tInc : array-like of bool, shape (n_times,)
Global motion-artifact mask. ``True`` = clean sample,
``False`` = motion artifact.
nSV : float | int
Number of principal components to remove.

* ``0 < nSV < 1`` – remove the fewest components whose cumulative
explained variance reaches ``nSV`` (e.g. ``0.97`` removes
components explaining the top 97 %% of variance).
* ``nSV >= 1`` – remove exactly ``int(nSV)`` components.

Default is ``0.97``.
%(verbose)s

Returns
-------
raw : instance of Raw
Data with PCA motion correction applied (copy).
svs : ndarray
Normalised singular values (sum = 1) of the PCA decomposition.
nSV : int
Actual number of principal components removed.

Notes
-----
There is a shorter alias ``mne.preprocessing.nirs.pca`` that
can be used instead of this function.

References
----------
.. footbibliography::
"""
_validate_type(raw, BaseRaw, "raw")
raw = raw.copy().load_data()
picks = _validate_nirs_info(raw.info)

if not len(picks):
raise RuntimeError(
"PCA motion correction should be run on optical density or hemoglobin data."
)

tInc = np.asarray(tInc, dtype=bool)
n_times = raw._data.shape[1]

# (n_picks, n_times)
y = raw._data[picks, :]

# Motion-artifact samples only, shape (n_motion_samples, n_picks)
y_motion = y[:, ~tInc].T

if y_motion.shape[0] == 0:
raise ValueError(
"No motion-artifact samples found (tInc is all True). "
"Mark artifact regions with False before calling this function."
)

# Z-score across time within the motion segments
y_mean = y_motion.mean(axis=0)
y_std = y_motion.std(axis=0)
y_std[y_std == 0] = 1.0 # avoid divide-by-zero for flat channels
y_zscore = (y_motion - y_mean) / y_std

# PCA via SVD of the covariance-like matrix
yo = y_zscore.copy()
c = np.dot(y_zscore.T, y_zscore)
V, St, _ = svd(c)

svs = St / np.sum(St)

# Cumulative variance for fractional nSV selection
svsc = svs.copy()
for i in range(1, len(svs)):
svsc[i] = svsc[i - 1] + svs[i]

if 0 < nSV < 1:
mask_keep = svsc < nSV
nSV = int(np.where(~mask_keep)[0][0])
else:
nSV = int(nSV)

ev = np.zeros((len(svs), 1))
ev[:nSV] = 1
ev = np.diag(np.squeeze(ev))

# Remove top PCs from motion-artifact segments
yc = yo - np.dot(np.dot(yo, V), np.dot(ev, V.T))
yc = (yc * y_std) + y_mean # back to original scale

# Identify motion-segment boundaries (1=good → 0=bad transitions)
lst_ms = np.where(np.diff(tInc.astype(int)) == -1)[0] # motion starts
lst_mf = np.where(np.diff(tInc.astype(int)) == 1)[0] # motion ends

if len(lst_ms) == 0:
lst_ms = np.asarray([0])
if len(lst_mf) == 0:
lst_mf = np.asarray([n_times - 1])
if lst_ms[0] > lst_mf[0]:
lst_ms = np.insert(lst_ms, 0, 0)
if lst_ms[-1] > lst_mf[-1]:
lst_mf = np.append(lst_mf, n_times - 1)

# Cumulative lengths used to index into yc
lst_mb = lst_mf - lst_ms
for ii in range(1, len(lst_mb)):
lst_mb[ii] = lst_mb[ii - 1] + lst_mb[ii]
lst_mb = lst_mb - 1

n_picks = len(picks)
cleaned_ts = y.T.copy() # (n_times, n_picks)
orig_ts = y.T.copy()

for jj in range(n_picks):
# First motion segment
lst = np.arange(lst_ms[0], lst_mf[0])
if lst_ms[0] > 0:
cleaned_ts[lst, jj] = (
yc[: lst_mb[0] + 1, jj] - yc[0, jj] + cleaned_ts[lst[0], jj]
)
else:
cleaned_ts[lst, jj] = (
yc[: lst_mb[0] + 1, jj] - yc[lst_mb[0], jj] + cleaned_ts[lst[-1], jj]
)

# Intermediate non-motion and motion segments
for kk in range(len(lst_mf) - 1):
# Non-motion gap between MA[kk] and MA[kk+1]
lst = np.arange(lst_mf[kk] - 1, lst_ms[kk + 1] + 1)
cleaned_ts[lst, jj] = (
orig_ts[lst, jj] - orig_ts[lst[0], jj] + cleaned_ts[lst[0], jj]
)

# Next motion segment
lst = np.arange(lst_ms[kk + 1], lst_mf[kk + 1])
cleaned_ts[lst, jj] = (
yc[lst_mb[kk] + 1 : lst_mb[kk + 1] + 1, jj]
- yc[lst_mb[kk] + 1, jj]
+ cleaned_ts[lst[0], jj]
)

# Trailing non-motion segment
if lst_mf[-1] < n_times - 1:
lst = np.arange(lst_mf[-1] - 1, n_times)
cleaned_ts[lst, jj] = (
orig_ts[lst, jj] - orig_ts[lst[0], jj] + cleaned_ts[lst[0], jj]
)

raw._data[picks, :] = cleaned_ts.T
return raw, svs, nSV


# provide a short alias
pca = motion_correct_pca
125 changes: 125 additions & 0 deletions mne/preprocessing/nirs/tests/test_pca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import numpy as np
import pytest
from numpy.testing import assert_allclose

from mne.datasets import testing
from mne.datasets.testing import data_path
from mne.io import read_raw_nirx
from mne.preprocessing.nirs import (
_validate_nirs_info,
motion_correct_pca,
optical_density,
pca,
)

fname_nirx_15_2 = (
data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording"
)


def _make_tinc(n_times, bad_start, bad_stop):
"""Return a per-time boolean mask (True=clean, False=motion)."""
tInc = np.ones(n_times, dtype=bool)
tInc[bad_start:bad_stop] = False
return tInc


@testing.requires_testing_data
@pytest.mark.parametrize("fname", ([fname_nirx_15_2]))
def test_motion_correct_pca_removes_shared_artifact(fname):
"""Test PCA correction reduces a correlated motion artefact."""
raw = read_raw_nirx(fname)
raw_od = optical_density(raw)
picks = _validate_nirs_info(raw_od.info)
n_times = raw_od._data.shape[1]

# Inject a correlated step shift in the middle of the recording
mid = n_times // 2
bad_start, bad_stop = mid - 10, mid + 10

# Save clean signal for comparison
original = raw_od._data[picks[0]].copy()

max_shift = np.max(np.abs(np.diff(raw_od._data[picks[0]])))
shift_amp = 20 * max_shift

for pick in picks:
raw_od._data[pick, bad_start:bad_stop] += shift_amp

tInc = _make_tinc(n_times, bad_start, bad_stop)
raw_od_corr, svs, nSV_ret = motion_correct_pca(raw_od, tInc=tInc, nSV=0.97)

# Corrected signal should be closer to original than corrupted signal
mse_before = np.mean((raw_od._data[picks[0]] - original) ** 2)
mse_after = np.mean((raw_od_corr._data[picks[0]] - original) ** 2)
assert mse_after < mse_before


@testing.requires_testing_data
@pytest.mark.parametrize("fname", ([fname_nirx_15_2]))
def test_motion_correct_pca_svs_sum_to_one(fname):
"""Test returned singular values are normalised and sum to 1."""
raw = read_raw_nirx(fname)
raw_od = optical_density(raw)
n_times = raw_od._data.shape[1]

tInc = _make_tinc(n_times, n_times // 2 - 5, n_times // 2 + 5)
_, svs, _ = motion_correct_pca(raw_od, tInc=tInc, nSV=0.97)

assert_allclose(np.sum(svs), 1.0, rtol=1e-6)


@testing.requires_testing_data
@pytest.mark.parametrize("fname", ([fname_nirx_15_2]))
def test_motion_correct_pca_integer_nsv(fname):
"""Test integer nSV removes exactly that many components."""
raw = read_raw_nirx(fname)
raw_od = optical_density(raw)
n_times = raw_od._data.shape[1]

tInc = _make_tinc(n_times, n_times // 3, n_times // 3 + 20)
_, _, nSV_ret = motion_correct_pca(raw_od, tInc=tInc, nSV=2)
assert nSV_ret == 2


@testing.requires_testing_data
@pytest.mark.parametrize("fname", ([fname_nirx_15_2]))
def test_motion_correct_pca_returns_copy(fname):
"""Test PCA correction does not modify the input Raw in place."""
raw = read_raw_nirx(fname)
raw_od = optical_density(raw)
picks = _validate_nirs_info(raw_od.info)
n_times = raw_od._data.shape[1]
original = raw_od._data[picks[0]].copy()

tInc = _make_tinc(n_times, 10, 30)
_, _, _ = motion_correct_pca(raw_od, tInc=tInc)
assert_allclose(raw_od._data[picks[0]], original)


@testing.requires_testing_data
@pytest.mark.parametrize("fname", ([fname_nirx_15_2]))
def test_motion_correct_pca_all_good_raises(fname):
"""Test PCA correction raises when tInc has no artefact samples."""
raw = read_raw_nirx(fname)
raw_od = optical_density(raw)
n_times = raw_od._data.shape[1]

tInc = np.ones(n_times, dtype=bool) # all clean – no motion to correct
with pytest.raises(ValueError, match="No motion-artifact samples"):
motion_correct_pca(raw_od, tInc=tInc)


def test_pca_alias():
"""Test pca is an alias for motion_correct_pca."""
assert pca is motion_correct_pca


def test_motion_correct_pca_wrong_type():
"""Test passing a non-Raw object raises TypeError."""
with pytest.raises(TypeError):
motion_correct_pca(np.zeros((10, 100)), tInc=np.ones(100, dtype=bool))