diff --git a/doc/changes/dev/13713.other.rst b/doc/changes/dev/13713.other.rst new file mode 100644 index 00000000000..da7f4ead348 --- /dev/null +++ b/doc/changes/dev/13713.other.rst @@ -0,0 +1 @@ +Extended support for visualising cross-spectral densities in :func:`mne.viz.plot_csd` and :meth:`mne.time_frequency.CrossSpectralDensity.plot` to all types of :term:`data channels`, by :newcontrib:`Aniket Singh Yadav`. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 79db14d44e3..09c5818b6af 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -22,6 +22,7 @@ .. _Andrew Gilbert: https://github.com/adgilbert .. _Andrew Quinn: https://github.com/ajquinn .. _Aniket Pradhan: https://github.com/Aniket-Pradhan +.. _Aniket Singh Yadav: https://github.com/Aniketsy .. _Anna Padee: https://github.com/apadee/ .. _Annalisa Pascarella: https://github.com/annapasca .. _Anne-Sophie Dubarry: https://github.com/annesodub diff --git a/mne/viz/misc.py b/mne/viz/misc.py index a62833d7f81..cdf57d51deb 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -23,7 +23,6 @@ _picks_by_type, pick_channels, pick_info, - pick_types, ) from .._fiff.proj import make_projector from .._freesurfer import _check_mri, _mri_orientation, _read_mri_info, _reorient_image @@ -53,18 +52,12 @@ ) -def _index_info_cov(info, cov, exclude): - if exclude == "bads": - exclude = info["bads"] - info = pick_info(info, pick_channels(info["ch_names"], cov["names"], exclude)) - del exclude +def _index_info_by_ch_type(info, ch_names): + """Build channel-type-grouped indices and metadata for an object's channels.""" + info_ch_names = info["ch_names"] picks_list = _picks_by_type(info, meg_combined=False, ref_meg=False, exclude=()) picks_by_type = dict(picks_list) - ch_names = [n for n in cov.ch_names if n in info["ch_names"]] - ch_idx = [cov.ch_names.index(n) for n in ch_names] - - info_ch_names = info["ch_names"] idx_by_type = defaultdict(list) for ch_type, sel in picks_by_type.items(): idx_by_type[ch_type] = [ @@ -72,16 +65,37 @@ def _index_info_cov(info, cov, exclude): for c in sel if info_ch_names[c] in ch_names ] + + indices = [] + titles = [] + units = [] + scalings = [] + ch_types = [] + for key in _DATA_CH_TYPES_SPLIT: + if len(idx_by_type[key]) > 0: + indices.append(idx_by_type[key]) + titles.append(DEFAULTS["titles"][key]) + units.append(DEFAULTS["units"][key]) + scalings.append(DEFAULTS["scalings"][key]) + ch_types.append(key) + return indices, titles, units, scalings, ch_types + + +def _index_info_cov(info, cov, exclude): + if exclude == "bads": + exclude = info["bads"] + info = pick_info(info, pick_channels(info["ch_names"], cov["names"], exclude)) + del exclude + + ch_names = [n for n in cov.ch_names if n in info["ch_names"]] + ch_idx = [cov.ch_names.index(n) for n in ch_names] + + indices, titles, units, scalings, ch_types = _index_info_by_ch_type(info, ch_names) idx_names = [ - ( - idx_by_type[key], - f"{DEFAULTS['titles'][key]} covariance", - DEFAULTS["units"][key], - DEFAULTS["scalings"][key], - key, + (idx, f"{title} covariance", unit, scaling, key) + for idx, title, unit, scaling, key in zip( + indices, titles, units, scalings, ch_types ) - for key in _DATA_CH_TYPES_SPLIT - if len(idx_by_type[key]) > 0 ] C = cov.data[ch_idx][:, ch_idx] return info, C, ch_names, idx_names @@ -1483,39 +1497,16 @@ def plot_csd( raise ValueError('"mode" should be either "csd" or "coh".') if info is not None: - info_ch_names = info["ch_names"] - sel_eeg = pick_types(info, meg=False, eeg=True, ref_meg=False, exclude=[]) - sel_mag = pick_types(info, meg="mag", eeg=False, ref_meg=False, exclude=[]) - sel_grad = pick_types(info, meg="grad", eeg=False, ref_meg=False, exclude=[]) - idx_eeg = [ - csd.ch_names.index(info_ch_names[c]) - for c in sel_eeg - if info_ch_names[c] in csd.ch_names - ] - idx_mag = [ - csd.ch_names.index(info_ch_names[c]) - for c in sel_mag - if info_ch_names[c] in csd.ch_names - ] - idx_grad = [ - csd.ch_names.index(info_ch_names[c]) - for c in sel_grad - if info_ch_names[c] in csd.ch_names - ] - indices = [idx_eeg, idx_mag, idx_grad] - titles = ["EEG", "Magnetometers", "Gradiometers"] - - if mode == "csd": - # The units in which to plot the CSD - units = dict(eeg="µV²", grad="fT²/cm²", mag="fT²") - scalings = dict(eeg=1e12, grad=1e26, mag=1e30) + indices, titles, units, scalings, ch_types = _index_info_by_ch_type( + info, csd.ch_names + ) else: indices = [np.arange(len(csd.ch_names))] + units = [""] + scalings = [1] + ch_types = [None] if mode == "csd": titles = ["Cross-spectral density"] - # Units and scaling unknown - units = dict() - scalings = dict() elif mode == "coh": titles = ["Coherence"] @@ -1526,10 +1517,9 @@ def plot_csd( n_rows = int(np.ceil(n_freqs / float(n_cols))) figs = [] - for ind, title, ch_type in zip(indices, titles, ["eeg", "mag", "grad"]): - if len(ind) == 0: - continue - + for ind, title, unit, scaling, ch_type in zip( + indices, titles, units, scalings, ch_types + ): fig, axes = plt.subplots( n_rows, n_cols, @@ -1542,7 +1532,7 @@ def plot_csd( for i in range(len(csd.frequencies)): cm = csd.get_data(index=i)[ind][:, ind] if mode == "csd": - cm = np.abs(cm) * scalings.get(ch_type, 1) + cm = np.abs(cm) * scaling**2 elif mode == "coh": # Compute coherence from the CSD matrix psd = np.diag(cm).real @@ -1566,8 +1556,10 @@ def plot_csd( cb = plt.colorbar(im, ax=[a for ax_ in axes for a in ax_]) if mode == "csd": label = "CSD" - if ch_type in units: - label += f" ({units[ch_type]})" + if ch_type is not None: + if "/" in unit: + unit = f"({unit})" + label += f" ({unit}²)" cb.set_label(label) elif mode == "coh": cb.set_label("Coherence") diff --git a/mne/viz/tests/test_misc.py b/mne/viz/tests/test_misc.py index 7e4a3b2d806..b69d71a7d4f 100644 --- a/mne/viz/tests/test_misc.py +++ b/mne/viz/tests/test_misc.py @@ -11,6 +11,7 @@ from mne import ( SourceEstimate, + create_info, pick_events, read_cov, read_dipole, @@ -18,6 +19,7 @@ read_evokeds, read_source_spaces, ) +from mne._fiff.pick import _DATA_CH_TYPES_SPLIT from mne.chpi import compute_chpi_snr from mne.datasets import testing from mne.filter import create_filter @@ -312,18 +314,48 @@ def test_plot_dipole_amplitudes(): dipoles.plot_amplitudes(show=False) -def test_plot_csd(): +@pytest.mark.parametrize( + "ch_types", + [ + None, + "eeg", + "ecog", + ["grad", "dbs"], + "misc", + ], +) +def test_plot_csd(ch_types): """Test plotting of CSD matrices.""" + n_ch = 2 + if isinstance(ch_types, list): + n_ch_types = len(ch_types) + ch_types = np.repeat(ch_types, n_ch).tolist() + n_ch *= n_ch_types + ch_names = [f"CH{i + 1}" for i in range(n_ch)] + n_data = n_ch * (n_ch + 1) // 2 + csd = CrossSpectralDensity( - [1, 2, 3], - ["CH1", "CH2"], + list(range(1, n_data + 1)), + ch_names, frequencies=[(10, 20)], n_fft=1, tmin=0, tmax=1, ) - plot_csd(csd, mode="csd") # Plot cross-spectral density - plot_csd(csd, mode="coh") # Plot coherence + + info = None + expected_n_figs = 1 + if ch_types is not None: + info = create_info(ch_names, sfreq=1.0, ch_types=ch_types) + expected_n_figs = len( + set(ch_types if isinstance(ch_types, list) else [ch_types]).intersection( + _DATA_CH_TYPES_SPLIT + ) + ) + + for mode in ("csd", "coh"): + figs = plot_csd(csd, info=info, mode=mode, show=False) + assert len(figs) == expected_n_figs @pytest.mark.slowtest # Slow on Azure