Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/api/decoding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ Functions that assist with decoding and model fitting:
cross_val_multiscore
get_coef
get_spatial_filter_from_estimator
read_csp
read_spoc
read_ssd
read_xdawn_transformer
1 change: 1 addition & 0 deletions doc/changes/dev/13718.other.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add native save and read functionality for :class:`mne.decoding.SSD`, :class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`, and :class:`mne.decoding.XdawnTransformer` objects, by `Aniket Singh Yadav`_.
10 changes: 7 additions & 3 deletions mne/decoding/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ __all__ = [
"SPoC",
"SSD",
"Scaler",
"read_csp",
"read_spoc",
"read_ssd",
"read_xdawn_transformer",
"SlidingEstimator",
"SpatialFilter",
"TemporalFilter",
Expand All @@ -31,12 +35,12 @@ from .base import (
cross_val_multiscore,
get_coef,
)
from .csp import CSP, SPoC
from .csp import CSP, SPoC, read_csp, read_spoc
from .ems import EMS, compute_ems
from .receptive_field import ReceptiveField
from .search_light import GeneralizingEstimator, SlidingEstimator
from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator
from .ssd import SSD
from .ssd import SSD, read_ssd
from .time_delaying_ridge import TimeDelayingRidge
from .time_frequency import TimeFrequency
from .transformer import (
Expand All @@ -47,4 +51,4 @@ from .transformer import (
UnsupervisedSpatialFilter,
Vectorizer,
)
from .xdawn import XdawnTransformer
from .xdawn import XdawnTransformer, read_xdawn_transformer
189 changes: 188 additions & 1 deletion mne/decoding/csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@

from .._fiff.meas_info import Info
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
from ..utils import _check_option, _validate_type, fill_doc, legacy
from ..utils import (
_check_fname,
_check_option,
_import_h5io_funcs,
_validate_type,
fill_doc,
legacy,
verbose,
)
from ..utils.check import check_fname
from ._covs_ged import _csp_estimate, _spoc_estimate
from ._mod_ged import _csp_mod, _spoc_mod
from .base import _GEDTransformer
Expand Down Expand Up @@ -160,13 +169,86 @@ def __init__(
R_func=sum,
)

_save_fname_type = "csp"

def __sklearn_tags__(self):
"""Tag the transformer."""
tags = super().__sklearn_tags__()
tags.target_tags.required = True
tags.target_tags.multi_output = True
return tags

def __getstate__(self):
"""Prepare state for serialization."""
state = self.__dict__.copy()
state.pop("cov_callable", None)
state.pop("mod_ged_callable", None)
state.pop("R_func", None)
return state

_required_state_keys = (
"n_components",
"reg",
"log",
"cov_est",
"transform_into",
"norm_trace",
"cov_method_params",
"restr_type",
"info",
"rank",
"component_order",
)

def __setstate__(self, state):
"""Restore state from serialization."""
missing = [k for k in self._required_state_keys if k not in state]
if missing:
raise ValueError(
f"State dict is missing required keys: {missing}. "
"The state may be from an incompatible version of MNE."
)
if state["info"] is not None:
state["info"] = Info(**state["info"])
self.__dict__.update(state)
self.cov_callable = partial(
_csp_estimate,
reg=self.reg,
cov_method_params=self.cov_method_params,
cov_est=self.cov_est,
info=self.info,
rank=self.rank,
norm_trace=self.norm_trace,
)
self.mod_ged_callable = partial(_csp_mod, evecs_order=self.component_order)
self.R_func = sum

@verbose
def save(self, fname, *, overwrite=False, verbose=None):
"""Save the CSP object to disk (in HDF5 format).

Parameters
----------
fname : path-like
The file path to save to. Should end with ``'.h5'`` or ``'.hdf5'``.
%(overwrite)s
%(verbose)s

See Also
--------
mne.decoding.read_csp
"""
_, write_hdf5 = _import_h5io_funcs()
check_fname(fname, self._save_fname_type, (".h5", ".hdf5"))
fname = _check_fname(fname, overwrite=overwrite, verbose=verbose)
write_hdf5(
fname,
self.__getstate__(),
overwrite=overwrite,
title="mnepython",
slash="replace",
)

def _validate_params(self, *, y):
_validate_type(self.n_components, int, "n_components")
if hasattr(self, "cov_est"):
Expand Down Expand Up @@ -766,12 +848,63 @@ def __init__(
delattr(self, "cov_est")
delattr(self, "norm_trace")

_save_fname_type = "spoc"

def __sklearn_tags__(self):
"""Tag the transformer."""
tags = super().__sklearn_tags__()
tags.target_tags.multi_output = False
return tags

_required_state_keys = (
"n_components",
"reg",
"log",
"transform_into",
"cov_method_params",
"restr_type",
"info",
"rank",
)

def __setstate__(self, state):
"""Restore state from serialization."""
missing = [k for k in self._required_state_keys if k not in state]
if missing:
raise ValueError(
f"State dict is missing required keys: {missing}. "
"The state may be from an incompatible version of MNE."
)
if state["info"] is not None:
state["info"] = Info(**state["info"])
self.__dict__.update(state)
self.cov_callable = partial(
_spoc_estimate,
reg=self.reg,
cov_method_params=self.cov_method_params,
info=self.info,
rank=self.rank,
)
self.mod_ged_callable = _spoc_mod
self.R_func = None

@verbose
def save(self, fname, *, overwrite=False, verbose=None):
"""Save the SPoC object to disk (in HDF5 format).

Parameters
----------
fname : path-like
The file path to save to. Should end with ``'.h5'`` or ``'.hdf5'``.
%(overwrite)s
%(verbose)s

See Also
--------
mne.decoding.read_spoc
"""
super().save(fname, overwrite=overwrite, verbose=verbose)

def fit(self, X, y):
"""Estimate the SPoC decomposition on epochs.

Expand Down Expand Up @@ -848,3 +981,57 @@ def fit_transform(self, X, y=None, **fit_params):
"""
# use parent TransformerMixin method but with custom docstring
return super().fit_transform(X, y=y, **fit_params)


def read_csp(fname):
"""Load a saved :class:`~mne.decoding.CSP` object from disk.

Parameters
----------
fname : path-like
Path to a CSP file in HDF5 format, which should end with ``.h5`` or
``.hdf5``.

Returns
-------
csp : instance of :class:`~mne.decoding.CSP`
The loaded CSP object with all fitted attributes restored.

See Also
--------
mne.decoding.CSP.save
"""
read_hdf5, _ = _import_h5io_funcs()
_validate_type(fname, "path-like", "fname")
fname = _check_fname(fname=fname, overwrite="read", must_exist=False)
state = read_hdf5(fname, title="mnepython", slash="replace")
csp = object.__new__(CSP)
csp.__setstate__(state)
return csp


def read_spoc(fname):
"""Load a saved :class:`~mne.decoding.SPoC` object from disk.

Parameters
----------
fname : path-like
Path to a SPoC file in HDF5 format, which should end with ``.h5`` or
``.hdf5``.

Returns
-------
spoc : instance of :class:`~mne.decoding.SPoC`
The loaded SPoC object with all fitted attributes restored.

See Also
--------
mne.decoding.SPoC.save
"""
read_hdf5, _ = _import_h5io_funcs()
_validate_type(fname, "path-like", "fname")
fname = _check_fname(fname=fname, overwrite="read", must_exist=False)
state = read_hdf5(fname, title="mnepython", slash="replace")
spoc = object.__new__(SPoC)
spoc.__setstate__(state)
return spoc
Loading