Skip to content

Commit 4bf1dbf

Browse files
authored
[FIX] Spectrum edited name (#365)
* fix output_type params * add spectrum name edition test cases * remove strptime dependency on file deserialization * fix AudioData name deserialization * force begin-based sorting of data in BaseDataset.data
1 parent 23ef208 commit 4bf1dbf

8 files changed

Lines changed: 64 additions & 35 deletions

File tree

src/osekit/core/audio_dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,12 @@ def _data_from_dict(cls, dictionary: dict) -> list[AudioData]:
165165
The list of deserialized ``AudioData`` objects.
166166
167167
"""
168-
return [AudioData.from_dict(data) for data in dictionary.values()]
168+
ads = []
169+
for name, value in dictionary.items():
170+
ad = AudioData.from_dict(value)
171+
ad.name = name
172+
ads.append(ad)
173+
return ads
169174

170175
@classmethod
171176
def from_folder( # noqa: PLR0913

src/osekit/core/base_data.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,18 +161,17 @@ def write(self, folder: Path, *, link: bool = False) -> None:
161161
"""Abstract method for writing data to file."""
162162

163163
@abstractmethod
164-
def link(self, folder: Path) -> None:
165-
"""Abstract method for linking data to a file in a given folder.
164+
def link(self, file: Path) -> None:
165+
"""Abstract method for linking data to a file in a given file.
166166
167167
Linking is intended for data objects that have been written to disk.
168168
After linking the data to the written file, it will have a single
169-
item that matches the File properties.
170-
The folder should contain a file named as ``str(self).extension``.
169+
item that matches the ``File`` properties.
171170
172171
Parameters
173172
----------
174-
folder: Path
175-
Folder in which is the file to which the ``BaseData`` instance should be linked.
173+
file: Path
174+
File to which the ``BaseData`` instance should be linked.
176175
177176
"""
178177

src/osekit/core/base_dataset.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,15 @@ def folder(self) -> Path:
130130
else next(iter(file.path.parent for file in self.files), None)
131131
)
132132

133+
@property
134+
def data(self) -> list[TData]:
135+
"""List of Data contained in this Dataset."""
136+
return sorted(self._data, key=lambda d: d.begin)
137+
138+
@data.setter
139+
def data(self, data: list[TData]) -> None:
140+
self._data = data
141+
133142
@folder.setter
134143
def folder(self, folder: Path) -> None:
135144
"""Set the folder in which the dataset files might be written.

src/osekit/core/base_file.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,10 @@ def from_dict(cls: type[Self], serialized: dict) -> type[Self]:
148148
path = serialized["path"]
149149
return cls(
150150
path=path,
151-
strptime_format=TIMESTAMP_FORMATS_EXPORTED_FILES,
151+
begin=strptime_from_text(
152+
text=serialized["begin"],
153+
datetime_template=TIMESTAMP_FORMATS_EXPORTED_FILES,
154+
),
152155
)
153156

154157
def __hash__(self) -> int:

src/osekit/core/spectro_data.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
from pandas import Timedelta
2222
from scipy.signal import ShortTimeFFT, welch
2323

24-
from osekit.config import (
25-
TIMESTAMP_FORMATS_EXPORTED_FILES,
26-
)
2724
from osekit.core.audio_data import AudioData
2825
from osekit.core.base_data import BaseData, TFile
2926
from osekit.core.spectro_file import SpectroFile
@@ -602,27 +599,25 @@ def write(
602599
timestamps="_".join(timestamps),
603600
)
604601
if link:
605-
self.link(folder=folder)
602+
self.link(file=folder / f"{self}.npz")
606603

607-
def link(self, folder: Path) -> None:
608-
"""Link the ``SpectroData`` to a ``SpectroFile`` in the folder.
604+
def link(self, file: Path) -> None:
605+
"""Link the ``SpectroData`` to a ``SpectroFile``.
609606
610-
The given folder should contain a file named ``"str(self).npz"``.
611607
Linking is intended for ``SpectroData`` objects that have already been
612608
written to disk.
613609
After linking, the ``SpectroData`` will have a single item with the same
614610
properties of the target ``SpectroFile``.
615611
616612
Parameters
617613
----------
618-
folder: Path
619-
Folder in which is located the ``SpectroFile`` to which the ``SpectroData``
620-
instance should be linked.
614+
file: Path
615+
File to which the ``SpectroData`` instance should be linked.
621616
622617
"""
623618
file = SpectroFile(
624-
path=folder / f"{self}.npz",
625-
strptime_format=TIMESTAMP_FORMATS_EXPORTED_FILES,
619+
path=file,
620+
begin=self.begin,
626621
)
627622
self.items = SpectroData.from_files([file]).items
628623

src/osekit/core/spectro_dataset.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -504,18 +504,19 @@ def from_dict(cls, dictionary: dict) -> SpectroDataset:
504504
)
505505
for name, sft in dictionary["sft"].items()
506506
]
507-
sd = [
508-
cls.data_cls.from_dict(
507+
sds = []
508+
for name, params in dictionary["data"].items():
509+
sd = cls.data_cls.from_dict(
509510
params,
510511
sft=next(sft for sft, linked_data in sfts if name in linked_data),
511512
)
512-
for name, params in dictionary["data"].items()
513-
]
513+
sd.name = name
514+
sds.append(sd)
514515
scale = dictionary["scale"]
515516
if dictionary["scale"] is not None:
516517
scale = Scale.from_dict_value(scale)
517518
return cls(
518-
data=sd,
519+
data=sds,
519520
name=dictionary["name"],
520521
suffix=dictionary["suffix"],
521522
folder=Path(dictionary["folder"]),

tests/test_public_api.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -618,11 +618,11 @@ def test_serialization(
618618
id="spectro_only",
619619
),
620620
pytest.param(
621-
OutputType.SPECTROGRAM,
621+
OutputType.SPECTRUM,
622622
id="spectrum_only",
623623
),
624624
pytest.param(
625-
OutputType.SPECTROGRAM,
625+
OutputType.SPECTRUM | OutputType.SPECTROGRAM,
626626
id="both_spectral_flags",
627627
),
628628
],
@@ -1169,8 +1169,8 @@ def test_prepare_spectro(
11691169

11701170

11711171
def test_edit_transform_before_run(
1172-
tmp_path: pytest.fixture,
1173-
audio_files: pytest.fixture,
1172+
tmp_path: Path,
1173+
audio_files: None,
11741174
) -> None:
11751175
project = Project(
11761176
folder=tmp_path,
@@ -1181,10 +1181,11 @@ def test_edit_transform_before_run(
11811181
project.build()
11821182

11831183
transform = Transform(
1184-
output_type=OutputType.AUDIO | OutputType.SPECTROGRAM,
1184+
output_type=OutputType.AUDIO | OutputType.SPECTRUM | OutputType.SPECTROGRAM,
11851185
data_duration=project.origin_dataset.duration / 2,
11861186
name="original_transform",
11871187
sample_rate=24_000,
1188+
v_lim=(0.0, 120.0),
11881189
fft=ShortTimeFFT(win=hamming(1024), hop=1024, fs=24_000),
11891190
)
11901191

@@ -1205,7 +1206,14 @@ def test_edit_transform_before_run(
12051206
ads.data = new_data
12061207
ads.normalization = new_normalization
12071208

1208-
project.run(transform, audio_dataset=ads)
1209+
# Spectro edits
1210+
new_v_lim = (50.0, 100.0)
1211+
sds = project.prepare_spectro(transform=transform, audio_dataset=ads)
1212+
sds.v_lim = new_v_lim
1213+
for idx, sd in enumerate(sds.data):
1214+
sd.name = str(idx)
1215+
1216+
project.run(transform, audio_dataset=ads, spectro_dataset=sds)
12091217

12101218
# New ads name
12111219
assert (project.folder / "data" / "audio" / ads.name).exists()
@@ -1234,6 +1242,10 @@ def test_edit_transform_before_run(
12341242
# Instrument has been edited
12351243
assert output_ads.instrument.end_to_end_db == new_instrument.end_to_end_db
12361244

1245+
# Spectro data have been edited
1246+
assert output_sds.v_lim == new_v_lim
1247+
assert all(sd.name == str(i) for i, sd in enumerate(output_sds.data))
1248+
12371249

12381250
def test_delete_output_dataset(
12391251
tmp_path: pytest.fixture,

tests/test_serialization.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ def test_audio_dataset_serialization(
280280
name=name,
281281
)
282282

283+
# NO TIMESTAMP IN THE NAME SHOULD STILL BE DESERIALIZED PROPERLY
284+
for idx, ad in list(enumerate(ads.data))[::2]:
285+
ad.name = str(idx)
286+
283287
assert ads.begin == begin
284288

285289
if type(sample_rate) is list:
@@ -308,10 +312,11 @@ def test_audio_dataset_serialization(
308312
assert ads.begin == ads2.begin
309313
assert ads.normalization == ads2.normalization
310314

311-
assert all(
312-
np.array_equal(ad.get_value(), ad2.get_value())
313-
for ad, ad2 in zip(ads.data, ads2.data, strict=False)
314-
)
315+
zipped = zip(ads.data, ads2.data, strict=True)
316+
317+
assert all(ad1.name == ad2.name for ad1, ad2 in zipped)
318+
319+
assert all(np.array_equal(ad.get_value(), ad2.get_value()) for ad, ad2 in zipped)
315320

316321

317322
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)