Skip to content

Commit 772db0d

Browse files
timtreisclaude
andcommitted
Split out groups filtering, fix deprecated dtype checks
Move groups filtering with na_color=None to a follow-up PR (#549). Replace deprecated pd.api.types.is_categorical_dtype with isinstance(dtype, pd.CategoricalDtype). Add TypeError for unsupported types in _coerce_categorical_source. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d6b2441 commit 772db0d

File tree

6 files changed

+22
-58
lines changed

6 files changed

+22
-58
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,8 @@ def render_shapes(
213213
`fill_alpha` will overwrite the value present in the cmap.
214214
groups : list[str] | str | None
215215
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
216-
them. Other values are set to NA. When ``na_color=None``, non-matching elements are filtered out entirely
217-
(shapes and points only). If element is None, broadcasting behaviour is attempted (use the same values for
218-
all elements).
216+
them. Other values are set to NA. If element is None, broadcasting behaviour is attempted (use the same
217+
values for all elements).
219218
palette : list[str] | str | None
220219
Palette for discrete annotations. List of valid color names that should be used for the categories. Must
221220
match the number of groups. If element is None, broadcasting behaviour is attempted (use the same values for

src/spatialdata_plot/pl/render.py

Lines changed: 20 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,32 @@
7070

7171

7272
def _coerce_categorical_source(cat_source: Any) -> pd.Categorical:
73-
"""Return a pandas Categorical from known, concrete sources only."""
73+
"""Return a pandas Categorical from known, concrete sources only.
74+
75+
Raises
76+
------
77+
TypeError
78+
If *cat_source* is not a ``dd.Series``, ``pd.Series``,
79+
``pd.Categorical``, or ``np.ndarray``.
80+
"""
7481
if isinstance(cat_source, dd.Series):
75-
if pd.api.types.is_categorical_dtype(cat_source.dtype) and getattr(cat_source.cat, "known", True) is False:
82+
if isinstance(cat_source.dtype, pd.CategoricalDtype) and getattr(cat_source.cat, "known", True) is False:
7683
cat_source = cat_source.cat.as_known()
7784
cat_source = cat_source.compute()
7885

7986
if isinstance(cat_source, pd.Series):
80-
if pd.api.types.is_categorical_dtype(cat_source.dtype):
87+
if isinstance(cat_source.dtype, pd.CategoricalDtype):
8188
return cat_source.array
8289
return pd.Categorical(cat_source)
8390
if isinstance(cat_source, pd.Categorical):
8491
return cat_source
92+
if isinstance(cat_source, np.ndarray):
93+
return pd.Categorical(cat_source)
8594

86-
return pd.Categorical(pd.Series(cat_source))
95+
raise TypeError(
96+
f"Cannot coerce {type(cat_source).__name__} to pd.Categorical. "
97+
"Expected dd.Series, pd.Series, pd.Categorical, or np.ndarray."
98+
)
8799

88100

89101
def _build_datashader_color_key(
@@ -209,20 +221,6 @@ def _render_shapes(
209221

210222
values_are_categorical = color_source_vector is not None
211223

212-
# When groups are specified and na_color is fully transparent (na_color=None),
213-
# filter out non-matching elements instead of showing them as invisible geometry.
214-
if groups is not None and values_are_categorical and render_params.cmap_params.na_color.alpha == "00":
215-
csv_series = pd.Series(color_source_vector)
216-
keep = csv_series.isin(groups).values
217-
shapes = shapes[keep].reset_index(drop=True)
218-
sdata_filt[element] = shapes
219-
color_source_vector = pd.Categorical(csv_series[keep].reset_index(drop=True))
220-
color_vector = (
221-
np.asarray(color_vector)[keep]
222-
if not hasattr(color_vector, "reset_index")
223-
else (color_vector[keep].reset_index(drop=True))
224-
)
225-
226224
# color_source_vector is None when the values aren't categorical
227225
if values_are_categorical and render_params.transfunc is not None:
228226
color_vector = render_params.transfunc(color_vector)
@@ -352,7 +350,7 @@ def _render_shapes(
352350
color_by_categorical = col_for_color is not None and color_source_vector is not None
353351
if color_by_categorical:
354352
cat_series = transformed_element[col_for_color]
355-
if not pd.api.types.is_categorical_dtype(cat_series):
353+
if not isinstance(cat_series.dtype, pd.CategoricalDtype):
356354
cat_series = cat_series.astype("category")
357355
transformed_element[col_for_color] = cat_series
358356

@@ -842,29 +840,6 @@ def _render_points(
842840
)
843841
points_dd = points_with_color_dd
844842

845-
# When groups are specified and na_color is fully transparent (na_color=None),
846-
# filter out non-matching points instead of rendering invisible geometry.
847-
if groups is not None and color_source_vector is not None and render_params.cmap_params.na_color.alpha == "00":
848-
csv_series = pd.Series(color_source_vector)
849-
keep = csv_series.isin(groups).values
850-
color_source_vector = pd.Categorical(csv_series[keep].reset_index(drop=True))
851-
color_vector = (
852-
np.asarray(color_vector)[keep]
853-
if not hasattr(color_vector, "reset_index")
854-
else (color_vector[keep].reset_index(drop=True))
855-
)
856-
# filter the materialized points, adata, and re-register in sdata_filt
857-
points = points[keep].reset_index(drop=True)
858-
adata = adata[keep].copy()
859-
points_dd = dask.dataframe.from_pandas(points, npartitions=1)
860-
sdata_filt.points[element] = PointsModel.parse(points_dd, coordinates={"x": "x", "y": "y"})
861-
set_transformation(
862-
element=sdata_filt.points[element],
863-
transformation=transformation_in_cs,
864-
to_coordinate_system=coordinate_system,
865-
)
866-
n_points = int(keep.sum())
867-
868843
# color_source_vector is None when the values aren't categorical
869844
if color_source_vector is None and render_params.transfunc is not None:
870845
color_vector = render_params.transfunc(color_vector)
@@ -931,11 +906,11 @@ def _render_points(
931906
color_dtype = transformed_element[col_for_color].dtype if col_for_color is not None else None
932907
color_by_categorical = col_for_color is not None and (
933908
color_source_vector is not None
934-
or pd.api.types.is_categorical_dtype(color_dtype)
909+
or isinstance(color_dtype, pd.CategoricalDtype)
935910
or pd.api.types.is_object_dtype(color_dtype)
936911
or pd.api.types.is_string_dtype(color_dtype)
937912
)
938-
if color_by_categorical and not pd.api.types.is_categorical_dtype(color_dtype):
913+
if color_by_categorical and not isinstance(color_dtype, pd.CategoricalDtype):
939914
transformed_element[col_for_color] = transformed_element[col_for_color].astype("category")
940915

941916
aggregate_with_reduction = None
@@ -944,7 +919,7 @@ def _render_points(
944919
if color_by_categorical:
945920
# add nan as category so that nan points are shown in the nan color
946921
cat_series = transformed_element[col_for_color]
947-
if not pd.api.types.is_categorical_dtype(cat_series):
922+
if not isinstance(cat_series.dtype, pd.CategoricalDtype):
948923
cat_series = cat_series.astype("category")
949924
if hasattr(cat_series.cat, "as_known"):
950925
cat_series = cat_series.cat.as_known()
-37.4 KB
Binary file not shown.
-23.2 KB
Binary file not shown.

tests/pl/test_render_points.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -576,11 +576,6 @@ def test_plot_can_annotate_points_with_nan_in_df_continuous_datashader(self, sda
576576
sdata_blobs["blobs_points"]["cont_color"] = pd.Series([np.nan, 2, 9, 13] * 50)
577577
sdata_blobs.pl.render_points("blobs_points", color="cont_color", size=40, method="datashader").pl.show()
578578

579-
def test_plot_groups_na_color_none_filters_points(self, sdata_blobs: SpatialData):
580-
"""When groups is set and na_color=None, non-matching points are filtered out entirely."""
581-
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
582-
sdata_blobs.pl.render_points("blobs_points", color="cat_color", groups=["a"], na_color=None, size=30).pl.show()
583-
584579

585580
def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
586581
# Work on an independent copy since we mutate tables

tests/pl/test_render_shapes.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -983,11 +983,6 @@ def test_plot_can_annotate_shapes_with_nan_in_df_continuous_datashader(self, sda
983983
sdata_blobs["blobs_polygons"]["cont_color"] = [np.nan, 2, 3, 4, 5]
984984
sdata_blobs.pl.render_shapes("blobs_polygons", color="cont_color", method="datashader").pl.show()
985985

986-
def test_plot_groups_na_color_none_filters_shapes(self, sdata_blobs: SpatialData):
987-
"""When groups is set and na_color=None, non-matching shapes are filtered out entirely."""
988-
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
989-
sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["a"], na_color=None).pl.show()
990-
991986

992987
def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData, caplog):
993988
"""Test that NaN values in color data are handled gracefully and logged."""

0 commit comments

Comments
 (0)