From c8f22a77122d296251c70d7d483f0ff3910a48ff Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Thu, 1 Jan 2026 11:38:15 -0800 Subject: [PATCH 1/6] Fix mypy errors with newer pandas-stubs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the pandas-stubs version pin (<=2.2.3.241126) and fix the mypy errors that appear with newer versions. This resolves #10110. Changes by category: 1. Added type: ignore comments with explanatory notes for cases where pandas-stubs is stricter than actual pandas behavior: - CFTimeIndex.__add__/__radd__ return Self instead of overloaded types - Index.get_indexer accepts ndarray/list, not just Index - CategoricalIndex.remove_unused_categories missing from stubs - Series.where accepts broader argument types - ExtensionArray.astype accepts ExtensionDtype - Series[datetime].__setitem__ accepts np.nan (converts to NaT) - MultiIndex.rename accepts list of names 2. Added explicit type annotations to help mypy infer correct types: - coordinates.py: codes as list[np.ndarray] - dataset.py: arrays and extension_arrays list types - indexes.py: xr_index variable with PandasIndex | PandasMultiIndex 3. Removed redundant casts and fixed variable shadowing: - Removed unnecessary cast in remove_unused_levels_categories - Converted dict_keys to list for reorder_levels - Fixed test_backends.py variable naming (tdf -> tdf_series) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- pixi.toml | 2 +- xarray/coding/cftimeindex.py | 6 ++++-- xarray/core/coordinates.py | 3 ++- xarray/core/dataset.py | 4 ++-- xarray/core/extension_array.py | 9 ++++++--- xarray/core/indexes.py | 23 +++++++++++++---------- xarray/core/missing.py | 3 ++- xarray/tests/test_backends.py | 6 +++--- xarray/tests/test_dataarray.py | 3 ++- xarray/tests/test_dataset.py | 6 ++++-- xarray/tests/test_duck_array_ops.py | 3 ++- xarray/tests/test_plot.py | 2 +- 12 files changed, 42 insertions(+), 28 deletions(-) diff --git a/pixi.toml b/pixi.toml index a8d165e868f..3589e5abf9f 100644 --- a/pixi.toml +++ b/pixi.toml @@ -255,7 +255,7 @@ mypy = "==1.18.1" pyright = "*" hypothesis = "*" lxml = "*" -pandas-stubs = "<=2.2.3.241126" # https://github.com/pydata/xarray/issues/10110 +pandas-stubs = "<=2.3.3.251219" types-colorama = "*" types-docutils = "*" types-psutil = "*" diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index b4bde1458b8..0ad8166f2de 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -514,12 +514,14 @@ def shift( # type: ignore[override,unused-ignore] f"'freq' must be of type str or datetime.timedelta, got {type(freq)}." ) - def __add__(self, other) -> Self: + # pandas-stubs defines many overloads for Index.__add__/__radd__ with specific + # return types, but CFTimeIndex legitimately returns Self for all cases + def __add__(self, other) -> Self: # type: ignore[override] if isinstance(other, pd.TimedeltaIndex): other = other.to_pytimedelta() return type(self)(np.array(self) + other) - def __radd__(self, other) -> Self: + def __radd__(self, other) -> Self: # type: ignore[override] if isinstance(other, pd.TimedeltaIndex): other = other.to_pytimedelta() return type(self)(other + np.array(self)) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 9aa64a57ff2..501c58657cf 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -169,7 +169,8 @@ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index: for i, index in enumerate(indexes): if isinstance(index, pd.MultiIndex): - codes, levels = index.codes, index.levels + codes: list[np.ndarray] = list(index.codes) + levels = index.levels else: code, level = pd.factorize(index) codes = [code] diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 01baa9aed3d..84a67d95412 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7402,8 +7402,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: "cannot convert a DataFrame with a non-unique MultiIndex into xarray" ) - arrays = [] - extension_arrays = [] + arrays: list[tuple[Hashable, np.ndarray]] = [] + extension_arrays: list[tuple[Hashable, pd.Series]] = [] for k, v in dataframe.items(): if not is_allowed_extension_array(v) or isinstance( v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 5841a696f15..d078d61e58a 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -104,7 +104,9 @@ def as_extension_array( [array_or_scalar], dtype=dtype ) else: - return array_or_scalar.astype(dtype, copy=copy) # type: ignore[union-attr] + # pandas-stubs is overly strict about astype's dtype parameter and return type; + # ExtensionArray.astype accepts ExtensionDtype and returns ExtensionArray + return array_or_scalar.astype(dtype, copy=copy) # type: ignore[union-attr,return-value,arg-type] @implements(np.result_type) @@ -192,10 +194,11 @@ def __extension_duck_array__where( # pd.where won't broadcast 0-dim arrays across a scalar-like series; scalar y's must be preserved if hasattr(y, "shape") and len(y.shape) == 1 and y.shape[0] == 1: y = y[0] # type: ignore[index] - return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[arg-type] + # pandas-stubs has strict overloads for Series.where that don't cover all valid arg types + return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[call-overload] -def _replace_duck(args, replacer: Callable[[PandasExtensionArray], list]) -> list: +def _replace_duck(args, replacer: Callable[[PandasExtensionArray], Any]) -> list: args_as_list = list(args) for index, value in enumerate(args_as_list): if isinstance(value, PandasExtensionArray): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 078fb7b7d79..1381110cbba 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -628,7 +628,8 @@ def get_indexer_nd(index: pd.Index, labels, method=None, tolerance=None) -> np.n flat_labels = np.ravel(labels) if flat_labels.dtype == "float16": flat_labels = flat_labels.astype("float64") - flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance) + # pandas-stubs expects Index for get_indexer, but ndarray works at runtime + flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance) # type: ignore[arg-type] indexer = flat_indexer.reshape(labels.shape) return indexer @@ -978,14 +979,15 @@ def remove_unused_levels_categories(index: T_PDIndex) -> T_PDIndex: Remove unused levels from MultiIndex and unused categories from CategoricalIndex """ if isinstance(index, pd.MultiIndex): - new_index = cast(pd.MultiIndex, index.remove_unused_levels()) + new_index = index.remove_unused_levels() # if it contains CategoricalIndex, we need to remove unused categories # manually. See https://github.com/pandas-dev/pandas/issues/30846 if any(isinstance(lev, pd.CategoricalIndex) for lev in new_index.levels): levels = [] for i, level in enumerate(new_index.levels): if isinstance(level, pd.CategoricalIndex): - level = level[new_index.codes[i]].remove_unused_categories() + # pandas-stubs is missing remove_unused_categories on CategoricalIndex + level = level[new_index.codes[i]].remove_unused_categories() # type: ignore[attr-defined] else: level = level[new_index.codes[i]] levels.append(level) @@ -1229,7 +1231,7 @@ def reorder_levels( its corresponding coordinates. """ - index = cast(pd.MultiIndex, self.index.reorder_levels(level_variables.keys())) + index = self.index.reorder_levels(list(level_variables.keys())) level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} return self._replace(index, level_coords_dtype=level_coords_dtype) @@ -1378,28 +1380,29 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: indexer = DataArray(indexer, coords=coords, dims=label.dims) if new_index is not None: + xr_index: PandasIndex | PandasMultiIndex if isinstance(new_index, pd.MultiIndex): level_coords_dtype = { k: self.level_coords_dtype[k] for k in new_index.names } - new_index = self._replace( + xr_index = self._replace( new_index, level_coords_dtype=level_coords_dtype ) dims_dict = {} - drop_coords = [] + drop_coords: list[Hashable] = [] else: - new_index = PandasIndex( + xr_index = PandasIndex( new_index, new_index.name, coord_dtype=self.level_coords_dtype[new_index.name], ) - dims_dict = {self.dim: new_index.index.name} + dims_dict = {self.dim: xr_index.index.name} drop_coords = [self.dim] # variable(s) attrs and encoding metadata are propagated # when replacing the indexes in the resulting xarray object - new_vars = new_index.create_variables() - indexes = cast(dict[Any, Index], dict.fromkeys(new_vars, new_index)) + new_vars = xr_index.create_variables() + indexes = cast(dict[Any, Index], dict.fromkeys(new_vars, xr_index)) # add scalar variable for each dropped level variables = new_vars diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 3a41f558700..5c9ba396faf 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -588,7 +588,8 @@ def _localize(obj: T, indexes_coords: SourceDest) -> tuple[T, SourceDest]: minval = np.nanmin(new_x_loaded) maxval = np.nanmax(new_x_loaded) index = x.to_index() - imin, imax = index.get_indexer([minval, maxval], method="nearest") + # pandas-stubs expects Index for get_indexer, but list works at runtime + imin, imax = index.get_indexer([minval, maxval], method="nearest") # type: ignore[arg-type] indexes[dim] = slice(max(imin - 2, 0), imax + 2) indexes_coords[dim] = (x[indexes[dim]], new_x) return obj.isel(indexes), indexes_coords # type: ignore[attr-defined] diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index fdc7fdc8edb..9300685dc77 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -7500,10 +7500,10 @@ def test_write_file_from_np_str(str_type: type[str | np.str_], tmpdir: str) -> N ) tdf.index.name = "scenario" tdf.columns.name = "year" - tdf = cast(pd.DataFrame, tdf.stack()) - tdf.name = "tas" + tdf_series = cast(pd.Series, tdf.stack()) + tdf_series.name = "tas" - txr = tdf.to_xarray() + txr = tdf_series.to_xarray() txr.to_netcdf(tmpdir.join("test.nc")) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e61ea9e7fe8..d1aacba6aaa 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3931,7 +3931,8 @@ def test_to_and_from_dict_with_nan_nat(self) -> None: y = np.random.randn(10, 3) y[2] = np.nan t = pd.Series(pd.date_range("20130101", periods=10)) - t[2] = np.nan + # pandas-stubs doesn't allow np.nan for datetime Series, but it converts to NaT + t[2] = np.nan # type: ignore[call-overload] lat = [77.7, 83.2, 76] da = DataArray(y, {"t": t, "lat": lat}, dims=["t", "lat"]) roundtripped = DataArray.from_dict(da.to_dict()) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index d25ef5a2771..4835c562d5a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3497,7 +3497,8 @@ def test_rename_multiindex(self) -> None: midx_coords = Coordinates.from_pandas_multiindex(midx, "x") original = Dataset({}, midx_coords) - midx_renamed = midx.rename(["a", "c"]) + # pandas-stubs expects Hashable for rename, but list of names works for MultiIndex + midx_renamed = midx.rename(["a", "c"]) # type: ignore[call-overload] midx_coords_renamed = Coordinates.from_pandas_multiindex(midx_renamed, "x") expected = Dataset({}, midx_coords_renamed) @@ -5602,7 +5603,8 @@ def test_to_and_from_dict_with_nan_nat( y = np.random.randn(10, 3) y[2] = np.nan t = pd.Series(pd.date_range("20130101", periods=10)) - t[2] = np.nan + # pandas-stubs doesn't allow np.nan for datetime Series, but it converts to NaT + t[2] = np.nan # type: ignore[call-overload] lat = [77.7, 83.2, 76] ds = Dataset( diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 8fa3641c796..645a975a15a 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1150,6 +1150,7 @@ def test_extension_array_attr(): assert (roundtripped == wrapped).all() interval_array = pd.arrays.IntervalArray.from_breaks([0, 1, 2, 3], closed="right") - wrapped = PandasExtensionArray(interval_array) + # pandas-stubs types PandasExtensionArray too narrowly; IntervalArray is valid + wrapped = PandasExtensionArray(interval_array) # type: ignore[arg-type] assert_array_equal(wrapped.left, interval_array.left, strict=True) assert wrapped.closed == interval_array.closed diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 3d8c3f5de2a..bb0b809448c 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -528,7 +528,7 @@ def test__infer_interval_breaks(self) -> None: [-0.5, 0.5, 5.0, 9.5, 10.5], _infer_interval_breaks([0, 1, 9, 10]) ) assert_array_equal( - pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), # type: ignore[operator] + pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), _infer_interval_breaks(pd.date_range("20000101", periods=3)), ) From f3831f1ba401fcb87071cafed55178895343efef Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 5 Jan 2026 10:07:11 -0700 Subject: [PATCH 2/6] Apply suggestion from @dcherian --- xarray/tests/test_dataarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index d1aacba6aaa..49e28f68fea 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3932,7 +3932,7 @@ def test_to_and_from_dict_with_nan_nat(self) -> None: y[2] = np.nan t = pd.Series(pd.date_range("20130101", periods=10)) # pandas-stubs doesn't allow np.nan for datetime Series, but it converts to NaT - t[2] = np.nan # type: ignore[call-overload] + t[2] = pd.NaT lat = [77.7, 83.2, 76] da = DataArray(y, {"t": t, "lat": lat}, dims=["t", "lat"]) roundtripped = DataArray.from_dict(da.to_dict()) From 30033d0ce44ad656b6b0319d47521cb1dbd96a5c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 5 Jan 2026 10:17:26 -0700 Subject: [PATCH 3/6] Revert "Apply suggestion from @dcherian" This reverts commit f3831f1ba401fcb87071cafed55178895343efef. --- xarray/tests/test_dataarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 49e28f68fea..d1aacba6aaa 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3932,7 +3932,7 @@ def test_to_and_from_dict_with_nan_nat(self) -> None: y[2] = np.nan t = pd.Series(pd.date_range("20130101", periods=10)) # pandas-stubs doesn't allow np.nan for datetime Series, but it converts to NaT - t[2] = pd.NaT + t[2] = np.nan # type: ignore[call-overload] lat = [77.7, 83.2, 76] da = DataArray(y, {"t": t, "lat": lat}, dims=["t", "lat"]) roundtripped = DataArray.from_dict(da.to_dict()) From 4ab9b2af1699e55d58068cabb9ea997e83e37ba7 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 5 Jan 2026 10:18:25 -0700 Subject: [PATCH 4/6] Update get_indexer calls --- xarray/core/indexes.py | 5 +++-- xarray/core/missing.py | 3 +-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 1381110cbba..74729d8e105 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -628,8 +628,9 @@ def get_indexer_nd(index: pd.Index, labels, method=None, tolerance=None) -> np.n flat_labels = np.ravel(labels) if flat_labels.dtype == "float16": flat_labels = flat_labels.astype("float64") - # pandas-stubs expects Index for get_indexer, but ndarray works at runtime - flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance) # type: ignore[arg-type] + flat_indexer = index.get_indexer( + pd.Index(flat_labels), method=method, tolerance=tolerance + ) indexer = flat_indexer.reshape(labels.shape) return indexer diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 5c9ba396faf..7f22d2df45c 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -588,8 +588,7 @@ def _localize(obj: T, indexes_coords: SourceDest) -> tuple[T, SourceDest]: minval = np.nanmin(new_x_loaded) maxval = np.nanmax(new_x_loaded) index = x.to_index() - # pandas-stubs expects Index for get_indexer, but list works at runtime - imin, imax = index.get_indexer([minval, maxval], method="nearest") # type: ignore[arg-type] + imin, imax = index.get_indexer(pd.Index([minval, maxval]), method="nearest") indexes[dim] = slice(max(imin - 2, 0), imax + 2) indexes_coords[dim] = (x[indexes[dim]], new_x) return obj.isel(indexes), indexes_coords # type: ignore[attr-defined] From ca72801c42a5ac43da377b04238f1c191e2fd70e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 5 Jan 2026 10:20:06 -0700 Subject: [PATCH 5/6] remove cast to Series --- xarray/tests/test_backends.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 9300685dc77..a6163950968 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -7500,7 +7500,7 @@ def test_write_file_from_np_str(str_type: type[str | np.str_], tmpdir: str) -> N ) tdf.index.name = "scenario" tdf.columns.name = "year" - tdf_series = cast(pd.Series, tdf.stack()) + tdf_series = tdf.stack() tdf_series.name = "tas" txr = tdf_series.to_xarray() From 7381536f36b0f9cbcda2f9a7365379d39b9fd046 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 5 Jan 2026 10:30:26 -0700 Subject: [PATCH 6/6] one more fix --- xarray/tests/test_backends.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a6163950968..51c9827f33a 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -7501,7 +7501,7 @@ def test_write_file_from_np_str(str_type: type[str | np.str_], tmpdir: str) -> N tdf.index.name = "scenario" tdf.columns.name = "year" tdf_series = tdf.stack() - tdf_series.name = "tas" + tdf_series.name = "tas" # type: ignore[union-attr] txr = tdf_series.to_xarray()