From a28c7c905e40bd748edd6b6f772988c72f8154dd Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 30 Jun 2025 15:17:25 +0200 Subject: [PATCH 01/18] init --- src/spatialdata/_core/query/masking.py | 113 +++++++++++++++++++++++++ tests/core/query/test_masking.py | 39 +++++++++ 2 files changed, 152 insertions(+) create mode 100644 src/spatialdata/_core/query/masking.py create mode 100644 tests/core/query/test_masking.py diff --git a/src/spatialdata/_core/query/masking.py b/src/spatialdata/_core/query/masking.py new file mode 100644 index 00000000..451e7d4e --- /dev/null +++ b/src/spatialdata/_core/query/masking.py @@ -0,0 +1,113 @@ +import numpy as np +import xarray as xr +from functools import partial + +from spatialdata.models import Labels2DModel, ShapesModel +from spatialdata.models.models import DataTree + + + +def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + # Use apply_ufunc for efficient processing + # Create a copy to avoid modifying read-only array + result = block.copy() + result[np.isin(result, ids_to_remove)] = 0 + return result + + +def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + processed = xr.apply_ufunc( + partial(_mask_block, ids_to_remove=ids_to_remove), + image, + input_core_dims=[["y", "x"]], + output_core_dims=[["y", "x"]], + vectorize=True, + dask="parallelized", + output_dtypes=[image.dtype], + dataset_fill_value=0, + dask_gufunc_kwargs={"allow_rechunk": True}, + ) + + # Force computation to ensure the changes are materialized + computed_result = processed.compute() + + # Create a new DataArray to ensure persistence + result = xr.DataArray( + data=computed_result.data, + coords=image.coords, + dims=image.dims, + attrs=image.attrs.copy(), # Preserve all attributes + ) + + return result + + +def _get_scale_factors(labels_element: Labels2DModel) -> list[tuple[float, float]]: + scales = list(labels_element.keys()) + + # Calculate relative scale factors between consecutive scales + scale_factors = [] + for i in range(len(scales) - 1): + y_size_current = labels_element[scales[i]].image.shape[0] + x_size_current = labels_element[scales[i]].image.shape[1] + y_size_next = labels_element[scales[i + 1]].image.shape[0] + x_size_next = labels_element[scales[i + 1]].image.shape[1] + y_factor = y_size_current / y_size_next + x_factor = x_size_current / x_size_next + + scale_factors.append((y_factor, x_factor)) + + return scale_factors + + +def filter_shapesmodel_by_instance_ids(element: ShapesModel, ids_to_remove: list[str]) -> ShapesModel: + """ + Filter a ShapesModel by instance ids. + + Parameters + ---------- + element + The ShapesModel to filter. + ids_to_remove + The instance ids to remove. + + Returns + ------- + The filtered ShapesModel. + """ + return element[~element.index.isin(ids_to_remove)] + + +def filter_labels2dmodel_by_instance_ids(element: Labels2DModel, ids_to_remove: list[int]) -> Labels2DModel: + """ + Filter a Labels2DModel by instance ids. + + This function works for both DataArray and DataTree and sets the + instance ids to zero. + + Parameters + ---------- + element + The Labels2DModel to filter. + ids_to_remove + The instance ids to remove. + + Returns + ------- + The filtered Labels2DModel. + """ + if isinstance(element, xr.DataArray): + return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove)) + + if isinstance(element, DataTree): + # we extract the info to just reconstruct + # the DataTree after filtering the max scale + max_scale = list(element.keys())[0] + scale_factors = _get_scale_factors(element) + scale_factors = [int(sf[0]) for sf in scale_factors] + + return Labels2DModel.parse( + data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove), + scale_factors=scale_factors, + ) + raise ValueError(f"Unknown element type: {type(element)}") diff --git a/tests/core/query/test_masking.py b/tests/core/query/test_masking.py new file mode 100644 index 00000000..89f4db52 --- /dev/null +++ b/tests/core/query/test_masking.py @@ -0,0 +1,39 @@ +import numpy as np +import anndata as ad + +from spatialdata._core.query.masking import filter_labels2dmodel_by_instance_ids, filter_shapesmodel_by_instance_ids +from spatialdata.datasets import blobs_annotating_element + + +def test_filter_labels2dmodel_by_instance_ids(): + sdata = blobs_annotating_element("blobs_labels") + labels_element = sdata["blobs_labels"] + all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() + filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) + + # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 + filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - {0,} + preserved_ids = np.unique(labels_element.data.compute()) + assert filtered_ids == (set(all_instance_ids) - {2, 3}) + # check if there is modification of the original labels + assert set(preserved_ids) == set(all_instance_ids) | {0} + + sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" + sdata.tables["table"].obs.region = "blobs_multiscale_labels" + labels_element = sdata["blobs_multiscale_labels"] + filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) + + for scale in labels_element: + filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - {0,} + preserved_ids = np.unique(labels_element[scale].image.compute()) + assert filtered_ids == (set(all_instance_ids) - {2, 3}) + # check if there is modification of the original labels + assert set(preserved_ids) == set(all_instance_ids) | {0} + + +def test_filter_shapesmodel_by_instance_ids(): + sdata = blobs_annotating_element("blobs_circles") + shapes_element = sdata["blobs_circles"] + filtered_shapes_element = filter_shapesmodel_by_instance_ids(shapes_element, [2, 3]) + + assert set(filtered_shapes_element.index.tolist()) == {0, 1, 4} From 225d593146efd81be25e505231f6450c47bd254c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Jun 2025 13:22:41 +0000 Subject: [PATCH 02/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata/_core/query/masking.py | 4 ++-- tests/core/query/test_masking.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_core/query/masking.py b/src/spatialdata/_core/query/masking.py index 451e7d4e..32510b03 100644 --- a/src/spatialdata/_core/query/masking.py +++ b/src/spatialdata/_core/query/masking.py @@ -1,12 +1,12 @@ +from functools import partial + import numpy as np import xarray as xr -from functools import partial from spatialdata.models import Labels2DModel, ShapesModel from spatialdata.models.models import DataTree - def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: # Use apply_ufunc for efficient processing # Create a copy to avoid modifying read-only array diff --git a/tests/core/query/test_masking.py b/tests/core/query/test_masking.py index 89f4db52..370c153c 100644 --- a/tests/core/query/test_masking.py +++ b/tests/core/query/test_masking.py @@ -1,5 +1,4 @@ import numpy as np -import anndata as ad from spatialdata._core.query.masking import filter_labels2dmodel_by_instance_ids, filter_shapesmodel_by_instance_ids from spatialdata.datasets import blobs_annotating_element @@ -12,7 +11,9 @@ def test_filter_labels2dmodel_by_instance_ids(): filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 - filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - {0,} + filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - { + 0, + } preserved_ids = np.unique(labels_element.data.compute()) assert filtered_ids == (set(all_instance_ids) - {2, 3}) # check if there is modification of the original labels @@ -24,7 +25,9 @@ def test_filter_labels2dmodel_by_instance_ids(): filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) for scale in labels_element: - filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - {0,} + filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - { + 0, + } preserved_ids = np.unique(labels_element[scale].image.compute()) assert filtered_ids == (set(all_instance_ids) - {2, 3}) # check if there is modification of the original labels From e549b4b4725f9ab73eaead533e27ca320491b9c4 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 30 Jun 2025 15:45:47 +0200 Subject: [PATCH 03/18] fix mypy linterrors --- src/spatialdata/_core/query/masking.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/spatialdata/_core/query/masking.py b/src/spatialdata/_core/query/masking.py index 32510b03..0605ad2e 100644 --- a/src/spatialdata/_core/query/masking.py +++ b/src/spatialdata/_core/query/masking.py @@ -2,9 +2,11 @@ import numpy as np import xarray as xr +from geopandas import GeoDataFrame +from xarray.core.dataarray import DataArray +from xarray.core.datatree import DataTree from spatialdata.models import Labels2DModel, ShapesModel -from spatialdata.models.models import DataTree def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: @@ -32,15 +34,13 @@ def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list computed_result = processed.compute() # Create a new DataArray to ensure persistence - result = xr.DataArray( + return xr.DataArray( data=computed_result.data, coords=image.coords, dims=image.dims, attrs=image.attrs.copy(), # Preserve all attributes ) - return result - def _get_scale_factors(labels_element: Labels2DModel) -> list[tuple[float, float]]: scales = list(labels_element.keys()) @@ -60,7 +60,7 @@ def _get_scale_factors(labels_element: Labels2DModel) -> list[tuple[float, float return scale_factors -def filter_shapesmodel_by_instance_ids(element: ShapesModel, ids_to_remove: list[str]) -> ShapesModel: +def filter_shapesmodel_by_instance_ids(element: ShapesModel, ids_to_remove: list[str]) -> GeoDataFrame: """ Filter a ShapesModel by instance ids. @@ -75,10 +75,11 @@ def filter_shapesmodel_by_instance_ids(element: ShapesModel, ids_to_remove: list ------- The filtered ShapesModel. """ - return element[~element.index.isin(ids_to_remove)] + element2: GeoDataFrame = element[~element.index.isin(ids_to_remove)] # type: ignore[index, attr-defined] + return ShapesModel.parse(element2) -def filter_labels2dmodel_by_instance_ids(element: Labels2DModel, ids_to_remove: list[int]) -> Labels2DModel: +def filter_labels2dmodel_by_instance_ids(element: Labels2DModel, ids_to_remove: list[int]) -> DataArray | DataTree: """ Filter a Labels2DModel by instance ids. @@ -103,8 +104,8 @@ def filter_labels2dmodel_by_instance_ids(element: Labels2DModel, ids_to_remove: # we extract the info to just reconstruct # the DataTree after filtering the max scale max_scale = list(element.keys())[0] - scale_factors = _get_scale_factors(element) - scale_factors = [int(sf[0]) for sf in scale_factors] + scale_factors_temp = _get_scale_factors(element) + scale_factors = [int(sf[0]) for sf in scale_factors_temp] return Labels2DModel.parse( data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove), From 2aad72b23c3ccbd4bcc852b28bdf866032d8d1d1 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 3 Jul 2025 16:33:21 +0200 Subject: [PATCH 04/18] update the location and the design --- src/spatialdata/__init__.py | 2 + src/spatialdata/_core/query/masking.py | 114 ---------------- .../_core/query/relational_query.py | 124 ++++++++++++++++++ tests/core/query/test_masking.py | 42 ------ ...tional_query_subset_sdata_by_table_mask.py | 67 ++++++++++ 5 files changed, 193 insertions(+), 156 deletions(-) delete mode 100644 src/spatialdata/_core/query/masking.py delete mode 100644 tests/core/query/test_masking.py create mode 100644 tests/core/query/test_relational_query_subset_sdata_by_table_mask.py diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 0b68391a..724975bf 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -55,6 +55,7 @@ "deepcopy", "sanitize_table", "sanitize_name", + "subset_sdata_by_table_mask", ] from spatialdata import dataloader, datasets, models, transformations @@ -78,6 +79,7 @@ match_element_to_table, match_sdata_to_table, match_table_to_element, + subset_sdata_by_table_mask, ) from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query from spatialdata._core.spatialdata import SpatialData diff --git a/src/spatialdata/_core/query/masking.py b/src/spatialdata/_core/query/masking.py deleted file mode 100644 index 0605ad2e..00000000 --- a/src/spatialdata/_core/query/masking.py +++ /dev/null @@ -1,114 +0,0 @@ -from functools import partial - -import numpy as np -import xarray as xr -from geopandas import GeoDataFrame -from xarray.core.dataarray import DataArray -from xarray.core.datatree import DataTree - -from spatialdata.models import Labels2DModel, ShapesModel - - -def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: - # Use apply_ufunc for efficient processing - # Create a copy to avoid modifying read-only array - result = block.copy() - result[np.isin(result, ids_to_remove)] = 0 - return result - - -def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: - processed = xr.apply_ufunc( - partial(_mask_block, ids_to_remove=ids_to_remove), - image, - input_core_dims=[["y", "x"]], - output_core_dims=[["y", "x"]], - vectorize=True, - dask="parallelized", - output_dtypes=[image.dtype], - dataset_fill_value=0, - dask_gufunc_kwargs={"allow_rechunk": True}, - ) - - # Force computation to ensure the changes are materialized - computed_result = processed.compute() - - # Create a new DataArray to ensure persistence - return xr.DataArray( - data=computed_result.data, - coords=image.coords, - dims=image.dims, - attrs=image.attrs.copy(), # Preserve all attributes - ) - - -def _get_scale_factors(labels_element: Labels2DModel) -> list[tuple[float, float]]: - scales = list(labels_element.keys()) - - # Calculate relative scale factors between consecutive scales - scale_factors = [] - for i in range(len(scales) - 1): - y_size_current = labels_element[scales[i]].image.shape[0] - x_size_current = labels_element[scales[i]].image.shape[1] - y_size_next = labels_element[scales[i + 1]].image.shape[0] - x_size_next = labels_element[scales[i + 1]].image.shape[1] - y_factor = y_size_current / y_size_next - x_factor = x_size_current / x_size_next - - scale_factors.append((y_factor, x_factor)) - - return scale_factors - - -def filter_shapesmodel_by_instance_ids(element: ShapesModel, ids_to_remove: list[str]) -> GeoDataFrame: - """ - Filter a ShapesModel by instance ids. - - Parameters - ---------- - element - The ShapesModel to filter. - ids_to_remove - The instance ids to remove. - - Returns - ------- - The filtered ShapesModel. - """ - element2: GeoDataFrame = element[~element.index.isin(ids_to_remove)] # type: ignore[index, attr-defined] - return ShapesModel.parse(element2) - - -def filter_labels2dmodel_by_instance_ids(element: Labels2DModel, ids_to_remove: list[int]) -> DataArray | DataTree: - """ - Filter a Labels2DModel by instance ids. - - This function works for both DataArray and DataTree and sets the - instance ids to zero. - - Parameters - ---------- - element - The Labels2DModel to filter. - ids_to_remove - The instance ids to remove. - - Returns - ------- - The filtered Labels2DModel. - """ - if isinstance(element, xr.DataArray): - return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove)) - - if isinstance(element, DataTree): - # we extract the info to just reconstruct - # the DataTree after filtering the max scale - max_scale = list(element.keys())[0] - scale_factors_temp = _get_scale_factors(element) - scale_factors = [int(sf[0]) for sf in scale_factors_temp] - - return Labels2DModel.parse( - data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove), - scale_factors=scale_factors, - ) - raise ValueError(f"Unknown element type: {type(element)}") diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index b84d43c1..f82572d3 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -11,9 +11,11 @@ import dask.array as da import numpy as np import pandas as pd +import xarray as xr from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame +from numpy.typing import NDArray from xarray import DataArray, DataTree from spatialdata._core.spatialdata import SpatialData @@ -1019,3 +1021,125 @@ def get_values( return df raise ValueError(f"Unknown origin {origin}") + + +def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + # Use apply_ufunc for efficient processing + # Create a copy to avoid modifying read-only array + result = block.copy() + result[np.isin(result, ids_to_remove)] = 0 + return result + + +def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + processed = xr.apply_ufunc( + partial(_mask_block, ids_to_remove=ids_to_remove), + image, + input_core_dims=[["y", "x"]], + output_core_dims=[["y", "x"]], + vectorize=True, + dask="parallelized", + output_dtypes=[image.dtype], + dataset_fill_value=0, + dask_gufunc_kwargs={"allow_rechunk": True}, + ) + + # Force computation to ensure the changes are materialized + computed_result = processed.compute() + + # Create a new DataArray to ensure persistence + return xr.DataArray( + data=computed_result.data, + coords=image.coords, + dims=image.dims, + attrs=image.attrs.copy(), # Preserve all attributes + ) + + +def _get_scale_factors(labels_element: DataTree) -> list[tuple[float, float]]: + scales = list(labels_element.keys()) + + # Calculate relative scale factors between consecutive scales + scale_factors = [] + for i in range(len(scales) - 1): + y_size_current = labels_element[scales[i]].image.shape[0] + x_size_current = labels_element[scales[i]].image.shape[1] + y_size_next = labels_element[scales[i + 1]].image.shape[0] + x_size_next = labels_element[scales[i + 1]].image.shape[1] + y_factor = y_size_current / y_size_next + x_factor = x_size_current / x_size_next + + scale_factors.append((y_factor, x_factor)) + + return scale_factors + + +@singledispatch +def _filter_by_instance_ids(element: Any, ids_to_remove: list[str], instance_key: str) -> Any: + raise NotImplementedError(f"Filtering by instance ids is not implemented for {element}") + + +@_filter_by_instance_ids.register(GeoDataFrame) +def _(element: GeoDataFrame, ids_to_remove: list[str], instance_key: str) -> GeoDataFrame: + return element[~element.index.isin(ids_to_remove)] + + +@_filter_by_instance_ids.register(DaskDataFrame) +def _(element: DaskDataFrame, ids_to_remove: list[str], instance_key: str) -> DaskDataFrame: + return element[~element[instance_key].isin(ids_to_remove)] + + +@_filter_by_instance_ids.register(DataArray) +def _(element: DataArray, ids_to_remove: list[int], instance_key: str) -> DataArray: + del instance_key + return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove)) + + +@_filter_by_instance_ids.register(DataTree) +def _(element: DataArray | DataTree, ids_to_remove: list[int], instance_key: str) -> xr.DataArray | xr.DataTree: + # we extract the info to just reconstruct + # the DataTree after filtering the max scale + max_scale = list(element.keys())[0] + scale_factors_temp = _get_scale_factors(element) + scale_factors = [int(sf[0]) for sf in scale_factors_temp] + + return Labels2DModel.parse( + data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove), + scale_factors=scale_factors, + ) + + +def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArray[np.bool_]) -> SpatialData: + """ + Subset a SpatialData object by a table and a mask. + + Parameters + ---------- + sdata + The SpatialData object to subset. + table_name + The name of the table to apply the mask to. + mask + Boolean mask to apply to the table. + + Returns + ------- + The subsetted SpatialData object. + """ + table = sdata.tables.get(table_name) + if table is None: + raise ValueError(f"Table {table_name} not found in SpatialData object.") + + subset_table = table[mask] + _, _, instance_key = get_table_keys(subset_table) + annotated_regions = SpatialData.get_annotated_regions(table) + removed_instance_ids = list(np.unique(table.obs[instance_key][~mask])) + + filtered_elements = {} + for reg in annotated_regions: + elem = sdata.get(reg) + model = get_model(elem) + if model in [Labels2DModel, PointsModel, ShapesModel]: + filtered_elements[reg] = _filter_by_instance_ids(elem, removed_instance_ids, instance_key) + + return SpatialData.init_from_elements(filtered_elements | {table_name: subset_table}) diff --git a/tests/core/query/test_masking.py b/tests/core/query/test_masking.py deleted file mode 100644 index 370c153c..00000000 --- a/tests/core/query/test_masking.py +++ /dev/null @@ -1,42 +0,0 @@ -import numpy as np - -from spatialdata._core.query.masking import filter_labels2dmodel_by_instance_ids, filter_shapesmodel_by_instance_ids -from spatialdata.datasets import blobs_annotating_element - - -def test_filter_labels2dmodel_by_instance_ids(): - sdata = blobs_annotating_element("blobs_labels") - labels_element = sdata["blobs_labels"] - all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() - filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) - - # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 - filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - { - 0, - } - preserved_ids = np.unique(labels_element.data.compute()) - assert filtered_ids == (set(all_instance_ids) - {2, 3}) - # check if there is modification of the original labels - assert set(preserved_ids) == set(all_instance_ids) | {0} - - sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" - sdata.tables["table"].obs.region = "blobs_multiscale_labels" - labels_element = sdata["blobs_multiscale_labels"] - filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) - - for scale in labels_element: - filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - { - 0, - } - preserved_ids = np.unique(labels_element[scale].image.compute()) - assert filtered_ids == (set(all_instance_ids) - {2, 3}) - # check if there is modification of the original labels - assert set(preserved_ids) == set(all_instance_ids) | {0} - - -def test_filter_shapesmodel_by_instance_ids(): - sdata = blobs_annotating_element("blobs_circles") - shapes_element = sdata["blobs_circles"] - filtered_shapes_element = filter_shapesmodel_by_instance_ids(shapes_element, [2, 3]) - - assert set(filtered_shapes_element.index.tolist()) == {0, 1, 4} diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py new file mode 100644 index 00000000..bc7e0967 --- /dev/null +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -0,0 +1,67 @@ +import numpy as np + +from spatialdata._core.query.relational_query import _filter_by_instance_ids +from spatialdata.datasets import blobs_annotating_element +from spatialdata import concatenate, subset_sdata_by_table_mask + + +def test_filter_labels2dmodel_by_instance_ids(): + sdata = blobs_annotating_element("blobs_labels") + labels_element = sdata["blobs_labels"] + all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() + filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id") + + # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 + filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - { + 0, + } + preserved_ids = np.unique(labels_element.data.compute()) + assert filtered_ids == (set(all_instance_ids) - {2, 3}) + # check if there is modification of the original labels + assert set(preserved_ids) == set(all_instance_ids) | {0} + + sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" + sdata.tables["table"].obs.region = "blobs_multiscale_labels" + labels_element = sdata["blobs_multiscale_labels"] + filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id") + + for scale in labels_element: + filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - { + 0, + } + preserved_ids = np.unique(labels_element[scale].image.compute()) + assert filtered_ids == (set(all_instance_ids) - {2, 3}) + # check if there is modification of the original labels + assert set(preserved_ids) == set(all_instance_ids) | {0} + + +def test_subset_sdata_by_table_mask(): + sdata = concatenate( + { + "labels": blobs_annotating_element("blobs_labels"), + "shapes": blobs_annotating_element("blobs_circles"), + "points": blobs_annotating_element("blobs_points"), + "multiscale_labels": blobs_annotating_element("blobs_multiscale_labels"), + }, + concatenate_tables=True, + ) + third_elems = sdata.tables["table"].obs["instance_id"] == 3 + subset_sdata = subset_sdata_by_table_mask(sdata, "table", third_elems) + + assert set(subset_sdata.labels.keys()) == {"blobs_labels-labels", "blobs_multiscale_labels-multiscale_labels"} + assert set(subset_sdata.points.keys()) == {"blobs_points-points"} + assert set(subset_sdata.shapes.keys()) == {"blobs_circles-shapes"} + + labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_labels-labels"].data.compute())) - {0} + assert labels_remaining_ids == {3} + + for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]: + ms_labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute())) - {0} + assert ms_labels_remaining_ids == {3} + + points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]['instance_id'].compute())) - {0} + assert points_remaining_ids == {3} + + shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} + assert shapes_remaining_ids == {3} + From ef7405790cca44e7b95f479661290ce13d9b62ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Jul 2025 14:33:39 +0000 Subject: [PATCH 05/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../test_relational_query_subset_sdata_by_table_mask.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index bc7e0967..810e2f73 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -1,8 +1,8 @@ import numpy as np +from spatialdata import concatenate, subset_sdata_by_table_mask from spatialdata._core.query.relational_query import _filter_by_instance_ids from spatialdata.datasets import blobs_annotating_element -from spatialdata import concatenate, subset_sdata_by_table_mask def test_filter_labels2dmodel_by_instance_ids(): @@ -56,12 +56,13 @@ def test_subset_sdata_by_table_mask(): assert labels_remaining_ids == {3} for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]: - ms_labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute())) - {0} + ms_labels_remaining_ids = set( + np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute()) + ) - {0} assert ms_labels_remaining_ids == {3} - points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]['instance_id'].compute())) - {0} + points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]["instance_id"].compute())) - {0} assert points_remaining_ids == {3} shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} assert shapes_remaining_ids == {3} - From d6e22cbe9ce1eff0f752b744c5be814b4b724b57 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 3 Jul 2025 16:41:30 +0200 Subject: [PATCH 06/18] update docs --- src/spatialdata/_core/query/relational_query.py | 13 ++++++++++--- ...t_relational_query_subset_sdata_by_table_mask.py | 9 +++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index f82572d3..8914d1da 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1110,8 +1110,15 @@ def _(element: DataArray | DataTree, ids_to_remove: list[int], instance_key: str def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArray[np.bool_]) -> SpatialData: - """ - Subset a SpatialData object by a table and a mask. + """Subset the annotated elements of a SpatialData object by a table and a mask. + + The mask is applied to the table and the annotated elements are subsetted + by the instance ids in the table. + This function returns a new SpatialData object with the subsetted elements. + Elements that are not annotated by the table are not included in the returned SpatialData object. + The element models that are + supported are :class:`spatialdata.models.Labels2DModel`, + :class:`spatialdata.models.PointsModel`, and :class:`spatialdata.models.ShapesModel`. Parameters ---------- @@ -1120,7 +1127,7 @@ def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArra table_name The name of the table to apply the mask to. mask - Boolean mask to apply to the table. + Boolean mask to apply to the table which is the same length as the number of rows in the table. Returns ------- diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index bc7e0967..810e2f73 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -1,8 +1,8 @@ import numpy as np +from spatialdata import concatenate, subset_sdata_by_table_mask from spatialdata._core.query.relational_query import _filter_by_instance_ids from spatialdata.datasets import blobs_annotating_element -from spatialdata import concatenate, subset_sdata_by_table_mask def test_filter_labels2dmodel_by_instance_ids(): @@ -56,12 +56,13 @@ def test_subset_sdata_by_table_mask(): assert labels_remaining_ids == {3} for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]: - ms_labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute())) - {0} + ms_labels_remaining_ids = set( + np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute()) + ) - {0} assert ms_labels_remaining_ids == {3} - points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]['instance_id'].compute())) - {0} + points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]["instance_id"].compute())) - {0} assert points_remaining_ids == {3} shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} assert shapes_remaining_ids == {3} - From 80d95a2710b7bdd5e6c020048864e64802810b0b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 3 Jul 2025 17:06:35 +0200 Subject: [PATCH 07/18] make coverage 100/100 because why not --- ...st_relational_query_subset_sdata_by_table_mask.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index 810e2f73..46d76eae 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from spatialdata import concatenate, subset_sdata_by_table_mask from spatialdata._core.query.relational_query import _filter_by_instance_ids @@ -66,3 +67,14 @@ def test_subset_sdata_by_table_mask(): shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} assert shapes_remaining_ids == {3} + + +def test_subset_sdata_by_table_mask_with_no_annotated_elements(): + with pytest.raises(ValueError, match="Table table_not_found not found in SpatialData object."): + sdata = blobs_annotating_element("blobs_labels") + _ = subset_sdata_by_table_mask(sdata, "table_not_found", sdata.tables["table"].obs["instance_id"] == 3) + + +def test_filter_by_instance_ids_fails_for_unsupported_element_models(): + with pytest.raises(NotImplementedError, match="Filtering by instance ids is not implemented for"): + _filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id") \ No newline at end of file From 44386052a31ec1d612db8b5ada3b3188f155947e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Jul 2025 15:06:53 +0000 Subject: [PATCH 08/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../query/test_relational_query_subset_sdata_by_table_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index 46d76eae..d634f645 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -77,4 +77,4 @@ def test_subset_sdata_by_table_mask_with_no_annotated_elements(): def test_filter_by_instance_ids_fails_for_unsupported_element_models(): with pytest.raises(NotImplementedError, match="Filtering by instance ids is not implemented for"): - _filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id") \ No newline at end of file + _filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id") From 4c927ee86ca7e60e86262e9050d57ac40575401f Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 10 Jul 2025 17:48:36 +0200 Subject: [PATCH 09/18] fixed type annotation --- src/spatialdata/_core/query/relational_query.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 8914d1da..50332955 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1096,9 +1096,10 @@ def _(element: DataArray, ids_to_remove: list[int], instance_key: str) -> DataAr @_filter_by_instance_ids.register(DataTree) -def _(element: DataArray | DataTree, ids_to_remove: list[int], instance_key: str) -> xr.DataArray | xr.DataTree: +def _(element: DataTree, ids_to_remove: list[int], instance_key: str) -> DataTree: # we extract the info to just reconstruct # the DataTree after filtering the max scale + del instance_key max_scale = list(element.keys())[0] scale_factors_temp = _get_scale_factors(element) scale_factors = [int(sf[0]) for sf in scale_factors_temp] From e9e0da2c62bad3e1dd59688765427aa03595bca9 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 10 Jul 2025 17:54:56 +0200 Subject: [PATCH 10/18] dont compute eagerly use. delete other instance key for consistency --- src/spatialdata/_core/query/relational_query.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 50332955..ba6c1eb2 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1044,12 +1044,9 @@ def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list dask_gufunc_kwargs={"allow_rechunk": True}, ) - # Force computation to ensure the changes are materialized - computed_result = processed.compute() - # Create a new DataArray to ensure persistence return xr.DataArray( - data=computed_result.data, + data=processed.data, coords=image.coords, dims=image.dims, attrs=image.attrs.copy(), # Preserve all attributes @@ -1081,6 +1078,7 @@ def _filter_by_instance_ids(element: Any, ids_to_remove: list[str], instance_key @_filter_by_instance_ids.register(GeoDataFrame) def _(element: GeoDataFrame, ids_to_remove: list[str], instance_key: str) -> GeoDataFrame: + del instance_key return element[~element.index.isin(ids_to_remove)] From 7534c91b4181f6c5e0feb8f0cf3c1dfa3ed5bd1f Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 14 Jul 2025 15:40:11 +0200 Subject: [PATCH 11/18] update the tests and make sure we use match_element_to_table --- src/spatialdata/_core/query/relational_query.py | 17 +++++------------ ...lational_query_subset_sdata_by_table_mask.py | 10 +++++----- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index ba6c1eb2..e570caec 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1076,17 +1076,6 @@ def _filter_by_instance_ids(element: Any, ids_to_remove: list[str], instance_key raise NotImplementedError(f"Filtering by instance ids is not implemented for {element}") -@_filter_by_instance_ids.register(GeoDataFrame) -def _(element: GeoDataFrame, ids_to_remove: list[str], instance_key: str) -> GeoDataFrame: - del instance_key - return element[~element.index.isin(ids_to_remove)] - - -@_filter_by_instance_ids.register(DaskDataFrame) -def _(element: DaskDataFrame, ids_to_remove: list[str], instance_key: str) -> DaskDataFrame: - return element[~element[instance_key].isin(ids_to_remove)] - - @_filter_by_instance_ids.register(DataArray) def _(element: DataArray, ids_to_remove: list[int], instance_key: str) -> DataArray: del instance_key @@ -1137,6 +1126,7 @@ def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArra raise ValueError(f"Table {table_name} not found in SpatialData object.") subset_table = table[mask] + sdata.tables[table_name] = subset_table _, _, instance_key = get_table_keys(subset_table) annotated_regions = SpatialData.get_annotated_regions(table) removed_instance_ids = list(np.unique(table.obs[instance_key][~mask])) @@ -1145,7 +1135,10 @@ def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArra for reg in annotated_regions: elem = sdata.get(reg) model = get_model(elem) - if model in [Labels2DModel, PointsModel, ShapesModel]: + if model is Labels2DModel: filtered_elements[reg] = _filter_by_instance_ids(elem, removed_instance_ids, instance_key) + elif model in [PointsModel, ShapesModel]: + element_dict, _ = match_element_to_table(sdata, element_name=reg, table_name=table_name) + filtered_elements[reg] = element_dict[reg] return SpatialData.init_from_elements(filtered_elements | {table_name: subset_table}) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index d634f645..629f58bb 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -6,7 +6,7 @@ from spatialdata.datasets import blobs_annotating_element -def test_filter_labels2dmodel_by_instance_ids(): +def test_filter_labels2dmodel_by_instance_ids() -> None: sdata = blobs_annotating_element("blobs_labels") labels_element = sdata["blobs_labels"] all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() @@ -36,7 +36,7 @@ def test_filter_labels2dmodel_by_instance_ids(): assert set(preserved_ids) == set(all_instance_ids) | {0} -def test_subset_sdata_by_table_mask(): +def test_subset_sdata_by_table_mask() -> None: sdata = concatenate( { "labels": blobs_annotating_element("blobs_labels"), @@ -62,19 +62,19 @@ def test_subset_sdata_by_table_mask(): ) - {0} assert ms_labels_remaining_ids == {3} - points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]["instance_id"].compute())) - {0} + points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"].index)) - {0} assert points_remaining_ids == {3} shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} assert shapes_remaining_ids == {3} -def test_subset_sdata_by_table_mask_with_no_annotated_elements(): +def test_subset_sdata_by_table_mask_with_no_annotated_elements() -> None: with pytest.raises(ValueError, match="Table table_not_found not found in SpatialData object."): sdata = blobs_annotating_element("blobs_labels") _ = subset_sdata_by_table_mask(sdata, "table_not_found", sdata.tables["table"].obs["instance_id"] == 3) -def test_filter_by_instance_ids_fails_for_unsupported_element_models(): +def test_filter_by_instance_ids_fails_for_unsupported_element_models() -> None: with pytest.raises(NotImplementedError, match="Filtering by instance ids is not implemented for"): _filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id") From 908f7b4797d8555e955b5195b1d459d6702c44b9 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 11 May 2026 13:41:18 +0200 Subject: [PATCH 12/18] wip rewrite tests using existing APIs --- ...tional_query_subset_sdata_by_table_mask.py | 46 ++++++++++++++++--- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index 629f58bb..9be127a4 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -1,16 +1,25 @@ +import warnings + import numpy as np import pytest +from xarray import DataArray -from spatialdata import concatenate, subset_sdata_by_table_mask -from spatialdata._core.query.relational_query import _filter_by_instance_ids +from spatialdata import concatenate, match_sdata_to_table, subset_sdata_by_table_mask +from spatialdata._core.query.relational_query import ( + _filter_by_instance_ids, + _set_instance_ids_in_labels_to_zero, +) from spatialdata.datasets import blobs_annotating_element +from spatialdata.models import Labels2DModel def test_filter_labels2dmodel_by_instance_ids() -> None: sdata = blobs_annotating_element("blobs_labels") labels_element = sdata["blobs_labels"] all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() - filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id") + filtered_labels_element = Labels2DModel.parse( + _set_instance_ids_in_labels_to_zero(labels_element, [2, 3]) + ) # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - { @@ -24,7 +33,11 @@ def test_filter_labels2dmodel_by_instance_ids() -> None: sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" sdata.tables["table"].obs.region = "blobs_multiscale_labels" labels_element = sdata["blobs_multiscale_labels"] - filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id") + max_scale = list(labels_element.keys())[0] + filtered_labels_element = Labels2DModel.parse( + _set_instance_ids_in_labels_to_zero(labels_element[max_scale].image, [2, 3]), + scale_factors=[2, 2], # blobs uses scale_factors=[2, 2], see datasets.py + ) for scale in labels_element: filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - { @@ -46,8 +59,29 @@ def test_subset_sdata_by_table_mask() -> None: }, concatenate_tables=True, ) - third_elems = sdata.tables["table"].obs["instance_id"] == 3 - subset_sdata = subset_sdata_by_table_mask(sdata, "table", third_elems) + table = sdata.tables["table"] + third_elems = table.obs["instance_id"] == 3 + ids_to_remove = list(np.unique(table.obs["instance_id"][~third_elems])) + subset_table = table[third_elems] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) # labels not supported for right join + subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table) + + for label_name in list(subset_sdata.labels.keys()): + elem = subset_sdata[label_name] + del subset_sdata[label_name] + if isinstance(elem, DataArray): + filtered = Labels2DModel.parse( + _set_instance_ids_in_labels_to_zero(elem, ids_to_remove) + ) + else: # DataTree (multiscale) + max_scale = list(elem.keys())[0] + filtered = Labels2DModel.parse( + _set_instance_ids_in_labels_to_zero(elem[max_scale].image, ids_to_remove), + scale_factors=[2, 2], + ) + subset_sdata[label_name] = filtered assert set(subset_sdata.labels.keys()) == {"blobs_labels-labels", "blobs_multiscale_labels-multiscale_labels"} assert set(subset_sdata.points.keys()) == {"blobs_points-points"} From 359d553c15ba3fbfcc53cd143f31c4501854a6a2 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 11 May 2026 15:34:54 +0200 Subject: [PATCH 13/18] test passing without using subset_sdata_by_table_mask() --- ...tional_query_subset_sdata_by_table_mask.py | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index 9be127a4..d2d3b81f 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import warnings import numpy as np import pytest from xarray import DataArray -from spatialdata import concatenate, match_sdata_to_table, subset_sdata_by_table_mask +from spatialdata import concatenate, match_sdata_to_table from spatialdata._core.query.relational_query import ( _filter_by_instance_ids, _set_instance_ids_in_labels_to_zero, @@ -17,9 +19,7 @@ def test_filter_labels2dmodel_by_instance_ids() -> None: sdata = blobs_annotating_element("blobs_labels") labels_element = sdata["blobs_labels"] all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() - filtered_labels_element = Labels2DModel.parse( - _set_instance_ids_in_labels_to_zero(labels_element, [2, 3]) - ) + filtered_labels_element = Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(labels_element, [2, 3])) # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - { @@ -33,20 +33,18 @@ def test_filter_labels2dmodel_by_instance_ids() -> None: sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" sdata.tables["table"].obs.region = "blobs_multiscale_labels" labels_element = sdata["blobs_multiscale_labels"] - max_scale = list(labels_element.keys())[0] - filtered_labels_element = Labels2DModel.parse( - _set_instance_ids_in_labels_to_zero(labels_element[max_scale].image, [2, 3]), - scale_factors=[2, 2], # blobs uses scale_factors=[2, 2], see datasets.py - ) + # Apply the filter independently at each scale: coarser scales may already be missing small + # instances (e.g. instance 5 disappears at scale2 in the original multiscale due to downsampling), + # so we compare against the IDs actually present at each scale rather than all_instance_ids. for scale in labels_element: - filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - { - 0, - } - preserved_ids = np.unique(labels_element[scale].image.compute()) - assert filtered_ids == (set(all_instance_ids) - {2, 3}) + scale_image = labels_element[scale].image + ids_at_scale = set(np.unique(scale_image.compute())) + filtered_image = _set_instance_ids_in_labels_to_zero(scale_image, [2, 3]) + filtered_ids = set(np.unique(filtered_image.compute())) - {0} + assert filtered_ids == (ids_at_scale - {0, 2, 3}) # check if there is modification of the original labels - assert set(preserved_ids) == set(all_instance_ids) | {0} + assert set(np.unique(scale_image.compute())) == ids_at_scale def test_subset_sdata_by_table_mask() -> None: @@ -72,9 +70,7 @@ def test_subset_sdata_by_table_mask() -> None: elem = subset_sdata[label_name] del subset_sdata[label_name] if isinstance(elem, DataArray): - filtered = Labels2DModel.parse( - _set_instance_ids_in_labels_to_zero(elem, ids_to_remove) - ) + filtered = Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(elem, ids_to_remove)) else: # DataTree (multiscale) max_scale = list(elem.keys())[0] filtered = Labels2DModel.parse( @@ -103,12 +99,6 @@ def test_subset_sdata_by_table_mask() -> None: assert shapes_remaining_ids == {3} -def test_subset_sdata_by_table_mask_with_no_annotated_elements() -> None: - with pytest.raises(ValueError, match="Table table_not_found not found in SpatialData object."): - sdata = blobs_annotating_element("blobs_labels") - _ = subset_sdata_by_table_mask(sdata, "table_not_found", sdata.tables["table"].obs["instance_id"] == 3) - - def test_filter_by_instance_ids_fails_for_unsupported_element_models() -> None: with pytest.raises(NotImplementedError, match="Filtering by instance ids is not implemented for"): _filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id") From 71e27efa76f285042962cf3a182a1b559b0799df Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 11 May 2026 16:50:51 +0200 Subject: [PATCH 14/18] Remove _filter_by_instance_ids and _get_scale_factors; refactor tests to use existing API - Remove _get_scale_factors (duplicated logic already in transformations/_utils.py) - Remove _filter_by_instance_ids and subset_sdata_by_table_mask (superseded by match_sdata_to_table / filter_by_table_query) - Parametrize test_subset_sdata_by_table_mask over both API functions - Replace test_filter_2d_labels_by_instance_ids with test_filter_out_instances, parametrized over both API functions and element types (2D / multiscale labels) Co-Authored-By: Claude Sonnet 4.6 --- .../_core/query/relational_query.py | 92 ------------------- ...tional_query_subset_sdata_by_table_mask.py | 92 ++++++++++--------- 2 files changed, 50 insertions(+), 134 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 52613701..f05c41d5 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -16,7 +16,6 @@ from annsel.core.typing import Predicates from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame -from numpy.typing import NDArray from xarray import DataArray, DataTree from spatialdata._core.spatialdata import SpatialData @@ -1131,94 +1130,3 @@ def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list dims=image.dims, attrs=image.attrs.copy(), # Preserve all attributes ) - - -def _get_scale_factors(labels_element: DataTree) -> list[tuple[float, float]]: - scales = list(labels_element.keys()) - - # Calculate relative scale factors between consecutive scales - scale_factors = [] - for i in range(len(scales) - 1): - y_size_current = labels_element[scales[i]].image.shape[0] - x_size_current = labels_element[scales[i]].image.shape[1] - y_size_next = labels_element[scales[i + 1]].image.shape[0] - x_size_next = labels_element[scales[i + 1]].image.shape[1] - y_factor = y_size_current / y_size_next - x_factor = x_size_current / x_size_next - - scale_factors.append((y_factor, x_factor)) - - return scale_factors - - -@singledispatch -def _filter_by_instance_ids(element: Any, ids_to_remove: list[str], instance_key: str) -> Any: - raise NotImplementedError(f"Filtering by instance ids is not implemented for {element}") - - -@_filter_by_instance_ids.register(DataArray) -def _(element: DataArray, ids_to_remove: list[int], instance_key: str) -> DataArray: - del instance_key - return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove)) - - -@_filter_by_instance_ids.register(DataTree) -def _(element: DataTree, ids_to_remove: list[int], instance_key: str) -> DataTree: - # we extract the info to just reconstruct - # the DataTree after filtering the max scale - del instance_key - max_scale = list(element.keys())[0] - scale_factors_temp = _get_scale_factors(element) - scale_factors = [int(sf[0]) for sf in scale_factors_temp] - - return Labels2DModel.parse( - data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove), - scale_factors=scale_factors, - ) - - -def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArray[np.bool_]) -> SpatialData: - """Subset the annotated elements of a SpatialData object by a table and a mask. - - The mask is applied to the table and the annotated elements are subsetted - by the instance ids in the table. - This function returns a new SpatialData object with the subsetted elements. - Elements that are not annotated by the table are not included in the returned SpatialData object. - The element models that are - supported are :class:`spatialdata.models.Labels2DModel`, - :class:`spatialdata.models.PointsModel`, and :class:`spatialdata.models.ShapesModel`. - - Parameters - ---------- - sdata - The SpatialData object to subset. - table_name - The name of the table to apply the mask to. - mask - Boolean mask to apply to the table which is the same length as the number of rows in the table. - - Returns - ------- - The subsetted SpatialData object. - """ - table = sdata.tables.get(table_name) - if table is None: - raise ValueError(f"Table {table_name} not found in SpatialData object.") - - subset_table = table[mask] - sdata.tables[table_name] = subset_table - _, _, instance_key = get_table_keys(subset_table) - annotated_regions = SpatialData.get_annotated_regions(table) - removed_instance_ids = list(np.unique(table.obs[instance_key][~mask])) - - filtered_elements = {} - for reg in annotated_regions: - elem = sdata.get(reg) - model = get_model(elem) - if model is Labels2DModel: - filtered_elements[reg] = _filter_by_instance_ids(elem, removed_instance_ids, instance_key) - elif model in [PointsModel, ShapesModel]: - element_dict, _ = match_element_to_table(sdata, element_name=reg, table_name=table_name) - filtered_elements[reg] = element_dict[reg] - - return SpatialData.init_from_elements(filtered_elements | {table_name: subset_table}) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index d2d3b81f..83a47591 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -2,52 +2,22 @@ import warnings +import annsel as an import numpy as np import pytest from xarray import DataArray from spatialdata import concatenate, match_sdata_to_table from spatialdata._core.query.relational_query import ( - _filter_by_instance_ids, _set_instance_ids_in_labels_to_zero, + filter_by_table_query, ) from spatialdata.datasets import blobs_annotating_element from spatialdata.models import Labels2DModel -def test_filter_labels2dmodel_by_instance_ids() -> None: - sdata = blobs_annotating_element("blobs_labels") - labels_element = sdata["blobs_labels"] - all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() - filtered_labels_element = Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(labels_element, [2, 3])) - - # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 - filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - { - 0, - } - preserved_ids = np.unique(labels_element.data.compute()) - assert filtered_ids == (set(all_instance_ids) - {2, 3}) - # check if there is modification of the original labels - assert set(preserved_ids) == set(all_instance_ids) | {0} - - sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" - sdata.tables["table"].obs.region = "blobs_multiscale_labels" - labels_element = sdata["blobs_multiscale_labels"] - - # Apply the filter independently at each scale: coarser scales may already be missing small - # instances (e.g. instance 5 disappears at scale2 in the original multiscale due to downsampling), - # so we compare against the IDs actually present at each scale rather than all_instance_ids. - for scale in labels_element: - scale_image = labels_element[scale].image - ids_at_scale = set(np.unique(scale_image.compute())) - filtered_image = _set_instance_ids_in_labels_to_zero(scale_image, [2, 3]) - filtered_ids = set(np.unique(filtered_image.compute())) - {0} - assert filtered_ids == (ids_at_scale - {0, 2, 3}) - # check if there is modification of the original labels - assert set(np.unique(scale_image.compute())) == ids_at_scale - - -def test_subset_sdata_by_table_mask() -> None: +@pytest.mark.parametrize("subset_func_name", ["match_sdata_to_table", "filter_by_table_query"]) +def test_subset_sdata_by_table_mask(subset_func_name: str) -> None: sdata = concatenate( { "labels": blobs_annotating_element("blobs_labels"), @@ -63,8 +33,11 @@ def test_subset_sdata_by_table_mask() -> None: subset_table = table[third_elems] with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) # labels not supported for right join - subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table) + warnings.filterwarnings("ignore", message="labels not supported for right join", category=UserWarning) + if subset_func_name == "match_sdata_to_table": + subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table) + else: + subset_sdata = filter_by_table_query(sdata, "table", obs_expr=an.col("instance_id") == 3) for label_name in list(subset_sdata.labels.keys()): elem = subset_sdata[label_name] @@ -72,10 +45,14 @@ def test_subset_sdata_by_table_mask() -> None: if isinstance(elem, DataArray): filtered = Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(elem, ids_to_remove)) else: # DataTree (multiscale) - max_scale = list(elem.keys())[0] + scales = list(elem.keys()) + scale_factors = [ + round(elem[scales[i]].image.shape[0] / elem[scales[i + 1]].image.shape[0]) + for i in range(len(scales) - 1) + ] filtered = Labels2DModel.parse( - _set_instance_ids_in_labels_to_zero(elem[max_scale].image, ids_to_remove), - scale_factors=[2, 2], + _set_instance_ids_in_labels_to_zero(elem[scales[0]].image, ids_to_remove), + scale_factors=scale_factors, ) subset_sdata[label_name] = filtered @@ -99,6 +76,37 @@ def test_subset_sdata_by_table_mask() -> None: assert shapes_remaining_ids == {3} -def test_filter_by_instance_ids_fails_for_unsupported_element_models() -> None: - with pytest.raises(NotImplementedError, match="Filtering by instance ids is not implemented for"): - _filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id") +@pytest.mark.parametrize("subset_func_name", ["match_sdata_to_table", "filter_by_table_query"]) +@pytest.mark.parametrize("element_name", ["blobs_labels", "blobs_multiscale_labels"]) +def test_filter_out_instances(subset_func_name: str, element_name: str) -> None: + sdata = blobs_annotating_element(element_name) + table = sdata.tables["table"] + keep_id = 3 + ids_to_remove = [i for i in table.obs["instance_id"].unique() if i != keep_id] + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="labels not supported for right join", category=UserWarning) + if subset_func_name == "match_sdata_to_table": + subset_table = table[table.obs["instance_id"] == keep_id] + subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table) + else: + subset_sdata = filter_by_table_query(sdata, "table", obs_expr=an.col("instance_id") == keep_id) + + elem = subset_sdata[element_name] + if isinstance(elem, DataArray): + filtered = Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(elem, ids_to_remove)) + remaining_ids = set(np.unique(filtered.data.compute())) - {0} + assert remaining_ids == {keep_id} + else: # DataTree (multiscale) + scales = list(elem.keys()) + scale_factors = [ + round(elem[scales[i]].image.shape[0] / elem[scales[i + 1]].image.shape[0]) for i in range(len(scales) - 1) + ] + filtered = Labels2DModel.parse( + _set_instance_ids_in_labels_to_zero(elem[scales[0]].image, ids_to_remove), + scale_factors=scale_factors, + ) + for scale in filtered: + # at coarser scales an instance may vanish due to downsampling, but no other instance should appear + remaining_ids = set(np.unique(filtered[scale].image.compute())) - {0} + assert remaining_ids <= {keep_id} From 264ad2ab99689a2bc88b96691c86bd75369a7269 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 11 May 2026 16:58:11 +0200 Subject: [PATCH 15/18] Add filter_label_pixels flag to match_sdata_to_table and filter_by_table_query MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Threads a filter_label_pixels: bool = False parameter through the full join stack (filter_by_table_query → match_sdata_to_table → join_spatialelement_table → _call_join → _right/_inner_join_spatialelement_table). When True, label pixels for removed instances are zeroed via a new _filter_labels_element helper (handles both DataArray and multiscale DataTree). When False (default), the existing warning is preserved but now also hints at the new flag. Tests no longer need manual _set_instance_ids_in_labels_to_zero calls or warnings suppression. Co-Authored-By: Claude Sonnet 4.6 --- .../_core/query/relational_query.py | 105 ++++++++++++++---- ...tional_query_subset_sdata_by_table_mask.py | 68 +++--------- 2 files changed, 102 insertions(+), 71 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index f05c41d5..498a389e 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -312,7 +312,10 @@ def _get_masked_element( def _right_exclusive_join_spatialelement_table( - element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] + element_dict: dict[str, dict[str, Any]], + table: AnnData, + match_rows: Literal["left", "no", "right"], + filter_label_pixels: bool = False, ) -> tuple[dict[str, Any], AnnData | None]: regions, region_column_name, instance_key = get_table_keys(table) if isinstance(regions, str): @@ -350,7 +353,10 @@ def _right_exclusive_join_spatialelement_table( def _right_join_spatialelement_table( - element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] + element_dict: dict[str, dict[str, Any]], + table: AnnData, + match_rows: Literal["left", "no", "right"], + filter_label_pixels: bool = False, ) -> tuple[dict[str, Any], AnnData]: if match_rows == "left": warnings.warn("Matching rows 'left' is not supported for 'right' join.", UserWarning, stacklevel=2) @@ -366,11 +372,17 @@ def _right_join_spatialelement_table( if element_type in ["points", "shapes"]: element_indices = element.index else: - warnings.warn( - f"Element type `labels` not supported for 'right' join. Skipping `{name}`", - UserWarning, - stacklevel=2, - ) + if filter_label_pixels: + element_dict[element_type][name] = _filter_labels_element( + element, table_instance_key_column.tolist() + ) + else: + warnings.warn( + f"Element type `labels` not supported for 'right' join, pixels are not filtered;" + f" to filter label pixels pass `filter_label_pixels=True`. Skipping `{name}`", + UserWarning, + stacklevel=2, + ) continue masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows) @@ -384,7 +396,10 @@ def _right_join_spatialelement_table( def _inner_join_spatialelement_table( - element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] + element_dict: dict[str, dict[str, Any]], + table: AnnData, + match_rows: Literal["left", "no", "right"], + filter_label_pixels: bool = False, ) -> tuple[dict[str, Any], AnnData]: regions, region_column_name, instance_key = get_table_keys(table) if isinstance(regions, str): @@ -400,11 +415,17 @@ def _inner_join_spatialelement_table( if element_type in ["points", "shapes"]: element_indices = element.index else: - warnings.warn( - f"Element type `labels` not supported for 'inner' join. Skipping `{name}`", - UserWarning, - stacklevel=2, - ) + if filter_label_pixels: + element_dict[element_type][name] = _filter_labels_element( + element, table_instance_key_column.tolist() + ) + else: + warnings.warn( + f"Element type `labels` not supported for 'inner' join, pixels are not filtered;" + f" to filter label pixels pass `filter_label_pixels=True`. Skipping `{name}`", + UserWarning, + stacklevel=2, + ) continue masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows) @@ -430,7 +451,10 @@ def _inner_join_spatialelement_table( def _left_exclusive_join_spatialelement_table( - element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] + element_dict: dict[str, dict[str, Any]], + table: AnnData, + match_rows: Literal["left", "no", "right"], + filter_label_pixels: bool = False, ) -> tuple[dict[str, Any], AnnData | None]: regions, region_column_name, instance_key = get_table_keys(table) if isinstance(regions, str): @@ -463,7 +487,10 @@ def _left_exclusive_join_spatialelement_table( def _left_join_spatialelement_table( - element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] + element_dict: dict[str, dict[str, Any]], + table: AnnData, + match_rows: Literal["left", "no", "right"], + filter_label_pixels: bool = False, ) -> tuple[dict[str, Any], AnnData]: if match_rows == "right": warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2) @@ -587,6 +614,7 @@ def join_spatialelement_table( table: AnnData | None = None, how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "left", match_rows: Literal["no", "left", "right"] = "no", + filter_label_pixels: bool = False, ) -> tuple[dict[str, Any], AnnData]: """ Join SpatialElement(s) and table together in SQL like manner. @@ -630,6 +658,10 @@ def join_spatialelement_table( match_rows Whether to match the indices of the element and table and if so how. If ``'left'``, element_indices take priority and if ``'right'`` table instance ids take priority. + filter_label_pixels + If ``True``, label pixels whose instance id is not present in the table are set to zero. Only applies to + ``'right'`` and ``'inner'`` joins. If ``False`` (default), a warning is issued and label elements are + returned unfiltered. Returns ------- @@ -695,12 +727,16 @@ def join_spatialelement_table( for name, element in getattr(derived_sdata, element_type).items(): elements_dict[element_type][name] = element - elements_dict_joined, table = _call_join(elements_dict, table, how, match_rows) + elements_dict_joined, table = _call_join(elements_dict, table, how, match_rows, filter_label_pixels) return elements_dict_joined, table def _call_join( - elements_dict: dict[str, dict[str, Any]], table: AnnData, how: str, match_rows: Literal["no", "left", "right"] + elements_dict: dict[str, dict[str, Any]], + table: AnnData, + how: str, + match_rows: Literal["no", "left", "right"], + filter_label_pixels: bool = False, ) -> tuple[dict[str, Any], AnnData]: assert any(key in elements_dict for key in ["labels", "shapes", "points"]), ( "No valid element to join in spatial_element_name. Must provide at least one of either `labels`, `points` or " @@ -715,7 +751,7 @@ def _call_join( # if how in JoinTypes.__dict__["_member_names_"]: # hotfix for bug with Python 3.13: if how in JoinTypes.__dict__: - elements_dict, table = getattr(JoinTypes, how)(elements_dict, table, match_rows) + elements_dict, table = getattr(JoinTypes, how)(elements_dict, table, match_rows, filter_label_pixels) else: raise TypeError(f"`{how}` is not a valid type of join.") @@ -798,6 +834,7 @@ def match_sdata_to_table( table_name: str, table: AnnData | None = None, how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", + filter_label_pixels: bool = False, ) -> SpatialData: """ Filter the elements of a SpatialData object to match only the rows present in the table. @@ -813,6 +850,9 @@ def match_sdata_to_table( `table_name` is used to name the table in the returned `SpatialData` object. how The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". + filter_label_pixels + If ``True``, label pixels whose instance id is not present in the table are set to zero. See + :func:`spatialdata.join_spatialelement_table` for details. Default is ``False``. Notes ----- @@ -824,7 +864,7 @@ def match_sdata_to_table( _, region_key, instance_key = get_table_keys(table) annotated_regions = SpatialData.get_annotated_regions(table) filtered_elements, filtered_table = join_spatialelement_table( - sdata, spatial_element_names=annotated_regions, table=table, how=how + sdata, spatial_element_names=annotated_regions, table=table, how=how, filter_label_pixels=filter_label_pixels ) filtered_table = TableModel.parse( filtered_table, @@ -848,6 +888,7 @@ def filter_by_table_query( var_names_expr: Predicates | None = None, layer: str | None = None, how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", + filter_label_pixels: bool = False, ) -> SpatialData: """Filter the SpatialData object based on a set of table queries. @@ -876,6 +917,9 @@ def filter_by_table_query( The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`. how The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". + filter_label_pixels + If ``True``, label pixels whose instance id is not present in the table are set to zero. See + :func:`spatialdata.join_spatialelement_table` for details. Default is ``False``. Returns ------- @@ -900,7 +944,13 @@ def filter_by_table_query( obs=obs_expr, var=var_expr, x=x_expr, obs_names=obs_names_expr, var_names=var_names_expr, layer=layer ) - return match_sdata_to_table(sdata=sdata_subset, table_name=table_name, table=filtered_table, how=how) + return match_sdata_to_table( + sdata=sdata_subset, + table_name=table_name, + table=filtered_table, + how=how, + filter_label_pixels=filter_label_pixels, + ) @dataclass @@ -1130,3 +1180,18 @@ def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list dims=image.dims, attrs=image.attrs.copy(), # Preserve all attributes ) + + +def _filter_labels_element(element: DataArray | DataTree, ids_to_keep: list[int]) -> DataArray | DataTree: + element_instances = get_element_instances(element) + ids_to_remove = [i for i in element_instances if i not in set(ids_to_keep)] + if isinstance(element, DataArray): + return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove)) + scales = list(element.keys()) + scale_factors = [ + round(element[scales[i]].image.shape[0] / element[scales[i + 1]].image.shape[0]) for i in range(len(scales) - 1) + ] + return Labels2DModel.parse( + _set_instance_ids_in_labels_to_zero(element[scales[0]].image, ids_to_remove), + scale_factors=scale_factors, + ) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index 83a47591..620169d4 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -1,19 +1,13 @@ from __future__ import annotations -import warnings - import annsel as an import numpy as np import pytest from xarray import DataArray from spatialdata import concatenate, match_sdata_to_table -from spatialdata._core.query.relational_query import ( - _set_instance_ids_in_labels_to_zero, - filter_by_table_query, -) +from spatialdata._core.query.relational_query import filter_by_table_query from spatialdata.datasets import blobs_annotating_element -from spatialdata.models import Labels2DModel @pytest.mark.parametrize("subset_func_name", ["match_sdata_to_table", "filter_by_table_query"]) @@ -29,32 +23,14 @@ def test_subset_sdata_by_table_mask(subset_func_name: str) -> None: ) table = sdata.tables["table"] third_elems = table.obs["instance_id"] == 3 - ids_to_remove = list(np.unique(table.obs["instance_id"][~third_elems])) subset_table = table[third_elems] - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="labels not supported for right join", category=UserWarning) - if subset_func_name == "match_sdata_to_table": - subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table) - else: - subset_sdata = filter_by_table_query(sdata, "table", obs_expr=an.col("instance_id") == 3) - - for label_name in list(subset_sdata.labels.keys()): - elem = subset_sdata[label_name] - del subset_sdata[label_name] - if isinstance(elem, DataArray): - filtered = Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(elem, ids_to_remove)) - else: # DataTree (multiscale) - scales = list(elem.keys()) - scale_factors = [ - round(elem[scales[i]].image.shape[0] / elem[scales[i + 1]].image.shape[0]) - for i in range(len(scales) - 1) - ] - filtered = Labels2DModel.parse( - _set_instance_ids_in_labels_to_zero(elem[scales[0]].image, ids_to_remove), - scale_factors=scale_factors, - ) - subset_sdata[label_name] = filtered + if subset_func_name == "match_sdata_to_table": + subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table, filter_label_pixels=True) + else: + subset_sdata = filter_by_table_query( + sdata, "table", obs_expr=an.col("instance_id") == 3, filter_label_pixels=True + ) assert set(subset_sdata.labels.keys()) == {"blobs_labels-labels", "blobs_multiscale_labels-multiscale_labels"} assert set(subset_sdata.points.keys()) == {"blobs_points-points"} @@ -82,31 +58,21 @@ def test_filter_out_instances(subset_func_name: str, element_name: str) -> None: sdata = blobs_annotating_element(element_name) table = sdata.tables["table"] keep_id = 3 - ids_to_remove = [i for i in table.obs["instance_id"].unique() if i != keep_id] + subset_table = table[table.obs["instance_id"] == keep_id] - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="labels not supported for right join", category=UserWarning) - if subset_func_name == "match_sdata_to_table": - subset_table = table[table.obs["instance_id"] == keep_id] - subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table) - else: - subset_sdata = filter_by_table_query(sdata, "table", obs_expr=an.col("instance_id") == keep_id) + if subset_func_name == "match_sdata_to_table": + subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table, filter_label_pixels=True) + else: + subset_sdata = filter_by_table_query( + sdata, "table", obs_expr=an.col("instance_id") == keep_id, filter_label_pixels=True + ) elem = subset_sdata[element_name] if isinstance(elem, DataArray): - filtered = Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(elem, ids_to_remove)) - remaining_ids = set(np.unique(filtered.data.compute())) - {0} + remaining_ids = set(np.unique(elem.data.compute())) - {0} assert remaining_ids == {keep_id} else: # DataTree (multiscale) - scales = list(elem.keys()) - scale_factors = [ - round(elem[scales[i]].image.shape[0] / elem[scales[i + 1]].image.shape[0]) for i in range(len(scales) - 1) - ] - filtered = Labels2DModel.parse( - _set_instance_ids_in_labels_to_zero(elem[scales[0]].image, ids_to_remove), - scale_factors=scale_factors, - ) - for scale in filtered: + for scale in elem: # at coarser scales an instance may vanish due to downsampling, but no other instance should appear - remaining_ids = set(np.unique(filtered[scale].image.compute())) - {0} + remaining_ids = set(np.unique(elem[scale].image.compute())) - {0} assert remaining_ids <= {keep_id} From 7f81db04b8fa2d9d9729f53e86a334adaa20fee8 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 11 May 2026 17:02:28 +0200 Subject: [PATCH 16/18] Change filter_label_pixels default to None; False silences the warning - None (default): warn that label pixels are not filtered, hint at the flag - True: filter label pixels (set removed instance pixels to zero) - False: skip silently, no warning Updated docstrings in join_spatialelement_table, match_sdata_to_table, and filter_by_table_query to document all three states. Co-Authored-By: Claude Sonnet 4.6 --- .../_core/query/relational_query.py | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 498a389e..54a9af61 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -315,7 +315,7 @@ def _right_exclusive_join_spatialelement_table( element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"], - filter_label_pixels: bool = False, + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData | None]: regions, region_column_name, instance_key = get_table_keys(table) if isinstance(regions, str): @@ -356,7 +356,7 @@ def _right_join_spatialelement_table( element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"], - filter_label_pixels: bool = False, + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData]: if match_rows == "left": warnings.warn("Matching rows 'left' is not supported for 'right' join.", UserWarning, stacklevel=2) @@ -372,14 +372,15 @@ def _right_join_spatialelement_table( if element_type in ["points", "shapes"]: element_indices = element.index else: - if filter_label_pixels: + if filter_label_pixels is True: element_dict[element_type][name] = _filter_labels_element( element, table_instance_key_column.tolist() ) - else: + elif filter_label_pixels is None: warnings.warn( f"Element type `labels` not supported for 'right' join, pixels are not filtered;" - f" to filter label pixels pass `filter_label_pixels=True`. Skipping `{name}`", + f" pass `filter_label_pixels=True` to filter or `filter_label_pixels=False` to silence" + f" this warning. Skipping `{name}`", UserWarning, stacklevel=2, ) @@ -399,7 +400,7 @@ def _inner_join_spatialelement_table( element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"], - filter_label_pixels: bool = False, + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData]: regions, region_column_name, instance_key = get_table_keys(table) if isinstance(regions, str): @@ -415,14 +416,15 @@ def _inner_join_spatialelement_table( if element_type in ["points", "shapes"]: element_indices = element.index else: - if filter_label_pixels: + if filter_label_pixels is True: element_dict[element_type][name] = _filter_labels_element( element, table_instance_key_column.tolist() ) - else: + elif filter_label_pixels is None: warnings.warn( f"Element type `labels` not supported for 'inner' join, pixels are not filtered;" - f" to filter label pixels pass `filter_label_pixels=True`. Skipping `{name}`", + f" pass `filter_label_pixels=True` to filter or `filter_label_pixels=False` to silence" + f" this warning. Skipping `{name}`", UserWarning, stacklevel=2, ) @@ -454,7 +456,7 @@ def _left_exclusive_join_spatialelement_table( element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"], - filter_label_pixels: bool = False, + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData | None]: regions, region_column_name, instance_key = get_table_keys(table) if isinstance(regions, str): @@ -490,7 +492,7 @@ def _left_join_spatialelement_table( element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"], - filter_label_pixels: bool = False, + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData]: if match_rows == "right": warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2) @@ -614,7 +616,7 @@ def join_spatialelement_table( table: AnnData | None = None, how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "left", match_rows: Literal["no", "left", "right"] = "no", - filter_label_pixels: bool = False, + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData]: """ Join SpatialElement(s) and table together in SQL like manner. @@ -659,9 +661,10 @@ def join_spatialelement_table( Whether to match the indices of the element and table and if so how. If ``'left'``, element_indices take priority and if ``'right'`` table instance ids take priority. filter_label_pixels - If ``True``, label pixels whose instance id is not present in the table are set to zero. Only applies to - ``'right'`` and ``'inner'`` joins. If ``False`` (default), a warning is issued and label elements are - returned unfiltered. + Controls pixel-level filtering of label elements for ``'right'`` and ``'inner'`` joins. + If ``True``, pixels whose instance id is not present in the table are set to zero. + If ``None`` (default), label elements are returned unfiltered and a warning is issued. + If ``False``, label elements are returned unfiltered silently (no warning). Returns ------- @@ -736,7 +739,7 @@ def _call_join( table: AnnData, how: str, match_rows: Literal["no", "left", "right"], - filter_label_pixels: bool = False, + filter_label_pixels: bool | None = None, ) -> tuple[dict[str, Any], AnnData]: assert any(key in elements_dict for key in ["labels", "shapes", "points"]), ( "No valid element to join in spatial_element_name. Must provide at least one of either `labels`, `points` or " @@ -834,7 +837,7 @@ def match_sdata_to_table( table_name: str, table: AnnData | None = None, how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", - filter_label_pixels: bool = False, + filter_label_pixels: bool | None = None, ) -> SpatialData: """ Filter the elements of a SpatialData object to match only the rows present in the table. @@ -851,8 +854,9 @@ def match_sdata_to_table( how The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". filter_label_pixels - If ``True``, label pixels whose instance id is not present in the table are set to zero. See - :func:`spatialdata.join_spatialelement_table` for details. Default is ``False``. + Controls pixel-level filtering of label elements. ``True`` filters pixels, ``None`` (default) leaves them + unfiltered and warns, ``False`` leaves them unfiltered silently. See + :func:`spatialdata.join_spatialelement_table` for details. Notes ----- @@ -888,7 +892,7 @@ def filter_by_table_query( var_names_expr: Predicates | None = None, layer: str | None = None, how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", - filter_label_pixels: bool = False, + filter_label_pixels: bool | None = None, ) -> SpatialData: """Filter the SpatialData object based on a set of table queries. @@ -918,8 +922,9 @@ def filter_by_table_query( how The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". filter_label_pixels - If ``True``, label pixels whose instance id is not present in the table are set to zero. See - :func:`spatialdata.join_spatialelement_table` for details. Default is ``False``. + Controls pixel-level filtering of label elements. ``True`` filters pixels, ``None`` (default) leaves them + unfiltered and warns, ``False`` leaves them unfiltered silently. See + :func:`spatialdata.join_spatialelement_table` for details. Returns ------- From 6cab962138ba27bed935d23a9bb819ba93fb61e3 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 11 May 2026 17:12:02 +0200 Subject: [PATCH 17/18] Move and consolidate label-filtering tests into test_relational_query_match_sdata_to_table MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace test_match_sdata_to_table_match_labels_error with test_filter_out_instances: parametrized over both API functions and element types; tests all three filter_label_pixels states (None→warn, False→nullcontext noop, True→pixels filtered) - Add test_subset_sdata_by_table_mask for mixed-element subsetting - Delete test_relational_query_subset_sdata_by_table_mask.py Co-Authored-By: Claude Sonnet 4.6 --- ...t_relational_query_match_sdata_to_table.py | 114 ++++++++++++++---- ...tional_query_subset_sdata_by_table_mask.py | 78 ------------ 2 files changed, 89 insertions(+), 103 deletions(-) delete mode 100644 tests/core/query/test_relational_query_subset_sdata_by_table_mask.py diff --git a/tests/core/query/test_relational_query_match_sdata_to_table.py b/tests/core/query/test_relational_query_match_sdata_to_table.py index c2468999..c237da1b 100644 --- a/tests/core/query/test_relational_query_match_sdata_to_table.py +++ b/tests/core/query/test_relational_query_match_sdata_to_table.py @@ -1,8 +1,14 @@ from __future__ import annotations +import contextlib + +import annsel as an +import numpy as np import pytest +from xarray import DataArray from spatialdata import SpatialData, concatenate, match_sdata_to_table +from spatialdata._core.query.relational_query import filter_by_table_query from spatialdata.datasets import blobs_annotating_element @@ -105,35 +111,48 @@ def test_match_sdata_to_table_shapes_and_points(): assert "blobs_polygons-sdata1" not in matched -def test_match_sdata_to_table_match_labels_error(): +@pytest.mark.parametrize("subset_func_name", ["match_sdata_to_table", "filter_by_table_query"]) +@pytest.mark.parametrize("element_name", ["blobs_labels", "blobs_multiscale_labels"]) +def test_filter_out_instances(subset_func_name: str, element_name: str) -> None: """ - match_sdata_to_table() uses the join operations; so when trying to match labels, the error will be raised by the - join. + By default a warning is issued when labels are encountered in a right join and pixels are not filtered. + Passing filter_label_pixels=True filters the label pixels to match the table. """ - sdata = _make_test_data() - sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "labels")) - sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category") - sdata.set_table_annotates_spatialelement( - table_name="table", - region=["blobs_labels-sdata1", "blobs_labels-sdata2"], - region_key="region", - instance_key="instance_id", - ) - - with pytest.warns( - UserWarning, - match="Element type `labels` not supported for 'right' join. Skipping ", - ): - matched = match_sdata_to_table( - sdata, - table=sdata["table"], - table_name="table", + sdata = blobs_annotating_element(element_name) + keep_id = 3 + table = sdata.tables["table"] + subset_table = table[table.obs["instance_id"] == keep_id] + + # None → warning issued; False → silenced, pixels still unfiltered + for flp, ctx in [ + (None, pytest.warns(UserWarning, match="pixels are not filtered")), + (False, contextlib.nullcontext()), + ]: + with ctx: + if subset_func_name == "match_sdata_to_table": + match_sdata_to_table(sdata, "table", table=subset_table, filter_label_pixels=flp) + else: + filter_by_table_query( + sdata, "table", obs_expr=an.col("instance_id") == keep_id, filter_label_pixels=flp + ) + + # filter_label_pixels=True: pixels are zeroed for removed instances + if subset_func_name == "match_sdata_to_table": + subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table, filter_label_pixels=True) + else: + subset_sdata = filter_by_table_query( + sdata, "table", obs_expr=an.col("instance_id") == keep_id, filter_label_pixels=True ) - assert len(matched["table"]) == 10 - assert "blobs_labels-sdata1" in matched - assert "blobs_labels-sdata2" in matched - assert "blobs_points-sdata1" not in matched + elem = subset_sdata[element_name] + if isinstance(elem, DataArray): + remaining_ids = set(np.unique(elem.data.compute())) - {0} + assert remaining_ids == {keep_id} + else: # DataTree (multiscale) + for scale in elem: + # at coarser scales an instance may vanish due to downsampling, but no other instance should appear + remaining_ids = set(np.unique(elem[scale].image.compute())) - {0} + assert remaining_ids <= {keep_id} def test_match_sdata_to_table_no_table_argument(sdata): @@ -145,3 +164,48 @@ def test_match_sdata_to_table_no_table_argument(sdata): assert len(matched["table"]) == 10 assert "blobs_polygons-sdata1" in matched assert "blobs_polygons-sdata2" in matched + + +@pytest.mark.parametrize("subset_func_name", ["match_sdata_to_table", "filter_by_table_query"]) +def test_subset_sdata_by_table_mask(subset_func_name: str) -> None: + """ + Subsetting an sdata with mixed element types (labels, shapes, points) keeps only the requested instances. + Labels are pixel-filtered when filter_label_pixels=True. + """ + sdata = concatenate( + { + "labels": blobs_annotating_element("blobs_labels"), + "shapes": blobs_annotating_element("blobs_circles"), + "points": blobs_annotating_element("blobs_points"), + "multiscale_labels": blobs_annotating_element("blobs_multiscale_labels"), + }, + concatenate_tables=True, + ) + table = sdata.tables["table"] + subset_table = table[table.obs["instance_id"] == 3] + + if subset_func_name == "match_sdata_to_table": + subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table, filter_label_pixels=True) + else: + subset_sdata = filter_by_table_query( + sdata, "table", obs_expr=an.col("instance_id") == 3, filter_label_pixels=True + ) + + assert set(subset_sdata.labels.keys()) == {"blobs_labels-labels", "blobs_multiscale_labels-multiscale_labels"} + assert set(subset_sdata.points.keys()) == {"blobs_points-points"} + assert set(subset_sdata.shapes.keys()) == {"blobs_circles-shapes"} + + labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_labels-labels"].data.compute())) - {0} + assert labels_remaining_ids == {3} + + for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]: + ms_labels_remaining_ids = set( + np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute()) + ) - {0} + assert ms_labels_remaining_ids == {3} + + points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"].index)) - {0} + assert points_remaining_ids == {3} + + shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} + assert shapes_remaining_ids == {3} diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py deleted file mode 100644 index 620169d4..00000000 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -import annsel as an -import numpy as np -import pytest -from xarray import DataArray - -from spatialdata import concatenate, match_sdata_to_table -from spatialdata._core.query.relational_query import filter_by_table_query -from spatialdata.datasets import blobs_annotating_element - - -@pytest.mark.parametrize("subset_func_name", ["match_sdata_to_table", "filter_by_table_query"]) -def test_subset_sdata_by_table_mask(subset_func_name: str) -> None: - sdata = concatenate( - { - "labels": blobs_annotating_element("blobs_labels"), - "shapes": blobs_annotating_element("blobs_circles"), - "points": blobs_annotating_element("blobs_points"), - "multiscale_labels": blobs_annotating_element("blobs_multiscale_labels"), - }, - concatenate_tables=True, - ) - table = sdata.tables["table"] - third_elems = table.obs["instance_id"] == 3 - subset_table = table[third_elems] - - if subset_func_name == "match_sdata_to_table": - subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table, filter_label_pixels=True) - else: - subset_sdata = filter_by_table_query( - sdata, "table", obs_expr=an.col("instance_id") == 3, filter_label_pixels=True - ) - - assert set(subset_sdata.labels.keys()) == {"blobs_labels-labels", "blobs_multiscale_labels-multiscale_labels"} - assert set(subset_sdata.points.keys()) == {"blobs_points-points"} - assert set(subset_sdata.shapes.keys()) == {"blobs_circles-shapes"} - - labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_labels-labels"].data.compute())) - {0} - assert labels_remaining_ids == {3} - - for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]: - ms_labels_remaining_ids = set( - np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute()) - ) - {0} - assert ms_labels_remaining_ids == {3} - - points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"].index)) - {0} - assert points_remaining_ids == {3} - - shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} - assert shapes_remaining_ids == {3} - - -@pytest.mark.parametrize("subset_func_name", ["match_sdata_to_table", "filter_by_table_query"]) -@pytest.mark.parametrize("element_name", ["blobs_labels", "blobs_multiscale_labels"]) -def test_filter_out_instances(subset_func_name: str, element_name: str) -> None: - sdata = blobs_annotating_element(element_name) - table = sdata.tables["table"] - keep_id = 3 - subset_table = table[table.obs["instance_id"] == keep_id] - - if subset_func_name == "match_sdata_to_table": - subset_sdata = match_sdata_to_table(sdata, "table", table=subset_table, filter_label_pixels=True) - else: - subset_sdata = filter_by_table_query( - sdata, "table", obs_expr=an.col("instance_id") == keep_id, filter_label_pixels=True - ) - - elem = subset_sdata[element_name] - if isinstance(elem, DataArray): - remaining_ids = set(np.unique(elem.data.compute())) - {0} - assert remaining_ids == {keep_id} - else: # DataTree (multiscale) - for scale in elem: - # at coarser scales an instance may vanish due to downsampling, but no other instance should appear - remaining_ids = set(np.unique(elem[scale].image.compute())) - {0} - assert remaining_ids <= {keep_id} From 8d6552ecf24a5b4e46b0c81267b18ab95c7e5660 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 11 May 2026 17:31:23 +0200 Subject: [PATCH 18/18] Add 3D labels guard to _filter_labels_element; fix annsel list predicate in tests - Raise NotImplementedError in _filter_labels_element when element is Labels3DModel - Add test_filter_out_instances_3d_labels_not_supported parametrized over both API functions - Use an.col().is_in() instead of == [list] in 3D test (narwhals does not support nested literals) Co-Authored-By: Claude Sonnet 4.6 --- .../_core/query/relational_query.py | 2 ++ ...t_relational_query_match_sdata_to_table.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 54a9af61..3b12347e 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1188,6 +1188,8 @@ def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list def _filter_labels_element(element: DataArray | DataTree, ids_to_keep: list[int]) -> DataArray | DataTree: + if get_model(element) is Labels3DModel: + raise NotImplementedError("Pixel-level filtering of 3D labels is not supported.") element_instances = get_element_instances(element) ids_to_remove = [i for i in element_instances if i not in set(ids_to_keep)] if isinstance(element, DataArray): diff --git a/tests/core/query/test_relational_query_match_sdata_to_table.py b/tests/core/query/test_relational_query_match_sdata_to_table.py index c237da1b..a2af0251 100644 --- a/tests/core/query/test_relational_query_match_sdata_to_table.py +++ b/tests/core/query/test_relational_query_match_sdata_to_table.py @@ -4,12 +4,15 @@ import annsel as an import numpy as np +import pandas as pd import pytest +from anndata import AnnData from xarray import DataArray from spatialdata import SpatialData, concatenate, match_sdata_to_table from spatialdata._core.query.relational_query import filter_by_table_query from spatialdata.datasets import blobs_annotating_element +from spatialdata.models import Labels3DModel, TableModel def _make_test_data() -> SpatialData: @@ -209,3 +212,24 @@ def test_subset_sdata_by_table_mask(subset_func_name: str) -> None: shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} assert shapes_remaining_ids == {3} + + +@pytest.mark.parametrize("subset_func_name", ["match_sdata_to_table", "filter_by_table_query"]) +def test_filter_out_instances_3d_labels_not_supported(subset_func_name: str) -> None: + """Pixel-level filtering of 3D labels raises NotImplementedError.""" + data = np.zeros((5, 5, 5), dtype=np.int32) + data[1:3, 1:3, 1:3] = 1 + data[3:5, 3:5, 3:5] = 2 + labels_3d = Labels3DModel.parse(data, dims=["z", "y", "x"]) + + obs_df = pd.DataFrame({"region": pd.Categorical(["labels_3d"]), "instance_id": [1]}, index=["0"]) + table = TableModel.parse( + AnnData(shape=(1, 0), obs=obs_df), region="labels_3d", region_key="region", instance_key="instance_id" + ) + sdata = SpatialData(labels={"labels_3d": labels_3d}, tables={"table": table}) + + with pytest.raises(NotImplementedError, match="3D labels"): + if subset_func_name == "match_sdata_to_table": + match_sdata_to_table(sdata, "table", filter_label_pixels=True) + else: + filter_by_table_query(sdata, "table", obs_expr=an.col("instance_id") == 1, filter_label_pixels=True)