|
| 1 | +# Authors: The MNE-Python contributors. |
| 2 | +# License: BSD-3-Clause |
| 3 | +# Copyright the MNE-Python contributors. |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from scipy.interpolate import UnivariateSpline |
| 7 | + |
| 8 | +from ...io import BaseRaw |
| 9 | +from ...utils import _validate_type, verbose |
| 10 | +from ..nirs import _validate_nirs_info |
| 11 | + |
| 12 | + |
| 13 | +def _compute_window(seg_length, dt_short, dt_long, fs): |
| 14 | + """Compute window size for spline baseline correction. |
| 15 | +
|
| 16 | + Parameters |
| 17 | + ---------- |
| 18 | + seg_length : int |
| 19 | + Length of the segment in samples. |
| 20 | + dt_short : float |
| 21 | + Short time interval in seconds. |
| 22 | + dt_long : float |
| 23 | + Long time interval in seconds. |
| 24 | + fs : float |
| 25 | + Sampling frequency in Hz. |
| 26 | +
|
| 27 | + Returns |
| 28 | + ------- |
| 29 | + wind : int |
| 30 | + Window size in samples (at least 1). |
| 31 | + """ |
| 32 | + if seg_length < dt_short * fs: |
| 33 | + wind = seg_length |
| 34 | + elif seg_length < dt_long * fs: |
| 35 | + wind = int(np.floor(dt_short * fs)) |
| 36 | + else: |
| 37 | + wind = int(np.floor(seg_length / 10)) |
| 38 | + return max(int(wind), 1) |
| 39 | + |
| 40 | + |
| 41 | +@verbose |
| 42 | +def motion_correct_spline(raw, p=0.01, tIncCh=None, *, verbose=None): |
| 43 | + """Apply spline interpolation motion correction to fNIRS data. |
| 44 | +
|
| 45 | + For each detected motion-artifact segment the signal is detrended with a |
| 46 | + smoothing spline, then consecutive segments are baseline-shifted so that |
| 47 | + they connect smoothly. Based on Homer3 v1.80.2 |
| 48 | + ``hmrR_tInc_baselineshift_Ch_Nirs.m`` :footcite:`HuppertEtAl2009` and the |
| 49 | + cedalion reimplementation. |
| 50 | +
|
| 51 | + Parameters |
| 52 | + ---------- |
| 53 | + raw : instance of Raw |
| 54 | + The raw fNIRS data (optical density or hemoglobin). |
| 55 | + p : float |
| 56 | + Smoothing parameter for the spline. Smaller values yield a spline |
| 57 | + that follows the data more closely. Default is ``0.01``. |
| 58 | + tIncCh : array-like of bool, shape (n_picks, n_times) | None |
| 59 | + Per-channel motion-artifact mask. ``True`` = clean sample, |
| 60 | + ``False`` = motion artifact. When ``None`` the entire recording is |
| 61 | + treated as a single motion-artifact segment and only spline detrending |
| 62 | + is applied (no baseline shifting). |
| 63 | + %(verbose)s |
| 64 | +
|
| 65 | + Returns |
| 66 | + ------- |
| 67 | + raw : instance of Raw |
| 68 | + Data with spline motion correction applied (copy). |
| 69 | +
|
| 70 | + Notes |
| 71 | + ----- |
| 72 | + ``n_picks`` is the number of fNIRS channels returned by |
| 73 | + ``_validate_nirs_info``. |
| 74 | +
|
| 75 | + There is a shorter alias ``mne.preprocessing.nirs.spline`` that |
| 76 | + can be used instead of this function. |
| 77 | +
|
| 78 | + References |
| 79 | + ---------- |
| 80 | + .. footbibliography:: |
| 81 | + """ |
| 82 | + _validate_type(raw, BaseRaw, "raw") |
| 83 | + raw = raw.copy().load_data() |
| 84 | + picks = _validate_nirs_info(raw.info) |
| 85 | + |
| 86 | + if not len(picks): |
| 87 | + raise RuntimeError( |
| 88 | + "Spline motion correction should be run on optical density " |
| 89 | + "or hemoglobin data." |
| 90 | + ) |
| 91 | + |
| 92 | + n_times = raw._data.shape[1] |
| 93 | + fs = raw.info["sfreq"] |
| 94 | + t = np.arange(n_times) / fs |
| 95 | + |
| 96 | + dt_short = 0.3 # seconds |
| 97 | + dt_long = 3.0 # seconds |
| 98 | + |
| 99 | + if tIncCh is None: |
| 100 | + tIncCh = np.ones((len(picks), n_times), dtype=bool) |
| 101 | + tIncCh = np.asarray(tIncCh, dtype=bool) |
| 102 | + |
| 103 | + for ch_idx, pick in enumerate(picks): |
| 104 | + channel = raw._data[pick].copy() |
| 105 | + mask = tIncCh[ch_idx] # True = good, False = motion |
| 106 | + |
| 107 | + # Only process if there are actual motion-artifact samples |
| 108 | + lst_ma = np.where(~mask)[0] |
| 109 | + if len(lst_ma) == 0: |
| 110 | + continue |
| 111 | + |
| 112 | + temp = np.diff(mask.astype(int)) |
| 113 | + lst_ms = np.where(temp == -1)[0] # good→bad (motion start) |
| 114 | + lst_mf = np.where(temp == 1)[0] # bad→good (motion end) |
| 115 | + |
| 116 | + if len(lst_ms) == 0: |
| 117 | + lst_ms = np.asarray([0]) |
| 118 | + if len(lst_mf) == 0: |
| 119 | + lst_mf = np.asarray([n_times - 1]) |
| 120 | + if lst_ms[0] > lst_mf[0]: |
| 121 | + lst_ms = np.insert(lst_ms, 0, 0) |
| 122 | + if lst_ms[-1] > lst_mf[-1]: |
| 123 | + lst_mf = np.append(lst_mf, n_times - 1) |
| 124 | + |
| 125 | + nb_ma = len(lst_ms) |
| 126 | + lst_ml = lst_mf - lst_ms |
| 127 | + |
| 128 | + dod_spline = channel.copy() |
| 129 | + |
| 130 | + # ---- Step 1: detrend each motion segment with a spline ---- |
| 131 | + for ii in range(nb_ma): |
| 132 | + idx_seg = np.arange(lst_ms[ii], lst_mf[ii]) |
| 133 | + if len(idx_seg) > 3: |
| 134 | + spl = UnivariateSpline(t[idx_seg], channel[idx_seg], s=p * len(idx_seg)) |
| 135 | + dod_spline[idx_seg] = channel[idx_seg] - spl(t[idx_seg]) |
| 136 | + |
| 137 | + # ---- Step 2: baseline-shift segments to align continuity ---- |
| 138 | + |
| 139 | + # First MA segment |
| 140 | + idx_seg = np.arange(lst_ms[0], lst_mf[0]) |
| 141 | + if len(idx_seg) > 0: |
| 142 | + seg_curr_len = lst_ml[0] |
| 143 | + wind_curr = _compute_window(seg_curr_len, dt_short, dt_long, fs) |
| 144 | + |
| 145 | + if lst_ms[0] > 0: |
| 146 | + seg_prev_len = lst_ms[0] |
| 147 | + wind_prev = _compute_window(seg_prev_len, dt_short, dt_long, fs) |
| 148 | + mean_prev = np.mean( |
| 149 | + dod_spline[max(0, idx_seg[0] - wind_prev) : idx_seg[0]] |
| 150 | + ) |
| 151 | + mean_curr = np.mean(dod_spline[idx_seg[0] : idx_seg[0] + wind_curr]) |
| 152 | + dod_spline[idx_seg] = dod_spline[idx_seg] - mean_curr + mean_prev |
| 153 | + else: |
| 154 | + seg_next_len = ( |
| 155 | + (lst_ms[1] - lst_mf[0]) if nb_ma > 1 else (n_times - lst_mf[0]) |
| 156 | + ) |
| 157 | + wind_next = _compute_window(seg_next_len, dt_short, dt_long, fs) |
| 158 | + mean_next = np.mean(dod_spline[idx_seg[-1] : idx_seg[-1] + wind_next]) |
| 159 | + mean_curr = np.mean( |
| 160 | + dod_spline[max(0, idx_seg[-1] - wind_curr) : idx_seg[-1]] |
| 161 | + ) |
| 162 | + dod_spline[idx_seg] = dod_spline[idx_seg] - mean_curr + mean_next |
| 163 | + |
| 164 | + # Intermediate non-MA and MA segments |
| 165 | + for kk in range(nb_ma - 1): |
| 166 | + # Non-motion segment between MA[kk] and MA[kk+1] |
| 167 | + idx_seg = np.arange(lst_mf[kk], lst_ms[kk + 1]) |
| 168 | + seg_prev_len = lst_ml[kk] |
| 169 | + seg_curr_len = len(idx_seg) |
| 170 | + |
| 171 | + wind_prev = _compute_window(seg_prev_len, dt_short, dt_long, fs) |
| 172 | + wind_curr = _compute_window(seg_curr_len, dt_short, dt_long, fs) |
| 173 | + |
| 174 | + mean_prev = np.mean(dod_spline[max(0, idx_seg[0] - wind_prev) : idx_seg[0]]) |
| 175 | + mean_curr = np.mean(channel[idx_seg[0] : idx_seg[0] + wind_curr]) |
| 176 | + dod_spline[idx_seg] = channel[idx_seg] - mean_curr + mean_prev |
| 177 | + |
| 178 | + # Next MA segment |
| 179 | + idx_seg = np.arange(lst_ms[kk + 1], lst_mf[kk + 1]) |
| 180 | + seg_prev_len = seg_curr_len |
| 181 | + seg_curr_len = lst_ml[kk + 1] |
| 182 | + |
| 183 | + wind_prev = _compute_window(seg_prev_len, dt_short, dt_long, fs) |
| 184 | + wind_curr = _compute_window(seg_curr_len, dt_short, dt_long, fs) |
| 185 | + |
| 186 | + mean_prev = np.mean(dod_spline[max(0, idx_seg[0] - wind_prev) : idx_seg[0]]) |
| 187 | + mean_curr = np.mean(dod_spline[idx_seg[0] : idx_seg[0] + wind_curr]) |
| 188 | + dod_spline[idx_seg] = dod_spline[idx_seg] - mean_curr + mean_prev |
| 189 | + |
| 190 | + # Last non-MA segment (after the final motion segment) |
| 191 | + if lst_mf[-1] < n_times: |
| 192 | + idx_seg = np.arange(lst_mf[-1], n_times) |
| 193 | + seg_prev_len = lst_ml[-1] |
| 194 | + seg_curr_len = len(idx_seg) |
| 195 | + |
| 196 | + wind_prev = _compute_window(seg_prev_len, dt_short, dt_long, fs) |
| 197 | + wind_curr = _compute_window(seg_curr_len, dt_short, dt_long, fs) |
| 198 | + |
| 199 | + mean_prev = np.mean(dod_spline[max(0, idx_seg[0] - wind_prev) : idx_seg[0]]) |
| 200 | + mean_curr = np.mean(channel[idx_seg[0] : idx_seg[0] + wind_curr]) |
| 201 | + dod_spline[idx_seg] = channel[idx_seg] - mean_curr + mean_prev |
| 202 | + |
| 203 | + raw._data[pick] = dod_spline |
| 204 | + |
| 205 | + return raw |
| 206 | + |
| 207 | + |
| 208 | +# provide a short alias |
| 209 | +spline = motion_correct_spline |
0 commit comments