Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/osekit/core/audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,12 @@ def _data_from_dict(cls, dictionary: dict) -> list[AudioData]:
The list of deserialized ``AudioData`` objects.

"""
return [AudioData.from_dict(data) for data in dictionary.values()]
ads = []
for name, value in dictionary.items():
ad = AudioData.from_dict(value)
ad.name = name
ads.append(ad)
return ads

@classmethod
def from_folder( # noqa: PLR0913
Expand Down
11 changes: 5 additions & 6 deletions src/osekit/core/base_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,17 @@ def write(self, folder: Path, *, link: bool = False) -> None:
"""Abstract method for writing data to file."""

@abstractmethod
def link(self, folder: Path) -> None:
"""Abstract method for linking data to a file in a given folder.
def link(self, file: Path) -> None:
"""Abstract method for linking data to a file in a given file.

Linking is intended for data objects that have been written to disk.
After linking the data to the written file, it will have a single
item that matches the File properties.
The folder should contain a file named as ``str(self).extension``.
item that matches the ``File`` properties.

Parameters
----------
folder: Path
Folder in which is the file to which the ``BaseData`` instance should be linked.
file: Path
File to which the ``BaseData`` instance should be linked.

"""

Expand Down
9 changes: 9 additions & 0 deletions src/osekit/core/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ def folder(self) -> Path:
else next(iter(file.path.parent for file in self.files), None)
)

@property
def data(self) -> list[TData]:
"""List of Data contained in this Dataset."""
return sorted(self._data, key=lambda d: d.begin)

@data.setter
def data(self, data: list[TData]) -> None:
self._data = data

@folder.setter
def folder(self, folder: Path) -> None:
"""Set the folder in which the dataset files might be written.
Expand Down
5 changes: 4 additions & 1 deletion src/osekit/core/base_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ def from_dict(cls: type[Self], serialized: dict) -> type[Self]:
path = serialized["path"]
return cls(
path=path,
strptime_format=TIMESTAMP_FORMATS_EXPORTED_FILES,
begin=strptime_from_text(
text=serialized["begin"],
datetime_template=TIMESTAMP_FORMATS_EXPORTED_FILES,
),
)

def __hash__(self) -> int:
Expand Down
19 changes: 7 additions & 12 deletions src/osekit/core/spectro_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
from pandas import Timedelta
from scipy.signal import ShortTimeFFT, welch

from osekit.config import (
TIMESTAMP_FORMATS_EXPORTED_FILES,
)
from osekit.core.audio_data import AudioData
from osekit.core.base_data import BaseData, TFile
from osekit.core.spectro_file import SpectroFile
Expand Down Expand Up @@ -602,27 +599,25 @@ def write(
timestamps="_".join(timestamps),
)
if link:
self.link(folder=folder)
self.link(file=folder / f"{self}.npz")

def link(self, folder: Path) -> None:
"""Link the ``SpectroData`` to a ``SpectroFile`` in the folder.
def link(self, file: Path) -> None:
"""Link the ``SpectroData`` to a ``SpectroFile``.

The given folder should contain a file named ``"str(self).npz"``.
Linking is intended for ``SpectroData`` objects that have already been
written to disk.
After linking, the ``SpectroData`` will have a single item with the same
properties of the target ``SpectroFile``.

Parameters
----------
folder: Path
Folder in which is located the ``SpectroFile`` to which the ``SpectroData``
instance should be linked.
file: Path
File to which the ``SpectroData`` instance should be linked.

"""
file = SpectroFile(
path=folder / f"{self}.npz",
strptime_format=TIMESTAMP_FORMATS_EXPORTED_FILES,
path=file,
begin=self.begin,
)
self.items = SpectroData.from_files([file]).items

Expand Down
11 changes: 6 additions & 5 deletions src/osekit/core/spectro_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,18 +504,19 @@ def from_dict(cls, dictionary: dict) -> SpectroDataset:
)
for name, sft in dictionary["sft"].items()
]
sd = [
cls.data_cls.from_dict(
sds = []
for name, params in dictionary["data"].items():
sd = cls.data_cls.from_dict(
params,
sft=next(sft for sft, linked_data in sfts if name in linked_data),
)
for name, params in dictionary["data"].items()
]
sd.name = name
sds.append(sd)
scale = dictionary["scale"]
if dictionary["scale"] is not None:
scale = Scale.from_dict_value(scale)
return cls(
data=sd,
data=sds,
name=dictionary["name"],
suffix=dictionary["suffix"],
folder=Path(dictionary["folder"]),
Expand Down
24 changes: 18 additions & 6 deletions tests/test_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,11 +618,11 @@ def test_serialization(
id="spectro_only",
),
pytest.param(
OutputType.SPECTROGRAM,
OutputType.SPECTRUM,
id="spectrum_only",
),
pytest.param(
OutputType.SPECTROGRAM,
OutputType.SPECTRUM | OutputType.SPECTROGRAM,
id="both_spectral_flags",
),
],
Expand Down Expand Up @@ -1169,8 +1169,8 @@ def test_prepare_spectro(


def test_edit_transform_before_run(
tmp_path: pytest.fixture,
audio_files: pytest.fixture,
tmp_path: Path,
audio_files: None,
) -> None:
project = Project(
folder=tmp_path,
Expand All @@ -1181,10 +1181,11 @@ def test_edit_transform_before_run(
project.build()

transform = Transform(
output_type=OutputType.AUDIO | OutputType.SPECTROGRAM,
output_type=OutputType.AUDIO | OutputType.SPECTRUM | OutputType.SPECTROGRAM,
data_duration=project.origin_dataset.duration / 2,
name="original_transform",
sample_rate=24_000,
v_lim=(0.0, 120.0),
fft=ShortTimeFFT(win=hamming(1024), hop=1024, fs=24_000),
)

Expand All @@ -1205,7 +1206,14 @@ def test_edit_transform_before_run(
ads.data = new_data
ads.normalization = new_normalization

project.run(transform, audio_dataset=ads)
# Spectro edits
new_v_lim = (50.0, 100.0)
sds = project.prepare_spectro(transform=transform, audio_dataset=ads)
sds.v_lim = new_v_lim
for idx, sd in enumerate(sds.data):
sd.name = str(idx)

project.run(transform, audio_dataset=ads, spectro_dataset=sds)

# New ads name
assert (project.folder / "data" / "audio" / ads.name).exists()
Expand Down Expand Up @@ -1234,6 +1242,10 @@ def test_edit_transform_before_run(
# Instrument has been edited
assert output_ads.instrument.end_to_end_db == new_instrument.end_to_end_db

# Spectro data have been edited
assert output_sds.v_lim == new_v_lim
assert all(sd.name == str(i) for i, sd in enumerate(output_sds.data))


def test_delete_output_dataset(
tmp_path: pytest.fixture,
Expand Down
13 changes: 9 additions & 4 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ def test_audio_dataset_serialization(
name=name,
)

# NO TIMESTAMP IN THE NAME SHOULD STILL BE DESERIALIZED PROPERLY
for idx, ad in list(enumerate(ads.data))[::2]:
ad.name = str(idx)

assert ads.begin == begin

if type(sample_rate) is list:
Expand Down Expand Up @@ -308,10 +312,11 @@ def test_audio_dataset_serialization(
assert ads.begin == ads2.begin
assert ads.normalization == ads2.normalization

assert all(
np.array_equal(ad.get_value(), ad2.get_value())
for ad, ad2 in zip(ads.data, ads2.data, strict=False)
)
zipped = zip(ads.data, ads2.data, strict=True)

assert all(ad1.name == ad2.name for ad1, ad2 in zipped)

assert all(np.array_equal(ad.get_value(), ad2.get_value()) for ad, ad2 in zipped)


@pytest.mark.parametrize(
Expand Down