Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __dir__() -> list[str]:
from spatialdata._core.centroids import get_centroids

# _core.concatenate
from spatialdata._core.concatenate import concatenate
from spatialdata._core.concatenate import concatenate, deconcatenate

# _core.data_extent
from spatialdata._core.data_extent import are_extents_equal, get_extent
Expand Down
75 changes: 74 additions & 1 deletion src/spatialdata/_core/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from collections.abc import Callable, Iterable
from copy import copy # Should probably go up at the top
from itertools import chain
from typing import Any
from typing import Any, Literal
from warnings import warn

import numpy as np
from anndata import AnnData
from anndata._core.merge import StrategiesLiteral, resolve_merge_strategy

from spatialdata._core._utils import _find_common_table_keys
from spatialdata._core.query.relational_query import join_spatialelement_table
from spatialdata._core.spatialdata import SpatialData
from spatialdata.models import TableModel, get_table_keys
from spatialdata.transformations import (
Expand All @@ -22,6 +23,7 @@

__all__ = [
"concatenate",
"deconcatenate",
]


Expand Down Expand Up @@ -281,3 +283,74 @@ def _fix_ensure_unique_element_names(
sdata_fixed = SpatialData.init_from_elements(elements | tables)
sdatas_fixed.append(sdata_fixed)
return sdatas_fixed


def deconcatenate(
sdata: SpatialData,
table_name: str | None = None,
split_by: str | None = None,
join: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
) -> list[SpatialData]:
"""
Split a SpatialData object into multiple objects by unique values of a column.

Parameters
----------
sdata
The SpatialData object to split.
table_name
Name of the table to split on. If ``None`` and the object contains exactly one table, that
table is used; otherwise the name must be provided explicitly.
split_by
Column in the table to split by. Defaults to the region_key inferred from the table attributes.
join
Join type for matching spatial elements to each sub-table.

Returns
-------
list[SpatialData]
One SpatialData object per unique value in the ``split_by`` column.
"""
if table_name is None:
if len(sdata.tables) == 1:
table_name = next(iter(sdata.tables))
else:
raise ValueError(
f"sdata contains {len(sdata.tables)} tables; please specify `table_name` explicitly. "
f"Available tables: {list(sdata.tables)}"
)
if table_name not in sdata.tables:
raise KeyError(f"Table '{table_name}' not found in sdata")

table = sdata[table_name]
_, region_key, instance_key = get_table_keys(table)

if split_by is None:
split_by = region_key

# groupby partitions obs_names by split_by in one pass (O(n_obs)); slicing AnnData by the
# resulting index groups is then O(group_size) per group rather than O(n_obs) per group.
groups = table.obs.groupby(split_by, observed=True).groups
sub_tables = {val: table[idx] for val, idx in groups.items()}
annotated_regions = {val: sub.obs[region_key].unique().tolist() for val, sub in sub_tables.items()}

# join_spatialelement_table is called once per split value; future work could perform all
# element joins in a single pass and then partition the result.
sdatas = []
for val in groups:
regions = annotated_regions[val]
elements, filtered_table = join_spatialelement_table(
sdata,
spatial_element_names=regions,
table=sub_tables[val],
how=join,
)
parsed_table = TableModel.parse(
filtered_table,
region=regions,
region_key=region_key,
instance_key=instance_key,
overwrite_metadata=True,
)
sdatas.append(SpatialData.init_from_elements(elements | {table_name: parsed_table}))
return sdatas
50 changes: 49 additions & 1 deletion tests/core/operations/test_spatialdata_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from anndata import AnnData
from geopandas import GeoDataFrame

from spatialdata._core.concatenate import _concatenate_tables, concatenate
from spatialdata._core.concatenate import _concatenate_tables, concatenate, deconcatenate
from spatialdata._core.data_extent import are_extents_equal, get_extent
from spatialdata._core.operations._utils import transform_to_data_extent
from spatialdata._core.spatialdata import SpatialData
Expand Down Expand Up @@ -694,3 +694,51 @@ def test_validate_table_in_spatialdata(full_sdata):
del full_sdata.points["points_0"]
with pytest.warns(UserWarning, match="in the SpatialData object"):
full_sdata.validate_table_in_spatialdata(table)


def test_deconcatenate():
from shapely.geometry import Point

shapes_a = ShapesModel.parse(GeoDataFrame({"geometry": [Point(0, 0), Point(1, 1)], "radius": [1.0, 1.0]}))
shapes_b = ShapesModel.parse(GeoDataFrame({"geometry": [Point(2, 2), Point(3, 3)], "radius": [1.0, 1.0]}))
shapes_c = ShapesModel.parse(GeoDataFrame({"geometry": [Point(4, 4), Point(5, 5)], "radius": [1.0, 1.0]}))

obs = pd.DataFrame(
{
"region": pd.Categorical(["region_a", "region_a", "region_b", "region_b", "region_c", "region_c"]),
"instance_id": [0, 1, 0, 1, 0, 1],
"batch": pd.Categorical(["batch1", "batch2", "batch1", "batch2", "batch1", "batch2"]),
},
index=["obs0", "obs1", "obs2", "obs3", "obs4", "obs5"],
)
table = AnnData(obs=obs)
table = TableModel.parse(
table,
region=["region_a", "region_b", "region_c"],
region_key="region",
instance_key="instance_id",
)
sdata = SpatialData(
shapes={"region_a": shapes_a, "region_b": shapes_b, "region_c": shapes_c},
tables={"table": table},
)

# split by region_key (default)
sdatas = deconcatenate(sdata)
assert isinstance(sdatas, list)
assert len(sdatas) == 3
seen_regions = set()
for sub in sdatas:
assert "table" in sub.tables
sub_table = sub.tables["table"]
unique_regions = sub_table.obs["region"].unique().tolist()
assert len(unique_regions) == 1
seen_regions.add(unique_regions[0])
assert seen_regions == {"region_a", "region_b", "region_c"}

# split by a non-default column
sdatas2 = deconcatenate(sdata, split_by="batch")
assert isinstance(sdatas2, list)
assert len(sdatas2) == 2
seen_batches = {sub.tables["table"].obs["batch"].unique()[0] for sub in sdatas2}
assert seen_batches == {"batch1", "batch2"}
Loading