Skip to content

Commit 1b4189d

Browse files
committed
Implement implicit DD version conversion in NetCDF backend
1 parent 954bdb8 commit 1b4189d

File tree

5 files changed

+117
-44
lines changed

5 files changed

+117
-44
lines changed

imas/backends/netcdf/db_entry_nc.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from imas.backends.netcdf.ids2nc import IDS2NC
1212
from imas.backends.netcdf.nc2ids import NC2IDS
1313
from imas.exception import DataEntryException, InvalidNetCDFEntry
14-
from imas.ids_convert import NBCPathMap, convert_ids
14+
from imas.ids_convert import NBCPathMap, dd_version_map_from_factories
1515
from imas.ids_factory import IDSFactory
1616
from imas.ids_toplevel import IDSToplevel
1717

@@ -123,14 +123,19 @@ def get(
123123

124124
# Load data into the destination IDS
125125
if self._ds_factory.dd_version == destination._dd_version:
126-
NC2IDS(group, destination).run()
126+
NC2IDS(group, destination, destination.metadata, None).run()
127127
else:
128-
# FIXME: implement automatic conversion using nbc_map
129-
# As a work-around: do an explicit conversion, but automatic conversion
130-
# will also be needed to implement lazy loading.
131-
ids = self._ds_factory.new(ids_name)
132-
NC2IDS(group, ids).run()
133-
convert_ids(ids, None, target=destination)
128+
# Construct relevant NBCPathMap, the one we get from DBEntry has the reverse
129+
# mapping from what we need. The imas_core logic does the mapping from
130+
# in-memory to on-disk, while we take what is on-disk and map it to
131+
# in-memory.
132+
ddmap, source_is_older = dd_version_map_from_factories(
133+
ids_name, self._ds_factory, self._factory
134+
)
135+
nbc_map = ddmap.old_to_new if source_is_older else ddmap.new_to_old
136+
NC2IDS(
137+
group, destination, self._ds_factory.new(ids_name).metadata, nbc_map
138+
).run()
134139

135140
return destination
136141

imas/backends/netcdf/nc2ids.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from imas.backends.netcdf.nc_metadata import NCMetadata
99
from imas.exception import InvalidNetCDFEntry
1010
from imas.ids_base import IDSBase
11+
from imas.ids_convert import NBCPathMap
1112
from imas.ids_data_type import IDSDataType
1213
from imas.ids_defs import IDS_TIME_MODE_HOMOGENEOUS
1314
from imas.ids_metadata import IDSMetadata
@@ -70,19 +71,32 @@ def _tree_iter(
7071
class NC2IDS:
7172
"""Class responsible for reading an IDS from a NetCDF group."""
7273

73-
def __init__(self, group: netCDF4.Group, ids: IDSToplevel) -> None:
74+
def __init__(
75+
self,
76+
group: netCDF4.Group,
77+
ids: IDSToplevel,
78+
ids_metadata: IDSMetadata,
79+
nbc_map: Optional[NBCPathMap],
80+
) -> None:
7481
"""Initialize NC2IDS converter.
7582
7683
Args:
7784
group: NetCDF group that stores the IDS data.
7885
ids: Corresponding IDS toplevel to store the data in.
86+
ids_metadata: Metadata corresponding to the DD version that the data is
87+
stored in.
88+
nbc_map: Path map for implicit DD conversions.
7989
"""
8090
self.group = group
8191
"""NetCDF Group that the IDS is stored in."""
8292
self.ids = ids
8393
"""IDS to store the data in."""
94+
self.ids_metadata = ids_metadata
95+
"""Metadata of the IDS in the DD version that the data is stored in"""
96+
self.nbc_map = nbc_map
97+
"""Path map for implicit DD conversions."""
8498

85-
self.ncmeta = NCMetadata(ids.metadata)
99+
self.ncmeta = NCMetadata(ids_metadata)
86100
"""NetCDF related metadata."""
87101
self.variables = list(group.variables)
88102
"""List of variable names stored in the netCDF group."""
@@ -114,16 +128,39 @@ def run(self) -> None:
114128
for var_name in self.variables:
115129
if var_name.endswith(":shape"):
116130
continue
117-
metadata = self.ids.metadata[var_name]
131+
metadata = self.ids_metadata[var_name]
118132

119133
if metadata.data_type is IDSDataType.STRUCTURE:
120134
continue # This only contains DD metadata we already know
121135

136+
# Handle implicit DD version conversion
137+
if self.nbc_map is None:
138+
target_metadata = metadata # no conversion
139+
elif metadata.path_string in self.nbc_map:
140+
new_path = self.nbc_map.path[metadata.path_string]
141+
if new_path is None:
142+
logging.info(
143+
"Not loading data for %s: no equivalent data structure exists "
144+
"in the target Data Dictionary version.",
145+
metadata.path_string,
146+
)
147+
continue
148+
target_metadata = self.ids.metadata[new_path]
149+
elif metadata.path_string in self.nbc_map.type_change:
150+
logging.info(
151+
"Not loading data for %s: cannot hanlde type changes when "
152+
"implicitly converting data to the target Data Dictionary version.",
153+
metadata.path_string,
154+
)
155+
continue
156+
else:
157+
target_metadata = metadata # no conversion required
158+
122159
var = self.group[var_name]
123160
if metadata.data_type is IDSDataType.STRUCT_ARRAY:
124161
if "sparse" in var.ncattrs():
125162
shapes = self.group[var_name + ":shape"][()]
126-
for index, node in tree_iter(self.ids, metadata):
163+
for index, node in tree_iter(self.ids, target_metadata):
127164
node.resize(shapes[index][0])
128165

129166
else:
@@ -132,7 +169,7 @@ def run(self) -> None:
132169
metadata.path_string, self.homogeneous_time
133170
)[-1]
134171
size = self.group.dimensions[dim].size
135-
for _, node in tree_iter(self.ids, metadata):
172+
for _, node in tree_iter(self.ids, target_metadata):
136173
node.resize(size)
137174

138175
continue
@@ -144,22 +181,22 @@ def run(self) -> None:
144181
if "sparse" in var.ncattrs():
145182
if metadata.ndim:
146183
shapes = self.group[var_name + ":shape"][()]
147-
for index, node in tree_iter(self.ids, metadata):
184+
for index, node in tree_iter(self.ids, target_metadata):
148185
shape = shapes[index]
149186
if shape.all():
150187
node.value = data[index + tuple(map(slice, shapes[index]))]
151188
else:
152-
for index, node in tree_iter(self.ids, metadata):
189+
for index, node in tree_iter(self.ids, target_metadata):
153190
value = data[index]
154191
if value != getattr(var, "_FillValue", None):
155192
node.value = data[index]
156193

157194
elif metadata.path_string not in self.ncmeta.aos:
158195
# Shortcut for assigning untensorized data
159-
self.ids[metadata.path] = data
196+
self.ids[target_metadata.path] = data
160197

161198
else:
162-
for index, node in tree_iter(self.ids, metadata):
199+
for index, node in tree_iter(self.ids, target_metadata):
163200
node.value = data[index]
164201

165202
def validate_variables(self) -> None:
@@ -194,7 +231,7 @@ def validate_variables(self) -> None:
194231
# Check that the DD defines this variable, and validate its metadata
195232
var = self.group[var_name]
196233
try:
197-
metadata = self.ids.metadata[var_name]
234+
metadata = self.ids_metadata[var_name]
198235
except KeyError:
199236
raise InvalidNetCDFEntry(
200237
f"Invalid variable {var_name}: no such variable exists in the "

imas/backends/netcdf/nc_validate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def validate_netcdf_file(filename: str) -> None:
4747
for ids_name in ids_names:
4848
for occurrence in entry.list_all_occurrences(ids_name):
4949
group = dataset[f"{ids_name}/{occurrence}"]
50+
ids = factory.new(ids_name)
5051
try:
51-
NC2IDS(group, factory.new(ids_name)).validate_variables()
52+
NC2IDS(group, ids, ids.metadata, None).validate_variables()
5253
except InvalidNetCDFEntry as exc:
5354
occ = f":{occurrence}" if occurrence else ""
5455
raise InvalidNetCDFEntry(f"Invalid IDS {ids_name}{occ}: {exc}")

imas/test/test_nbc_change.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,11 @@
99

1010
import numpy as np
1111
import pytest
12-
1312
from imas.db_entry import DBEntry
1413
from imas.ids_convert import convert_ids
1514
from imas.ids_defs import IDS_TIME_MODE_HOMOGENEOUS, MEMORY_BACKEND
1615
from imas.ids_factory import IDSFactory
17-
from imas.test.test_helpers import (
18-
compare_children,
19-
fill_with_random_data,
20-
open_dbentry,
21-
)
16+
from imas.test.test_helpers import compare_children, fill_with_random_data, open_dbentry
2217

2318

2419
@pytest.fixture(autouse=True)
@@ -97,6 +92,23 @@ def test_nbc_0d_to_1d(caplog, requires_imas):
9792
entry_339.close()
9893

9994

95+
def test_nbc_0d_to_1d_netcdf(caplog, tmp_path):
96+
# channel/filter_spectrometer/radiance_calibration in spectrometer visible changed
97+
# from FLT_0D to FLT_1D in DD 3.39.0
98+
ids = IDSFactory("3.32.0").spectrometer_visible()
99+
ids.ids_properties.homogeneous_time = IDS_TIME_MODE_HOMOGENEOUS
100+
ids.channel.resize(1)
101+
ids.channel[0].filter_spectrometer.radiance_calibration = 1.0
102+
103+
# Test implicit conversion during get
104+
with DBEntry(str(tmp_path / "test.nc"), "x", dd_version="3.32.0") as entry_332:
105+
entry_332.put(ids)
106+
with DBEntry(str(tmp_path / "test.nc"), "r", dd_version="3.39.0") as entry_339:
107+
ids_339 = entry_339.get("spectrometer_visible") # implicit conversion
108+
assert not ids_339.channel[0].filter_spectrometer.radiance_calibration.has_value
109+
entry_339.close()
110+
111+
100112
def test_nbc_change_aos_renamed():
101113
"""Test renamed AoS in pulse_schedule: ec/antenna -> ec/launcher.
102114
@@ -272,7 +284,7 @@ def test_pulse_schedule_aos_renamed_autofill_up(backend, worker_id, tmp_path):
272284
dbentry.close()
273285

274286

275-
def test_pulse_schedule_multi_rename():
287+
def test_pulse_schedule_multi_rename(tmp_path):
276288
# Multiple renames of the same element:
277289
# DD >= 3.40+: ec/beam
278290
# DD 3.26-3.40: ec/launcher (but NBC metadata added in 3.28 only)
@@ -294,9 +306,18 @@ def test_pulse_schedule_multi_rename():
294306
ps["3.40.0"].ec.beam[0].name = name
295307

296308
for version1 in ps:
309+
ncfilename = str(tmp_path / f"{version1}.nc")
310+
with DBEntry(ncfilename, "x", dd_version=version1) as entry:
311+
entry.put(ps[version1])
312+
297313
for version2 in ps:
298314
converted = convert_ids(ps[version1], version2)
299-
compare_children(ps[version2], converted)
315+
compare_children(ps[version2].ec, converted.ec)
316+
317+
# Test with netCDF backend
318+
with DBEntry(ncfilename, "r", dd_version=version2) as entry:
319+
converted = entry.get("pulse_schedule")
320+
compare_children(ps[version2].ec, converted.ec)
300321

301322

302323
def test_autofill_save_newer(ids_name, backend, worker_id, tmp_path):

imas/test/test_nc_validation.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import netCDF4
22
import numpy as np
33
import pytest
4-
54
from imas.backends.netcdf.ids2nc import IDS2NC
65
from imas.backends.netcdf.nc2ids import NC2IDS
76
from imas.backends.netcdf.nc_validate import validate_netcdf_file
@@ -32,7 +31,8 @@ def memfile_with_ids(memfile, factory):
3231
ids.profiles_1d[0].zeff = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
3332
IDS2NC(ids, memfile).run()
3433
# This one is valid:
35-
NC2IDS(memfile, factory.core_profiles()).run()
34+
ids = factory.core_profiles()
35+
NC2IDS(memfile, ids, ids.metadata, None).run()
3636
return memfile
3737

3838

@@ -51,66 +51,75 @@ def test_invalid_homogeneous_time(memfile, factory):
5151

5252
ids = factory.core_profiles()
5353
with pytest.raises(InvalidNetCDFEntry):
54-
NC2IDS(empty_group, ids) # ids_properties.homogeneous_time does not exist
54+
# ids_properties.homogeneous_time does not exist
55+
NC2IDS(empty_group, ids, ids.metadata, None)
5556
with pytest.raises(InvalidNetCDFEntry):
56-
NC2IDS(invalid_dtype, ids)
57+
NC2IDS(invalid_dtype, ids, ids.metadata, None)
5758
with pytest.raises(InvalidNetCDFEntry):
58-
NC2IDS(invalid_shape, ids)
59+
NC2IDS(invalid_shape, ids, ids.metadata, None)
5960
with pytest.raises(InvalidNetCDFEntry):
60-
NC2IDS(invalid_value, ids)
61+
NC2IDS(invalid_value, ids, ids.metadata, None)
6162

6263

6364
def test_invalid_units(memfile_with_ids, factory):
6465
memfile_with_ids["time"].units = "hours"
66+
ids = factory.core_profiles()
6567
with pytest.raises(InvalidNetCDFEntry):
66-
NC2IDS(memfile_with_ids, factory.core_profiles()).run()
68+
NC2IDS(memfile_with_ids, ids, ids.metadata, None).run()
6769

6870

6971
def test_invalid_documentation(memfile_with_ids, factory, caplog):
72+
ids = factory.core_profiles()
7073
with caplog.at_level("WARNING"):
71-
NC2IDS(memfile_with_ids, factory.core_profiles()).run()
74+
NC2IDS(memfile_with_ids, ids, ids.metadata, None).run()
7275
assert not caplog.records
7376
# Invalid docstring logs a warning
7477
memfile_with_ids["time"].documentation = "https://en.wikipedia.org/wiki/Time"
7578
with caplog.at_level("WARNING"):
76-
NC2IDS(memfile_with_ids, factory.core_profiles()).run()
79+
NC2IDS(memfile_with_ids, ids, ids.metadata, None).run()
7780
assert len(caplog.records) == 1
7881

7982

8083
def test_invalid_dimension_name(memfile_with_ids, factory):
8184
memfile_with_ids.renameDimension("time", "T")
85+
ids = factory.core_profiles()
8286
with pytest.raises(InvalidNetCDFEntry):
83-
NC2IDS(memfile_with_ids, factory.core_profiles()).run()
87+
NC2IDS(memfile_with_ids, ids, ids.metadata, None).run()
8488

8589

8690
def test_invalid_coordinates(memfile_with_ids, factory):
8791
memfile_with_ids["profiles_1d.grid.rho_tor_norm"].coordinates = "xyz"
92+
ids = factory.core_profiles()
8893
with pytest.raises(InvalidNetCDFEntry):
89-
NC2IDS(memfile_with_ids, factory.core_profiles()).run()
94+
NC2IDS(memfile_with_ids, ids, ids.metadata, None).run()
9095

9196

9297
def test_invalid_ancillary_variables(memfile_with_ids, factory):
9398
memfile_with_ids["time"].ancillary_variables = "xyz"
99+
ids = factory.core_profiles()
94100
with pytest.raises(InvalidNetCDFEntry):
95-
NC2IDS(memfile_with_ids, factory.core_profiles()).run()
101+
NC2IDS(memfile_with_ids, ids, ids.metadata, None).run()
96102

97103

98104
def test_extra_attributes(memfile_with_ids, factory):
99105
memfile_with_ids["time"].new_attribute = [1, 2, 3]
106+
ids = factory.core_profiles()
100107
with pytest.raises(InvalidNetCDFEntry):
101-
NC2IDS(memfile_with_ids, factory.core_profiles()).run()
108+
NC2IDS(memfile_with_ids, ids, ids.metadata, None).run()
102109

103110

104111
def test_shape_array_without_data(memfile_with_ids, factory):
105112
memfile_with_ids.createVariable("profiles_1d.t_i_average:shape", int, ())
113+
ids = factory.core_profiles()
106114
with pytest.raises(InvalidNetCDFEntry):
107-
NC2IDS(memfile_with_ids, factory.core_profiles()).run()
115+
NC2IDS(memfile_with_ids, ids, ids.metadata, None).run()
108116

109117

110118
def test_shape_array_without_sparse_data(memfile_with_ids, factory):
111119
memfile_with_ids.createVariable("profiles_1d.grid.rho_tor_norm:shape", int, ())
120+
ids = factory.core_profiles()
112121
with pytest.raises(InvalidNetCDFEntry):
113-
NC2IDS(memfile_with_ids, factory.core_profiles()).run()
122+
NC2IDS(memfile_with_ids, ids, ids.metadata, None).run()
114123

115124

116125
def test_shape_array_with_invalid_dimensions(memfile_with_ids, factory):
@@ -128,7 +137,7 @@ def test_shape_array_with_invalid_dimensions(memfile_with_ids, factory):
128137
("time", "profiles_1d.grid.rho_tor_norm:i"),
129138
)
130139
with pytest.raises(InvalidNetCDFEntry):
131-
NC2IDS(memfile_with_ids, cp).run()
140+
NC2IDS(memfile_with_ids, cp, cp.metadata, None).run()
132141

133142

134143
def test_shape_array_with_invalid_dtype(memfile_with_ids, factory):
@@ -144,7 +153,7 @@ def test_shape_array_with_invalid_dtype(memfile_with_ids, factory):
144153
"profiles_1d.t_i_average:shape", float, ("time", "1D")
145154
)
146155
with pytest.raises(InvalidNetCDFEntry):
147-
NC2IDS(memfile_with_ids, cp).run()
156+
NC2IDS(memfile_with_ids, cp, cp.metadata, None).run()
148157

149158

150159
def test_validate_nc(tmpdir):

0 commit comments

Comments
 (0)