diff --git a/src/spike_beans/components.py b/src/spike_beans/components.py index b8432af..9cc3566 100755 --- a/src/spike_beans/components.py +++ b/src/spike_beans/components.py @@ -66,19 +66,21 @@ class FilterStack(base.Component): base.HasAttributes("signal")) def __init__(self): - self._filters = [] + self.filters = [] + self._filter_functions = [] self._signal = None super(FilterStack, self).__init__() def add_filter(self, filt, *args, **kwargs): # type checking if hasattr(filt, "__call__"): - self._filters.append(lambda signal: filt(signal, *args, **kwargs)) + filter_func = filt + filter_descr = {'type': filter_func.__name__} elif isinstance(filt, str): try: filter_func = filters.__getattribute__("flt" + filt) - self._filters.append( - lambda signal: filter_func(signal, *args, **kwargs)) + filter_descr = {'type': filt} + self._filter_functions.append(lambda signal: filter_func(signal, *args, **kwargs)) except AttributeError: raise AttributeError("No such method found in 'core.filters':" + " flt" + filt) @@ -86,10 +88,14 @@ def add_filter(self, filt, *args, **kwargs): raise TypeError(("Unsupported argument type: %s." % type(filt)) + "Only string or callable are accepted") + filter_descr.update({'args': args, 'extra_args': kwargs}) + self.filters.append(filter_descr) + self._filter_functions.append(lambda signal: filter_func(signal, *args, **kwargs)) + def read_signal(self): if self._signal is None: self._signal = self.raw_src.signal - for filt in self._filters: + for filt in self._filter_functions: self._signal = filt(self._signal) return self._signal diff --git a/tests/test_components.py b/tests/test_components.py index e44a1c5..aa76035 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -181,6 +181,22 @@ def test_filter_stack_add_filter_string(): # only tests that filtering has changed a signal ok_(np.linalg.norm(io.signal['data'] - io_filter.signal['data']) > 0.) +@with_setup(setup, teardown) +def test_filter_stack_is_json_serializable(): + io = DummySignalSource() + base.features.Provide("RawSource", io) + + def custom_filter(): pass + + io_filter = components.FilterStack() + io_filter.add_filter("LinearIIR", 800., 300.) + io_filter.add_filter(custom_filter) + filter_json = json.dumps(io_filter.filters) + filter_descr = json.loads(filter_json) + + assert filter_descr[0]['type'] == 'LinearIIR' + assert filter_descr[1]['type'] == 'custom_filter' + @with_setup(setup, teardown) def test_filter_stack_add_filter_attribute_error(): io = DummySignalSource()