Skip to content

Commit 2429595

Browse files
maarten-icolivhoenen
authored andcommitted
Refactor ids2nc, extract common tensorization logic in IDSTensorizer
Allows reuse of functionality to (partially) convert IDSs to xarray Datasets.
1 parent 36cebdf commit 2429595

File tree

2 files changed

+177
-134
lines changed

2 files changed

+177
-134
lines changed

imas/backends/netcdf/ids2nc.py

Lines changed: 4 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@
33
"""NetCDF IO support for IMAS-Python. Requires [netcdf] extra dependencies."""
44

55
import netCDF4
6-
import numpy
76
from packaging import version
87

9-
from imas.backends.netcdf.nc_metadata import NCMetadata
10-
from imas.backends.netcdf.iterators import indexed_tree_iter
8+
from imas.backends.netcdf.ids_tensorizer import SHAPE_DTYPE, IDSTensorizer, dtypes
119
from imas.exception import InvalidNetCDFEntry
1210
from imas.ids_data_type import IDSDataType
13-
from imas.ids_defs import IDS_TIME_MODE_HOMOGENEOUS
1411
from imas.ids_toplevel import IDSToplevel
1512

1613
default_fillvals = {
@@ -19,16 +16,9 @@
1916
IDSDataType.FLT: netCDF4.default_fillvals["f8"],
2017
IDSDataType.CPX: netCDF4.default_fillvals["f8"] * (1 + 1j),
2118
}
22-
dtypes = {
23-
IDSDataType.INT: numpy.dtype(numpy.int32),
24-
IDSDataType.STR: str,
25-
IDSDataType.FLT: numpy.dtype(numpy.float64),
26-
IDSDataType.CPX: numpy.dtype(numpy.complex128),
27-
}
28-
SHAPE_DTYPE = numpy.int32
2919

3020

31-
class IDS2NC:
21+
class IDS2NC(IDSTensorizer):
3222
"""Class responsible for storing an IDS to a NetCDF file."""
3323

3424
def __init__(self, ids: IDSToplevel, group: netCDF4.Group) -> None:
@@ -38,112 +28,18 @@ def __init__(self, ids: IDSToplevel, group: netCDF4.Group) -> None:
3828
ids: IDSToplevel to store in the netCDF group
3929
group: Empty netCDF group to store the IDS in.
4030
"""
41-
self.ids = ids
42-
"""IDS to store."""
31+
super().__init__(ids, []) # pass empty list: tensorize full IDS
4332
self.group = group
4433
"""NetCDF Group to store the IDS in."""
4534

46-
self.ncmeta = NCMetadata(ids.metadata)
47-
"""NetCDF related metadata."""
48-
self.dimension_size = {}
49-
"""Map dimension name to its size."""
50-
self.filled_data = {}
51-
"""Map of IDS paths to filled data nodes."""
52-
self.filled_variables = set()
53-
"""Set of filled IDS variables"""
54-
self.homogeneous_time = (
55-
ids.ids_properties.homogeneous_time == IDS_TIME_MODE_HOMOGENEOUS
56-
)
57-
"""True iff the IDS time mode is homogeneous."""
58-
self.shapes = {}
59-
"""Map of IDS paths to data shape arrays."""
60-
6135
def run(self) -> None:
6236
"""Store the IDS in the NetCDF group."""
6337
self.collect_filled_data()
6438
self.determine_data_shapes()
6539
self.create_dimensions()
6640
self.create_variables()
67-
# Synchronize variables to disk
68-
# This is not strictly required (automatically done by netCDF4 when needed), but
69-
# by separating it we get more meaningful profiling statistics
70-
self.group.sync()
7141
self.store_data()
7242

73-
def collect_filled_data(self) -> None:
74-
"""Collect all filled data in the IDS and determine dimension sizes.
75-
76-
Results are stored in :attr:`filled_data` and :attr:`dimension_size`.
77-
"""
78-
# Initialize dictionary with all paths that could exist in this IDS
79-
filled_data = {path: {} for path in self.ncmeta.paths}
80-
dimension_size = {}
81-
get_dimensions = self.ncmeta.get_dimensions
82-
83-
for aos_index, node in indexed_tree_iter(self.ids):
84-
path = node.metadata.path_string
85-
filled_data[path][aos_index] = node
86-
ndim = node.metadata.ndim
87-
if not ndim:
88-
continue
89-
dimensions = get_dimensions(path, self.homogeneous_time)
90-
# We're only interested in the non-tensorized dimensions: [-ndim:]
91-
for dim_name, size in zip(dimensions[-ndim:], node.shape):
92-
dimension_size[dim_name] = max(dimension_size.get(dim_name, 0), size)
93-
94-
# Remove paths without data
95-
self.filled_data = {path: data for path, data in filled_data.items() if data}
96-
self.filled_variables = {path.replace("/", ".") for path in self.filled_data}
97-
# Store dimension sizes
98-
self.dimension_size = dimension_size
99-
100-
def determine_data_shapes(self) -> None:
101-
"""Determine tensorized data shapes and sparsity, save in :attr:`shapes`."""
102-
get_dimensions = self.ncmeta.get_dimensions
103-
104-
for path, nodes_dict in self.filled_data.items():
105-
metadata = self.ids.metadata[path]
106-
# Structures don't have a size
107-
if metadata.data_type is IDSDataType.STRUCTURE:
108-
continue
109-
ndim = metadata.ndim
110-
dimensions = get_dimensions(path, self.homogeneous_time)
111-
112-
# node shape if it is completely filled
113-
full_shape = tuple(self.dimension_size[dim] for dim in dimensions[-ndim:])
114-
115-
if len(dimensions) == ndim:
116-
# Data at this path is not tensorized
117-
node = nodes_dict[()]
118-
sparse = node.shape != full_shape
119-
if sparse:
120-
shapes = numpy.array(node.shape, dtype=SHAPE_DTYPE)
121-
122-
else:
123-
# Data is tensorized, determine if it is homogeneously shaped
124-
aos_dims = get_dimensions(self.ncmeta.aos[path], self.homogeneous_time)
125-
shapes_shape = [self.dimension_size[dim] for dim in aos_dims]
126-
if ndim:
127-
shapes_shape.append(ndim)
128-
shapes = numpy.zeros(shapes_shape, dtype=SHAPE_DTYPE)
129-
130-
if ndim: # ND types have a shape
131-
for aos_coords, node in nodes_dict.items():
132-
shapes[aos_coords] = node.shape
133-
sparse = not numpy.array_equiv(shapes, full_shape)
134-
135-
else: # 0D types don't have a shape
136-
for aos_coords in nodes_dict.keys():
137-
shapes[aos_coords] = 1
138-
sparse = not shapes.all()
139-
shapes = None
140-
141-
if sparse:
142-
self.shapes[path] = shapes
143-
if ndim:
144-
# Ensure there is a pseudo-dimension f"{ndim}D" for shapes variable
145-
self.dimension_size[f"{ndim}D"] = ndim
146-
14743
def create_dimensions(self) -> None:
14844
"""Create netCDF dimensions."""
14945
for dimension, size in self.dimension_size.items():
@@ -228,14 +124,6 @@ def create_variables(self) -> None:
228124
"shape is unset (i.e. filled with _Fillvalue)."
229125
)
230126

231-
def filter_coordinates(self, path: str) -> str:
232-
"""Filter the coordinates list from NCMetadata to filled variables only."""
233-
return " ".join(
234-
coordinate
235-
for coordinate in self.ncmeta.get_coordinates(path, self.homogeneous_time)
236-
if coordinate in self.filled_variables
237-
)
238-
239127
def store_data(self) -> None:
240128
"""Store data in the netCDF variables"""
241129
for path, nodes_dict in self.filled_data.items():
@@ -273,22 +161,4 @@ def store_data(self) -> None:
273161

274162
else:
275163
# Data is tensorized: tensorize in-memory
276-
# TODO: depending on the data, tmp_var may be HUGE, we may need a more
277-
# efficient assignment algorithm for large and/or irregular data
278-
tmp_var = numpy.full(var.shape, default_fillvals[metadata.data_type])
279-
if metadata.data_type is IDSDataType.STR:
280-
tmp_var = numpy.asarray(tmp_var, dtype=object)
281-
282-
# Fill tmp_var
283-
if shapes is None:
284-
# Data is not sparse, so we can assign to the aos_coords
285-
for aos_coords, node in nodes_dict.items():
286-
tmp_var[aos_coords] = node.value
287-
else:
288-
# Data is sparse, so we must select a slice
289-
for aos_coords, node in nodes_dict.items():
290-
tmp_var[aos_coords + tuple(map(slice, node.shape))] = node.value
291-
292-
# Assign data to variable
293-
var[()] = tmp_var
294-
del tmp_var
164+
var[()] = self.tensorize(path, default_fillvals[metadata.data_type])
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# This file is part of IMAS-Python.
2+
# You should have received the IMAS-Python LICENSE file with this project.
3+
"""Tensorization logic to convert IDSs to netCDF files and/or xarray Datasets."""
4+
5+
from typing import List
6+
7+
import numpy
8+
9+
from imas.backends.netcdf.iterators import indexed_tree_iter
10+
from imas.backends.netcdf.nc_metadata import NCMetadata
11+
from imas.ids_data_type import IDSDataType
12+
from imas.ids_defs import IDS_TIME_MODE_HOMOGENEOUS
13+
from imas.ids_toplevel import IDSToplevel
14+
15+
dtypes = {
16+
IDSDataType.INT: numpy.dtype(numpy.int32),
17+
IDSDataType.STR: str,
18+
IDSDataType.FLT: numpy.dtype(numpy.float64),
19+
IDSDataType.CPX: numpy.dtype(numpy.complex128),
20+
}
21+
SHAPE_DTYPE = numpy.int32
22+
23+
24+
class IDSTensorizer:
25+
"""Common functionality for tensorizing IDSs. Used in IDS2NC and util.to_xarray."""
26+
27+
def __init__(self, ids: IDSToplevel, paths_to_tensorize: List[str]) -> None:
28+
"""Initialize IDSTensorizer.
29+
30+
Args:
31+
ids: IDSToplevel to store in the netCDF group
32+
paths_to_tensorize: Restrict tensorization to the provided paths. If an
33+
empty list is provided, all filled quantities in the IDS will be
34+
tensorized.
35+
"""
36+
self.ids = ids
37+
"""IDS to tensorize."""
38+
self.paths_to_tensorize = paths_to_tensorize
39+
"""List of paths to tensorize"""
40+
41+
self.ncmeta = NCMetadata(ids.metadata)
42+
"""NetCDF related metadata."""
43+
self.dimension_size = {}
44+
"""Map dimension name to its size."""
45+
self.filled_data = {}
46+
"""Map of IDS paths to filled data nodes."""
47+
self.filled_variables = set()
48+
"""Set of filled IDS variables"""
49+
self.homogeneous_time = (
50+
ids.ids_properties.homogeneous_time == IDS_TIME_MODE_HOMOGENEOUS
51+
)
52+
"""True iff the IDS time mode is homogeneous."""
53+
self.shapes = {}
54+
"""Map of IDS paths to data shape arrays."""
55+
56+
def collect_filled_data(self) -> None:
57+
"""Collect all filled data in the IDS and determine dimension sizes.
58+
59+
Results are stored in :attr:`filled_data` and :attr:`dimension_size`.
60+
"""
61+
# Initialize dictionary with all paths that could exist in this IDS
62+
filled_data = {path: {} for path in self.ncmeta.paths}
63+
dimension_size = {}
64+
get_dimensions = self.ncmeta.get_dimensions
65+
66+
if self.paths_to_tensorize:
67+
# Restrict tensorization to provided paths
68+
iterator = (
69+
item
70+
for path in self.paths_to_tensorize
71+
for item in indexed_tree_iter(self.ids, self.ids.metadata[path])
72+
if item[1].has_value # Skip nodes without value set
73+
)
74+
else:
75+
# Tensorize all non-empty nodes
76+
iterator = indexed_tree_iter(self.ids)
77+
78+
for aos_index, node in iterator:
79+
path = node.metadata.path_string
80+
filled_data[path][aos_index] = node
81+
ndim = node.metadata.ndim
82+
if not ndim:
83+
continue
84+
dimensions = get_dimensions(path, self.homogeneous_time)
85+
# We're only interested in the non-tensorized dimensions: [-ndim:]
86+
for dim_name, size in zip(dimensions[-ndim:], node.shape):
87+
dimension_size[dim_name] = max(dimension_size.get(dim_name, 0), size)
88+
89+
# Remove paths without data
90+
self.filled_data = {path: data for path, data in filled_data.items() if data}
91+
self.filled_variables = {path.replace("/", ".") for path in self.filled_data}
92+
# Store dimension sizes
93+
self.dimension_size = dimension_size
94+
95+
def determine_data_shapes(self) -> None:
96+
"""Determine tensorized data shapes and sparsity, save in :attr:`shapes`."""
97+
get_dimensions = self.ncmeta.get_dimensions
98+
99+
for path, nodes_dict in self.filled_data.items():
100+
metadata = self.ids.metadata[path]
101+
# Structures don't have a size
102+
if metadata.data_type is IDSDataType.STRUCTURE:
103+
continue
104+
ndim = metadata.ndim
105+
dimensions = get_dimensions(path, self.homogeneous_time)
106+
107+
# node shape if it is completely filled
108+
full_shape = tuple(self.dimension_size[dim] for dim in dimensions[-ndim:])
109+
110+
if len(dimensions) == ndim:
111+
# Data at this path is not tensorized
112+
node = nodes_dict[()]
113+
sparse = node.shape != full_shape
114+
if sparse:
115+
shapes = numpy.array(node.shape, dtype=SHAPE_DTYPE)
116+
117+
else:
118+
# Data is tensorized, determine if it is homogeneously shaped
119+
aos_dims = get_dimensions(self.ncmeta.aos[path], self.homogeneous_time)
120+
shapes_shape = [self.dimension_size[dim] for dim in aos_dims]
121+
if ndim:
122+
shapes_shape.append(ndim)
123+
shapes = numpy.zeros(shapes_shape, dtype=SHAPE_DTYPE)
124+
125+
if ndim: # ND types have a shape
126+
for aos_coords, node in nodes_dict.items():
127+
shapes[aos_coords] = node.shape
128+
sparse = not numpy.array_equiv(shapes, full_shape)
129+
130+
else: # 0D types don't have a shape
131+
for aos_coords in nodes_dict.keys():
132+
shapes[aos_coords] = 1
133+
sparse = not shapes.all()
134+
shapes = None
135+
136+
if sparse:
137+
self.shapes[path] = shapes
138+
if ndim:
139+
# Ensure there is a pseudo-dimension f"{ndim}D" for shapes variable
140+
self.dimension_size[f"{ndim}D"] = ndim
141+
142+
def filter_coordinates(self, path: str) -> str:
143+
"""Filter the coordinates list from NCMetadata to filled variables only."""
144+
return " ".join(
145+
coordinate
146+
for coordinate in self.ncmeta.get_coordinates(path, self.homogeneous_time)
147+
if coordinate in self.filled_variables
148+
)
149+
150+
def tensorize(self, path, fillvalue):
151+
dimensions = self.ncmeta.get_dimensions(path, self.homogeneous_time)
152+
shape = tuple(self.dimension_size[dim] for dim in dimensions)
153+
154+
# TODO: depending on the data, tmp_var may be HUGE, we may need a more
155+
# efficient assignment algorithm for large and/or irregular data
156+
tmp_var = numpy.full(shape, fillvalue)
157+
if isinstance(fillvalue, str):
158+
tmp_var = numpy.asarray(tmp_var, dtype=object)
159+
160+
shapes = self.shapes.get(path)
161+
nodes_dict = self.filled_data[path]
162+
163+
# Fill tmp_var
164+
if shapes is None:
165+
# Data is not sparse, so we can assign to the aos_coords
166+
for aos_coords, node in nodes_dict.items():
167+
tmp_var[aos_coords] = node.value
168+
else:
169+
# Data is sparse, so we must select a slice
170+
for aos_coords, node in nodes_dict.items():
171+
tmp_var[aos_coords + tuple(map(slice, node.shape))] = node.value
172+
173+
return tmp_var

0 commit comments

Comments
 (0)