Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,37 @@ def _reparse_points(
)


def _warn_missing_groups(
groups: str | list[str],
color_source_vector: pd.Categorical,
col_for_color: str | None = None,
) -> None:
"""Warn when ``groups`` contains values absent from the color column's categories."""
groups_set = {groups} if isinstance(groups, str) else set(groups)
missing = groups_set - set(color_source_vector.categories)
if not missing:
return
col_label = f" '{col_for_color}'" if col_for_color else " the color column"
try:
missing_str = str(sorted(missing))
except TypeError:
missing_str = str(list(missing))
if missing == groups_set:
logger.warning(
f"None of the requested groups {missing_str} were found in{col_label}. "
"This usually means `groups` refers to values from a different column than `color`. "
"The `groups` parameter selects categories of the column specified via `color`."
)
else:
try:
cats_str = str(sorted(color_source_vector.categories))
except TypeError:
cats_str = str(list(color_source_vector.categories))
logger.warning(
f"Groups {missing_str} were not found in{col_label} and will be ignored. Available categories: {cats_str}."
)


def _filter_groups_transparent_na(
groups: str | list[str],
color_source_vector: pd.Categorical,
Expand Down Expand Up @@ -298,10 +329,13 @@ def _render_shapes(

values_are_categorical = color_source_vector is not None

if groups is not None and color_source_vector is not None:
_warn_missing_groups(groups, color_source_vector, col_for_color)

# When groups are specified, filter out non-matching elements by default.
# Only show non-matching elements if the user explicitly sets na_color.
_na = render_params.cmap_params.na_color
if groups is not None and values_are_categorical and (_na.default_color_set or _na.alpha == "00"):
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"):
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
groups, color_source_vector, color_vector
)
Expand Down Expand Up @@ -750,6 +784,9 @@ def _render_points(
if added_color_from_table and col_for_color is not None:
_reparse_points(sdata_filt, element, points_pd_with_color, transformation_in_cs, coordinate_system)

if groups is not None and color_source_vector is not None:
_warn_missing_groups(groups, color_source_vector, col_for_color)

# When groups are specified, filter out non-matching elements by default.
# Only show non-matching elements if the user explicitly sets na_color.
_na = render_params.cmap_params.na_color
Expand Down Expand Up @@ -1298,6 +1335,9 @@ def _render_labels(
else:
assert color_source_vector is None

if groups is not None and color_source_vector is not None:
_warn_missing_groups(groups, color_source_vector, col_for_color)

# When groups are specified, zero out non-matching label IDs so they render as background.
# Only show non-matching labels if the user explicitly sets na_color.
_na = render_params.cmap_params.na_color
Expand Down
3 changes: 3 additions & 0 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ def _create_patches(
shapes_df, fill_c.tolist(), outline_c.tolist() if hasattr(outline_c, "tolist") else outline_c, s
)

if patches.empty:
return PatchCollection([])

return PatchCollection(
patches["geometry"].values.tolist(),
snap=False,
Expand Down
19 changes: 19 additions & 0 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from spatialdata.models import Labels2DModel, TableModel

import spatialdata_plot # noqa: F401
from spatialdata_plot._logging import logger, logger_warns
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over, get_standard_RNG

sc.pl.set_rcParams_defaults()
Expand Down Expand Up @@ -428,3 +429,21 @@ def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
color="channel_0_sum",
table_name="other_table",
).pl.show()


def test_groups_warns_when_no_groups_match_labels(sdata_blobs: SpatialData, caplog):
"""Warning fires when no groups match label color categories."""
labels_name = "blobs_labels"
instances = get_element_instances(sdata_blobs[labels_name])
n_obs = len(instances)
adata = AnnData(np.zeros((n_obs, 1)))
adata.obs["instance_id"] = instances.values
adata.obs["cat"] = pd.Categorical(["a", "b"] * (n_obs // 2) + ["a"] * (n_obs % 2))
adata.obs["region"] = labels_name
sdata_blobs["label_table"] = TableModel.parse(
adata=adata, region_key="region", instance_key="instance_id", region=labels_name
)
with logger_warns(caplog, logger, match="None of the requested groups"):
sdata_blobs.pl.render_labels(
labels_name, color="cat", groups=["nonexistent"], table_name="label_table", na_color=None
).pl.show()
21 changes: 21 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from spatialdata.transformations._utils import _set_transformations

import spatialdata_plot # noqa: F401
from spatialdata_plot._logging import logger, logger_warns
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over, get_standard_RNG

sc.pl.set_rcParams_defaults()
Expand Down Expand Up @@ -607,6 +608,26 @@ def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData):
).pl.show()


@pytest.mark.parametrize("na_color", [None, "red"])
def test_groups_warns_when_no_groups_match_points(sdata_blobs: SpatialData, caplog, na_color):
"""Warning fires regardless of na_color when no groups match."""
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
with logger_warns(caplog, logger, match="None of the requested groups"):
sdata_blobs.pl.render_points(
"blobs_points", color="cat_color", groups=["nonexistent"], na_color=na_color, size=30
).pl.show()


@pytest.mark.parametrize("na_color", [None, "red"])
def test_groups_warns_when_some_groups_missing_points(sdata_blobs: SpatialData, caplog, na_color):
"""Warning fires regardless of na_color when some groups are missing."""
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
with logger_warns(caplog, logger, match="were not found in"):
sdata_blobs.pl.render_points(
"blobs_points", color="cat_color", groups=["a", "nonexistent"], na_color=na_color, size=30
).pl.show()


def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
# Work on an independent copy since we mutate tables
sdata_blobs_local = deepcopy(sdata_blobs)
Expand Down
20 changes: 20 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,26 @@ def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=None).pl.show()


@pytest.mark.parametrize("na_color", [None, "red"])
def test_groups_warns_when_no_groups_match(sdata_blobs: SpatialData, caplog, na_color):
"""Warning fires regardless of na_color when no groups match."""
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
with logger_warns(caplog, logger, match="None of the requested groups"):
sdata_blobs.pl.render_shapes(
"blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=na_color
).pl.show()


@pytest.mark.parametrize("na_color", [None, "red"])
def test_groups_warns_when_some_groups_missing(sdata_blobs: SpatialData, caplog, na_color):
"""Warning fires regardless of na_color when some groups are missing."""
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
with logger_warns(caplog, logger, match="were not found in"):
sdata_blobs.pl.render_shapes(
"blobs_polygons", color="cat_color", groups=["a", "nonexistent"], na_color=na_color
).pl.show()


def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData, caplog):
"""Test that NaN values in color data are handled gracefully and logged."""
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
Expand Down
Loading