diff --git a/doc/changes/dev/13885.bugfix.rst b/doc/changes/dev/13885.bugfix.rst new file mode 100644 index 00000000000..828da0d5673 --- /dev/null +++ b/doc/changes/dev/13885.bugfix.rst @@ -0,0 +1 @@ +Fix bug with :meth:`mne.preprocessing.ICA.plot_properties` when using ``reject`` in :meth:`mne.preprocessing.ICA.fit`, by `Eric Larson`_. diff --git a/examples/preprocessing/find_ref_artifacts.py b/examples/preprocessing/find_ref_artifacts.py index 90e3d1fb0da..d03701e1e5a 100644 --- a/examples/preprocessing/find_ref_artifacts.py +++ b/examples/preprocessing/find_ref_artifacts.py @@ -24,7 +24,6 @@ on the reference channels are removed. This technique is fully described and validated in :footcite:`HannaEtAl2020` - """ # Authors: Jeff Hanna # @@ -78,6 +77,7 @@ ica_kwargs = dict( method="picard", fit_params=dict(tol=1e-4), # use a high tol here for speed + random_state=99, ) all_picks = mne.pick_types(raw_tog.info, meg=True, ref_meg=True) ica_tog = ICA(n_components=60, max_iter="auto", allow_ref_meg=True, **ica_kwargs) diff --git a/examples/preprocessing/muscle_ica.py b/examples/preprocessing/muscle_ica.py index f61d1e22bc4..0395503defd 100644 --- a/examples/preprocessing/muscle_ica.py +++ b/examples/preprocessing/muscle_ica.py @@ -48,9 +48,7 @@ # %% # By inspection, let's select out the muscle-artifact components based on -# :footcite:`DharmapraniEtAl2016` manually. -# -# The criteria are: +# :footcite:`DharmapraniEtAl2016` manually. The criteria are: # # - Positive slope of log-log power spectrum between 7 and 75 Hz # (here just flat because it's not in log-log) diff --git a/mne/_fiff/pick.py b/mne/_fiff/pick.py index e007100dae1..1b24fd27509 100644 --- a/mne/_fiff/pick.py +++ b/mne/_fiff/pick.py @@ -1238,7 +1238,7 @@ def _picks_to_idx( extra_repr = ", treated as range({n_chan})" else: picks = none # let _picks_str_to_idx handle it - extra_repr = f'None, treated as "{none}"' + extra_repr = f', treated as "{none}"' # # slice diff --git a/mne/viz/ica.py b/mne/viz/ica.py index b1458896e69..08d1b81bd7e 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -15,7 +15,6 @@ from .._fiff.proj import _has_eeg_average_ref_proj from ..defaults import DEFAULTS, _handle_default from ..utils import ( - _reject_data_segments, _validate_type, fill_doc, verbose, @@ -202,13 +201,10 @@ def _create_properties_layout(figsize=None, fig=None): def _plot_ica_properties( pick, ica, - inst, psds_mean, freqs, - n_trials, - epoch_var, plot_lowpass_edge, - epochs_src, + this_epochs_src, set_title_and_labels, plot_std, psd_ylabel, @@ -219,7 +215,7 @@ def _plot_ica_properties( fig, axes, kind, - dropped_indices, + bad_indices, ): """Plot ICA properties (helper).""" from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable @@ -237,23 +233,15 @@ def _plot_ica_properties( ) # image and erp - # we create a new epoch with dropped rows - epoch_data = epochs_src.get_data(copy=False) - epoch_data = np.insert( - arr=epoch_data, - obj=(dropped_indices - np.arange(len(dropped_indices))).astype(int), - values=0.0, - axis=0, - ) - from ..epochs import EpochsArray - - epochs_src = EpochsArray( - epoch_data, epochs_src.info, tmin=epochs_src.tmin, verbose=0 - ) - + n_trials = len(this_epochs_src) + epoch_var = np.var(this_epochs_src.get_data(), axis=-1) + assert epoch_var.shape[1] == 1 # single channel + epoch_var = epoch_var[:, 0] + assert epoch_var.shape == (len(this_epochs_src),) + this_epochs_src._data[bad_indices] = 0 plot_epochs_image( - epochs_src, - picks=pick, + this_epochs_src, + picks=[0], axes=[image_ax, erp_ax], combine=None, colorbar=False, @@ -273,44 +261,41 @@ def _plot_ica_properties( ) if plot_lowpass_edge: spec_ax.axvline( - inst.info["lowpass"], lw=2, linestyle="--", color="k", alpha=0.2 + this_epochs_src.info["lowpass"], lw=2, linestyle="--", color="k", alpha=0.2 ) # epoch variance + good_indices = np.setdiff1d(np.arange(n_trials), bad_indices) var_ax_divider = make_axes_locatable(var_ax) - hist_ax = var_ax_divider.append_axes("right", size="33%", pad="2.5%") - var_ax.scatter( - range(len(epoch_var)), epoch_var, alpha=0.5, facecolor=[0, 0, 0], lw=0 - ) + hist_ax = var_ax_divider.append_axes("right", size="33%", pad="2.5%", sharey=var_ax) + facecolor = np.zeros((len(epoch_var), 3)) + alpha = np.full(len(epoch_var), 0.5) # rejected epochs in red + facecolor[bad_indices] = [1, 0, 0] + alpha[bad_indices] = 0.75 var_ax.scatter( - dropped_indices, - epoch_var[dropped_indices], - alpha=1.0, - facecolor=[1, 0, 0], - lw=0, + np.arange(n_trials), epoch_var, alpha=alpha, facecolor=facecolor, lw=0 ) # compute percentage of dropped epochs - var_percent = float(len(dropped_indices)) / float(len(epoch_var)) * 100.0 + var_percent = 100 * len(bad_indices) / n_trials # histogram & histogram + epoch_var_good = epoch_var[good_indices] _, counts, _ = hist_ax.hist( - epoch_var, orientation="horizontal", color="k", alpha=0.5 + epoch_var_good, orientation="horizontal", color="k", alpha=0.5 ) # kde - ymin, ymax = hist_ax.get_ylim() try: - kde = gaussian_kde(epoch_var) + kde = gaussian_kde(epoch_var_good) except np.linalg.LinAlgError: pass # singular: happens when there is nothing plotted else: - x = np.linspace(ymin, ymax, 50) + x = np.linspace(epoch_var_good.min(), epoch_var_good.max(), 50) kde_ = kde(x) kde_ /= kde_.max() or 1.0 kde_ *= hist_ax.get_xlim()[-1] * 0.9 hist_ax.plot(kde_, x, color="k") - hist_ax.set_ylim(ymin, ymax) # aesthetics # ---------- @@ -319,16 +304,16 @@ def _plot_ica_properties( # erp set_title_and_labels(erp_ax, [], "Time (s)", "AU") erp_ax.spines["right"].set_color("k") - erp_ax.set_xlim(epochs_src.times[[0, -1]]) + erp_ax.set_xlim(this_epochs_src.times[[0, -1]]) # remove half of yticks if more than 5 yt = erp_ax.get_yticks() if len(yt) > 5: - erp_ax.yaxis.set_ticks(yt[::2]) + erp_ax.set_yticks(yt[::2]) # remove xticks - erp plot shows xticks for both image and erp plot - image_ax.xaxis.set_ticks([]) + image_ax.set_xticks([]) yt = image_ax.get_yticks() - image_ax.yaxis.set_ticks(yt[1:]) + image_ax.set_yticks(yt[1:]) image_ax.set_ylim([-0.5, n_trials + 0.5]) def _set_scale(ax, scale): @@ -342,10 +327,6 @@ def _set_scale(ax, scale): set_title_and_labels(spec_ax, "Spectrum", "Frequency (Hz)", psd_ylabel) spec_ax.yaxis.labelpad = 0 spec_ax.set_xlim(freqs[[0, -1]]) - ylim = spec_ax.get_ylim() - air = np.diff(ylim)[0] * 0.1 - spec_ax.set_ylim(ylim[0] - air, ylim[1] + air) - image_ax.axhline(0, color="k", linewidth=0.5) if log_scale: _set_scale(spec_ax, "log") @@ -603,24 +584,24 @@ def _fast_plot_ica_properties( # calculations # ------------ if isinstance(precomputed_data, tuple): - kind, dropped_indices, epochs_src, data = precomputed_data + kind, bad_indices, epochs_src = precomputed_data else: - kind, dropped_indices, epochs_src, data = _prepare_data_ica_properties( + kind, bad_indices, epochs_src = _prepare_data_ica_properties( inst, ica, reject_by_annotation, reject ) - del reject - ica_data = np.swapaxes(data[:, picks, :], 0, 1) - dropped_src = ica_data + del reject, inst + epochs_src_picked = epochs_src.pick(picks) + del epochs_src + good_indices = np.setdiff1d(np.arange(len(epochs_src_picked)), bad_indices) # spectrum - Nyquist = inst.info["sfreq"] / 2.0 - lp = inst.info["lowpass"] + Nyquist = epochs_src_picked.info["sfreq"] / 2.0 + lp = epochs_src_picked.info["lowpass"] if "fmax" not in psd_args: psd_args["fmax"] = min(lp * 1.25, Nyquist) plot_lowpass_edge = lp < Nyquist and (psd_args["fmax"] > lp) - spectrum = epochs_src.compute_psd(picks=picks, **psd_args) - # we've already restricted picks ↑↑↑↑↑↑↑↑↑↑↑ - # in the spectrum object, so here we do picks=all ↓↓↓↓↓↓↓↓↓↓↓ + # we've already restricted picks in epochs_src_picked, so here we do picks=all + spectrum = epochs_src_picked[good_indices].compute_psd(picks="all", **psd_args) psds, freqs = spectrum.get_data(return_freqs=True, picks="all", exclude=[]) # we also pass exclude=[] so that when this is called by right-clicking in # a plot_sources() window on an ICA component name that has been marked as @@ -654,30 +635,14 @@ def set_title_and_labels(ax, title, xlab, ylab): if idx > 0: fig, axes = _create_properties_layout(figsize=figsize) - # we reconstruct an epoch_variance with 0 where indexes where dropped - epoch_var = np.var(ica_data[idx], axis=1) - drop_var = np.var(dropped_src[idx], axis=1) - drop_indices_corrected = ( - dropped_indices - np.arange(len(dropped_indices)) - ).astype(int) - epoch_var = np.insert( - arr=epoch_var, - obj=drop_indices_corrected, - values=drop_var[dropped_indices], - axis=0, - ) - # the actual plot fig = _plot_ica_properties( pick, ica, - inst, psds_mean, freqs, - ica_data.shape[1], - epoch_var, plot_lowpass_edge, - epochs_src, + epochs_src_picked.copy().pick(picks=[idx]), set_title_and_labels, plot_std, psd_ylabel, @@ -688,7 +653,7 @@ def set_title_and_labels(ax, title, xlab, ylab): fig, axes, kind, - dropped_indices, + bad_indices, ) all_fig.append(fig) @@ -721,65 +686,76 @@ def _prepare_data_ica_properties(inst, ica, reject_by_annotation=True, reject="a data : array of shape (n_epochs, n_ica_sources, n_times) A view on epochs ICA sources data. """ - from ..epochs import BaseEpochs + from ..epochs import BaseEpochs, Epochs, make_fixed_length_events from ..io import BaseRaw, RawArray _validate_type(inst, (BaseRaw, BaseEpochs), "inst", "Raw or Epochs") + bad_indices = [] if isinstance(inst, BaseRaw): # when auto, delegate reject to the ica - from ..epochs import make_fixed_length_epochs if reject == "auto": reject = ica.reject_ - drop_inds = None - dropped_indices = [] - if reject is None: - inst_current = inst - else: - data = inst.get_data() - data, drop_inds = _reject_data_segments( - data, reject, flat=None, decim=None, info=inst.info, tstep=2.0 - ) - inst_current = RawArray(data, inst.info) - # break up continuous signal into segments; suppress "All epochs were - # dropped!" because we handle that case gracefully below - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "All epochs were dropped!", RuntimeWarning - ) - epochs_src = make_fixed_length_epochs( - ica.get_sources(inst_current), - duration=2, - preload=True, - reject_by_annotation=reject_by_annotation, - proj=False, - verbose=False, - ) - # if all epochs were dropped by annotations, stitch the good segments - # together so that the plot can still be generated - if reject_by_annotation and len(epochs_src) == 0: - good_data = inst_current.get_data(reject_by_annotation="omit") + # First we try making epochs in the normal way and see if we have enough + events = make_fixed_length_events(inst, duration=2) + kwargs = dict( + tmin=0, + tmax=2 - 1.0 / inst.info["sfreq"], + baseline=None, + verbose="error", + proj=False, + ) + epochs = Epochs( + inst, + events, + reject=reject, + reject_by_annotation=reject_by_annotation, + preload=False, + **kwargs, + ).drop_bad(verbose="error") + # If all epochs were dropped, stitch the good segments according to + # reject_by_annotation back together and get sources for those, subject to + # the reject param + if reject_by_annotation and len(epochs) == 0: + good_data = inst.get_data(reject_by_annotation="omit") + inst_stitched = RawArray(good_data, inst.info.copy(), verbose="error") + events_stitched = make_fixed_length_events(inst_stitched, duration=2) + epochs_stitched = Epochs( + inst_stitched, + events_stitched, + reject=reject, + reject_by_annotation=False, + preload=False, + **kwargs, + ).drop_bad(verbose="error") + got_samps = len(epochs_stitched) * len(epochs_stitched.times) min_samples = int(2 * inst.info["sfreq"]) - if good_data.shape[1] >= min_samples: - inst_good = RawArray(good_data, inst_current.info.copy(), verbose=False) - epochs_src = make_fixed_length_epochs( - ica.get_sources(inst_good), - duration=2, - preload=True, - reject_by_annotation=False, - proj=False, - verbose=False, - ) - # getting dropped epochs indexes - if drop_inds is not None: - dropped_indices = [(d[0] // len(epochs_src.times)) + 1 for d in drop_inds] + if got_samps >= min_samples: + inst = inst_stitched + events = events_stitched + epochs = epochs_stitched + epochs_src = Epochs( + ica.get_sources(inst), + events, + # We have already rejected by annotation and reject above, but we don't + # here so we can keep data for bad epochs around + reject=None, + reject_by_annotation=False, + preload=True, + **kwargs, + ) + bad_indices = np.where([len(log) for log in epochs.drop_log])[0] kind = "Segment" + assert len(epochs_src) == len(epochs) + len(bad_indices) + if len(epochs_src) == len(bad_indices): + raise RuntimeError( + f"No clean 2-second segments found out of {len(events)} using " + f"{reject=} and {reject_by_annotation=}." + ) else: - drop_inds = None epochs_src = ica.get_sources(inst) - dropped_indices = [] kind = "Epochs" - return kind, dropped_indices, epochs_src, epochs_src.get_data(copy=False) + return kind, bad_indices, epochs_src def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, ica, labels=None): diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index 46d1145e851..37491985b4a 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -13,6 +13,7 @@ from mne import ( Annotations, Epochs, + create_info, make_fixed_length_events, pick_types, read_cov, @@ -143,7 +144,7 @@ def test_plot_ica_components(): @pytest.mark.slowtest -def test_plot_ica_properties(): +def test_plot_ica_properties_basic(): """Test plotting of ICA properties.""" raw = _get_raw(preload=True).crop(0, 5) raw.add_proj([], remove_existing=True) @@ -278,7 +279,50 @@ def test_plot_ica_properties(): raw_all_bad.del_proj() fig = ica.plot_properties(raw_all_bad, picks=[0], **topoargs) assert_equal(len(fig), 1) - plt.close("all") + + +@pytest.mark.parametrize("kind", ["first", "last"]) +def test_plot_ica_properties_reject(kind): + """Check for gh-13879.""" + sfreq, n_epochs = 100.0, 5 + n_samples = int(sfreq * n_epochs * 2.0) # 2s per segment + rng = np.random.default_rng(0) + n_channels = 3 + data = rng.uniform(-3e-6, 3e-6, size=(n_channels, n_samples)) + assert kind in ("first", "last") + idx = 0 if kind == "first" else -1 + data[0, idx] = 1000e-6 + info = create_info(["Fz", "Cz", "C2"], sfreq, "eeg") + raw = RawArray(data, info) + raw.set_montage("standard_1020") + ica = ICA( + n_components=2, + random_state=0, + max_iter=1, + ) + with ( + pytest.warns(RuntimeWarning, match="filtered"), + pytest.warns(Warning, match="converge"), + catch_logging(True) as log, + ): + ica.fit(raw, reject=dict(eeg=500e-6)) + log = log.getvalue() + assert log.count("Artifact detected") == 1 # dropped one epoch + figs = ica.plot_properties(raw, picks=[0], show=False) + assert len(figs) == 1 + fig = figs[0] + img_ax = fig.axes[1] + img_ylim = img_ax.get_ylim() + assert img_ylim[0] == -0.5 + assert img_ylim[1] == n_epochs + 0.5 + hist_ax = fig.axes[-1] + var_ax = fig.axes[-2] + min_hist = np.min(hist_ax.lines[0].get_ydata()) + assert min_hist > 0 + scatter_x, _ = var_ax.collections[0].get_offsets().data.T + assert_array_equal(scatter_x, np.arange(n_epochs)) + with pytest.raises(RuntimeError, match="No clean"): + ica.plot_properties(raw, reject=dict(eeg=1e-6)) def test_plot_ica_sources(raw_orig, browser_backend, monkeypatch): @@ -451,7 +495,6 @@ def test_plot_ica_overlay(): pytest.raises(TypeError, ica.plot_overlay, raw[:2, :3][0]) pytest.raises(TypeError, ica.plot_overlay, raw, exclude=2) ica.plot_overlay(raw) - plt.close("all") # smoke test for CTF raw = read_raw_fif(raw_ctf_fname) diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index 257b1f85051..aceee1f610d 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -255,7 +255,7 @@ # baseline correction. ica = ICA(n_components=15, max_iter="auto", random_state=97) -ica.fit(filt_raw) +ica.fit(filt_raw, reject=dict(eeg=200e-6)) # avoid a couple of big artifacts ica # %%