Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions mne/viz/tests/test_topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,7 @@ def get_texts(p):
fig1 = evoked.plot_topomap(
"interactive", ch_type="mag", proj="interactive", **fast_test
)
# TODO: Clicking the slider creates a *new* image rather than updating
# the data directly. This makes it so that the projection is not applied
# to the correct matplotlib Image object.
# _fake_click(fig1, fig1.axes[1], (0.5, 0.5)) # click slider
_fake_click(fig1, fig1.axes[1], (0.5, 0.5)) # click slider
data_max = np.max(fig1.axes[0].images[0]._A)
proj_fig = plt.figure(plt.get_fignums()[-1])
assert fig1.mne.proj_checkboxes.get_status() == [False, False, False]
Expand Down
42 changes: 41 additions & 1 deletion mne/viz/topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ def _plot_update_evoked_topomap(params, bools):
new_evoked.apply_proj()

data = new_evoked.data[:, params["time_idx"]] * params["scale"]
_fig = params.get("fig")
_mne_params = getattr(_fig, "_mne_params", None) if _fig is not None else None
if _mne_params is not None and "current_time_idx" in _mne_params:
data = new_evoked.data[:, [_mne_params["current_time_idx"]]] * params["scale"]
if params["merge_channels"]:
data, _ = _merge_ch_data(data, "grad", [])

Expand Down Expand Up @@ -2375,6 +2379,8 @@ def plot_evoked_topomap(
# if ch_type in _fnirs_types:
if modality != "other":
merge_channels = False
merge_ch_type = ch_type
merge_names = names
# apply mask if requested
if mask is not None:
mask = mask.astype(bool, copy=False)
Expand Down Expand Up @@ -2460,6 +2466,12 @@ def plot_evoked_topomap(
)
slider.vline.remove() # remove initial point indicator
func = _merge_ch_data if merge_channels else lambda x: x
# Store merge metadata on the figure so slider callbacks can access it
fig._mne_params = dict(
merge_ch_type=merge_ch_type,
merge_names=merge_names,
merge_channels=merge_channels,
)

def _slider_changed(val):
publish(fig, TimeChange(time=val))
Expand All @@ -2484,6 +2496,7 @@ def _slider_changed(val):
scaling_time=scaling_time,
slider=slider,
kwargs=kwargs,
images=images,
),
)
subscribe(
Expand Down Expand Up @@ -2532,6 +2545,11 @@ def _slider_changed(val):
extrapolate=extrapolate,
)
_draw_proj_checkbox(None, params)
# When the interactive slider is also active, store a reference to
# params so that _on_time_change can keep time_idx synchronized and
# clear stale contour artists after each ax.clear().
if hasattr(fig, "_mne_params"):
fig._mne_params["proj_params"] = params
# This is mostly for testing purposes, but it's also consistent with
# raw.plot, so maybe not a bad thing in principle either
from mne.viz._figure import BrowserParams
Expand Down Expand Up @@ -2567,13 +2585,35 @@ def _on_time_change(
scaling_time,
slider,
kwargs,
images=None,
):
"""Handle updating topomap to show a new time."""
from ..channels.layout import _merge_ch_data

idx = np.argmin(np.abs(times - event.time))
data = func(data[:, idx]).ravel() * scaling
# Record the slider's current index in fig._mne_params. When the
# projection callback (_plot_update_evoked_topomap) runs later it needs
# this value to apply SSPs at the same time point the slider is showing.
_mne_params = getattr(fig, "_mne_params", None)
if _mne_params is not None:
_mne_params["current_time_idx"] = idx
_proj_params = _mne_params.get("proj_params")
if _proj_params is not None:
_proj_params["contours_"] = []
if _mne_params.get("merge_channels", False):
ch_type = _mne_params.get("merge_ch_type", "grad")
names = _mne_params.get("merge_names", [])
data, _ = _merge_ch_data(data[:, idx], ch_type, names)
else:
data = data[:, idx]
data = data.ravel() * scaling
else:
data = func(data[:, idx]).ravel() * scaling
ax = fig.axes[0]
ax.clear()
im, _ = plot_topomap(data, pos, axes=ax, **kwargs)
if images is not None:
images[0] = im
if hasattr(ax, "CB"):
ax.CB.mappable = im
_resize_cbar(ax.CB.cbar.ax, 2)
Expand Down