diff --git a/doc/changes/dev/13719.newfeature.rst b/doc/changes/dev/13719.newfeature.rst new file mode 100644 index 000000000000..78a736b64776 --- /dev/null +++ b/doc/changes/dev/13719.newfeature.rst @@ -0,0 +1 @@ +Improve memory usage and runtime of :func:`mne.time_frequency.csd_fourier`, :func:`mne.time_frequency.csd_multitaper`, :func:`mne.time_frequency.csd_array_fourier`, and :func:`mne.time_frequency.csd_array_multitaper` by avoiding unnecessary full-matrix CSD construction, by `Pragnya Khandelwal`_. \ No newline at end of file diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py index 4ddaa0ac6a3b..e4b67fdaed35 100644 --- a/mne/time_frequency/csd.py +++ b/mne/time_frequency/csd.py @@ -1406,22 +1406,14 @@ def _csd_fourier(X, sfreq, n_times, freq_mask, n_fft): x_mt, _ = _mt_spectra(X, np.hanning(n_times), sfreq, n_fft) # Hack so we can sum over axis=-2 - weights = np.array([1.0])[:, np.newaxis, np.newaxis, np.newaxis] + weights = np.array([1.0])[np.newaxis, :, np.newaxis] x_mt = x_mt[:, :, freq_mask] - # Calculating CSD - # Tiling x_mt so that we can easily use _csd_from_mt() - x_mt = x_mt[:, np.newaxis, :, :] - x_mt = np.tile(x_mt, [1, x_mt.shape[0], 1, 1]) - y_mt = np.transpose(x_mt, axes=[1, 0, 2, 3]) - weights_y = np.transpose(weights, axes=[1, 0, 2, 3]) - csds = _csd_from_mt(x_mt, y_mt, weights, weights_y) - - # FIXME: don't compute full matrix in the first place - csds = np.array( - [_sym_mat_to_vector(csds[:, :, i]) for i in range(csds.shape[-1])] - ).T + # Calculate CSD for upper-triangle channel pairs directly. + # This avoids computing/storing the full channel x channel matrix. + ii, jj = np.triu_indices(x_mt.shape[0]) + csds = _csd_from_mt(x_mt[ii], x_mt[jj], weights, weights) # Scaling by number of samples and compensating for loss of power # due to windowing (see section 11.5.2 in Bendat & Piersol). @@ -1445,27 +1437,23 @@ def _csd_multitaper( _, weights = _psd_from_mt_adaptive( x_mt, eigvals, freq_mask, max_iter, return_weights=True ) - # Tiling weights so that we can easily use _csd_from_mt() - weights = weights[:, np.newaxis, :, :] - weights = np.tile(weights, [1, x_mt.shape[0], 1, 1]) else: # Do not use adaptive weights - weights = np.sqrt(eigvals)[np.newaxis, np.newaxis, :, np.newaxis] + weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis] x_mt = x_mt[:, :, freq_mask] - # Calculating CSD - # Tiling x_mt so that we can easily use _csd_from_mt() - x_mt = x_mt[:, np.newaxis, :, :] - x_mt = np.tile(x_mt, [1, x_mt.shape[0], 1, 1]) - y_mt = np.transpose(x_mt, axes=[1, 0, 2, 3]) - weights_y = np.transpose(weights, axes=[1, 0, 2, 3]) - csds = _csd_from_mt(x_mt, y_mt, weights, weights_y) - - # FIXME: don't compute full matrix in the first place - csds = np.array( - [_sym_mat_to_vector(csds[:, :, i]) for i in range(csds.shape[-1])] - ).T + # Calculate CSD for upper-triangle channel pairs directly. + # This avoids computing/storing the full channel x channel matrix. + ii, jj = np.triu_indices(x_mt.shape[0]) + x_mt_i = x_mt[ii] + x_mt_j = x_mt[jj] + if adaptive: + weights_i = weights[ii] + weights_j = weights[jj] + else: + weights_i = weights_j = weights + csds = _csd_from_mt(x_mt_i, x_mt_j, weights_i, weights_j) # Scaling by sampling frequency for compatibility with Matlab csds /= sfreq