diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..32b9ca1a8f 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -83,7 +83,7 @@ // "psutil": [""] "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 - // "scikit-misc": [""], + "scikit-misc": [""], }, // Combinations of libraries/python versions can be excluded/included @@ -104,7 +104,7 @@ // - environment_type // Environment type, as above. // - sys_platform - // Platform, as in sys.platform. Possible values for the common + // Platform, as in sys.platform. Possible values for the commonπ // cases: 'linux2', 'win32', 'cygwin', 'darwin'. // // "exclude": [ diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 9633c8e208..84a356548c 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -47,17 +47,6 @@ def time_pca(self, *_) -> None: def peakmem_pca(self, *_) -> None: sc.pp.pca(self.adata, svd_solver="arpack") - def time_highly_variable_genes(self, *_) -> None: - # the default flavor runs on log-transformed data - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - - def peakmem_highly_variable_genes(self, *_) -> None: - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - # regress_out is very slow for this dataset @skip_when(dataset={"pbmc3k"}) def time_regress_out(self, *_) -> None: @@ -72,3 +61,29 @@ def time_scale(self, *_) -> None: def peakmem_scale(self, *_) -> None: sc.pp.scale(self.adata, max_value=10) + + +class HVGSuite: # noqa: D101 + params = (["seurat_v3", "cell_ranger", "seurat"],) + param_names = ("flavor",) + + def setup_cache(self) -> None: + """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" + adata, _ = get_dataset("pbmc3k") + adata.write_h5ad("pbmc3k.h5ad") + + def setup(self, flavor) -> None: + self.adata = ad.read_h5ad("pbmc3k.h5ad") + sc.pp.filter_genes(self.adata, min_cells=3) + self.flavor = flavor + + def time_highly_variable_genes(self, *_) -> None: + # the default flavor runs on log-transformed data + sc.pp.highly_variable_genes( + self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + ) + + def peakmem_highly_variable_genes(self, *_) -> None: + sc.pp.highly_variable_genes( + self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + ) diff --git a/pyproject.toml b/pyproject.toml index 933ebd2538..cee5b01599 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -198,6 +198,8 @@ filterwarnings = [ "ignore:.*'(parseAll)'.*'(parse_all)':DeprecationWarning", # igraph vs leidenalg warning "ignore:The `igraph` implementation of leiden clustering:UserWarning", + "ignore:Detected unsupported threading environment:UserWarning", + "ignore:Cannot cache compiled function", # everybody uses this zarr 3 feature, including us, XArray, lots of data out there … "ignore:Consolidated metadata is currently not part:UserWarning", ] diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 26d5a7763e..dd62093cfd 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -13,7 +13,7 @@ from fast_array_utils import stats from .. import logging as logg -from .._compat import CSBase, CSRBase, DaskArray, old_positionals, warn +from .._compat import CSBase, CSRBase, DaskArray, njit, old_positionals, warn from .._settings import Verbosity, settings from .._utils import ( check_nonnegative_integers, @@ -92,31 +92,50 @@ def _(data_batch: CSBase, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray] return _sum_and_sum_squares_clipped( batch_counts.indices, batch_counts.data, + batch_counts.indptr, + n_rows=batch_counts.shape[0], n_cols=batch_counts.shape[1], clip_val=clip_val, - nnz=batch_counts.nnz, ) -# parallel=False needed for accuracy -@numba.njit(cache=True, parallel=False) # noqa: TID251 +@njit def _sum_and_sum_squares_clipped( - indices: NDArray[np.integer], - data: NDArray[np.floating], + indices: np.ndarray, + data: np.ndarray, + indptr: np.ndarray, *, + n_rows: int, n_cols: int, - clip_val: NDArray[np.float64], - nnz: int, -) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64) - batch_counts_sum = np.zeros(n_cols, dtype=np.float64) - for i in numba.prange(nnz): - idx = indices[i] - element = min(np.float64(data[i]), clip_val[idx]) - squared_batch_counts_sum[idx] += element**2 - batch_counts_sum[idx] += element + clip_val: np.ndarray, +): + n_threads = numba.get_num_threads() - return squared_batch_counts_sum, batch_counts_sum + # Each thread gets its own private buffer to avoid race conditions + sum_local = np.zeros((n_threads, n_cols), dtype=np.float64) + squared_local = np.zeros((n_threads, n_cols), dtype=np.float64) + + # We parallelize over the rows of the sparse matrix + for tid in numba.prange(n_threads): + for r in range(tid, n_rows, n_threads): + for i in range(indptr[r], indptr[r + 1]): + col_idx = indices[i] + val = np.float64(data[i]) + element = min(val, clip_val[col_idx]) + # Use the thread's private buffer slice + sum_local[tid, col_idx] += element + squared_local[tid, col_idx] += element**2 + + # Reduction phase (merging the thread buffers) + final_sum = np.zeros(n_cols, dtype=np.float64) + final_squared = np.zeros(n_cols, dtype=np.float64) + + for t in range(n_threads): + for c in range(n_cols): + final_sum[c] += sum_local[t, c] + final_squared[c] += squared_local[t, c] + + return final_squared, final_sum def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 @@ -176,13 +195,19 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 ) norm_gene_vars = [] - for b in np.unique(batch_info): + unique_batches = np.unique(batch_info) + n_batches = len(unique_batches) + + for b in unique_batches: data_batch = data[batch_info == b] - mean, var = stats.mean_var(data_batch, axis=0, correction=1) - # These get computed anyway for loess - if isinstance(mean, DaskArray): - mean, var = mean.compute(), var.compute() + if n_batches > 1: + mean, var = stats.mean_var(data_batch, axis=0, correction=1) + # Compute Dask arrays since loess requires in-memory data + if isinstance(mean, DaskArray): + mean, var = mean.compute(), var.compute() + else: + mean, var = df["means"].to_numpy(), df["variances"].to_numpy() not_const = var > 0 estimat_var = np.zeros(data.shape[1], dtype=np.float64)