Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13713.other.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Extended support for visualising cross-spectral densities in :func:`mne.viz.plot_csd` and :meth:`mne.time_frequency.CrossSpectralDensity.plot` to all types of :term:`data channels`, by :newcontrib:`Aniket Singh Yadav`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
.. _Andrew Gilbert: https://github.com/adgilbert
.. _Andrew Quinn: https://github.com/ajquinn
.. _Aniket Pradhan: https://github.com/Aniket-Pradhan
.. _Aniket Singh Yadav: https://github.com/Aniketsy
.. _Anna Padee: https://github.com/apadee/
.. _Annalisa Pascarella: https://github.com/annapasca
.. _Anne-Sophie Dubarry: https://github.com/annesodub
Expand Down
100 changes: 46 additions & 54 deletions mne/viz/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
_picks_by_type,
pick_channels,
pick_info,
pick_types,
)
from .._fiff.proj import make_projector
from .._freesurfer import _check_mri, _mri_orientation, _read_mri_info, _reorient_image
Expand Down Expand Up @@ -53,35 +52,50 @@
)


def _index_info_cov(info, cov, exclude):
if exclude == "bads":
exclude = info["bads"]
info = pick_info(info, pick_channels(info["ch_names"], cov["names"], exclude))
del exclude
def _index_info_by_ch_type(info, ch_names):
"""Build channel-type-grouped indices and metadata for an object's channels."""
info_ch_names = info["ch_names"]
picks_list = _picks_by_type(info, meg_combined=False, ref_meg=False, exclude=())
picks_by_type = dict(picks_list)

ch_names = [n for n in cov.ch_names if n in info["ch_names"]]
ch_idx = [cov.ch_names.index(n) for n in ch_names]

info_ch_names = info["ch_names"]
idx_by_type = defaultdict(list)
for ch_type, sel in picks_by_type.items():
idx_by_type[ch_type] = [
ch_names.index(info_ch_names[c])
for c in sel
if info_ch_names[c] in ch_names
]

indices = []
titles = []
units = []
scalings = []
ch_types = []
for key in _DATA_CH_TYPES_SPLIT:
if len(idx_by_type[key]) > 0:
indices.append(idx_by_type[key])
titles.append(DEFAULTS["titles"][key])
units.append(DEFAULTS["units"][key])
scalings.append(DEFAULTS["scalings"][key])
ch_types.append(key)
return indices, titles, units, scalings, ch_types


def _index_info_cov(info, cov, exclude):
if exclude == "bads":
exclude = info["bads"]
info = pick_info(info, pick_channels(info["ch_names"], cov["names"], exclude))
del exclude

ch_names = [n for n in cov.ch_names if n in info["ch_names"]]
ch_idx = [cov.ch_names.index(n) for n in ch_names]

indices, titles, units, scalings, ch_types = _index_info_by_ch_type(info, ch_names)
idx_names = [
(
idx_by_type[key],
f"{DEFAULTS['titles'][key]} covariance",
DEFAULTS["units"][key],
DEFAULTS["scalings"][key],
key,
(idx, f"{title} covariance", unit, scaling, key)
for idx, title, unit, scaling, key in zip(
indices, titles, units, scalings, ch_types
)
for key in _DATA_CH_TYPES_SPLIT
if len(idx_by_type[key]) > 0
]
C = cov.data[ch_idx][:, ch_idx]
return info, C, ch_names, idx_names
Expand Down Expand Up @@ -1483,39 +1497,16 @@ def plot_csd(
raise ValueError('"mode" should be either "csd" or "coh".')

if info is not None:
info_ch_names = info["ch_names"]
sel_eeg = pick_types(info, meg=False, eeg=True, ref_meg=False, exclude=[])
sel_mag = pick_types(info, meg="mag", eeg=False, ref_meg=False, exclude=[])
sel_grad = pick_types(info, meg="grad", eeg=False, ref_meg=False, exclude=[])
idx_eeg = [
csd.ch_names.index(info_ch_names[c])
for c in sel_eeg
if info_ch_names[c] in csd.ch_names
]
idx_mag = [
csd.ch_names.index(info_ch_names[c])
for c in sel_mag
if info_ch_names[c] in csd.ch_names
]
idx_grad = [
csd.ch_names.index(info_ch_names[c])
for c in sel_grad
if info_ch_names[c] in csd.ch_names
]
indices = [idx_eeg, idx_mag, idx_grad]
titles = ["EEG", "Magnetometers", "Gradiometers"]

if mode == "csd":
# The units in which to plot the CSD
units = dict(eeg="µV²", grad="fT²/cm²", mag="fT²")
scalings = dict(eeg=1e12, grad=1e26, mag=1e30)
indices, titles, units, scalings, ch_types = _index_info_by_ch_type(
info, csd.ch_names
)
else:
indices = [np.arange(len(csd.ch_names))]
units = [""]
scalings = [1]
ch_types = [None]
if mode == "csd":
titles = ["Cross-spectral density"]
# Units and scaling unknown
units = dict()
scalings = dict()
elif mode == "coh":
titles = ["Coherence"]

Expand All @@ -1526,10 +1517,9 @@ def plot_csd(
n_rows = int(np.ceil(n_freqs / float(n_cols)))

figs = []
for ind, title, ch_type in zip(indices, titles, ["eeg", "mag", "grad"]):
if len(ind) == 0:
continue

for ind, title, unit, scaling, ch_type in zip(
indices, titles, units, scalings, ch_types
):
fig, axes = plt.subplots(
n_rows,
n_cols,
Expand All @@ -1542,7 +1532,7 @@ def plot_csd(
for i in range(len(csd.frequencies)):
cm = csd.get_data(index=i)[ind][:, ind]
if mode == "csd":
cm = np.abs(cm) * scalings.get(ch_type, 1)
cm = np.abs(cm) * scaling**2
elif mode == "coh":
# Compute coherence from the CSD matrix
psd = np.diag(cm).real
Expand All @@ -1566,8 +1556,10 @@ def plot_csd(
cb = plt.colorbar(im, ax=[a for ax_ in axes for a in ax_])
if mode == "csd":
label = "CSD"
if ch_type in units:
label += f" ({units[ch_type]})"
if ch_type is not None:
if "/" in unit:
unit = f"({unit})"
label += f" ({unit}²)"
cb.set_label(label)
elif mode == "coh":
cb.set_label("Coherence")
Expand Down
42 changes: 37 additions & 5 deletions mne/viz/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

from mne import (
SourceEstimate,
create_info,
pick_events,
read_cov,
read_dipole,
read_events,
read_evokeds,
read_source_spaces,
)
from mne._fiff.pick import _DATA_CH_TYPES_SPLIT
from mne.chpi import compute_chpi_snr
from mne.datasets import testing
from mne.filter import create_filter
Expand Down Expand Up @@ -312,18 +314,48 @@ def test_plot_dipole_amplitudes():
dipoles.plot_amplitudes(show=False)


def test_plot_csd():
@pytest.mark.parametrize(
"ch_types",
[
None,
"eeg",
"ecog",
["grad", "dbs"],
"misc",
],
)
def test_plot_csd(ch_types):
"""Test plotting of CSD matrices."""
n_ch = 2
if isinstance(ch_types, list):
n_ch_types = len(ch_types)
ch_types = np.repeat(ch_types, n_ch).tolist()
n_ch *= n_ch_types
ch_names = [f"CH{i + 1}" for i in range(n_ch)]
n_data = n_ch * (n_ch + 1) // 2

csd = CrossSpectralDensity(
[1, 2, 3],
["CH1", "CH2"],
list(range(1, n_data + 1)),
ch_names,
frequencies=[(10, 20)],
n_fft=1,
tmin=0,
tmax=1,
)
plot_csd(csd, mode="csd") # Plot cross-spectral density
plot_csd(csd, mode="coh") # Plot coherence

info = None
expected_n_figs = 1
if ch_types is not None:
info = create_info(ch_names, sfreq=1.0, ch_types=ch_types)
expected_n_figs = len(
set(ch_types if isinstance(ch_types, list) else [ch_types]).intersection(
_DATA_CH_TYPES_SPLIT
)
)

for mode in ("csd", "coh"):
figs = plot_csd(csd, info=info, mode=mode, show=False)
assert len(figs) == expected_n_figs


@pytest.mark.slowtest # Slow on Azure
Expand Down
Loading