Skip to content

Commit 4a44b05

Browse files
ENH: add spline interpolation motion correction for fNIRS (motion_correct_spline)
1 parent 557eca6 commit 4a44b05

6 files changed

Lines changed: 346 additions & 0 deletions

File tree

doc/api/preprocessing.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ Projections:
137137
short_channels
138138
scalp_coupling_index
139139
temporal_derivative_distribution_repair
140+
motion_correct_spline
141+
spline
140142

141143
:py:mod:`mne.preprocessing.ieeg`:
142144

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add :func:`mne.preprocessing.nirs.motion_correct_spline` (alias :func:`mne.preprocessing.nirs.spline`) for spline interpolation-based motion correction of fNIRS data, based on Homer3 ``hmrR_tInc_baselineshift_Ch_Nirs``, by :newcontrib:`Leonardo Zaggia`.

doc/changes/names.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@
181181
.. _Laurent Le Mentec: https://github.com/LaurentLM
182182
.. _Leonardo Barbosa: https://github.com/noreun
183183
.. _Leonardo Rochael Almeida: https://github.com/leorochael
184+
.. _Leonardo Zaggia: https://github.com/leonardozaggia
184185
.. _Liberty Hamilton: https://github.com/libertyh
185186
.. _Lorenzo Desantis: https://github.com/lorenzo-desantis/
186187
.. _Lukas Breuer: https://www.researchgate.net/profile/Lukas-Breuer-2

mne/preprocessing/nirs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from ._beer_lambert_law import beer_lambert_law
2121
from ._scalp_coupling_index import scalp_coupling_index
2222
from ._tddr import temporal_derivative_distribution_repair, tddr
23+
from ._spline import motion_correct_spline, spline

mne/preprocessing/nirs/_spline.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Authors: The MNE-Python contributors.
2+
# License: BSD-3-Clause
3+
# Copyright the MNE-Python contributors.
4+
5+
import numpy as np
6+
from scipy.interpolate import UnivariateSpline
7+
8+
from ...io import BaseRaw
9+
from ...utils import _validate_type, verbose
10+
from ..nirs import _validate_nirs_info
11+
12+
13+
def _compute_window(seg_length, dt_short, dt_long, fs):
14+
"""Compute window size for spline baseline correction.
15+
16+
Parameters
17+
----------
18+
seg_length : int
19+
Length of the segment in samples.
20+
dt_short : float
21+
Short time interval in seconds.
22+
dt_long : float
23+
Long time interval in seconds.
24+
fs : float
25+
Sampling frequency in Hz.
26+
27+
Returns
28+
-------
29+
wind : int
30+
Window size in samples (at least 1).
31+
"""
32+
if seg_length < dt_short * fs:
33+
wind = seg_length
34+
elif seg_length < dt_long * fs:
35+
wind = int(np.floor(dt_short * fs))
36+
else:
37+
wind = int(np.floor(seg_length / 10))
38+
return max(int(wind), 1)
39+
40+
41+
@verbose
42+
def motion_correct_spline(raw, p=0.01, tIncCh=None, *, verbose=None):
43+
"""Apply spline interpolation motion correction to fNIRS data.
44+
45+
For each detected motion-artifact segment the signal is detrended with a
46+
smoothing spline, then consecutive segments are baseline-shifted so that
47+
they connect smoothly. Based on Homer3 v1.80.2
48+
``hmrR_tInc_baselineshift_Ch_Nirs.m`` :footcite:`HuppertEtAl2009` and the
49+
cedalion reimplementation.
50+
51+
Parameters
52+
----------
53+
raw : instance of Raw
54+
The raw fNIRS data (optical density or hemoglobin).
55+
p : float
56+
Smoothing parameter for the spline. Smaller values yield a spline
57+
that follows the data more closely. Default is ``0.01``.
58+
tIncCh : array-like of bool, shape (n_picks, n_times) | None
59+
Per-channel motion-artifact mask. ``True`` = clean sample,
60+
``False`` = motion artifact. When ``None`` the entire recording is
61+
treated as a single motion-artifact segment and only spline detrending
62+
is applied (no baseline shifting).
63+
%(verbose)s
64+
65+
Returns
66+
-------
67+
raw : instance of Raw
68+
Data with spline motion correction applied (copy).
69+
70+
Notes
71+
-----
72+
``n_picks`` is the number of fNIRS channels returned by
73+
``_validate_nirs_info``.
74+
75+
There is a shorter alias ``mne.preprocessing.nirs.spline`` that
76+
can be used instead of this function.
77+
78+
References
79+
----------
80+
.. footbibliography::
81+
"""
82+
_validate_type(raw, BaseRaw, "raw")
83+
raw = raw.copy().load_data()
84+
picks = _validate_nirs_info(raw.info)
85+
86+
if not len(picks):
87+
raise RuntimeError(
88+
"Spline motion correction should be run on optical density "
89+
"or hemoglobin data."
90+
)
91+
92+
n_times = raw._data.shape[1]
93+
fs = raw.info["sfreq"]
94+
t = np.arange(n_times) / fs
95+
96+
dt_short = 0.3 # seconds
97+
dt_long = 3.0 # seconds
98+
99+
if tIncCh is None:
100+
tIncCh = np.ones((len(picks), n_times), dtype=bool)
101+
tIncCh = np.asarray(tIncCh, dtype=bool)
102+
103+
for ch_idx, pick in enumerate(picks):
104+
channel = raw._data[pick].copy()
105+
mask = tIncCh[ch_idx] # True = good, False = motion
106+
107+
# Only process if there are actual motion-artifact samples
108+
lst_ma = np.where(~mask)[0]
109+
if len(lst_ma) == 0:
110+
continue
111+
112+
temp = np.diff(mask.astype(int))
113+
lst_ms = np.where(temp == -1)[0] # good→bad (motion start)
114+
lst_mf = np.where(temp == 1)[0] # bad→good (motion end)
115+
116+
if len(lst_ms) == 0:
117+
lst_ms = np.asarray([0])
118+
if len(lst_mf) == 0:
119+
lst_mf = np.asarray([n_times - 1])
120+
if lst_ms[0] > lst_mf[0]:
121+
lst_ms = np.insert(lst_ms, 0, 0)
122+
if lst_ms[-1] > lst_mf[-1]:
123+
lst_mf = np.append(lst_mf, n_times - 1)
124+
125+
nb_ma = len(lst_ms)
126+
lst_ml = lst_mf - lst_ms
127+
128+
dod_spline = channel.copy()
129+
130+
# ---- Step 1: detrend each motion segment with a spline ----
131+
for ii in range(nb_ma):
132+
idx_seg = np.arange(lst_ms[ii], lst_mf[ii])
133+
if len(idx_seg) > 3:
134+
spl = UnivariateSpline(t[idx_seg], channel[idx_seg], s=p * len(idx_seg))
135+
dod_spline[idx_seg] = channel[idx_seg] - spl(t[idx_seg])
136+
137+
# ---- Step 2: baseline-shift segments to align continuity ----
138+
139+
# First MA segment
140+
idx_seg = np.arange(lst_ms[0], lst_mf[0])
141+
if len(idx_seg) > 0:
142+
seg_curr_len = lst_ml[0]
143+
wind_curr = _compute_window(seg_curr_len, dt_short, dt_long, fs)
144+
145+
if lst_ms[0] > 0:
146+
seg_prev_len = lst_ms[0]
147+
wind_prev = _compute_window(seg_prev_len, dt_short, dt_long, fs)
148+
mean_prev = np.mean(
149+
dod_spline[max(0, idx_seg[0] - wind_prev) : idx_seg[0]]
150+
)
151+
mean_curr = np.mean(dod_spline[idx_seg[0] : idx_seg[0] + wind_curr])
152+
dod_spline[idx_seg] = dod_spline[idx_seg] - mean_curr + mean_prev
153+
else:
154+
seg_next_len = (
155+
(lst_ms[1] - lst_mf[0]) if nb_ma > 1 else (n_times - lst_mf[0])
156+
)
157+
wind_next = _compute_window(seg_next_len, dt_short, dt_long, fs)
158+
mean_next = np.mean(dod_spline[idx_seg[-1] : idx_seg[-1] + wind_next])
159+
mean_curr = np.mean(
160+
dod_spline[max(0, idx_seg[-1] - wind_curr) : idx_seg[-1]]
161+
)
162+
dod_spline[idx_seg] = dod_spline[idx_seg] - mean_curr + mean_next
163+
164+
# Intermediate non-MA and MA segments
165+
for kk in range(nb_ma - 1):
166+
# Non-motion segment between MA[kk] and MA[kk+1]
167+
idx_seg = np.arange(lst_mf[kk], lst_ms[kk + 1])
168+
seg_prev_len = lst_ml[kk]
169+
seg_curr_len = len(idx_seg)
170+
171+
wind_prev = _compute_window(seg_prev_len, dt_short, dt_long, fs)
172+
wind_curr = _compute_window(seg_curr_len, dt_short, dt_long, fs)
173+
174+
mean_prev = np.mean(dod_spline[max(0, idx_seg[0] - wind_prev) : idx_seg[0]])
175+
mean_curr = np.mean(channel[idx_seg[0] : idx_seg[0] + wind_curr])
176+
dod_spline[idx_seg] = channel[idx_seg] - mean_curr + mean_prev
177+
178+
# Next MA segment
179+
idx_seg = np.arange(lst_ms[kk + 1], lst_mf[kk + 1])
180+
seg_prev_len = seg_curr_len
181+
seg_curr_len = lst_ml[kk + 1]
182+
183+
wind_prev = _compute_window(seg_prev_len, dt_short, dt_long, fs)
184+
wind_curr = _compute_window(seg_curr_len, dt_short, dt_long, fs)
185+
186+
mean_prev = np.mean(dod_spline[max(0, idx_seg[0] - wind_prev) : idx_seg[0]])
187+
mean_curr = np.mean(dod_spline[idx_seg[0] : idx_seg[0] + wind_curr])
188+
dod_spline[idx_seg] = dod_spline[idx_seg] - mean_curr + mean_prev
189+
190+
# Last non-MA segment (after the final motion segment)
191+
if lst_mf[-1] < n_times:
192+
idx_seg = np.arange(lst_mf[-1], n_times)
193+
seg_prev_len = lst_ml[-1]
194+
seg_curr_len = len(idx_seg)
195+
196+
wind_prev = _compute_window(seg_prev_len, dt_short, dt_long, fs)
197+
wind_curr = _compute_window(seg_curr_len, dt_short, dt_long, fs)
198+
199+
mean_prev = np.mean(dod_spline[max(0, idx_seg[0] - wind_prev) : idx_seg[0]])
200+
mean_curr = np.mean(channel[idx_seg[0] : idx_seg[0] + wind_curr])
201+
dod_spline[idx_seg] = channel[idx_seg] - mean_curr + mean_prev
202+
203+
raw._data[pick] = dod_spline
204+
205+
return raw
206+
207+
208+
# provide a short alias
209+
spline = motion_correct_spline
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Authors: The MNE-Python contributors.
2+
# License: BSD-3-Clause
3+
# Copyright the MNE-Python contributors.
4+
5+
import numpy as np
6+
import pytest
7+
from numpy.testing import assert_allclose
8+
9+
from mne.datasets import testing
10+
from mne.datasets.testing import data_path
11+
from mne.io import read_raw_nirx
12+
from mne.preprocessing.nirs import (
13+
_validate_nirs_info,
14+
beer_lambert_law,
15+
motion_correct_spline,
16+
optical_density,
17+
spline,
18+
)
19+
20+
fname_nirx_15_2 = (
21+
data_path(download=False) / "NIRx" / "nirscout" / "nirx_15_2_recording"
22+
)
23+
24+
25+
@testing.requires_testing_data
26+
@pytest.mark.parametrize("fname", ([fname_nirx_15_2]))
27+
def test_motion_correct_spline_reduces_step_od(fname):
28+
"""Test spline correction reduces a step artefact in OD data."""
29+
raw = read_raw_nirx(fname)
30+
raw_od = optical_density(raw)
31+
picks = _validate_nirs_info(raw_od.info)
32+
n_times = raw_od._data.shape[1]
33+
34+
# Save clean signal for comparison
35+
original = raw_od._data[0].copy()
36+
37+
# Inject a step artefact in the first 30 samples of channel 0
38+
max_shift = np.max(np.abs(np.diff(raw_od._data[0])))
39+
shift_amp = 20 * max_shift
40+
raw_od._data[0, 0:30] -= shift_amp
41+
42+
# Build per-channel motion mask: first 30 samples are artefact
43+
tIncCh = np.ones((len(picks), n_times), dtype=bool)
44+
tIncCh[:, 0:30] = False
45+
46+
raw_od_corr = motion_correct_spline(raw_od, p=0.01, tIncCh=tIncCh)
47+
48+
# Corrected signal should be closer to original than corrupted signal
49+
mse_before = np.mean((raw_od._data[0] - original) ** 2)
50+
mse_after = np.mean((raw_od_corr._data[0] - original) ** 2)
51+
assert mse_after < mse_before
52+
53+
54+
@testing.requires_testing_data
55+
@pytest.mark.parametrize("fname", ([fname_nirx_15_2]))
56+
def test_motion_correct_spline_reduces_step_hb(fname):
57+
"""Test spline correction works on haemoglobin concentration data."""
58+
raw = read_raw_nirx(fname)
59+
raw_od = optical_density(raw)
60+
raw_hb = beer_lambert_law(raw_od)
61+
picks = _validate_nirs_info(raw_hb.info)
62+
n_times = raw_hb._data.shape[1]
63+
64+
max_shift = np.max(np.diff(raw_hb._data[0]))
65+
shift_amp = 5 * max_shift
66+
raw_hb._data[0, 0:30] -= shift_amp
67+
68+
tIncCh = np.ones((len(picks), n_times), dtype=bool)
69+
tIncCh[:, 0:30] = False
70+
71+
raw_hb_corr = motion_correct_spline(raw_hb, p=0.01, tIncCh=tIncCh)
72+
assert np.max(np.diff(raw_hb_corr._data[0])) < shift_amp
73+
74+
75+
@testing.requires_testing_data
76+
@pytest.mark.parametrize("fname", ([fname_nirx_15_2]))
77+
def test_motion_correct_spline_constant_channels(fname):
78+
"""Test spline correction does not crash on constant channels."""
79+
raw = read_raw_nirx(fname)
80+
raw_od = optical_density(raw)
81+
picks = _validate_nirs_info(raw_od.info)
82+
n_times = raw_od._data.shape[1]
83+
84+
raw_od._data[picks[0]] = 0.0
85+
raw_od._data[picks[1]] = 1.0
86+
87+
tIncCh = np.ones((len(picks), n_times), dtype=bool)
88+
tIncCh[:, 10:20] = False
89+
90+
raw_od_corr = motion_correct_spline(raw_od, p=0.01, tIncCh=tIncCh)
91+
92+
assert_allclose(raw_od_corr._data[picks[0]], 0.0)
93+
assert_allclose(raw_od_corr._data[picks[1]], 1.0)
94+
95+
96+
@testing.requires_testing_data
97+
@pytest.mark.parametrize("fname", ([fname_nirx_15_2]))
98+
def test_motion_correct_spline_returns_copy(fname):
99+
"""Test spline correction does not modify the input Raw in place."""
100+
raw = read_raw_nirx(fname)
101+
raw_od = optical_density(raw)
102+
picks = _validate_nirs_info(raw_od.info)
103+
n_times = raw_od._data.shape[1]
104+
original = raw_od._data[picks[0]].copy()
105+
106+
tIncCh = np.ones((len(picks), n_times), dtype=bool)
107+
tIncCh[0, 10:30] = False
108+
109+
_ = motion_correct_spline(raw_od, p=0.01, tIncCh=tIncCh)
110+
assert_allclose(raw_od._data[picks[0]], original)
111+
112+
113+
@testing.requires_testing_data
114+
@pytest.mark.parametrize("fname", ([fname_nirx_15_2]))
115+
def test_motion_correct_spline_no_artifacts(fname):
116+
"""Test with tIncCh=None the function runs without raising."""
117+
raw = read_raw_nirx(fname)
118+
raw_od = optical_density(raw)
119+
120+
raw_od_corr = motion_correct_spline(raw_od, p=0.01, tIncCh=None)
121+
assert raw_od_corr._data.shape == raw_od._data.shape
122+
123+
124+
def test_spline_alias():
125+
"""Test spline is an alias for motion_correct_spline."""
126+
assert spline is motion_correct_spline
127+
128+
129+
def test_motion_correct_spline_wrong_type():
130+
"""Test passing a non-Raw object raises TypeError."""
131+
with pytest.raises(TypeError):
132+
motion_correct_spline(np.zeros((10, 100)), p=0.01)

0 commit comments

Comments
 (0)