diff --git a/docs/releases/development.rst b/docs/releases/development.rst index f8208a0..3710844 100644 --- a/docs/releases/development.rst +++ b/docs/releases/development.rst @@ -27,3 +27,5 @@ Next release (in development) (:pr:`220`). * Add :func:`emsarray.utils.estimate_bounds_1d` function (:pr:`221`). +* Bounds variables are now included in :meth:`Convention.get_all_geometry_names()` + (:pr:`222`). diff --git a/src/emsarray/conventions/_base.py b/src/emsarray/conventions/_base.py index c54e7ab..d38bd93 100644 --- a/src/emsarray/conventions/_base.py +++ b/src/emsarray/conventions/_base.py @@ -251,9 +251,11 @@ def wind( axis : int, optional The axis number that should be wound. Optional, defaults to the last axis. + Mutually exclusive with the `linear_dimension` parameter. linear_dimension : Hashable, optional - The axis number that should be wound. + The name of the dimension in the data array that should be wound. Optional, defaults to the last dimension. + Mutually exclusive with the `axis` parameter. Returns ------- @@ -989,9 +991,11 @@ def wind( axis : int, optional The axis number that should be wound. Optional, defaults to the last axis. + Mutually exclusive with the `linear_dimension` parameter. linear_dimension : Hashable, optional - The axis number that should be wound. + The name of the dimension in the data array that should be wound. Optional, defaults to the last dimension. + Mutually exclusive with the `axis` parameter. Returns ------- diff --git a/src/emsarray/conventions/arakawa_c.py b/src/emsarray/conventions/arakawa_c.py index 69a67d9..998c889 100644 --- a/src/emsarray/conventions/arakawa_c.py +++ b/src/emsarray/conventions/arakawa_c.py @@ -372,7 +372,7 @@ def _make_geometry_centroid(self, grid_kind: ArakawaCGridKind) -> numpy.ndarray: return cast(numpy.ndarray, points) def get_all_geometry_names(self) -> list[Hashable]: - return [ + return utils.coordinates_plus_bounds(self.dataset, [ self.face.longitude.name, self.face.latitude.name, self.node.longitude.name, @@ -381,7 +381,7 @@ def get_all_geometry_names(self) -> list[Hashable]: self.left.latitude.name, self.back.longitude.name, self.back.latitude.name, - ] + ]) def make_clip_mask( self, diff --git a/src/emsarray/conventions/grid.py b/src/emsarray/conventions/grid.py index 2599e5b..19718fc 100644 --- a/src/emsarray/conventions/grid.py +++ b/src/emsarray/conventions/grid.py @@ -270,20 +270,10 @@ def grid_dimensions(self) -> dict[CFGridKind, Sequence[Hashable]]: def get_all_geometry_names(self) -> list[Hashable]: # Grid datasets contain latitude and longitude variables # plus optional bounds variables. - names = [ + return utils.coordinates_plus_bounds(self.dataset, [ self.topology.longitude_name, self.topology.latitude_name, - ] - - bounds_names: list[Hashable | None] = [ - self.topology.longitude.attrs.get('bounds', None), - self.topology.latitude.attrs.get('bounds', None), - ] - for bounds_name in bounds_names: - if bounds_name is not None and bounds_name in self.dataset.variables: - names.append(bounds_name) - - return names + ]) def drop_geometry(self) -> xarray.Dataset: dataset = super().drop_geometry() diff --git a/src/emsarray/conventions/ugrid.py b/src/emsarray/conventions/ugrid.py index 482b2ef..37cf701 100644 --- a/src/emsarray/conventions/ugrid.py +++ b/src/emsarray/conventions/ugrid.py @@ -1410,7 +1410,7 @@ def get_all_geometry_names(self) -> list[Hashable]: names.append(topology.face_x.name) if topology.face_y is not None: names.append(topology.face_y.name) - return names + return utils.coordinates_plus_bounds(self.dataset, names) def drop_geometry(self) -> xarray.Dataset: dataset = super().drop_geometry() diff --git a/src/emsarray/utils.py b/src/emsarray/utils.py index 13ac352..24664ed 100644 --- a/src/emsarray/utils.py +++ b/src/emsarray/utils.py @@ -1040,3 +1040,35 @@ def estimate_bounds_1d( dataset = dataset.set_coords(bounds_name) dataset[coordinate.name].attrs['bounds'] = bounds_name return dataset + + +def coordinates_plus_bounds(dataset: xarray.Dataset, names: list[Hashable]) -> list[Hashable]: + """ + Given a list of coordinate variable names, + return a list of all these coordinates plus the names of their bounds variables, + if such bounds exist. + + Parameters + ---------- + dataset : xarray.Dataset + The dataset with coordinate variables + names : list of Hashable + A list of coordinate variables + + Returns + ------- + list of Hashable + All of the coordinates in `names`, + plus any bounds variables named in the attributes of these coordinate variables. + """ + all_names = [] + for name in names: + all_names.append(name) + data_array = dataset[name] + if 'bounds' not in data_array.attrs: + continue + bounds_name = data_array.attrs['bounds'] + if bounds_name not in dataset.variables.keys(): + continue + all_names.append(bounds_name) + return all_names diff --git a/tests/conventions/test_cfgrid1d.py b/tests/conventions/test_cfgrid1d.py index 057d070..0ce8210 100644 --- a/tests/conventions/test_cfgrid1d.py +++ b/tests/conventions/test_cfgrid1d.py @@ -375,6 +375,14 @@ def test_grid_kinds(): assert convention.default_grid_kind == CFGridKind.face +def test_get_all_geometry_names(): + dataset = make_dataset(width=3, height=3, bounds=True) + assert set(dataset.ems.get_all_geometry_names()) == { + 'lon', 'lon_bounds', + 'lat', 'lat_bounds', + } + + def test_drop_geometry(datasets: pathlib.Path): dataset = xarray.open_dataset(datasets / 'cfgrid1d.nc') diff --git a/tests/conventions/test_ugrid.py b/tests/conventions/test_ugrid.py index 04edc65..4e2284d 100644 --- a/tests/conventions/test_ugrid.py +++ b/tests/conventions/test_ugrid.py @@ -568,6 +568,33 @@ def test_grid_kinds_without_edges(): assert convention.default_grid_kind == UGridKind.face +def test_get_all_geometry_names_with_edges(): + dataset = make_dataset(width=3, make_edges=True, make_face_coordinates=True) + topology_names = dataset.ems.get_all_geometry_names() + assert set(topology_names) == { + 'Mesh2', + 'Mesh2_node_x', + 'Mesh2_node_y', + 'Mesh2_face_x', + 'Mesh2_face_y', + 'Mesh2_face_nodes', + 'Mesh2_edge_nodes', + } + + +def test_get_all_geometry_names_without_edges(): + dataset = make_dataset(width=3, make_edges=False, make_face_coordinates=True) + topology_names = dataset.ems.get_all_geometry_names() + assert set(topology_names) == { + 'Mesh2', + 'Mesh2_node_x', + 'Mesh2_node_y', + 'Mesh2_face_x', + 'Mesh2_face_y', + 'Mesh2_face_nodes', + } + + def test_drop_geometry_minimal(): dataset = make_dataset(width=3, make_edges=False, make_face_coordinates=False) topology = dataset.ems.topology diff --git a/tests/utils/test_xarray.py b/tests/utils/test_xarray.py index 3b7b761..0837435 100644 --- a/tests/utils/test_xarray.py +++ b/tests/utils/test_xarray.py @@ -543,3 +543,30 @@ def test_wind_dimension_renamed(): ) wound = utils.wind_dimension(data_array, ['y', 'x'], [5, 4], linear_dimension='ix') xarray.testing.assert_equal(wound, expected) + + +def test_coordinates_with_bounds(): + lat_bounds = numpy.arange(5) + lon_bounds = numpy.arange(7) + depth_bounds = numpy.arange(6) + dataset = xarray.Dataset( + { + 'lat_bounds': (('lat', 'two'), numpy.c_[lat_bounds[:-1], lat_bounds[1:]]), + 'depth_bounds': (('depth', 'two'), numpy.c_[depth_bounds[:-1], depth_bounds[1:]]), + 'temp': (('lat', 'lon'), numpy.arange(4 * 6).reshape((4, 6)), {'standard_name': 'temp'}), + }, + coords={ + # lat has bounds attribute, and lat_bounds exists + 'lat': ('lat', (lat_bounds[:-1] + lat_bounds[1:]) / 2, {'bounds': 'lat_bounds'}), + # lon has bounds atttribute, but lon_bounds doesn't exist + 'lon': ('lon', (lon_bounds[:-1] + lon_bounds[1:]) / 2, {'bounds': 'lon_bounds'}), + # time doesn't have bounds attribute + 'time': ('time', pandas.date_range('2026-03-04', periods=5)), + # depth has bounds but isn't included in the list of coordinates + 'depth': ('depth', (depth_bounds[:-1] + depth_bounds[1:]) / 2, {'bounds': 'depth_bounds'}), + } + ) + + assert utils.coordinates_plus_bounds(dataset, ['lat', 'lon', 'time']) == [ + 'lat', 'lat_bounds', 'lon', 'time' + ]