Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions disruption_py/settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
"""Package initialization for the settings module."""

from .log_settings import LogSettings
from .output_setting import (
OutputSetting,
OutputSettingParams,
)
from .output_setting import OutputSetting, OutputSettingParams
from .retrieval_settings import RetrievalSettings
from .shotlist_setting import DatabaseShotlistSetting, FileShotlistSetting
from .time_setting import TimeSetting, TimeSettingParams
Expand Down
18 changes: 16 additions & 2 deletions disruption_py/settings/output_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,13 @@ def get_results(self) -> OutputSingleType:
xr.Dataset | xr.DataTree | pd.DataFrame
The resulting object.
"""
logger.debug("Concatenating {tot} shots.", tot=len(self.results))
logger.debug("Concatenating {tot:,} shots...", tot=len(self.results))
took = -time.time()
self.result = self.concat()
took += time.time()
logger.info(
"Concatenated {tot:,} shots in {sec:.3f}s.", tot=len(self.results), sec=took
)
self.results = {}
return self.result

Expand Down Expand Up @@ -319,7 +324,16 @@ def concat(self) -> xr.Dataset:
if not self.results:
logger.critical("Nothing to concatenate!")
return xr.Dataset()
return xr.concat(self.results.values(), dim="idx", combine_attrs="no_conflicts")

ds = xr.concat(self.results.values(), dim="idx", combine_attrs="no_conflicts")
if "shot" not in ds.coords or "time" not in ds.coords:
return ds

took = -time.time()
ds = ds.sortby(["shot", "time"])
took += time.time()
logger.debug("Sorted {tot:,} rows in {sec:.3f}s.", tot=len(ds.idx), sec=took)
return ds


class DataTreeOutputSetting(SingleOutputSetting):
Expand Down
5 changes: 1 addition & 4 deletions disruption_py/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@

from disruption_py.config import config
from disruption_py.core.retrieval_manager import RetrievalManager
from disruption_py.core.utils.misc import (
get_elapsed_time,
without_duplicates,
)
from disruption_py.core.utils.misc import get_elapsed_time, without_duplicates
from disruption_py.inout.mds import ProcessMDSConnection
from disruption_py.inout.sql import ShotDatabase
from disruption_py.inout.xr import XarrayConnection
Expand Down
8 changes: 4 additions & 4 deletions tests/test_output_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def test_output_exists(fresh_data, test_folder_m):
pd.testing.assert_frame_equal(df_out, df_dsk)

# format equivalence
xr.testing.assert_identical(ds_out, xr.concat(dict_out.values(), dim="idx"))
xr.testing.assert_identical(
ds_out, xr.concat([dt.to_dataset() for dt in dt_out.values()], dim="idx")
)
for ds_list in [dict_out.values(), [dt.to_dataset() for dt in dt_out.values()]]:
xr.testing.assert_identical(
ds_out, xr.concat(ds_list, dim="idx").sortby(["shot", "time"])
)
pd.testing.assert_frame_equal(df_out, ds_out.to_dataframe()[df_out.columns])