Skip to content
3 changes: 3 additions & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ History

Latest
------
- ENH: Add `convention` option to `set_options()` for future multi-convention support (pull #899)
- ENH: Add read support for Zarr spatial and proj conventions (pull #XXX)
- REF: Extract CF convention logic to `_convention/cf.py` module (pull #899)


0.21.0
Expand Down
9 changes: 9 additions & 0 deletions docs/rioxarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ rioxarray.show_versions
.. autofunction:: rioxarray.show_versions


rioxarray.enum module
---------------------

.. automodule:: rioxarray.enum
:members:
:undoc-members:
:show-inheritance:


rioxarray `rio` accessors
--------------------------

Expand Down
1 change: 1 addition & 0 deletions rioxarray/_convention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Convention modules for rioxarray."""
298 changes: 298 additions & 0 deletions rioxarray/_convention/cf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
"""
CF (Climate and Forecasts) convention support for rioxarray.

This module provides functions for reading and writing geospatial metadata according to
the CF conventions: https://github.com/cf-convention/cf-conventions
"""
from typing import Optional, Tuple, Union

import numpy
import pyproj
import rasterio.crs
import xarray
from affine import Affine

from rioxarray._options import EXPORT_GRID_MAPPING, get_option
from rioxarray.crs import crs_from_user_input


def _find_grid_mapping(
obj: Union[xarray.Dataset, xarray.DataArray],
grid_mapping: Optional[str] = None,
) -> Optional[str]:
"""
Find the grid_mapping coordinate name.

Parameters
----------
obj : xarray.Dataset or xarray.DataArray
Object to search for grid_mapping
grid_mapping : str, optional
Explicit grid_mapping name to use

Returns
-------
str or None
The grid_mapping name, or None if not found
"""
if grid_mapping is not None:
return grid_mapping

# Try to find grid_mapping attribute on data variables
if hasattr(obj, "data_vars"):
for data_var in obj.data_vars.values():
if "grid_mapping" in data_var.attrs:
return data_var.attrs["grid_mapping"]
if "grid_mapping" in data_var.encoding:
return data_var.encoding["grid_mapping"]

if hasattr(obj, "attrs") and "grid_mapping" in obj.attrs:
return obj.attrs["grid_mapping"]

if hasattr(obj, "encoding") and "grid_mapping" in obj.encoding:
return obj.encoding["grid_mapping"]

return None


def read_crs(
obj: Union[xarray.Dataset, xarray.DataArray], grid_mapping: Optional[str] = None
) -> Optional[rasterio.crs.CRS]:
"""
Read CRS from CF conventions.

Parameters
----------
obj : xarray.Dataset or xarray.DataArray
Object to read CRS from
grid_mapping : str, optional
Name of the grid_mapping coordinate variable

Returns
-------
rasterio.crs.CRS or None
CRS object, or None if not found
"""
grid_mapping = _find_grid_mapping(obj, grid_mapping)

if grid_mapping is not None:
try:
grid_mapping_coord = obj.coords[grid_mapping]

# Look in wkt attributes first for performance
for crs_attr in ("spatial_ref", "crs_wkt"):
try:
return crs_from_user_input(grid_mapping_coord.attrs[crs_attr])
except KeyError:
pass

# Look in grid_mapping CF attributes
try:
return pyproj.CRS.from_cf(grid_mapping_coord.attrs)
except (KeyError, pyproj.exceptions.CRSError):
pass
except KeyError:
# grid_mapping coordinate doesn't exist, fall through to attrs check
pass

# Fallback: look in attrs for 'crs'
try:
return crs_from_user_input(obj.attrs["crs"])
except KeyError:
return None


def read_transform(
obj: Union[xarray.Dataset, xarray.DataArray], grid_mapping: Optional[str] = None
) -> Optional[Affine]:
"""
Read transform from CF conventions (GeoTransform attribute).

Parameters
----------
obj : xarray.Dataset or xarray.DataArray
Object to read transform from
grid_mapping : str, optional
Name of the grid_mapping coordinate variable

Returns
-------
affine.Affine or None
Transform object, or None if not found
"""
grid_mapping = _find_grid_mapping(obj, grid_mapping)

if grid_mapping is not None:
try:
transform = numpy.fromstring(
obj.coords[grid_mapping].attrs["GeoTransform"], sep=" "
)
# Calling .tolist() to assure the arguments are Python float and JSON serializable
return Affine.from_gdal(*transform.tolist())
except KeyError:
pass

# Fallback: look in attrs for 'transform'
try:
return Affine(*obj.attrs["transform"][:6])
except KeyError:
pass

return None


def read_spatial_dimensions(
obj: Union[xarray.Dataset, xarray.DataArray],
) -> Optional[Tuple[str, str]]:
"""
Read spatial dimensions from CF conventions.

This function detects spatial dimensions based on:
1. Standard dimension names ('x'/'y', 'longitude'/'latitude')
2. CF coordinate attributes ('axis', 'standard_name')

Parameters
----------
obj : xarray.Dataset or xarray.DataArray
Object to read spatial dimensions from

Returns
-------
tuple of (y_dim, x_dim) or None
Tuple of dimension names, or None if not found
"""
x_dim = None
y_dim = None

# Check standard dimension names
if "x" in obj.dims and "y" in obj.dims:
return "y", "x"
elif "longitude" in obj.dims and "latitude" in obj.dims:
return "latitude", "longitude"

# Look for coordinates with CF attributes
for coord in obj.coords:
# Make sure to only look in 1D coordinates
# that has the same dimension name as the coordinate
if obj.coords[coord].dims != (coord,):
continue
if (obj.coords[coord].attrs.get("axis", "").upper() == "X") or (
obj.coords[coord].attrs.get("standard_name", "").lower()
in ("longitude", "projection_x_coordinate")
):
x_dim = coord
elif (obj.coords[coord].attrs.get("axis", "").upper() == "Y") or (
obj.coords[coord].attrs.get("standard_name", "").lower()
in ("latitude", "projection_y_coordinate")
):
y_dim = coord

if x_dim is not None and y_dim is not None:
return y_dim, x_dim

return None


def write_crs(
obj: Union[xarray.Dataset, xarray.DataArray],
crs: rasterio.crs.CRS,
grid_mapping_name: str,
inplace: bool = True,
) -> Union[xarray.Dataset, xarray.DataArray]:
"""
Write CRS using CF conventions.

Parameters
----------
obj : xarray.Dataset or xarray.DataArray
Object to write CRS to
crs : rasterio.crs.CRS
CRS to write
grid_mapping_name : str
Name of the grid_mapping coordinate
inplace : bool, default True
If True, modify object in place

Returns
-------
xarray.Dataset or xarray.DataArray
Object with CRS written
"""
obj_out = obj if inplace else obj.copy(deep=True)

# Get original transform before modifying
transform = read_transform(obj)

# Remove old grid mapping coordinate if exists
try:
del obj_out.coords[grid_mapping_name]
except KeyError:
pass

# Add grid mapping coordinate
obj_out.coords[grid_mapping_name] = xarray.Variable((), 0)
grid_map_attrs = {}
if get_option(EXPORT_GRID_MAPPING):
try:
grid_map_attrs = pyproj.CRS.from_user_input(crs).to_cf()
except KeyError:
pass

# spatial_ref is for compatibility with GDAL
crs_wkt = crs.to_wkt()
grid_map_attrs["spatial_ref"] = crs_wkt
grid_map_attrs["crs_wkt"] = crs_wkt
if transform is not None:
grid_map_attrs["GeoTransform"] = " ".join(
[str(item) for item in transform.to_gdal()]
)
obj_out.coords[grid_mapping_name].attrs = grid_map_attrs

# Remove old crs if exists
obj_out.attrs.pop("crs", None)

return obj_out


def write_transform(
obj: Union[xarray.Dataset, xarray.DataArray],
transform: Affine,
grid_mapping_name: str,
inplace: bool = True,
) -> Union[xarray.Dataset, xarray.DataArray]:
"""
Write transform using CF conventions (GeoTransform attribute).

Parameters
----------
obj : xarray.Dataset or xarray.DataArray
Object to write transform to
transform : affine.Affine
Transform to write
grid_mapping_name : str
Name of the grid_mapping coordinate
inplace : bool, default True
If True, modify object in place

Returns
-------
xarray.Dataset or xarray.DataArray
Object with transform written
"""
obj_out = obj if inplace else obj.copy(deep=True)

# Delete the old attribute to prevent confusion
obj_out.attrs.pop("transform", None)

try:
grid_map_attrs = obj_out.coords[grid_mapping_name].attrs.copy()
except KeyError:
obj_out.coords[grid_mapping_name] = xarray.Variable((), 0)
grid_map_attrs = obj_out.coords[grid_mapping_name].attrs.copy()

grid_map_attrs["GeoTransform"] = " ".join(
[str(item) for item in transform.to_gdal()]
)
obj_out.coords[grid_mapping_name].attrs = grid_map_attrs

return obj_out
Loading