Skip to content
1 change: 1 addition & 0 deletions doc/changes/dev/13885.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug with :meth:`mne.preprocessing.ICA.plot_properties` when using ``reject`` in :meth:`mne.preprocessing.ICA.fit`, by `Eric Larson`_.
2 changes: 1 addition & 1 deletion examples/preprocessing/find_ref_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
on the reference channels are removed.

This technique is fully described and validated in :footcite:`HannaEtAl2020`

"""
# Authors: Jeff Hanna <jeff.hanna@gmail.com>
#
Expand Down Expand Up @@ -78,6 +77,7 @@
ica_kwargs = dict(
method="picard",
fit_params=dict(tol=1e-4), # use a high tol here for speed
random_state=99,
)
all_picks = mne.pick_types(raw_tog.info, meg=True, ref_meg=True)
ica_tog = ICA(n_components=60, max_iter="auto", allow_ref_meg=True, **ica_kwargs)
Expand Down
4 changes: 1 addition & 3 deletions examples/preprocessing/muscle_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@

# %%
# By inspection, let's select out the muscle-artifact components based on
# :footcite:`DharmapraniEtAl2016` manually.
#
# The criteria are:
# :footcite:`DharmapraniEtAl2016` manually. The criteria are:
#
# - Positive slope of log-log power spectrum between 7 and 75 Hz
# (here just flat because it's not in log-log)
Expand Down
2 changes: 1 addition & 1 deletion mne/_fiff/pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,7 @@ def _picks_to_idx(
extra_repr = ", treated as range({n_chan})"
else:
picks = none # let _picks_str_to_idx handle it
extra_repr = f'None, treated as "{none}"'
extra_repr = f', treated as "{none}"'

#
# slice
Expand Down
216 changes: 96 additions & 120 deletions mne/viz/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from .._fiff.proj import _has_eeg_average_ref_proj
from ..defaults import DEFAULTS, _handle_default
from ..utils import (
_reject_data_segments,
_validate_type,
fill_doc,
verbose,
Expand Down Expand Up @@ -202,13 +201,10 @@ def _create_properties_layout(figsize=None, fig=None):
def _plot_ica_properties(
pick,
ica,
inst,
psds_mean,
freqs,
n_trials,
epoch_var,
plot_lowpass_edge,
epochs_src,
this_epochs_src,
set_title_and_labels,
plot_std,
psd_ylabel,
Expand All @@ -219,7 +215,7 @@ def _plot_ica_properties(
fig,
axes,
kind,
dropped_indices,
bad_indices,
):
"""Plot ICA properties (helper)."""
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
Expand All @@ -237,23 +233,15 @@ def _plot_ica_properties(
)

# image and erp
# we create a new epoch with dropped rows
epoch_data = epochs_src.get_data(copy=False)
epoch_data = np.insert(
arr=epoch_data,
obj=(dropped_indices - np.arange(len(dropped_indices))).astype(int),
values=0.0,
axis=0,
)
from ..epochs import EpochsArray

epochs_src = EpochsArray(
epoch_data, epochs_src.info, tmin=epochs_src.tmin, verbose=0
)

n_trials = len(this_epochs_src)
epoch_var = np.var(this_epochs_src.get_data(), axis=-1)
assert epoch_var.shape[1] == 1 # single channel
epoch_var = epoch_var[:, 0]
assert epoch_var.shape == (len(this_epochs_src),)
this_epochs_src._data[bad_indices] = 0
plot_epochs_image(
epochs_src,
picks=pick,
this_epochs_src,
picks=[0],
axes=[image_ax, erp_ax],
combine=None,
colorbar=False,
Expand All @@ -273,44 +261,41 @@ def _plot_ica_properties(
)
if plot_lowpass_edge:
spec_ax.axvline(
inst.info["lowpass"], lw=2, linestyle="--", color="k", alpha=0.2
this_epochs_src.info["lowpass"], lw=2, linestyle="--", color="k", alpha=0.2
)

# epoch variance
good_indices = np.setdiff1d(np.arange(n_trials), bad_indices)
var_ax_divider = make_axes_locatable(var_ax)
hist_ax = var_ax_divider.append_axes("right", size="33%", pad="2.5%")
var_ax.scatter(
range(len(epoch_var)), epoch_var, alpha=0.5, facecolor=[0, 0, 0], lw=0
)
hist_ax = var_ax_divider.append_axes("right", size="33%", pad="2.5%", sharey=var_ax)
facecolor = np.zeros((len(epoch_var), 3))
alpha = np.full(len(epoch_var), 0.5)
# rejected epochs in red
facecolor[bad_indices] = [1, 0, 0]
alpha[bad_indices] = 0.75
var_ax.scatter(
dropped_indices,
epoch_var[dropped_indices],
alpha=1.0,
facecolor=[1, 0, 0],
lw=0,
np.arange(n_trials), epoch_var, alpha=alpha, facecolor=facecolor, lw=0
)
# compute percentage of dropped epochs
var_percent = float(len(dropped_indices)) / float(len(epoch_var)) * 100.0
var_percent = 100 * len(bad_indices) / n_trials

# histogram & histogram
epoch_var_good = epoch_var[good_indices]
_, counts, _ = hist_ax.hist(
epoch_var, orientation="horizontal", color="k", alpha=0.5
epoch_var_good, orientation="horizontal", color="k", alpha=0.5
)

# kde
ymin, ymax = hist_ax.get_ylim()
try:
kde = gaussian_kde(epoch_var)
kde = gaussian_kde(epoch_var_good)
except np.linalg.LinAlgError:
pass # singular: happens when there is nothing plotted
else:
x = np.linspace(ymin, ymax, 50)
x = np.linspace(epoch_var_good.min(), epoch_var_good.max(), 50)
kde_ = kde(x)
kde_ /= kde_.max() or 1.0
kde_ *= hist_ax.get_xlim()[-1] * 0.9
hist_ax.plot(kde_, x, color="k")
hist_ax.set_ylim(ymin, ymax)

# aesthetics
# ----------
Expand All @@ -319,16 +304,16 @@ def _plot_ica_properties(
# erp
set_title_and_labels(erp_ax, [], "Time (s)", "AU")
erp_ax.spines["right"].set_color("k")
erp_ax.set_xlim(epochs_src.times[[0, -1]])
erp_ax.set_xlim(this_epochs_src.times[[0, -1]])
# remove half of yticks if more than 5
yt = erp_ax.get_yticks()
if len(yt) > 5:
erp_ax.yaxis.set_ticks(yt[::2])
erp_ax.set_yticks(yt[::2])

# remove xticks - erp plot shows xticks for both image and erp plot
image_ax.xaxis.set_ticks([])
image_ax.set_xticks([])
yt = image_ax.get_yticks()
image_ax.yaxis.set_ticks(yt[1:])
image_ax.set_yticks(yt[1:])
image_ax.set_ylim([-0.5, n_trials + 0.5])

def _set_scale(ax, scale):
Expand All @@ -342,10 +327,6 @@ def _set_scale(ax, scale):
set_title_and_labels(spec_ax, "Spectrum", "Frequency (Hz)", psd_ylabel)
spec_ax.yaxis.labelpad = 0
spec_ax.set_xlim(freqs[[0, -1]])
ylim = spec_ax.get_ylim()
air = np.diff(ylim)[0] * 0.1
spec_ax.set_ylim(ylim[0] - air, ylim[1] + air)
image_ax.axhline(0, color="k", linewidth=0.5)
if log_scale:
_set_scale(spec_ax, "log")

Expand Down Expand Up @@ -603,24 +584,24 @@ def _fast_plot_ica_properties(
# calculations
# ------------
if isinstance(precomputed_data, tuple):
kind, dropped_indices, epochs_src, data = precomputed_data
kind, bad_indices, epochs_src = precomputed_data
else:
kind, dropped_indices, epochs_src, data = _prepare_data_ica_properties(
kind, bad_indices, epochs_src = _prepare_data_ica_properties(
inst, ica, reject_by_annotation, reject
)
del reject
ica_data = np.swapaxes(data[:, picks, :], 0, 1)
dropped_src = ica_data
del reject, inst
epochs_src_picked = epochs_src.pick(picks)
del epochs_src
good_indices = np.setdiff1d(np.arange(len(epochs_src_picked)), bad_indices)

# spectrum
Nyquist = inst.info["sfreq"] / 2.0
lp = inst.info["lowpass"]
Nyquist = epochs_src_picked.info["sfreq"] / 2.0
lp = epochs_src_picked.info["lowpass"]
if "fmax" not in psd_args:
psd_args["fmax"] = min(lp * 1.25, Nyquist)
plot_lowpass_edge = lp < Nyquist and (psd_args["fmax"] > lp)
spectrum = epochs_src.compute_psd(picks=picks, **psd_args)
# we've already restricted picks ↑↑↑↑↑↑↑↑↑↑↑
# in the spectrum object, so here we do picks=all ↓↓↓↓↓↓↓↓↓↓↓
# we've already restricted picks in epochs_src_picked, so here we do picks=all
spectrum = epochs_src_picked[good_indices].compute_psd(picks="all", **psd_args)
psds, freqs = spectrum.get_data(return_freqs=True, picks="all", exclude=[])
# we also pass exclude=[] so that when this is called by right-clicking in
# a plot_sources() window on an ICA component name that has been marked as
Expand Down Expand Up @@ -654,30 +635,14 @@ def set_title_and_labels(ax, title, xlab, ylab):
if idx > 0:
fig, axes = _create_properties_layout(figsize=figsize)

# we reconstruct an epoch_variance with 0 where indexes where dropped
epoch_var = np.var(ica_data[idx], axis=1)
drop_var = np.var(dropped_src[idx], axis=1)
drop_indices_corrected = (
dropped_indices - np.arange(len(dropped_indices))
).astype(int)
epoch_var = np.insert(
arr=epoch_var,
obj=drop_indices_corrected,
values=drop_var[dropped_indices],
axis=0,
)

# the actual plot
fig = _plot_ica_properties(
pick,
ica,
inst,
psds_mean,
freqs,
ica_data.shape[1],
epoch_var,
plot_lowpass_edge,
epochs_src,
epochs_src_picked.copy().pick(picks=[idx]),
set_title_and_labels,
plot_std,
psd_ylabel,
Expand All @@ -688,7 +653,7 @@ def set_title_and_labels(ax, title, xlab, ylab):
fig,
axes,
kind,
dropped_indices,
bad_indices,
)
all_fig.append(fig)

Expand Down Expand Up @@ -721,65 +686,76 @@ def _prepare_data_ica_properties(inst, ica, reject_by_annotation=True, reject="a
data : array of shape (n_epochs, n_ica_sources, n_times)
A view on epochs ICA sources data.
"""
from ..epochs import BaseEpochs
from ..epochs import BaseEpochs, Epochs, make_fixed_length_events
from ..io import BaseRaw, RawArray

_validate_type(inst, (BaseRaw, BaseEpochs), "inst", "Raw or Epochs")
bad_indices = []
if isinstance(inst, BaseRaw):
# when auto, delegate reject to the ica
from ..epochs import make_fixed_length_epochs

if reject == "auto":
reject = ica.reject_
drop_inds = None
dropped_indices = []
if reject is None:
inst_current = inst
else:
data = inst.get_data()
data, drop_inds = _reject_data_segments(
data, reject, flat=None, decim=None, info=inst.info, tstep=2.0
)
inst_current = RawArray(data, inst.info)
# break up continuous signal into segments; suppress "All epochs were
# dropped!" because we handle that case gracefully below
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "All epochs were dropped!", RuntimeWarning
)
epochs_src = make_fixed_length_epochs(
ica.get_sources(inst_current),
duration=2,
preload=True,
reject_by_annotation=reject_by_annotation,
proj=False,
verbose=False,
)
# if all epochs were dropped by annotations, stitch the good segments
# together so that the plot can still be generated
if reject_by_annotation and len(epochs_src) == 0:
good_data = inst_current.get_data(reject_by_annotation="omit")
# First we try making epochs in the normal way and see if we have enough
events = make_fixed_length_events(inst, duration=2)
kwargs = dict(
tmin=0,
tmax=2 - 1.0 / inst.info["sfreq"],
baseline=None,
verbose="error",
proj=False,
)
epochs = Epochs(
inst,
events,
reject=reject,
reject_by_annotation=reject_by_annotation,
preload=False,
**kwargs,
).drop_bad(verbose="error")
# If all epochs were dropped, stitch the good segments according to
# reject_by_annotation back together and get sources for those, subject to
# the reject param
if reject_by_annotation and len(epochs) == 0:
good_data = inst.get_data(reject_by_annotation="omit")
inst_stitched = RawArray(good_data, inst.info.copy(), verbose="error")
events_stitched = make_fixed_length_events(inst_stitched, duration=2)
epochs_stitched = Epochs(
inst_stitched,
events_stitched,
reject=reject,
reject_by_annotation=False,
preload=False,
**kwargs,
).drop_bad(verbose="error")
got_samps = len(epochs_stitched) * len(epochs_stitched.times)
min_samples = int(2 * inst.info["sfreq"])
if good_data.shape[1] >= min_samples:
inst_good = RawArray(good_data, inst_current.info.copy(), verbose=False)
epochs_src = make_fixed_length_epochs(
ica.get_sources(inst_good),
duration=2,
preload=True,
reject_by_annotation=False,
proj=False,
verbose=False,
)
# getting dropped epochs indexes
if drop_inds is not None:
dropped_indices = [(d[0] // len(epochs_src.times)) + 1 for d in drop_inds]
if got_samps >= min_samples:
inst = inst_stitched
events = events_stitched
epochs = epochs_stitched
epochs_src = Epochs(
ica.get_sources(inst),
events,
# We have already rejected by annotation and reject above, but we don't
# here so we can keep data for bad epochs around
reject=None,
reject_by_annotation=False,
preload=True,
**kwargs,
)
bad_indices = np.where([len(log) for log in epochs.drop_log])[0]
kind = "Segment"
assert len(epochs_src) == len(epochs) + len(bad_indices)
if len(epochs_src) == len(bad_indices):
raise RuntimeError(
f"No clean 2-second segments found out of {len(events)} using "
f"{reject=} and {reject_by_annotation=}."
)
else:
drop_inds = None
epochs_src = ica.get_sources(inst)
dropped_indices = []
kind = "Epochs"
return kind, dropped_indices, epochs_src, epochs_src.get_data(copy=False)
return kind, bad_indices, epochs_src


def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, ica, labels=None):
Expand Down
Loading
Loading