Skip to content

Commit 09fb355

Browse files
maarten-icolivhoenen
authored andcommitted
Implement imas.util.to_xarray
Reuses most of the tensorization and metadata logic from the netCDF export.
1 parent 2429595 commit 09fb355

File tree

4 files changed

+241
-3
lines changed

4 files changed

+241
-3
lines changed

imas/_to_xarray.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# xarray is an optional dependency, but this module won't be imported when xarray is not
2+
# available
3+
import numpy
4+
import xarray
5+
6+
from imas.ids_toplevel import IDSToplevel
7+
from imas.backends.netcdf.ids_tensorizer import IDSTensorizer
8+
from imas.ids_data_type import IDSDataType
9+
10+
fillvals = {
11+
IDSDataType.INT: -(2**31) + 1,
12+
IDSDataType.STR: "",
13+
IDSDataType.FLT: numpy.nan,
14+
IDSDataType.CPX: numpy.nan * (1 + 1j),
15+
}
16+
17+
18+
def to_xarray(ids: IDSToplevel, *paths: str) -> xarray.Dataset:
19+
"""See :func:`imas.util.to_xarray`"""
20+
# We really need an IDS toplevel element
21+
if not isinstance(ids, IDSToplevel):
22+
raise TypeError(
23+
f"to_xarray needs a toplevel IDS element as first argument, but got {ids!r}"
24+
)
25+
26+
# Valid path can use / or . as separator, but IDSTensorizer expects /. The following
27+
# block checks if the paths are valid, and by using "metadata.path_string" we ensure
28+
# that / are used as separator.
29+
try:
30+
paths = [ids.metadata[path].path_string for path in paths]
31+
except KeyError as exc:
32+
raise ValueError(str(exc)) from None
33+
34+
# Converting lazy-loaded IDSs requires users to specify at least one path
35+
if ids._lazy and not paths:
36+
raise RuntimeError(
37+
"This IDS is lazy loaded. Please provide at least one path to convert to"
38+
" xarray."
39+
)
40+
41+
# Use netcdf IDS Tensorizer to tensorize the data and determine metadata
42+
tensorizer = IDSTensorizer(ids, paths)
43+
tensorizer.include_coordinate_paths()
44+
tensorizer.collect_filled_data()
45+
tensorizer.determine_data_shapes()
46+
47+
data_vars = {}
48+
coordinate_names = set()
49+
for path in tensorizer.filled_data:
50+
var_name = path.replace("/", ".")
51+
metadata = ids.metadata[path]
52+
if metadata.data_type in (IDSDataType.STRUCTURE, IDSDataType.STRUCT_ARRAY):
53+
continue # We don't store these in xarray
54+
55+
dimensions = tensorizer.ncmeta.get_dimensions(path, tensorizer.homogeneous_time)
56+
data = tensorizer.tensorize(path, fillvals[metadata.data_type])
57+
58+
attrs = dict(documentation=metadata.documentation)
59+
if metadata.units:
60+
attrs["units"] = metadata.units
61+
coordinates = tensorizer.filter_coordinates(path)
62+
if coordinates:
63+
coordinate_names.update(coordinates.split(" "))
64+
attrs["coordinates"] = coordinates
65+
66+
data_vars[var_name] = (dimensions, data, attrs)
67+
68+
# Remove coordinates from data_vars and put in coordinates mapping:
69+
coordinates = {}
70+
for coordinate_name in coordinate_names:
71+
coordinates[coordinate_name] = data_vars.pop(coordinate_name)
72+
73+
return xarray.Dataset(data_vars, coordinates)

imas/backends/netcdf/ids_tensorizer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# You should have received the IMAS-Python LICENSE file with this project.
33
"""Tensorization logic to convert IDSs to netCDF files and/or xarray Datasets."""
44

5+
from collections import deque
56
from typing import List
67

78
import numpy
@@ -53,6 +54,26 @@ def __init__(self, ids: IDSToplevel, paths_to_tensorize: List[str]) -> None:
5354
self.shapes = {}
5455
"""Map of IDS paths to data shape arrays."""
5556

57+
def include_coordinate_paths(self) -> None:
58+
"""Append all paths that are coordinates of self.paths_to_tensorize"""
59+
# Use a queue so we can also take coordinates of coordinates into account
60+
queue = deque(self.paths_to_tensorize)
61+
# Include all parent AoS as well:
62+
for path in self.paths_to_tensorize:
63+
while path:
64+
path, _, _ = path.rpartition("/")
65+
if self.ncmeta.get_dimensions(path, self.homogeneous_time):
66+
queue.append(path)
67+
68+
self.paths_to_tensorize = []
69+
while queue:
70+
path = queue.popleft()
71+
if path in self.paths_to_tensorize:
72+
continue # already processed
73+
self.paths_to_tensorize.append(path)
74+
for coordinate in self.ncmeta.get_coordinates(path, self.homogeneous_time):
75+
queue.append(coordinate.replace(".", "/"))
76+
5677
def collect_filled_data(self) -> None:
5778
"""Collect all filled data in the IDS and determine dimension sizes.
5879

imas/test/test_to_xarray.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import numpy as np
2+
import pytest
3+
4+
import imas
5+
import imas.training
6+
from imas.util import to_xarray
7+
8+
pytest.importorskip("xarray")
9+
10+
11+
@pytest.fixture
12+
def entry(requires_imas, monkeypatch):
13+
monkeypatch.setenv("IMAS_VERSION", "3.39.0") # Use fixed DD version
14+
return imas.training.get_training_db_entry()
15+
16+
17+
def test_to_xarray_invalid_argtype():
18+
ids = imas.IDSFactory("3.39.0").core_profiles()
19+
20+
with pytest.raises(TypeError):
21+
to_xarray("test")
22+
with pytest.raises(TypeError):
23+
to_xarray(ids.time)
24+
with pytest.raises(TypeError):
25+
to_xarray(ids.ids_properties)
26+
27+
28+
def test_to_xarray_invalid_paths():
29+
ids = imas.IDSFactory("3.39.0").core_profiles()
30+
31+
with pytest.raises(ValueError, match="xyz"):
32+
to_xarray(ids, "xyz")
33+
with pytest.raises(ValueError, match="ids_properties/xyz"):
34+
to_xarray(ids, "ids_properties/xyz")
35+
with pytest.raises(ValueError, match="Xtime"):
36+
to_xarray(ids, "time", "Xtime")
37+
38+
39+
def validate_trainingdb_electron_temperature_dataset(ds):
40+
assert ds.sizes == {"time": 3, "profiles_1d.grid.rho_tor_norm:i": 101}
41+
assert ds.data_vars.keys() == {"profiles_1d.electrons.temperature"}
42+
assert ds.coords.keys() == {"time", "profiles_1d.grid.rho_tor_norm"}
43+
44+
# Check that values are loaded as expected
45+
assert np.allclose(ds["time"], [3.987222, 432.937598, 792.0])
46+
assert np.allclose(
47+
ds.isel(time=1)["profiles_1d.electrons.temperature"][10:13],
48+
[17728.81703089, 17440.78020568, 17139.35431082],
49+
)
50+
51+
52+
def test_to_xarray_lazy_loaded(entry):
53+
ids = entry.get("core_profiles", lazy=True)
54+
55+
with pytest.raises(RuntimeError):
56+
to_xarray(ids)
57+
58+
ds = to_xarray(ids, "profiles_1d.electrons.temperature")
59+
validate_trainingdb_electron_temperature_dataset(ds)
60+
61+
62+
def test_to_xarray_from_trainingdb(entry):
63+
ids = entry.get("core_profiles")
64+
65+
ds = to_xarray(ids)
66+
validate_trainingdb_electron_temperature_dataset(
67+
ds["profiles_1d.electrons.temperature"].to_dataset()
68+
)
69+
ds = to_xarray(ids, "profiles_1d.electrons.temperature")
70+
validate_trainingdb_electron_temperature_dataset(ds)
71+
72+
ds = to_xarray(
73+
ids, "profiles_1d.electrons.temperature", "profiles_1d/electrons/density"
74+
)
75+
assert ds.data_vars.keys() == {
76+
"profiles_1d.electrons.temperature",
77+
"profiles_1d.electrons.density",
78+
}
79+
80+
81+
def test_to_xarray():
82+
ids = imas.IDSFactory("3.39.0").core_profiles()
83+
84+
ids.profiles_1d.resize(2)
85+
ids.profiles_1d[0].electrons.temperature = [1.0, 2.0]
86+
ids.profiles_1d[0].grid.rho_tor_norm = [0.0, 1.0]
87+
ids.profiles_1d[0].time = 0.0
88+
89+
# These should all be identical:
90+
ds1 = to_xarray(ids)
91+
ds2 = to_xarray(ids, "profiles_1d.electrons.temperature")
92+
ds3 = to_xarray(ids, "profiles_1d/electrons/temperature")
93+
assert ds1.equals(ds2)
94+
assert ds2.equals(ds3)

imas/util.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# This file is part of IMAS-Python.
22
# You should have received the IMAS-Python LICENSE file with this project.
3-
"""Collection of useful helper methods when working with IMAS-Python.
4-
"""
5-
3+
"""Collection of useful helper methods when working with IMAS-Python."""
64

75
import logging
86
import re
@@ -524,3 +522,55 @@ def get_data_dictionary_version(obj: Union[IDSBase, DBEntry, IDSFactory]) -> str
524522
if isinstance(obj, IDSBase):
525523
return obj._version
526524
raise TypeError(f"Cannot get data dictionary version of '{type(obj)}'")
525+
526+
527+
def to_xarray(ids: IDSToplevel, *paths: str) -> Any:
528+
"""Convert an IDS to an xarray Dataset.
529+
530+
Args:
531+
ids: An IDS toplevel element
532+
paths: Optional list of element paths to convert to xarray. The full IDS will be
533+
converted to an xarray Dataset if no paths are provided.
534+
535+
Paths must not contain indices, and may use a ``/`` or a ``.`` as separator.
536+
For example, ``"profiles_1d(itime)/electrons/density"`` is not allowed as
537+
path, use ``"profiles_1d/electrons/density"`` or
538+
``profiles_1d.electrons.density"`` instead.
539+
540+
Coordinates to the quantities in the requested paths will also be included
541+
in the xarray Dataset.
542+
543+
Returns:
544+
An ``xarray.Dataset`` object.
545+
546+
Examples:
547+
.. code-block:: python
548+
549+
# Convert the whole IDS to an xarray Dataset
550+
ds = imas.util.to_xarray(ids)
551+
552+
# Convert only some elements in the IDS (including their coordinates)
553+
ds = imas.util.to_xarray(
554+
ids,
555+
"profiles_1d/electrons/density",
556+
"profiles_1d/electrons/temperature",
557+
)
558+
559+
# Paths can be provided with "/" or "." as separator
560+
ds = imas.util.to_xarray(
561+
ids,
562+
"profiles_1d.electrons.density",
563+
"profiles_1d.electrons.temperature",
564+
)
565+
566+
See Also:
567+
https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html
568+
"""
569+
try:
570+
import xarray # noqa: F401
571+
except ImportError:
572+
raise RuntimeError("xarray is not available, cannot convert the IDS to xarray.")
573+
574+
from imas._to_xarray import to_xarray
575+
576+
return to_xarray(ids, *paths)

0 commit comments

Comments
 (0)