diff --git a/doc/api/decoding.rst b/doc/api/decoding.rst index f8f2257825f..fe21427a093 100644 --- a/doc/api/decoding.rst +++ b/doc/api/decoding.rst @@ -41,3 +41,7 @@ Functions that assist with decoding and model fitting: cross_val_multiscore get_coef get_spatial_filter_from_estimator + read_csp + read_spoc + read_ssd + read_xdawn_transformer diff --git a/doc/changes/dev/13718.other.rst b/doc/changes/dev/13718.other.rst new file mode 100644 index 00000000000..230eca004bd --- /dev/null +++ b/doc/changes/dev/13718.other.rst @@ -0,0 +1 @@ +Add native save and read functionality for :class:`mne.decoding.SSD`, :class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`, and :class:`mne.decoding.XdawnTransformer` objects, by `Aniket Singh Yadav`_. \ No newline at end of file diff --git a/mne/decoding/__init__.pyi b/mne/decoding/__init__.pyi index 1131f1597c5..3895604bfa8 100644 --- a/mne/decoding/__init__.pyi +++ b/mne/decoding/__init__.pyi @@ -10,6 +10,10 @@ __all__ = [ "SPoC", "SSD", "Scaler", + "read_csp", + "read_spoc", + "read_ssd", + "read_xdawn_transformer", "SlidingEstimator", "SpatialFilter", "TemporalFilter", @@ -31,12 +35,12 @@ from .base import ( cross_val_multiscore, get_coef, ) -from .csp import CSP, SPoC +from .csp import CSP, SPoC, read_csp, read_spoc from .ems import EMS, compute_ems from .receptive_field import ReceptiveField from .search_light import GeneralizingEstimator, SlidingEstimator from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator -from .ssd import SSD +from .ssd import SSD, read_ssd from .time_delaying_ridge import TimeDelayingRidge from .time_frequency import TimeFrequency from .transformer import ( @@ -47,4 +51,4 @@ from .transformer import ( UnsupervisedSpatialFilter, Vectorizer, ) -from .xdawn import XdawnTransformer +from .xdawn import XdawnTransformer, read_xdawn_transformer diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index adf857d11b9..67763f01a5f 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -9,7 +9,16 @@ from .._fiff.meas_info import Info from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT -from ..utils import _check_option, _validate_type, fill_doc, legacy +from ..utils import ( + _check_fname, + _check_option, + _import_h5io_funcs, + _validate_type, + fill_doc, + legacy, + verbose, +) +from ..utils.check import check_fname from ._covs_ged import _csp_estimate, _spoc_estimate from ._mod_ged import _csp_mod, _spoc_mod from .base import _GEDTransformer @@ -160,6 +169,8 @@ def __init__( R_func=sum, ) + _save_fname_type = "csp" + def __sklearn_tags__(self): """Tag the transformer.""" tags = super().__sklearn_tags__() @@ -167,6 +178,77 @@ def __sklearn_tags__(self): tags.target_tags.multi_output = True return tags + def __getstate__(self): + """Prepare state for serialization.""" + state = self.__dict__.copy() + state.pop("cov_callable", None) + state.pop("mod_ged_callable", None) + state.pop("R_func", None) + return state + + _required_state_keys = ( + "n_components", + "reg", + "log", + "cov_est", + "transform_into", + "norm_trace", + "cov_method_params", + "restr_type", + "info", + "rank", + "component_order", + ) + + def __setstate__(self, state): + """Restore state from serialization.""" + missing = [k for k in self._required_state_keys if k not in state] + if missing: + raise ValueError( + f"State dict is missing required keys: {missing}. " + "The state may be from an incompatible version of MNE." + ) + if state["info"] is not None: + state["info"] = Info(**state["info"]) + self.__dict__.update(state) + self.cov_callable = partial( + _csp_estimate, + reg=self.reg, + cov_method_params=self.cov_method_params, + cov_est=self.cov_est, + info=self.info, + rank=self.rank, + norm_trace=self.norm_trace, + ) + self.mod_ged_callable = partial(_csp_mod, evecs_order=self.component_order) + self.R_func = sum + + @verbose + def save(self, fname, *, overwrite=False, verbose=None): + """Save the CSP object to disk (in HDF5 format). + + Parameters + ---------- + fname : path-like + The file path to save to. Should end with ``'.h5'`` or ``'.hdf5'``. + %(overwrite)s + %(verbose)s + + See Also + -------- + mne.decoding.read_csp + """ + _, write_hdf5 = _import_h5io_funcs() + check_fname(fname, self._save_fname_type, (".h5", ".hdf5")) + fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) + write_hdf5( + fname, + self.__getstate__(), + overwrite=overwrite, + title="mnepython", + slash="replace", + ) + def _validate_params(self, *, y): _validate_type(self.n_components, int, "n_components") if hasattr(self, "cov_est"): @@ -766,12 +848,63 @@ def __init__( delattr(self, "cov_est") delattr(self, "norm_trace") + _save_fname_type = "spoc" + def __sklearn_tags__(self): """Tag the transformer.""" tags = super().__sklearn_tags__() tags.target_tags.multi_output = False return tags + _required_state_keys = ( + "n_components", + "reg", + "log", + "transform_into", + "cov_method_params", + "restr_type", + "info", + "rank", + ) + + def __setstate__(self, state): + """Restore state from serialization.""" + missing = [k for k in self._required_state_keys if k not in state] + if missing: + raise ValueError( + f"State dict is missing required keys: {missing}. " + "The state may be from an incompatible version of MNE." + ) + if state["info"] is not None: + state["info"] = Info(**state["info"]) + self.__dict__.update(state) + self.cov_callable = partial( + _spoc_estimate, + reg=self.reg, + cov_method_params=self.cov_method_params, + info=self.info, + rank=self.rank, + ) + self.mod_ged_callable = _spoc_mod + self.R_func = None + + @verbose + def save(self, fname, *, overwrite=False, verbose=None): + """Save the SPoC object to disk (in HDF5 format). + + Parameters + ---------- + fname : path-like + The file path to save to. Should end with ``'.h5'`` or ``'.hdf5'``. + %(overwrite)s + %(verbose)s + + See Also + -------- + mne.decoding.read_spoc + """ + super().save(fname, overwrite=overwrite, verbose=verbose) + def fit(self, X, y): """Estimate the SPoC decomposition on epochs. @@ -848,3 +981,57 @@ def fit_transform(self, X, y=None, **fit_params): """ # use parent TransformerMixin method but with custom docstring return super().fit_transform(X, y=y, **fit_params) + + +def read_csp(fname): + """Load a saved :class:`~mne.decoding.CSP` object from disk. + + Parameters + ---------- + fname : path-like + Path to a CSP file in HDF5 format, which should end with ``.h5`` or + ``.hdf5``. + + Returns + ------- + csp : instance of :class:`~mne.decoding.CSP` + The loaded CSP object with all fitted attributes restored. + + See Also + -------- + mne.decoding.CSP.save + """ + read_hdf5, _ = _import_h5io_funcs() + _validate_type(fname, "path-like", "fname") + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + state = read_hdf5(fname, title="mnepython", slash="replace") + csp = object.__new__(CSP) + csp.__setstate__(state) + return csp + + +def read_spoc(fname): + """Load a saved :class:`~mne.decoding.SPoC` object from disk. + + Parameters + ---------- + fname : path-like + Path to a SPoC file in HDF5 format, which should end with ``.h5`` or + ``.hdf5``. + + Returns + ------- + spoc : instance of :class:`~mne.decoding.SPoC` + The loaded SPoC object with all fitted attributes restored. + + See Also + -------- + mne.decoding.SPoC.save + """ + read_hdf5, _ = _import_h5io_funcs() + _validate_type(fname, "path-like", "fname") + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + state = read_hdf5(fname, title="mnepython", slash="replace") + spoc = object.__new__(SPoC) + spoc.__setstate__(state) + return spoc diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 41d67ece8c6..a267b42cd2d 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -11,15 +11,54 @@ from .._fiff.pick import _picks_to_idx from ..filter import filter_data from ..utils import ( + _check_fname, + _import_h5io_funcs, _validate_type, fill_doc, logger, + verbose, ) +from ..utils.check import check_fname from ._covs_ged import _ssd_estimate from ._mod_ged import _get_spectral_ratio, _ssd_mod from .base import _GEDTransformer +def _create_callables( + reg, + cov_method_params, + info, + picks, + n_fft, + filt_params_signal, + filt_params_noise, + rank, + sort_by_spectral_ratio, +): + """Create covariance and mod_ged callables for SSD. + + Returns + ------- + cov_callable : callable + Partial function for computing SSD covariance estimates. + mod_ged_callable : callable + Function for modifying the GED result. + """ + cov_callable = partial( + _ssd_estimate, + reg=reg, + cov_method_params=cov_method_params, + info=info, + picks=picks, + n_fft=n_fft, + filt_params_signal=filt_params_signal, + filt_params_noise=filt_params_noise, + rank=rank, + sort_by_spectral_ratio=sort_by_spectral_ratio, + ) + return cov_callable, _ssd_mod + + @fill_doc class SSD(_GEDTransformer): """ @@ -127,25 +166,95 @@ def __init__( self.restr_type = restr_type self.rank = rank - cov_callable = partial( - _ssd_estimate, - reg=reg, - cov_method_params=cov_method_params, - info=info, - picks=picks, - n_fft=n_fft, - filt_params_signal=filt_params_signal, - filt_params_noise=filt_params_noise, - rank=rank, - sort_by_spectral_ratio=sort_by_spectral_ratio, + self.cov_callable, self.mod_ged_callable = _create_callables( + reg=self.reg, + cov_method_params=self.cov_method_params, + info=self.info, + picks=self.picks, + n_fft=self.n_fft, + filt_params_signal=self.filt_params_signal, + filt_params_noise=self.filt_params_noise, + rank=self.rank, + sort_by_spectral_ratio=self.sort_by_spectral_ratio, ) super().__init__( n_components=n_components, - cov_callable=cov_callable, - mod_ged_callable=_ssd_mod, + cov_callable=self.cov_callable, + mod_ged_callable=self.mod_ged_callable, restr_type=restr_type, ) + def __getstate__(self): + """Prepare state for serialization.""" + state = self.__dict__.copy() + state.pop("cov_callable", None) + state.pop("mod_ged_callable", None) + return state + + _required_state_keys = ( + "info", + "filt_params_signal", + "filt_params_noise", + "reg", + "n_components", + "picks", + "sort_by_spectral_ratio", + "return_filtered", + "n_fft", + "cov_method_params", + "restr_type", + "rank", + ) + + def __setstate__(self, state): + """Restore state from serialization.""" + missing = [k for k in self._required_state_keys if k not in state] + if missing: + raise ValueError( + f"State dict is missing required keys: {missing}. " + "The state may be from an incompatible version of MNE." + ) + if isinstance(state["info"], dict): + state["info"] = Info(**state["info"]) + self.__dict__.update(state) + self.cov_callable, self.mod_ged_callable = _create_callables( + reg=self.reg, + cov_method_params=self.cov_method_params, + info=self.info, + picks=self.picks, + n_fft=self.n_fft, + filt_params_signal=self.filt_params_signal, + filt_params_noise=self.filt_params_noise, + rank=self.rank, + sort_by_spectral_ratio=self.sort_by_spectral_ratio, + ) + + @verbose + def save(self, fname, *, overwrite=False, verbose=None): + """Save the SSD object to disk (in HDF5 format). + + Parameters + ---------- + fname : path-like + The file path to save to. Should end with ``'.h5'`` or ``'.hdf5'``. + %(overwrite)s + %(verbose)s + + See Also + -------- + mne.decoding.read_ssd + """ + _, write_hdf5 = _import_h5io_funcs() + check_fname(fname, "ssd", (".h5", ".hdf5")) + fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) + write_hdf5( + fname, + self.__getstate__(), + overwrite=overwrite, + title="mnepython", + slash="replace", + ) + def _validate_params(self, X): if isinstance(self.info, float): # special case, mostly for testing self.sfreq_ = self.info @@ -350,3 +459,30 @@ def apply(self, X): pick_patterns = self.patterns_[: self.n_components].T X = pick_patterns @ X_ssd return X + + +def read_ssd(fname): + """Load a saved :class:`~mne.decoding.SSD` object from disk. + + Parameters + ---------- + fname : path-like + Path to an SSD file in HDF5 format, which should end with ``.h5`` or + ``.hdf5``. + + Returns + ------- + ssd : instance of :class:`~mne.decoding.SSD` + The loaded SSD object with all fitted attributes restored. + + See Also + -------- + mne.decoding.SSD.save + """ + read_hdf5, _ = _import_h5io_funcs() + _validate_type(fname, "path-like", "fname") + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + state = read_hdf5(fname, title="mnepython", slash="replace") + ssd = object.__new__(SSD) + ssd.__setstate__(state) + return ssd diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py index f3f5e16cf98..0caadd5184a 100644 --- a/mne/decoding/tests/test_csp.py +++ b/mne/decoding/tests/test_csp.py @@ -22,7 +22,7 @@ from sklearn.utils.estimator_checks import parametrize_with_checks from mne import Epochs, compute_proj_raw, io, pick_types, read_events -from mne.decoding import CSP, LinearModel, Scaler, SPoC, get_coef +from mne.decoding import CSP, LinearModel, Scaler, SPoC, get_coef, read_csp, read_spoc from mne.decoding.csp import _ajd_pham from mne.utils import catch_logging @@ -497,3 +497,114 @@ def test_sklearn_compliance(estimator, check): """Test compliance with sklearn.""" pytest.importorskip("sklearn", minversion="1.5") # TODO VERSION remove on 1.5+ check(estimator) + + +def test_csp_save_load(tmp_path): + """Test that CSP can be saved to disk and loaded back correctly.""" + h5io = pytest.importorskip("h5io") + rng = np.random.RandomState(42) + n_epochs, n_channels = 40, 10 + X = rng.randn(n_epochs, n_channels, 50) + y = np.array([0] * 20 + [1] * 20) + + csp = CSP(n_components=4) + csp.fit(X, y) + + state = csp.__getstate__() + assert "cov_callable" not in state + assert "mod_ged_callable" not in state + assert "R_func" not in state + + fname = tmp_path / "test_csp.h5" + csp.save(fname) + + csp_loaded = read_csp(fname) + + assert hasattr(csp_loaded, "cov_callable") + assert hasattr(csp_loaded, "mod_ged_callable") + assert callable(csp_loaded.cov_callable) + assert callable(csp_loaded.mod_ged_callable) + + # Check fitted array attributes are restored + assert_array_almost_equal(csp.filters_, csp_loaded.filters_) + assert_array_almost_equal(csp.patterns_, csp_loaded.patterns_) + + # Check scalar/param attributes + assert csp.n_components == csp_loaded.n_components + assert csp.component_order == csp_loaded.component_order + assert csp.transform_into == csp_loaded.transform_into + assert csp.log == csp_loaded.log + assert csp.reg == csp_loaded.reg + assert csp.norm_trace == csp_loaded.norm_trace + + # Check transform output matches + X_orig = csp.transform(X) + X_loaded = csp_loaded.transform(X) + assert_array_almost_equal(X_orig, X_loaded) + + with pytest.raises(FileExistsError): + csp.save(fname) + csp.save(fname, overwrite=True) + + # Check that loading an HDF5 file with missing keys raises an error + bad_fname = tmp_path / "bad_csp.h5" + h5io.write_hdf5(bad_fname, dict(foo="bar"), title="mnepython", slash="replace") + with pytest.raises(ValueError, match="missing required keys"): + read_csp(bad_fname) + + with pytest.raises(OSError, match="not found"): + read_csp(tmp_path / "nonexistent.h5") + + +def test_spoc_save_load(tmp_path): + """Test that SPoC can be saved to disk and loaded back correctly.""" + h5io = pytest.importorskip("h5io") + rng = np.random.RandomState(42) + X = rng.randn(10, 10, 20) + y = rng.randn(10) + + spoc = SPoC(n_components=4) + spoc.fit(X, y) + + state = spoc.__getstate__() + assert "cov_callable" not in state + assert "mod_ged_callable" not in state + assert "R_func" not in state + + fname = tmp_path / "test_spoc.h5" + spoc.save(fname) + + spoc_loaded = read_spoc(fname) + + assert hasattr(spoc_loaded, "cov_callable") + assert hasattr(spoc_loaded, "mod_ged_callable") + assert callable(spoc_loaded.cov_callable) + assert callable(spoc_loaded.mod_ged_callable) + + # Check fitted array attributes are restored + assert_array_almost_equal(spoc.filters_, spoc_loaded.filters_) + assert_array_almost_equal(spoc.patterns_, spoc_loaded.patterns_) + + # Check scalar/param attributes + assert spoc.n_components == spoc_loaded.n_components + assert spoc.transform_into == spoc_loaded.transform_into + assert spoc.log == spoc_loaded.log + assert spoc.reg == spoc_loaded.reg + + # Check transform output matches + X_orig = spoc.transform(X) + X_loaded = spoc_loaded.transform(X) + assert_array_almost_equal(X_orig, X_loaded) + + with pytest.raises(FileExistsError): + spoc.save(fname) + spoc.save(fname, overwrite=True) + + # Check that loading an HDF5 file with missing keys raises an error + bad_fname = tmp_path / "bad_spoc.h5" + h5io.write_hdf5(bad_fname, dict(foo="bar"), title="mnepython", slash="replace") + with pytest.raises(ValueError, match="missing required keys"): + read_spoc(bad_fname) + + with pytest.raises(OSError, match="not found"): + read_spoc(tmp_path / "nonexistent.h5") diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 086413b043f..5ddd00098e5 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -18,7 +18,7 @@ from mne._fiff.pick import _picks_to_idx from mne.decoding import CSP from mne.decoding._mod_ged import _get_spectral_ratio -from mne.decoding.ssd import SSD +from mne.decoding.ssd import SSD, read_ssd from mne.filter import filter_data from mne.time_frequency import psd_array_welch @@ -570,6 +570,74 @@ def test_picks_arg(): ssd.fit(X).transform(X) +def test_ssd_save_load(tmp_path): + """Test that SSD can be saved to disk and loaded back correctly.""" + h5io = pytest.importorskip("h5io") + X, _, _ = simulate_data() + sf = 250 + n_channels = X.shape[0] + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") + n_components = 5 + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + + ssd = SSD(info, filt_params_signal, filt_params_noise, n_components=n_components) + ssd.fit(X) + + state = ssd.__getstate__() + assert "cov_callable" not in state + assert "mod_ged_callable" not in state + + fname = tmp_path / "test_ssd.h5" + ssd.save(fname) + + ssd_loaded = read_ssd(fname) + + assert hasattr(ssd_loaded, "cov_callable") + assert hasattr(ssd_loaded, "mod_ged_callable") + assert callable(ssd_loaded.cov_callable) + assert callable(ssd_loaded.mod_ged_callable) + + # Check fitted array attributes are restored + assert_array_almost_equal(ssd.filters_, ssd_loaded.filters_) + assert_array_almost_equal(ssd.patterns_, ssd_loaded.patterns_) + + # Check scalar/param attributes + assert ssd.n_components == ssd_loaded.n_components + assert ssd.info["sfreq"] == ssd_loaded.info["sfreq"] + assert ssd.filt_params_signal == ssd_loaded.filt_params_signal + assert ssd.filt_params_noise == ssd_loaded.filt_params_noise + + # Check transform output matches + X_orig = ssd.transform(X) + X_loaded = ssd_loaded.transform(X) + assert_array_almost_equal(X_orig, X_loaded) + + with pytest.raises(FileExistsError): + ssd.save(fname) + ssd.save(fname, overwrite=True) + + # Check that loading an HDF5 file with missing keys raises an error + bad_fname = tmp_path / "bad.h5" + h5io.write_hdf5(bad_fname, dict(foo="bar"), title="mnepython", slash="replace") + with pytest.raises(ValueError, match="missing required keys"): + read_ssd(bad_fname) + + with pytest.raises(OSError, match="not found"): + read_ssd(tmp_path / "nonexistent.h5") + + def test_get_spectral_ratio(): """Test that method is the same as function in _mod_ged.py.""" X, _, _ = simulate_data() diff --git a/mne/decoding/tests/test_xdawn.py b/mne/decoding/tests/test_xdawn.py index a2936686b59..2b0839c7b40 100644 --- a/mne/decoding/tests/test_xdawn.py +++ b/mne/decoding/tests/test_xdawn.py @@ -2,12 +2,14 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import numpy as np import pytest +from numpy.testing import assert_array_almost_equal pytest.importorskip("sklearn") from sklearn.utils.estimator_checks import parametrize_with_checks -from mne.decoding import XdawnTransformer +from mne.decoding import XdawnTransformer, read_xdawn_transformer @pytest.mark.filterwarnings("ignore:.*Only one sample available.*") @@ -15,3 +17,57 @@ def test_sklearn_compliance(estimator, check): """Test compliance with sklearn.""" check(estimator) + + +def test_xdawn_save_load(tmp_path): + """Test that XdawnTransformer can be saved to disk and loaded correctly.""" + h5io = pytest.importorskip("h5io") + rng = np.random.RandomState(42) + n_epochs, n_channels, n_times = 40, 10, 50 + X = rng.randn(n_epochs, n_channels, n_times) + y = rng.randint(0, 2, n_epochs) + + xdawn = XdawnTransformer(n_components=2) + xdawn.fit(X, y) + + state = xdawn.__getstate__() + assert "cov_callable" not in state + assert "mod_ged_callable" not in state + + fname = tmp_path / "test_xdawn.h5" + xdawn.save(fname) + + xdawn_loaded = read_xdawn_transformer(fname) + + assert hasattr(xdawn_loaded, "cov_callable") + assert hasattr(xdawn_loaded, "mod_ged_callable") + assert callable(xdawn_loaded.cov_callable) + assert callable(xdawn_loaded.mod_ged_callable) + + # Check fitted array attributes are restored + assert_array_almost_equal(xdawn.filters_, xdawn_loaded.filters_) + assert_array_almost_equal(xdawn.patterns_, xdawn_loaded.patterns_) + + # Check scalar/param attributes + assert xdawn.n_components == xdawn_loaded.n_components + assert xdawn.reg == xdawn_loaded.reg + assert xdawn.rank == xdawn_loaded.rank + assert xdawn.restr_type == xdawn_loaded.restr_type + + # Check transform output matches + X_orig = xdawn.transform(X) + X_loaded = xdawn_loaded.transform(X) + assert_array_almost_equal(X_orig, X_loaded) + + with pytest.raises(FileExistsError): + xdawn.save(fname) + xdawn.save(fname, overwrite=True) + + # Check that loading an HDF5 file with missing keys raises an error + bad_fname = tmp_path / "bad_xdawn.h5" + h5io.write_hdf5(bad_fname, dict(foo="bar"), title="mnepython", slash="replace") + with pytest.raises(ValueError, match="missing required keys"): + read_xdawn_transformer(bad_fname) + + with pytest.raises(OSError, match="not found"): + read_xdawn_transformer(tmp_path / "nonexistent.h5") diff --git a/mne/decoding/xdawn.py b/mne/decoding/xdawn.py index a34d042e30a..9b51bfa966a 100644 --- a/mne/decoding/xdawn.py +++ b/mne/decoding/xdawn.py @@ -12,7 +12,14 @@ from ..decoding._covs_ged import _xdawn_estimate from ..decoding._mod_ged import _xdawn_mod from ..decoding.base import _GEDTransformer -from ..utils import _validate_type, fill_doc +from ..utils import ( + _check_fname, + _import_h5io_funcs, + _validate_type, + fill_doc, + verbose, +) +from ..utils.check import check_fname @fill_doc @@ -123,6 +130,70 @@ def __sklearn_tags__(self): tags.target_tags.required = True return tags + def __getstate__(self): + """Prepare state for serialization.""" + state = self.__dict__.copy() + state.pop("cov_callable", None) + state.pop("mod_ged_callable", None) + return state + + _required_state_keys = ( + "n_components", + "reg", + "signal_cov", + "cov_method_params", + "restr_type", + "info", + "rank", + ) + + def __setstate__(self, state): + """Restore state from serialization.""" + missing = [k for k in self._required_state_keys if k not in state] + if missing: + raise ValueError( + f"State dict is missing required keys: {missing}. " + "The state may be from an incompatible version of MNE." + ) + if state["info"] is not None: + state["info"] = Info(**state["info"]) + self.__dict__.update(state) + self.cov_callable = partial( + _xdawn_estimate, + reg=self.reg, + cov_method_params=self.cov_method_params, + R=self.signal_cov, + info=self.info, + rank=self.rank, + ) + self.mod_ged_callable = _xdawn_mod + + @verbose + def save(self, fname, *, overwrite=False, verbose=None): + """Save the XdawnTransformer object to disk (in HDF5 format). + + Parameters + ---------- + fname : path-like + The file path to save to. Should end with ``'.h5'`` or ``'.hdf5'``. + %(overwrite)s + %(verbose)s + + See Also + -------- + mne.decoding.read_xdawn_transformer + """ + _, write_hdf5 = _import_h5io_funcs() + check_fname(fname, "xdawn_transformer", (".h5", ".hdf5")) + fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) + write_hdf5( + fname, + self.__getstate__(), + overwrite=overwrite, + title="mnepython", + slash="replace", + ) + def _validate_params(self, X): _validate_type(self.n_components, int, "n_components") @@ -211,3 +282,30 @@ def inverse_transform(self, X): pick_patterns = self._subset_multi_components(name="patterns") # Transform return np.dot(pick_patterns.T, X).transpose(1, 0, 2) + + +def read_xdawn_transformer(fname): + """Load a saved :class:`~mne.decoding.XdawnTransformer` object from disk. + + Parameters + ---------- + fname : path-like + Path to an XdawnTransformer file in HDF5 format, which should end with + ``.h5`` or ``.hdf5``. + + Returns + ------- + xdawn_transformer : instance of :class:`~mne.decoding.XdawnTransformer` + The loaded XdawnTransformer object with all fitted attributes restored. + + See Also + -------- + mne.decoding.XdawnTransformer.save + """ + read_hdf5, _ = _import_h5io_funcs() + _validate_type(fname, "path-like", "fname") + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + state = read_hdf5(fname, title="mnepython", slash="replace") + xdawn_transformer = object.__new__(XdawnTransformer) + xdawn_transformer.__setstate__(state) + return xdawn_transformer