Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13885.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug with :meth:`mne.preprocessing.ICA.plot_properties` when using ``reject`` in :meth:`mne.preprocessing.ICA.fit`, by `Eric Larson`_.
2 changes: 1 addition & 1 deletion examples/preprocessing/find_ref_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
on the reference channels are removed.

This technique is fully described and validated in :footcite:`HannaEtAl2020`

"""
# Authors: Jeff Hanna <jeff.hanna@gmail.com>
#
Expand Down Expand Up @@ -78,6 +77,7 @@
ica_kwargs = dict(
method="picard",
fit_params=dict(tol=1e-4), # use a high tol here for speed
random_state=99,
)
all_picks = mne.pick_types(raw_tog.info, meg=True, ref_meg=True)
ica_tog = ICA(n_components=60, max_iter="auto", allow_ref_meg=True, **ica_kwargs)
Expand Down
4 changes: 1 addition & 3 deletions examples/preprocessing/muscle_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@

# %%
# By inspection, let's select out the muscle-artifact components based on
# :footcite:`DharmapraniEtAl2016` manually.
#
# The criteria are:
# :footcite:`DharmapraniEtAl2016` manually. The criteria are:
#
# - Positive slope of log-log power spectrum between 7 and 75 Hz
# (here just flat because it's not in log-log)
Expand Down
2 changes: 1 addition & 1 deletion mne/_fiff/pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,7 @@ def _picks_to_idx(
extra_repr = ", treated as range({n_chan})"
else:
picks = none # let _picks_str_to_idx handle it
extra_repr = f'None, treated as "{none}"'
extra_repr = f', treated as "{none}"'

#
# slice
Expand Down
216 changes: 96 additions & 120 deletions mne/viz/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from .._fiff.proj import _has_eeg_average_ref_proj
from ..defaults import DEFAULTS, _handle_default
from ..utils import (
_reject_data_segments,
_validate_type,
fill_doc,
verbose,
Expand Down Expand Up @@ -202,13 +201,10 @@ def _create_properties_layout(figsize=None, fig=None):
def _plot_ica_properties(
pick,
ica,
inst,
psds_mean,
freqs,
n_trials,
epoch_var,
plot_lowpass_edge,
epochs_src,
this_epochs_src,
set_title_and_labels,
plot_std,
psd_ylabel,
Expand All @@ -219,7 +215,7 @@ def _plot_ica_properties(
fig,
axes,
kind,
dropped_indices,
bad_indices,
):
"""Plot ICA properties (helper)."""
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
Expand All @@ -237,23 +233,15 @@ def _plot_ica_properties(
)

# image and erp
# we create a new epoch with dropped rows
epoch_data = epochs_src.get_data(copy=False)
epoch_data = np.insert(
arr=epoch_data,
obj=(dropped_indices - np.arange(len(dropped_indices))).astype(int),
values=0.0,
axis=0,
)
from ..epochs import EpochsArray

epochs_src = EpochsArray(
epoch_data, epochs_src.info, tmin=epochs_src.tmin, verbose=0
)

n_trials = len(this_epochs_src)
epoch_var = np.var(this_epochs_src.get_data(), axis=-1)
assert epoch_var.shape[1] == 1 # single channel
epoch_var = epoch_var[:, 0]
assert epoch_var.shape == (len(this_epochs_src),)
this_epochs_src._data[bad_indices] = 0
plot_epochs_image(
epochs_src,
picks=pick,
this_epochs_src,
picks=[0],
axes=[image_ax, erp_ax],
combine=None,
colorbar=False,
Expand All @@ -273,44 +261,41 @@ def _plot_ica_properties(
)
if plot_lowpass_edge:
spec_ax.axvline(
inst.info["lowpass"], lw=2, linestyle="--", color="k", alpha=0.2
this_epochs_src.info["lowpass"], lw=2, linestyle="--", color="k", alpha=0.2
)

# epoch variance
good_indices = np.setdiff1d(np.arange(n_trials), bad_indices)
var_ax_divider = make_axes_locatable(var_ax)
hist_ax = var_ax_divider.append_axes("right", size="33%", pad="2.5%")
var_ax.scatter(
range(len(epoch_var)), epoch_var, alpha=0.5, facecolor=[0, 0, 0], lw=0
)
hist_ax = var_ax_divider.append_axes("right", size="33%", pad="2.5%", sharey=var_ax)
facecolor = np.zeros((len(epoch_var), 3))
alpha = np.full(len(epoch_var), 0.5)
# rejected epochs in red
facecolor[bad_indices] = [1, 0, 0]
alpha[bad_indices] = 0.75
var_ax.scatter(
dropped_indices,
epoch_var[dropped_indices],
alpha=1.0,
facecolor=[1, 0, 0],
lw=0,
np.arange(n_trials), epoch_var, alpha=alpha, facecolor=facecolor, lw=0
)
# compute percentage of dropped epochs
var_percent = float(len(dropped_indices)) / float(len(epoch_var)) * 100.0
var_percent = 100 * len(bad_indices) / n_trials

# histogram & histogram
epoch_var_good = epoch_var[good_indices]
_, counts, _ = hist_ax.hist(
epoch_var, orientation="horizontal", color="k", alpha=0.5
epoch_var_good, orientation="horizontal", color="k", alpha=0.5
)

# kde
ymin, ymax = hist_ax.get_ylim()
try:
kde = gaussian_kde(epoch_var)
kde = gaussian_kde(epoch_var_good)
except np.linalg.LinAlgError:
pass # singular: happens when there is nothing plotted
else:
x = np.linspace(ymin, ymax, 50)
x = np.linspace(epoch_var_good.min(), epoch_var_good.max(), 50)
kde_ = kde(x)
kde_ /= kde_.max() or 1.0
kde_ *= hist_ax.get_xlim()[-1] * 0.9
hist_ax.plot(kde_, x, color="k")
hist_ax.set_ylim(ymin, ymax)

# aesthetics
# ----------
Expand All @@ -319,16 +304,16 @@ def _plot_ica_properties(
# erp
set_title_and_labels(erp_ax, [], "Time (s)", "AU")
erp_ax.spines["right"].set_color("k")
erp_ax.set_xlim(epochs_src.times[[0, -1]])
erp_ax.set_xlim(this_epochs_src.times[[0, -1]])
# remove half of yticks if more than 5
yt = erp_ax.get_yticks()
if len(yt) > 5:
erp_ax.yaxis.set_ticks(yt[::2])
erp_ax.set_yticks(yt[::2])

# remove xticks - erp plot shows xticks for both image and erp plot
image_ax.xaxis.set_ticks([])
image_ax.set_xticks([])
yt = image_ax.get_yticks()
image_ax.yaxis.set_ticks(yt[1:])
image_ax.set_yticks(yt[1:])
image_ax.set_ylim([-0.5, n_trials + 0.5])

def _set_scale(ax, scale):
Expand All @@ -342,10 +327,6 @@ def _set_scale(ax, scale):
set_title_and_labels(spec_ax, "Spectrum", "Frequency (Hz)", psd_ylabel)
spec_ax.yaxis.labelpad = 0
spec_ax.set_xlim(freqs[[0, -1]])
ylim = spec_ax.get_ylim()
air = np.diff(ylim)[0] * 0.1
spec_ax.set_ylim(ylim[0] - air, ylim[1] + air)
image_ax.axhline(0, color="k", linewidth=0.5)
if log_scale:
_set_scale(spec_ax, "log")

Expand Down Expand Up @@ -603,24 +584,24 @@ def _fast_plot_ica_properties(
# calculations
# ------------
if isinstance(precomputed_data, tuple):
kind, dropped_indices, epochs_src, data = precomputed_data
kind, bad_indices, epochs_src = precomputed_data
else:
kind, dropped_indices, epochs_src, data = _prepare_data_ica_properties(
kind, bad_indices, epochs_src = _prepare_data_ica_properties(
inst, ica, reject_by_annotation, reject
)
del reject
ica_data = np.swapaxes(data[:, picks, :], 0, 1)
dropped_src = ica_data
del reject, inst
epochs_src_picked = epochs_src.pick(picks)
del epochs_src
good_indices = np.setdiff1d(np.arange(len(epochs_src_picked)), bad_indices)

# spectrum
Nyquist = inst.info["sfreq"] / 2.0
lp = inst.info["lowpass"]
Nyquist = epochs_src_picked.info["sfreq"] / 2.0
lp = epochs_src_picked.info["lowpass"]
if "fmax" not in psd_args:
psd_args["fmax"] = min(lp * 1.25, Nyquist)
plot_lowpass_edge = lp < Nyquist and (psd_args["fmax"] > lp)
spectrum = epochs_src.compute_psd(picks=picks, **psd_args)
# we've already restricted picks ↑↑↑↑↑↑↑↑↑↑↑
# in the spectrum object, so here we do picks=all ↓↓↓↓↓↓↓↓↓↓↓
# we've already restricted picks in epochs_src_picked, so here we do picks=all
spectrum = epochs_src_picked[good_indices].compute_psd(picks="all", **psd_args)
psds, freqs = spectrum.get_data(return_freqs=True, picks="all", exclude=[])
# we also pass exclude=[] so that when this is called by right-clicking in
# a plot_sources() window on an ICA component name that has been marked as
Expand Down Expand Up @@ -654,30 +635,14 @@ def set_title_and_labels(ax, title, xlab, ylab):
if idx > 0:
fig, axes = _create_properties_layout(figsize=figsize)

# we reconstruct an epoch_variance with 0 where indexes where dropped
epoch_var = np.var(ica_data[idx], axis=1)
drop_var = np.var(dropped_src[idx], axis=1)
drop_indices_corrected = (
dropped_indices - np.arange(len(dropped_indices))
).astype(int)
epoch_var = np.insert(
arr=epoch_var,
obj=drop_indices_corrected,
values=drop_var[dropped_indices],
axis=0,
)

# the actual plot
fig = _plot_ica_properties(
pick,
ica,
inst,
psds_mean,
freqs,
ica_data.shape[1],
epoch_var,
plot_lowpass_edge,
epochs_src,
epochs_src_picked.copy().pick(picks=[idx]),
set_title_and_labels,
plot_std,
psd_ylabel,
Expand All @@ -688,7 +653,7 @@ def set_title_and_labels(ax, title, xlab, ylab):
fig,
axes,
kind,
dropped_indices,
bad_indices,
)
all_fig.append(fig)

Expand Down Expand Up @@ -721,65 +686,76 @@ def _prepare_data_ica_properties(inst, ica, reject_by_annotation=True, reject="a
data : array of shape (n_epochs, n_ica_sources, n_times)
A view on epochs ICA sources data.
"""
from ..epochs import BaseEpochs
from ..epochs import BaseEpochs, Epochs, make_fixed_length_events
from ..io import BaseRaw, RawArray

_validate_type(inst, (BaseRaw, BaseEpochs), "inst", "Raw or Epochs")
bad_indices = []
if isinstance(inst, BaseRaw):
# when auto, delegate reject to the ica
from ..epochs import make_fixed_length_epochs

if reject == "auto":
reject = ica.reject_
drop_inds = None
dropped_indices = []
if reject is None:
inst_current = inst
else:
data = inst.get_data()
data, drop_inds = _reject_data_segments(
data, reject, flat=None, decim=None, info=inst.info, tstep=2.0
)
inst_current = RawArray(data, inst.info)
# break up continuous signal into segments; suppress "All epochs were
# dropped!" because we handle that case gracefully below
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "All epochs were dropped!", RuntimeWarning
)
epochs_src = make_fixed_length_epochs(
ica.get_sources(inst_current),
duration=2,
preload=True,
reject_by_annotation=reject_by_annotation,
proj=False,
verbose=False,
)
# if all epochs were dropped by annotations, stitch the good segments
# together so that the plot can still be generated
if reject_by_annotation and len(epochs_src) == 0:
good_data = inst_current.get_data(reject_by_annotation="omit")
# First we try making epochs in the normal way and see if we have enough
events = make_fixed_length_events(inst, duration=2)
kwargs = dict(
tmin=0,
tmax=2 - 1.0 / inst.info["sfreq"],
baseline=None,
verbose="error",
proj=False,
)
epochs = Epochs(
inst,
events,
reject=reject,
reject_by_annotation=reject_by_annotation,
preload=False,
**kwargs,
).drop_bad(verbose="error")
# If all epochs were dropped, stitch the good segments according to
# reject_by_annotation back together and get sources for those, subject to
# the reject param
if reject_by_annotation and len(epochs) == 0:
good_data = inst.get_data(reject_by_annotation="omit")
inst_stitched = RawArray(good_data, inst.info.copy(), verbose="error")
events_stitched = make_fixed_length_events(inst_stitched, duration=2)
epochs_stitched = Epochs(
inst_stitched,
events_stitched,
reject=reject,
reject_by_annotation=False,
preload=False,
**kwargs,
).drop_bad(verbose="error")
got_samps = len(epochs_stitched) * len(epochs_stitched.times)
min_samples = int(2 * inst.info["sfreq"])
if good_data.shape[1] >= min_samples:
inst_good = RawArray(good_data, inst_current.info.copy(), verbose=False)
epochs_src = make_fixed_length_epochs(
ica.get_sources(inst_good),
duration=2,
preload=True,
reject_by_annotation=False,
proj=False,
verbose=False,
)
# getting dropped epochs indexes
if drop_inds is not None:
dropped_indices = [(d[0] // len(epochs_src.times)) + 1 for d in drop_inds]
if got_samps >= min_samples:
inst = inst_stitched
events = events_stitched
epochs = epochs_stitched
epochs_src = Epochs(
ica.get_sources(inst),
events,
# We have already rejected by annotation and reject above, but we don't
# here so we can keep data for bad epochs around
reject=None,
reject_by_annotation=False,
preload=True,
**kwargs,
)
bad_indices = np.where([len(log) for log in epochs.drop_log])[0]
kind = "Segment"
assert len(epochs_src) == len(epochs) + len(bad_indices)
if len(epochs_src) == len(bad_indices):
raise RuntimeError(
f"No clean 2-second segments found out of {len(events)} using "
f"{reject=} and {reject_by_annotation=}."
)
else:
drop_inds = None
epochs_src = ica.get_sources(inst)
dropped_indices = []
kind = "Epochs"
return kind, dropped_indices, epochs_src, epochs_src.get_data(copy=False)
return kind, bad_indices, epochs_src


def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, ica, labels=None):
Expand Down
Loading
Loading