From 2225fcca68f029c75e51866a39c7f7920f31bd18 Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Fri, 17 Apr 2026 15:23:05 -0700 Subject: [PATCH 1/8] perf: use delayed dask matrix for agggregate, unchunked_axis not needed anymore --- src/scanpy/get/_aggregated.py | 110 ++++++++++++++++++---------------- 1 file changed, 58 insertions(+), 52 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index fe5140171b..69a08290b0 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -1,8 +1,9 @@ from __future__ import annotations -from functools import partial, singledispatch -from typing import TYPE_CHECKING, Literal, TypedDict, get_args +from functools import singledispatch +from typing import TYPE_CHECKING, Literal, TypedDict +import dask import numpy as np import pandas as pd from anndata import AnnData @@ -395,29 +396,7 @@ def aggregate_dask( if not isinstance(data._meta, CSBase | np.ndarray): msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported." raise ValueError(msg) - chunked_axis, unchunked_axis = ( - (0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0) - ) - if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]: - msg = "Feature axis must be unchunked" - raise ValueError(msg) - - def aggregate_chunk_sum_or_count_nonzero( - chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None - ): - # only subset the mask and by if we need to i.e., - # there is chunking along the same axis as by and mask - if chunked_axis == 0: - # See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html - # for what is contained in `block_info`. - subset = slice(*block_info[0]["array-location"][0]) - by_subsetted = by[subset] - mask_subsetted = mask[subset] if mask is not None else mask - else: - by_subsetted = by - mask_subsetted = mask - res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func] - return res[None, :] if unchunked_axis == 1 else res + chunked_axis, _ = (0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0) funcs = set([func] if isinstance(func, str) else func) if "median" in funcs: @@ -425,33 +404,60 @@ def aggregate_chunk_sum_or_count_nonzero( raise NotImplementedError(msg) has_mean, has_var = (v in funcs for v in ["mean", "var"]) funcs_no_var_or_mean = funcs - {"var", "mean"} - # aggregate each row chunk or column chunk individually, - # producing a #chunks × #categories × #features or a #categories × #chunks array, - # then aggregate the per-chunk results. - chunks = ( - ((1,) * data.blocks.size, (len(by.categories),), data.shape[1]) - if unchunked_axis == 1 - else (len(by.categories), data.chunks[1]) - ) - aggregated = { - f: data.map_blocks( - partial(aggregate_chunk_sum_or_count_nonzero, func=func), - new_axis=(1,) if unchunked_axis == 1 else None, - chunks=chunks, - meta=np.array( - [], - dtype=np.float64 - if func not in get_args(ConstantDtypeAgg) - else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original - ), - ) - for f in funcs_no_var_or_mean - } - # If we have row chunking, we need to handle the extra axis by summing over all category × feature matrices. - # Otherwise, dask internally concatenates the #categories × #chunks arrays i.e., the column chunks are concatenated together to get a #categories × #features matrix. - if unchunked_axis == 1: - for k, v in aggregated.items(): - aggregated[k] = v.sum(axis=chunked_axis) + + if funcs_no_var_or_mean: + funcs_list = list(funcs_no_var_or_mean) + by_codes = np.asarray(by.codes) + mask_arr = np.asarray(mask) if mask is not None else None + + @dask.delayed + def aggregate_chunk(block, block_idx): + subset = slice(block_idx[0], block_idx[1]) + by_subsetted = ( + pd.Categorical.from_codes(by_codes[subset], categories=by.categories) + if chunked_axis == 0 + else by + ) + mask_subsetted = ( + mask_arr[subset] + if (mask_arr is not None and chunked_axis == 0) + else mask_arr + ) + return { + f: _aggregate(block, by_subsetted, f, mask=mask_subsetted, dof=dof)[f] + for f in funcs_list + } + + @dask.delayed + def add_aggs(a, b): + return {f: a[f] + b[f] for f in funcs_list} + + blocks = data.to_delayed().ravel() + + offset = 0 + delayed_chunks = [] + for i, block in enumerate(blocks): + block_idx = (offset, offset + data.chunks[chunked_axis][i]) + delayed_chunks.append(aggregate_chunk(block, block_idx)) + offset += data.chunks[chunked_axis][i] + + while len(delayed_chunks) > 1: + delayed_chunks = [ + add_aggs(delayed_chunks[i], delayed_chunks[i + 1]) + if i + 1 < len(delayed_chunks) + else delayed_chunks[i] + for i in range(0, len(delayed_chunks), 2) + ] + + aggregated = { + f: dask.array.from_delayed( + dask.delayed(lambda r, f=f: r[f])(delayed_chunks[0]), + shape=(len(by.categories), data.shape[1]), + dtype=np.float64, + ) + for f in funcs_list + } + if has_var: aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof) aggregated["var"] = aggredated_mean_var["var"] From 972b5e9e7df8eb75721f3e1a968257cdf5785a2a Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Fri, 17 Apr 2026 15:23:46 -0700 Subject: [PATCH 2/8] perf: pass only nnz to _sum function for sparse matrix in count_nonzero --- src/scanpy/get/_aggregated.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 69a08290b0..4673fbe001 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -80,7 +80,15 @@ def count_nonzero(self) -> NDArray[np.integer]: Array of counts. """ - return self._sum(data=(self.data != 0).astype("uint8")) + data = self.data + if isinstance(data, CSBase): + data = type(data)( + (np.ones(data.nnz, dtype="uint8"), data.indices, data.indptr), + shape=data.shape, + ) + else: + data = (data != 0).astype("uint8") + return self._sum(data=data) def _sum(self, data: ArrayT): if isinstance(data, np.ndarray): From d7abc6b0756178147868e418f20de8a9d759a5cb Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Fri, 17 Apr 2026 15:24:28 -0700 Subject: [PATCH 3/8] feat: convert aggregate output back to sparse if input was sparse and aggregate fits sparsity heuristic --- src/scanpy/get/_aggregated.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 4673fbe001..510d615efa 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -101,6 +101,10 @@ def _sum(self, data: ArrayT): (agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)( self.indicator_matrix, data, out ) + if isinstance(data, CSBase): + nnz = np.count_nonzero(out) + if nnz / out.size < 0.5: # heuristic for when to return sparse vs dense + return type(data)(out) # convert to sparse type of input return out def sum(self) -> np.ndarray: From 72694c9590d308399a646bf95d36ff9c48373128 Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Fri, 17 Apr 2026 15:40:08 -0700 Subject: [PATCH 4/8] fix: make dask import internal to aggregate_dask function --- src/scanpy/get/_aggregated.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 510d615efa..aedc116a0b 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -3,7 +3,6 @@ from functools import singledispatch from typing import TYPE_CHECKING, Literal, TypedDict -import dask import numpy as np import pandas as pd from anndata import AnnData @@ -405,6 +404,8 @@ def aggregate_dask( mask: NDArray[np.bool] | None = None, dof: int = 1, ) -> dict[AggType, DaskArray]: + import dask + if not isinstance(data._meta, CSBase | np.ndarray): msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported." raise ValueError(msg) From 7287d7a00aee88d3054ff2e2485b8d9fec10c0ea Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Sat, 18 Apr 2026 07:12:01 -0700 Subject: [PATCH 5/8] fix: adjust combine function to work with csc matrix --- src/scanpy/get/_aggregated.py | 98 +++++++++++++++++------------------ 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index aedc116a0b..c06854793a 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -418,58 +418,58 @@ def aggregate_dask( has_mean, has_var = (v in funcs for v in ["mean", "var"]) funcs_no_var_or_mean = funcs - {"var", "mean"} - if funcs_no_var_or_mean: - funcs_list = list(funcs_no_var_or_mean) - by_codes = np.asarray(by.codes) - mask_arr = np.asarray(mask) if mask is not None else None - - @dask.delayed - def aggregate_chunk(block, block_idx): - subset = slice(block_idx[0], block_idx[1]) - by_subsetted = ( - pd.Categorical.from_codes(by_codes[subset], categories=by.categories) - if chunked_axis == 0 - else by - ) - mask_subsetted = ( - mask_arr[subset] - if (mask_arr is not None and chunked_axis == 0) - else mask_arr - ) + @dask.delayed + def aggregate_chunk(block, block_idx): + subset = slice(block_idx[0], block_idx[1]) + by_subsetted = ( + pd.Categorical.from_codes(by.codes[subset], categories=by.categories) + if chunked_axis == 0 + else by + ) + mask_subsetted = ( + mask[subset] if (mask is not None and chunked_axis == 0) else mask + ) + return { + f: _aggregate(block, by_subsetted, f, mask=mask_subsetted, dof=dof)[f] + for f in funcs_no_var_or_mean + } + + @dask.delayed + def combine_aggs(a, b): + if chunked_axis == 0: + return {f: a[f] + b[f] for f in funcs_no_var_or_mean} + else: return { - f: _aggregate(block, by_subsetted, f, mask=mask_subsetted, dof=dof)[f] - for f in funcs_list + f: sparse.hstack([a[f], b[f]]) + if isinstance(a[f], CSBase) + else np.concatenate([a[f], b[f]], axis=1) + for f in funcs_no_var_or_mean } - @dask.delayed - def add_aggs(a, b): - return {f: a[f] + b[f] for f in funcs_list} - - blocks = data.to_delayed().ravel() - - offset = 0 - delayed_chunks = [] - for i, block in enumerate(blocks): - block_idx = (offset, offset + data.chunks[chunked_axis][i]) - delayed_chunks.append(aggregate_chunk(block, block_idx)) - offset += data.chunks[chunked_axis][i] - - while len(delayed_chunks) > 1: - delayed_chunks = [ - add_aggs(delayed_chunks[i], delayed_chunks[i + 1]) - if i + 1 < len(delayed_chunks) - else delayed_chunks[i] - for i in range(0, len(delayed_chunks), 2) - ] - - aggregated = { - f: dask.array.from_delayed( - dask.delayed(lambda r, f=f: r[f])(delayed_chunks[0]), - shape=(len(by.categories), data.shape[1]), - dtype=np.float64, - ) - for f in funcs_list - } + offset = 0 + delayed_chunks = [] + blocks = data.to_delayed().ravel() + for i, block in enumerate(blocks): + block_idx = (offset, offset + data.chunks[chunked_axis][i]) + delayed_chunks.append(aggregate_chunk(block, block_idx)) + offset += data.chunks[chunked_axis][i] + + while len(delayed_chunks) > 1: + delayed_chunks = [ + combine_aggs(delayed_chunks[i], delayed_chunks[i + 1]) + if i + 1 < len(delayed_chunks) + else delayed_chunks[i] + for i in range(0, len(delayed_chunks), 2) + ] + + aggregated = { + f: dask.array.from_delayed( + dask.delayed(lambda r, f=f: r[f])(delayed_chunks[0]), + shape=(len(by.categories), data.shape[1]), + dtype=np.float64, + ) + for f in funcs_no_var_or_mean + } if has_var: aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof) From 64fe6d40f223cdb2a10cd19f07955166f7b55b3c Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Sat, 18 Apr 2026 15:36:49 -0700 Subject: [PATCH 6/8] fix: rechunk matrix for bad chunk and update test not fail --- src/scanpy/get/_aggregated.py | 6 +++++- tests/test_aggregated.py | 6 ++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index c06854793a..8deb6ea2e2 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -409,7 +409,11 @@ def aggregate_dask( if not isinstance(data._meta, CSBase | np.ndarray): msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported." raise ValueError(msg) - chunked_axis, _ = (0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0) + chunked_axis, unchunked_axis = ( + (0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0) + ) + if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]: + data = data.rechunk({unchunked_axis: -1}) funcs = set([func] if isinstance(func, str) else func) if "median" in funcs: diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index 979cbdc536..4908bfa346 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -201,8 +201,10 @@ def test_aggregate_bad_dask_array( ) -> None: adata = pbmc3k_processed().raw.to_adata() adata.X = func(adata.X) - with pytest.raises(ValueError, match=error_msg): - sc.get.aggregate(adata, ["louvain"], "sum") + # The implementation now rechunks the array to make the feature axis unchunked + # instead of raising; ensure aggregation completes and returns a dask layer. + result = sc.get.aggregate(adata, ["louvain"], "sum") + assert isinstance(result.layers["sum"], DaskArray) @pytest.mark.parametrize("axis_name", ["obs", "var"]) From 05d825a7ca360fef8397cc877ce3e1405d382752 Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Sat, 18 Apr 2026 15:58:18 -0700 Subject: [PATCH 7/8] fix: add keep_sparse parameter to control sparsity and remove heuristic to keep chunks consistent --- src/scanpy/get/_aggregated.py | 93 +++++++++++++++++++++++++---------- tests/test_aggregated.py | 33 ++++++++++--- 2 files changed, 94 insertions(+), 32 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 8deb6ea2e2..4172e8c1c4 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -71,9 +71,16 @@ def __init__( indicator_matrix: CSRBase | sparse.coo_array data: ArrayT - def count_nonzero(self) -> NDArray[np.integer]: + def count_nonzero(self, *, keep_sparse: bool = True) -> NDArray[np.integer]: """Count the number of observations in each group. + Parameters + ---------- + keep_sparse + If True and the input data is a sparse matrix, return a sparse matrix + of the same type for the aggregated counts. If False, always return + a dense :class:`numpy.ndarray`. + Returns ------- Array of counts. @@ -87,9 +94,9 @@ def count_nonzero(self) -> NDArray[np.integer]: ) else: data = (data != 0).astype("uint8") - return self._sum(data=data) + return self._sum(data=data, keep_sparse=keep_sparse) - def _sum(self, data: ArrayT): + def _sum(self, data: ArrayT, *, keep_sparse: bool): if isinstance(data, np.ndarray): res = self.indicator_matrix @ data if isinstance(res, CSBase): @@ -100,21 +107,26 @@ def _sum(self, data: ArrayT): (agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)( self.indicator_matrix, data, out ) - if isinstance(data, CSBase): - nnz = np.count_nonzero(out) - if nnz / out.size < 0.5: # heuristic for when to return sparse vs dense - return type(data)(out) # convert to sparse type of input + if keep_sparse and isinstance(data, CSBase): + return type(data)(out) # convert to sparse type of input return out - def sum(self) -> np.ndarray: + def sum(self, *, keep_sparse: bool = True) -> np.ndarray: """Compute the sum per feature per group of observations. + Parameters + ---------- + keep_sparse + If True and the input data is a sparse matrix, return a sparse matrix + of the same type for the aggregated sums. If False, always return + a dense :class:`numpy.ndarray`. + Returns ------- Array of sum. """ - return self._sum(self.data) + return self._sum(self.data, keep_sparse=keep_sparse) def mean(self) -> Array: """Compute the mean per feature per group of observations. @@ -124,7 +136,7 @@ def mean(self) -> Array: Array of mean. """ - return self.sum() / np.bincount(self.groupby.codes)[:, None] + return self.sum(keep_sparse=False) / np.bincount(self.groupby.codes)[:, None] def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: """Compute the count, as well as mean and variance per feature, per group of observations. @@ -151,7 +163,10 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: if isinstance(self.data, np.ndarray): mean_ = self.mean() # sparse matrices do not support ** for elementwise power. - mean_sq = self._sum(_power(self.data, 2)) / group_counts[:, None] + mean_sq = ( + self._sum(_power(self.data, 2), keep_sparse=False) + / group_counts[:, None] + ) sq_mean = mean_**2 var_ = mean_sq - sq_mean else: @@ -217,6 +232,7 @@ def aggregate( # noqa: PLR0912 axis: Literal["obs", 0, "var", 1] | None = None, mask: NDArray[np.bool] | str | None = None, dof: int = 1, + keep_sparse: bool = True, layer: str | None = None, obsm: str | None = None, varm: str | None = None, @@ -247,6 +263,10 @@ def aggregate( # noqa: PLR0912 Boolean mask (or key to column containing mask) to apply along the axis. dof Degrees of freedom for variance. Defaults to 1. + keep_sparse + If True and the input data is a sparse matrix, preserve sparse outputs + for metrics that support it (for example, ``sum`` and ``count_nonzero``). + If False, force dense :class:`numpy.ndarray` outputs. Defaults to True. layer If not None, key for aggregation data. obsm @@ -336,6 +356,7 @@ def aggregate( # noqa: PLR0912 func=func, mask=mask, dof=dof, + keep_sparse=keep_sparse, ) # Define new var dataframe @@ -366,6 +387,7 @@ def _aggregate( *, mask: NDArray[np.bool] | None = None, dof: int = 1, + keep_sparse: bool = True, ) -> dict[AggType, np.ndarray | DaskArray]: msg = f"Data type {type(data)} not supported for aggregation" raise NotImplementedError(msg) @@ -382,9 +404,14 @@ def aggregate_dask_mean_var( *, mask: NDArray[np.bool] | None = None, dof: int = 1, + keep_sparse: bool = False, ) -> MeanVarDict: - mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"] - sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"] + mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof, keep_sparse=False)[ + "mean" + ] + sq_mean = aggregate_dask( + fau_power(data, 2), by, "mean", mask=mask, dof=dof, keep_sparse=False + )["mean"] # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse. if isinstance(data._meta, CSRBase): sq_mean = sq_mean.compute() @@ -403,6 +430,7 @@ def aggregate_dask( *, mask: NDArray[np.bool] | None = None, dof: int = 1, + keep_sparse: bool = True, ) -> dict[AggType, DaskArray]: import dask @@ -434,7 +462,14 @@ def aggregate_chunk(block, block_idx): mask[subset] if (mask is not None and chunked_axis == 0) else mask ) return { - f: _aggregate(block, by_subsetted, f, mask=mask_subsetted, dof=dof)[f] + f: _aggregate( + block, + by_subsetted, + f, + mask=mask_subsetted, + dof=dof, + keep_sparse=keep_sparse, + )[f] for f in funcs_no_var_or_mean } @@ -476,7 +511,10 @@ def combine_aggs(a, b): } if has_var: - aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof) + # mean/var must be dense regardless of `keep_sparse` + aggredated_mean_var = aggregate_dask_mean_var( + data, by, mask=mask, dof=dof, keep_sparse=False + ) aggregated["var"] = aggredated_mean_var["var"] if has_mean: aggregated["mean"] = aggredated_mean_var["mean"] @@ -484,16 +522,23 @@ def combine_aggs(a, b): # i.e., we can't just call map blocks over the mean function. elif has_mean: group_counts = np.bincount(by.codes) + # compute sum then divide; force sum to be dense here for mean aggregated["mean"] = ( - aggregate_dask(data, by, "sum", mask=mask, dof=dof)["sum"] + aggregate_dask(data, by, "sum", mask=mask, dof=dof, keep_sparse=False)[ + "sum" + ] / group_counts[:, None] ) return aggregated @_aggregate.register(pd.DataFrame) -def aggregate_df(data, by, func, *, mask=None, dof=1) -> dict[AggType, np.ndarray]: - return _aggregate(data.values, by, func, mask=mask, dof=dof) +def aggregate_df( + data, by, func, *, mask=None, dof=1, keep_sparse=False +) -> dict[AggType, np.ndarray]: + return _aggregate( + data.values, by, func, mask=mask, dof=dof, keep_sparse=keep_sparse + ) @_aggregate.register(np.ndarray) @@ -505,6 +550,7 @@ def aggregate_array( *, mask: NDArray[np.bool] | None = None, dof: int = 1, + keep_sparse: bool = True, ) -> dict[AggType, np.ndarray]: groupby = Aggregate(groupby=by, data=data, mask=mask) result = {} @@ -515,22 +561,19 @@ def aggregate_array( raise ValueError(msg) if "sum" in funcs: # sum is calculated separately from the rest - agg = groupby.sum() - result["sum"] = agg + result["sum"] = groupby.sum(keep_sparse=keep_sparse) # here and below for count, if var is present, these can be calculate alongside var if "mean" in funcs and "var" not in funcs: - agg = groupby.mean() - result["mean"] = agg + result["mean"] = groupby.mean() if "count_nonzero" in funcs: - result["count_nonzero"] = groupby.count_nonzero() + result["count_nonzero"] = groupby.count_nonzero(keep_sparse=keep_sparse) if "var" in funcs: mean_, var_ = groupby.mean_var(dof) result["var"] = var_ if "mean" in funcs: result["mean"] = mean_ if "median" in funcs: - agg = groupby.median() - result["median"] = agg + result["median"] = groupby.median() return result diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index 4908bfa346..8f3b9795e3 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -9,7 +9,7 @@ from scipy import sparse import scanpy as sc -from scanpy._compat import DaskArray +from scanpy._compat import CSBase, DaskArray from scanpy._utils import _resolve_axis, get_literal_vals from scanpy.get._aggregated import AggType from testing.scanpy._helpers import assert_equal @@ -416,8 +416,9 @@ def test_combine_categories( @pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES) +@pytest.mark.parametrize("keep_sparse", [True, False]) def test_aggregate_arraytype( - array_type, metric: AggType, request: pytest.FixtureRequest + array_type, metric: AggType, *, keep_sparse: bool, request: pytest.FixtureRequest ) -> None: adata = pbmc3k_processed().raw.to_adata() adata = adata[ @@ -425,11 +426,29 @@ def test_aggregate_arraytype( ].copy() adata.X = array_type(adata.X) xfail_dask_median(adata, metric, request) - aggregate = sc.get.aggregate(adata, ["louvain"], metric) - assert isinstance( - aggregate.layers[metric], - DaskArray if isinstance(adata.X, DaskArray) else np.ndarray, - ) + aggregate = sc.get.aggregate(adata, ["louvain"], metric, keep_sparse=keep_sparse) + + # Resolve dask if present for type assertions + layer = aggregate.layers[metric] + + if isinstance(adata.X, DaskArray): + assert isinstance(layer, DaskArray) + layer = layer.compute() + adata.X = adata.X.compute() + + # Determine expected sparsity concisely + if metric in {"mean", "var"}: + expected_sparse = False + elif metric in {"count_nonzero", "sum"}: + expected_sparse = isinstance(adata.X, CSBase) and keep_sparse + print(keep_sparse, expected_sparse, isinstance(adata.X, CSBase)) + else: + expected_sparse = False + + if expected_sparse: + assert isinstance(layer, CSBase) + else: + assert isinstance(layer, np.ndarray) def test_aggregate_obsm_varm() -> None: From d9e766c113c62c9dca2e01d0c65af4dba94ec817 Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Sat, 18 Apr 2026 16:04:10 -0700 Subject: [PATCH 8/8] add tests for comparing dask output with in-memory output --- tests/test_aggregated.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index 8f3b9795e3..71d004eed2 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -16,6 +16,7 @@ from testing.scanpy._helpers.data import pbmc3k_processed from testing.scanpy._pytest.marks import needs from testing.scanpy._pytest.params import ARRAY_TYPES as ARRAY_TYPES_ALL +from testing.scanpy._pytest.params import ARRAY_TYPES_DASK if TYPE_CHECKING: from collections.abc import Callable @@ -565,3 +566,42 @@ def test_nan() -> None: "s2_control_C", ] assert adata_agg.obs["n_obs_aggregated"].tolist() == [1, 2, 1] + + +def _to_dense_array(x): + """Normalize various array-like objects to a dense numpy array for comparison. + + Handles `DaskArray` by computing, sparse matrices by converting to dense, + and ensures a numpy ndarray is returned. + """ + if isinstance(x, DaskArray): + x = x.compute() + if isinstance(x, CSBase): + x = x.toarray() + return np.asarray(x) + + +@needs.dask +@pytest.mark.parametrize("array_type", ARRAY_TYPES_DASK) +def test_aggregate_dask_vs_regular( + array_type, metric: AggType, request: pytest.FixtureRequest +): + adata = pbmc3k_processed().raw.to_adata() + adata = adata[ + adata.obs["louvain"].isin(adata.obs["louvain"].cat.categories[:5]), :1_000 + ].copy() + + # expected result + expected = sc.get.aggregate(adata, ["louvain"], metric) + + # create dask array + adata.X = array_type(adata.X) + xfail_dask_median(adata, metric, request) + + # dask result + dask_res = sc.get.aggregate(adata, ["louvain"], metric) + + # check results + a = _to_dense_array(expected.layers[metric]) + b = _to_dense_array(dask_res.layers[metric]) + np.testing.assert_allclose(a, b, atol=1e-6)