From b36b27b5ee67e979ef7e4070cd6adb948ea669c1 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Tue, 3 Mar 2026 17:17:08 +0000 Subject: [PATCH 1/8] Make SSD dict-based reconstruction sklearn-compliant --- mne/decoding/__init__.pyi | 3 +- mne/decoding/ssd.py | 134 ++++++++++++++++++++++++++++----- mne/decoding/tests/test_ssd.py | 71 ++++++++++++++++- 3 files changed, 189 insertions(+), 19 deletions(-) diff --git a/mne/decoding/__init__.pyi b/mne/decoding/__init__.pyi index 1131f1597c5..0c3ba92e33e 100644 --- a/mne/decoding/__init__.pyi +++ b/mne/decoding/__init__.pyi @@ -10,6 +10,7 @@ __all__ = [ "SPoC", "SSD", "Scaler", + "read_ssd", "SlidingEstimator", "SpatialFilter", "TemporalFilter", @@ -36,7 +37,7 @@ from .ems import EMS, compute_ems from .receptive_field import ReceptiveField from .search_light import GeneralizingEstimator, SlidingEstimator from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator -from .ssd import SSD +from .ssd import SSD, read_ssd from .time_delaying_ridge import TimeDelayingRidge from .time_frequency import TimeFrequency from .transformer import ( diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 41d67ece8c6..cf41f002031 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -3,17 +3,21 @@ # Copyright the MNE-Python contributors. import collections.abc as abc +import pickle from functools import partial import numpy as np +from .. import __version__ as _mne_version from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx from ..filter import filter_data from ..utils import ( + _check_fname, _validate_type, fill_doc, logger, + warn, ) from ._covs_ged import _ssd_estimate from ._mod_ged import _get_spectral_ratio, _ssd_mod @@ -83,7 +87,10 @@ class SSD(_GEDTransformer): measurement info or estimated from the data, which determines the maximum possible number of components. See Notes of :func:`mne.compute_rank` for details. - We recommend to use 'full' when working with epoched data. + We recommend to use 'full' when working with epoched da + t + # mne.Info subclasses dict, so Info objects must be excluded from this + # branch. Only a plain dict signals a serialized SSD state.a. Attributes ---------- @@ -100,8 +107,8 @@ class SSD(_GEDTransformer): def __init__( self, info, - filt_params_signal, - filt_params_noise, + filt_params_signal=None, + filt_params_noise=None, reg=None, n_components=None, picks=None, @@ -114,6 +121,25 @@ def __init__( rank=None, ): """Initialize instance.""" + if isinstance(info, dict) and info.get("_ssd_state"): + required_keys = ( + "info", + "filt_params_signal", + "filt_params_noise", + "n_components", + "filters_", + "patterns_", + ) + missing = [k for k in required_keys if k not in info] + if missing: + raise ValueError( + "If 'info' is a dict, it must be a serialized SSD state " + f"(missing keys: {missing}). " + "Otherwise pass an mne.Info object." + ) + self.__setstate__(info) + return + self.info = info self.filt_params_signal = filt_params_signal self.filt_params_noise = filt_params_noise @@ -127,25 +153,71 @@ def __init__( self.restr_type = restr_type self.rank = rank - cov_callable = partial( - _ssd_estimate, - reg=reg, - cov_method_params=cov_method_params, - info=info, - picks=picks, - n_fft=n_fft, - filt_params_signal=filt_params_signal, - filt_params_noise=filt_params_noise, - rank=rank, - sort_by_spectral_ratio=sort_by_spectral_ratio, - ) + self._create_cov_callable() super().__init__( n_components=n_components, - cov_callable=cov_callable, - mod_ged_callable=_ssd_mod, + cov_callable=self.cov_callable, + mod_ged_callable=self.mod_ged_callable, restr_type=restr_type, ) + def _create_cov_callable(self): + """Recreate covariance callable after initialization or loading.""" + self.cov_callable = partial( + _ssd_estimate, + reg=self.reg, + cov_method_params=self.cov_method_params, + info=self.info, + picks=self.picks, + n_fft=self.n_fft, + filt_params_signal=self.filt_params_signal, + filt_params_noise=self.filt_params_noise, + rank=self.rank, + sort_by_spectral_ratio=self.sort_by_spectral_ratio, + ) + self.mod_ged_callable = _ssd_mod + + def __getstate__(self): + """Prepare state for serialization.""" + state = self.__dict__.copy() + state.pop("cov_callable", None) + state.pop("mod_ged_callable", None) + state["_ssd_state"] = True + state["class_name"] = "SSD" + state["mne_version"] = _mne_version + return state + + def __setstate__(self, state): + """Restore state from serialization.""" + saved_version = state.get("mne_version") + if saved_version is not None and saved_version != _mne_version: + warn( + f"The SSD object was saved with MNE-Python {saved_version} but is " + f"being loaded with {_mne_version}. This may cause issues." + ) + state = { + k: v + for k, v in state.items() + if k not in ("_ssd_state", "class_name", "mne_version") + } + self.__dict__.update(state) + self._create_cov_callable() + + def save(self, fname, *, overwrite=False): + """Save the SSD object to a file. + + Parameters + ---------- + fname : path-like + The file path to save to (e.g. ``'my_ssd.pkl'``). + overwrite : bool + If True, overwrite an existing file. Defaults to False. + """ + fname = _check_fname(fname, overwrite=overwrite) + state = self.__getstate__() + with open(fname, "wb") as fid: + pickle.dump(state, fid) + def _validate_params(self, X): if isinstance(self.info, float): # special case, mostly for testing self.sfreq_ = self.info @@ -350,3 +422,31 @@ def apply(self, X): pick_patterns = self.patterns_[: self.n_components].T X = pick_patterns @ X_ssd return X + + +def read_ssd(fname): + """Load a saved :class:`~mne.decoding.SSD` object from disk. + + Parameters + ---------- + fname : path-like + The full path to the ``.pkl`` file saved by :meth:`~mne.decoding.SSD.save`. + + Returns + ------- + ssd : instance of :class:`~mne.decoding.SSD` + The loaded SSD object with all fitted attributes restored. + + See Also + -------- + SSD.save + """ + fname = _check_fname(fname, must_exist=True, overwrite="read") + with open(fname, "rb") as fid: + state = pickle.load(fid) + if state.get("class_name") != "SSD": + raise ValueError( + f"File does not contain an SSD object (got class_name=" + f"{state.get('class_name')!r})." + ) + return SSD(info=state) diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 086413b043f..fe618e8c9bf 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -18,7 +18,7 @@ from mne._fiff.pick import _picks_to_idx from mne.decoding import CSP from mne.decoding._mod_ged import _get_spectral_ratio -from mne.decoding.ssd import SSD +from mne.decoding.ssd import SSD, read_ssd from mne.filter import filter_data from mne.time_frequency import psd_array_welch @@ -570,6 +570,75 @@ def test_picks_arg(): ssd.fit(X).transform(X) +def test_ssd_save_load(tmp_path): + """Test that SSD can be saved to disk and loaded back correctly.""" + X, _, _ = simulate_data() + sf = 250 + n_channels = X.shape[0] + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") + n_components = 5 + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + + ssd = SSD(info, filt_params_signal, filt_params_noise, n_components=n_components) + ssd.fit(X) + + fname = tmp_path / "test_ssd.pkl" + ssd.save(fname) + + ssd_loaded = read_ssd(fname) + + # Check fitted array attributes are restored + assert_array_almost_equal(ssd.filters_, ssd_loaded.filters_) + assert_array_almost_equal(ssd.patterns_, ssd_loaded.patterns_) + + # Check scalar/param attributes + assert ssd.n_components == ssd_loaded.n_components + assert ssd.info["sfreq"] == ssd_loaded.info["sfreq"] + assert ssd.filt_params_signal == ssd_loaded.filt_params_signal + assert ssd.filt_params_noise == ssd_loaded.filt_params_noise + + # Check transform output matches + X_orig = ssd.transform(X) + X_loaded = ssd_loaded.transform(X) + assert_array_almost_equal(X_orig, X_loaded) + + with pytest.raises(FileExistsError): + ssd.save(fname) + ssd.save(fname, overwrite=True) + + # Check that loading a non-SSD file raises an error + import pickle + + bad_fname = tmp_path / "bad.pkl" + with open(bad_fname, "wb") as f: + pickle.dump({"class_name": "NotSSD"}, f) + with pytest.raises(ValueError, match="does not contain an SSD object"): + read_ssd(bad_fname) + with pytest.raises(FileNotFoundError): + read_ssd(tmp_path / "nonexistent.pkl") + + ssd_pickled = pickle.loads(pickle.dumps(ssd)) + assert_array_almost_equal(ssd.filters_, ssd_pickled.filters_) + assert_array_almost_equal(ssd.transform(X), ssd_pickled.transform(X)) + assert not hasattr(ssd_pickled, "class_name") + assert not hasattr(ssd_pickled, "mne_version") + + with pytest.raises(ValueError, match="serialized SSD state"): + SSD(info={"_ssd_state": True, "random": 123}) + + def test_get_spectral_ratio(): """Test that method is the same as function in _mod_ged.py.""" X, _, _ = simulate_data() From 002fedecec9d41f82199af6e87e659e23eb140b5 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Tue, 3 Mar 2026 17:39:20 +0000 Subject: [PATCH 2/8] Make SSD dict-based reconstruction sklearn-compliant --- mne/decoding/ssd.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index cf41f002031..9b23f43f5f5 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -87,10 +87,7 @@ class SSD(_GEDTransformer): measurement info or estimated from the data, which determines the maximum possible number of components. See Notes of :func:`mne.compute_rank` for details. - We recommend to use 'full' when working with epoched da - t - # mne.Info subclasses dict, so Info objects must be excluded from this - # branch. Only a plain dict signals a serialized SSD state.a. + We recommend to use 'full' when working with epoched data. Attributes ---------- From 8d09638cb0a0230ccb035daad3d997842c7d9967 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Thu, 5 Mar 2026 08:09:15 +0000 Subject: [PATCH 3/8] Make SSD dict-based reconstruction sklearn-compliant --- mne/decoding/ssd.py | 169 ++++++++++++++++++++------------- mne/decoding/tests/test_ssd.py | 39 ++++---- 2 files changed, 125 insertions(+), 83 deletions(-) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 9b23f43f5f5..1165bb4c724 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -3,27 +3,61 @@ # Copyright the MNE-Python contributors. import collections.abc as abc -import pickle from functools import partial import numpy as np -from .. import __version__ as _mne_version from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx from ..filter import filter_data from ..utils import ( _check_fname, + _import_h5io_funcs, _validate_type, fill_doc, logger, - warn, ) +from ..utils.check import check_fname from ._covs_ged import _ssd_estimate from ._mod_ged import _get_spectral_ratio, _ssd_mod from .base import _GEDTransformer +def _create_callables( + reg, + cov_method_params, + info, + picks, + n_fft, + filt_params_signal, + filt_params_noise, + rank, + sort_by_spectral_ratio, +): + """Create covariance and mod_ged callables for SSD. + + Returns + ------- + cov_callable : callable + Partial function for computing SSD covariance estimates. + mod_ged_callable : callable + Function for modifying the GED result. + """ + cov_callable = partial( + _ssd_estimate, + reg=reg, + cov_method_params=cov_method_params, + info=info, + picks=picks, + n_fft=n_fft, + filt_params_signal=filt_params_signal, + filt_params_noise=filt_params_noise, + rank=rank, + sort_by_spectral_ratio=sort_by_spectral_ratio, + ) + return cov_callable, _ssd_mod + + @fill_doc class SSD(_GEDTransformer): """ @@ -104,8 +138,8 @@ class SSD(_GEDTransformer): def __init__( self, info, - filt_params_signal=None, - filt_params_noise=None, + filt_params_signal, + filt_params_noise, reg=None, n_components=None, picks=None, @@ -118,25 +152,6 @@ def __init__( rank=None, ): """Initialize instance.""" - if isinstance(info, dict) and info.get("_ssd_state"): - required_keys = ( - "info", - "filt_params_signal", - "filt_params_noise", - "n_components", - "filters_", - "patterns_", - ) - missing = [k for k in required_keys if k not in info] - if missing: - raise ValueError( - "If 'info' is a dict, it must be a serialized SSD state " - f"(missing keys: {missing}). " - "Otherwise pass an mne.Info object." - ) - self.__setstate__(info) - return - self.info = info self.filt_params_signal = filt_params_signal self.filt_params_noise = filt_params_noise @@ -150,18 +165,7 @@ def __init__( self.restr_type = restr_type self.rank = rank - self._create_cov_callable() - super().__init__( - n_components=n_components, - cov_callable=self.cov_callable, - mod_ged_callable=self.mod_ged_callable, - restr_type=restr_type, - ) - - def _create_cov_callable(self): - """Recreate covariance callable after initialization or loading.""" - self.cov_callable = partial( - _ssd_estimate, + self.cov_callable, self.mod_ged_callable = _create_callables( reg=self.reg, cov_method_params=self.cov_method_params, info=self.info, @@ -172,48 +176,80 @@ def _create_cov_callable(self): rank=self.rank, sort_by_spectral_ratio=self.sort_by_spectral_ratio, ) - self.mod_ged_callable = _ssd_mod + super().__init__( + n_components=n_components, + cov_callable=self.cov_callable, + mod_ged_callable=self.mod_ged_callable, + restr_type=restr_type, + ) def __getstate__(self): """Prepare state for serialization.""" state = self.__dict__.copy() state.pop("cov_callable", None) state.pop("mod_ged_callable", None) - state["_ssd_state"] = True - state["class_name"] = "SSD" - state["mne_version"] = _mne_version return state + _required_state_keys = ( + "info", + "filt_params_signal", + "filt_params_noise", + "reg", + "n_components", + "picks", + "sort_by_spectral_ratio", + "return_filtered", + "n_fft", + "cov_method_params", + "restr_type", + "rank", + ) + def __setstate__(self, state): """Restore state from serialization.""" - saved_version = state.get("mne_version") - if saved_version is not None and saved_version != _mne_version: - warn( - f"The SSD object was saved with MNE-Python {saved_version} but is " - f"being loaded with {_mne_version}. This may cause issues." + missing = [k for k in self._required_state_keys if k not in state] + if missing: + raise ValueError( + f"State dict is missing required keys: {missing}. " + "The file may be corrupt or not a valid SSD state." ) - state = { - k: v - for k, v in state.items() - if k not in ("_ssd_state", "class_name", "mne_version") - } self.__dict__.update(state) - self._create_cov_callable() + self.cov_callable, self.mod_ged_callable = _create_callables( + reg=self.reg, + cov_method_params=self.cov_method_params, + info=self.info, + picks=self.picks, + n_fft=self.n_fft, + filt_params_signal=self.filt_params_signal, + filt_params_noise=self.filt_params_noise, + rank=self.rank, + sort_by_spectral_ratio=self.sort_by_spectral_ratio, + ) def save(self, fname, *, overwrite=False): - """Save the SSD object to a file. + """Save the SSD object to disk (in HDF5 format). Parameters ---------- fname : path-like - The file path to save to (e.g. ``'my_ssd.pkl'``). + The file path to save to. Should end with ``'.h5'`` or ``'.hdf5'``. overwrite : bool If True, overwrite an existing file. Defaults to False. + + See Also + -------- + mne.decoding.read_ssd """ + _, write_hdf5 = _import_h5io_funcs() + check_fname(fname, "ssd", (".h5", ".hdf5")) fname = _check_fname(fname, overwrite=overwrite) - state = self.__getstate__() - with open(fname, "wb") as fid: - pickle.dump(state, fid) + write_hdf5( + fname, + self.__getstate__(), + overwrite=overwrite, + title="mnepython", + slash="replace", + ) def _validate_params(self, X): if isinstance(self.info, float): # special case, mostly for testing @@ -427,7 +463,8 @@ def read_ssd(fname): Parameters ---------- fname : path-like - The full path to the ``.pkl`` file saved by :meth:`~mne.decoding.SSD.save`. + Path to an SSD file in HDF5 format, which should end with ``.h5`` or + ``.hdf5``. Returns ------- @@ -438,12 +475,12 @@ def read_ssd(fname): -------- SSD.save """ - fname = _check_fname(fname, must_exist=True, overwrite="read") - with open(fname, "rb") as fid: - state = pickle.load(fid) - if state.get("class_name") != "SSD": - raise ValueError( - f"File does not contain an SSD object (got class_name=" - f"{state.get('class_name')!r})." - ) - return SSD(info=state) + read_hdf5, _ = _import_h5io_funcs() + _validate_type(fname, "path-like", "fname") + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + state = read_hdf5(fname, title="mnepython", slash="replace") + if "info" in state and not isinstance(state["info"], Info): + state["info"] = Info(state["info"]) + ssd = object.__new__(SSD) + ssd.__setstate__(state) + return ssd diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index fe618e8c9bf..3b0500b892c 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -572,6 +572,7 @@ def test_picks_arg(): def test_ssd_save_load(tmp_path): """Test that SSD can be saved to disk and loaded back correctly.""" + h5io = pytest.importorskip("h5io") X, _, _ = simulate_data() sf = 250 n_channels = X.shape[0] @@ -594,11 +595,20 @@ def test_ssd_save_load(tmp_path): ssd = SSD(info, filt_params_signal, filt_params_noise, n_components=n_components) ssd.fit(X) - fname = tmp_path / "test_ssd.pkl" + state = ssd.__getstate__() + assert "cov_callable" not in state + assert "mod_ged_callable" not in state + + fname = tmp_path / "test_ssd.h5" ssd.save(fname) ssd_loaded = read_ssd(fname) + assert hasattr(ssd_loaded, "cov_callable") + assert hasattr(ssd_loaded, "mod_ged_callable") + assert callable(ssd_loaded.cov_callable) + assert callable(ssd_loaded.mod_ged_callable) + # Check fitted array attributes are restored assert_array_almost_equal(ssd.filters_, ssd_loaded.filters_) assert_array_almost_equal(ssd.patterns_, ssd_loaded.patterns_) @@ -618,25 +628,20 @@ def test_ssd_save_load(tmp_path): ssd.save(fname) ssd.save(fname, overwrite=True) - # Check that loading a non-SSD file raises an error - import pickle - - bad_fname = tmp_path / "bad.pkl" - with open(bad_fname, "wb") as f: - pickle.dump({"class_name": "NotSSD"}, f) - with pytest.raises(ValueError, match="does not contain an SSD object"): + # Check that loading an HDF5 file with missing keys raises an error + bad_fname = tmp_path / "bad.h5" + h5io.write_hdf5(bad_fname, dict(foo="bar"), title="mnepython", slash="replace") + with pytest.raises(ValueError, match="missing required keys"): read_ssd(bad_fname) - with pytest.raises(FileNotFoundError): - read_ssd(tmp_path / "nonexistent.pkl") - ssd_pickled = pickle.loads(pickle.dumps(ssd)) - assert_array_almost_equal(ssd.filters_, ssd_pickled.filters_) - assert_array_almost_equal(ssd.transform(X), ssd_pickled.transform(X)) - assert not hasattr(ssd_pickled, "class_name") - assert not hasattr(ssd_pickled, "mne_version") + with pytest.raises(FileNotFoundError): + read_ssd(tmp_path / "nonexistent.h5") - with pytest.raises(ValueError, match="serialized SSD state"): - SSD(info={"_ssd_state": True, "random": 123}) + fname_rt = tmp_path / "test_ssd_rt.h5" + ssd.save(fname_rt) + ssd_rt = read_ssd(fname_rt) + assert_array_almost_equal(ssd.filters_, ssd_rt.filters_) + assert_array_almost_equal(ssd.transform(X), ssd_rt.transform(X)) def test_get_spectral_ratio(): From a521c9229c994825f8b36cb12cfdc2b01b7213a5 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Thu, 5 Mar 2026 09:52:29 +0000 Subject: [PATCH 4/8] fix-doc and add changelog --- doc/api/decoding.rst | 1 + doc/changes/13718.other.rst | 1 + mne/decoding/tests/test_ssd.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 doc/changes/13718.other.rst diff --git a/doc/api/decoding.rst b/doc/api/decoding.rst index f8f2257825f..c2d20a0837f 100644 --- a/doc/api/decoding.rst +++ b/doc/api/decoding.rst @@ -41,3 +41,4 @@ Functions that assist with decoding and model fitting: cross_val_multiscore get_coef get_spatial_filter_from_estimator + read_ssd diff --git a/doc/changes/13718.other.rst b/doc/changes/13718.other.rst new file mode 100644 index 00000000000..b8cf9fd84cd --- /dev/null +++ b/doc/changes/13718.other.rst @@ -0,0 +1 @@ +Add native save and read functionality for :class:`mne.decoding.SSD` and :class:`mne.decoding.SPoC` objects, by `Aniket Singh Yadav`_. \ No newline at end of file diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 3b0500b892c..810d7b86699 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -634,7 +634,7 @@ def test_ssd_save_load(tmp_path): with pytest.raises(ValueError, match="missing required keys"): read_ssd(bad_fname) - with pytest.raises(FileNotFoundError): + with pytest.raises(OSError, match="not found"): read_ssd(tmp_path / "nonexistent.h5") fname_rt = tmp_path / "test_ssd_rt.h5" From 346015105b47396405737c3b4a7f6e99a14578c4 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Thu, 5 Mar 2026 10:11:28 +0000 Subject: [PATCH 5/8] add changelog --- doc/changes/{ => dev}/13718.other.rst | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename doc/changes/{ => dev}/13718.other.rst (100%) diff --git a/doc/changes/13718.other.rst b/doc/changes/dev/13718.other.rst similarity index 100% rename from doc/changes/13718.other.rst rename to doc/changes/dev/13718.other.rst From a9ce55303fe4dbde1c32b62ac52c3e38ecac68b0 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Mon, 9 Mar 2026 11:42:07 +0000 Subject: [PATCH 6/8] Make SSD dict-based reconstruction sklearn-compliant --- mne/decoding/ssd.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 1165bb4c724..6d30cf84785 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -16,6 +16,7 @@ _validate_type, fill_doc, logger, + verbose, ) from ..utils.check import check_fname from ._covs_ged import _ssd_estimate @@ -211,7 +212,7 @@ def __setstate__(self, state): if missing: raise ValueError( f"State dict is missing required keys: {missing}. " - "The file may be corrupt or not a valid SSD state." + "The state may be from an incompatible version of MNE." ) self.__dict__.update(state) self.cov_callable, self.mod_ged_callable = _create_callables( @@ -226,15 +227,16 @@ def __setstate__(self, state): sort_by_spectral_ratio=self.sort_by_spectral_ratio, ) - def save(self, fname, *, overwrite=False): + @verbose + def save(self, fname, *, overwrite=False, verbose=None): """Save the SSD object to disk (in HDF5 format). Parameters ---------- fname : path-like The file path to save to. Should end with ``'.h5'`` or ``'.hdf5'``. - overwrite : bool - If True, overwrite an existing file. Defaults to False. + %(overwrite)s + %(verbose)s See Also -------- @@ -242,7 +244,7 @@ def save(self, fname, *, overwrite=False): """ _, write_hdf5 = _import_h5io_funcs() check_fname(fname, "ssd", (".h5", ".hdf5")) - fname = _check_fname(fname, overwrite=overwrite) + fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) write_hdf5( fname, self.__getstate__(), @@ -479,8 +481,6 @@ def read_ssd(fname): _validate_type(fname, "path-like", "fname") fname = _check_fname(fname=fname, overwrite="read", must_exist=False) state = read_hdf5(fname, title="mnepython", slash="replace") - if "info" in state and not isinstance(state["info"], Info): - state["info"] = Info(state["info"]) ssd = object.__new__(SSD) ssd.__setstate__(state) return ssd From 2a5730bbc0c8b68b229a048f9b48a0fcdf8f3904 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Tue, 10 Mar 2026 04:27:09 +0000 Subject: [PATCH 7/8] extending framework to other classes --- doc/api/decoding.rst | 3 + mne/decoding/__init__.pyi | 7 +- mne/decoding/csp.py | 189 ++++++++++++++++++++++++++++++- mne/decoding/ssd.py | 3 +- mne/decoding/tests/test_csp.py | 113 +++++++++++++++++- mne/decoding/tests/test_ssd.py | 6 - mne/decoding/tests/test_xdawn.py | 58 +++++++++- mne/decoding/xdawn.py | 100 +++++++++++++++- 8 files changed, 466 insertions(+), 13 deletions(-) diff --git a/doc/api/decoding.rst b/doc/api/decoding.rst index c2d20a0837f..fe21427a093 100644 --- a/doc/api/decoding.rst +++ b/doc/api/decoding.rst @@ -41,4 +41,7 @@ Functions that assist with decoding and model fitting: cross_val_multiscore get_coef get_spatial_filter_from_estimator + read_csp + read_spoc read_ssd + read_xdawn_transformer diff --git a/mne/decoding/__init__.pyi b/mne/decoding/__init__.pyi index 0c3ba92e33e..3895604bfa8 100644 --- a/mne/decoding/__init__.pyi +++ b/mne/decoding/__init__.pyi @@ -10,7 +10,10 @@ __all__ = [ "SPoC", "SSD", "Scaler", + "read_csp", + "read_spoc", "read_ssd", + "read_xdawn_transformer", "SlidingEstimator", "SpatialFilter", "TemporalFilter", @@ -32,7 +35,7 @@ from .base import ( cross_val_multiscore, get_coef, ) -from .csp import CSP, SPoC +from .csp import CSP, SPoC, read_csp, read_spoc from .ems import EMS, compute_ems from .receptive_field import ReceptiveField from .search_light import GeneralizingEstimator, SlidingEstimator @@ -48,4 +51,4 @@ from .transformer import ( UnsupervisedSpatialFilter, Vectorizer, ) -from .xdawn import XdawnTransformer +from .xdawn import XdawnTransformer, read_xdawn_transformer diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index adf857d11b9..67763f01a5f 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -9,7 +9,16 @@ from .._fiff.meas_info import Info from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT -from ..utils import _check_option, _validate_type, fill_doc, legacy +from ..utils import ( + _check_fname, + _check_option, + _import_h5io_funcs, + _validate_type, + fill_doc, + legacy, + verbose, +) +from ..utils.check import check_fname from ._covs_ged import _csp_estimate, _spoc_estimate from ._mod_ged import _csp_mod, _spoc_mod from .base import _GEDTransformer @@ -160,6 +169,8 @@ def __init__( R_func=sum, ) + _save_fname_type = "csp" + def __sklearn_tags__(self): """Tag the transformer.""" tags = super().__sklearn_tags__() @@ -167,6 +178,77 @@ def __sklearn_tags__(self): tags.target_tags.multi_output = True return tags + def __getstate__(self): + """Prepare state for serialization.""" + state = self.__dict__.copy() + state.pop("cov_callable", None) + state.pop("mod_ged_callable", None) + state.pop("R_func", None) + return state + + _required_state_keys = ( + "n_components", + "reg", + "log", + "cov_est", + "transform_into", + "norm_trace", + "cov_method_params", + "restr_type", + "info", + "rank", + "component_order", + ) + + def __setstate__(self, state): + """Restore state from serialization.""" + missing = [k for k in self._required_state_keys if k not in state] + if missing: + raise ValueError( + f"State dict is missing required keys: {missing}. " + "The state may be from an incompatible version of MNE." + ) + if state["info"] is not None: + state["info"] = Info(**state["info"]) + self.__dict__.update(state) + self.cov_callable = partial( + _csp_estimate, + reg=self.reg, + cov_method_params=self.cov_method_params, + cov_est=self.cov_est, + info=self.info, + rank=self.rank, + norm_trace=self.norm_trace, + ) + self.mod_ged_callable = partial(_csp_mod, evecs_order=self.component_order) + self.R_func = sum + + @verbose + def save(self, fname, *, overwrite=False, verbose=None): + """Save the CSP object to disk (in HDF5 format). + + Parameters + ---------- + fname : path-like + The file path to save to. Should end with ``'.h5'`` or ``'.hdf5'``. + %(overwrite)s + %(verbose)s + + See Also + -------- + mne.decoding.read_csp + """ + _, write_hdf5 = _import_h5io_funcs() + check_fname(fname, self._save_fname_type, (".h5", ".hdf5")) + fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) + write_hdf5( + fname, + self.__getstate__(), + overwrite=overwrite, + title="mnepython", + slash="replace", + ) + def _validate_params(self, *, y): _validate_type(self.n_components, int, "n_components") if hasattr(self, "cov_est"): @@ -766,12 +848,63 @@ def __init__( delattr(self, "cov_est") delattr(self, "norm_trace") + _save_fname_type = "spoc" + def __sklearn_tags__(self): """Tag the transformer.""" tags = super().__sklearn_tags__() tags.target_tags.multi_output = False return tags + _required_state_keys = ( + "n_components", + "reg", + "log", + "transform_into", + "cov_method_params", + "restr_type", + "info", + "rank", + ) + + def __setstate__(self, state): + """Restore state from serialization.""" + missing = [k for k in self._required_state_keys if k not in state] + if missing: + raise ValueError( + f"State dict is missing required keys: {missing}. " + "The state may be from an incompatible version of MNE." + ) + if state["info"] is not None: + state["info"] = Info(**state["info"]) + self.__dict__.update(state) + self.cov_callable = partial( + _spoc_estimate, + reg=self.reg, + cov_method_params=self.cov_method_params, + info=self.info, + rank=self.rank, + ) + self.mod_ged_callable = _spoc_mod + self.R_func = None + + @verbose + def save(self, fname, *, overwrite=False, verbose=None): + """Save the SPoC object to disk (in HDF5 format). + + Parameters + ---------- + fname : path-like + The file path to save to. Should end with ``'.h5'`` or ``'.hdf5'``. + %(overwrite)s + %(verbose)s + + See Also + -------- + mne.decoding.read_spoc + """ + super().save(fname, overwrite=overwrite, verbose=verbose) + def fit(self, X, y): """Estimate the SPoC decomposition on epochs. @@ -848,3 +981,57 @@ def fit_transform(self, X, y=None, **fit_params): """ # use parent TransformerMixin method but with custom docstring return super().fit_transform(X, y=y, **fit_params) + + +def read_csp(fname): + """Load a saved :class:`~mne.decoding.CSP` object from disk. + + Parameters + ---------- + fname : path-like + Path to a CSP file in HDF5 format, which should end with ``.h5`` or + ``.hdf5``. + + Returns + ------- + csp : instance of :class:`~mne.decoding.CSP` + The loaded CSP object with all fitted attributes restored. + + See Also + -------- + mne.decoding.CSP.save + """ + read_hdf5, _ = _import_h5io_funcs() + _validate_type(fname, "path-like", "fname") + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + state = read_hdf5(fname, title="mnepython", slash="replace") + csp = object.__new__(CSP) + csp.__setstate__(state) + return csp + + +def read_spoc(fname): + """Load a saved :class:`~mne.decoding.SPoC` object from disk. + + Parameters + ---------- + fname : path-like + Path to a SPoC file in HDF5 format, which should end with ``.h5`` or + ``.hdf5``. + + Returns + ------- + spoc : instance of :class:`~mne.decoding.SPoC` + The loaded SPoC object with all fitted attributes restored. + + See Also + -------- + mne.decoding.SPoC.save + """ + read_hdf5, _ = _import_h5io_funcs() + _validate_type(fname, "path-like", "fname") + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + state = read_hdf5(fname, title="mnepython", slash="replace") + spoc = object.__new__(SPoC) + spoc.__setstate__(state) + return spoc diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 6d30cf84785..021e2621b8a 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -214,6 +214,7 @@ def __setstate__(self, state): f"State dict is missing required keys: {missing}. " "The state may be from an incompatible version of MNE." ) + state["info"] = Info(**state["info"]) self.__dict__.update(state) self.cov_callable, self.mod_ged_callable = _create_callables( reg=self.reg, @@ -475,7 +476,7 @@ def read_ssd(fname): See Also -------- - SSD.save + mne.decoding.SSD.save """ read_hdf5, _ = _import_h5io_funcs() _validate_type(fname, "path-like", "fname") diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py index f3f5e16cf98..0caadd5184a 100644 --- a/mne/decoding/tests/test_csp.py +++ b/mne/decoding/tests/test_csp.py @@ -22,7 +22,7 @@ from sklearn.utils.estimator_checks import parametrize_with_checks from mne import Epochs, compute_proj_raw, io, pick_types, read_events -from mne.decoding import CSP, LinearModel, Scaler, SPoC, get_coef +from mne.decoding import CSP, LinearModel, Scaler, SPoC, get_coef, read_csp, read_spoc from mne.decoding.csp import _ajd_pham from mne.utils import catch_logging @@ -497,3 +497,114 @@ def test_sklearn_compliance(estimator, check): """Test compliance with sklearn.""" pytest.importorskip("sklearn", minversion="1.5") # TODO VERSION remove on 1.5+ check(estimator) + + +def test_csp_save_load(tmp_path): + """Test that CSP can be saved to disk and loaded back correctly.""" + h5io = pytest.importorskip("h5io") + rng = np.random.RandomState(42) + n_epochs, n_channels = 40, 10 + X = rng.randn(n_epochs, n_channels, 50) + y = np.array([0] * 20 + [1] * 20) + + csp = CSP(n_components=4) + csp.fit(X, y) + + state = csp.__getstate__() + assert "cov_callable" not in state + assert "mod_ged_callable" not in state + assert "R_func" not in state + + fname = tmp_path / "test_csp.h5" + csp.save(fname) + + csp_loaded = read_csp(fname) + + assert hasattr(csp_loaded, "cov_callable") + assert hasattr(csp_loaded, "mod_ged_callable") + assert callable(csp_loaded.cov_callable) + assert callable(csp_loaded.mod_ged_callable) + + # Check fitted array attributes are restored + assert_array_almost_equal(csp.filters_, csp_loaded.filters_) + assert_array_almost_equal(csp.patterns_, csp_loaded.patterns_) + + # Check scalar/param attributes + assert csp.n_components == csp_loaded.n_components + assert csp.component_order == csp_loaded.component_order + assert csp.transform_into == csp_loaded.transform_into + assert csp.log == csp_loaded.log + assert csp.reg == csp_loaded.reg + assert csp.norm_trace == csp_loaded.norm_trace + + # Check transform output matches + X_orig = csp.transform(X) + X_loaded = csp_loaded.transform(X) + assert_array_almost_equal(X_orig, X_loaded) + + with pytest.raises(FileExistsError): + csp.save(fname) + csp.save(fname, overwrite=True) + + # Check that loading an HDF5 file with missing keys raises an error + bad_fname = tmp_path / "bad_csp.h5" + h5io.write_hdf5(bad_fname, dict(foo="bar"), title="mnepython", slash="replace") + with pytest.raises(ValueError, match="missing required keys"): + read_csp(bad_fname) + + with pytest.raises(OSError, match="not found"): + read_csp(tmp_path / "nonexistent.h5") + + +def test_spoc_save_load(tmp_path): + """Test that SPoC can be saved to disk and loaded back correctly.""" + h5io = pytest.importorskip("h5io") + rng = np.random.RandomState(42) + X = rng.randn(10, 10, 20) + y = rng.randn(10) + + spoc = SPoC(n_components=4) + spoc.fit(X, y) + + state = spoc.__getstate__() + assert "cov_callable" not in state + assert "mod_ged_callable" not in state + assert "R_func" not in state + + fname = tmp_path / "test_spoc.h5" + spoc.save(fname) + + spoc_loaded = read_spoc(fname) + + assert hasattr(spoc_loaded, "cov_callable") + assert hasattr(spoc_loaded, "mod_ged_callable") + assert callable(spoc_loaded.cov_callable) + assert callable(spoc_loaded.mod_ged_callable) + + # Check fitted array attributes are restored + assert_array_almost_equal(spoc.filters_, spoc_loaded.filters_) + assert_array_almost_equal(spoc.patterns_, spoc_loaded.patterns_) + + # Check scalar/param attributes + assert spoc.n_components == spoc_loaded.n_components + assert spoc.transform_into == spoc_loaded.transform_into + assert spoc.log == spoc_loaded.log + assert spoc.reg == spoc_loaded.reg + + # Check transform output matches + X_orig = spoc.transform(X) + X_loaded = spoc_loaded.transform(X) + assert_array_almost_equal(X_orig, X_loaded) + + with pytest.raises(FileExistsError): + spoc.save(fname) + spoc.save(fname, overwrite=True) + + # Check that loading an HDF5 file with missing keys raises an error + bad_fname = tmp_path / "bad_spoc.h5" + h5io.write_hdf5(bad_fname, dict(foo="bar"), title="mnepython", slash="replace") + with pytest.raises(ValueError, match="missing required keys"): + read_spoc(bad_fname) + + with pytest.raises(OSError, match="not found"): + read_spoc(tmp_path / "nonexistent.h5") diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 810d7b86699..5ddd00098e5 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -637,12 +637,6 @@ def test_ssd_save_load(tmp_path): with pytest.raises(OSError, match="not found"): read_ssd(tmp_path / "nonexistent.h5") - fname_rt = tmp_path / "test_ssd_rt.h5" - ssd.save(fname_rt) - ssd_rt = read_ssd(fname_rt) - assert_array_almost_equal(ssd.filters_, ssd_rt.filters_) - assert_array_almost_equal(ssd.transform(X), ssd_rt.transform(X)) - def test_get_spectral_ratio(): """Test that method is the same as function in _mod_ged.py.""" diff --git a/mne/decoding/tests/test_xdawn.py b/mne/decoding/tests/test_xdawn.py index a2936686b59..2b0839c7b40 100644 --- a/mne/decoding/tests/test_xdawn.py +++ b/mne/decoding/tests/test_xdawn.py @@ -2,12 +2,14 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import numpy as np import pytest +from numpy.testing import assert_array_almost_equal pytest.importorskip("sklearn") from sklearn.utils.estimator_checks import parametrize_with_checks -from mne.decoding import XdawnTransformer +from mne.decoding import XdawnTransformer, read_xdawn_transformer @pytest.mark.filterwarnings("ignore:.*Only one sample available.*") @@ -15,3 +17,57 @@ def test_sklearn_compliance(estimator, check): """Test compliance with sklearn.""" check(estimator) + + +def test_xdawn_save_load(tmp_path): + """Test that XdawnTransformer can be saved to disk and loaded correctly.""" + h5io = pytest.importorskip("h5io") + rng = np.random.RandomState(42) + n_epochs, n_channels, n_times = 40, 10, 50 + X = rng.randn(n_epochs, n_channels, n_times) + y = rng.randint(0, 2, n_epochs) + + xdawn = XdawnTransformer(n_components=2) + xdawn.fit(X, y) + + state = xdawn.__getstate__() + assert "cov_callable" not in state + assert "mod_ged_callable" not in state + + fname = tmp_path / "test_xdawn.h5" + xdawn.save(fname) + + xdawn_loaded = read_xdawn_transformer(fname) + + assert hasattr(xdawn_loaded, "cov_callable") + assert hasattr(xdawn_loaded, "mod_ged_callable") + assert callable(xdawn_loaded.cov_callable) + assert callable(xdawn_loaded.mod_ged_callable) + + # Check fitted array attributes are restored + assert_array_almost_equal(xdawn.filters_, xdawn_loaded.filters_) + assert_array_almost_equal(xdawn.patterns_, xdawn_loaded.patterns_) + + # Check scalar/param attributes + assert xdawn.n_components == xdawn_loaded.n_components + assert xdawn.reg == xdawn_loaded.reg + assert xdawn.rank == xdawn_loaded.rank + assert xdawn.restr_type == xdawn_loaded.restr_type + + # Check transform output matches + X_orig = xdawn.transform(X) + X_loaded = xdawn_loaded.transform(X) + assert_array_almost_equal(X_orig, X_loaded) + + with pytest.raises(FileExistsError): + xdawn.save(fname) + xdawn.save(fname, overwrite=True) + + # Check that loading an HDF5 file with missing keys raises an error + bad_fname = tmp_path / "bad_xdawn.h5" + h5io.write_hdf5(bad_fname, dict(foo="bar"), title="mnepython", slash="replace") + with pytest.raises(ValueError, match="missing required keys"): + read_xdawn_transformer(bad_fname) + + with pytest.raises(OSError, match="not found"): + read_xdawn_transformer(tmp_path / "nonexistent.h5") diff --git a/mne/decoding/xdawn.py b/mne/decoding/xdawn.py index a34d042e30a..9b51bfa966a 100644 --- a/mne/decoding/xdawn.py +++ b/mne/decoding/xdawn.py @@ -12,7 +12,14 @@ from ..decoding._covs_ged import _xdawn_estimate from ..decoding._mod_ged import _xdawn_mod from ..decoding.base import _GEDTransformer -from ..utils import _validate_type, fill_doc +from ..utils import ( + _check_fname, + _import_h5io_funcs, + _validate_type, + fill_doc, + verbose, +) +from ..utils.check import check_fname @fill_doc @@ -123,6 +130,70 @@ def __sklearn_tags__(self): tags.target_tags.required = True return tags + def __getstate__(self): + """Prepare state for serialization.""" + state = self.__dict__.copy() + state.pop("cov_callable", None) + state.pop("mod_ged_callable", None) + return state + + _required_state_keys = ( + "n_components", + "reg", + "signal_cov", + "cov_method_params", + "restr_type", + "info", + "rank", + ) + + def __setstate__(self, state): + """Restore state from serialization.""" + missing = [k for k in self._required_state_keys if k not in state] + if missing: + raise ValueError( + f"State dict is missing required keys: {missing}. " + "The state may be from an incompatible version of MNE." + ) + if state["info"] is not None: + state["info"] = Info(**state["info"]) + self.__dict__.update(state) + self.cov_callable = partial( + _xdawn_estimate, + reg=self.reg, + cov_method_params=self.cov_method_params, + R=self.signal_cov, + info=self.info, + rank=self.rank, + ) + self.mod_ged_callable = _xdawn_mod + + @verbose + def save(self, fname, *, overwrite=False, verbose=None): + """Save the XdawnTransformer object to disk (in HDF5 format). + + Parameters + ---------- + fname : path-like + The file path to save to. Should end with ``'.h5'`` or ``'.hdf5'``. + %(overwrite)s + %(verbose)s + + See Also + -------- + mne.decoding.read_xdawn_transformer + """ + _, write_hdf5 = _import_h5io_funcs() + check_fname(fname, "xdawn_transformer", (".h5", ".hdf5")) + fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) + write_hdf5( + fname, + self.__getstate__(), + overwrite=overwrite, + title="mnepython", + slash="replace", + ) + def _validate_params(self, X): _validate_type(self.n_components, int, "n_components") @@ -211,3 +282,30 @@ def inverse_transform(self, X): pick_patterns = self._subset_multi_components(name="patterns") # Transform return np.dot(pick_patterns.T, X).transpose(1, 0, 2) + + +def read_xdawn_transformer(fname): + """Load a saved :class:`~mne.decoding.XdawnTransformer` object from disk. + + Parameters + ---------- + fname : path-like + Path to an XdawnTransformer file in HDF5 format, which should end with + ``.h5`` or ``.hdf5``. + + Returns + ------- + xdawn_transformer : instance of :class:`~mne.decoding.XdawnTransformer` + The loaded XdawnTransformer object with all fitted attributes restored. + + See Also + -------- + mne.decoding.XdawnTransformer.save + """ + read_hdf5, _ = _import_h5io_funcs() + _validate_type(fname, "path-like", "fname") + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + state = read_hdf5(fname, title="mnepython", slash="replace") + xdawn_transformer = object.__new__(XdawnTransformer) + xdawn_transformer.__setstate__(state) + return xdawn_transformer From cfa26c4fbb68ef1837e6d2e0c317d244f471ba99 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Tue, 10 Mar 2026 05:40:45 +0000 Subject: [PATCH 8/8] extending framework to other classes --- doc/changes/dev/13718.other.rst | 2 +- mne/decoding/ssd.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/changes/dev/13718.other.rst b/doc/changes/dev/13718.other.rst index b8cf9fd84cd..230eca004bd 100644 --- a/doc/changes/dev/13718.other.rst +++ b/doc/changes/dev/13718.other.rst @@ -1 +1 @@ -Add native save and read functionality for :class:`mne.decoding.SSD` and :class:`mne.decoding.SPoC` objects, by `Aniket Singh Yadav`_. \ No newline at end of file +Add native save and read functionality for :class:`mne.decoding.SSD`, :class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`, and :class:`mne.decoding.XdawnTransformer` objects, by `Aniket Singh Yadav`_. \ No newline at end of file diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 021e2621b8a..a267b42cd2d 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -214,7 +214,8 @@ def __setstate__(self, state): f"State dict is missing required keys: {missing}. " "The state may be from an incompatible version of MNE." ) - state["info"] = Info(**state["info"]) + if isinstance(state["info"], dict): + state["info"] = Info(**state["info"]) self.__dict__.update(state) self.cov_callable, self.mod_ged_callable = _create_callables( reg=self.reg,