Skip to content

Commit 0fc3b35

Browse files
timtreisclaude
andcommitted
Filter non-matching elements when groups is set with na_color=None
When `groups` is specified and `na_color=None`, non-matching shapes and points are now filtered out entirely instead of rendered as invisible geometry. This avoids wasting rendering time on transparent elements. - Add _filter_groups_transparent_na() helper - Apply filtering in _render_shapes and _render_points - Update groups docstring to document na_color=None behavior - Add regression tests (baseline images pending from CI) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3109c7d commit 0fc3b35

4 files changed

Lines changed: 56 additions & 2 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ def render_shapes(
213213
`fill_alpha` will overwrite the value present in the cmap.
214214
groups : list[str] | str | None
215215
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
216-
them. Other values are set to NA. If element is None, broadcasting behaviour is attempted (use the same
217-
values for all elements).
216+
them. Other values are set to NA. When ``na_color=None``, non-matching elements are filtered out entirely
217+
(shapes and points only). If element is None, broadcasting behaviour is attempted (use the same values for
218+
all elements).
218219
palette : list[str] | str | None
219220
Palette for discrete annotations. List of valid color names that should be used for the categories. Must
220221
match the number of groups. If element is None, broadcasting behaviour is attempted (use the same values for

src/spatialdata_plot/pl/render.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,22 @@ def _build_datashader_color_key(
118118
return color_key
119119

120120

121+
def _filter_groups_transparent_na(
122+
groups: str | list[str],
123+
color_source_vector: pd.Categorical,
124+
color_vector: pd.Series | np.ndarray | list[str],
125+
) -> tuple[np.ndarray, pd.Categorical, np.ndarray]:
126+
"""Return a boolean mask and filtered color vectors for groups filtering.
127+
128+
Used when ``na_color=None`` (fully transparent) so that non-matching
129+
elements are removed entirely instead of rendered invisibly.
130+
"""
131+
keep = color_source_vector.isin(groups)
132+
filtered_csv = color_source_vector[keep]
133+
filtered_cv = np.asarray(color_vector)[keep]
134+
return keep, filtered_csv, filtered_cv
135+
136+
121137
def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]:
122138
"""Split colorbar params into layout hints, Matplotlib kwargs, and label override."""
123139
layout: dict[str, object] = {}
@@ -221,6 +237,15 @@ def _render_shapes(
221237

222238
values_are_categorical = color_source_vector is not None
223239

240+
# When groups are specified and na_color is fully transparent (na_color=None),
241+
# filter out non-matching elements instead of showing them as invisible geometry.
242+
if groups is not None and values_are_categorical and render_params.cmap_params.na_color.alpha == "00":
243+
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
244+
groups, color_source_vector, color_vector
245+
)
246+
shapes = shapes[keep].reset_index(drop=True)
247+
sdata_filt[element] = shapes
248+
224249
# color_source_vector is None when the values aren't categorical
225250
if values_are_categorical and render_params.transfunc is not None:
226251
color_vector = render_params.transfunc(color_vector)
@@ -840,6 +865,24 @@ def _render_points(
840865
)
841866
points_dd = points_with_color_dd
842867

868+
# When groups are specified and na_color is fully transparent (na_color=None),
869+
# filter out non-matching points instead of rendering invisible geometry.
870+
if groups is not None and color_source_vector is not None and render_params.cmap_params.na_color.alpha == "00":
871+
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
872+
groups, color_source_vector, color_vector
873+
)
874+
# filter the materialized points, adata, and re-register in sdata_filt
875+
points = points[keep].reset_index(drop=True)
876+
adata = adata[keep].copy()
877+
points_dd = dask.dataframe.from_pandas(points, npartitions=1)
878+
sdata_filt.points[element] = PointsModel.parse(points_dd, coordinates={"x": "x", "y": "y"})
879+
set_transformation(
880+
element=sdata_filt.points[element],
881+
transformation=transformation_in_cs,
882+
to_coordinate_system=coordinate_system,
883+
)
884+
n_points = int(keep.sum())
885+
843886
# color_source_vector is None when the values aren't categorical
844887
if color_source_vector is None and render_params.transfunc is not None:
845888
color_vector = render_params.transfunc(color_vector)

tests/pl/test_render_points.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,11 @@ def test_plot_can_annotate_points_with_nan_in_df_continuous_datashader(self, sda
576576
sdata_blobs["blobs_points"]["cont_color"] = pd.Series([np.nan, 2, 9, 13] * 50)
577577
sdata_blobs.pl.render_points("blobs_points", color="cont_color", size=40, method="datashader").pl.show()
578578

579+
def test_plot_groups_na_color_none_filters_points(self, sdata_blobs: SpatialData):
580+
"""When groups is set and na_color=None, non-matching points are filtered out entirely."""
581+
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
582+
sdata_blobs.pl.render_points("blobs_points", color="cat_color", groups=["a"], na_color=None, size=30).pl.show()
583+
579584

580585
def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
581586
# Work on an independent copy since we mutate tables

tests/pl/test_render_shapes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,11 @@ def test_plot_can_annotate_shapes_with_nan_in_df_continuous_datashader(self, sda
983983
sdata_blobs["blobs_polygons"]["cont_color"] = [np.nan, 2, 3, 4, 5]
984984
sdata_blobs.pl.render_shapes("blobs_polygons", color="cont_color", method="datashader").pl.show()
985985

986+
def test_plot_groups_na_color_none_filters_shapes(self, sdata_blobs: SpatialData):
987+
"""When groups is set and na_color=None, non-matching shapes are filtered out entirely."""
988+
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
989+
sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["a"], na_color=None).pl.show()
990+
986991

987992
def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData, caplog):
988993
"""Test that NaN values in color data are handled gracefully and logged."""

0 commit comments

Comments
 (0)