diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 8232f74c..3b12347e 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -11,6 +11,7 @@ import dask.array as da import numpy as np import pandas as pd +import xarray as xr from anndata import AnnData from annsel.core.typing import Predicates from dask.dataframe import DataFrame as DaskDataFrame @@ -311,7 +312,10 @@ def _get_masked_element( def _right_exclusive_join_spatialelement_table( - element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] + element_dict: dict[str, dict[str, Any]], + table: AnnData, + match_rows: Literal["left", "no", "right"], + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData | None]: regions, region_column_name, instance_key = get_table_keys(table) if isinstance(regions, str): @@ -349,7 +353,10 @@ def _right_exclusive_join_spatialelement_table( def _right_join_spatialelement_table( - element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] + element_dict: dict[str, dict[str, Any]], + table: AnnData, + match_rows: Literal["left", "no", "right"], + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData]: if match_rows == "left": warnings.warn("Matching rows 'left' is not supported for 'right' join.", UserWarning, stacklevel=2) @@ -365,11 +372,18 @@ def _right_join_spatialelement_table( if element_type in ["points", "shapes"]: element_indices = element.index else: - warnings.warn( - f"Element type `labels` not supported for 'right' join. Skipping `{name}`", - UserWarning, - stacklevel=2, - ) + if filter_label_pixels is True: + element_dict[element_type][name] = _filter_labels_element( + element, table_instance_key_column.tolist() + ) + elif filter_label_pixels is None: + warnings.warn( + f"Element type `labels` not supported for 'right' join, pixels are not filtered;" + f" pass `filter_label_pixels=True` to filter or `filter_label_pixels=False` to silence" + f" this warning. Skipping `{name}`", + UserWarning, + stacklevel=2, + ) continue masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows) @@ -383,7 +397,10 @@ def _right_join_spatialelement_table( def _inner_join_spatialelement_table( - element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] + element_dict: dict[str, dict[str, Any]], + table: AnnData, + match_rows: Literal["left", "no", "right"], + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData]: regions, region_column_name, instance_key = get_table_keys(table) if isinstance(regions, str): @@ -399,11 +416,18 @@ def _inner_join_spatialelement_table( if element_type in ["points", "shapes"]: element_indices = element.index else: - warnings.warn( - f"Element type `labels` not supported for 'inner' join. Skipping `{name}`", - UserWarning, - stacklevel=2, - ) + if filter_label_pixels is True: + element_dict[element_type][name] = _filter_labels_element( + element, table_instance_key_column.tolist() + ) + elif filter_label_pixels is None: + warnings.warn( + f"Element type `labels` not supported for 'inner' join, pixels are not filtered;" + f" pass `filter_label_pixels=True` to filter or `filter_label_pixels=False` to silence" + f" this warning. Skipping `{name}`", + UserWarning, + stacklevel=2, + ) continue masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows) @@ -429,7 +453,10 @@ def _inner_join_spatialelement_table( def _left_exclusive_join_spatialelement_table( - element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] + element_dict: dict[str, dict[str, Any]], + table: AnnData, + match_rows: Literal["left", "no", "right"], + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData | None]: regions, region_column_name, instance_key = get_table_keys(table) if isinstance(regions, str): @@ -462,7 +489,10 @@ def _left_exclusive_join_spatialelement_table( def _left_join_spatialelement_table( - element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] + element_dict: dict[str, dict[str, Any]], + table: AnnData, + match_rows: Literal["left", "no", "right"], + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData]: if match_rows == "right": warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2) @@ -586,6 +616,7 @@ def join_spatialelement_table( table: AnnData | None = None, how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "left", match_rows: Literal["no", "left", "right"] = "no", + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData]: """ Join SpatialElement(s) and table together in SQL like manner. @@ -629,6 +660,11 @@ def join_spatialelement_table( match_rows Whether to match the indices of the element and table and if so how. If ``'left'``, element_indices take priority and if ``'right'`` table instance ids take priority. + filter_label_pixels + Controls pixel-level filtering of label elements for ``'right'`` and ``'inner'`` joins. + If ``True``, pixels whose instance id is not present in the table are set to zero. + If ``None`` (default), label elements are returned unfiltered and a warning is issued. + If ``False``, label elements are returned unfiltered silently (no warning). Returns ------- @@ -694,12 +730,16 @@ def join_spatialelement_table( for name, element in getattr(derived_sdata, element_type).items(): elements_dict[element_type][name] = element - elements_dict_joined, table = _call_join(elements_dict, table, how, match_rows) + elements_dict_joined, table = _call_join(elements_dict, table, how, match_rows, filter_label_pixels) return elements_dict_joined, table def _call_join( - elements_dict: dict[str, dict[str, Any]], table: AnnData, how: str, match_rows: Literal["no", "left", "right"] + elements_dict: dict[str, dict[str, Any]], + table: AnnData, + how: str, + match_rows: Literal["no", "left", "right"], + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData]: assert any(key in elements_dict for key in ["labels", "shapes", "points"]), ( "No valid element to join in spatial_element_name. Must provide at least one of either `labels`, `points` or " @@ -714,7 +754,7 @@ def _call_join( # if how in JoinTypes.__dict__["_member_names_"]: # hotfix for bug with Python 3.13: if how in JoinTypes.__dict__: - elements_dict, table = getattr(JoinTypes, how)(elements_dict, table, match_rows) + elements_dict, table = getattr(JoinTypes, how)(elements_dict, table, match_rows, filter_label_pixels) else: raise TypeError(f"`{how}` is not a valid type of join.") @@ -797,6 +837,7 @@ def match_sdata_to_table( table_name: str, table: AnnData | None = None, how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", + filter_label_pixels: bool | None = None, ) -> SpatialData: """ Filter the elements of a SpatialData object to match only the rows present in the table. @@ -812,6 +853,10 @@ def match_sdata_to_table( `table_name` is used to name the table in the returned `SpatialData` object. how The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". + filter_label_pixels + Controls pixel-level filtering of label elements. ``True`` filters pixels, ``None`` (default) leaves them + unfiltered and warns, ``False`` leaves them unfiltered silently. See + :func:`spatialdata.join_spatialelement_table` for details. Notes ----- @@ -823,7 +868,7 @@ def match_sdata_to_table( _, region_key, instance_key = get_table_keys(table) annotated_regions = SpatialData.get_annotated_regions(table) filtered_elements, filtered_table = join_spatialelement_table( - sdata, spatial_element_names=annotated_regions, table=table, how=how + sdata, spatial_element_names=annotated_regions, table=table, how=how, filter_label_pixels=filter_label_pixels ) filtered_table = TableModel.parse( filtered_table, @@ -847,6 +892,7 @@ def filter_by_table_query( var_names_expr: Predicates | None = None, layer: str | None = None, how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", + filter_label_pixels: bool | None = None, ) -> SpatialData: """Filter the SpatialData object based on a set of table queries. @@ -875,6 +921,10 @@ def filter_by_table_query( The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`. how The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". + filter_label_pixels + Controls pixel-level filtering of label elements. ``True`` filters pixels, ``None`` (default) leaves them + unfiltered and warns, ``False`` leaves them unfiltered silently. See + :func:`spatialdata.join_spatialelement_table` for details. Returns ------- @@ -899,7 +949,13 @@ def filter_by_table_query( obs=obs_expr, var=var_expr, x=x_expr, obs_names=obs_names_expr, var_names=var_names_expr, layer=layer ) - return match_sdata_to_table(sdata=sdata_subset, table_name=table_name, table=filtered_table, how=how) + return match_sdata_to_table( + sdata=sdata_subset, + table_name=table_name, + table=filtered_table, + how=how, + filter_label_pixels=filter_label_pixels, + ) @dataclass @@ -1099,3 +1155,50 @@ def get_values( return df raise ValueError(f"Unknown origin {origin}") + + +def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + # Use apply_ufunc for efficient processing + # Create a copy to avoid modifying read-only array + result = block.copy() + result[np.isin(result, ids_to_remove)] = 0 + return result + + +def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + processed = xr.apply_ufunc( + partial(_mask_block, ids_to_remove=ids_to_remove), + image, + input_core_dims=[["y", "x"]], + output_core_dims=[["y", "x"]], + vectorize=True, + dask="parallelized", + output_dtypes=[image.dtype], + dataset_fill_value=0, + dask_gufunc_kwargs={"allow_rechunk": True}, + ) + + # Create a new DataArray to ensure persistence + return xr.DataArray( + data=processed.data, + coords=image.coords, + dims=image.dims, + attrs=image.attrs.copy(), # Preserve all attributes + ) + + +def _filter_labels_element(element: DataArray | DataTree, ids_to_keep: list[int]) -> DataArray | DataTree: + if get_model(element) is Labels3DModel: + raise NotImplementedError("Pixel-level filtering of 3D labels is not supported.") + element_instances = get_element_instances(element) + ids_to_remove = [i for i in element_instances if i not in set(ids_to_keep)] + if isinstance(element, DataArray): + return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove)) + scales = list(element.keys()) + scale_factors = [ + round(element[scales[i]].image.shape[0] / element[scales[i + 1]].image.shape[0]) for i in range(len(scales) - 1) + ] + return Labels2DModel.parse( + _set_instance_ids_in_labels_to_zero(element[scales[0]].image, ids_to_remove), + scale_factors=scale_factors, + ) diff --git a/tests/core/query/test_relational_query_match_sdata_to_table.py b/tests/core/query/test_relational_query_match_sdata_to_table.py index c2468999..a2af0251 100644 --- a/tests/core/query/test_relational_query_match_sdata_to_table.py +++ b/tests/core/query/test_relational_query_match_sdata_to_table.py @@ -1,9 +1,18 @@ from __future__ import annotations +import contextlib + +import annsel as an +import numpy as np +import pandas as pd import pytest +from anndata import AnnData +from xarray import DataArray from spatialdata import SpatialData, concatenate, match_sdata_to_table +from spatialdata._core.query.relational_query import filter_by_table_query from spatialdata.datasets import blobs_annotating_element +from spatialdata.models import Labels3DModel, TableModel def _make_test_data() -> SpatialData: @@ -105,35 +114,48 @@ def test_match_sdata_to_table_shapes_and_points(): assert "blobs_polygons-sdata1" not in matched -def test_match_sdata_to_table_match_labels_error(): +@pytest.mark.parametrize("subset_func_name", ["match_sdata_to_table", "filter_by_table_query"]) +@pytest.mark.parametrize("element_name", ["blobs_labels", "blobs_multiscale_labels"]) +def test_filter_out_instances(subset_func_name: str, element_name: str) -> None: """ - match_sdata_to_table() uses the join operations; so when trying to match labels, the error will be raised by the - join. + By default a warning is issued when labels are encountered in a right join and pixels are not filtered. + Passing filter_label_pixels=True filters the label pixels to match the table. """ - sdata = _make_test_data() - sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "labels")) - sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category") - sdata.set_table_annotates_spatialelement( - table_name="table", - region=["blobs_labels-sdata1", "blobs_labels-sdata2"], - region_key="region", - instance_key="instance_id", - ) - - with pytest.warns( - UserWarning, - match="Element type `labels` not supported for 'right' join. Skipping ", - ): - matched = match_sdata_to_table( - sdata, - table=sdata["table"], - table_name="table", + sdata = blobs_annotating_element(element_name) + keep_id = 3 + table = sdata.tables["table"] + subset_table = table[table.obs["instance_id"] == keep_id] + + # None → warning issued; False → silenced, pixels still unfiltered + for flp, ctx in [ + (None, pytest.warns(UserWarning, match="pixels are not filtered")), + (False, contextlib.nullcontext()), + ]: + with ctx: + if subset_func_name == "match_sdata_to_table": + match_sdata_to_table(sdata, "table", table=subset_table, filter_label_pixels=flp) + else: + filter_by_table_query( + sdata, "table", obs_expr=an.col("instance_id") == keep_id, filter_label_pixels=flp + ) + + # filter_label_pixels=True: pixels are zeroed for removed instances + if subset_func_name == "match_sdata_to_table": + subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table, filter_label_pixels=True) + else: + subset_sdata = filter_by_table_query( + sdata, "table", obs_expr=an.col("instance_id") == keep_id, filter_label_pixels=True ) - assert len(matched["table"]) == 10 - assert "blobs_labels-sdata1" in matched - assert "blobs_labels-sdata2" in matched - assert "blobs_points-sdata1" not in matched + elem = subset_sdata[element_name] + if isinstance(elem, DataArray): + remaining_ids = set(np.unique(elem.data.compute())) - {0} + assert remaining_ids == {keep_id} + else: # DataTree (multiscale) + for scale in elem: + # at coarser scales an instance may vanish due to downsampling, but no other instance should appear + remaining_ids = set(np.unique(elem[scale].image.compute())) - {0} + assert remaining_ids <= {keep_id} def test_match_sdata_to_table_no_table_argument(sdata): @@ -145,3 +167,69 @@ def test_match_sdata_to_table_no_table_argument(sdata): assert len(matched["table"]) == 10 assert "blobs_polygons-sdata1" in matched assert "blobs_polygons-sdata2" in matched + + +@pytest.mark.parametrize("subset_func_name", ["match_sdata_to_table", "filter_by_table_query"]) +def test_subset_sdata_by_table_mask(subset_func_name: str) -> None: + """ + Subsetting an sdata with mixed element types (labels, shapes, points) keeps only the requested instances. + Labels are pixel-filtered when filter_label_pixels=True. + """ + sdata = concatenate( + { + "labels": blobs_annotating_element("blobs_labels"), + "shapes": blobs_annotating_element("blobs_circles"), + "points": blobs_annotating_element("blobs_points"), + "multiscale_labels": blobs_annotating_element("blobs_multiscale_labels"), + }, + concatenate_tables=True, + ) + table = sdata.tables["table"] + subset_table = table[table.obs["instance_id"] == 3] + + if subset_func_name == "match_sdata_to_table": + subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table, filter_label_pixels=True) + else: + subset_sdata = filter_by_table_query( + sdata, "table", obs_expr=an.col("instance_id") == 3, filter_label_pixels=True + ) + + assert set(subset_sdata.labels.keys()) == {"blobs_labels-labels", "blobs_multiscale_labels-multiscale_labels"} + assert set(subset_sdata.points.keys()) == {"blobs_points-points"} + assert set(subset_sdata.shapes.keys()) == {"blobs_circles-shapes"} + + labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_labels-labels"].data.compute())) - {0} + assert labels_remaining_ids == {3} + + for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]: + ms_labels_remaining_ids = set( + np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute()) + ) - {0} + assert ms_labels_remaining_ids == {3} + + points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"].index)) - {0} + assert points_remaining_ids == {3} + + shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} + assert shapes_remaining_ids == {3} + + +@pytest.mark.parametrize("subset_func_name", ["match_sdata_to_table", "filter_by_table_query"]) +def test_filter_out_instances_3d_labels_not_supported(subset_func_name: str) -> None: + """Pixel-level filtering of 3D labels raises NotImplementedError.""" + data = np.zeros((5, 5, 5), dtype=np.int32) + data[1:3, 1:3, 1:3] = 1 + data[3:5, 3:5, 3:5] = 2 + labels_3d = Labels3DModel.parse(data, dims=["z", "y", "x"]) + + obs_df = pd.DataFrame({"region": pd.Categorical(["labels_3d"]), "instance_id": [1]}, index=["0"]) + table = TableModel.parse( + AnnData(shape=(1, 0), obs=obs_df), region="labels_3d", region_key="region", instance_key="instance_id" + ) + sdata = SpatialData(labels={"labels_3d": labels_3d}, tables={"table": table}) + + with pytest.raises(NotImplementedError, match="3D labels"): + if subset_func_name == "match_sdata_to_table": + match_sdata_to_table(sdata, "table", filter_label_pixels=True) + else: + filter_by_table_query(sdata, "table", obs_expr=an.col("instance_id") == 1, filter_label_pixels=True)