From ef032bdbb989b48671c8836e8f1c192eee2edd56 Mon Sep 17 00:00:00 2001 From: gboscagli Date: Mon, 9 Mar 2026 16:25:38 +0100 Subject: [PATCH 1/4] adding deconcatenate --- src/spatialdata/_core/deconcatenate.py | 40 +++++++++++++++++++ .../operations/test_spatialdata_operations.py | 37 +++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 src/spatialdata/_core/deconcatenate.py diff --git a/src/spatialdata/_core/deconcatenate.py b/src/spatialdata/_core/deconcatenate.py new file mode 100644 index 00000000..2511b702 --- /dev/null +++ b/src/spatialdata/_core/deconcatenate.py @@ -0,0 +1,40 @@ +from collections.abc import Iterable + +from spatialdata._core.query.relational_query import match_sdata_to_table +from spatialdata._core.spatialdata import SpatialData + +def deconcatenate( + full_sdata: SpatialData, + by: str | Iterable[str], + target_coordinate_system: str, + full_sdata_table_name: str = "table", + sdatas_table_names: str | Iterable[str] = "table", + region_key: str = "region", + join: str = "right" +) -> SpatialData | list[SpatialData]: + """ + From a `SpatialData` object containing multiple regions, returns `SpatialData` objects specific to each region identified in `by`. + """ + + if full_sdata_table_name not in full_sdata.tables: + raise KeyError('Missing table') + + sdata_table = full_sdata[full_sdata_table_name] + + is_single_region = isinstance(by, str) + deconcat_regions = [by] if is_single_region else list(by) + sdatas_table_names = [sdatas_table_names] * len(deconcat_regions) if isinstance(sdatas_table_names, str) else list(sdatas_table_names) + + sdatas = [] + for region, table_name in zip(deconcat_regions, sdatas_table_names): + + deconcat_table = sdata_table[sdata_table.obs[region_key] == region] + deconcat_sdata = match_sdata_to_table( + full_sdata, + table=deconcat_table, + table_name=table_name, + how=join) + + sdatas.append(deconcat_sdata) + + return sdatas[0] if is_single_region else sdatas \ No newline at end of file diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 68b538e0..75e191b3 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -10,6 +10,7 @@ from spatialdata._core.concatenate import _concatenate_tables, concatenate from spatialdata._core.data_extent import are_extents_equal, get_extent +from spatialdata._core.deconcatenate import deconcatenate from spatialdata._core.operations._utils import transform_to_data_extent from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike @@ -693,3 +694,39 @@ def test_validate_table_in_spatialdata(full_sdata): del full_sdata.points["points_0"] with pytest.warns(UserWarning, match="in the SpatialData object"): full_sdata.validate_table_in_spatialdata(table) + + +def test_deconcatenate(full_sdata): + + regions = ["region1", "region2"] + table_names = ["table1", "table2"] + assert len(regions) == len(table_names) + + # MULTIPLE REGIONS === + sdatas = deconcatenate( + full_sdata, + by=regions, + target_coordinate_system="global", + sdatas_table_names=table_names + ) + + assert isinstance(sdatas, list) + assert len(sdatas) == len(regions) + + for sdata, region, table_name in zip(sdatas, regions, table_names): + + assert table_name in sdata.tables + table = sdata.tables[table_name] + assert (table.obs["region"] == region).all() + + + + # SINGLE REGION === + single = deconcatenate( + full_sdata, + by=regions[0], + target_coordinate_system="global" + ) + + assert not isinstance(single, list) + assert "table" in single.tables \ No newline at end of file From c9a9866e10e6711e4e63740685424a7aa755b9ca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Mar 2026 15:46:46 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata/_core/deconcatenate.py | 35 ++++++++++--------- .../operations/test_spatialdata_operations.py | 18 ++-------- 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/src/spatialdata/_core/deconcatenate.py b/src/spatialdata/_core/deconcatenate.py index 2511b702..87c60a9e 100644 --- a/src/spatialdata/_core/deconcatenate.py +++ b/src/spatialdata/_core/deconcatenate.py @@ -1,40 +1,41 @@ +from __future__ import annotations + from collections.abc import Iterable from spatialdata._core.query.relational_query import match_sdata_to_table from spatialdata._core.spatialdata import SpatialData + def deconcatenate( - full_sdata: SpatialData, - by: str | Iterable[str], - target_coordinate_system: str, - full_sdata_table_name: str = "table", - sdatas_table_names: str | Iterable[str] = "table", - region_key: str = "region", - join: str = "right" + full_sdata: SpatialData, + by: str | Iterable[str], + target_coordinate_system: str, + full_sdata_table_name: str = "table", + sdatas_table_names: str | Iterable[str] = "table", + region_key: str = "region", + join: str = "right", ) -> SpatialData | list[SpatialData]: """ From a `SpatialData` object containing multiple regions, returns `SpatialData` objects specific to each region identified in `by`. """ - if full_sdata_table_name not in full_sdata.tables: - raise KeyError('Missing table') + raise KeyError("Missing table") sdata_table = full_sdata[full_sdata_table_name] is_single_region = isinstance(by, str) deconcat_regions = [by] if is_single_region else list(by) - sdatas_table_names = [sdatas_table_names] * len(deconcat_regions) if isinstance(sdatas_table_names, str) else list(sdatas_table_names) + sdatas_table_names = ( + [sdatas_table_names] * len(deconcat_regions) + if isinstance(sdatas_table_names, str) + else list(sdatas_table_names) + ) sdatas = [] for region, table_name in zip(deconcat_regions, sdatas_table_names): - deconcat_table = sdata_table[sdata_table.obs[region_key] == region] - deconcat_sdata = match_sdata_to_table( - full_sdata, - table=deconcat_table, - table_name=table_name, - how=join) + deconcat_sdata = match_sdata_to_table(full_sdata, table=deconcat_table, table_name=table_name, how=join) sdatas.append(deconcat_sdata) - return sdatas[0] if is_single_region else sdatas \ No newline at end of file + return sdatas[0] if is_single_region else sdatas diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 75e191b3..0d9ddab0 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -703,30 +703,18 @@ def test_deconcatenate(full_sdata): assert len(regions) == len(table_names) # MULTIPLE REGIONS === - sdatas = deconcatenate( - full_sdata, - by=regions, - target_coordinate_system="global", - sdatas_table_names=table_names - ) + sdatas = deconcatenate(full_sdata, by=regions, target_coordinate_system="global", sdatas_table_names=table_names) assert isinstance(sdatas, list) assert len(sdatas) == len(regions) for sdata, region, table_name in zip(sdatas, regions, table_names): - assert table_name in sdata.tables table = sdata.tables[table_name] assert (table.obs["region"] == region).all() - - # SINGLE REGION === - single = deconcatenate( - full_sdata, - by=regions[0], - target_coordinate_system="global" - ) + single = deconcatenate(full_sdata, by=regions[0], target_coordinate_system="global") assert not isinstance(single, list) - assert "table" in single.tables \ No newline at end of file + assert "table" in single.tables From 8054a83dd368e586519a3b2e4d3bf8b2af06ca4e Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 11 May 2026 18:01:38 +0200 Subject: [PATCH 3/4] Move and refactor deconcatenate() into concatenate.py - Move deconcatenate() from deconcatenate.py into concatenate.py and delete the old module - Rename full_sdata -> sdata, by -> split_by (now a column name, not a value); split_by defaults to region_key inferred from table attrs - Remove region_key, instance_key, target_coordinate_system and sdatas_table_names parameters; infer region_key/instance_key from attrs via get_table_keys(); table_name auto-detected when sdata has exactly one table - Replace match_sdata_to_table loop with direct join_spatialelement_table calls; use groupby to partition obs in one pass (O(n_obs)) instead of per-value boolean filtering (O(n_values * n_obs)) - Always return list[SpatialData] (no more single/list duality) - Export deconcatenate from spatialdata.__init__ - Rewrite test to use 3 self-contained shape elements and cover the split_by non-default column path Co-Authored-By: Claude Sonnet 4.6 --- src/spatialdata/__init__.py | 2 +- src/spatialdata/_core/concatenate.py | 75 ++++++++++++++++++- src/spatialdata/_core/deconcatenate.py | 41 ---------- .../operations/test_spatialdata_operations.py | 64 +++++++++++----- 4 files changed, 120 insertions(+), 62 deletions(-) delete mode 100644 src/spatialdata/_core/deconcatenate.py diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 7ba66e71..9254ece9 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -164,7 +164,7 @@ def __dir__() -> list[str]: from spatialdata._core.centroids import get_centroids # _core.concatenate - from spatialdata._core.concatenate import concatenate + from spatialdata._core.concatenate import concatenate, deconcatenate # _core.data_extent from spatialdata._core.data_extent import are_extents_equal, get_extent diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index b0639eac..314226b1 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Iterable from copy import copy # Should probably go up at the top from itertools import chain -from typing import Any +from typing import Any, Literal from warnings import warn import numpy as np @@ -12,6 +12,7 @@ from anndata._core.merge import StrategiesLiteral, resolve_merge_strategy from spatialdata._core._utils import _find_common_table_keys +from spatialdata._core.query.relational_query import join_spatialelement_table from spatialdata._core.spatialdata import SpatialData from spatialdata.models import TableModel, get_table_keys from spatialdata.transformations import ( @@ -22,6 +23,7 @@ __all__ = [ "concatenate", + "deconcatenate", ] @@ -281,3 +283,74 @@ def _fix_ensure_unique_element_names( sdata_fixed = SpatialData.init_from_elements(elements | tables) sdatas_fixed.append(sdata_fixed) return sdatas_fixed + + +def deconcatenate( + sdata: SpatialData, + table_name: str | None = None, + split_by: str | None = None, + join: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", +) -> list[SpatialData]: + """ + Split a SpatialData object into multiple objects by unique values of a column. + + Parameters + ---------- + sdata + The SpatialData object to split. + table_name + Name of the table to split on. If ``None`` and the object contains exactly one table, that + table is used; otherwise the name must be provided explicitly. + split_by + Column in the table to split by. Defaults to the region_key inferred from the table attributes. + join + Join type for matching spatial elements to each sub-table. + + Returns + ------- + list[SpatialData] + One SpatialData object per unique value in the ``split_by`` column. + """ + if table_name is None: + if len(sdata.tables) == 1: + table_name = next(iter(sdata.tables)) + else: + raise ValueError( + f"sdata contains {len(sdata.tables)} tables; please specify `table_name` explicitly. " + f"Available tables: {list(sdata.tables)}" + ) + if table_name not in sdata.tables: + raise KeyError(f"Table '{table_name}' not found in sdata") + + table = sdata[table_name] + _, region_key, instance_key = get_table_keys(table) + + if split_by is None: + split_by = region_key + + # groupby partitions obs_names by split_by in one pass (O(n_obs)); slicing AnnData by the + # resulting index groups is then O(group_size) per group rather than O(n_obs) per group. + groups = table.obs.groupby(split_by, observed=True).groups + sub_tables = {val: table[idx] for val, idx in groups.items()} + annotated_regions = {val: sub.obs[region_key].unique().tolist() for val, sub in sub_tables.items()} + + # join_spatialelement_table is called once per split value; future work could perform all + # element joins in a single pass and then partition the result. + sdatas = [] + for val in groups: + regions = annotated_regions[val] + elements, filtered_table = join_spatialelement_table( + sdata, + spatial_element_names=regions, + table=sub_tables[val], + how=join, + ) + parsed_table = TableModel.parse( + filtered_table, + region=regions, + region_key=region_key, + instance_key=instance_key, + overwrite_metadata=True, + ) + sdatas.append(SpatialData.init_from_elements(elements | {table_name: parsed_table})) + return sdatas diff --git a/src/spatialdata/_core/deconcatenate.py b/src/spatialdata/_core/deconcatenate.py deleted file mode 100644 index 87c60a9e..00000000 --- a/src/spatialdata/_core/deconcatenate.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable - -from spatialdata._core.query.relational_query import match_sdata_to_table -from spatialdata._core.spatialdata import SpatialData - - -def deconcatenate( - full_sdata: SpatialData, - by: str | Iterable[str], - target_coordinate_system: str, - full_sdata_table_name: str = "table", - sdatas_table_names: str | Iterable[str] = "table", - region_key: str = "region", - join: str = "right", -) -> SpatialData | list[SpatialData]: - """ - From a `SpatialData` object containing multiple regions, returns `SpatialData` objects specific to each region identified in `by`. - """ - if full_sdata_table_name not in full_sdata.tables: - raise KeyError("Missing table") - - sdata_table = full_sdata[full_sdata_table_name] - - is_single_region = isinstance(by, str) - deconcat_regions = [by] if is_single_region else list(by) - sdatas_table_names = ( - [sdatas_table_names] * len(deconcat_regions) - if isinstance(sdatas_table_names, str) - else list(sdatas_table_names) - ) - - sdatas = [] - for region, table_name in zip(deconcat_regions, sdatas_table_names): - deconcat_table = sdata_table[sdata_table.obs[region_key] == region] - deconcat_sdata = match_sdata_to_table(full_sdata, table=deconcat_table, table_name=table_name, how=join) - - sdatas.append(deconcat_sdata) - - return sdatas[0] if is_single_region else sdatas diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 023f8ea9..4cdd0734 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -10,7 +10,7 @@ from spatialdata._core.concatenate import _concatenate_tables, concatenate from spatialdata._core.data_extent import are_extents_equal, get_extent -from spatialdata._core.deconcatenate import deconcatenate +from spatialdata._core.concatenate import deconcatenate from spatialdata._core.operations._utils import transform_to_data_extent from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike @@ -697,25 +697,51 @@ def test_validate_table_in_spatialdata(full_sdata): full_sdata.validate_table_in_spatialdata(table) -def test_deconcatenate(full_sdata): +def test_deconcatenate(): + from shapely.geometry import Point - regions = ["region1", "region2"] - table_names = ["table1", "table2"] - assert len(regions) == len(table_names) + shapes_a = ShapesModel.parse(GeoDataFrame({"geometry": [Point(0, 0), Point(1, 1)], "radius": [1.0, 1.0]})) + shapes_b = ShapesModel.parse(GeoDataFrame({"geometry": [Point(2, 2), Point(3, 3)], "radius": [1.0, 1.0]})) + shapes_c = ShapesModel.parse(GeoDataFrame({"geometry": [Point(4, 4), Point(5, 5)], "radius": [1.0, 1.0]})) - # MULTIPLE REGIONS === - sdatas = deconcatenate(full_sdata, by=regions, target_coordinate_system="global", sdatas_table_names=table_names) + obs = pd.DataFrame( + { + "region": pd.Categorical( + ["region_a", "region_a", "region_b", "region_b", "region_c", "region_c"] + ), + "instance_id": [0, 1, 0, 1, 0, 1], + "batch": pd.Categorical(["batch1", "batch2", "batch1", "batch2", "batch1", "batch2"]), + }, + index=["obs0", "obs1", "obs2", "obs3", "obs4", "obs5"], + ) + table = AnnData(obs=obs) + table = TableModel.parse( + table, + region=["region_a", "region_b", "region_c"], + region_key="region", + instance_key="instance_id", + ) + sdata = SpatialData( + shapes={"region_a": shapes_a, "region_b": shapes_b, "region_c": shapes_c}, + tables={"table": table}, + ) + # split by region_key (default) + sdatas = deconcatenate(sdata) assert isinstance(sdatas, list) - assert len(sdatas) == len(regions) - - for sdata, region, table_name in zip(sdatas, regions, table_names): - assert table_name in sdata.tables - table = sdata.tables[table_name] - assert (table.obs["region"] == region).all() - - # SINGLE REGION === - single = deconcatenate(full_sdata, by=regions[0], target_coordinate_system="global") - - assert not isinstance(single, list) - assert "table" in single.tables + assert len(sdatas) == 3 + seen_regions = set() + for sub in sdatas: + assert "table" in sub.tables + sub_table = sub.tables["table"] + unique_regions = sub_table.obs["region"].unique().tolist() + assert len(unique_regions) == 1 + seen_regions.add(unique_regions[0]) + assert seen_regions == {"region_a", "region_b", "region_c"} + + # split by a non-default column + sdatas2 = deconcatenate(sdata, split_by="batch") + assert isinstance(sdatas2, list) + assert len(sdatas2) == 2 + seen_batches = {sub.tables["table"].obs["batch"].unique()[0] for sub in sdatas2} + assert seen_batches == {"batch1", "batch2"} From 203e3174c12c5ae3be72323b4e0b90d0a1b584fc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 May 2026 16:08:23 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/core/operations/test_spatialdata_operations.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 4cdd0734..0cb7c54c 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -8,9 +8,8 @@ from anndata import AnnData from geopandas import GeoDataFrame -from spatialdata._core.concatenate import _concatenate_tables, concatenate +from spatialdata._core.concatenate import _concatenate_tables, concatenate, deconcatenate from spatialdata._core.data_extent import are_extents_equal, get_extent -from spatialdata._core.concatenate import deconcatenate from spatialdata._core.operations._utils import transform_to_data_extent from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike @@ -706,9 +705,7 @@ def test_deconcatenate(): obs = pd.DataFrame( { - "region": pd.Categorical( - ["region_a", "region_a", "region_b", "region_b", "region_c", "region_c"] - ), + "region": pd.Categorical(["region_a", "region_a", "region_b", "region_b", "region_c", "region_c"]), "instance_id": [0, 1, 0, 1, 0, 1], "batch": pd.Categorical(["batch1", "batch2", "batch1", "batch2", "batch1", "batch2"]), },