Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/releases/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ Next release (in development)
(:pr:`220`).
* Add :func:`emsarray.utils.estimate_bounds_1d` function
(:pr:`221`).
* Bounds variables are now included in :meth:`Convention.get_all_geometry_names()`
(:pr:`222`).
8 changes: 6 additions & 2 deletions src/emsarray/conventions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,11 @@ def wind(
axis : int, optional
The axis number that should be wound.
Optional, defaults to the last axis.
Mutually exclusive with the `linear_dimension` parameter.
linear_dimension : Hashable, optional
The axis number that should be wound.
The name of the dimension in the data array that should be wound.
Optional, defaults to the last dimension.
Mutually exclusive with the `axis` parameter.

Returns
-------
Expand Down Expand Up @@ -989,9 +991,11 @@ def wind(
axis : int, optional
The axis number that should be wound.
Optional, defaults to the last axis.
Mutually exclusive with the `linear_dimension` parameter.
linear_dimension : Hashable, optional
The axis number that should be wound.
The name of the dimension in the data array that should be wound.
Optional, defaults to the last dimension.
Mutually exclusive with the `axis` parameter.

Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions src/emsarray/conventions/arakawa_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _make_geometry_centroid(self, grid_kind: ArakawaCGridKind) -> numpy.ndarray:
return cast(numpy.ndarray, points)

def get_all_geometry_names(self) -> list[Hashable]:
return [
return utils.coordinates_plus_bounds(self.dataset, [
self.face.longitude.name,
self.face.latitude.name,
self.node.longitude.name,
Expand All @@ -381,7 +381,7 @@ def get_all_geometry_names(self) -> list[Hashable]:
self.left.latitude.name,
self.back.longitude.name,
self.back.latitude.name,
]
])

def make_clip_mask(
self,
Expand Down
14 changes: 2 additions & 12 deletions src/emsarray/conventions/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,20 +270,10 @@ def grid_dimensions(self) -> dict[CFGridKind, Sequence[Hashable]]:
def get_all_geometry_names(self) -> list[Hashable]:
# Grid datasets contain latitude and longitude variables
# plus optional bounds variables.
names = [
return utils.coordinates_plus_bounds(self.dataset, [
self.topology.longitude_name,
self.topology.latitude_name,
]

bounds_names: list[Hashable | None] = [
self.topology.longitude.attrs.get('bounds', None),
self.topology.latitude.attrs.get('bounds', None),
]
for bounds_name in bounds_names:
if bounds_name is not None and bounds_name in self.dataset.variables:
names.append(bounds_name)

return names
])

def drop_geometry(self) -> xarray.Dataset:
dataset = super().drop_geometry()
Expand Down
2 changes: 1 addition & 1 deletion src/emsarray/conventions/ugrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,7 +1410,7 @@ def get_all_geometry_names(self) -> list[Hashable]:
names.append(topology.face_x.name)
if topology.face_y is not None:
names.append(topology.face_y.name)
return names
return utils.coordinates_plus_bounds(self.dataset, names)

def drop_geometry(self) -> xarray.Dataset:
dataset = super().drop_geometry()
Expand Down
32 changes: 32 additions & 0 deletions src/emsarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,3 +1040,35 @@ def estimate_bounds_1d(
dataset = dataset.set_coords(bounds_name)
dataset[coordinate.name].attrs['bounds'] = bounds_name
return dataset


def coordinates_plus_bounds(dataset: xarray.Dataset, names: list[Hashable]) -> list[Hashable]:
"""
Given a list of coordinate variable names,
return a list of all these coordinates plus the names of their bounds variables,
if such bounds exist.

Parameters
----------
dataset : xarray.Dataset
The dataset with coordinate variables
names : list of Hashable
A list of coordinate variables

Returns
-------
list of Hashable
All of the coordinates in `names`,
plus any bounds variables named in the attributes of these coordinate variables.
"""
all_names = []
for name in names:
all_names.append(name)
data_array = dataset[name]
if 'bounds' not in data_array.attrs:
continue
bounds_name = data_array.attrs['bounds']
if bounds_name not in dataset.variables.keys():
continue
all_names.append(bounds_name)
return all_names
8 changes: 8 additions & 0 deletions tests/conventions/test_cfgrid1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,14 @@ def test_grid_kinds():
assert convention.default_grid_kind == CFGridKind.face


def test_get_all_geometry_names():
dataset = make_dataset(width=3, height=3, bounds=True)
assert set(dataset.ems.get_all_geometry_names()) == {
'lon', 'lon_bounds',
'lat', 'lat_bounds',
}


def test_drop_geometry(datasets: pathlib.Path):
dataset = xarray.open_dataset(datasets / 'cfgrid1d.nc')

Expand Down
27 changes: 27 additions & 0 deletions tests/conventions/test_ugrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,33 @@ def test_grid_kinds_without_edges():
assert convention.default_grid_kind == UGridKind.face


def test_get_all_geometry_names_with_edges():
dataset = make_dataset(width=3, make_edges=True, make_face_coordinates=True)
topology_names = dataset.ems.get_all_geometry_names()
assert set(topology_names) == {
'Mesh2',
'Mesh2_node_x',
'Mesh2_node_y',
'Mesh2_face_x',
'Mesh2_face_y',
'Mesh2_face_nodes',
'Mesh2_edge_nodes',
}


def test_get_all_geometry_names_without_edges():
dataset = make_dataset(width=3, make_edges=False, make_face_coordinates=True)
topology_names = dataset.ems.get_all_geometry_names()
assert set(topology_names) == {
'Mesh2',
'Mesh2_node_x',
'Mesh2_node_y',
'Mesh2_face_x',
'Mesh2_face_y',
'Mesh2_face_nodes',
}


def test_drop_geometry_minimal():
dataset = make_dataset(width=3, make_edges=False, make_face_coordinates=False)
topology = dataset.ems.topology
Expand Down
27 changes: 27 additions & 0 deletions tests/utils/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,3 +543,30 @@ def test_wind_dimension_renamed():
)
wound = utils.wind_dimension(data_array, ['y', 'x'], [5, 4], linear_dimension='ix')
xarray.testing.assert_equal(wound, expected)


def test_coordinates_with_bounds():
lat_bounds = numpy.arange(5)
lon_bounds = numpy.arange(7)
depth_bounds = numpy.arange(6)
dataset = xarray.Dataset(
{
'lat_bounds': (('lat', 'two'), numpy.c_[lat_bounds[:-1], lat_bounds[1:]]),
'depth_bounds': (('depth', 'two'), numpy.c_[depth_bounds[:-1], depth_bounds[1:]]),
'temp': (('lat', 'lon'), numpy.arange(4 * 6).reshape((4, 6)), {'standard_name': 'temp'}),
},
coords={
# lat has bounds attribute, and lat_bounds exists
'lat': ('lat', (lat_bounds[:-1] + lat_bounds[1:]) / 2, {'bounds': 'lat_bounds'}),
# lon has bounds atttribute, but lon_bounds doesn't exist
'lon': ('lon', (lon_bounds[:-1] + lon_bounds[1:]) / 2, {'bounds': 'lon_bounds'}),
# time doesn't have bounds attribute
'time': ('time', pandas.date_range('2026-03-04', periods=5)),
# depth has bounds but isn't included in the list of coordinates
'depth': ('depth', (depth_bounds[:-1] + depth_bounds[1:]) / 2, {'bounds': 'depth_bounds'}),
}
)

assert utils.coordinates_plus_bounds(dataset, ['lat', 'lon', 'time']) == [
'lat', 'lat_bounds', 'lon', 'time'
]
Loading