From 9cbdb55da9a87b2d29423fe878c63708aa0d3052 Mon Sep 17 00:00:00 2001 From: snowman2 Date: Tue, 3 Feb 2026 09:59:03 -0600 Subject: [PATCH] REF:convention: Reorganize convention modules & use rio accessor grid_mapping methods --- rioxarray/_convention/__init__.py | 216 +------------ rioxarray/_convention/_base.py | 52 ++++ rioxarray/_convention/_core.py | 174 +++++++++++ rioxarray/_convention/cf.py | 502 ++++++++++++------------------ rioxarray/rioxarray.py | 30 +- test/unit/test_convention_cf.py | 36 +-- 6 files changed, 468 insertions(+), 542 deletions(-) create mode 100644 rioxarray/_convention/_base.py create mode 100644 rioxarray/_convention/_core.py diff --git a/rioxarray/_convention/__init__.py b/rioxarray/_convention/__init__.py index 7818ff7b..adb035af 100644 --- a/rioxarray/_convention/__init__.py +++ b/rioxarray/_convention/__init__.py @@ -1,214 +1,4 @@ -"""Convention modules for rioxarray. - -This module defines the common interface for convention implementations -and provides helpers for selecting conventions. """ - -from typing import Dict, Optional, Protocol, Tuple, Union - -import rasterio.crs -import xarray -from affine import Affine - -from rioxarray._convention.cf import CFConvention -from rioxarray._options import CONVENTION, get_option -from rioxarray.crs import crs_from_user_input -from rioxarray.enum import Convention - - -class ConventionProtocol(Protocol): - """Protocol defining the interface for convention modules.""" - - @staticmethod - def read_crs( - obj: Union[xarray.Dataset, xarray.DataArray], **kwargs - ) -> Optional[rasterio.crs.CRS]: - """Read CRS from the object using this convention.""" - - @staticmethod - def read_transform( - obj: Union[xarray.Dataset, xarray.DataArray], **kwargs - ) -> Optional[Affine]: - """Read transform from the object using this convention.""" - - @staticmethod - def read_spatial_dimensions( - obj: Union[xarray.Dataset, xarray.DataArray], - ) -> Optional[Tuple[str, str]]: - """Read spatial dimensions (y_dim, x_dim) from the object using this convention.""" - - @staticmethod - def write_crs( - obj: Union[xarray.Dataset, xarray.DataArray], - crs: rasterio.crs.CRS, - **kwargs, - ) -> Union[xarray.Dataset, xarray.DataArray]: - """Write CRS to the object using this convention.""" - - @staticmethod - def write_transform( - obj: Union[xarray.Dataset, xarray.DataArray], - transform: Affine, - **kwargs, - ) -> Union[xarray.Dataset, xarray.DataArray]: - """Write transform to the object using this convention.""" - - -# Convention classes mapped by Convention enum -_CONVENTION_MODULES: Dict[Convention, ConventionProtocol] = { - Convention.CF: CFConvention # type: ignore[dict-item] -} - - -def _get_convention(convention: Convention | None) -> ConventionProtocol: - """ - Get the convention module for writing. - - Parameters - ---------- - convention : Convention enum value or None - The convention to use. If None, uses the global default. - - Returns - ------- - ConventionProtocol - The module implementing the convention - """ - if convention is None: - convention = get_option(CONVENTION) or Convention.CF - convention = Convention(convention) - return _CONVENTION_MODULES[convention] - - -def read_crs_auto( - obj: Union[xarray.Dataset, xarray.DataArray], - **kwargs, -) -> Optional[rasterio.crs.CRS]: - """ - Auto-detect and read CRS by trying convention readers. - - If a convention is set globally via set_options(), that convention - is tried first for better performance. Then other conventions are - tried as fallback. - - Parameters - ---------- - obj : xarray.Dataset or xarray.DataArray - Object to read CRS from - **kwargs - Convention-specific parameters (e.g., grid_mapping for CF) - - Returns - ------- - rasterio.crs.CRS or None - CRS object, or None if not found in any convention - """ - # Try the configured convention first (if set) - configured_convention = get_option(CONVENTION) - if configured_convention is not None: - result = _CONVENTION_MODULES[configured_convention].read_crs(obj, **kwargs) - if result is not None: - return result - - # Try all other conventions - for conv_enum, convention in _CONVENTION_MODULES.items(): - if conv_enum == configured_convention: - continue # Already tried this one - result = convention.read_crs(obj, **kwargs) - if result is not None: - return result - - # Legacy fallback: look in attrs for 'crs' (not part of any convention) - try: - return crs_from_user_input(obj.attrs["crs"]) - except KeyError: - pass - - return None - - -def read_transform_auto( - obj: Union[xarray.Dataset, xarray.DataArray], - **kwargs, -) -> Optional[Affine]: - """ - Auto-detect and read transform by trying convention readers. - - If a convention is set globally via set_options(), that convention - is tried first for better performance. Then other conventions are - tried as fallback. - - Parameters - ---------- - obj : xarray.Dataset or xarray.DataArray - Object to read transform from - **kwargs - Convention-specific parameters (e.g., grid_mapping for CF) - - Returns - ------- - affine.Affine or None - Transform object, or None if not found in any convention - """ - # Try the configured convention first (if set) - configured_convention = get_option(CONVENTION) - if configured_convention is not None: - result = _CONVENTION_MODULES[configured_convention].read_transform( - obj, **kwargs - ) - if result is not None: - return result - - # Try all other conventions - for conv_enum, convention in _CONVENTION_MODULES.items(): - if conv_enum == configured_convention: - continue # Already tried this one - result = convention.read_transform(obj, **kwargs) - if result is not None: - return result - - # Legacy fallback: look in attrs for 'transform' (not part of any convention) - try: - return Affine(*obj.attrs["transform"][:6]) - except KeyError: - pass - - return None - - -def read_spatial_dimensions_auto( - obj: Union[xarray.Dataset, xarray.DataArray], -) -> Optional[Tuple[str, str]]: - """ - Auto-detect and read spatial dimensions by trying convention readers. - - If a convention is set globally via set_options(), that convention - is tried first for better performance. Then other conventions are - tried as fallback. - - Parameters - ---------- - obj : xarray.Dataset or xarray.DataArray - Object to read spatial dimensions from - - Returns - ------- - tuple of (y_dim, x_dim) or None - Tuple of dimension names, or None if not found in any convention - """ - # Try the configured convention first (if set) - configured_convention = get_option(CONVENTION) - if configured_convention is not None: - result = _CONVENTION_MODULES[configured_convention].read_spatial_dimensions(obj) - if result is not None: - return result - - # Try all other conventions - for conv_enum, convention in _CONVENTION_MODULES.items(): - if conv_enum == configured_convention: - continue # Already tried this one - result = convention.read_spatial_dimensions(obj) - if result is not None: - return result - - return None +Convention modules for rioxarray. +Each convention module implements the ConventionProtocol interface. +""" diff --git a/rioxarray/_convention/_base.py b/rioxarray/_convention/_base.py new file mode 100644 index 00000000..18db2d2d --- /dev/null +++ b/rioxarray/_convention/_base.py @@ -0,0 +1,52 @@ +"""This module defines the common interface for convention implementations +and provides helpers for selecting conventions. +""" + +from typing import Optional, Protocol, Union + +import rasterio.crs +import xarray +from affine import Affine + + +class ConventionProtocol(Protocol): + """Protocol defining the interface for convention modules.""" + + @classmethod + def read_crs( + cls, obj: Union[xarray.Dataset, xarray.DataArray], **kwargs + ) -> Optional[rasterio.crs.CRS]: + """Read CRS from the object using this convention.""" + + @classmethod + def read_transform( + cls, obj: Union[xarray.Dataset, xarray.DataArray], **kwargs + ) -> Optional[Affine]: + """Read transform from the object using this convention.""" + + @classmethod + def read_spatial_dimensions( + cls, + obj: Union[xarray.Dataset, xarray.DataArray], + ) -> Optional[tuple[str, str]]: + """Read spatial dimensions (y_dim, x_dim) from the object using this convention.""" + + @classmethod + def write_crs( + cls, + obj: Union[xarray.Dataset, xarray.DataArray], + *, + crs: rasterio.crs.CRS, + **kwargs, + ) -> Union[xarray.Dataset, xarray.DataArray]: + """Write CRS to the object using this convention.""" + + @classmethod + def write_transform( + cls, + obj: Union[xarray.Dataset, xarray.DataArray], + *, + transform: Affine, + **kwargs, + ) -> Union[xarray.Dataset, xarray.DataArray]: + """Write transform to the object using this convention.""" diff --git a/rioxarray/_convention/_core.py b/rioxarray/_convention/_core.py new file mode 100644 index 00000000..3169eceb --- /dev/null +++ b/rioxarray/_convention/_core.py @@ -0,0 +1,174 @@ +""" +Core convention methods for rioxarray. +""" + +from typing import Optional, Union + +import rasterio.crs +import xarray +from affine import Affine + +from rioxarray._convention._base import ConventionProtocol +from rioxarray._convention.cf import CFConvention +from rioxarray._options import CONVENTION, get_option +from rioxarray.crs import crs_from_user_input +from rioxarray.enum import Convention + +# Convention classes mapped by Convention enum +_CONVENTION_MODULES: dict[Convention, ConventionProtocol] = { + Convention.CF: CFConvention # type: ignore[dict-item] +} + + +def _get_convention(convention: Convention | None) -> ConventionProtocol: + """ + Get the convention module for writing. + + Parameters + ---------- + convention : Convention enum value or None + The convention to use. If None, uses the global default. + + Returns + ------- + ConventionProtocol + The module implementing the convention + """ + if convention is None: + convention = get_option(CONVENTION) or Convention.CF + convention = Convention(convention) + return _CONVENTION_MODULES[convention] + + +def read_crs_auto( + obj: Union[xarray.Dataset, xarray.DataArray], + **kwargs, +) -> Optional[rasterio.crs.CRS]: + """ + Auto-detect and read CRS by trying convention readers. + + If a convention is set globally via set_options(), that convention + is tried first for better performance. Then other conventions are + tried as fallback. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read CRS from + **kwargs + Convention-specific parameters (e.g., grid_mapping for CF) + + Returns + ------- + rasterio.crs.CRS or None + CRS object, or None if not found in any convention + """ + # Try the configured convention first (if set) + configured_convention = get_option(CONVENTION) + if configured_convention is not None: + result = _CONVENTION_MODULES[configured_convention].read_crs(obj, **kwargs) + if result is not None: + return result + + # Try all other conventions + for conv_enum, convention in _CONVENTION_MODULES.items(): + if conv_enum == configured_convention: + continue # Already tried this one + result = convention.read_crs(obj, **kwargs) + if result is not None: + return result + + # Legacy fallback: look in attrs for 'crs' (not part of any convention) + try: + return crs_from_user_input(obj.attrs["crs"]) + except KeyError: + pass + + return None + + +def read_transform_auto( + obj: Union[xarray.Dataset, xarray.DataArray], + **kwargs, +) -> Optional[Affine]: + """ + Auto-detect and read transform by trying convention readers. + + If a convention is set globally via set_options(), that convention + is tried first for better performance. Then other conventions are + tried as fallback. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read transform from + **kwargs + Convention-specific parameters (e.g., grid_mapping for CF) + + Returns + ------- + affine.Affine or None + Transform object, or None if not found in any convention + """ + # Try the configured convention first (if set) + configured_convention = get_option(CONVENTION) + if configured_convention is not None: + result = _CONVENTION_MODULES[configured_convention].read_transform( + obj, **kwargs + ) + if result is not None: + return result + + # Try all other conventions + for conv_enum, convention in _CONVENTION_MODULES.items(): + if conv_enum == configured_convention: + continue # Already tried this one + result = convention.read_transform(obj, **kwargs) + if result is not None: + return result + + # Legacy fallback: look in attrs for 'transform' (not part of any convention) + try: + return Affine(*obj.attrs["transform"][:6]) + except KeyError: + pass + + return None + + +def read_spatial_dimensions_auto( + obj: Union[xarray.Dataset, xarray.DataArray], +) -> Optional[tuple[str, str]]: + """ + Auto-detect and read spatial dimensions by trying convention readers. + + If a convention is set globally via set_options(), that convention + is tried first for better performance. Then other conventions are + tried as fallback. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read spatial dimensions from + + Returns + ------- + tuple of (y_dim, x_dim) or None + Tuple of dimension names, or None if not found in any convention + """ + # Try the configured convention first (if set) + configured_convention = get_option(CONVENTION) + if configured_convention is not None: + result = _CONVENTION_MODULES[configured_convention].read_spatial_dimensions(obj) + if result is not None: + return result + + # Try all other conventions + for conv_enum, convention in _CONVENTION_MODULES.items(): + if conv_enum == configured_convention: + continue # Already tried this one + result = convention.read_spatial_dimensions(obj) + if result is not None: + return result + + return None diff --git a/rioxarray/_convention/cf.py b/rioxarray/_convention/cf.py index 37ee278f..cb08c8a8 100644 --- a/rioxarray/_convention/cf.py +++ b/rioxarray/_convention/cf.py @@ -13,78 +13,33 @@ from affine import Affine from rioxarray._options import EXPORT_GRID_MAPPING, get_option -from rioxarray._spatial_utils import ( - DEFAULT_GRID_MAP, - _get_spatial_dims, - _has_spatial_dims, -) from rioxarray.crs import crs_from_user_input -from rioxarray.exceptions import MissingSpatialDimensionError -def _find_grid_mapping( - obj: Union[xarray.Dataset, xarray.DataArray], - *, - grid_mapping: Optional[str] = None, -) -> Optional[str]: - """ - Find the grid_mapping coordinate name. - - Parameters - ---------- - obj : xarray.Dataset or xarray.DataArray - Object to search for grid_mapping - grid_mapping : str, optional - Explicit grid_mapping name to use - - Returns - ------- - str or None - The grid_mapping name, or None if not found - """ - if grid_mapping is not None: - return grid_mapping - - # Try to find grid_mapping attribute on data variables - if hasattr(obj, "data_vars"): - for data_var in obj.data_vars.values(): - if "grid_mapping" in data_var.attrs: - return data_var.attrs["grid_mapping"] - if "grid_mapping" in data_var.encoding: - return data_var.encoding["grid_mapping"] - - if hasattr(obj, "attrs") and "grid_mapping" in obj.attrs: - return obj.attrs["grid_mapping"] - - if hasattr(obj, "encoding") and "grid_mapping" in obj.encoding: - return obj.encoding["grid_mapping"] - - return None - - -def read_crs( - obj: Union[xarray.Dataset, xarray.DataArray], *, grid_mapping: Optional[str] = None -) -> Optional[rasterio.crs.CRS]: - """ - Read CRS from CF conventions. - - Parameters - ---------- - obj : xarray.Dataset or xarray.DataArray - Object to read CRS from - grid_mapping : str, optional - Name of the grid_mapping coordinate variable - - Returns - ------- - rasterio.crs.CRS or None - CRS object, or None if not found - """ - grid_mapping = _find_grid_mapping(obj, grid_mapping=grid_mapping) +class CFConvention: + """CF convention class implementing ConventionProtocol.""" - if grid_mapping is not None: + @classmethod + def read_crs( + cls, obj: Union[xarray.Dataset, xarray.DataArray] + ) -> Optional[rasterio.crs.CRS]: + """ + Read CRS from CF conventions. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read CRS from + grid_mapping_name : str, optional + Name of the grid_mapping coordinate variable + + Returns + ------- + rasterio.crs.CRS or None + CRS object, or None if not found + """ try: - grid_mapping_coord = obj.coords[grid_mapping] + grid_mapping_coord = obj.coords[obj.rio.grid_mapping] # Look in wkt attributes first for performance for crs_attr in ("spatial_ref", "crs_wkt"): @@ -101,259 +56,194 @@ def read_crs( except KeyError: # grid_mapping coordinate doesn't exist pass - - return None - - -def read_transform( - obj: Union[xarray.Dataset, xarray.DataArray], *, grid_mapping: Optional[str] = None -) -> Optional[Affine]: - """ - Read transform from CF conventions (GeoTransform attribute). - - Parameters - ---------- - obj : xarray.Dataset or xarray.DataArray - Object to read transform from - grid_mapping : str, optional - Name of the grid_mapping coordinate variable - - Returns - ------- - affine.Affine or None - Transform object, or None if not found - """ - grid_mapping = _find_grid_mapping(obj, grid_mapping=grid_mapping) - - if grid_mapping is not None: + return None + + @classmethod + def read_transform( + cls, obj: Union[xarray.Dataset, xarray.DataArray] + ) -> Optional[Affine]: + """ + Read transform from CF conventions (GeoTransform attribute). + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read transform from + + Returns + ------- + affine.Affine or None + Transform object, or None if not found + """ try: transform = numpy.fromstring( - obj.coords[grid_mapping].attrs["GeoTransform"], sep=" " + obj.coords[obj.rio.grid_mapping].attrs["GeoTransform"], sep=" " ) # Calling .tolist() to assure the arguments are Python float and JSON serializable return Affine.from_gdal(*transform.tolist()) except KeyError: pass - return None - - -def read_spatial_dimensions( - obj: Union[xarray.Dataset, xarray.DataArray], -) -> Optional[Tuple[str, str]]: - """ - Read spatial dimensions from CF conventions. - - This function detects spatial dimensions based on: - 1. Standard dimension names ('x'/'y', 'longitude'/'latitude') - 2. CF coordinate attributes ('axis', 'standard_name') - - Parameters - ---------- - obj : xarray.Dataset or xarray.DataArray - Object to read spatial dimensions from - - Returns - ------- - tuple of (y_dim, x_dim) or None - Tuple of dimension names, or None if not found - """ - x_dim = None - y_dim = None - - # Check standard dimension names - if "x" in obj.dims and "y" in obj.dims: - return "y", "x" - if "longitude" in obj.dims and "latitude" in obj.dims: - return "latitude", "longitude" - - # Look for coordinates with CF attributes - for coord in obj.coords: - # Make sure to only look in 1D coordinates - # that has the same dimension name as the coordinate - if obj.coords[coord].dims != (coord,): - continue - if (obj.coords[coord].attrs.get("axis", "").upper() == "X") or ( - obj.coords[coord].attrs.get("standard_name", "").lower() - in ("longitude", "projection_x_coordinate") - ): - x_dim = coord - elif (obj.coords[coord].attrs.get("axis", "").upper() == "Y") or ( - obj.coords[coord].attrs.get("standard_name", "").lower() - in ("latitude", "projection_y_coordinate") - ): - y_dim = coord - - if x_dim is not None and y_dim is not None: - return str(y_dim), str(x_dim) - - return None - - -def write_crs( - obj: Union[xarray.Dataset, xarray.DataArray], - *, - crs: rasterio.crs.CRS, - **kwargs, -) -> Union[xarray.Dataset, xarray.DataArray]: - """ - Write CRS using CF conventions. - - This also writes the grid_mapping attribute to encoding for CF compliance. - - Parameters - ---------- - obj : xarray.Dataset or xarray.DataArray - Object to write CRS to - crs : rasterio.crs.CRS - CRS to write - **kwargs + return None + + @classmethod + def read_spatial_dimensions( + cls, + obj: Union[xarray.Dataset, xarray.DataArray], + ) -> Optional[Tuple[str, str]]: + """ + Read spatial dimensions from CF conventions. + + This function detects spatial dimensions based on: + 1. Standard dimension names ('x'/'y', 'longitude'/'latitude') + 2. CF coordinate attributes ('axis', 'standard_name') + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read spatial dimensions from + + Returns + ------- + tuple of (y_dim, x_dim) or None + Tuple of dimension names, or None if not found + """ + x_dim = None + y_dim = None + + # Check standard dimension names + if "x" in obj.dims and "y" in obj.dims: + return "y", "x" + if "longitude" in obj.dims and "latitude" in obj.dims: + return "latitude", "longitude" + + # Look for coordinates with CF attributes + for coord in obj.coords: + # Make sure to only look in 1D coordinates + # that has the same dimension name as the coordinate + if obj.coords[coord].dims != (coord,): + continue + if (obj.coords[coord].attrs.get("axis", "").upper() == "X") or ( + obj.coords[coord].attrs.get("standard_name", "").lower() + in ("longitude", "projection_x_coordinate") + ): + x_dim = coord + elif (obj.coords[coord].attrs.get("axis", "").upper() == "Y") or ( + obj.coords[coord].attrs.get("standard_name", "").lower() + in ("latitude", "projection_y_coordinate") + ): + y_dim = coord + + if x_dim is not None and y_dim is not None: + return str(y_dim), str(x_dim) + return None + + @classmethod + def write_crs( + cls, + obj: Union[xarray.Dataset, xarray.DataArray], + crs: rasterio.crs.CRS, + *, + grid_mapping_name: Optional[str] = None, + **kwargs, # pylint: disable=unused-argument + ) -> Union[xarray.Dataset, xarray.DataArray]: + """ + Write CRS using CF conventions. + + This also writes the grid_mapping attribute to encoding for CF compliance. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to write CRS to + crs : rasterio.crs.CRS + CRS to write grid_mapping_name : str - Name of the grid_mapping coordinate (required for CF) - - Returns - ------- - xarray.Dataset or xarray.DataArray - Object with CRS written - """ - grid_mapping_name = kwargs.get("grid_mapping_name") - if grid_mapping_name is None: - # Get grid_mapping from encoding/attrs or use default - grid_mapping_name = _find_grid_mapping(obj) or DEFAULT_GRID_MAP + Name of the grid_mapping coordinate + **kwargs + Additional convention-specific parameters + + Returns + ------- + xarray.Dataset or xarray.DataArray + Object with CRS written + """ + # get original transform + transform = obj.rio._cached_transform() + # remove old grid mapping coordinate if exists + grid_mapping_name = ( + obj.rio.grid_mapping if grid_mapping_name is None else grid_mapping_name + ) + try: + del obj.coords[grid_mapping_name] + except KeyError: + pass - # Get original transform before modifying (pass grid_mapping_name to find it) - transform = read_transform(obj, grid_mapping=grid_mapping_name) + # add grid mapping coordinate + obj.coords[grid_mapping_name] = xarray.Variable((), 0) + grid_map_attrs = {} + if get_option(EXPORT_GRID_MAPPING): + try: + grid_map_attrs = pyproj.CRS.from_user_input(crs).to_cf() + except KeyError: + pass + # spatial_ref is for compatibility with GDAL + crs_wkt = crs.to_wkt() + grid_map_attrs["spatial_ref"] = crs_wkt + grid_map_attrs["crs_wkt"] = crs_wkt + if transform is not None: + grid_map_attrs["GeoTransform"] = " ".join( + [str(item) for item in transform.to_gdal()] + ) + obj.coords[grid_mapping_name].rio.set_attrs(grid_map_attrs, inplace=True) - # Remove old grid mapping coordinate if exists - try: - del obj.coords[grid_mapping_name] - except KeyError: - pass + return obj.rio.write_grid_mapping( + grid_mapping_name=grid_mapping_name, inplace=True + ) - # Add grid mapping coordinate - obj.coords[grid_mapping_name] = xarray.Variable((), 0) - grid_map_attrs = {} - if get_option(EXPORT_GRID_MAPPING): + @classmethod + def write_transform( + cls, + obj: Union[xarray.Dataset, xarray.DataArray], + *, + transform: Affine, + grid_mapping_name: Optional[str] = None, + **kwargs, # pylint: disable=unused-argument + ) -> Union[xarray.Dataset, xarray.DataArray]: + """ + Write transform using CF conventions (GeoTransform attribute). + + This also writes the grid_mapping attribute to encoding for CF compliance. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to write transform to + transform : affine.Affine + Transform to write + grid_mapping_name : Optional[str] + Name of the grid_mapping coordinate + **kwargs + Additional convention-specific parameters + + Returns + ------- + xarray.Dataset or xarray.DataArray + Object with transform written + """ + transform = transform or obj.rio.transform(recalc=True) + grid_mapping_name = ( + obj.rio.grid_mapping if grid_mapping_name is None else grid_mapping_name + ) try: - grid_map_attrs = pyproj.CRS.from_user_input(crs).to_cf() + grid_map_attrs = obj.coords[grid_mapping_name].attrs.copy() except KeyError: - pass - - # spatial_ref is for compatibility with GDAL - crs_wkt = crs.to_wkt() - grid_map_attrs["spatial_ref"] = crs_wkt - grid_map_attrs["crs_wkt"] = crs_wkt - if transform is not None: + obj.coords[grid_mapping_name] = xarray.Variable((), 0) + grid_map_attrs = obj.coords[grid_mapping_name].attrs.copy() grid_map_attrs["GeoTransform"] = " ".join( [str(item) for item in transform.to_gdal()] ) - obj.coords[grid_mapping_name].attrs = grid_map_attrs - - # Write grid_mapping to encoding (CF specific) - obj = _write_grid_mapping(obj, grid_mapping_name=grid_mapping_name) - - return obj - - -def _write_grid_mapping( - obj: Union[xarray.Dataset, xarray.DataArray], - *, - grid_mapping_name: str, -) -> Union[xarray.Dataset, xarray.DataArray]: - """ - Write the CF grid_mapping attribute to the encoding. - - Parameters - ---------- - obj : xarray.Dataset or xarray.DataArray - Object to write grid_mapping to - grid_mapping_name : str - Name of the grid_mapping coordinate - - Returns - ------- - xarray.Dataset or xarray.DataArray - Object with grid_mapping written to encoding - """ - if hasattr(obj, "data_vars"): - for var in obj.data_vars: - if not _has_spatial_dims(obj, var=var): - continue - try: - x_dim, y_dim = _get_spatial_dims(obj, var=var) - except MissingSpatialDimensionError: - continue - # remove grid_mapping from attributes if it exists - # and update the grid_mapping in encoding - new_attrs = dict(obj[var].attrs) - new_attrs.pop("grid_mapping", None) - obj[var].attrs = new_attrs - obj[var].encoding["grid_mapping"] = grid_mapping_name - obj[var].rio.set_spatial_dims(x_dim=x_dim, y_dim=y_dim, inplace=True) - - # remove grid_mapping from attributes if it exists - # and update the grid_mapping in encoding - new_attrs = dict(obj.attrs) - new_attrs.pop("grid_mapping", None) - obj.attrs = new_attrs - obj.encoding["grid_mapping"] = grid_mapping_name - - return obj - - -def write_transform( - obj: Union[xarray.Dataset, xarray.DataArray], - *, - transform: Affine, - **kwargs, -) -> Union[xarray.Dataset, xarray.DataArray]: - """ - Write transform using CF conventions (GeoTransform attribute). - - This also writes the grid_mapping attribute to encoding for CF compliance. - - Parameters - ---------- - obj : xarray.Dataset or xarray.DataArray - Object to write transform to - transform : affine.Affine - Transform to write - **kwargs - grid_mapping_name : str - Name of the grid_mapping coordinate (required for CF) - - Returns - ------- - xarray.Dataset or xarray.DataArray - Object with transform written - """ - grid_mapping_name = kwargs.get("grid_mapping_name") - if grid_mapping_name is None: - # Get grid_mapping from encoding/attrs or use default - grid_mapping_name = _find_grid_mapping(obj) or DEFAULT_GRID_MAP - - try: - grid_map_attrs = obj.coords[grid_mapping_name].attrs.copy() - except KeyError: - obj.coords[grid_mapping_name] = xarray.Variable((), 0) - grid_map_attrs = obj.coords[grid_mapping_name].attrs.copy() - - grid_map_attrs["GeoTransform"] = " ".join( - [str(item) for item in transform.to_gdal()] - ) - obj.coords[grid_mapping_name].attrs = grid_map_attrs - - # Write grid_mapping to encoding (CF specific) - obj = _write_grid_mapping(obj, grid_mapping_name=grid_mapping_name) - - return obj - - -class CFConvention: - """CF convention class implementing ConventionProtocol.""" - - read_crs = staticmethod(read_crs) - read_transform = staticmethod(read_transform) - read_spatial_dimensions = staticmethod(read_spatial_dimensions) - write_crs = staticmethod(write_crs) - write_transform = staticmethod(write_transform) + obj.coords[grid_mapping_name].rio.set_attrs(grid_map_attrs, inplace=True) + return obj.rio.write_grid_mapping( + grid_mapping_name=grid_mapping_name, inplace=True + ) diff --git a/rioxarray/rioxarray.py b/rioxarray/rioxarray.py index f1d6e05c..368d26a1 100644 --- a/rioxarray/rioxarray.py +++ b/rioxarray/rioxarray.py @@ -21,9 +21,8 @@ from rasterio.crs import CRS from rasterio.rpc import RPC -from rioxarray._convention import ( +from rioxarray._convention._core import ( _get_convention, - cf, read_crs_auto, read_spatial_dimensions_auto, read_transform_auto, @@ -86,7 +85,7 @@ def crs(self) -> Optional[rasterio.crs.CRS]: return None if self._crs is False else self._crs # Auto-detect CRS from any supported convention - parsed_crs = read_crs_auto(self._obj, grid_mapping=self.grid_mapping) + parsed_crs = read_crs_auto(self._obj) if parsed_crs is not None: self._set_crs(parsed_crs, inplace=True) @@ -230,7 +229,28 @@ def write_grid_mapping( :meth:`rioxarray.rioxarray.XRasterBase.write_crs` """ data_obj = self._get_obj(inplace=inplace) - return cf._write_grid_mapping(data_obj, grid_mapping_name=grid_mapping_name) + if hasattr(data_obj, "data_vars"): + for var in data_obj.data_vars: + try: + x_dim, y_dim = _get_spatial_dims(data_obj, var=var) + except MissingSpatialDimensionError: + continue + # remove grid_mapping from attributes if it exists + # and update the grid_mapping in encoding + new_attrs = dict(data_obj[var].attrs) + new_attrs.pop("grid_mapping", None) + data_obj[var].rio.update_encoding( + {"grid_mapping": grid_mapping_name}, inplace=True + ).rio.set_attrs(new_attrs, inplace=True).rio.set_spatial_dims( + x_dim=x_dim, y_dim=y_dim, inplace=True + ) + # remove grid_mapping from attributes if it exists + # and update the grid_mapping in encoding + new_attrs = dict(data_obj.attrs) + new_attrs.pop("grid_mapping", None) + return data_obj.rio.update_encoding( + {"grid_mapping": grid_mapping_name}, inplace=True + ).rio.set_attrs(new_attrs, inplace=True) def write_crs( self, @@ -340,7 +360,7 @@ def _cached_transform(self) -> Optional[Affine]: """ Get the transform by auto-detecting from any supported convention. """ - return read_transform_auto(self._obj, grid_mapping=self.grid_mapping) + return read_transform_auto(self._obj) def write_transform( self, diff --git a/test/unit/test_convention_cf.py b/test/unit/test_convention_cf.py index 7d9eec60..18f32818 100644 --- a/test/unit/test_convention_cf.py +++ b/test/unit/test_convention_cf.py @@ -4,7 +4,8 @@ from affine import Affine from rasterio.crs import CRS -from rioxarray._convention import cf, read_crs_auto, read_transform_auto +from rioxarray._convention._core import read_crs_auto, read_transform_auto +from rioxarray._convention.cf import CFConvention def test_read_crs__from_grid_mapping_spatial_ref(): @@ -13,7 +14,7 @@ def test_read_crs__from_grid_mapping_spatial_ref(): data.coords["spatial_ref"] = xarray.Variable((), 0) data.coords["spatial_ref"].attrs["spatial_ref"] = "EPSG:4326" - crs = cf.read_crs(data, grid_mapping="spatial_ref") + crs = CFConvention.read_crs(data) assert crs is not None assert crs == CRS.from_epsg(4326) @@ -24,7 +25,7 @@ def test_read_crs__from_grid_mapping_crs_wkt(): data.coords["spatial_ref"] = xarray.Variable((), 0) data.coords["spatial_ref"].attrs["crs_wkt"] = CRS.from_epsg(4326).to_wkt() - crs = cf.read_crs(data, grid_mapping="spatial_ref") + crs = CFConvention.read_crs(data) assert crs is not None assert crs == CRS.from_epsg(4326) @@ -39,7 +40,7 @@ def test_read_crs__from_legacy_attrs(): data.attrs["crs"] = "EPSG:4326" # CF convention should NOT find this - crs = cf.read_crs(data) + crs = CFConvention.read_crs(data) assert crs is None # Auto-detect should find it @@ -59,11 +60,11 @@ def test_read_crs__from_legacy_attrs_with_missing_grid_mapping(): data.attrs["crs"] = "EPSG:4326" # CF convention should NOT find this - crs = cf.read_crs(data, grid_mapping="spatial_ref") + crs = CFConvention.read_crs(data) assert crs is None # Auto-detect should find it - crs = read_crs_auto(data, grid_mapping="spatial_ref") + crs = read_crs_auto(data) assert crs is not None assert crs == CRS.from_epsg(4326) @@ -72,7 +73,7 @@ def test_read_crs__not_found(): """Test that None is returned when no CRS is found.""" data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) - crs = cf.read_crs(data) + crs = CFConvention.read_crs(data) assert crs is None @@ -83,7 +84,7 @@ def test_read_transform__from_geotransform(): # GeoTransform format: [c, a, b, f, d, e] (GDAL format) data.coords["spatial_ref"].attrs["GeoTransform"] = "0.0 1.0 0.0 10.0 0.0 -1.0" - transform = cf.read_transform(data, grid_mapping="spatial_ref") + transform = CFConvention.read_transform(data) assert transform is not None assert transform == Affine(1.0, 0.0, 0.0, 0.0, -1.0, 10.0) @@ -99,7 +100,7 @@ def test_read_transform__from_legacy_attrs(): data.attrs["transform"] = [1.0, 0.0, 0.0, 0.0, -1.0, 10.0] # CF convention should NOT find this - transform = cf.read_transform(data) + transform = CFConvention.read_transform(data) assert transform is None # Auto-detect should find it @@ -112,7 +113,7 @@ def test_read_transform__not_found(): """Test that None is returned when no transform is found.""" data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) - transform = cf.read_transform(data) + transform = CFConvention.read_transform(data) assert transform is None @@ -120,7 +121,7 @@ def test_read_spatial_dimensions__xy(): """Test detecting x/y dimension names.""" data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) - dims = cf.read_spatial_dimensions(data) + dims = CFConvention.read_spatial_dimensions(data) assert dims == ("y", "x") @@ -128,7 +129,7 @@ def test_read_spatial_dimensions__lonlat(): """Test detecting longitude/latitude dimension names.""" data = xarray.DataArray(numpy.random.rand(10, 10), dims=["latitude", "longitude"]) - dims = cf.read_spatial_dimensions(data) + dims = CFConvention.read_spatial_dimensions(data) assert dims == ("latitude", "longitude") @@ -145,7 +146,7 @@ def test_read_spatial_dimensions__cf_axis(): data.coords["row"].attrs["axis"] = "Y" data.coords["col"].attrs["axis"] = "X" - dims = cf.read_spatial_dimensions(data) + dims = CFConvention.read_spatial_dimensions(data) assert dims == ("row", "col") @@ -162,7 +163,7 @@ def test_read_spatial_dimensions__cf_standard_name(): data.coords["lat"].attrs["standard_name"] = "latitude" data.coords["lon"].attrs["standard_name"] = "longitude" - dims = cf.read_spatial_dimensions(data) + dims = CFConvention.read_spatial_dimensions(data) assert dims == ("lat", "lon") @@ -170,7 +171,7 @@ def test_read_spatial_dimensions__not_found(): """Test that None is returned when spatial dimensions are not found.""" data = xarray.DataArray(numpy.random.rand(10, 10), dims=["a", "b"]) - dims = cf.read_spatial_dimensions(data) + dims = CFConvention.read_spatial_dimensions(data) assert dims is None @@ -178,8 +179,7 @@ def test_write_crs(): """Test writing CRS to a DataArray.""" data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) crs = CRS.from_epsg(4326) - - result = cf.write_crs(data, crs=crs, grid_mapping_name="spatial_ref") + result = CFConvention.write_crs(data, crs=crs, grid_mapping_name="spatial_ref") assert "spatial_ref" in result.coords assert result.coords["spatial_ref"].attrs["spatial_ref"] == crs.to_wkt() @@ -191,7 +191,7 @@ def test_write_transform(): data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 10.0) - result = cf.write_transform( + result = CFConvention.write_transform( data, transform=transform, grid_mapping_name="spatial_ref" )