Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a28c7c9
init
selmanozleyen Jun 30, 2025
225d593
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2025
e549b4b
fix mypy linterrors
selmanozleyen Jun 30, 2025
2aad72b
update the location and the design
selmanozleyen Jul 3, 2025
ef74057
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2025
d6e22cb
update docs
selmanozleyen Jul 3, 2025
46c41db
Merge branch 'feature/filter_operations_on_label' of https://github.c…
selmanozleyen Jul 3, 2025
80d95a2
make coverage 100/100 because why not
selmanozleyen Jul 3, 2025
4438605
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2025
4c927ee
fixed type annotation
selmanozleyen Jul 10, 2025
e9e0da2
dont compute eagerly use. delete other instance key for consistency
selmanozleyen Jul 10, 2025
7534c91
update the tests and make sure we use match_element_to_table
selmanozleyen Jul 14, 2025
b4901cb
Merge branch 'main' into feature/filter_operations_on_label
selmanozleyen Aug 21, 2025
b21a0a1
Merge branch 'main' into feature/filter_operations_on_label
selmanozleyen Oct 20, 2025
631fe2a
Merge branch 'main' into feature/filter_operations_on_label
selmanozleyen Nov 3, 2025
908f7b4
wip rewrite tests using existing APIs
LucaMarconato May 11, 2026
6f7e468
Merge branch 'main' into feature/filter_operations_on_label
LucaMarconato May 11, 2026
359d553
test passing without using subset_sdata_by_table_mask()
LucaMarconato May 11, 2026
71e27ef
Remove _filter_by_instance_ids and _get_scale_factors; refactor tests…
LucaMarconato May 11, 2026
264ad2a
Add filter_label_pixels flag to match_sdata_to_table and filter_by_ta…
LucaMarconato May 11, 2026
7f81db0
Change filter_label_pixels default to None; False silences the warning
LucaMarconato May 11, 2026
6cab962
Move and consolidate label-filtering tests into test_relational_query…
LucaMarconato May 11, 2026
8d6552e
Add 3D labels guard to _filter_labels_element; fix annsel list predic…
LucaMarconato May 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 123 additions & 20 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import dask.array as da
import numpy as np
import pandas as pd
import xarray as xr
from anndata import AnnData
from annsel.core.typing import Predicates
from dask.dataframe import DataFrame as DaskDataFrame
Expand Down Expand Up @@ -311,7 +312,10 @@ def _get_masked_element(


def _right_exclusive_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
element_dict: dict[str, dict[str, Any]],
table: AnnData,
match_rows: Literal["left", "no", "right"],
filter_label_pixels: bool | None = None,
) -> tuple[dict[str, Any], AnnData | None]:
regions, region_column_name, instance_key = get_table_keys(table)
if isinstance(regions, str):
Expand Down Expand Up @@ -349,7 +353,10 @@ def _right_exclusive_join_spatialelement_table(


def _right_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
element_dict: dict[str, dict[str, Any]],
table: AnnData,
match_rows: Literal["left", "no", "right"],
filter_label_pixels: bool | None = None,
) -> tuple[dict[str, Any], AnnData]:
if match_rows == "left":
warnings.warn("Matching rows 'left' is not supported for 'right' join.", UserWarning, stacklevel=2)
Expand All @@ -365,11 +372,18 @@ def _right_join_spatialelement_table(
if element_type in ["points", "shapes"]:
element_indices = element.index
else:
warnings.warn(
f"Element type `labels` not supported for 'right' join. Skipping `{name}`",
UserWarning,
stacklevel=2,
)
if filter_label_pixels is True:
element_dict[element_type][name] = _filter_labels_element(
element, table_instance_key_column.tolist()
)
elif filter_label_pixels is None:
warnings.warn(
f"Element type `labels` not supported for 'right' join, pixels are not filtered;"
f" pass `filter_label_pixels=True` to filter or `filter_label_pixels=False` to silence"
f" this warning. Skipping `{name}`",
UserWarning,
stacklevel=2,
)
continue

masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows)
Expand All @@ -383,7 +397,10 @@ def _right_join_spatialelement_table(


def _inner_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
element_dict: dict[str, dict[str, Any]],
table: AnnData,
match_rows: Literal["left", "no", "right"],
filter_label_pixels: bool | None = None,
) -> tuple[dict[str, Any], AnnData]:
regions, region_column_name, instance_key = get_table_keys(table)
if isinstance(regions, str):
Expand All @@ -399,11 +416,18 @@ def _inner_join_spatialelement_table(
if element_type in ["points", "shapes"]:
element_indices = element.index
else:
warnings.warn(
f"Element type `labels` not supported for 'inner' join. Skipping `{name}`",
UserWarning,
stacklevel=2,
)
if filter_label_pixels is True:
element_dict[element_type][name] = _filter_labels_element(
element, table_instance_key_column.tolist()
)
elif filter_label_pixels is None:
warnings.warn(
f"Element type `labels` not supported for 'inner' join, pixels are not filtered;"
f" pass `filter_label_pixels=True` to filter or `filter_label_pixels=False` to silence"
f" this warning. Skipping `{name}`",
UserWarning,
stacklevel=2,
)
continue

masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows)
Expand All @@ -429,7 +453,10 @@ def _inner_join_spatialelement_table(


def _left_exclusive_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
element_dict: dict[str, dict[str, Any]],
table: AnnData,
match_rows: Literal["left", "no", "right"],
filter_label_pixels: bool | None = None,
) -> tuple[dict[str, Any], AnnData | None]:
regions, region_column_name, instance_key = get_table_keys(table)
if isinstance(regions, str):
Expand Down Expand Up @@ -462,7 +489,10 @@ def _left_exclusive_join_spatialelement_table(


def _left_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
element_dict: dict[str, dict[str, Any]],
table: AnnData,
match_rows: Literal["left", "no", "right"],
filter_label_pixels: bool | None = None,
) -> tuple[dict[str, Any], AnnData]:
if match_rows == "right":
warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2)
Expand Down Expand Up @@ -586,6 +616,7 @@ def join_spatialelement_table(
table: AnnData | None = None,
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "left",
match_rows: Literal["no", "left", "right"] = "no",
filter_label_pixels: bool | None = None,
) -> tuple[dict[str, Any], AnnData]:
"""
Join SpatialElement(s) and table together in SQL like manner.
Expand Down Expand Up @@ -629,6 +660,11 @@ def join_spatialelement_table(
match_rows
Whether to match the indices of the element and table and if so how. If ``'left'``, element_indices take
priority and if ``'right'`` table instance ids take priority.
filter_label_pixels
Controls pixel-level filtering of label elements for ``'right'`` and ``'inner'`` joins.
If ``True``, pixels whose instance id is not present in the table are set to zero.
If ``None`` (default), label elements are returned unfiltered and a warning is issued.
If ``False``, label elements are returned unfiltered silently (no warning).

Returns
-------
Expand Down Expand Up @@ -694,12 +730,16 @@ def join_spatialelement_table(
for name, element in getattr(derived_sdata, element_type).items():
elements_dict[element_type][name] = element

elements_dict_joined, table = _call_join(elements_dict, table, how, match_rows)
elements_dict_joined, table = _call_join(elements_dict, table, how, match_rows, filter_label_pixels)
return elements_dict_joined, table


def _call_join(
elements_dict: dict[str, dict[str, Any]], table: AnnData, how: str, match_rows: Literal["no", "left", "right"]
elements_dict: dict[str, dict[str, Any]],
table: AnnData,
how: str,
match_rows: Literal["no", "left", "right"],
filter_label_pixels: bool | None = None,
) -> tuple[dict[str, Any], AnnData]:
assert any(key in elements_dict for key in ["labels", "shapes", "points"]), (
"No valid element to join in spatial_element_name. Must provide at least one of either `labels`, `points` or "
Expand All @@ -714,7 +754,7 @@ def _call_join(
# if how in JoinTypes.__dict__["_member_names_"]:
# hotfix for bug with Python 3.13:
if how in JoinTypes.__dict__:
elements_dict, table = getattr(JoinTypes, how)(elements_dict, table, match_rows)
elements_dict, table = getattr(JoinTypes, how)(elements_dict, table, match_rows, filter_label_pixels)
else:
raise TypeError(f"`{how}` is not a valid type of join.")

Expand Down Expand Up @@ -797,6 +837,7 @@ def match_sdata_to_table(
table_name: str,
table: AnnData | None = None,
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
filter_label_pixels: bool | None = None,
) -> SpatialData:
"""
Filter the elements of a SpatialData object to match only the rows present in the table.
Expand All @@ -812,6 +853,10 @@ def match_sdata_to_table(
`table_name` is used to name the table in the returned `SpatialData` object.
how
The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right".
filter_label_pixels
Controls pixel-level filtering of label elements. ``True`` filters pixels, ``None`` (default) leaves them
unfiltered and warns, ``False`` leaves them unfiltered silently. See
:func:`spatialdata.join_spatialelement_table` for details.

Notes
-----
Expand All @@ -823,7 +868,7 @@ def match_sdata_to_table(
_, region_key, instance_key = get_table_keys(table)
annotated_regions = SpatialData.get_annotated_regions(table)
filtered_elements, filtered_table = join_spatialelement_table(
sdata, spatial_element_names=annotated_regions, table=table, how=how
sdata, spatial_element_names=annotated_regions, table=table, how=how, filter_label_pixels=filter_label_pixels
)
filtered_table = TableModel.parse(
filtered_table,
Expand All @@ -847,6 +892,7 @@ def filter_by_table_query(
var_names_expr: Predicates | None = None,
layer: str | None = None,
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
filter_label_pixels: bool | None = None,
) -> SpatialData:
"""Filter the SpatialData object based on a set of table queries.

Expand Down Expand Up @@ -875,6 +921,10 @@ def filter_by_table_query(
The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`.
how
The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right".
filter_label_pixels
Controls pixel-level filtering of label elements. ``True`` filters pixels, ``None`` (default) leaves them
unfiltered and warns, ``False`` leaves them unfiltered silently. See
:func:`spatialdata.join_spatialelement_table` for details.

Returns
-------
Expand All @@ -899,7 +949,13 @@ def filter_by_table_query(
obs=obs_expr, var=var_expr, x=x_expr, obs_names=obs_names_expr, var_names=var_names_expr, layer=layer
)

return match_sdata_to_table(sdata=sdata_subset, table_name=table_name, table=filtered_table, how=how)
return match_sdata_to_table(
sdata=sdata_subset,
table_name=table_name,
table=filtered_table,
how=how,
filter_label_pixels=filter_label_pixels,
)


@dataclass
Expand Down Expand Up @@ -1099,3 +1155,50 @@ def get_values(
return df

raise ValueError(f"Unknown origin {origin}")


def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray:
# Use apply_ufunc for efficient processing
# Create a copy to avoid modifying read-only array
result = block.copy()
result[np.isin(result, ids_to_remove)] = 0
return result


def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray:
processed = xr.apply_ufunc(
partial(_mask_block, ids_to_remove=ids_to_remove),
image,
input_core_dims=[["y", "x"]],
output_core_dims=[["y", "x"]],
vectorize=True,
dask="parallelized",
output_dtypes=[image.dtype],
dataset_fill_value=0,
dask_gufunc_kwargs={"allow_rechunk": True},
)

# Create a new DataArray to ensure persistence
return xr.DataArray(
data=processed.data,
coords=image.coords,
dims=image.dims,
attrs=image.attrs.copy(), # Preserve all attributes
)


def _filter_labels_element(element: DataArray | DataTree, ids_to_keep: list[int]) -> DataArray | DataTree:
if get_model(element) is Labels3DModel:
raise NotImplementedError("Pixel-level filtering of 3D labels is not supported.")
element_instances = get_element_instances(element)
ids_to_remove = [i for i in element_instances if i not in set(ids_to_keep)]
if isinstance(element, DataArray):
return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove))
scales = list(element.keys())
scale_factors = [
round(element[scales[i]].image.shape[0] / element[scales[i + 1]].image.shape[0]) for i in range(len(scales) - 1)
]
return Labels2DModel.parse(
_set_instance_ids_in_labels_to_zero(element[scales[0]].image, ids_to_remove),
scale_factors=scale_factors,
)
Loading
Loading