Skip to content
194 changes: 130 additions & 64 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import partial, singledispatch
from typing import TYPE_CHECKING, Literal, TypedDict, get_args
from functools import singledispatch
from typing import TYPE_CHECKING, Literal, TypedDict

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -71,17 +71,32 @@ def __init__(
indicator_matrix: CSRBase | sparse.coo_array
data: ArrayT

def count_nonzero(self) -> NDArray[np.integer]:
def count_nonzero(self, *, keep_sparse: bool = True) -> NDArray[np.integer]:
"""Count the number of observations in each group.

Parameters
----------
keep_sparse
If True and the input data is a sparse matrix, return a sparse matrix
of the same type for the aggregated counts. If False, always return
a dense :class:`numpy.ndarray`.

Returns
-------
Array of counts.

"""
return self._sum(data=(self.data != 0).astype("uint8"))
data = self.data
if isinstance(data, CSBase):
data = type(data)(
(np.ones(data.nnz, dtype="uint8"), data.indices, data.indptr),
shape=data.shape,
)
else:
data = (data != 0).astype("uint8")
return self._sum(data=data, keep_sparse=keep_sparse)

def _sum(self, data: ArrayT):
def _sum(self, data: ArrayT, *, keep_sparse: bool):
if isinstance(data, np.ndarray):
res = self.indicator_matrix @ data
if isinstance(res, CSBase):
Expand All @@ -92,17 +107,26 @@ def _sum(self, data: ArrayT):
(agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)(
self.indicator_matrix, data, out
)
if keep_sparse and isinstance(data, CSBase):
return type(data)(out) # convert to sparse type of input
return out

def sum(self) -> np.ndarray:
def sum(self, *, keep_sparse: bool = True) -> np.ndarray:
"""Compute the sum per feature per group of observations.

Parameters
----------
keep_sparse
If True and the input data is a sparse matrix, return a sparse matrix
of the same type for the aggregated sums. If False, always return
a dense :class:`numpy.ndarray`.

Returns
-------
Array of sum.

"""
return self._sum(self.data)
return self._sum(self.data, keep_sparse=keep_sparse)

def mean(self) -> Array:
"""Compute the mean per feature per group of observations.
Expand All @@ -112,7 +136,7 @@ def mean(self) -> Array:
Array of mean.

"""
return self.sum() / np.bincount(self.groupby.codes)[:, None]
return self.sum(keep_sparse=False) / np.bincount(self.groupby.codes)[:, None]

def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
"""Compute the count, as well as mean and variance per feature, per group of observations.
Expand All @@ -139,7 +163,10 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
if isinstance(self.data, np.ndarray):
mean_ = self.mean()
# sparse matrices do not support ** for elementwise power.
mean_sq = self._sum(_power(self.data, 2)) / group_counts[:, None]
mean_sq = (
self._sum(_power(self.data, 2), keep_sparse=False)
/ group_counts[:, None]
)
sq_mean = mean_**2
var_ = mean_sq - sq_mean
else:
Expand Down Expand Up @@ -205,6 +232,7 @@ def aggregate( # noqa: PLR0912
axis: Literal["obs", 0, "var", 1] | None = None,
mask: NDArray[np.bool] | str | None = None,
dof: int = 1,
keep_sparse: bool = True,
layer: str | None = None,
obsm: str | None = None,
varm: str | None = None,
Expand Down Expand Up @@ -235,6 +263,10 @@ def aggregate( # noqa: PLR0912
Boolean mask (or key to column containing mask) to apply along the axis.
dof
Degrees of freedom for variance. Defaults to 1.
keep_sparse
If True and the input data is a sparse matrix, preserve sparse outputs
for metrics that support it (for example, ``sum`` and ``count_nonzero``).
If False, force dense :class:`numpy.ndarray` outputs. Defaults to True.
layer
If not None, key for aggregation data.
obsm
Expand Down Expand Up @@ -324,6 +356,7 @@ def aggregate( # noqa: PLR0912
func=func,
mask=mask,
dof=dof,
keep_sparse=keep_sparse,
)

# Define new var dataframe
Expand Down Expand Up @@ -354,6 +387,7 @@ def _aggregate(
*,
mask: NDArray[np.bool] | None = None,
dof: int = 1,
keep_sparse: bool = True,
) -> dict[AggType, np.ndarray | DaskArray]:
msg = f"Data type {type(data)} not supported for aggregation"
raise NotImplementedError(msg)
Expand All @@ -370,9 +404,14 @@ def aggregate_dask_mean_var(
*,
mask: NDArray[np.bool] | None = None,
dof: int = 1,
keep_sparse: bool = False,
) -> MeanVarDict:
mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"]
sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"]
mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof, keep_sparse=False)[
"mean"
]
sq_mean = aggregate_dask(
fau_power(data, 2), by, "mean", mask=mask, dof=dof, keep_sparse=False
)["mean"]
# TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse.
if isinstance(data._meta, CSRBase):
sq_mean = sq_mean.compute()
Expand All @@ -391,86 +430,115 @@ def aggregate_dask(
*,
mask: NDArray[np.bool] | None = None,
dof: int = 1,
keep_sparse: bool = True,
) -> dict[AggType, DaskArray]:
import dask

if not isinstance(data._meta, CSBase | np.ndarray):
msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported."
raise ValueError(msg)
chunked_axis, unchunked_axis = (
(0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0)
)
if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]:
msg = "Feature axis must be unchunked"
raise ValueError(msg)

def aggregate_chunk_sum_or_count_nonzero(
chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None
):
# only subset the mask and by if we need to i.e.,
# there is chunking along the same axis as by and mask
if chunked_axis == 0:
# See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
# for what is contained in `block_info`.
subset = slice(*block_info[0]["array-location"][0])
by_subsetted = by[subset]
mask_subsetted = mask[subset] if mask is not None else mask
else:
by_subsetted = by
mask_subsetted = mask
res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func]
return res[None, :] if unchunked_axis == 1 else res
data = data.rechunk({unchunked_axis: -1})

funcs = set([func] if isinstance(func, str) else func)
if "median" in funcs:
msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue."
raise NotImplementedError(msg)
has_mean, has_var = (v in funcs for v in ["mean", "var"])
funcs_no_var_or_mean = funcs - {"var", "mean"}
# aggregate each row chunk or column chunk individually,
# producing a #chunks × #categories × #features or a #categories × #chunks array,
# then aggregate the per-chunk results.
chunks = (
((1,) * data.blocks.size, (len(by.categories),), data.shape[1])
if unchunked_axis == 1
else (len(by.categories), data.chunks[1])
)

@dask.delayed
def aggregate_chunk(block, block_idx):
subset = slice(block_idx[0], block_idx[1])
by_subsetted = (
pd.Categorical.from_codes(by.codes[subset], categories=by.categories)
if chunked_axis == 0
else by
)
mask_subsetted = (
mask[subset] if (mask is not None and chunked_axis == 0) else mask
)
return {
f: _aggregate(
block,
by_subsetted,
f,
mask=mask_subsetted,
dof=dof,
keep_sparse=keep_sparse,
)[f]
for f in funcs_no_var_or_mean
}

@dask.delayed
def combine_aggs(a, b):
if chunked_axis == 0:
return {f: a[f] + b[f] for f in funcs_no_var_or_mean}
else:
return {
f: sparse.hstack([a[f], b[f]])
if isinstance(a[f], CSBase)
else np.concatenate([a[f], b[f]], axis=1)
for f in funcs_no_var_or_mean
}

offset = 0
delayed_chunks = []
blocks = data.to_delayed().ravel()
for i, block in enumerate(blocks):
block_idx = (offset, offset + data.chunks[chunked_axis][i])
delayed_chunks.append(aggregate_chunk(block, block_idx))
offset += data.chunks[chunked_axis][i]

while len(delayed_chunks) > 1:
delayed_chunks = [
combine_aggs(delayed_chunks[i], delayed_chunks[i + 1])
if i + 1 < len(delayed_chunks)
else delayed_chunks[i]
for i in range(0, len(delayed_chunks), 2)
]

aggregated = {
f: data.map_blocks(
partial(aggregate_chunk_sum_or_count_nonzero, func=func),
new_axis=(1,) if unchunked_axis == 1 else None,
chunks=chunks,
meta=np.array(
[],
dtype=np.float64
if func not in get_args(ConstantDtypeAgg)
else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original
),
f: dask.array.from_delayed(
dask.delayed(lambda r, f=f: r[f])(delayed_chunks[0]),
shape=(len(by.categories), data.shape[1]),
dtype=np.float64,
)
for f in funcs_no_var_or_mean
}
# If we have row chunking, we need to handle the extra axis by summing over all category × feature matrices.
# Otherwise, dask internally concatenates the #categories × #chunks arrays i.e., the column chunks are concatenated together to get a #categories × #features matrix.
if unchunked_axis == 1:
for k, v in aggregated.items():
aggregated[k] = v.sum(axis=chunked_axis)

if has_var:
aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof)
# mean/var must be dense regardless of `keep_sparse`
aggredated_mean_var = aggregate_dask_mean_var(
data, by, mask=mask, dof=dof, keep_sparse=False
)
aggregated["var"] = aggredated_mean_var["var"]
if has_mean:
aggregated["mean"] = aggredated_mean_var["mean"]
# division must come after, not before, the summation for numerical precision
# i.e., we can't just call map blocks over the mean function.
elif has_mean:
group_counts = np.bincount(by.codes)
# compute sum then divide; force sum to be dense here for mean
aggregated["mean"] = (
aggregate_dask(data, by, "sum", mask=mask, dof=dof)["sum"]
aggregate_dask(data, by, "sum", mask=mask, dof=dof, keep_sparse=False)[
"sum"
]
/ group_counts[:, None]
)
return aggregated


@_aggregate.register(pd.DataFrame)
def aggregate_df(data, by, func, *, mask=None, dof=1) -> dict[AggType, np.ndarray]:
return _aggregate(data.values, by, func, mask=mask, dof=dof)
def aggregate_df(
data, by, func, *, mask=None, dof=1, keep_sparse=False
) -> dict[AggType, np.ndarray]:
return _aggregate(
data.values, by, func, mask=mask, dof=dof, keep_sparse=keep_sparse
)


@_aggregate.register(np.ndarray)
Expand All @@ -482,6 +550,7 @@ def aggregate_array(
*,
mask: NDArray[np.bool] | None = None,
dof: int = 1,
keep_sparse: bool = True,
) -> dict[AggType, np.ndarray]:
groupby = Aggregate(groupby=by, data=data, mask=mask)
result = {}
Expand All @@ -492,22 +561,19 @@ def aggregate_array(
raise ValueError(msg)

if "sum" in funcs: # sum is calculated separately from the rest
agg = groupby.sum()
result["sum"] = agg
result["sum"] = groupby.sum(keep_sparse=keep_sparse)
# here and below for count, if var is present, these can be calculate alongside var
if "mean" in funcs and "var" not in funcs:
agg = groupby.mean()
result["mean"] = agg
result["mean"] = groupby.mean()
if "count_nonzero" in funcs:
result["count_nonzero"] = groupby.count_nonzero()
result["count_nonzero"] = groupby.count_nonzero(keep_sparse=keep_sparse)
if "var" in funcs:
mean_, var_ = groupby.mean_var(dof)
result["var"] = var_
if "mean" in funcs:
result["mean"] = mean_
if "median" in funcs:
agg = groupby.median()
result["median"] = agg
result["median"] = groupby.median()
return result


Expand Down
Loading
Loading