From 5f74c77a13b72d1235b57bf239a34cc9c973ca93 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Wed, 21 Jan 2026 18:03:03 +0100 Subject: [PATCH] add ignore_spatial_coords argument --- tests/test_20_open_dataset.py | 37 +++++++++++++++++++++++++++++++++++ xarray_esgf/client.py | 12 ++++++++++++ xarray_esgf/engine.py | 3 +++ 3 files changed, 52 insertions(+) diff --git a/tests/test_20_open_dataset.py b/tests/test_20_open_dataset.py index 2c14e0f..b249153 100644 --- a/tests/test_20_open_dataset.py +++ b/tests/test_20_open_dataset.py @@ -1,9 +1,13 @@ +import contextlib from collections.abc import Hashable from pathlib import Path from typing import Any import pytest import xarray as xr +from xarray import AlignmentError + +does_not_raise = contextlib.nullcontext @pytest.mark.parametrize("download", [True, False]) @@ -142,3 +146,36 @@ def test_time_selection( sel=sel, ) assert ds.sizes["time"] == expected_size + + +@pytest.mark.parametrize( + "ignore_spatial_coords, raises", + [ + ("areacella", does_not_raise()), + (None, pytest.raises(AlignmentError)), + ], +) +def test_ignore_spatial_coords( + tmp_path: Path, + index_node: str, + ignore_spatial_coords: str | None, + raises: contextlib.nullcontext, +) -> None: + esgpull_path = tmp_path / "esgpull" + selection = { + "query": [ + '"CMIP6.CMIP.MPI-M.MPI-ESM1-2-HR.historical.r1i1p1f1.fx.areacella.gn.v20190710.areacella_fx_MPI-ESM1-2-HR_historical_r1i1p1f1_gn.nc"', + '"CMIP6.CMIP.MPI-M.MPI-ESM1-2-HR.historical.r1i1p1f1.Amon.tas.gn.v20190710.tas_Amon_MPI-ESM1-2-HR_historical_r1i1p1f1_gn_185001-185412.nc"', + ] + } + + with raises: + ds = xr.open_dataset( + selection, # type: ignore[arg-type] + esgpull_path=esgpull_path, + engine="esgf", + index_node=index_node, + chunks={}, + ignore_spatial_coords=ignore_spatial_coords, + ) + assert {"lat", "lon"} <= set(ds.variables) diff --git a/xarray_esgf/client.py b/xarray_esgf/client.py index c6367ab..fe29d30 100644 --- a/xarray_esgf/client.py +++ b/xarray_esgf/client.py @@ -139,6 +139,7 @@ def _open_datasets( download: bool, show_progress: bool, sel: dict[Hashable, Any], + ignore_spatial_coords: str | Iterable[str], ) -> dict[str, Dataset]: sel = { k: slice(*v["slice"]) if isinstance(v, dict) else v for k, v in sel.items() @@ -148,6 +149,10 @@ def _open_datasets( concat_dims = [concat_dims] concat_dims = concat_dims or [] + if isinstance(ignore_spatial_coords, str): + ignore_spatial_coords = {ignore_spatial_coords} + ignore_spatial_coords = set(ignore_spatial_coords) + if download: self.download() @@ -162,7 +167,12 @@ def _open_datasets( drop_variables=drop_variables, storage_options={"ssl": self.verify_ssl}, ) + ds = ds.sel({k: v for k, v in sel.items() if k in ds.dims}) + + if ignore_spatial_coords.intersection(ds.variables): + ds = ds.drop_vars(set(ds.variables) & {"lat", "lon"}) + if all(ds.sizes.values()): grouped_objects[file.dataset_id].append(ds.drop_encoding()) @@ -192,6 +202,7 @@ def open_dataset( download: bool = False, show_progress: bool = True, sel: dict[Hashable, Any] | None = None, + ignore_spatial_coords: str | Iterable[str] | None = None, ) -> Dataset: combined_datasets = self._open_datasets( concat_dims=concat_dims, @@ -199,6 +210,7 @@ def open_dataset( download=download, show_progress=show_progress, sel=sel or {}, + ignore_spatial_coords=ignore_spatial_coords or {}, ) obj = combine_datasets([ds.reset_coords() for ds in combined_datasets.values()]) diff --git a/xarray_esgf/engine.py b/xarray_esgf/engine.py index 4dab023..560c7a1 100644 --- a/xarray_esgf/engine.py +++ b/xarray_esgf/engine.py @@ -23,6 +23,7 @@ def open_dataset( # type: ignore[override] download: bool = False, show_progress: bool = True, sel: dict[Hashable, Any] | None = None, + ignore_spatial_coords: str | Iterable[str] | None = None, ) -> Dataset: client = Client( selection=filename_or_obj, @@ -38,6 +39,7 @@ def open_dataset( # type: ignore[override] download=download, show_progress=show_progress, sel=sel, + ignore_spatial_coords=ignore_spatial_coords, ) open_dataset_parameters = ( @@ -51,6 +53,7 @@ def open_dataset( # type: ignore[override] "download", "show_progress", "sel", + "ignore_spatial_coords", ) def guess_can_open(self, filename_or_obj: Any) -> bool: