From a625c5594e13ff368fb857f40cdd14920096c0d9 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 26 Mar 2026 14:44:35 +0100 Subject: [PATCH 01/15] perf: "two-pass" seurat hvg3 via `scanpy.get.aggregate` --- .../preprocessing/_highly_variable_genes.py | 76 +++++++++++++------ 1 file changed, 52 insertions(+), 24 deletions(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 10dbec1117..7224e87ebd 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -20,7 +20,7 @@ raise_if_dask_feature_axis_chunked, sanitize_anndata, ) -from ..get import _get_obs_rep +from ..get import _get_obs_rep, aggregate from ._distributed import materialize_as_ndarray from ._simple import filter_genes @@ -36,7 +36,7 @@ @singledispatch def clip_square_sum( data_batch: np.ndarray, clip_val: np.ndarray -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray] | tuple[DaskArray, DaskArray]: """Clip data_batch by clip_val. Parameters @@ -64,24 +64,19 @@ def clip_square_sum( @clip_square_sum.register(DaskArray) -def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]: +def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[DaskArray, DaskArray]: n_blocks = data_batch.blocks.size def sum_and_sum_squares_clipped_from_block(block): return np.vstack(clip_square_sum(block, clip_val))[None, ...] - squared_batch_counts_sum, batch_counts_sum = ( - data_batch - .map_blocks( - sum_and_sum_squares_clipped_from_block, - new_axis=(1,), - chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)), - meta=np.array([]), - dtype=np.float64, - ) - .sum(axis=0) - .compute() - ) + squared_batch_counts_sum, batch_counts_sum = data_batch.map_blocks( + sum_and_sum_squares_clipped_from_block, + new_axis=(1,), + chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)), + meta=np.array([]), + dtype=np.float64, + ).sum(axis=0) return squared_batch_counts_sum, batch_counts_sum @@ -172,17 +167,43 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 batch_info = ( pd.Categorical(np.zeros(adata.shape[0], dtype=int)) if batch_key is None - else adata.obs[batch_key].to_numpy() + else adata.obs[batch_key] ) - norm_gene_vars = [] - for b in np.unique(batch_info): - data_batch = data[batch_info == b] - mean, var = stats.mean_var(data_batch, axis=0, correction=1) - # These get computed anyway for loess - if isinstance(mean, DaskArray): - mean, var = mean.compute(), var.compute() + if can_aggregate := (inplace or not adata.is_view): + adata.obs["__hvg_v3_batch_info__"] = batch_info + aggregated_mean_var = aggregate( + adata, by="__hvg_v3_batch_info__", func=["mean", "var"] + ) + mean_global, var_global = ( + aggregated_mean_var.layers[l] for l in ["mean", "var"] + ) + if isinstance(mean_global, DaskArray): + mean_global, var_global = mean_global.compute(), var_global.compute() + aggregated_mean_var.layers["mean"] = mean_global + aggregated_mean_var.layers["var"] = var_global + batch_info = batch_info.to_numpy() + for b in batch_info: + data_batch = data[batch_info == b] + if can_aggregate: + mean, var = ( + aggregated_mean_var[ + aggregated_mean_var.obs["__hvg_v3_batch_info__"] == b + ].layers[l] + for l in ["mean", "var"] + ) + if isinstance(mean, CSBase): + mean = mean.toarray() + mean = mean.ravel() + if isinstance(var, CSBase): + var = var.toarray() + var = var.ravel() + else: + mean, var = stats.mean_var(data_batch, axis=0, correction=1) + # These get computed anyway for loess + if isinstance(mean, DaskArray): + mean, var = mean.compute(), var.compute() estimat_var = np.zeros(data.shape[1], dtype=np.float64) if (not_const := var > 0).any(): y = np.log10(var[not_const]) @@ -204,8 +225,15 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 + squared_batch_counts_sum - 2 * batch_counts_sum * mean ) - norm_gene_vars.append(norm_gene_var.reshape(1, -1)) + norm_gene_vars.append(norm_gene_var) + if can_aggregate: + del adata.obs["__hvg_v3_batch_info__"] + if any(isinstance(e, DaskArray) for e in norm_gene_vars): + import dask.array as da + + norm_gene_vars = da.compute(*norm_gene_vars) + norm_gene_vars = [ngv.reshape(1, -1) for ngv in norm_gene_vars] norm_gene_vars = np.concatenate(norm_gene_vars, axis=0) # argsort twice gives ranks, small rank means most variable ranked_norm_gene_vars = np.argsort(np.argsort(-norm_gene_vars, axis=1), axis=1) From d839e98dd66733f63c647cea747c5fdbf5ac4e0a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 26 Mar 2026 14:48:51 +0100 Subject: [PATCH 02/15] chore: hvg v3 benchmark --- benchmarks/asv.conf.json | 2 +- benchmarks/benchmarks/preprocessing_log.py | 37 +++++++++++++++------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..2921d82f40 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -83,7 +83,7 @@ // "psutil": [""] "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 - // "scikit-misc": [""], + "scikit-misc": [""], }, // Combinations of libraries/python versions can be excluded/included diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 9633c8e208..84a356548c 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -47,17 +47,6 @@ def time_pca(self, *_) -> None: def peakmem_pca(self, *_) -> None: sc.pp.pca(self.adata, svd_solver="arpack") - def time_highly_variable_genes(self, *_) -> None: - # the default flavor runs on log-transformed data - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - - def peakmem_highly_variable_genes(self, *_) -> None: - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - # regress_out is very slow for this dataset @skip_when(dataset={"pbmc3k"}) def time_regress_out(self, *_) -> None: @@ -72,3 +61,29 @@ def time_scale(self, *_) -> None: def peakmem_scale(self, *_) -> None: sc.pp.scale(self.adata, max_value=10) + + +class HVGSuite: # noqa: D101 + params = (["seurat_v3", "cell_ranger", "seurat"],) + param_names = ("flavor",) + + def setup_cache(self) -> None: + """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" + adata, _ = get_dataset("pbmc3k") + adata.write_h5ad("pbmc3k.h5ad") + + def setup(self, flavor) -> None: + self.adata = ad.read_h5ad("pbmc3k.h5ad") + sc.pp.filter_genes(self.adata, min_cells=3) + self.flavor = flavor + + def time_highly_variable_genes(self, *_) -> None: + # the default flavor runs on log-transformed data + sc.pp.highly_variable_genes( + self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + ) + + def peakmem_highly_variable_genes(self, *_) -> None: + sc.pp.highly_variable_genes( + self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + ) From 86db4990934de4e1b1751512e05c6faca159b376 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 26 Mar 2026 14:56:52 +0100 Subject: [PATCH 03/15] fix: use counts --- benchmarks/benchmarks/preprocessing_log.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 84a356548c..cea301dc1d 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -80,10 +80,20 @@ def setup(self, flavor) -> None: def time_highly_variable_genes(self, *_) -> None: # the default flavor runs on log-transformed data sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + self.adata, + min_mean=0.0125, + max_mean=3, + min_disp=0.5, + flavor=self.flavor, + **({"layer": "counts"} if self.self.flavor == "seurat_v3" else {}), ) def peakmem_highly_variable_genes(self, *_) -> None: sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + self.adata, + min_mean=0.0125, + max_mean=3, + min_disp=0.5, + flavor=self.flavor, + **({"layer": "counts"} if self.self.flavor == "seurat_v3" else {}), ) From d5a6a7833f738ab1500f31a0d601eba51991487f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 26 Mar 2026 15:07:56 +0100 Subject: [PATCH 04/15] fix: use a batch key --- benchmarks/benchmarks/preprocessing_log.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index cea301dc1d..3ffc5560c1 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -69,11 +69,11 @@ class HVGSuite: # noqa: D101 def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" - adata, _ = get_dataset("pbmc3k") - adata.write_h5ad("pbmc3k.h5ad") + adata, _ = get_dataset("lung93k") + adata.write_h5ad("lung93k.h5ad") def setup(self, flavor) -> None: - self.adata = ad.read_h5ad("pbmc3k.h5ad") + self.adata = ad.read_h5ad("lung93k.h5ad") sc.pp.filter_genes(self.adata, min_cells=3) self.flavor = flavor @@ -85,7 +85,8 @@ def time_highly_variable_genes(self, *_) -> None: max_mean=3, min_disp=0.5, flavor=self.flavor, - **({"layer": "counts"} if self.self.flavor == "seurat_v3" else {}), + batch_key="PatientNumber", + **({"layer": "counts"} if self.flavor == "seurat_v3" else {}), ) def peakmem_highly_variable_genes(self, *_) -> None: @@ -95,5 +96,6 @@ def peakmem_highly_variable_genes(self, *_) -> None: max_mean=3, min_disp=0.5, flavor=self.flavor, + batch_key="PatientNumber", **({"layer": "counts"} if self.self.flavor == "seurat_v3" else {}), ) From fdc5653bb78715b8ba085a43c51ef102e9e66a4f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 26 Mar 2026 15:12:32 +0100 Subject: [PATCH 05/15] fix: not again --- benchmarks/benchmarks/preprocessing_log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 3ffc5560c1..6f597e7a66 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -97,5 +97,5 @@ def peakmem_highly_variable_genes(self, *_) -> None: min_disp=0.5, flavor=self.flavor, batch_key="PatientNumber", - **({"layer": "counts"} if self.self.flavor == "seurat_v3" else {}), + **({"layer": "counts"} if self.flavor == "seurat_v3" else {}), ) From 8f0e426cc95c6168c3c4704c36c6012580ee8dc5 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 15:54:29 +0200 Subject: [PATCH 06/15] fix: `compute` single pass! --- src/scanpy/preprocessing/_highly_variable_genes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 7224e87ebd..8a8ef4cee4 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -180,7 +180,9 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 aggregated_mean_var.layers[l] for l in ["mean", "var"] ) if isinstance(mean_global, DaskArray): - mean_global, var_global = mean_global.compute(), var_global.compute() + import dask.array as da + + mean_global, var_global = da.compute(mean_global, var_global) aggregated_mean_var.layers["mean"] = mean_global aggregated_mean_var.layers["var"] = var_global batch_info = batch_info.to_numpy() From 7e0390ee10fe2c38d382b97ad2ff0bf8e6280e6b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 9 Apr 2026 09:31:03 +0200 Subject: [PATCH 07/15] fix: unique --- src/scanpy/preprocessing/_highly_variable_genes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 10d59e1af9..9d06ef11dc 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -186,7 +186,7 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 aggregated_mean_var.layers["mean"] = mean_global aggregated_mean_var.layers["var"] = var_global batch_info = batch_info.to_numpy() - for b in batch_info: + for b in np.unique(batch_info): data_batch = data[batch_info == b] if can_aggregate: mean, var = ( From 96c16e91841bb80cda49564c5927a77a8351293c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 12:04:57 +0200 Subject: [PATCH 08/15] chore: add new `dask` benchmark --- benchmarks/benchmarks/preprocessing_log.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 6f597e7a66..3f10728cee 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -9,12 +9,15 @@ from typing import TYPE_CHECKING import anndata as ad +import numpy as np import scanpy as sc from ._utils import get_dataset, param_skipper if TYPE_CHECKING: + from typing import Literal + from ._utils import Dataset, KeyX @@ -64,16 +67,25 @@ def peakmem_scale(self, *_) -> None: class HVGSuite: # noqa: D101 - params = (["seurat_v3", "cell_ranger", "seurat"],) - param_names = ("flavor",) + params = (["seurat_v3", "cell_ranger", "seurat"], [True, False]) + param_names = ("flavor", "use_dask") def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" adata, _ = get_dataset("lung93k") adata.write_h5ad("lung93k.h5ad") - - def setup(self, flavor) -> None: - self.adata = ad.read_h5ad("lung93k.h5ad") + obs = np.arange(adata.shape[0]) + np.random.default_rng().shuffle(obs) + adata[obs].write_h5ad("lung93k_shuffled.h5ad") + + def setup( + self, + flavor: Literal["seurat_v3", "cell_ranger", "seurat"], + use_dask: bool, # noqa: FBT001 + ) -> None: + self.adata = ad.read_h5ad( + "lung93k_shuffled.h5ad" if use_dask else "lung93k.h5ad" + ) sc.pp.filter_genes(self.adata, min_cells=3) self.flavor = flavor From 478af4af39b46fb2e94ee7e9a0fbf10900681dc8 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 13:08:31 +0200 Subject: [PATCH 09/15] fix: actually use dask lol --- benchmarks/benchmarks/preprocessing_log.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 3f10728cee..e36aecfbdc 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -83,9 +83,14 @@ def setup( flavor: Literal["seurat_v3", "cell_ranger", "seurat"], use_dask: bool, # noqa: FBT001 ) -> None: - self.adata = ad.read_h5ad( - "lung93k_shuffled.h5ad" if use_dask else "lung93k.h5ad" - ) + if use_dask: + self.adata = ad.experimental.read_lazy("lung93k_shuffled.h5ad") + self.adata.obs = self.adata.obs.to_memory() + self.adata.var = self.adata.var.to_memory() + else: + self.adata = ad.read_h5ad( + "lung93k_shuffled.h5ad" if use_dask else "lung93k.h5ad" + ) sc.pp.filter_genes(self.adata, min_cells=3) self.flavor = flavor From 54db31b20ad894c7d7cab8d4585b3b3d9cac6bfa Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 13:46:24 +0200 Subject: [PATCH 10/15] chore: really do dask --- benchmarks/asv.conf.json | 1 + benchmarks/benchmarks/preprocessing_log.py | 21 ++++++++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index 2921d82f40..655b33d9a8 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -84,6 +84,7 @@ "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 "scikit-misc": [""], + "dask": [""], }, // Combinations of libraries/python versions can be excluded/included diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index e36aecfbdc..4a13febb12 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -10,6 +10,7 @@ import anndata as ad import numpy as np +import zarr import scanpy as sc @@ -73,10 +74,10 @@ class HVGSuite: # noqa: D101 def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" adata, _ = get_dataset("lung93k") - adata.write_h5ad("lung93k.h5ad") + adata.write_zarr("lung93k.zarr") obs = np.arange(adata.shape[0]) np.random.default_rng().shuffle(obs) - adata[obs].write_h5ad("lung93k_shuffled.h5ad") + adata[obs].write_zarr("lung93k_shuffled.zarr") def setup( self, @@ -84,12 +85,18 @@ def setup( use_dask: bool, # noqa: FBT001 ) -> None: if use_dask: - self.adata = ad.experimental.read_lazy("lung93k_shuffled.h5ad") - self.adata.obs = self.adata.obs.to_memory() - self.adata.var = self.adata.var.to_memory() + z = zarr.open("lung93k_shuffled.zarr") + self.adata = ad.AnnData( + obs=ad.io.read_elem(z["obs"]), + var=ad.io.read_elem(z["var"]), + layers={ + "counts": ad.experimental.read_elem_lazy(z["layers"]["counts"]) + }, + X=ad.experimental.read_elem_lazy(z["X"]), + ) else: - self.adata = ad.read_h5ad( - "lung93k_shuffled.h5ad" if use_dask else "lung93k.h5ad" + self.adata = ad.read_zarr( + "lung93k_shuffled.zarr" if use_dask else "lung93k.zarr" ) sc.pp.filter_genes(self.adata, min_cells=3) self.flavor = flavor From 4fe84c524380b15670bdcab8c7501d113beba9cc Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 14:13:20 +0200 Subject: [PATCH 11/15] fix: layers support --- src/scanpy/preprocessing/_highly_variable_genes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index d79e1c2f43..375e2cdacf 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -174,7 +174,7 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 if can_aggregate := (inplace or not adata.is_view): adata.obs["__hvg_v3_batch_info__"] = batch_info aggregated_mean_var = aggregate( - adata, by="__hvg_v3_batch_info__", func=["mean", "var"] + adata, by="__hvg_v3_batch_info__", func=["mean", "var"], layer=layer ) mean_global, var_global = ( aggregated_mean_var.layers[l] for l in ["mean", "var"] From 35590a4a90e7dfab222c937b211238975b1519cc Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 15:17:13 +0200 Subject: [PATCH 12/15] fix: no view check needed --- .../preprocessing/_highly_variable_genes.py | 61 +++++++++---------- 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 375e2cdacf..fc156e5ed3 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -171,41 +171,38 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 ) norm_gene_vars = [] - if can_aggregate := (inplace or not adata.is_view): - adata.obs["__hvg_v3_batch_info__"] = batch_info - aggregated_mean_var = aggregate( - adata, by="__hvg_v3_batch_info__", func=["mean", "var"], layer=layer - ) - mean_global, var_global = ( - aggregated_mean_var.layers[l] for l in ["mean", "var"] - ) - if isinstance(mean_global, DaskArray): - import dask.array as da + adata_agg = AnnData( + X=data, + var=pd.DataFrame(index=adata.var_names), + obs=pd.DataFrame( + index=adata.obs_names, data={"__hvg_v3_batch_info__": batch_info} + ), + ) + aggregated_mean_var = aggregate( + adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"], layer=layer + ) + mean_global, var_global = (aggregated_mean_var.layers[l] for l in ["mean", "var"]) + if isinstance(mean_global, DaskArray): + import dask.array as da - mean_global, var_global = da.compute(mean_global, var_global) - aggregated_mean_var.layers["mean"] = mean_global - aggregated_mean_var.layers["var"] = var_global + mean_global, var_global = da.compute(mean_global, var_global) + aggregated_mean_var.layers["mean"] = mean_global + aggregated_mean_var.layers["var"] = var_global batch_info = batch_info.to_numpy() for b in np.unique(batch_info): data_batch = data[batch_info == b] - if can_aggregate: - mean, var = ( - aggregated_mean_var[ - aggregated_mean_var.obs["__hvg_v3_batch_info__"] == b - ].layers[l] - for l in ["mean", "var"] - ) - if isinstance(mean, CSBase): - mean = mean.toarray() - mean = mean.ravel() - if isinstance(var, CSBase): - var = var.toarray() - var = var.ravel() - else: - mean, var = stats.mean_var(data_batch, axis=0, correction=1) - # These get computed anyway for loess - if isinstance(mean, DaskArray): - mean, var = mean.compute(), var.compute() + mean, var = ( + aggregated_mean_var[ + aggregated_mean_var.obs["__hvg_v3_batch_info__"] == b + ].layers[l] + for l in ["mean", "var"] + ) + if isinstance(mean, CSBase): + mean = mean.toarray() + mean = mean.ravel() + if isinstance(var, CSBase): + var = var.toarray() + var = var.ravel() estimat_var = np.zeros(data.shape[1], dtype=np.float64) if (not_const := var > 0).any(): y = np.log10(var[not_const]) @@ -228,8 +225,6 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 - 2 * batch_counts_sum * mean ) norm_gene_vars.append(norm_gene_var) - if can_aggregate: - del adata.obs["__hvg_v3_batch_info__"] if any(isinstance(e, DaskArray) for e in norm_gene_vars): import dask.array as da From db81d6eb60e7172743728aa613ea6fb1a75687cc Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 16:23:22 +0200 Subject: [PATCH 13/15] fix: no layers eeded --- benchmarks/benchmarks/preprocessing_log.py | 4 +--- src/scanpy/preprocessing/_highly_variable_genes.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 4a13febb12..6d6f6eef55 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -95,9 +95,7 @@ def setup( X=ad.experimental.read_elem_lazy(z["X"]), ) else: - self.adata = ad.read_zarr( - "lung93k_shuffled.zarr" if use_dask else "lung93k.zarr" - ) + self.adata = ad.read_zarr("lung93k.zarr") sc.pp.filter_genes(self.adata, min_cells=3) self.flavor = flavor diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index fc156e5ed3..4f7c38ccf2 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -179,7 +179,7 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 ), ) aggregated_mean_var = aggregate( - adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"], layer=layer + adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"] ) mean_global, var_global = (aggregated_mean_var.layers[l] for l in ["mean", "var"]) if isinstance(mean_global, DaskArray): From b37444e388367ea482f0f30b3d2a24c84f71d3e7 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 5 May 2026 12:06:32 +0200 Subject: [PATCH 14/15] fix: reduce number of batches --- benchmarks/benchmarks/preprocessing_log.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 6d6f6eef55..5ea2fa58fd 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -94,6 +94,10 @@ def setup( }, X=ad.experimental.read_elem_lazy(z["X"]), ) + # Times out on the benchmark machine with full dataset + self.adata = self.adata[ + self.adata.obs["PatientNumber"].isin(["1", "2"]) + ].copy() else: self.adata = ad.read_zarr("lung93k.zarr") sc.pp.filter_genes(self.adata, min_cells=3) From cf65665a6591de5158c94013474ac051358f065b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 5 May 2026 14:53:15 +0200 Subject: [PATCH 15/15] fix: a little bit more --- benchmarks/benchmarks/preprocessing_log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 5ea2fa58fd..70e658ffbf 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -96,7 +96,7 @@ def setup( ) # Times out on the benchmark machine with full dataset self.adata = self.adata[ - self.adata.obs["PatientNumber"].isin(["1", "2"]) + self.adata.obs["PatientNumber"].isin(["1", "2", "3"]) ].copy() else: self.adata = ad.read_zarr("lung93k.zarr")