diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 7ba66e71..9254ece9 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -164,7 +164,7 @@ def __dir__() -> list[str]: from spatialdata._core.centroids import get_centroids # _core.concatenate - from spatialdata._core.concatenate import concatenate + from spatialdata._core.concatenate import concatenate, deconcatenate # _core.data_extent from spatialdata._core.data_extent import are_extents_equal, get_extent diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index b0639eac..314226b1 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Iterable from copy import copy # Should probably go up at the top from itertools import chain -from typing import Any +from typing import Any, Literal from warnings import warn import numpy as np @@ -12,6 +12,7 @@ from anndata._core.merge import StrategiesLiteral, resolve_merge_strategy from spatialdata._core._utils import _find_common_table_keys +from spatialdata._core.query.relational_query import join_spatialelement_table from spatialdata._core.spatialdata import SpatialData from spatialdata.models import TableModel, get_table_keys from spatialdata.transformations import ( @@ -22,6 +23,7 @@ __all__ = [ "concatenate", + "deconcatenate", ] @@ -281,3 +283,74 @@ def _fix_ensure_unique_element_names( sdata_fixed = SpatialData.init_from_elements(elements | tables) sdatas_fixed.append(sdata_fixed) return sdatas_fixed + + +def deconcatenate( + sdata: SpatialData, + table_name: str | None = None, + split_by: str | None = None, + join: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", +) -> list[SpatialData]: + """ + Split a SpatialData object into multiple objects by unique values of a column. + + Parameters + ---------- + sdata + The SpatialData object to split. + table_name + Name of the table to split on. If ``None`` and the object contains exactly one table, that + table is used; otherwise the name must be provided explicitly. + split_by + Column in the table to split by. Defaults to the region_key inferred from the table attributes. + join + Join type for matching spatial elements to each sub-table. + + Returns + ------- + list[SpatialData] + One SpatialData object per unique value in the ``split_by`` column. + """ + if table_name is None: + if len(sdata.tables) == 1: + table_name = next(iter(sdata.tables)) + else: + raise ValueError( + f"sdata contains {len(sdata.tables)} tables; please specify `table_name` explicitly. " + f"Available tables: {list(sdata.tables)}" + ) + if table_name not in sdata.tables: + raise KeyError(f"Table '{table_name}' not found in sdata") + + table = sdata[table_name] + _, region_key, instance_key = get_table_keys(table) + + if split_by is None: + split_by = region_key + + # groupby partitions obs_names by split_by in one pass (O(n_obs)); slicing AnnData by the + # resulting index groups is then O(group_size) per group rather than O(n_obs) per group. + groups = table.obs.groupby(split_by, observed=True).groups + sub_tables = {val: table[idx] for val, idx in groups.items()} + annotated_regions = {val: sub.obs[region_key].unique().tolist() for val, sub in sub_tables.items()} + + # join_spatialelement_table is called once per split value; future work could perform all + # element joins in a single pass and then partition the result. + sdatas = [] + for val in groups: + regions = annotated_regions[val] + elements, filtered_table = join_spatialelement_table( + sdata, + spatial_element_names=regions, + table=sub_tables[val], + how=join, + ) + parsed_table = TableModel.parse( + filtered_table, + region=regions, + region_key=region_key, + instance_key=instance_key, + overwrite_metadata=True, + ) + sdatas.append(SpatialData.init_from_elements(elements | {table_name: parsed_table})) + return sdatas diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 96e72599..0cb7c54c 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -8,7 +8,7 @@ from anndata import AnnData from geopandas import GeoDataFrame -from spatialdata._core.concatenate import _concatenate_tables, concatenate +from spatialdata._core.concatenate import _concatenate_tables, concatenate, deconcatenate from spatialdata._core.data_extent import are_extents_equal, get_extent from spatialdata._core.operations._utils import transform_to_data_extent from spatialdata._core.spatialdata import SpatialData @@ -694,3 +694,51 @@ def test_validate_table_in_spatialdata(full_sdata): del full_sdata.points["points_0"] with pytest.warns(UserWarning, match="in the SpatialData object"): full_sdata.validate_table_in_spatialdata(table) + + +def test_deconcatenate(): + from shapely.geometry import Point + + shapes_a = ShapesModel.parse(GeoDataFrame({"geometry": [Point(0, 0), Point(1, 1)], "radius": [1.0, 1.0]})) + shapes_b = ShapesModel.parse(GeoDataFrame({"geometry": [Point(2, 2), Point(3, 3)], "radius": [1.0, 1.0]})) + shapes_c = ShapesModel.parse(GeoDataFrame({"geometry": [Point(4, 4), Point(5, 5)], "radius": [1.0, 1.0]})) + + obs = pd.DataFrame( + { + "region": pd.Categorical(["region_a", "region_a", "region_b", "region_b", "region_c", "region_c"]), + "instance_id": [0, 1, 0, 1, 0, 1], + "batch": pd.Categorical(["batch1", "batch2", "batch1", "batch2", "batch1", "batch2"]), + }, + index=["obs0", "obs1", "obs2", "obs3", "obs4", "obs5"], + ) + table = AnnData(obs=obs) + table = TableModel.parse( + table, + region=["region_a", "region_b", "region_c"], + region_key="region", + instance_key="instance_id", + ) + sdata = SpatialData( + shapes={"region_a": shapes_a, "region_b": shapes_b, "region_c": shapes_c}, + tables={"table": table}, + ) + + # split by region_key (default) + sdatas = deconcatenate(sdata) + assert isinstance(sdatas, list) + assert len(sdatas) == 3 + seen_regions = set() + for sub in sdatas: + assert "table" in sub.tables + sub_table = sub.tables["table"] + unique_regions = sub_table.obs["region"].unique().tolist() + assert len(unique_regions) == 1 + seen_regions.add(unique_regions[0]) + assert seen_regions == {"region_a", "region_b", "region_c"} + + # split by a non-default column + sdatas2 = deconcatenate(sdata, split_by="batch") + assert isinstance(sdatas2, list) + assert len(sdatas2) == 2 + seen_batches = {sub.tables["table"].obs["batch"].unique()[0] for sub in sdatas2} + assert seen_batches == {"batch1", "batch2"}