diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..655b33d9a8 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -83,7 +83,8 @@ // "psutil": [""] "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 - // "scikit-misc": [""], + "scikit-misc": [""], + "dask": [""], }, // Combinations of libraries/python versions can be excluded/included diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 9633c8e208..70e658ffbf 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -9,12 +9,16 @@ from typing import TYPE_CHECKING import anndata as ad +import numpy as np +import zarr import scanpy as sc from ._utils import get_dataset, param_skipper if TYPE_CHECKING: + from typing import Literal + from ._utils import Dataset, KeyX @@ -47,17 +51,6 @@ def time_pca(self, *_) -> None: def peakmem_pca(self, *_) -> None: sc.pp.pca(self.adata, svd_solver="arpack") - def time_highly_variable_genes(self, *_) -> None: - # the default flavor runs on log-transformed data - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - - def peakmem_highly_variable_genes(self, *_) -> None: - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - # regress_out is very slow for this dataset @skip_when(dataset={"pbmc3k"}) def time_regress_out(self, *_) -> None: @@ -72,3 +65,63 @@ def time_scale(self, *_) -> None: def peakmem_scale(self, *_) -> None: sc.pp.scale(self.adata, max_value=10) + + +class HVGSuite: # noqa: D101 + params = (["seurat_v3", "cell_ranger", "seurat"], [True, False]) + param_names = ("flavor", "use_dask") + + def setup_cache(self) -> None: + """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" + adata, _ = get_dataset("lung93k") + adata.write_zarr("lung93k.zarr") + obs = np.arange(adata.shape[0]) + np.random.default_rng().shuffle(obs) + adata[obs].write_zarr("lung93k_shuffled.zarr") + + def setup( + self, + flavor: Literal["seurat_v3", "cell_ranger", "seurat"], + use_dask: bool, # noqa: FBT001 + ) -> None: + if use_dask: + z = zarr.open("lung93k_shuffled.zarr") + self.adata = ad.AnnData( + obs=ad.io.read_elem(z["obs"]), + var=ad.io.read_elem(z["var"]), + layers={ + "counts": ad.experimental.read_elem_lazy(z["layers"]["counts"]) + }, + X=ad.experimental.read_elem_lazy(z["X"]), + ) + # Times out on the benchmark machine with full dataset + self.adata = self.adata[ + self.adata.obs["PatientNumber"].isin(["1", "2", "3"]) + ].copy() + else: + self.adata = ad.read_zarr("lung93k.zarr") + sc.pp.filter_genes(self.adata, min_cells=3) + self.flavor = flavor + + def time_highly_variable_genes(self, *_) -> None: + # the default flavor runs on log-transformed data + sc.pp.highly_variable_genes( + self.adata, + min_mean=0.0125, + max_mean=3, + min_disp=0.5, + flavor=self.flavor, + batch_key="PatientNumber", + **({"layer": "counts"} if self.flavor == "seurat_v3" else {}), + ) + + def peakmem_highly_variable_genes(self, *_) -> None: + sc.pp.highly_variable_genes( + self.adata, + min_mean=0.0125, + max_mean=3, + min_disp=0.5, + flavor=self.flavor, + batch_key="PatientNumber", + **({"layer": "counts"} if self.flavor == "seurat_v3" else {}), + ) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 21fc5da45a..4f7c38ccf2 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -20,7 +20,7 @@ raise_if_dask_feature_axis_chunked, sanitize_anndata, ) -from ..get import _get_obs_rep +from ..get import _get_obs_rep, aggregate from ._distributed import materialize_as_ndarray from ._simple import filter_genes @@ -36,7 +36,7 @@ @singledispatch def clip_square_sum( data_batch: np.ndarray, clip_val: np.ndarray -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray] | tuple[DaskArray, DaskArray]: """Clip data_batch by clip_val. Parameters @@ -64,24 +64,19 @@ def clip_square_sum( @clip_square_sum.register(DaskArray) -def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]: +def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[DaskArray, DaskArray]: n_blocks = data_batch.blocks.size def sum_and_sum_squares_clipped_from_block(block): return np.vstack(clip_square_sum(block, clip_val))[None, ...] - squared_batch_counts_sum, batch_counts_sum = ( - data_batch - .map_blocks( - sum_and_sum_squares_clipped_from_block, - new_axis=(1,), - chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)), - meta=np.array([]), - dtype=np.float64, - ) - .sum(axis=0) - .compute() - ) + squared_batch_counts_sum, batch_counts_sum = data_batch.map_blocks( + sum_and_sum_squares_clipped_from_block, + new_axis=(1,), + chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)), + meta=np.array([]), + dtype=np.float64, + ).sum(axis=0) return squared_batch_counts_sum, batch_counts_sum @@ -172,17 +167,42 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 batch_info = ( pd.Categorical(np.zeros(adata.shape[0], dtype=int)) if batch_key is None - else adata.obs[batch_key].to_numpy() + else adata.obs[batch_key] ) - norm_gene_vars = [] + + adata_agg = AnnData( + X=data, + var=pd.DataFrame(index=adata.var_names), + obs=pd.DataFrame( + index=adata.obs_names, data={"__hvg_v3_batch_info__": batch_info} + ), + ) + aggregated_mean_var = aggregate( + adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"] + ) + mean_global, var_global = (aggregated_mean_var.layers[l] for l in ["mean", "var"]) + if isinstance(mean_global, DaskArray): + import dask.array as da + + mean_global, var_global = da.compute(mean_global, var_global) + aggregated_mean_var.layers["mean"] = mean_global + aggregated_mean_var.layers["var"] = var_global + batch_info = batch_info.to_numpy() for b in np.unique(batch_info): data_batch = data[batch_info == b] - - mean, var = stats.mean_var(data_batch, axis=0, correction=1) - # These get computed anyway for loess - if isinstance(mean, DaskArray): - mean, var = mean.compute(), var.compute() + mean, var = ( + aggregated_mean_var[ + aggregated_mean_var.obs["__hvg_v3_batch_info__"] == b + ].layers[l] + for l in ["mean", "var"] + ) + if isinstance(mean, CSBase): + mean = mean.toarray() + mean = mean.ravel() + if isinstance(var, CSBase): + var = var.toarray() + var = var.ravel() estimat_var = np.zeros(data.shape[1], dtype=np.float64) if (not_const := var > 0).any(): y = np.log10(var[not_const]) @@ -204,8 +224,13 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 + squared_batch_counts_sum - 2 * batch_counts_sum * mean ) - norm_gene_vars.append(norm_gene_var.reshape(1, -1)) + norm_gene_vars.append(norm_gene_var) + if any(isinstance(e, DaskArray) for e in norm_gene_vars): + import dask.array as da + + norm_gene_vars = da.compute(*norm_gene_vars) + norm_gene_vars = [ngv.reshape(1, -1) for ngv in norm_gene_vars] norm_gene_vars = np.concatenate(norm_gene_vars, axis=0) # argsort twice gives ranks, small rank means most variable ranked_norm_gene_vars = np.argsort(np.argsort(-norm_gene_vars, axis=1), axis=1)