Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13719.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve memory usage and runtime of :func:`mne.time_frequency.csd_fourier`, :func:`mne.time_frequency.csd_multitaper`, :func:`mne.time_frequency.csd_array_fourier`, and :func:`mne.time_frequency.csd_array_multitaper` by avoiding unnecessary full-matrix CSD construction, by `Pragnya Khandelwal`_.
46 changes: 17 additions & 29 deletions mne/time_frequency/csd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,22 +1406,14 @@ def _csd_fourier(X, sfreq, n_times, freq_mask, n_fft):
x_mt, _ = _mt_spectra(X, np.hanning(n_times), sfreq, n_fft)

# Hack so we can sum over axis=-2
weights = np.array([1.0])[:, np.newaxis, np.newaxis, np.newaxis]
weights = np.array([1.0])[np.newaxis, :, np.newaxis]

x_mt = x_mt[:, :, freq_mask]

# Calculating CSD
# Tiling x_mt so that we can easily use _csd_from_mt()
x_mt = x_mt[:, np.newaxis, :, :]
x_mt = np.tile(x_mt, [1, x_mt.shape[0], 1, 1])
y_mt = np.transpose(x_mt, axes=[1, 0, 2, 3])
weights_y = np.transpose(weights, axes=[1, 0, 2, 3])
csds = _csd_from_mt(x_mt, y_mt, weights, weights_y)

# FIXME: don't compute full matrix in the first place
csds = np.array(
[_sym_mat_to_vector(csds[:, :, i]) for i in range(csds.shape[-1])]
).T
# Calculate CSD for upper-triangle channel pairs directly.
# This avoids computing/storing the full channel x channel matrix.
ii, jj = np.triu_indices(x_mt.shape[0])
csds = _csd_from_mt(x_mt[ii], x_mt[jj], weights, weights)

# Scaling by number of samples and compensating for loss of power
# due to windowing (see section 11.5.2 in Bendat & Piersol).
Expand All @@ -1445,27 +1437,23 @@ def _csd_multitaper(
_, weights = _psd_from_mt_adaptive(
x_mt, eigvals, freq_mask, max_iter, return_weights=True
)
# Tiling weights so that we can easily use _csd_from_mt()
weights = weights[:, np.newaxis, :, :]
weights = np.tile(weights, [1, x_mt.shape[0], 1, 1])
else:
# Do not use adaptive weights
weights = np.sqrt(eigvals)[np.newaxis, np.newaxis, :, np.newaxis]
weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis]

x_mt = x_mt[:, :, freq_mask]

# Calculating CSD
# Tiling x_mt so that we can easily use _csd_from_mt()
x_mt = x_mt[:, np.newaxis, :, :]
x_mt = np.tile(x_mt, [1, x_mt.shape[0], 1, 1])
y_mt = np.transpose(x_mt, axes=[1, 0, 2, 3])
weights_y = np.transpose(weights, axes=[1, 0, 2, 3])
csds = _csd_from_mt(x_mt, y_mt, weights, weights_y)

# FIXME: don't compute full matrix in the first place
csds = np.array(
[_sym_mat_to_vector(csds[:, :, i]) for i in range(csds.shape[-1])]
).T
# Calculate CSD for upper-triangle channel pairs directly.
# This avoids computing/storing the full channel x channel matrix.
ii, jj = np.triu_indices(x_mt.shape[0])
x_mt_i = x_mt[ii]
x_mt_j = x_mt[jj]
if adaptive:
weights_i = weights[ii]
weights_j = weights[jj]
else:
weights_i = weights_j = weights
csds = _csd_from_mt(x_mt_i, x_mt_j, weights_i, weights_j)

# Scaling by sampling frequency for compatibility with Matlab
csds /= sfreq
Expand Down