From 848dd4aa35899dcb9b84b4e24fee5ba624be30c7 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Fri, 23 Jan 2026 13:44:00 +0100 Subject: [PATCH] add drop_attributes --- tests/test_20_open_dataset.py | 22 ++++++++++++++++++++++ xarray_esgf/client.py | 14 +++++++++++++- xarray_esgf/engine.py | 3 +++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/test_20_open_dataset.py b/tests/test_20_open_dataset.py index 1a59b62..ec5f5d6 100644 --- a/tests/test_20_open_dataset.py +++ b/tests/test_20_open_dataset.py @@ -180,3 +180,25 @@ def test_ignore_spatial_coords( ignore_spatial_coords=ignore_spatial_coords, ) assert {"lat", "lon"} <= set(ds.variables) + + +def test_drop_attributes( + tmp_path: Path, + index_node: str, +) -> None: + esgpull_path = tmp_path / "esgpull" + selection = { + "query": [ + '"tas_Amon_EC-Earth3-CC_ssp245_r1i1p1f1_gr_201901-201912.nc"', + ] + } + ds = xr.open_dataset( + selection, # type: ignore[arg-type] + esgpull_path=esgpull_path, + engine="esgf", + index_node=index_node, + chunks={}, + drop_attributes=["contact", "standard_name"], + ) + assert "contact" not in ds.attrs + assert "standard_name" not in ds["tas"].attrs diff --git a/xarray_esgf/client.py b/xarray_esgf/client.py index 016b1b7..2c72638 100644 --- a/xarray_esgf/client.py +++ b/xarray_esgf/client.py @@ -11,7 +11,7 @@ import xarray as xr from esgpull import Esgpull, File, Query from esgpull.fs import FileCheck -from xarray import DataArray, Dataset +from xarray import DataArray, Dataset, Variable DATASET_ID_KEYS = Literal[ "project", @@ -65,6 +65,16 @@ def move_dimensionless_coords_to_attrs(ds: Dataset) -> Dataset: return ds +def pop_attrs(obj: Dataset | DataArray | Variable, keys: str | Iterable[str]) -> None: + if isinstance(keys, str): + keys = [keys] + for key in keys: + obj.attrs.pop(key, None) + if isinstance(obj, Dataset): + for var in obj.variables.values(): + pop_attrs(var, keys) + + @dataclasses.dataclass class Client: selection: dict[str, str | list[str]] @@ -211,6 +221,7 @@ def open_dataset( show_progress: bool = True, sel: dict[Hashable, Any] | None = None, ignore_spatial_coords: str | Iterable[str] | None = None, + drop_attributes: str | Iterable[str] | None = None, ) -> Dataset: combined_datasets = self._open_datasets( concat_dims=concat_dims, @@ -234,4 +245,5 @@ def open_dataset( obj.attrs["coordinates"] = " ".join(sorted(str(coord) for coord in obj.coords)) obj.attrs["dataset_ids"] = sorted(combined_datasets) + pop_attrs(obj, drop_attributes or []) return obj diff --git a/xarray_esgf/engine.py b/xarray_esgf/engine.py index 560c7a1..5b18799 100644 --- a/xarray_esgf/engine.py +++ b/xarray_esgf/engine.py @@ -24,6 +24,7 @@ def open_dataset( # type: ignore[override] show_progress: bool = True, sel: dict[Hashable, Any] | None = None, ignore_spatial_coords: str | Iterable[str] | None = None, + drop_attributes: str | Iterable[str] | None = None, ) -> Dataset: client = Client( selection=filename_or_obj, @@ -40,6 +41,7 @@ def open_dataset( # type: ignore[override] show_progress=show_progress, sel=sel, ignore_spatial_coords=ignore_spatial_coords, + drop_attributes=drop_attributes, ) open_dataset_parameters = ( @@ -54,6 +56,7 @@ def open_dataset( # type: ignore[override] "show_progress", "sel", "ignore_spatial_coords", + "drop_attributes", ) def guess_can_open(self, filename_or_obj: Any) -> bool: