diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 9c74cd83..d45888a6 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -108,6 +108,37 @@ def _reparse_points( ) +def _warn_missing_groups( + groups: str | list[str], + color_source_vector: pd.Categorical, + col_for_color: str | None = None, +) -> None: + """Warn when ``groups`` contains values absent from the color column's categories.""" + groups_set = {groups} if isinstance(groups, str) else set(groups) + missing = groups_set - set(color_source_vector.categories) + if not missing: + return + col_label = f" '{col_for_color}'" if col_for_color else " the color column" + try: + missing_str = str(sorted(missing)) + except TypeError: + missing_str = str(list(missing)) + if missing == groups_set: + logger.warning( + f"None of the requested groups {missing_str} were found in{col_label}. " + "This usually means `groups` refers to values from a different column than `color`. " + "The `groups` parameter selects categories of the column specified via `color`." + ) + else: + try: + cats_str = str(sorted(color_source_vector.categories)) + except TypeError: + cats_str = str(list(color_source_vector.categories)) + logger.warning( + f"Groups {missing_str} were not found in{col_label} and will be ignored. Available categories: {cats_str}." + ) + + def _filter_groups_transparent_na( groups: str | list[str], color_source_vector: pd.Categorical, @@ -298,10 +329,13 @@ def _render_shapes( values_are_categorical = color_source_vector is not None + if groups is not None and color_source_vector is not None: + _warn_missing_groups(groups, color_source_vector, col_for_color) + # When groups are specified, filter out non-matching elements by default. # Only show non-matching elements if the user explicitly sets na_color. _na = render_params.cmap_params.na_color - if groups is not None and values_are_categorical and (_na.default_color_set or _na.alpha == "00"): + if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"): keep, color_source_vector, color_vector = _filter_groups_transparent_na( groups, color_source_vector, color_vector ) @@ -750,6 +784,9 @@ def _render_points( if added_color_from_table and col_for_color is not None: _reparse_points(sdata_filt, element, points_pd_with_color, transformation_in_cs, coordinate_system) + if groups is not None and color_source_vector is not None: + _warn_missing_groups(groups, color_source_vector, col_for_color) + # When groups are specified, filter out non-matching elements by default. # Only show non-matching elements if the user explicitly sets na_color. _na = render_params.cmap_params.na_color @@ -1298,6 +1335,9 @@ def _render_labels( else: assert color_source_vector is None + if groups is not None and color_source_vector is not None: + _warn_missing_groups(groups, color_source_vector, col_for_color) + # When groups are specified, zero out non-matching label IDs so they render as background. # Only show non-matching labels if the user explicitly sets na_color. _na = render_params.cmap_params.na_color diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 11787509..3fbb1743 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -553,6 +553,9 @@ def _create_patches( shapes_df, fill_c.tolist(), outline_c.tolist() if hasattr(outline_c, "tolist") else outline_c, s ) + if patches.empty: + return PatchCollection([]) + return PatchCollection( patches["geometry"].values.tolist(), snap=False, diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 1dd6af8f..ab7c158d 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -12,6 +12,7 @@ from spatialdata.models import Labels2DModel, TableModel import spatialdata_plot # noqa: F401 +from spatialdata_plot._logging import logger, logger_warns from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over, get_standard_RNG sc.pl.set_rcParams_defaults() @@ -428,3 +429,21 @@ def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData): color="channel_0_sum", table_name="other_table", ).pl.show() + + +def test_groups_warns_when_no_groups_match_labels(sdata_blobs: SpatialData, caplog): + """Warning fires when no groups match label color categories.""" + labels_name = "blobs_labels" + instances = get_element_instances(sdata_blobs[labels_name]) + n_obs = len(instances) + adata = AnnData(np.zeros((n_obs, 1))) + adata.obs["instance_id"] = instances.values + adata.obs["cat"] = pd.Categorical(["a", "b"] * (n_obs // 2) + ["a"] * (n_obs % 2)) + adata.obs["region"] = labels_name + sdata_blobs["label_table"] = TableModel.parse( + adata=adata, region_key="region", instance_key="instance_id", region=labels_name + ) + with logger_warns(caplog, logger, match="None of the requested groups"): + sdata_blobs.pl.render_labels( + labels_name, color="cat", groups=["nonexistent"], table_name="label_table", na_color=None + ).pl.show() diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index b5b29728..bc321b17 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -22,6 +22,7 @@ from spatialdata.transformations._utils import _set_transformations import spatialdata_plot # noqa: F401 +from spatialdata_plot._logging import logger, logger_warns from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over, get_standard_RNG sc.pl.set_rcParams_defaults() @@ -607,6 +608,26 @@ def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData): ).pl.show() +@pytest.mark.parametrize("na_color", [None, "red"]) +def test_groups_warns_when_no_groups_match_points(sdata_blobs: SpatialData, caplog, na_color): + """Warning fires regardless of na_color when no groups match.""" + sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category") + with logger_warns(caplog, logger, match="None of the requested groups"): + sdata_blobs.pl.render_points( + "blobs_points", color="cat_color", groups=["nonexistent"], na_color=na_color, size=30 + ).pl.show() + + +@pytest.mark.parametrize("na_color", [None, "red"]) +def test_groups_warns_when_some_groups_missing_points(sdata_blobs: SpatialData, caplog, na_color): + """Warning fires regardless of na_color when some groups are missing.""" + sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category") + with logger_warns(caplog, logger, match="were not found in"): + sdata_blobs.pl.render_points( + "blobs_points", color="cat_color", groups=["a", "nonexistent"], na_color=na_color, size=30 + ).pl.show() + + def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData): # Work on an independent copy since we mutate tables sdata_blobs_local = deepcopy(sdata_blobs) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 5701fedf..3b7fb716 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -1016,6 +1016,26 @@ def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=None).pl.show() +@pytest.mark.parametrize("na_color", [None, "red"]) +def test_groups_warns_when_no_groups_match(sdata_blobs: SpatialData, caplog, na_color): + """Warning fires regardless of na_color when no groups match.""" + sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category") + with logger_warns(caplog, logger, match="None of the requested groups"): + sdata_blobs.pl.render_shapes( + "blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=na_color + ).pl.show() + + +@pytest.mark.parametrize("na_color", [None, "red"]) +def test_groups_warns_when_some_groups_missing(sdata_blobs: SpatialData, caplog, na_color): + """Warning fires regardless of na_color when some groups are missing.""" + sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category") + with logger_warns(caplog, logger, match="were not found in"): + sdata_blobs.pl.render_shapes( + "blobs_polygons", color="cat_color", groups=["a", "nonexistent"], na_color=na_color + ).pl.show() + + def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData, caplog): """Test that NaN values in color data are handled gracefully and logged.""" sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)