From a2cc04fab7c5aac9f9cf6951afbeef35f5d4aa2d Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Mon, 2 Mar 2026 12:35:38 +0000 Subject: [PATCH 01/13] Extend plot_csd support to SEEG, ECoG, and DBS channel types --- mne/viz/misc.py | 69 ++++++++++++++++++++++++++------------ mne/viz/tests/test_misc.py | 52 ++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 22 deletions(-) diff --git a/mne/viz/misc.py b/mne/viz/misc.py index a62833d7f81..f7066b98fd7 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -1484,33 +1484,58 @@ def plot_csd( if info is not None: info_ch_names = info["ch_names"] - sel_eeg = pick_types(info, meg=False, eeg=True, ref_meg=False, exclude=[]) - sel_mag = pick_types(info, meg="mag", eeg=False, ref_meg=False, exclude=[]) - sel_grad = pick_types(info, meg="grad", eeg=False, ref_meg=False, exclude=[]) - idx_eeg = [ - csd.ch_names.index(info_ch_names[c]) - for c in sel_eeg - if info_ch_names[c] in csd.ch_names + # Each entry: (pick_types kwargs, ch_type key, plot title) + _ch_type_info = [ + (dict(meg=False, eeg=True, ref_meg=False, exclude=[]), "eeg", "EEG"), + ( + dict(meg="mag", eeg=False, ref_meg=False, exclude=[]), + "mag", + "Magnetometers", + ), + ( + dict(meg="grad", eeg=False, ref_meg=False, exclude=[]), + "grad", + "Gradiometers", + ), + (dict(meg=False, seeg=True, ref_meg=False, exclude=[]), "seeg", "sEEG"), + (dict(meg=False, ecog=True, ref_meg=False, exclude=[]), "ecog", "ECoG"), + (dict(meg=False, dbs=True, ref_meg=False, exclude=[]), "dbs", "DBS"), ] - idx_mag = [ - csd.ch_names.index(info_ch_names[c]) - for c in sel_mag - if info_ch_names[c] in csd.ch_names - ] - idx_grad = [ - csd.ch_names.index(info_ch_names[c]) - for c in sel_grad - if info_ch_names[c] in csd.ch_names - ] - indices = [idx_eeg, idx_mag, idx_grad] - titles = ["EEG", "Magnetometers", "Gradiometers"] + indices = [] + titles = [] + ch_types = [] + for pick_kwargs, ch_type, title in _ch_type_info: + sel = pick_types(info, **pick_kwargs) + idx = [ + csd.ch_names.index(info_ch_names[c]) + for c in sel + if info_ch_names[c] in csd.ch_names + ] + indices.append(idx) + titles.append(title) + ch_types.append(ch_type) if mode == "csd": # The units in which to plot the CSD - units = dict(eeg="µV²", grad="fT²/cm²", mag="fT²") - scalings = dict(eeg=1e12, grad=1e26, mag=1e30) + units = dict( + eeg="µV²", + grad="fT²/cm²", + mag="fT²", + seeg="mV²", + ecog="µV²", + dbs="µV²", + ) + scalings = dict( + eeg=1e12, + grad=1e26, + mag=1e30, + seeg=1e6, + ecog=1e12, + dbs=1e12, + ) else: indices = [np.arange(len(csd.ch_names))] + ch_types = [None] if mode == "csd": titles = ["Cross-spectral density"] # Units and scaling unknown @@ -1526,7 +1551,7 @@ def plot_csd( n_rows = int(np.ceil(n_freqs / float(n_cols))) figs = [] - for ind, title, ch_type in zip(indices, titles, ["eeg", "mag", "grad"]): + for ind, title, ch_type in zip(indices, titles, ch_types): if len(ind) == 0: continue diff --git a/mne/viz/tests/test_misc.py b/mne/viz/tests/test_misc.py index 7e4a3b2d806..7e8c97f2f7e 100644 --- a/mne/viz/tests/test_misc.py +++ b/mne/viz/tests/test_misc.py @@ -11,6 +11,7 @@ from mne import ( SourceEstimate, + create_info, pick_events, read_cov, read_dipole, @@ -325,6 +326,57 @@ def test_plot_csd(): plot_csd(csd, mode="csd") # Plot cross-spectral density plot_csd(csd, mode="coh") # Plot coherence + # EEG + info_eeg = create_info(["CH1", "CH2"], sfreq=1.0, ch_types="eeg") + figs = plot_csd(csd, info=info_eeg, mode="csd", show=False) + assert len(figs) == 1 + figs = plot_csd(csd, info=info_eeg, mode="coh", show=False) + assert len(figs) == 1 + + # sEEG + info_seeg = create_info(["CH1", "CH2"], sfreq=1.0, ch_types="seeg") + figs = plot_csd(csd, info=info_seeg, mode="csd", show=False) + assert len(figs) == 1 + figs = plot_csd(csd, info=info_seeg, mode="coh", show=False) + assert len(figs) == 1 + + # ECoG + info_ecog = create_info(["CH1", "CH2"], sfreq=1.0, ch_types="ecog") + figs = plot_csd(csd, info=info_ecog, mode="csd", show=False) + assert len(figs) == 1 + figs = plot_csd(csd, info=info_ecog, mode="coh", show=False) + assert len(figs) == 1 + + # DBS + info_dbs = create_info(["CH1", "CH2"], sfreq=1.0, ch_types="dbs") + figs = plot_csd(csd, info=info_dbs, mode="csd", show=False) + assert len(figs) == 1 + figs = plot_csd(csd, info=info_dbs, mode="coh", show=False) + assert len(figs) == 1 + + # Mixed: EEG + sEEG + ECoG + DBS — each should produce its own figure + info_mixed = create_info( + ["EEG1", "SEEG1", "ECOG1", "DBS1"], + sfreq=1.0, + ch_types=["eeg", "seeg", "ecog", "dbs"], + ) + csd_mixed = CrossSpectralDensity( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + ["EEG1", "SEEG1", "ECOG1", "DBS1"], + frequencies=[(10, 20)], + n_fft=1, + tmin=0, + tmax=1, + ) + figs = plot_csd(csd_mixed, info=info_mixed, mode="csd", show=False) + assert len(figs) == 4 + figs = plot_csd(csd_mixed, info=info_mixed, mode="coh", show=False) + assert len(figs) == 4 + + info_other = create_info(["OTHERCHAN"], sfreq=1.0, ch_types="eeg") + figs = plot_csd(csd, info=info_other, mode="csd", show=False) + assert len(figs) == 0 + @pytest.mark.slowtest # Slow on Azure @testing.requires_testing_data From 55a4ab88fe5b6dfcfc9724d9615f1007b702655c Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Mon, 2 Mar 2026 12:52:15 +0000 Subject: [PATCH 02/13] Extend plot_csd support to SEEG, ECoG, and DBS channel types --- mne/viz/misc.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mne/viz/misc.py b/mne/viz/misc.py index f7066b98fd7..528b50fd76a 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -1484,8 +1484,7 @@ def plot_csd( if info is not None: info_ch_names = info["ch_names"] - # Each entry: (pick_types kwargs, ch_type key, plot title) - _ch_type_info = [ + _channel_configs = [ (dict(meg=False, eeg=True, ref_meg=False, exclude=[]), "eeg", "EEG"), ( dict(meg="mag", eeg=False, ref_meg=False, exclude=[]), @@ -1504,16 +1503,17 @@ def plot_csd( indices = [] titles = [] ch_types = [] - for pick_kwargs, ch_type, title in _ch_type_info: + for pick_kwargs, ch_type, title in _channel_configs: sel = pick_types(info, **pick_kwargs) idx = [ csd.ch_names.index(info_ch_names[c]) for c in sel if info_ch_names[c] in csd.ch_names ] - indices.append(idx) - titles.append(title) - ch_types.append(ch_type) + if len(idx) > 0: + indices.append(idx) + titles.append(title) + ch_types.append(ch_type) if mode == "csd": # The units in which to plot the CSD From bbaa04e74330753d014fe53aa8b43d8c4d034544 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Tue, 3 Mar 2026 10:21:21 +0000 Subject: [PATCH 03/13] Extend plot_csd support to SEEG, ECoG, and DBS channel types --- doc/changes/dev/13713.other.rst | 1 + mne/viz/misc.py | 123 +++++++++++++------------------- 2 files changed, 51 insertions(+), 73 deletions(-) create mode 100644 doc/changes/dev/13713.other.rst diff --git a/doc/changes/dev/13713.other.rst b/doc/changes/dev/13713.other.rst new file mode 100644 index 00000000000..44d2347154c --- /dev/null +++ b/doc/changes/dev/13713.other.rst @@ -0,0 +1 @@ +Refactored CSD channel-type indexing to support all data channel types using ``_DATA_CH_TYPES_SPLIT`` and ``DEFAULTS``, removing hardcoded handling and reducing duplicated logic with covariance plotting, by `Aniket Singh Yadav`_. \ No newline at end of file diff --git a/mne/viz/misc.py b/mne/viz/misc.py index 528b50fd76a..0fc4f5c45c6 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -23,7 +23,6 @@ _picks_by_type, pick_channels, pick_info, - pick_types, ) from .._fiff.proj import make_projector from .._freesurfer import _check_mri, _mri_orientation, _read_mri_info, _reorient_image @@ -53,29 +52,41 @@ ) -def _index_info_cov(info, cov, exclude): - if exclude == "bads": - exclude = info["bads"] - info = pick_info(info, pick_channels(info["ch_names"], cov["names"], exclude)) - del exclude - picks_list = _picks_by_type(info, meg_combined=False, ref_meg=False, exclude=()) - picks_by_type = dict(picks_list) +def _index_info_by_ch_type(info, obj_ch_names): + """Build channel-type-grouped indices and metadata for an object's channels. - ch_names = [n for n in cov.ch_names if n in info["ch_names"]] - ch_idx = [cov.ch_names.index(n) for n in ch_names] + Parameters + ---------- + info : Info + The measurement info, already restricted to the channels of interest + (e.g. via :func:`pick_info`). + obj_ch_names : list of str + Channel names of the object being indexed (e.g. the channel list of a + :class:`~mne.time_frequency.CrossSpectralDensity` or the filtered + channel list derived from a :class:`~mne.Covariance`). + Returns + ------- + idx_names : list of tuple + Each entry is ``(indices, title, unit, scaling, ch_type)`` for each + data channel type that has at least one matching channel. + *indices* are integer positions into *obj_ch_names*. + """ info_ch_names = info["ch_names"] + picks_list = _picks_by_type(info, meg_combined=False, ref_meg=False, exclude=()) + picks_by_type = dict(picks_list) + idx_by_type = defaultdict(list) for ch_type, sel in picks_by_type.items(): idx_by_type[ch_type] = [ - ch_names.index(info_ch_names[c]) + obj_ch_names.index(info_ch_names[c]) for c in sel - if info_ch_names[c] in ch_names + if info_ch_names[c] in obj_ch_names ] - idx_names = [ + return [ ( idx_by_type[key], - f"{DEFAULTS['titles'][key]} covariance", + DEFAULTS["titles"][key], DEFAULTS["units"][key], DEFAULTS["scalings"][key], key, @@ -83,6 +94,21 @@ def _index_info_cov(info, cov, exclude): for key in _DATA_CH_TYPES_SPLIT if len(idx_by_type[key]) > 0 ] + + +def _index_info_cov(info, cov, exclude): + if exclude == "bads": + exclude = info["bads"] + info = pick_info(info, pick_channels(info["ch_names"], cov["names"], exclude)) + del exclude + + ch_names = [n for n in cov.ch_names if n in info["ch_names"]] + ch_idx = [cov.ch_names.index(n) for n in ch_names] + + idx_names = [ + (idx, f"{title} covariance", unit, scaling, key) + for idx, title, unit, scaling, key in _index_info_by_ch_type(info, ch_names) + ] C = cov.data[ch_idx][:, ch_idx] return info, C, ch_names, idx_names @@ -1483,64 +1509,15 @@ def plot_csd( raise ValueError('"mode" should be either "csd" or "coh".') if info is not None: - info_ch_names = info["ch_names"] - _channel_configs = [ - (dict(meg=False, eeg=True, ref_meg=False, exclude=[]), "eeg", "EEG"), - ( - dict(meg="mag", eeg=False, ref_meg=False, exclude=[]), - "mag", - "Magnetometers", - ), - ( - dict(meg="grad", eeg=False, ref_meg=False, exclude=[]), - "grad", - "Gradiometers", - ), - (dict(meg=False, seeg=True, ref_meg=False, exclude=[]), "seeg", "sEEG"), - (dict(meg=False, ecog=True, ref_meg=False, exclude=[]), "ecog", "ECoG"), - (dict(meg=False, dbs=True, ref_meg=False, exclude=[]), "dbs", "DBS"), - ] - indices = [] - titles = [] - ch_types = [] - for pick_kwargs, ch_type, title in _channel_configs: - sel = pick_types(info, **pick_kwargs) - idx = [ - csd.ch_names.index(info_ch_names[c]) - for c in sel - if info_ch_names[c] in csd.ch_names - ] - if len(idx) > 0: - indices.append(idx) - titles.append(title) - ch_types.append(ch_type) - - if mode == "csd": - # The units in which to plot the CSD - units = dict( - eeg="µV²", - grad="fT²/cm²", - mag="fT²", - seeg="mV²", - ecog="µV²", - dbs="µV²", - ) - scalings = dict( - eeg=1e12, - grad=1e26, - mag=1e30, - seeg=1e6, - ecog=1e12, - dbs=1e12, - ) + idx_names = _index_info_by_ch_type(info, csd.ch_names) + indices = [item[0] for item in idx_names] + titles = [item[1] for item in idx_names] + ch_types = [item[4] for item in idx_names] else: indices = [np.arange(len(csd.ch_names))] ch_types = [None] if mode == "csd": titles = ["Cross-spectral density"] - # Units and scaling unknown - units = dict() - scalings = dict() elif mode == "coh": titles = ["Coherence"] @@ -1552,9 +1529,6 @@ def plot_csd( figs = [] for ind, title, ch_type in zip(indices, titles, ch_types): - if len(ind) == 0: - continue - fig, axes = plt.subplots( n_rows, n_cols, @@ -1567,7 +1541,8 @@ def plot_csd( for i in range(len(csd.frequencies)): cm = csd.get_data(index=i)[ind][:, ind] if mode == "csd": - cm = np.abs(cm) * scalings.get(ch_type, 1) + scaling = DEFAULTS["scalings"].get(ch_type, 1) if ch_type else 1 + cm = np.abs(cm) * scaling**2 elif mode == "coh": # Compute coherence from the CSD matrix psd = np.diag(cm).real @@ -1591,8 +1566,10 @@ def plot_csd( cb = plt.colorbar(im, ax=[a for ax_ in axes for a in ax_]) if mode == "csd": label = "CSD" - if ch_type in units: - label += f" ({units[ch_type]})" + if ch_type is not None: + unit = DEFAULTS["units"].get(ch_type, "") + if unit: + label += f" ({unit}²)" cb.set_label(label) elif mode == "coh": cb.set_label("Coherence") From fa90e9fe9e09ee18ea5a5202b544b7a22ca96d77 Mon Sep 17 00:00:00 2001 From: Aniket <148300120+Aniketsy@users.noreply.github.com> Date: Thu, 5 Mar 2026 13:53:21 +0530 Subject: [PATCH 04/13] Update mne/viz/misc.py Co-authored-by: Thomas S. Binns --- mne/viz/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/viz/misc.py b/mne/viz/misc.py index 0fc4f5c45c6..f6a1e723bb7 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -1541,7 +1541,7 @@ def plot_csd( for i in range(len(csd.frequencies)): cm = csd.get_data(index=i)[ind][:, ind] if mode == "csd": - scaling = DEFAULTS["scalings"].get(ch_type, 1) if ch_type else 1 + scaling = DEFAULTS["scalings"].get(ch_type, 1) cm = np.abs(cm) * scaling**2 elif mode == "coh": # Compute coherence from the CSD matrix From 52e8cd378b4be6cffb89de12b47edde9eafc253b Mon Sep 17 00:00:00 2001 From: Aniket <148300120+Aniketsy@users.noreply.github.com> Date: Thu, 5 Mar 2026 13:53:35 +0530 Subject: [PATCH 05/13] Update mne/viz/misc.py Co-authored-by: Thomas S. Binns --- mne/viz/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/viz/misc.py b/mne/viz/misc.py index f6a1e723bb7..4327667f203 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -1569,7 +1569,7 @@ def plot_csd( if ch_type is not None: unit = DEFAULTS["units"].get(ch_type, "") if unit: - label += f" ({unit}²)" + label += f" ({'/'.join([el + '²' for el in unit.split('/')])})" cb.set_label(label) elif mode == "coh": cb.set_label("Coherence") From 7a4eddf8131b7772198b439484f279a4a2ff1a7f Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Thu, 5 Mar 2026 17:25:07 +0000 Subject: [PATCH 06/13] Extend plot_csd support to SEEG, ECoG, and DBS channel types --- mne/viz/misc.py | 58 +++++++++++++++---------- mne/viz/tests/test_misc.py | 89 ++++++++++++++------------------------ 2 files changed, 69 insertions(+), 78 deletions(-) diff --git a/mne/viz/misc.py b/mne/viz/misc.py index 4327667f203..7c6ae0ecb62 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -67,10 +67,16 @@ def _index_info_by_ch_type(info, obj_ch_names): Returns ------- - idx_names : list of tuple - Each entry is ``(indices, title, unit, scaling, ch_type)`` for each - data channel type that has at least one matching channel. - *indices* are integer positions into *obj_ch_names*. + indices : list of list of int + Channel indices into *obj_ch_names*, one list per channel type. + titles : list of str + Human-readable title for each channel type. + units : list of str + Unit string for each channel type. + scalings : list of float + Scaling factor for each channel type. + ch_types : list of str + Channel type key for each group. """ info_ch_names = info["ch_names"] picks_list = _picks_by_type(info, meg_combined=False, ref_meg=False, exclude=()) @@ -83,17 +89,20 @@ def _index_info_by_ch_type(info, obj_ch_names): for c in sel if info_ch_names[c] in obj_ch_names ] - return [ - ( - idx_by_type[key], - DEFAULTS["titles"][key], - DEFAULTS["units"][key], - DEFAULTS["scalings"][key], - key, - ) - for key in _DATA_CH_TYPES_SPLIT - if len(idx_by_type[key]) > 0 - ] + + indices = [] + titles = [] + units = [] + scalings = [] + ch_types = [] + for key in _DATA_CH_TYPES_SPLIT: + if len(idx_by_type[key]) > 0: + indices.append(idx_by_type[key]) + titles.append(DEFAULTS["titles"][key]) + units.append(DEFAULTS["units"][key]) + scalings.append(DEFAULTS["scalings"][key]) + ch_types.append(key) + return indices, titles, units, scalings, ch_types def _index_info_cov(info, cov, exclude): @@ -105,9 +114,12 @@ def _index_info_cov(info, cov, exclude): ch_names = [n for n in cov.ch_names if n in info["ch_names"]] ch_idx = [cov.ch_names.index(n) for n in ch_names] + indices, titles, units, scalings, ch_types = _index_info_by_ch_type(info, ch_names) idx_names = [ (idx, f"{title} covariance", unit, scaling, key) - for idx, title, unit, scaling, key in _index_info_by_ch_type(info, ch_names) + for idx, title, unit, scaling, key in zip( + indices, titles, units, scalings, ch_types + ) ] C = cov.data[ch_idx][:, ch_idx] return info, C, ch_names, idx_names @@ -1509,10 +1521,7 @@ def plot_csd( raise ValueError('"mode" should be either "csd" or "coh".') if info is not None: - idx_names = _index_info_by_ch_type(info, csd.ch_names) - indices = [item[0] for item in idx_names] - titles = [item[1] for item in idx_names] - ch_types = [item[4] for item in idx_names] + indices, titles, _, _, ch_types = _index_info_by_ch_type(info, csd.ch_names) else: indices = [np.arange(len(csd.ch_names))] ch_types = [None] @@ -1537,11 +1546,13 @@ def plot_csd( layout="constrained", ) + if mode == "csd": + scaling = DEFAULTS["scalings"].get(ch_type, 1) + csd_mats = [] for i in range(len(csd.frequencies)): cm = csd.get_data(index=i)[ind][:, ind] if mode == "csd": - scaling = DEFAULTS["scalings"].get(ch_type, 1) cm = np.abs(cm) * scaling**2 elif mode == "coh": # Compute coherence from the CSD matrix @@ -1569,7 +1580,10 @@ def plot_csd( if ch_type is not None: unit = DEFAULTS["units"].get(ch_type, "") if unit: - label += f" ({'/'.join([el + '²' for el in unit.split('/')])})" + if "/" in unit: + label += f" ({unit})²" + else: + label += f" ({unit}²)" cb.set_label(label) elif mode == "coh": cb.set_label("Coherence") diff --git a/mne/viz/tests/test_misc.py b/mne/viz/tests/test_misc.py index 7e8c97f2f7e..e09468f4a32 100644 --- a/mne/viz/tests/test_misc.py +++ b/mne/viz/tests/test_misc.py @@ -313,69 +313,46 @@ def test_plot_dipole_amplitudes(): dipoles.plot_amplitudes(show=False) -def test_plot_csd(): +@pytest.mark.parametrize( + "ch_types, expected_n_figs", + [ + (None, 1), + ("eeg", 1), + ("ecog", 1), + (["grad", "dbs"], 2), + ("misc", 0), + ], +) +def test_plot_csd(ch_types, expected_n_figs): """Test plotting of CSD matrices.""" + if isinstance(ch_types, list): + ch_type_list = ch_types + elif ch_types is not None: + ch_type_list = [ch_types, ch_types] + else: + ch_type_list = None + + n_ch = len(ch_type_list) if ch_type_list is not None else 2 + ch_names = [f"CH{i + 1}" for i in range(n_ch)] + n_data = n_ch * (n_ch + 1) // 2 + csd = CrossSpectralDensity( - [1, 2, 3], - ["CH1", "CH2"], + list(range(1, n_data + 1)), + ch_names, frequencies=[(10, 20)], n_fft=1, tmin=0, tmax=1, ) - plot_csd(csd, mode="csd") # Plot cross-spectral density - plot_csd(csd, mode="coh") # Plot coherence - - # EEG - info_eeg = create_info(["CH1", "CH2"], sfreq=1.0, ch_types="eeg") - figs = plot_csd(csd, info=info_eeg, mode="csd", show=False) - assert len(figs) == 1 - figs = plot_csd(csd, info=info_eeg, mode="coh", show=False) - assert len(figs) == 1 - - # sEEG - info_seeg = create_info(["CH1", "CH2"], sfreq=1.0, ch_types="seeg") - figs = plot_csd(csd, info=info_seeg, mode="csd", show=False) - assert len(figs) == 1 - figs = plot_csd(csd, info=info_seeg, mode="coh", show=False) - assert len(figs) == 1 - - # ECoG - info_ecog = create_info(["CH1", "CH2"], sfreq=1.0, ch_types="ecog") - figs = plot_csd(csd, info=info_ecog, mode="csd", show=False) - assert len(figs) == 1 - figs = plot_csd(csd, info=info_ecog, mode="coh", show=False) - assert len(figs) == 1 - - # DBS - info_dbs = create_info(["CH1", "CH2"], sfreq=1.0, ch_types="dbs") - figs = plot_csd(csd, info=info_dbs, mode="csd", show=False) - assert len(figs) == 1 - figs = plot_csd(csd, info=info_dbs, mode="coh", show=False) - assert len(figs) == 1 - - # Mixed: EEG + sEEG + ECoG + DBS — each should produce its own figure - info_mixed = create_info( - ["EEG1", "SEEG1", "ECOG1", "DBS1"], - sfreq=1.0, - ch_types=["eeg", "seeg", "ecog", "dbs"], - ) - csd_mixed = CrossSpectralDensity( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - ["EEG1", "SEEG1", "ECOG1", "DBS1"], - frequencies=[(10, 20)], - n_fft=1, - tmin=0, - tmax=1, - ) - figs = plot_csd(csd_mixed, info=info_mixed, mode="csd", show=False) - assert len(figs) == 4 - figs = plot_csd(csd_mixed, info=info_mixed, mode="coh", show=False) - assert len(figs) == 4 - - info_other = create_info(["OTHERCHAN"], sfreq=1.0, ch_types="eeg") - figs = plot_csd(csd, info=info_other, mode="csd", show=False) - assert len(figs) == 0 + + if ch_type_list is not None: + info = create_info(ch_names, sfreq=1.0, ch_types=ch_type_list) + else: + info = None + + for mode in ("csd", "coh"): + figs = plot_csd(csd, info=info, mode=mode, show=False) + assert len(figs) == expected_n_figs @pytest.mark.slowtest # Slow on Azure From 342d04e54b31ecebec860be68ba6f47137a5feab Mon Sep 17 00:00:00 2001 From: Aniket <148300120+Aniketsy@users.noreply.github.com> Date: Fri, 6 Mar 2026 09:59:31 +0530 Subject: [PATCH 07/13] Update mne/viz/misc.py Co-authored-by: Thomas S. Binns --- mne/viz/misc.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mne/viz/misc.py b/mne/viz/misc.py index 7c6ae0ecb62..50b16b45b77 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -1578,12 +1578,9 @@ def plot_csd( if mode == "csd": label = "CSD" if ch_type is not None: - unit = DEFAULTS["units"].get(ch_type, "") - if unit: - if "/" in unit: - label += f" ({unit})²" - else: - label += f" ({unit}²)" + if "/" in unit: + unit = f"({unit})" + label += f" ({unit}²)" cb.set_label(label) elif mode == "coh": cb.set_label("Coherence") From ad4cd30919d4e85c5a2da2f102f354734bbc087d Mon Sep 17 00:00:00 2001 From: Aniket <148300120+Aniketsy@users.noreply.github.com> Date: Fri, 6 Mar 2026 10:00:01 +0530 Subject: [PATCH 08/13] Update mne/viz/tests/test_misc.py Co-authored-by: Thomas S. Binns --- mne/viz/tests/test_misc.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/mne/viz/tests/test_misc.py b/mne/viz/tests/test_misc.py index e09468f4a32..ff4fca6200c 100644 --- a/mne/viz/tests/test_misc.py +++ b/mne/viz/tests/test_misc.py @@ -325,14 +325,11 @@ def test_plot_dipole_amplitudes(): ) def test_plot_csd(ch_types, expected_n_figs): """Test plotting of CSD matrices.""" + n_ch = 2 if isinstance(ch_types, list): - ch_type_list = ch_types - elif ch_types is not None: - ch_type_list = [ch_types, ch_types] - else: - ch_type_list = None - - n_ch = len(ch_type_list) if ch_type_list is not None else 2 + n_ch_types = len(ch_types) + ch_types = np.repeat(ch_types, n_ch).tolist() + n_ch *= n_ch_types ch_names = [f"CH{i + 1}" for i in range(n_ch)] n_data = n_ch * (n_ch + 1) // 2 From 988073c917cb3107094cc5756e76a93d67945a3c Mon Sep 17 00:00:00 2001 From: Aniket <148300120+Aniketsy@users.noreply.github.com> Date: Fri, 6 Mar 2026 10:00:18 +0530 Subject: [PATCH 09/13] Update doc/changes/dev/13713.other.rst Co-authored-by: Thomas S. Binns --- doc/changes/dev/13713.other.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/changes/dev/13713.other.rst b/doc/changes/dev/13713.other.rst index 44d2347154c..921a445d2fb 100644 --- a/doc/changes/dev/13713.other.rst +++ b/doc/changes/dev/13713.other.rst @@ -1 +1 @@ -Refactored CSD channel-type indexing to support all data channel types using ``_DATA_CH_TYPES_SPLIT`` and ``DEFAULTS``, removing hardcoded handling and reducing duplicated logic with covariance plotting, by `Aniket Singh Yadav`_. \ No newline at end of file +Extended support for visualising cross-spectral densities in :func:`mne.viz.plot_csd` and :meth:`mne.time_frequency.CSD.plot` to all types of :term:`data channels`, by `Aniket Singh Yadav`_. \ No newline at end of file From 6049cf0a4784d1646063dfe5d9ac35437b9ba172 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Fri, 6 Mar 2026 04:38:56 +0000 Subject: [PATCH 10/13] remove doc and improvements --- mne/viz/misc.py | 45 ++++++++++---------------------------- mne/viz/tests/test_misc.py | 25 +++++++++++---------- 2 files changed, 26 insertions(+), 44 deletions(-) diff --git a/mne/viz/misc.py b/mne/viz/misc.py index 50b16b45b77..cdf57d51deb 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -52,32 +52,8 @@ ) -def _index_info_by_ch_type(info, obj_ch_names): - """Build channel-type-grouped indices and metadata for an object's channels. - - Parameters - ---------- - info : Info - The measurement info, already restricted to the channels of interest - (e.g. via :func:`pick_info`). - obj_ch_names : list of str - Channel names of the object being indexed (e.g. the channel list of a - :class:`~mne.time_frequency.CrossSpectralDensity` or the filtered - channel list derived from a :class:`~mne.Covariance`). - - Returns - ------- - indices : list of list of int - Channel indices into *obj_ch_names*, one list per channel type. - titles : list of str - Human-readable title for each channel type. - units : list of str - Unit string for each channel type. - scalings : list of float - Scaling factor for each channel type. - ch_types : list of str - Channel type key for each group. - """ +def _index_info_by_ch_type(info, ch_names): + """Build channel-type-grouped indices and metadata for an object's channels.""" info_ch_names = info["ch_names"] picks_list = _picks_by_type(info, meg_combined=False, ref_meg=False, exclude=()) picks_by_type = dict(picks_list) @@ -85,9 +61,9 @@ def _index_info_by_ch_type(info, obj_ch_names): idx_by_type = defaultdict(list) for ch_type, sel in picks_by_type.items(): idx_by_type[ch_type] = [ - obj_ch_names.index(info_ch_names[c]) + ch_names.index(info_ch_names[c]) for c in sel - if info_ch_names[c] in obj_ch_names + if info_ch_names[c] in ch_names ] indices = [] @@ -1521,9 +1497,13 @@ def plot_csd( raise ValueError('"mode" should be either "csd" or "coh".') if info is not None: - indices, titles, _, _, ch_types = _index_info_by_ch_type(info, csd.ch_names) + indices, titles, units, scalings, ch_types = _index_info_by_ch_type( + info, csd.ch_names + ) else: indices = [np.arange(len(csd.ch_names))] + units = [""] + scalings = [1] ch_types = [None] if mode == "csd": titles = ["Cross-spectral density"] @@ -1537,7 +1517,9 @@ def plot_csd( n_rows = int(np.ceil(n_freqs / float(n_cols))) figs = [] - for ind, title, ch_type in zip(indices, titles, ch_types): + for ind, title, unit, scaling, ch_type in zip( + indices, titles, units, scalings, ch_types + ): fig, axes = plt.subplots( n_rows, n_cols, @@ -1546,9 +1528,6 @@ def plot_csd( layout="constrained", ) - if mode == "csd": - scaling = DEFAULTS["scalings"].get(ch_type, 1) - csd_mats = [] for i in range(len(csd.frequencies)): cm = csd.get_data(index=i)[ind][:, ind] diff --git a/mne/viz/tests/test_misc.py b/mne/viz/tests/test_misc.py index ff4fca6200c..2c05e36442d 100644 --- a/mne/viz/tests/test_misc.py +++ b/mne/viz/tests/test_misc.py @@ -19,6 +19,7 @@ read_evokeds, read_source_spaces, ) +from mne._fiff.pick import _DATA_CH_TYPES_SPLIT from mne.chpi import compute_chpi_snr from mne.datasets import testing from mne.filter import create_filter @@ -314,16 +315,16 @@ def test_plot_dipole_amplitudes(): @pytest.mark.parametrize( - "ch_types, expected_n_figs", + "ch_types", [ - (None, 1), - ("eeg", 1), - ("ecog", 1), - (["grad", "dbs"], 2), - ("misc", 0), + None, + "eeg", + "ecog", + ["grad", "dbs"], + "misc", ], ) -def test_plot_csd(ch_types, expected_n_figs): +def test_plot_csd(ch_types): """Test plotting of CSD matrices.""" n_ch = 2 if isinstance(ch_types, list): @@ -342,10 +343,12 @@ def test_plot_csd(ch_types, expected_n_figs): tmax=1, ) - if ch_type_list is not None: - info = create_info(ch_names, sfreq=1.0, ch_types=ch_type_list) - else: - info = None + info = None + expected_n_figs = 1 + if ch_types is not None: + info = create_info(ch_names, sfreq=1.0, ch_types=ch_types) + types_iter = ch_types if isinstance(ch_types, list) else [ch_types] + expected_n_figs = len(set(types_iter).intersection(_DATA_CH_TYPES_SPLIT)) for mode in ("csd", "coh"): figs = plot_csd(csd, info=info, mode=mode, show=False) From bd561c07e0df9a3f1030782e0b00ba4f847eeb24 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Fri, 6 Mar 2026 06:13:33 +0000 Subject: [PATCH 11/13] improve test --- mne/viz/tests/test_misc.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mne/viz/tests/test_misc.py b/mne/viz/tests/test_misc.py index 2c05e36442d..b69d71a7d4f 100644 --- a/mne/viz/tests/test_misc.py +++ b/mne/viz/tests/test_misc.py @@ -347,8 +347,11 @@ def test_plot_csd(ch_types): expected_n_figs = 1 if ch_types is not None: info = create_info(ch_names, sfreq=1.0, ch_types=ch_types) - types_iter = ch_types if isinstance(ch_types, list) else [ch_types] - expected_n_figs = len(set(types_iter).intersection(_DATA_CH_TYPES_SPLIT)) + expected_n_figs = len( + set(ch_types if isinstance(ch_types, list) else [ch_types]).intersection( + _DATA_CH_TYPES_SPLIT + ) + ) for mode in ("csd", "coh"): figs = plot_csd(csd, info=info, mode=mode, show=False) From 9bbbe180ec71427f9955e09b78e54b400e80dc64 Mon Sep 17 00:00:00 2001 From: Aniket <148300120+Aniketsy@users.noreply.github.com> Date: Sat, 7 Mar 2026 14:13:51 +0530 Subject: [PATCH 12/13] Update doc/changes/dev/13713.other.rst --- doc/changes/dev/13713.other.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/changes/dev/13713.other.rst b/doc/changes/dev/13713.other.rst index 921a445d2fb..e085843da2f 100644 --- a/doc/changes/dev/13713.other.rst +++ b/doc/changes/dev/13713.other.rst @@ -1 +1 @@ -Extended support for visualising cross-spectral densities in :func:`mne.viz.plot_csd` and :meth:`mne.time_frequency.CSD.plot` to all types of :term:`data channels`, by `Aniket Singh Yadav`_. \ No newline at end of file +Extended support for visualising cross-spectral densities in :func:`mne.viz.plot_csd` and :meth:`mne.time_frequency.CrossSpectralDensity.plot` to all types of :term:`data channels`, by `Aniket Singh Yadav`_. \ No newline at end of file From 7a210e4fa2f9e8d37307d125f67b8555854d6de2 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Sat, 7 Mar 2026 20:01:50 +0000 Subject: [PATCH 13/13] update changelog --- doc/changes/dev/13713.other.rst | 2 +- doc/changes/names.inc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/changes/dev/13713.other.rst b/doc/changes/dev/13713.other.rst index e085843da2f..da7f4ead348 100644 --- a/doc/changes/dev/13713.other.rst +++ b/doc/changes/dev/13713.other.rst @@ -1 +1 @@ -Extended support for visualising cross-spectral densities in :func:`mne.viz.plot_csd` and :meth:`mne.time_frequency.CrossSpectralDensity.plot` to all types of :term:`data channels`, by `Aniket Singh Yadav`_. \ No newline at end of file +Extended support for visualising cross-spectral densities in :func:`mne.viz.plot_csd` and :meth:`mne.time_frequency.CrossSpectralDensity.plot` to all types of :term:`data channels`, by :newcontrib:`Aniket Singh Yadav`. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 79db14d44e3..09c5818b6af 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -22,6 +22,7 @@ .. _Andrew Gilbert: https://github.com/adgilbert .. _Andrew Quinn: https://github.com/ajquinn .. _Aniket Pradhan: https://github.com/Aniket-Pradhan +.. _Aniket Singh Yadav: https://github.com/Aniketsy .. _Anna Padee: https://github.com/apadee/ .. _Annalisa Pascarella: https://github.com/annapasca .. _Anne-Sophie Dubarry: https://github.com/annesodub