From de9c481a14dde3208830d00e7a0b67c2adceb9ce Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 23 Feb 2026 16:26:42 +0100 Subject: [PATCH 01/10] feat: support np.random.Generator --- src/scanpy/_utils/random.py | 29 ++++++---- src/scanpy/datasets/_datasets.py | 9 ++-- src/scanpy/experimental/_docs.py | 4 +- src/scanpy/experimental/pp/_normalization.py | 5 +- src/scanpy/experimental/pp/_recipes.py | 6 ++- src/scanpy/neighbors/__init__.py | 29 +++++----- src/scanpy/preprocessing/_pca/__init__.py | 33 ++++++------ src/scanpy/preprocessing/_pca/_compat.py | 16 +++--- src/scanpy/preprocessing/_recipes.py | 6 +-- .../preprocessing/_scrublet/__init__.py | 53 +++++++++---------- src/scanpy/preprocessing/_scrublet/core.py | 22 ++++---- .../preprocessing/_scrublet/pipeline.py | 22 ++++---- .../preprocessing/_scrublet/sparse_utils.py | 7 +-- src/scanpy/preprocessing/_simple.py | 46 ++++++---------- src/scanpy/preprocessing/_utils.py | 7 +-- src/scanpy/tools/_diffmap.py | 15 +++--- src/scanpy/tools/_dpt.py | 13 +++-- src/scanpy/tools/_draw_graph.py | 36 ++++++------- src/scanpy/tools/_leiden.py | 16 +++--- src/scanpy/tools/_score_genes.py | 19 +++---- src/scanpy/tools/_tsne.py | 9 ++-- src/scanpy/tools/_umap.py | 29 +++++----- src/scanpy/tools/_utils.py | 9 ++-- 23 files changed, 223 insertions(+), 217 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index 98d6ce8a1b..aa405c0ce8 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING import numpy as np -from sklearn.utils import check_random_state from . import ensure_igraph @@ -23,6 +22,7 @@ "_LegacyRandom", "ith_k_tuple", "legacy_numpy_gen", + "legacy_random_state", "random_k_tuples", "random_str", ] @@ -43,29 +43,29 @@ class _RNGIgraph: See :func:`igraph.set_random_number_generator` for the requirements. """ - def __init__(self, random_state: int | np.random.RandomState = 0) -> None: - self._rng = check_random_state(random_state) + def __init__(self, rng: SeedLike | RNGLike | None) -> None: + self._rng = np.random.default_rng(rng) def getrandbits(self, k: int) -> int: - return self._rng.tomaxint() & ((1 << k) - 1) + lims = np.iinfo(np.uint64) + i = int(self._rng.integers(0, lims.max, dtype=np.uint64)) + return i & ((1 << k) - 1) - def randint(self, a: int, b: int) -> int: - return self._rng.randint(a, b + 1) + def randint(self, a: int, b: int) -> np.int64: + return self._rng.integers(a, b + 1) def __getattr__(self, attr: str): return getattr(self._rng, "normal" if attr == "gauss" else attr) @contextmanager -def set_igraph_random_state( - random_state: int | np.random.RandomState, -) -> Generator[None, None, None]: +def set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]: ensure_igraph() import igraph - rng = _RNGIgraph(random_state) + ig_rng = _RNGIgraph(rng) try: - igraph.set_random_number_generator(rng) + igraph.set_random_number_generator(ig_rng) yield None finally: igraph.set_random_number_generator(random) @@ -114,6 +114,13 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs): _FakeRandomGen._delegate() +def legacy_random_state(rng: SeedLike | RNGLike | None) -> np.random.RandomState: + rng = np.random.default_rng(rng) + if isinstance(rng, _FakeRandomGen): + return rng._state + return np.random.RandomState(rng.bit_generator.spawn(1)[0]) + + ################### # Random k-tuples # ################### diff --git a/src/scanpy/datasets/_datasets.py b/src/scanpy/datasets/_datasets.py index d6221fe8d2..116c8a572d 100644 --- a/src/scanpy/datasets/_datasets.py +++ b/src/scanpy/datasets/_datasets.py @@ -12,13 +12,14 @@ from .._compat import deprecated, old_positionals from .._settings import settings from .._utils._doctests import doctest_internet, doctest_needs +from .._utils.random import legacy_random_state from ..readwrite import read, read_h5ad, read_visium from ._utils import check_datasetdir_exists if TYPE_CHECKING: from typing import Literal - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike type VisiumSampleID = Literal[ "V1_Breast_Cancer_Block_A_Section_1", @@ -63,7 +64,7 @@ def blobs( n_centers: int = 5, cluster_std: float = 1.0, n_observations: int = 640, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> AnnData: """Gaussian Blobs. @@ -78,7 +79,7 @@ def blobs( n_observations Number of observations. By default, this is the same observation number as in :func:`scanpy.datasets.krumsiek11`. - random_state + rng Determines random number generation for dataset creation. Returns @@ -101,7 +102,7 @@ def blobs( n_features=n_variables, centers=n_centers, cluster_std=cluster_std, - random_state=random_state, + random_state=legacy_random_state(rng), ) return AnnData(x, obs=dict(blobs=y.astype(str))) diff --git a/src/scanpy/experimental/_docs.py b/src/scanpy/experimental/_docs.py index c6f1bf2f8b..a449c317f1 100644 --- a/src/scanpy/experimental/_docs.py +++ b/src/scanpy/experimental/_docs.py @@ -60,8 +60,8 @@ doc_pca_chunk = """\ n_comps Number of principal components to compute in the PCA step. -random_state - Random seed for setting the initial states for the optimization in the PCA step. +rng + Random number generator for setting the initial states for the optimization in the PCA step. kwargs_pca Dictionary of further keyword arguments passed on to `scanpy.pp.pca()`. """ diff --git a/src/scanpy/experimental/pp/_normalization.py b/src/scanpy/experimental/pp/_normalization.py index cb34b9902b..eddb899bb5 100644 --- a/src/scanpy/experimental/pp/_normalization.py +++ b/src/scanpy/experimental/pp/_normalization.py @@ -27,6 +27,7 @@ from typing import Any from ..._utils import Empty + from ..._utils.random import RNGLike, SeedLike def _pearson_residuals( @@ -166,7 +167,7 @@ def normalize_pearson_residuals_pca( theta: float = 100, clip: float | None = None, n_comps: int | None = 50, - random_state: float = 0, + rng: SeedLike | RNGLike | None = None, kwargs_pca: Mapping[str, Any] = MappingProxyType({}), mask_var: np.ndarray | str | None | Empty = _empty, use_highly_variable: bool | None = None, @@ -233,7 +234,7 @@ def normalize_pearson_residuals_pca( normalize_pearson_residuals( adata_pca, theta=theta, clip=clip, check_values=check_values ) - pca(adata_pca, n_comps=n_comps, random_state=random_state, **kwargs_pca) + pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca) n_comps = adata_pca.obsm["X_pca"].shape[1] # might be None if inplace: diff --git a/src/scanpy/experimental/pp/_recipes.py b/src/scanpy/experimental/pp/_recipes.py index 27d272fc4d..5b0357dc4d 100644 --- a/src/scanpy/experimental/pp/_recipes.py +++ b/src/scanpy/experimental/pp/_recipes.py @@ -24,6 +24,8 @@ import pandas as pd from anndata import AnnData + from ..._utils.random import RNGLike, SeedLike + @_doc_params( adata=doc_adata, @@ -42,7 +44,7 @@ def recipe_pearson_residuals( # noqa: PLR0913 batch_key: str | None = None, chunksize: int = 1000, n_comps: int | None = 50, - random_state: float | None = 0, + rng: SeedLike | RNGLike | None = None, kwargs_pca: Mapping[str, Any] = MappingProxyType({}), check_values: bool = True, inplace: bool = True, @@ -133,7 +135,7 @@ def recipe_pearson_residuals( # noqa: PLR0913 experimental.pp.normalize_pearson_residuals( adata_pca, theta=theta, clip=clip, check_values=check_values ) - pca(adata_pca, n_comps=n_comps, random_state=random_state, **kwargs_pca) + pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca) if inplace: normalization_param = adata_pca.uns["pearson_residuals_normalization"] diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 3efb392a1d..4fa2b5d597 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -11,13 +11,13 @@ import numpy as np import scipy from scipy import sparse -from sklearn.utils import check_random_state from .. import _utils from .. import logging as logg from .._compat import CSBase, CSRBase, SpBase, old_positionals, warn from .._settings import settings from .._utils import NeighborsView, _doc_params, get_literal_vals +from .._utils.random import legacy_random_state from . import _connectivity from ._common import ( _get_indices_distances_from_dense_matrix, @@ -36,7 +36,7 @@ from igraph import Graph from numpy.typing import NDArray - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike, _LegacyRandom from ._types import KnnTransformerLike, _Metric, _MetricFn # TODO: make `type` when https://github.com/sphinx-doc/sphinx/pull/13508 is released @@ -90,7 +90,7 @@ def neighbors( # noqa: PLR0913 transformer: KnnTransformerLike | _KnownTransformer | None = None, metric: _Metric | _MetricFn | None = None, metric_kwds: Mapping[str, Any] = MappingProxyType({}), - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, key_added: str | None = None, copy: bool = False, ) -> AnnData | None: @@ -158,8 +158,8 @@ def neighbors( # noqa: PLR0913 Options for the metric. *ignored if ``transformer`` is an instance.* - random_state - A numpy random seed. + rng + A numpy random number generator. *ignored if ``transformer`` is an instance.* key_added @@ -220,14 +220,14 @@ def neighbors( # noqa: PLR0913 transformer=transformer, metric=metric, metric_kwds=metric_kwds, - random_state=random_state, + rng=rng, ) else: params = locals() if ignored := { p.name for p in signature(neighbors).parameters.values() - if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds", "random_state"} + if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds", "rng"} if params[p.name] != p.default }: warn( @@ -262,7 +262,7 @@ def neighbors( # noqa: PLR0913 key_added, n_neighbors=neighbors_.n_neighbors, method=method, - random_state=random_state, + random_state=rng, metric=metric, **({} if not metric_kwds else dict(metric_kwds=metric_kwds)), **({} if use_rep is None else dict(use_rep=use_rep)), @@ -583,7 +583,7 @@ def compute_neighbors( transformer: KnnTransformerLike | _KnownTransformer | None = None, metric: _Metric | _MetricFn = "euclidean", metric_kwds: Mapping[str, Any] = MappingProxyType({}), - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> None: """Compute distances and connectivities of neighbors. @@ -619,7 +619,7 @@ def compute_neighbors( n_neighbors=n_neighbors, metric=metric, metric_params=metric_kwds, # most use _params, not _kwds - random_state=random_state, + random_state=legacy_random_state(rng), ) method, transformer, shortcut = self._handle_transformer( method, transformer, knn=knn, kwds=transformer_kwds_default @@ -848,7 +848,7 @@ def compute_eigen( n_comps: int = 15, sym: bool | None = None, sort: Literal["decrease", "increase"] = "decrease", - random_state: _LegacyRandom = 0, + rng: np.random.Generator, ): """Compute eigen decomposition of transition matrix. @@ -861,8 +861,8 @@ def compute_eigen( Instead of computing the eigendecomposition of the assymetric transition matrix, computed the eigendecomposition of the symmetric Ktilde matrix. - random_state - A numpy random seed + rng + A numpy random number generator Returns ------- @@ -895,8 +895,7 @@ def compute_eigen( matrix = matrix.astype(np.float64) # Setting the random initial vector - random_state = check_random_state(random_state) - v0 = random_state.standard_normal(matrix.shape[0]) + v0 = rng.standard_normal(matrix.shape[0]) evals, evecs = sparse.linalg.eigsh( matrix, k=n_comps, which=which, ncv=ncv, v0=v0 ) diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index b370c09d7b..87ad1752ee 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -5,12 +5,12 @@ import numpy as np from anndata import AnnData from packaging.version import Version -from sklearn.utils import check_random_state from ... import logging as logg from ..._compat import CSBase, DaskArray, pkg_version, warn from ..._settings import settings from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type +from ..._utils.random import legacy_random_state from ...get import _check_mask, _get_obs_rep from .._docs import doc_mask_var_hvg from ._compat import _pca_compat_sparse @@ -25,7 +25,7 @@ from numpy.typing import DTypeLike, NDArray from ..._utils import Empty - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike type MethodDaskML = type[dmld.PCA | dmld.IncrementalPCA | dmld.TruncatedSVD] @@ -64,7 +64,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 svd_solver: SvdSolver | None = None, chunked: bool = False, chunk_size: int | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, return_info: bool = False, mask_var: NDArray[np.bool_] | str | None | Empty = _empty, use_highly_variable: bool | None = None, @@ -157,7 +157,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 chunk_size Number of observations to include in each chunk. Required if `chunked=True` was passed. - random_state + rng Change to use different initial states for the optimization. return_info Only relevant when not passing an :class:`~anndata.AnnData`: @@ -241,18 +241,17 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 msg = f"PCA is not implemented for matrices of type {type(x)} from layers/obsm" raise NotImplementedError(msg) - # check_random_state returns a numpy RandomState when passed an int but # dask needs an int for random state if not isinstance(x, DaskArray): - random_state = check_random_state(random_state) - elif not isinstance(random_state, int): - msg = f"random_state needs to be an int, not a {type(random_state).__name__} when passing a dask array" + rng = np.random.default_rng(rng) + elif not isinstance(rng, int): + msg = f"rng needs to be an int, not a {type(rng).__name__} when passing a dask array" raise TypeError(msg) if chunked: if ( not zero_center - or random_state + or rng is not None or (svd_solver is not None and svd_solver != "arpack") ): logg.debug("Ignoring zero_center, random_state, svd_solver") @@ -287,9 +286,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 "Also the lobpcg solver has been observed to be inaccurate. Please use 'arpack' instead." ) warn(msg, FutureWarning) - x_pca, pca_ = _pca_compat_sparse( - x, n_comps, solver=svd_solver, random_state=random_state - ) + x_pca, pca_ = _pca_compat_sparse(x, n_comps, solver=svd_solver, rng=rng) else: if not isinstance(x, DaskArray): from sklearn.decomposition import PCA @@ -300,13 +297,13 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 pca_ = PCA( n_components=n_comps, svd_solver=svd_solver, - random_state=random_state, + random_state=legacy_random_state(rng), ) elif isinstance(x._meta, CSBase) or svd_solver == "covariance_eigh": from ._dask import PCAEighDask - if random_state != 0: - msg = f"Ignoring {random_state=} when using a sparse dask array" + if rng is not None: + msg = f"Ignoring {rng=} when using a sparse dask array" warn(msg, UserWarning) if svd_solver not in {None, "covariance_eigh"}: msg = f"Ignoring {svd_solver=} when using a sparse dask array" @@ -319,7 +316,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 pca_ = PCA( n_components=n_comps, svd_solver=svd_solver, - random_state=random_state, + random_state=legacy_random_state(rng), ) x_pca = pca_.fit_transform(x) else: @@ -345,7 +342,9 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 " the following components often resemble the exact PCA very closely" ) pca_ = TruncatedSVD( - n_components=n_comps, random_state=random_state, algorithm=svd_solver + n_components=n_comps, + random_state=legacy_random_state(rng), + algorithm=svd_solver, ) x_pca = pca_.fit_transform(x) diff --git a/src/scanpy/preprocessing/_pca/_compat.py b/src/scanpy/preprocessing/_pca/_compat.py index b1c64e0735..133be7b0f0 100644 --- a/src/scanpy/preprocessing/_pca/_compat.py +++ b/src/scanpy/preprocessing/_pca/_compat.py @@ -6,10 +6,11 @@ from fast_array_utils.stats import mean_var from packaging.version import Version from scipy.sparse.linalg import LinearOperator, svds -from sklearn.utils import check_array, check_random_state +from sklearn.utils import check_array from sklearn.utils.extmath import svd_flip from ..._compat import pkg_version +from ..._utils.random import legacy_random_state if TYPE_CHECKING: from typing import Literal @@ -18,7 +19,7 @@ from sklearn.decomposition import PCA from ..._compat import CSBase - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike def _pca_compat_sparse( @@ -27,12 +28,11 @@ def _pca_compat_sparse( *, solver: Literal["arpack", "lobpcg"], mu: NDArray[np.floating] | None = None, - random_state: _LegacyRandom = None, + rng: SeedLike | RNGLike | None = None, ) -> tuple[NDArray[np.floating], PCA]: """Sparse PCA for scikit-learn <1.4.""" - random_state = check_random_state(random_state) - np.random.set_state(random_state.get_state()) - random_init = np.random.rand(np.min(x.shape)) + rng = np.random.default_rng(rng) + random_init = rng.uniform(size=np.min(x.shape)) x = check_array(x, accept_sparse=["csr", "csc"]) if mu is None: @@ -70,7 +70,9 @@ def rmat_op(v: NDArray[np.floating]): from sklearn.decomposition import PCA - pca = PCA(n_components=n_pcs, svd_solver=solver, random_state=random_state) + pca = PCA( + n_components=n_pcs, svd_solver=solver, random_state=legacy_random_state(rng) + ) pca.explained_variance_ = ev pca.explained_variance_ratio_ = ev_ratio pca.components_ = v diff --git a/src/scanpy/preprocessing/_recipes.py b/src/scanpy/preprocessing/_recipes.py index d223036873..34cf986d57 100644 --- a/src/scanpy/preprocessing/_recipes.py +++ b/src/scanpy/preprocessing/_recipes.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike @old_positionals( @@ -36,7 +36,7 @@ def recipe_weinreb17( cv_threshold: int = 2, n_pcs: int = 50, svd_solver="randomized", - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, copy: bool = False, ) -> AnnData | None: """Normalize and filter as of :cite:p:`Weinreb2017`. @@ -72,7 +72,7 @@ def recipe_weinreb17( zscore_deprecated(adata.X), n_comps=n_pcs, svd_solver=svd_solver, - random_state=random_state, + rng=rng, ) # update adata adata.obsm["X_pca"] = x_pca diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index b8ea1a267c..b5c7d60a55 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -16,7 +16,7 @@ from .core import Scrublet if TYPE_CHECKING: - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike from ...neighbors import _Metric, _MetricFn @@ -59,7 +59,7 @@ def scrublet( # noqa: PLR0913 threshold: float | None = None, verbose: bool = True, copy: bool = False, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> AnnData | None: """Predict doublets using Scrublet :cite:p:`Wolock2019`. @@ -147,7 +147,7 @@ def scrublet( # noqa: PLR0913 copy If :data:`True`, return a copy of the input ``adata`` with Scrublet results added. Otherwise, Scrublet results are added in place. - random_state + rng Initial state for doublet simulation and nearest neighbors. Returns @@ -178,6 +178,7 @@ def scrublet( # noqa: PLR0913 scores for observed transcriptomes and simulated doublets. """ + rng = np.random.default_rng(rng) if threshold is None and not find_spec("skimage"): # pragma: no cover # Scrublet.call_doublets requires `skimage` with `threshold=None` but PCA # is called early, which is wasteful if there is not `skimage` @@ -224,7 +225,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): layer="raw", sim_doublet_ratio=sim_doublet_ratio, synthetic_doublet_umi_subsampling=synthetic_doublet_umi_subsampling, - random_seed=random_state, + rng=rng, ) del ad_obs.layers["raw"] if log_transform: @@ -249,7 +250,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): knn_dist_metric=knn_dist_metric, get_doublet_neighbor_parents=get_doublet_neighbor_parents, threshold=threshold, - random_state=random_state, + rng=rng, verbose=verbose, ) @@ -307,18 +308,18 @@ def _scrublet_call_doublets( # noqa: PLR0913 adata_obs: AnnData, adata_sim: AnnData, *, - n_neighbors: int | None = None, - expected_doublet_rate: float = 0.05, - stdev_doublet_rate: float = 0.02, - mean_center: bool = True, - normalize_variance: bool = True, - n_prin_comps: int = 30, - use_approx_neighbors: bool | None = None, - knn_dist_metric: _Metric | _MetricFn = "euclidean", - get_doublet_neighbor_parents: bool = False, - threshold: float | None = None, - random_state: _LegacyRandom = 0, - verbose: bool = True, + n_neighbors: int | None, + expected_doublet_rate: float, + stdev_doublet_rate: float, + mean_center: bool, + normalize_variance: bool, + n_prin_comps: int, + use_approx_neighbors: bool | None, + knn_dist_metric: _Metric | _MetricFn, + get_doublet_neighbor_parents: bool, + threshold: float | None, + rng: np.random.Generator, + verbose: bool, ) -> AnnData: """Core function for predicting doublets using Scrublet :cite:p:`Wolock2019`. @@ -376,8 +377,8 @@ def _scrublet_call_doublets( # noqa: PLR0913 practice to check the threshold visually using the `doublet_scores_sim_` histogram and/or based on co-localization of predicted doublets in a 2-D embedding. - random_state - Initial state for doublet simulation and nearest neighbors. + rng + Random number generator for doublet simulation and nearest neighbors. verbose If :data:`True`, log progress updates. @@ -414,7 +415,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 n_neighbors=n_neighbors, expected_doublet_rate=expected_doublet_rate, stdev_doublet_rate=stdev_doublet_rate, - random_state=random_state, + rng=rng, ) # Ensure normalised matrix sparseness as Scrublet does @@ -440,12 +441,10 @@ def _scrublet_call_doublets( # noqa: PLR0913 if mean_center: logg.info("Embedding transcriptomes using PCA...") - pipeline.pca(scrub, n_prin_comps=n_prin_comps, random_state=scrub._random_state) + pipeline.pca(scrub, n_prin_comps=n_prin_comps, rng=scrub._rng) else: logg.info("Embedding transcriptomes using Truncated SVD...") - pipeline.truncated_svd( - scrub, n_prin_comps=n_prin_comps, random_state=scrub._random_state - ) + pipeline.truncated_svd(scrub, n_prin_comps=n_prin_comps, rng=scrub._rng) # Score the doublets @@ -477,7 +476,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 .get("sim_doublet_ratio", None) ), "n_neighbors": n_neighbors, - "random_state": random_state, + "rng": rng, }, } @@ -511,7 +510,7 @@ def scrublet_simulate_doublets( layer: str | None = None, sim_doublet_ratio: float = 2.0, synthetic_doublet_umi_subsampling: float = 1.0, - random_seed: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> AnnData: """Simulate doublets by adding the counts of random observed transcriptome pairs. @@ -556,7 +555,7 @@ def scrublet_simulate_doublets( """ x = _get_obs_rep(adata, layer=layer) - scrub = Scrublet(x, random_state=random_seed) + scrub = Scrublet(x, rng=rng) scrub.simulate_doublets( sim_doublet_ratio=sim_doublet_ratio, diff --git a/src/scanpy/preprocessing/_scrublet/core.py b/src/scanpy/preprocessing/_scrublet/core.py index a5cb80fd43..511feac334 100644 --- a/src/scanpy/preprocessing/_scrublet/core.py +++ b/src/scanpy/preprocessing/_scrublet/core.py @@ -7,7 +7,6 @@ import pandas as pd from anndata import AnnData, concat from scipy import sparse -from sklearn.utils import check_random_state from ... import logging as logg from ...neighbors import ( @@ -18,11 +17,10 @@ from .sparse_utils import subsample_counts if TYPE_CHECKING: - from numpy.random import RandomState from numpy.typing import NDArray from ..._compat import CSBase, CSCBase - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike from ...neighbors import _Metric, _MetricFn __all__ = ["Scrublet"] @@ -58,8 +56,8 @@ class Scrublet: stdev_doublet_rate Uncertainty in the expected doublet rate. - random_state - Random state for doublet simulation, approximate + rng + Random number generator for doublet simulation, approximate nearest neighbor search, and PCA/TruncatedSVD. """ @@ -72,12 +70,12 @@ class Scrublet: n_neighbors: InitVar[int | None] = None expected_doublet_rate: float = 0.1 stdev_doublet_rate: float = 0.02 - random_state: InitVar[_LegacyRandom] = 0 + rng: InitVar[SeedLike | RNGLike | None] = None # private fields _n_neighbors: int = field(init=False, repr=False) - _random_state: RandomState = field(init=False, repr=False) + _rng: np.random.Generator = field(init=False, repr=False) _counts_obs: CSCBase = field(init=False, repr=False) _total_counts_obs: NDArray[np.integer] = field(init=False, repr=False) @@ -169,7 +167,7 @@ def __post_init__( counts_obs: CSBase | NDArray[np.integer], total_counts_obs: NDArray[np.integer] | None, n_neighbors: int | None, - random_state: _LegacyRandom, + rng: SeedLike | RNGLike | None, ) -> None: self._counts_obs = sparse.csc_matrix(counts_obs) # noqa: TID251 self._total_counts_obs = ( @@ -182,7 +180,7 @@ def __post_init__( if n_neighbors is None else n_neighbors ) - self._random_state = check_random_state(random_state) + self._rng = np.random.default_rng(rng) def simulate_doublets( self, @@ -218,7 +216,7 @@ def simulate_doublets( n_obs = self._counts_obs.shape[0] n_sim = int(n_obs * sim_doublet_ratio) - pair_ix = sample_comb((n_obs, n_obs), n_sim, random_state=self._random_state) + pair_ix = sample_comb((n_obs, n_obs), n_sim, rng=self._rng) e1 = cast("CSCBase", self._counts_obs[pair_ix[:, 0], :]) e2 = cast("CSCBase", self._counts_obs[pair_ix[:, 1], :]) @@ -229,7 +227,7 @@ def simulate_doublets( e1 + e2, rate=synthetic_doublet_umi_subsampling, original_totals=tots1 + tots2, - random_seed=self._random_state, + rng=self._rng, ) else: self._counts_sim = e1 + e2 @@ -348,7 +346,7 @@ def _nearest_neighbor_classifier( knn=True, transformer=transformer, method=None, - random_state=self._random_state, + rng=self._rng, ) neighbors, _ = _get_indices_distances_from_sparse_matrix(knn.distances, k_adj) if use_approx_neighbors: diff --git a/src/scanpy/preprocessing/_scrublet/pipeline.py b/src/scanpy/preprocessing/_scrublet/pipeline.py index edc3417cd9..53d98ff8ed 100644 --- a/src/scanpy/preprocessing/_scrublet/pipeline.py +++ b/src/scanpy/preprocessing/_scrublet/pipeline.py @@ -6,12 +6,12 @@ from fast_array_utils.stats import mean_var from scipy import sparse +from ..._utils.random import legacy_random_state from .sparse_utils import sparse_multiply, sparse_zscore if TYPE_CHECKING: from typing import Literal - from ..._utils.random import _LegacyRandom from .core import Scrublet @@ -46,10 +46,10 @@ def zscore(self: Scrublet) -> None: def truncated_svd( self: Scrublet, - n_prin_comps: int = 30, + n_prin_comps: int, *, - random_state: _LegacyRandom = 0, - algorithm: Literal["arpack", "randomized"] = "arpack", + rng: np.random.Generator, + algorithm: Literal["arpack", "randomized"], ) -> None: if self._counts_sim_norm is None: msg = "_counts_sim_norm is not set" @@ -57,7 +57,9 @@ def truncated_svd( from sklearn.decomposition import TruncatedSVD svd = TruncatedSVD( - n_components=n_prin_comps, random_state=random_state, algorithm=algorithm + n_components=n_prin_comps, + random_state=legacy_random_state(rng), + algorithm=algorithm, ).fit(self._counts_obs_norm) self.set_manifold( svd.transform(self._counts_obs_norm), svd.transform(self._counts_sim_norm) @@ -66,10 +68,10 @@ def truncated_svd( def pca( self: Scrublet, - n_prin_comps: int = 50, + n_prin_comps: int, *, - random_state: _LegacyRandom = 0, - svd_solver: Literal["auto", "full", "arpack", "randomized"] = "arpack", + rng: np.random.Generator, + svd_solver: Literal["auto", "full", "arpack", "randomized"], ) -> None: if self._counts_sim_norm is None: msg = "_counts_sim_norm is not set" @@ -80,6 +82,8 @@ def pca( x_sim = self._counts_sim_norm.toarray() pca = PCA( - n_components=n_prin_comps, random_state=random_state, svd_solver=svd_solver + n_components=n_prin_comps, + random_state=legacy_random_state(rng), + svd_solver=svd_solver, ).fit(x_obs) self.set_manifold(pca.transform(x_obs), pca.transform(x_sim)) diff --git a/src/scanpy/preprocessing/_scrublet/sparse_utils.py b/src/scanpy/preprocessing/_scrublet/sparse_utils.py index 5b7e7aaaaf..611754f91b 100644 --- a/src/scanpy/preprocessing/_scrublet/sparse_utils.py +++ b/src/scanpy/preprocessing/_scrublet/sparse_utils.py @@ -5,13 +5,11 @@ import numpy as np from fast_array_utils.stats import mean_var from scipy import sparse -from sklearn.utils import check_random_state if TYPE_CHECKING: from numpy.typing import NDArray from ..._compat import CSBase - from ..._utils.random import _LegacyRandom def sparse_multiply( @@ -48,11 +46,10 @@ def subsample_counts( *, rate: float, original_totals, - random_seed: _LegacyRandom = 0, + rng: np.random.Generator, ) -> tuple[CSBase, NDArray[np.int64]]: if rate < 1: - random_seed = check_random_state(random_seed) - e.data = random_seed.binomial(np.round(e.data).astype(int), rate) + e.data = rng.binomial(np.round(e.data).astype(int), rate) current_totals = np.asarray(e.sum(1)).squeeze() unsampled_orig_totals = original_totals - current_totals unsampled_downsamp_totals = np.random.binomial( diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 9eebeb0246..9d6abf1a38 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -41,7 +41,7 @@ import pandas as pd from numpy.typing import NDArray - from .._utils.random import RNGLike, SeedLike, _LegacyRandom + from .._utils.random import RNGLike, SeedLike @old_positionals( @@ -853,7 +853,7 @@ def sample( fraction: float | None = None, *, n: int | None = None, - rng: RNGLike | SeedLike | None = 0, + rng: RNGLike | SeedLike | None = None, copy: Literal[False] = False, replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", @@ -996,7 +996,7 @@ def downsample_counts( counts_per_cell: int | Collection[int] | None = None, total_counts: int | None = None, *, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, replace: bool = False, copy: bool = False, ) -> AnnData | None: @@ -1020,7 +1020,7 @@ def downsample_counts( total_counts Target total counts. If the count matrix has more than `total_counts` it will be downsampled to have this number. - random_state + rng Random seed for subsampling. replace Whether to sample the counts with replacement. @@ -1037,6 +1037,7 @@ def downsample_counts( """ raise_not_implemented_error_if_backed_type(adata.X, "downsample_counts") # This logic is all dispatch + rng = np.random.default_rng(rng) total_counts_call = total_counts is not None counts_per_cell_call = counts_per_cell is not None if total_counts_call is counts_per_cell_call: @@ -1046,11 +1047,11 @@ def downsample_counts( adata = adata.copy() if total_counts_call: adata.X = _downsample_total_counts( - adata.X, total_counts, random_state=random_state, replace=replace + adata.X, total_counts, rng=rng, replace=replace ) elif counts_per_cell_call: adata.X = _downsample_per_cell( - adata.X, counts_per_cell, random_state=random_state, replace=replace + adata.X, counts_per_cell, rng=rng, replace=replace ) if copy: return adata @@ -1061,7 +1062,7 @@ def _downsample_per_cell( /, counts_per_cell: int, *, - random_state: _LegacyRandom, + rng: np.random.Generator, replace: bool, ) -> CSBase: n_obs = x.shape[0] @@ -1088,11 +1089,7 @@ def _downsample_per_cell( for rowidx in under_target: row = rows[rowidx] _downsample_array( - row, - counts_per_cell[rowidx], - random_state=random_state, - replace=replace, - inplace=True, + row, counts_per_cell[rowidx], rng=rng, replace=replace, inplace=True ) x.eliminate_zeros() if not issubclass(original_type, CSRBase): # Put it back @@ -1103,11 +1100,7 @@ def _downsample_per_cell( for rowidx in under_target: row = x[rowidx, :] _downsample_array( - row, - counts_per_cell[rowidx], - random_state=random_state, - replace=replace, - inplace=True, + row, counts_per_cell[rowidx], rng=rng, replace=replace, inplace=True ) return x @@ -1117,7 +1110,7 @@ def _downsample_total_counts( /, total_counts: int, *, - random_state: _LegacyRandom, + rng: np.random.Generator, replace: bool, ) -> CSBase: total_counts = int(total_counts) @@ -1128,21 +1121,13 @@ def _downsample_total_counts( original_type = type(x) if not isinstance(x, CSRBase): x = x.tocsr() - _downsample_array( - x.data, - total_counts, - random_state=random_state, - replace=replace, - inplace=True, - ) + _downsample_array(x.data, total_counts, rng=rng, replace=replace, inplace=True) x.eliminate_zeros() if not issubclass(original_type, CSRBase): x = original_type(x) else: v = x.reshape(np.multiply(*x.shape)) - _downsample_array( - v, total_counts, random_state=random_state, replace=replace, inplace=True - ) + _downsample_array(v, total_counts, rng=rng, replace=replace, inplace=True) return x @@ -1152,7 +1137,7 @@ def _downsample_array( col: np.ndarray, target: int, *, - random_state: _LegacyRandom = 0, + rng: np.random.Generator, replace: bool = True, inplace: bool = False, ): @@ -1162,14 +1147,13 @@ def _downsample_array( * total counts in cell must be less than target """ - np.random.seed(random_state) cumcounts = col.cumsum() if inplace: col[:] = 0 else: col = np.zeros_like(col) total = np.int_(cumcounts[-1]) - sample = np.random.choice(total, target, replace=replace) + sample = rng.choice(total, target, replace=replace) sample.sort() geneptr = 0 for count in sample: diff --git a/src/scanpy/preprocessing/_utils.py b/src/scanpy/preprocessing/_utils.py index 0ba97c7b4e..3da0fed2d5 100644 --- a/src/scanpy/preprocessing/_utils.py +++ b/src/scanpy/preprocessing/_utils.py @@ -5,24 +5,25 @@ import numpy as np from sklearn.random_projection import sample_without_replacement +from .._utils.random import legacy_random_state + if TYPE_CHECKING: from typing import Literal from numpy.typing import NDArray - from .._utils.random import _LegacyRandom - def sample_comb( dims: tuple[int, ...], nsamp: int, *, - random_state: _LegacyRandom = None, + rng: np.random.Generator, method: Literal[ "auto", "tracking_selection", "reservoir_sampling", "pool" ] = "auto", ) -> NDArray[np.int64]: """Randomly sample indices from a grid, without repeating the same tuple.""" + random_state = legacy_random_state(rng) idx = sample_without_replacement( np.prod(dims), nsamp, random_state=random_state, method=method ) diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index 62fa55b92e..eae7050fda 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -2,13 +2,15 @@ from typing import TYPE_CHECKING +import numpy as np + from .._compat import old_positionals from ._dpt import _diffmap if TYPE_CHECKING: from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike @old_positionals("neighbors_key", "random_state", "copy") @@ -17,7 +19,7 @@ def diffmap( n_comps: int = 15, *, neighbors_key: str | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, copy: bool = False, ) -> AnnData | None: """Diffusion Maps :cite:p:`Coifman2005,Haghverdi2015,Wolf2018`. @@ -50,8 +52,8 @@ def diffmap( .obsp[.uns[neighbors_key]['connectivities_key']] and .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances, respectively. - random_state - A numpy random seed + rng + A numpy random number generator copy Return a copy instead of writing to adata. @@ -75,6 +77,7 @@ def diffmap( e.g. `adata.obsm["X_diffmap"][:,1]` """ + rng = np.random.default_rng(rng) if neighbors_key is None: neighbors_key = "neighbors" @@ -85,7 +88,5 @@ def diffmap( msg = "Provide any value greater than 2 for `n_comps`. " raise ValueError(msg) adata = adata.copy() if copy else adata - _diffmap( - adata, n_comps=n_comps, neighbors_key=neighbors_key, random_state=random_state - ) + _diffmap(adata, n_comps=n_comps, neighbors_key=neighbors_key, rng=rng) return adata if copy else None diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 2a52633026..7d8e9dc661 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -9,6 +9,7 @@ from .. import logging as logg from .._compat import old_positionals +from .._utils.random import legacy_numpy_gen from ..neighbors import Neighbors, OnFlySymMatrix if TYPE_CHECKING: @@ -17,11 +18,17 @@ from anndata import AnnData -def _diffmap(adata, n_comps=15, neighbors_key=None, random_state=0): +def _diffmap( + adata: AnnData, + n_comps: int = 15, + *, + neighbors_key: str | None, + rng: np.random.Generator, +) -> None: start = logg.info(f"computing Diffusion Maps using {n_comps=}(=n_dcs)") dpt = DPT(adata, neighbors_key=neighbors_key) dpt.compute_transitions() - dpt.compute_eigen(n_comps=n_comps, random_state=random_state) + dpt.compute_eigen(n_comps=n_comps, rng=rng) adata.obsm["X_diffmap"] = dpt.eigen_basis adata.uns["diffmap_evals"] = dpt.eigen_values logg.info( @@ -144,7 +151,7 @@ def dpt( "Trying to run `tl.dpt` without prior call of `tl.diffmap`. " "Falling back to `tl.diffmap` with default parameters." ) - _diffmap(adata, neighbors_key=neighbors_key) + _diffmap(adata, neighbors_key=neighbors_key, rng=legacy_numpy_gen(0)) # start with the actual computation dpt = DPT( adata, diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index a76bad1c30..5a6cc6eec3 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -1,6 +1,5 @@ from __future__ import annotations -import random from importlib.util import find_spec from typing import TYPE_CHECKING, Literal @@ -10,6 +9,7 @@ from .. import logging as logg from .._compat import old_positionals from .._utils import _choose_graph, get_literal_vals +from .._utils.random import set_igraph_rng from ._utils import get_init_pos_from_paga if TYPE_CHECKING: @@ -18,7 +18,7 @@ from anndata import AnnData from .._compat import SpBase - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike type _Layout = Literal["fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa"] @@ -41,7 +41,7 @@ def draw_graph( # noqa: PLR0913 *, init_pos: str | bool | None = None, root: int | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, n_jobs: int | None = None, adjacency: SpBase | None = None, key_added_ext: str | None = None, @@ -83,7 +83,7 @@ def draw_graph( # noqa: PLR0913 'rt' (Reingold Tilford tree layout). root Root for tree layouts. - random_state + rng For layouts with random initialization like 'fr', change this to use different intial states for the optimization. If `None`, no seed is set. adjacency @@ -123,6 +123,7 @@ def draw_graph( # noqa: PLR0913 """ start = logg.info(f"drawing single-cell graph using layout {layout!r}") + rng = np.random.default_rng(rng) if layout not in (layouts := get_literal_vals(_Layout)): msg = f"Provide a valid layout, one of {layouts}." raise ValueError(msg) @@ -136,33 +137,30 @@ def draw_graph( # noqa: PLR0913 init_coords = get_init_pos_from_paga( adata, adjacency, - random_state=random_state, + rng=rng, neighbors_key=neighbors_key, obsp=obsp, ) else: - np.random.seed(random_state) - init_coords = np.random.random((adjacency.shape[0], 2)) + init_coords = rng.random((adjacency.shape[0], 2)) layout = coerce_fa2_layout(layout) # actual drawing if layout == "fa": positions = np.array(fa2_positions(adjacency, init_coords, **kwds)) else: - # igraph doesn't use numpy seed - random.seed(random_state) - g = _utils.get_igraph_from_adjacency(adjacency) - if layout in {"fr", "drl", "kk", "grid_fr"}: - ig_layout = g.layout(layout, seed=init_coords.tolist(), **kwds) - elif "rt" in layout: - if root is not None: - root = [root] - ig_layout = g.layout(layout, root=root, **kwds) - else: - ig_layout = g.layout(layout, **kwds) + with set_igraph_rng(rng): + if layout in {"fr", "drl", "kk", "grid_fr"}: + ig_layout = g.layout(layout, seed=init_coords.tolist(), **kwds) + elif "rt" in layout: + if root is not None: + root = [root] + ig_layout = g.layout(layout, root=root, **kwds) + else: + ig_layout = g.layout(layout, **kwds) positions = np.array(ig_layout.coords) adata.uns["draw_graph"] = {} - adata.uns["draw_graph"]["params"] = dict(layout=layout, random_state=random_state) + adata.uns["draw_graph"]["params"] = dict(layout=layout, random_state=rng) key_added = f"X_draw_graph_{key_added_ext or layout}" adata.obsm[key_added] = positions logg.info( diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 61967f2ce8..58ae9d04f3 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Hashable from typing import TYPE_CHECKING, cast import numpy as np @@ -9,7 +10,7 @@ from .. import _utils from .. import logging as logg from .._compat import warn -from .._utils.random import set_igraph_random_state +from .._utils.random import set_igraph_rng from ._utils_clustering import rename_groups, restrict_adjacency if TYPE_CHECKING: @@ -19,7 +20,7 @@ from anndata import AnnData from .._compat import CSBase - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike try: # sphinx-autodoc-typehints + optional dependency from leidenalg.VertexPartition import MutableVertexPartition @@ -34,7 +35,7 @@ def leiden( # noqa: PLR0913 resolution: float = 1, *, restrict_to: tuple[str, Sequence[str]] | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, key_added: str = "leiden", adjacency: CSBase | None = None, directed: bool | None = None, @@ -67,7 +68,7 @@ def leiden( # noqa: PLR0913 Higher values lead to more clusters. Set to `None` if overriding `partition_type` to one that doesn’t accept a `resolution_parameter`. - random_state + rng Change the initialization of the optimization. restrict_to Restrict the clustering to the categories within the key for sample @@ -160,7 +161,8 @@ def leiden( # noqa: PLR0913 partition_type = leidenalg.RBConfigurationVertexPartition if use_weights: clustering_args["weights"] = np.array(g.es["weight"]).astype(np.float64) - clustering_args["seed"] = random_state + if isinstance(rng, Hashable): + clustering_args["seed"] = rng part = cast( "MutableVertexPartition", leidenalg.find_partition(g, partition_type, **clustering_args), @@ -172,7 +174,7 @@ def leiden( # noqa: PLR0913 if resolution is not None: clustering_args["resolution"] = resolution clustering_args.setdefault("objective_function", "modularity") - with set_igraph_random_state(random_state): + with set_igraph_rng(rng): part = g.community_leiden(**clustering_args) # store output into adata.obs groups = np.array(part.membership) @@ -195,7 +197,7 @@ def leiden( # noqa: PLR0913 adata.uns[key_added] = {} adata.uns[key_added]["params"] = dict( resolution=resolution, - random_state=random_state, + random_state=rng, n_iterations=n_iterations, ) adata.uns[key_added]["modularity"] = part.modularity diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index 426497182d..afef115ce0 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -19,7 +19,8 @@ from anndata import AnnData from numpy.typing import DTypeLike, NDArray - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike + type _StrIdx = pd.Index[str] type _GetSubset = Callable[[_StrIdx], np.ndarray | CSBase] @@ -61,7 +62,7 @@ def score_genes( # noqa: PLR0913 gene_pool: Sequence[str] | pd.Index[str] | None = None, n_bins: int = 25, score_name: str = "score", - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, copy: bool = False, use_raw: bool | None = None, layer: str | None = None, @@ -94,8 +95,8 @@ def score_genes( # noqa: PLR0913 Number of expression level bins for sampling. score_name Name of the field to be added in `.obs`. - random_state - The random seed for sampling. + rng + The random number generator for sampling. copy Copy `adata` or modify it inplace. use_raw @@ -119,19 +120,17 @@ def score_genes( # noqa: PLR0913 """ start = logg.info(f"computing score {score_name!r}") + rng = np.random.default_rng(rng) adata = adata.copy() if copy else adata use_raw = check_use_raw(adata, use_raw, layer=layer) if is_backed_type(adata.X) and not use_raw: msg = f"score_genes is not implemented for matrices of type {type(adata.X)}" raise NotImplementedError(msg) - if random_state is not None: - np.random.seed(random_state) - gene_list, gene_pool, get_subset = _check_score_genes_args( adata, gene_list, gene_pool, use_raw=use_raw, layer=layer ) - del use_raw, layer, random_state + del use_raw, layer # Trying here to match the Seurat approach in scoring cells. # Basically we need to compare genes against random genes in a matched @@ -145,6 +144,7 @@ def score_genes( # noqa: PLR0913 ctrl_size=ctrl_size, n_bins=n_bins, get_subset=get_subset, + rng=rng, ): control_genes = control_genes.union(r_genes) @@ -224,6 +224,7 @@ def _score_genes_bins( ctrl_size: int, n_bins: int, get_subset: _GetSubset, + rng: np.random.Generator, ) -> Generator[pd.Index[str], None, None]: # average expression of genes obs_avg = pd.Series(_nan_means(get_subset(gene_pool), axis=0), index=gene_pool) @@ -244,7 +245,7 @@ def _score_genes_bins( ) logg.warning(msg) if ctrl_size < len(r_genes): - r_genes = r_genes.to_series().sample(ctrl_size).index + r_genes = r_genes.to_series().sample(ctrl_size, random_state=rng).index if ctrl_as_ref: # otherwise `r_genes` is already filtered r_genes = r_genes.difference(gene_list) yield r_genes diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index de3f3a7300..5aca180291 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -6,13 +6,14 @@ from .._compat import old_positionals, warn from .._settings import settings from .._utils import _doc_params, raise_not_implemented_error_if_backed_type +from .._utils.random import legacy_random_state from ..neighbors._doc import doc_n_pcs, doc_use_rep from ._utils import _choose_representation if TYPE_CHECKING: from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike @old_positionals( @@ -36,7 +37,7 @@ def tsne( # noqa: PLR0913 metric: str = "euclidean", early_exaggeration: float = 12, learning_rate: float = 1000, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, use_fast_tsne: bool = False, n_jobs: int | None = None, key_added: str | None = None, @@ -84,7 +85,7 @@ def tsne( # noqa: PLR0913 optimization, the early exaggeration factor or the learning rate might be too high. If the cost function gets stuck in a bad local minimum increasing the learning rate helps sometimes. - random_state + rng Change this to use different intial states for the optimization. If `None`, the initial state is not reproducible. n_jobs @@ -118,7 +119,7 @@ def tsne( # noqa: PLR0913 n_jobs = settings.n_jobs if n_jobs is None else n_jobs params_sklearn = dict( perplexity=perplexity, - random_state=random_state, + random_state=legacy_random_state(rng), verbose=settings.verbosity > 3, early_exaggeration=early_exaggeration, learning_rate=learning_rate, diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index a1848173fa..ba0b908e8d 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -4,12 +4,13 @@ from typing import TYPE_CHECKING import numpy as np -from sklearn.utils import check_array, check_random_state +from sklearn.utils import check_array from .. import logging as logg from .._compat import old_positionals, warn from .._settings import settings from .._utils import NeighborsView +from .._utils.random import legacy_random_state from ._utils import _choose_representation, get_init_pos_from_paga if TYPE_CHECKING: @@ -17,7 +18,8 @@ from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike + type _InitPos = Literal["paga", "spectral", "random"] @@ -49,7 +51,7 @@ def umap( # noqa: PLR0913, PLR0915 gamma: float = 1.0, negative_sample_rate: int = 5, init_pos: _InitPos | np.ndarray | None = "spectral", - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, a: float | None = None, b: float | None = None, method: Literal["umap", "rapids"] = "umap", @@ -111,11 +113,10 @@ def umap( # noqa: PLR0913, PLR0915 * 'spectral': use a spectral embedding of the graph. * 'random': assign initial embedding positions at random. * A numpy array of initial embedding positions. - random_state - If `int`, `random_state` is the seed used by the random number generator; - If `RandomState` or `Generator`, `random_state` is the random number generator; - If `None`, the random number generator is the `RandomState` instance used - by `np.random`. + rng + If `int`, `rng` is the seed used by the random number generator; + If `Generator`, `random_state` is the random number generator; + If `None`, the random number generator is not reproducible. a More specific parameters controlling the embedding. If `None` these values are set automatically as determined by `min_dist` and @@ -158,6 +159,7 @@ def umap( # noqa: PLR0913, PLR0915 UMAP parameters. """ + rng = np.random.default_rng(rng) adata = adata.copy() if copy else adata key_obsm, key_uns = ("X_umap", "umap") if key_added is None else [key_added] * 2 @@ -191,16 +193,15 @@ def umap( # noqa: PLR0913, PLR0915 init_coords = adata.obsm[init_pos] elif isinstance(init_pos, str) and init_pos == "paga": init_coords = get_init_pos_from_paga( - adata, random_state=random_state, neighbors_key=neighbors_key + adata, rng=rng, neighbors_key=neighbors_key ) else: init_coords = init_pos # Let umap handle it if hasattr(init_coords, "dtype"): init_coords = check_array(init_coords, dtype=np.float32, accept_sparse=False) - if random_state != 0: - adata.uns[key_uns]["params"]["random_state"] = random_state - random_state = check_random_state(random_state) + if rng is not None: + adata.uns[key_uns]["params"]["random_state"] = rng neigh_params = neighbors["params"] x = _choose_representation( @@ -225,7 +226,7 @@ def umap( # noqa: PLR0913, PLR0915 negative_sample_rate=negative_sample_rate, n_epochs=n_epochs, init=init_coords, - random_state=random_state, + random_state=legacy_random_state(rng), metric=neigh_params.get("metric", "euclidean"), metric_kwds=neigh_params.get("metric_kwds", {}), densmap=False, @@ -265,7 +266,7 @@ def umap( # noqa: PLR0913, PLR0915 a=a, b=b, verbose=settings.verbosity > 3, - random_state=random_state, + random_state=legacy_random_state(rng), ) x_umap = umap.fit_transform(x_contiguous) adata.obsm[key_obsm] = x_umap # annotate samples with UMAP coordinates diff --git a/src/scanpy/tools/_utils.py b/src/scanpy/tools/_utils.py index fdcff28720..bdb9ae7f90 100644 --- a/src/scanpy/tools/_utils.py +++ b/src/scanpy/tools/_utils.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from anndata import AnnData + from numpy.typing import NDArray from .._compat import CSRBase, SpBase @@ -77,12 +78,12 @@ def _get_pca_or_small_x(adata: AnnData, n_pcs: int | None) -> np.ndarray | CSRBa def get_init_pos_from_paga( adata: AnnData, + *, + rng: np.random.Generator, adjacency: SpBase | None = None, - random_state=0, neighbors_key: str | None = None, obsp: str | None = None, -): - np.random.seed(random_state) +) -> NDArray[np.float64]: if adjacency is None: adjacency = _choose_graph(adata, obsp, neighbors_key) if "pos" not in adata.uns.get("paga", {}): @@ -99,7 +100,7 @@ def get_init_pos_from_paga( if len(neighbors[1]) > 0: connectivities = connectivities_coarse[i][neighbors] nearest_neighbor = neighbors[1][np.argmax(connectivities)] - noise = np.random.random((len(subset[subset]), 2)) + noise = rng.random((len(subset[subset]), 2)) dist = group_pos - pos[nearest_neighbor] noise = noise * dist init_pos[subset] = group_pos - 0.5 * dist + noise From 8ab6661858c7cab8c03f76909a0a1370a9a3a187 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 24 Feb 2026 13:59:37 +0100 Subject: [PATCH 02/10] add decorator --- src/scanpy/_utils/random.py | 70 ++++++++++++++----- src/scanpy/datasets/_datasets.py | 3 +- src/scanpy/experimental/pp/_normalization.py | 2 + src/scanpy/experimental/pp/_recipes.py | 3 + src/scanpy/neighbors/__init__.py | 5 +- .../preprocessing/_deprecated/sampling.py | 4 +- src/scanpy/preprocessing/_pca/__init__.py | 3 +- src/scanpy/preprocessing/_pca/_compat.py | 3 +- src/scanpy/preprocessing/_recipes.py | 2 + .../preprocessing/_scrublet/__init__.py | 3 + src/scanpy/preprocessing/_simple.py | 2 + src/scanpy/tools/_diffmap.py | 2 + src/scanpy/tools/_dpt.py | 4 +- src/scanpy/tools/_draw_graph.py | 3 +- src/scanpy/tools/_leiden.py | 3 +- src/scanpy/tools/_tsne.py | 3 +- tests/test_utils.py | 4 +- 17 files changed, 88 insertions(+), 31 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index aa405c0ce8..49074f22f3 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from collections.abc import Generator + from typing import Self from numpy.typing import NDArray @@ -20,8 +21,8 @@ "RNGLike", "SeedLike", "_LegacyRandom", + "accepts_legacy_random_state", "ith_k_tuple", - "legacy_numpy_gen", "legacy_random_state", "random_k_tuples", "random_str", @@ -76,23 +77,23 @@ def set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]: ################################### -def legacy_numpy_gen( - random_state: _LegacyRandom | None = None, -) -> np.random.Generator: - """Return a random generator that behaves like the legacy one.""" - if random_state is not None: - if isinstance(random_state, np.random.RandomState): - np.random.set_state(random_state.get_state(legacy=False)) - return _FakeRandomGen(random_state) - np.random.seed(random_state) - return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator())) - - class _FakeRandomGen(np.random.Generator): + _arg: _LegacyRandom _state: np.random.RandomState - def __init__(self, random_state: np.random.RandomState) -> None: - self._state = random_state + def __init__(self, seed_or_state: _LegacyRandom) -> None: + self._arg = seed_or_state + self._state = np.random.RandomState(seed_or_state) + + @classmethod + def wrap_global(cls, random_state: _LegacyRandom | None = None) -> Self: + """Create a generator that wraps the global `RandomState` backing the legacy `np.random` functions.""" + if random_state is not None: + if isinstance(random_state, np.random.RandomState): + np.random.set_state(random_state.get_state(legacy=False)) + return _FakeRandomGen(random_state) + np.random.seed(random_state) + return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator())) @classmethod def _delegate(cls) -> None: @@ -114,13 +115,46 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs): _FakeRandomGen._delegate() -def legacy_random_state(rng: SeedLike | RNGLike | None) -> np.random.RandomState: - rng = np.random.default_rng(rng) +def legacy_random_state(rng: SeedLike | RNGLike | None) -> _LegacyRandom: + """Convert a np.random.Generator into a legacy `random_state` argument. + + If `rng` is already a `_FakeRandomGen`, return its original `_arg` attribute. + """ if isinstance(rng, _FakeRandomGen): - return rng._state + return rng._arg + rng = np.random.default_rng(rng) return np.random.RandomState(rng.bit_generator.spawn(1)[0]) +def accepts_legacy_random_state[**P, R]( + random_state_default: _LegacyRandom, +) -> callable[[callable[P, R]], callable[P, R]]: + """Make a function accept `random_state: _LegacyRandom` and pass it as `rng`. + + If the decorated function is called with a `random_state` argument, + it’ll be wrapped in a :class:`_FakeRandomGen`. + Passing both ``rng`` and ``random_state`` at the same time is an error. + If neither is given, ``random_state_default`` is used. + """ + + def decorator(func: callable[P, R]) -> callable[P, R]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + match "random_state" in kwargs, "rng" in kwargs: + case True, True: + msg = "Specify at most one of `rng` and `random_state`." + raise TypeError(msg) + case True, False: + kwargs["rng"] = _FakeRandomGen(kwargs.pop("random_state")) + case False, False: + kwargs["rng"] = _FakeRandomGen(random_state_default) + return func(*args, **kwargs) + + return wrapper + + return decorator + + ################### # Random k-tuples # ################### diff --git a/src/scanpy/datasets/_datasets.py b/src/scanpy/datasets/_datasets.py index 116c8a572d..d4b56f055d 100644 --- a/src/scanpy/datasets/_datasets.py +++ b/src/scanpy/datasets/_datasets.py @@ -12,7 +12,7 @@ from .._compat import deprecated, old_positionals from .._settings import settings from .._utils._doctests import doctest_internet, doctest_needs -from .._utils.random import legacy_random_state +from .._utils.random import accepts_legacy_random_state, legacy_random_state from ..readwrite import read, read_h5ad, read_visium from ._utils import check_datasetdir_exists @@ -58,6 +58,7 @@ @old_positionals( "n_variables", "n_centers", "cluster_std", "n_observations", "random_state" ) +@accepts_legacy_random_state(0) def blobs( *, n_variables: int = 11, diff --git a/src/scanpy/experimental/pp/_normalization.py b/src/scanpy/experimental/pp/_normalization.py index eddb899bb5..9cd2d6d092 100644 --- a/src/scanpy/experimental/pp/_normalization.py +++ b/src/scanpy/experimental/pp/_normalization.py @@ -9,6 +9,7 @@ from ... import logging as logg from ..._compat import CSBase, warn from ..._utils import _doc_params, _empty, check_nonnegative_integers, view_to_actual +from ..._utils.random import accepts_legacy_random_state from ...experimental._docs import ( doc_adata, doc_check_values, @@ -161,6 +162,7 @@ def normalize_pearson_residuals( check_values=doc_check_values, inplace=doc_inplace, ) +@accepts_legacy_random_state(0) def normalize_pearson_residuals_pca( adata: AnnData, *, diff --git a/src/scanpy/experimental/pp/_recipes.py b/src/scanpy/experimental/pp/_recipes.py index 5b0357dc4d..b688b8d4d8 100644 --- a/src/scanpy/experimental/pp/_recipes.py +++ b/src/scanpy/experimental/pp/_recipes.py @@ -17,6 +17,8 @@ ) from scanpy.preprocessing import pca +from ..._utils.random import accepts_legacy_random_state + if TYPE_CHECKING: from collections.abc import Mapping from typing import Any @@ -35,6 +37,7 @@ check_values=doc_check_values, inplace=doc_inplace, ) +@accepts_legacy_random_state(0) def recipe_pearson_residuals( # noqa: PLR0913 adata: AnnData, *, diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 4fa2b5d597..705398c053 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -17,7 +17,7 @@ from .._compat import CSBase, CSRBase, SpBase, old_positionals, warn from .._settings import settings from .._utils import NeighborsView, _doc_params, get_literal_vals -from .._utils.random import legacy_random_state +from .._utils.random import accepts_legacy_random_state, legacy_random_state from . import _connectivity from ._common import ( _get_indices_distances_from_dense_matrix, @@ -78,6 +78,7 @@ class NeighborsParams(TypedDict): # noqa: D101 @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) +@accepts_legacy_random_state(0) def neighbors( # noqa: PLR0913 adata: AnnData, n_neighbors: int = 15, @@ -572,6 +573,7 @@ def to_igraph(self) -> Graph: return _utils.get_igraph_from_adjacency(self.connectivities) @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) + @accepts_legacy_random_state(0) def compute_neighbors( self, n_neighbors: int = 30, @@ -842,6 +844,7 @@ def compute_transitions(self, *, density_normalize: bool = True): self._transitions_sym = self.Z @ conn_norm @ self.Z logg.info(" finished", time=start) + @accepts_legacy_random_state(0) def compute_eigen( self, *, diff --git a/src/scanpy/preprocessing/_deprecated/sampling.py b/src/scanpy/preprocessing/_deprecated/sampling.py index 2280f3c9a0..b2dfeb0e92 100644 --- a/src/scanpy/preprocessing/_deprecated/sampling.py +++ b/src/scanpy/preprocessing/_deprecated/sampling.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from ..._compat import old_positionals -from ..._utils.random import legacy_numpy_gen +from ..._utils.random import _FakeRandomGen from .._simple import sample if TYPE_CHECKING: @@ -52,7 +52,7 @@ def subsample( returns a subsampled copy of it (`copy == True`). """ - rng = legacy_numpy_gen(random_state) + rng = _FakeRandomGen.wrap_global(random_state) return sample( data=data, fraction=fraction, n=n_obs, rng=rng, copy=copy, replace=False, axis=0 ) diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 87ad1752ee..c6d70e7c36 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -10,7 +10,7 @@ from ..._compat import CSBase, DaskArray, pkg_version, warn from ..._settings import settings from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type -from ..._utils.random import legacy_random_state +from ..._utils.random import accepts_legacy_random_state, legacy_random_state from ...get import _check_mask, _get_obs_rep from .._docs import doc_mask_var_hvg from ._compat import _pca_compat_sparse @@ -54,6 +54,7 @@ @_doc_params( mask_var_hvg=doc_mask_var_hvg, ) +@accepts_legacy_random_state(0) def pca( # noqa: PLR0912, PLR0913, PLR0915 data: AnnData | np.ndarray | CSBase, n_comps: int | None = None, diff --git a/src/scanpy/preprocessing/_pca/_compat.py b/src/scanpy/preprocessing/_pca/_compat.py index 133be7b0f0..f1f4e33b0d 100644 --- a/src/scanpy/preprocessing/_pca/_compat.py +++ b/src/scanpy/preprocessing/_pca/_compat.py @@ -10,7 +10,7 @@ from sklearn.utils.extmath import svd_flip from ..._compat import pkg_version -from ..._utils.random import legacy_random_state +from ..._utils.random import accepts_legacy_random_state, legacy_random_state if TYPE_CHECKING: from typing import Literal @@ -22,6 +22,7 @@ from ..._utils.random import RNGLike, SeedLike +@accepts_legacy_random_state(None) def _pca_compat_sparse( x: CSBase, n_pcs: int, diff --git a/src/scanpy/preprocessing/_recipes.py b/src/scanpy/preprocessing/_recipes.py index 34cf986d57..504561dcf7 100644 --- a/src/scanpy/preprocessing/_recipes.py +++ b/src/scanpy/preprocessing/_recipes.py @@ -7,6 +7,7 @@ from .. import logging as logg from .. import preprocessing as pp from .._compat import CSBase, old_positionals +from .._utils.random import accepts_legacy_random_state from ._deprecated.highly_variable_genes import ( filter_genes_cv_deprecated, filter_genes_dispersion, @@ -28,6 +29,7 @@ "random_state", "copy", ) +@accepts_legacy_random_state(0) def recipe_weinreb17( adata: AnnData, *, diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index b5c7d60a55..7b9f86335b 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -11,6 +11,7 @@ from ... import logging as logg from ... import preprocessing as pp from ..._compat import old_positionals +from ..._utils.random import accepts_legacy_random_state from ...get import _get_obs_rep from . import pipeline from .core import Scrublet @@ -39,6 +40,7 @@ "copy", "random_state", ) +@accepts_legacy_random_state() def scrublet( # noqa: PLR0913 adata: AnnData, adata_sim: AnnData | None = None, @@ -304,6 +306,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): return adata if copy else None +@accepts_legacy_random_state() def _scrublet_call_doublets( # noqa: PLR0913 adata_obs: AnnData, adata_sim: AnnData, diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 9d6abf1a38..89816d4152 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -30,6 +30,7 @@ sanitize_anndata, view_to_actual, ) +from .._utils.random import accepts_legacy_random_state from ..get import _check_mask, _get_obs_rep, _set_obs_rep from ._distributed import materialize_as_ndarray @@ -991,6 +992,7 @@ def sample( # noqa: PLR0912 @renamed_arg("target_counts", "counts_per_cell") +@accepts_legacy_random_state(0) def downsample_counts( adata: AnnData, counts_per_cell: int | Collection[int] | None = None, diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index eae7050fda..bf533938dd 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -5,6 +5,7 @@ import numpy as np from .._compat import old_positionals +from .._utils.random import accepts_legacy_random_state from ._dpt import _diffmap if TYPE_CHECKING: @@ -14,6 +15,7 @@ @old_positionals("neighbors_key", "random_state", "copy") +@accepts_legacy_random_state(0) def diffmap( adata: AnnData, n_comps: int = 15, diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 7d8e9dc661..5225ec2164 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -9,7 +9,7 @@ from .. import logging as logg from .._compat import old_positionals -from .._utils.random import legacy_numpy_gen +from .._utils.random import _FakeRandomGen from ..neighbors import Neighbors, OnFlySymMatrix if TYPE_CHECKING: @@ -151,7 +151,7 @@ def dpt( "Trying to run `tl.dpt` without prior call of `tl.diffmap`. " "Falling back to `tl.diffmap` with default parameters." ) - _diffmap(adata, neighbors_key=neighbors_key, rng=legacy_numpy_gen(0)) + _diffmap(adata, neighbors_key=neighbors_key, rng=_FakeRandomGen(0)) # start with the actual computation dpt = DPT( adata, diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index 5a6cc6eec3..b0e9406575 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -9,7 +9,7 @@ from .. import logging as logg from .._compat import old_positionals from .._utils import _choose_graph, get_literal_vals -from .._utils.random import set_igraph_rng +from .._utils.random import accepts_legacy_random_state, set_igraph_rng from ._utils import get_init_pos_from_paga if TYPE_CHECKING: @@ -35,6 +35,7 @@ "obsp", "copy", ) +@accepts_legacy_random_state(0) def draw_graph( # noqa: PLR0913 adata: AnnData, layout: _Layout = "fa", diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 58ae9d04f3..80efd4d6b0 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -10,7 +10,7 @@ from .. import _utils from .. import logging as logg from .._compat import warn -from .._utils.random import set_igraph_rng +from .._utils.random import accepts_legacy_random_state, set_igraph_rng from ._utils_clustering import rename_groups, restrict_adjacency if TYPE_CHECKING: @@ -30,6 +30,7 @@ MutableVertexPartition.__module__ = "leidenalg.VertexPartition" +@accepts_legacy_random_state(0) def leiden( # noqa: PLR0913 adata: AnnData, resolution: float = 1, diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index 5aca180291..f49456f505 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -6,7 +6,7 @@ from .._compat import old_positionals, warn from .._settings import settings from .._utils import _doc_params, raise_not_implemented_error_if_backed_type -from .._utils.random import legacy_random_state +from .._utils.random import accepts_legacy_random_state, legacy_random_state from ..neighbors._doc import doc_n_pcs, doc_use_rep from ._utils import _choose_representation @@ -26,6 +26,7 @@ "n_jobs", "copy", ) +@accepts_legacy_random_state(0) @_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep) def tsne( # noqa: PLR0913 adata: AnnData, diff --git a/tests/test_utils.py b/tests/test_utils.py index b82deea324..6c8709d296 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,8 +18,8 @@ descend_classes_and_funcs, ) from scanpy._utils.random import ( + _FakeRandomGen, ith_k_tuple, - legacy_numpy_gen, random_k_tuples, random_str, ) @@ -206,7 +206,7 @@ def test_legacy_numpy_gen(*, seed: int, pass_seed: bool, func: str): def _mk_random(func: str, *, direct: bool, seed: int | None) -> np.ndarray: if direct and seed is not None: np.random.seed(seed) - gen = np.random if direct else legacy_numpy_gen(seed) + gen = np.random if direct else _FakeRandomGen.wrap_global(seed) match func: case "choice": arr = np.arange(1000) From 1ef87803bcbb1d9e01f0ca919bcf917beff4320a Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 24 Feb 2026 14:19:23 +0100 Subject: [PATCH 03/10] scrublet --- src/scanpy/preprocessing/_scrublet/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index 7b9f86335b..6967b4ecd1 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -40,7 +40,7 @@ "copy", "random_state", ) -@accepts_legacy_random_state() +@accepts_legacy_random_state(0) def scrublet( # noqa: PLR0913 adata: AnnData, adata_sim: AnnData | None = None, @@ -306,7 +306,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): return adata if copy else None -@accepts_legacy_random_state() +@accepts_legacy_random_state(0) def _scrublet_call_doublets( # noqa: PLR0913 adata_obs: AnnData, adata_sim: AnnData, From 5308a1acb90d6de3c3bf432b3d24d40c83dc187c Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 24 Feb 2026 18:14:04 +0100 Subject: [PATCH 04/10] almost done --- pyproject.toml | 2 + src/scanpy/_utils/random.py | 57 ++++++++++++++----- src/scanpy/neighbors/__init__.py | 19 +++++-- src/scanpy/preprocessing/_pca/__init__.py | 25 +++++--- .../preprocessing/_scrublet/__init__.py | 18 ++++-- src/scanpy/preprocessing/_simple.py | 19 +++++-- src/scanpy/tools/_draw_graph.py | 7 ++- src/scanpy/tools/_leiden.py | 10 ++-- src/scanpy/tools/_score_genes.py | 5 ++ src/scanpy/tools/_umap.py | 7 ++- tests/test_pca.py | 4 +- 11 files changed, 125 insertions(+), 48 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 126098f593..545c6f8a3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -279,6 +279,8 @@ filterwarnings = [ "ignore:The `igraph` implementation of leiden clustering:UserWarning", # everybody uses this zarr 3 feature, including us, XArray, lots of data out there … "ignore:Consolidated metadata is currently not part:UserWarning", + # joblib fallback to serial mode in restricted multiprocessing environments + "ignore:.*joblib will operate in serial mode:UserWarning", ] [tool.coverage] diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index 49074f22f3..b2692192b5 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -21,6 +21,7 @@ "RNGLike", "SeedLike", "_LegacyRandom", + "_if_legacy_apply_global", "accepts_legacy_random_state", "ith_k_tuple", "legacy_random_state", @@ -81,22 +82,38 @@ class _FakeRandomGen(np.random.Generator): _arg: _LegacyRandom _state: np.random.RandomState - def __init__(self, seed_or_state: _LegacyRandom) -> None: - self._arg = seed_or_state - self._state = np.random.RandomState(seed_or_state) + def __init__( + self, arg: _LegacyRandom, state: np.random.RandomState | None = None + ) -> None: + self._arg = arg + self._state = np.random.RandomState(arg) if state is None else state + super().__init__(self._state._bit_generator) @classmethod - def wrap_global(cls, random_state: _LegacyRandom | None = None) -> Self: + def wrap_global( + cls, + arg: _LegacyRandom = None, + state: np.random.RandomState | None = None, + ) -> Self: """Create a generator that wraps the global `RandomState` backing the legacy `np.random` functions.""" - if random_state is not None: - if isinstance(random_state, np.random.RandomState): - np.random.set_state(random_state.get_state(legacy=False)) - return _FakeRandomGen(random_state) - np.random.seed(random_state) - return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator())) + if arg is not None: + if isinstance(arg, np.random.RandomState): + np.random.set_state(arg.get_state(legacy=False)) + return _FakeRandomGen(arg, state) + np.random.seed(arg) + return _FakeRandomGen(arg, np.random.RandomState(np.random.get_bit_generator())) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _FakeRandomGen): + return False + return self._arg == other._arg + + def __hash__(self) -> int: + return hash((type(self), self._arg)) @classmethod def _delegate(cls) -> None: + names = dict(integers="randint") for name, meth in np.random.Generator.__dict__.items(): if name.startswith("_") or not callable(meth): continue @@ -109,19 +126,33 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs): return wrapper - setattr(cls, name, mk_wrapper(name, meth)) + setattr(cls, names.get(name, name), mk_wrapper(name, meth)) _FakeRandomGen._delegate() -def legacy_random_state(rng: SeedLike | RNGLike | None) -> _LegacyRandom: +def _if_legacy_apply_global(rng: np.random.Generator) -> np.random.Generator: + """Re-apply legacy `random_state` semantics when `rng` is a `_FakeRandomGen`. + + This resets the global legacy RNG from the original `_arg` and returns a + generator which continues drawing from the same internal state. + """ + if not isinstance(rng, _FakeRandomGen): + return rng + + return _FakeRandomGen.wrap_global(rng._arg, rng._state) + + +def legacy_random_state( + rng: SeedLike | RNGLike | None, *, always_state: bool = False +) -> _LegacyRandom: """Convert a np.random.Generator into a legacy `random_state` argument. If `rng` is already a `_FakeRandomGen`, return its original `_arg` attribute. """ if isinstance(rng, _FakeRandomGen): - return rng._arg + return rng._state if always_state else rng._arg rng = np.random.default_rng(rng) return np.random.RandomState(rng.bit_generator.spawn(1)[0]) diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 705398c053..3e03c60da5 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -17,7 +17,11 @@ from .._compat import CSBase, CSRBase, SpBase, old_positionals, warn from .._settings import settings from .._utils import NeighborsView, _doc_params, get_literal_vals -from .._utils.random import accepts_legacy_random_state, legacy_random_state +from .._utils.random import ( + _FakeRandomGen, + accepts_legacy_random_state, + legacy_random_state, +) from . import _connectivity from ._common import ( _get_indices_distances_from_dense_matrix, @@ -225,17 +229,20 @@ def neighbors( # noqa: PLR0913 ) else: params = locals() - if ignored := { + ignored = { p.name for p in signature(neighbors).parameters.values() - if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds", "rng"} + if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds"} if params[p.name] != p.default - }: + } + if not isinstance(rng, _FakeRandomGen) or rng._arg != 0: + ignored.add("rng/random_state") + rng = _FakeRandomGen(0) + if ignored: warn( f"Parameter(s) ignored if `distances` is given: {ignored}", UserWarning, ) - random_state = 0 if callable(metric): msg = "`metric` must be a string if `distances` is given." raise TypeError(msg) @@ -263,7 +270,7 @@ def neighbors( # noqa: PLR0913 key_added, n_neighbors=neighbors_.n_neighbors, method=method, - random_state=rng, + random_state=legacy_random_state(rng), metric=metric, **({} if not metric_kwds else dict(metric_kwds=metric_kwds)), **({} if use_rep is None else dict(use_rep=use_rep)), diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index c6d70e7c36..da1888a1cc 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -10,7 +10,11 @@ from ..._compat import CSBase, DaskArray, pkg_version, warn from ..._settings import settings from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type -from ..._utils.random import accepts_legacy_random_state, legacy_random_state +from ..._utils.random import ( + _FakeRandomGen, + accepts_legacy_random_state, + legacy_random_state, +) from ...get import _check_mask, _get_obs_rep from .._docs import doc_mask_var_hvg from ._compat import _pca_compat_sparse @@ -243,19 +247,22 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 raise NotImplementedError(msg) # dask needs an int for random state - if not isinstance(x, DaskArray): - rng = np.random.default_rng(rng) - elif not isinstance(rng, int): - msg = f"rng needs to be an int, not a {type(rng).__name__} when passing a dask array" + rng = np.random.default_rng(rng) + if not isinstance(rng, _FakeRandomGen) and not isinstance( + rng._arg, int | np.random.RandomState + ): + # TODO: remove this error and if we don’t have a _FakeRandomGen, + # just use rng.integers to make a seed farther down + msg = f"rng needs to be an int or a np.random.RandomState, not a {type(rng).__name__} when passing a dask array" raise TypeError(msg) if chunked: if ( not zero_center - or rng is not None + or (not isinstance(rng, _FakeRandomGen) or rng._arg != 0) or (svd_solver is not None and svd_solver != "arpack") ): - logg.debug("Ignoring zero_center, random_state, svd_solver") + logg.debug("Ignoring zero_center, rng, svd_solver") incremental_pca_kwargs = dict() if isinstance(x, DaskArray): @@ -303,8 +310,8 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 elif isinstance(x._meta, CSBase) or svd_solver == "covariance_eigh": from ._dask import PCAEighDask - if rng is not None: - msg = f"Ignoring {rng=} when using a sparse dask array" + if not isinstance(rng, _FakeRandomGen) or rng._arg != 0: + msg = f"Ignoring rng={legacy_random_state(rng)} when using a sparse dask array" warn(msg, UserWarning) if svd_solver not in {None, "covariance_eigh"}: msg = f"Ignoring {svd_solver=} when using a sparse dask array" diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index 6967b4ecd1..a75f27a33b 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -11,7 +11,11 @@ from ... import logging as logg from ... import preprocessing as pp from ..._compat import old_positionals -from ..._utils.random import accepts_legacy_random_state +from ..._utils.random import ( + _if_legacy_apply_global, + accepts_legacy_random_state, + legacy_random_state, +) from ...get import _get_obs_rep from . import pipeline from .core import Scrublet @@ -181,6 +185,7 @@ def scrublet( # noqa: PLR0913 """ rng = np.random.default_rng(rng) + rng = _if_legacy_apply_global(rng) if threshold is None and not find_spec("skimage"): # pragma: no cover # Scrublet.call_doublets requires `skimage` with `threshold=None` but PCA # is called early, which is wasteful if there is not `skimage` @@ -444,10 +449,14 @@ def _scrublet_call_doublets( # noqa: PLR0913 if mean_center: logg.info("Embedding transcriptomes using PCA...") - pipeline.pca(scrub, n_prin_comps=n_prin_comps, rng=scrub._rng) + pipeline.pca( + scrub, n_prin_comps=n_prin_comps, svd_solver="arpack", rng=scrub._rng + ) else: logg.info("Embedding transcriptomes using Truncated SVD...") - pipeline.truncated_svd(scrub, n_prin_comps=n_prin_comps, rng=scrub._rng) + pipeline.truncated_svd( + scrub, n_prin_comps=n_prin_comps, algorithm="arpack", rng=scrub._rng + ) # Score the doublets @@ -479,7 +488,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 .get("sim_doublet_ratio", None) ), "n_neighbors": n_neighbors, - "rng": rng, + "random_state": legacy_random_state(rng), }, } @@ -557,6 +566,7 @@ def scrublet_simulate_doublets( scores for observed transcriptomes and simulated doublets. """ + rng = _if_legacy_apply_global(rng) x = _get_obs_rep(adata, layer=layer) scrub = Scrublet(x, rng=rng) diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 89816d4152..83e5e96c58 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -30,7 +30,7 @@ sanitize_anndata, view_to_actual, ) -from .._utils.random import accepts_legacy_random_state +from .._utils.random import _if_legacy_apply_global, accepts_legacy_random_state from ..get import _check_mask, _get_obs_rep, _set_obs_rep from ._distributed import materialize_as_ndarray @@ -1040,6 +1040,7 @@ def downsample_counts( raise_not_implemented_error_if_backed_type(adata.X, "downsample_counts") # This logic is all dispatch rng = np.random.default_rng(rng) + rng = _if_legacy_apply_global(rng) total_counts_call = total_counts is not None counts_per_cell_call = counts_per_cell is not None if total_counts_call is counts_per_cell_call: @@ -1134,7 +1135,6 @@ def _downsample_total_counts( # TODO: can/should this be parallelized? -@numba.njit(cache=True) # noqa: TID251 def _downsample_array( col: np.ndarray, target: int, @@ -1142,7 +1142,7 @@ def _downsample_array( rng: np.random.Generator, replace: bool = True, inplace: bool = False, -): +) -> np.ndarray: """Evenly reduce counts in cell to target amount. This is an internal function and has some restrictions: @@ -1150,13 +1150,20 @@ def _downsample_array( * total counts in cell must be less than target """ cumcounts = col.cumsum() + total = np.int_(cumcounts[-1]) + sample = rng.choice(total, target, replace=replace) + sample.sort() + return _downsample_array_inner(col, cumcounts, sample, inplace=inplace) + + +@numba.njit(cache=True) # noqa: TID251 +def _downsample_array_inner( + col: np.ndarray, cumcounts: np.ndarray, sample: np.ndarray, *, inplace: bool +) -> np.ndarray: if inplace: col[:] = 0 else: col = np.zeros_like(col) - total = np.int_(cumcounts[-1]) - sample = rng.choice(total, target, replace=replace) - sample.sort() geneptr = 0 for count in sample: while count >= cumcounts[geneptr]: diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index b0e9406575..ce0d5182b3 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -9,7 +9,11 @@ from .. import logging as logg from .._compat import old_positionals from .._utils import _choose_graph, get_literal_vals -from .._utils.random import accepts_legacy_random_state, set_igraph_rng +from .._utils.random import ( + _if_legacy_apply_global, + accepts_legacy_random_state, + set_igraph_rng, +) from ._utils import get_init_pos_from_paga if TYPE_CHECKING: @@ -125,6 +129,7 @@ def draw_graph( # noqa: PLR0913 """ start = logg.info(f"drawing single-cell graph using layout {layout!r}") rng = np.random.default_rng(rng) + rng = _if_legacy_apply_global(rng) if layout not in (layouts := get_literal_vals(_Layout)): msg = f"Provide a valid layout, one of {layouts}." raise ValueError(msg) diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 80efd4d6b0..e4a50ef949 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Hashable from typing import TYPE_CHECKING, cast import numpy as np @@ -10,7 +9,11 @@ from .. import _utils from .. import logging as logg from .._compat import warn -from .._utils.random import accepts_legacy_random_state, set_igraph_rng +from .._utils.random import ( + accepts_legacy_random_state, + legacy_random_state, + set_igraph_rng, +) from ._utils_clustering import rename_groups, restrict_adjacency if TYPE_CHECKING: @@ -162,8 +165,7 @@ def leiden( # noqa: PLR0913 partition_type = leidenalg.RBConfigurationVertexPartition if use_weights: clustering_args["weights"] = np.array(g.es["weight"]).astype(np.float64) - if isinstance(rng, Hashable): - clustering_args["seed"] = rng + clustering_args["seed"] = legacy_random_state(rng) part = cast( "MutableVertexPartition", leidenalg.find_partition(g, partition_type, **clustering_args), diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index afef115ce0..bf977a7531 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -10,6 +10,7 @@ from .. import logging as logg from .._compat import CSBase, old_positionals from .._utils import check_use_raw, is_backed_type +from .._utils.random import _if_legacy_apply_global, accepts_legacy_random_state from ..get import _get_obs_rep if TYPE_CHECKING: @@ -53,6 +54,7 @@ def _sparse_nanmean(x: CSBase, /, axis: Literal[0, 1]) -> NDArray[np.float64]: @old_positionals( "ctrl_size", "gene_pool", "n_bins", "score_name", "random_state", "copy", "use_raw" ) +@accepts_legacy_random_state(0) def score_genes( # noqa: PLR0913 adata: AnnData, gene_list: Sequence[str] | pd.Index[str], @@ -120,7 +122,10 @@ def score_genes( # noqa: PLR0913 """ start = logg.info(f"computing score {score_name!r}") + rng_was_passed = rng is not None rng = np.random.default_rng(rng) + if rng_was_passed: # backwards compatibility: call np.random.seed() by default + rng = _if_legacy_apply_global(rng) adata = adata.copy() if copy else adata use_raw = check_use_raw(adata, use_raw, layer=layer) if is_backed_type(adata.X) and not use_raw: diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index ba0b908e8d..1e0cdc0bcc 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -10,7 +10,7 @@ from .._compat import old_positionals, warn from .._settings import settings from .._utils import NeighborsView -from .._utils.random import legacy_random_state +from .._utils.random import accepts_legacy_random_state, legacy_random_state from ._utils import _choose_representation, get_init_pos_from_paga if TYPE_CHECKING: @@ -40,6 +40,7 @@ "method", "neighbors_key", ) +@accepts_legacy_random_state(0) def umap( # noqa: PLR0913, PLR0915 adata: AnnData, *, @@ -201,7 +202,7 @@ def umap( # noqa: PLR0913, PLR0915 init_coords = check_array(init_coords, dtype=np.float32, accept_sparse=False) if rng is not None: - adata.uns[key_uns]["params"]["random_state"] = rng + adata.uns[key_uns]["params"]["random_state"] = legacy_random_state(rng) neigh_params = neighbors["params"] x = _choose_representation( @@ -226,7 +227,7 @@ def umap( # noqa: PLR0913, PLR0915 negative_sample_rate=negative_sample_rate, n_epochs=n_epochs, init=init_coords, - random_state=legacy_random_state(rng), + random_state=legacy_random_state(rng, always_state=True), metric=neigh_params.get("metric", "euclidean"), metric_kwds=neigh_params.get("metric_kwds", {}), densmap=False, diff --git a/tests/test_pca.py b/tests/test_pca.py index ba996d8ad9..d4cd5e58f3 100644 --- a/tests/test_pca.py +++ b/tests/test_pca.py @@ -244,7 +244,7 @@ def test_pca_transform_randomized(array_type): warnings.filterwarnings("error") if isinstance(adata.X, DaskArray) and isinstance(adata.X._meta, CSBase): patterns = ( - r"Ignoring random_state=14 when using a sparse dask array", + r"Ignoring rng=14 when using a sparse dask array", r"Ignoring svd_solver='randomized' when using a sparse dask array", ) ctx = _helpers.MultiContext( @@ -338,7 +338,7 @@ def test_pca_reproducible(array_type): pbmc.X = array_type(pbmc.X) with ( - pytest.warns(UserWarning, match=r"Ignoring random_state.*sparse dask array") + pytest.warns(UserWarning, match=r"Ignoring rng.*sparse dask array") if isinstance(pbmc.X, DaskArray) and isinstance(pbmc.X._meta, CSBase) else nullcontext() ): From 32b3ddc9617df423930f512a71f2afbd698b5616 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 13:09:57 +0100 Subject: [PATCH 05/10] fix scrublet_simulate_doublets --- src/scanpy/preprocessing/_scrublet/__init__.py | 9 ++------- src/scanpy/preprocessing/_simple.py | 4 ++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index 718672df25..b9215ae163 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -10,11 +10,7 @@ from ... import logging as logg from ... import preprocessing as pp -from ..._utils.random import ( - _if_legacy_apply_global, - accepts_legacy_random_state, - legacy_random_state, -) +from ..._utils.random import accepts_legacy_random_state, legacy_random_state from ...get import _get_obs_rep from . import pipeline from .core import Scrublet @@ -165,7 +161,6 @@ def scrublet( # noqa: PLR0913 """ rng = np.random.default_rng(rng) - rng = _if_legacy_apply_global(rng) if threshold is None and not find_spec("skimage"): # pragma: no cover # Scrublet.call_doublets requires `skimage` with `threshold=None` but PCA # is called early, which is wasteful if there is not `skimage` @@ -493,6 +488,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 return adata_obs +@accepts_legacy_random_state(0) def scrublet_simulate_doublets( adata: AnnData, *, @@ -543,7 +539,6 @@ def scrublet_simulate_doublets( scores for observed transcriptomes and simulated doublets. """ - rng = _if_legacy_apply_global(rng) x = _get_obs_rep(adata, layer=layer) scrub = Scrublet(x, rng=rng) diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index e0c699d7b6..51f8b4a726 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -1020,7 +1020,6 @@ def downsample_counts( raise_not_implemented_error_if_backed_type(adata.X, "downsample_counts") # This logic is all dispatch rng = np.random.default_rng(rng) - rng = _if_legacy_apply_global(rng) total_counts_call = total_counts is not None counts_per_cell_call = counts_per_cell is not None if total_counts_call is counts_per_cell_call: @@ -1114,7 +1113,6 @@ def _downsample_total_counts( return x -# TODO: can/should this be parallelized? def _downsample_array( col: np.ndarray, target: int, @@ -1129,6 +1127,7 @@ def _downsample_array( * total counts in cell must be less than target """ + rng = _if_legacy_apply_global(rng) cumcounts = col.cumsum() total = np.int_(cumcounts[-1]) sample = rng.choice(total, target, replace=replace) @@ -1136,6 +1135,7 @@ def _downsample_array( return _downsample_array_inner(col, cumcounts, sample, inplace=inplace) +# TODO: can/should this be parallelized? @numba.njit(cache=True) # noqa: TID251 def _downsample_array_inner( col: np.ndarray, cumcounts: np.ndarray, sample: np.ndarray, *, inplace: bool From c3da2bb235cc1410f5a97b341b70812d0d47243e Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 13:25:37 +0100 Subject: [PATCH 06/10] fix _RNGIgraph compat --- src/scanpy/_utils/random.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index b2692192b5..22c1c113d2 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -40,7 +40,7 @@ class _RNGIgraph: - """Random number generator for igraph so global seed is not changed. + """Random number generator for igraph so global random state is not changed. See :func:`igraph.set_random_number_generator` for the requirements. """ @@ -49,8 +49,11 @@ def __init__(self, rng: SeedLike | RNGLike | None) -> None: self._rng = np.random.default_rng(rng) def getrandbits(self, k: int) -> int: - lims = np.iinfo(np.uint64) - i = int(self._rng.integers(0, lims.max, dtype=np.uint64)) + if isinstance(self._rng, _FakeRandomGen): + i = self._rng._state.tomaxint() + else: + lims = np.iinfo(np.uint64) + i = int(self._rng.integers(0, lims.max, dtype=np.uint64)) return i & ((1 << k) - 1) def randint(self, a: int, b: int) -> np.int64: @@ -128,6 +131,9 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs): setattr(cls, names.get(name, name), mk_wrapper(name, meth)) + def __getattribute__(self, name: str) -> object: + return super().__getattribute__(name) + _FakeRandomGen._delegate() From bd85d959e049ae495150e7a500a9aef33493fb61 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 13:38:31 +0100 Subject: [PATCH 07/10] whoops --- src/scanpy/tools/_louvain.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scanpy/tools/_louvain.py b/src/scanpy/tools/_louvain.py index b10491c28b..4e131affd6 100644 --- a/src/scanpy/tools/_louvain.py +++ b/src/scanpy/tools/_louvain.py @@ -86,7 +86,8 @@ def louvain( # noqa: PLR0912, PLR0913, PLR0915 resolution (higher resolution means finding more and smaller clusters), which defaults to 1.0. See “Time as a resolution parameter” in :cite:t:`Lambiotte2014`. - {random_state} + random_state + Change the initialization of the optimization. {restrict_to} key_added Key under which to add the cluster labels. (default: ``'louvain'``) From 8247cdbf9cff12c76686c412b84a07f5137cf10b Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 13:43:31 +0100 Subject: [PATCH 08/10] relnote --- docs/release-notes/3983.feat.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/release-notes/3983.feat.md diff --git a/docs/release-notes/3983.feat.md b/docs/release-notes/3983.feat.md new file mode 100644 index 0000000000..ca67baba94 --- /dev/null +++ b/docs/release-notes/3983.feat.md @@ -0,0 +1 @@ +Add support for {class}`numpy.random.Generator` to all functions previously accepting a `random_state` parameter {smaller}`P Angerer` From 47f3ceba1823d81a383f78cbbde5cf7e180f616f Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 14:37:49 +0100 Subject: [PATCH 09/10] =?UTF-8?q?don=E2=80=99t=20store=20rng=20in=20random?= =?UTF-8?q?=5Fstate=20arg?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/scanpy/_utils/random.py | 9 +-------- src/scanpy/tools/_draw_graph.py | 5 ++++- src/scanpy/tools/_leiden.py | 2 +- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index baf0c8fdc3..e80eac1613 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING import numpy as np +from numpy.random._generator import Generator from . import ensure_igraph @@ -107,14 +108,6 @@ def wrap_global( np.random.seed(arg) return _FakeRandomGen(arg, np.random.RandomState(np.random.get_bit_generator())) - def __eq__(self, other: object) -> bool: - if not isinstance(other, _FakeRandomGen): - return False - return self._arg == other._arg - - def __hash__(self) -> int: - return hash((type(self), self._arg)) - @classmethod def _delegate(cls) -> None: names = dict(integers="randint") diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index 4047d63a41..01c9b151a5 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -11,6 +11,7 @@ from .._utils.random import ( _if_legacy_apply_global, accepts_legacy_random_state, + legacy_random_state, set_igraph_rng, ) from ._utils import get_init_pos_from_paga @@ -154,7 +155,9 @@ def draw_graph( # noqa: PLR0913 ig_layout = g.layout(layout, **kwds) positions = np.array(ig_layout.coords) adata.uns["draw_graph"] = {} - adata.uns["draw_graph"]["params"] = dict(layout=layout, random_state=rng) + adata.uns["draw_graph"]["params"] = dict( + layout=layout, random_state=legacy_random_state(rng) + ) key_added = f"X_draw_graph_{key_added_ext or layout}" adata.obsm[key_added] = positions logg.info( diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 1f4fd2a7ea..f43723c6e0 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -204,7 +204,7 @@ def leiden( # noqa: PLR0913 adata.uns[key_added] = {} adata.uns[key_added]["params"] = dict( resolution=resolution, - random_state=rng, + random_state=legacy_random_state(rng), n_iterations=n_iterations, ) adata.uns[key_added]["modularity"] = part.modularity From 1e43b2ab4eb4be897fd37d8c07e55ebb5f22d0e3 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 14:50:08 +0100 Subject: [PATCH 10/10] make consistent --- src/scanpy/_utils/random.py | 18 +++++++++--------- src/scanpy/datasets/_datasets.py | 6 +++--- src/scanpy/experimental/pp/_normalization.py | 4 ++-- src/scanpy/experimental/pp/_recipes.py | 4 ++-- src/scanpy/neighbors/__init__.py | 14 +++++++------- src/scanpy/preprocessing/_pca/__init__.py | 14 +++++++------- src/scanpy/preprocessing/_pca/_compat.py | 6 +++--- src/scanpy/preprocessing/_recipes.py | 4 ++-- src/scanpy/preprocessing/_scrublet/__init__.py | 10 +++++----- src/scanpy/preprocessing/_scrublet/pipeline.py | 6 +++--- src/scanpy/preprocessing/_simple.py | 4 ++-- src/scanpy/preprocessing/_utils.py | 4 ++-- src/scanpy/tools/_diffmap.py | 4 ++-- src/scanpy/tools/_draw_graph.py | 12 ++++++------ src/scanpy/tools/_leiden.py | 14 +++++++------- src/scanpy/tools/_score_genes.py | 4 ++-- src/scanpy/tools/_tsne.py | 6 +++--- src/scanpy/tools/_umap.py | 10 +++++----- 18 files changed, 72 insertions(+), 72 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index e80eac1613..f5dc3c6531 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -7,12 +7,11 @@ from typing import TYPE_CHECKING import numpy as np -from numpy.random._generator import Generator from . import ensure_igraph if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Callable, Generator from typing import Self from numpy.typing import NDArray @@ -22,10 +21,11 @@ "RNGLike", "SeedLike", "_LegacyRandom", + "_accepts_legacy_random_state", "_if_legacy_apply_global", - "accepts_legacy_random_state", + "_legacy_random_state", + "_set_igraph_rng", "ith_k_tuple", - "legacy_random_state", "random_k_tuples", "random_str", ] @@ -66,7 +66,7 @@ def __getattr__(self, attr: str): @contextmanager -def set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]: +def _set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]: ensure_igraph() import igraph @@ -141,7 +141,7 @@ def _if_legacy_apply_global(rng: np.random.Generator) -> np.random.Generator: return _FakeRandomGen.wrap_global(rng._arg, rng._state) -def legacy_random_state( +def _legacy_random_state( rng: SeedLike | RNGLike | None, *, always_state: bool = False ) -> _LegacyRandom: """Convert a np.random.Generator into a legacy `random_state` argument. @@ -154,9 +154,9 @@ def legacy_random_state( return np.random.RandomState(rng.bit_generator.spawn(1)[0]) -def accepts_legacy_random_state[**P, R]( +def _accepts_legacy_random_state[**P, R]( random_state_default: _LegacyRandom, -) -> callable[[callable[P, R]], callable[P, R]]: +) -> Callable[[Callable[P, R]], Callable[P, R]]: """Make a function accept `random_state: _LegacyRandom` and pass it as `rng`. If the decorated function is called with a `random_state` argument, @@ -165,7 +165,7 @@ def accepts_legacy_random_state[**P, R]( If neither is given, ``random_state_default`` is used. """ - def decorator(func: callable[P, R]) -> callable[P, R]: + def decorator(func: Callable[P, R]) -> Callable[P, R]: @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: match "random_state" in kwargs, "rng" in kwargs: diff --git a/src/scanpy/datasets/_datasets.py b/src/scanpy/datasets/_datasets.py index 5b56be0fad..a3775fe3a4 100644 --- a/src/scanpy/datasets/_datasets.py +++ b/src/scanpy/datasets/_datasets.py @@ -12,7 +12,7 @@ from .._compat import deprecated from .._settings import settings from .._utils._doctests import doctest_internet, doctest_needs -from .._utils.random import accepts_legacy_random_state, legacy_random_state +from .._utils.random import _accepts_legacy_random_state, _legacy_random_state from ..readwrite import read, read_h5ad, read_visium from ._utils import check_datasetdir_exists @@ -55,7 +55,7 @@ HERE = Path(__file__).parent -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def blobs( *, n_variables: int = 11, @@ -100,7 +100,7 @@ def blobs( n_features=n_variables, centers=n_centers, cluster_std=cluster_std, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), ) return AnnData(x, obs=dict(blobs=y.astype(str))) diff --git a/src/scanpy/experimental/pp/_normalization.py b/src/scanpy/experimental/pp/_normalization.py index 9cd2d6d092..74c03d518c 100644 --- a/src/scanpy/experimental/pp/_normalization.py +++ b/src/scanpy/experimental/pp/_normalization.py @@ -9,7 +9,7 @@ from ... import logging as logg from ..._compat import CSBase, warn from ..._utils import _doc_params, _empty, check_nonnegative_integers, view_to_actual -from ..._utils.random import accepts_legacy_random_state +from ..._utils.random import _accepts_legacy_random_state from ...experimental._docs import ( doc_adata, doc_check_values, @@ -162,7 +162,7 @@ def normalize_pearson_residuals( check_values=doc_check_values, inplace=doc_inplace, ) -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def normalize_pearson_residuals_pca( adata: AnnData, *, diff --git a/src/scanpy/experimental/pp/_recipes.py b/src/scanpy/experimental/pp/_recipes.py index b688b8d4d8..ba944350bf 100644 --- a/src/scanpy/experimental/pp/_recipes.py +++ b/src/scanpy/experimental/pp/_recipes.py @@ -17,7 +17,7 @@ ) from scanpy.preprocessing import pca -from ..._utils.random import accepts_legacy_random_state +from ..._utils.random import _accepts_legacy_random_state if TYPE_CHECKING: from collections.abc import Mapping @@ -37,7 +37,7 @@ check_values=doc_check_values, inplace=doc_inplace, ) -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def recipe_pearson_residuals( # noqa: PLR0913 adata: AnnData, *, diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 853695c30b..332dae931a 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -18,9 +18,9 @@ from .._settings import settings from .._utils import NeighborsView, _doc_params, get_literal_vals from .._utils.random import ( + _accepts_legacy_random_state, _FakeRandomGen, - accepts_legacy_random_state, - legacy_random_state, + _legacy_random_state, ) from . import _connectivity from ._common import ( @@ -82,7 +82,7 @@ class NeighborsParams(TypedDict): # noqa: D101 @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def neighbors( # noqa: PLR0913 adata: AnnData, n_neighbors: int = 15, @@ -270,7 +270,7 @@ def neighbors( # noqa: PLR0913 key_added, n_neighbors=neighbors_.n_neighbors, method=method, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), metric=metric, **({} if not metric_kwds else dict(metric_kwds=metric_kwds)), **({} if use_rep is None else dict(use_rep=use_rep)), @@ -579,7 +579,7 @@ def to_igraph(self) -> Graph: return _utils.get_igraph_from_adjacency(self.connectivities) @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) - @accepts_legacy_random_state(0) + @_accepts_legacy_random_state(0) def compute_neighbors( self, n_neighbors: int = 30, @@ -627,7 +627,7 @@ def compute_neighbors( n_neighbors=n_neighbors, metric=metric, metric_params=metric_kwds, # most use _params, not _kwds - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), ) method, transformer, shortcut = self._handle_transformer( method, transformer, knn=knn, kwds=transformer_kwds_default @@ -849,7 +849,7 @@ def compute_transitions(self, *, density_normalize: bool = True) -> None: self._transitions_sym = self.Z @ conn_norm @ self.Z logg.info(" finished", time=start) - @accepts_legacy_random_state(0) + @_accepts_legacy_random_state(0) def compute_eigen( self, *, diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 98f27aa030..1863506447 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -11,9 +11,9 @@ from ..._settings import settings from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type from ..._utils.random import ( + _accepts_legacy_random_state, _FakeRandomGen, - accepts_legacy_random_state, - legacy_random_state, + _legacy_random_state, ) from ...get import _check_mask, _get_obs_rep from .._docs import doc_mask_var_hvg @@ -58,7 +58,7 @@ @_doc_params( mask_var_hvg=doc_mask_var_hvg, ) -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def pca( # noqa: PLR0912, PLR0913, PLR0915 data: AnnData | np.ndarray | CSBase, n_comps: int | None = None, @@ -305,13 +305,13 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 pca_ = PCA( n_components=n_comps, svd_solver=svd_solver, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), ) elif isinstance(x._meta, CSBase) or svd_solver == "covariance_eigh": from ._dask import PCAEighDask if not isinstance(rng, _FakeRandomGen) or rng._arg != 0: - msg = f"Ignoring rng={legacy_random_state(rng)} when using a sparse dask array" + msg = f"Ignoring rng={_legacy_random_state(rng)} when using a sparse dask array" warn(msg, UserWarning) if svd_solver not in {None, "covariance_eigh"}: msg = f"Ignoring {svd_solver=} when using a sparse dask array" @@ -324,7 +324,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 pca_ = PCA( n_components=n_comps, svd_solver=svd_solver, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), ) x_pca = pca_.fit_transform(x) else: @@ -351,7 +351,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 ) pca_ = TruncatedSVD( n_components=n_comps, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), algorithm=svd_solver, ) x_pca = pca_.fit_transform(x) diff --git a/src/scanpy/preprocessing/_pca/_compat.py b/src/scanpy/preprocessing/_pca/_compat.py index f1f4e33b0d..7809fff83a 100644 --- a/src/scanpy/preprocessing/_pca/_compat.py +++ b/src/scanpy/preprocessing/_pca/_compat.py @@ -10,7 +10,7 @@ from sklearn.utils.extmath import svd_flip from ..._compat import pkg_version -from ..._utils.random import accepts_legacy_random_state, legacy_random_state +from ..._utils.random import _accepts_legacy_random_state, _legacy_random_state if TYPE_CHECKING: from typing import Literal @@ -22,7 +22,7 @@ from ..._utils.random import RNGLike, SeedLike -@accepts_legacy_random_state(None) +@_accepts_legacy_random_state(None) def _pca_compat_sparse( x: CSBase, n_pcs: int, @@ -72,7 +72,7 @@ def rmat_op(v: NDArray[np.floating]): from sklearn.decomposition import PCA pca = PCA( - n_components=n_pcs, svd_solver=solver, random_state=legacy_random_state(rng) + n_components=n_pcs, svd_solver=solver, random_state=_legacy_random_state(rng) ) pca.explained_variance_ = ev pca.explained_variance_ratio_ = ev_ratio diff --git a/src/scanpy/preprocessing/_recipes.py b/src/scanpy/preprocessing/_recipes.py index 179a9ebdbe..4737cbe016 100644 --- a/src/scanpy/preprocessing/_recipes.py +++ b/src/scanpy/preprocessing/_recipes.py @@ -7,7 +7,7 @@ from .. import logging as logg from .. import preprocessing as pp from .._compat import CSBase -from .._utils.random import accepts_legacy_random_state +from .._utils.random import _accepts_legacy_random_state from ._deprecated.highly_variable_genes import ( filter_genes_cv_deprecated, filter_genes_dispersion, @@ -20,7 +20,7 @@ from .._utils.random import RNGLike, SeedLike -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def recipe_weinreb17( adata: AnnData, *, diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index b9215ae163..502204537b 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -10,7 +10,7 @@ from ... import logging as logg from ... import preprocessing as pp -from ..._utils.random import accepts_legacy_random_state, legacy_random_state +from ..._utils.random import _accepts_legacy_random_state, _legacy_random_state from ...get import _get_obs_rep from . import pipeline from .core import Scrublet @@ -20,7 +20,7 @@ from ...neighbors import _Metric, _MetricFn -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def scrublet( # noqa: PLR0913 adata: AnnData, adata_sim: AnnData | None = None, @@ -286,7 +286,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): return adata if copy else None -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def _scrublet_call_doublets( # noqa: PLR0913 adata_obs: AnnData, adata_sim: AnnData, @@ -463,7 +463,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 .get("sim_doublet_ratio", None) ), "n_neighbors": n_neighbors, - "random_state": legacy_random_state(rng), + "random_state": _legacy_random_state(rng), }, } @@ -488,7 +488,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 return adata_obs -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def scrublet_simulate_doublets( adata: AnnData, *, diff --git a/src/scanpy/preprocessing/_scrublet/pipeline.py b/src/scanpy/preprocessing/_scrublet/pipeline.py index 53d98ff8ed..78b8d0aa7f 100644 --- a/src/scanpy/preprocessing/_scrublet/pipeline.py +++ b/src/scanpy/preprocessing/_scrublet/pipeline.py @@ -6,7 +6,7 @@ from fast_array_utils.stats import mean_var from scipy import sparse -from ..._utils.random import legacy_random_state +from ..._utils.random import _legacy_random_state from .sparse_utils import sparse_multiply, sparse_zscore if TYPE_CHECKING: @@ -58,7 +58,7 @@ def truncated_svd( svd = TruncatedSVD( n_components=n_prin_comps, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), algorithm=algorithm, ).fit(self._counts_obs_norm) self.set_manifold( @@ -83,7 +83,7 @@ def pca( pca = PCA( n_components=n_prin_comps, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), svd_solver=svd_solver, ).fit(x_obs) self.set_manifold(pca.transform(x_obs), pca.transform(x_sim)) diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 51f8b4a726..42f6a70931 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -29,7 +29,7 @@ sanitize_anndata, view_to_actual, ) -from .._utils.random import _if_legacy_apply_global, accepts_legacy_random_state +from .._utils.random import _accepts_legacy_random_state, _if_legacy_apply_global from ..get import _check_mask, _get_obs_rep, _set_obs_rep from ._distributed import materialize_as_ndarray @@ -972,7 +972,7 @@ def sample( # noqa: PLR0912 return subset, indices -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def downsample_counts( adata: AnnData, counts_per_cell: int | Collection[int] | None = None, diff --git a/src/scanpy/preprocessing/_utils.py b/src/scanpy/preprocessing/_utils.py index 3da0fed2d5..87d2402ad1 100644 --- a/src/scanpy/preprocessing/_utils.py +++ b/src/scanpy/preprocessing/_utils.py @@ -5,7 +5,7 @@ import numpy as np from sklearn.random_projection import sample_without_replacement -from .._utils.random import legacy_random_state +from .._utils.random import _legacy_random_state if TYPE_CHECKING: from typing import Literal @@ -23,7 +23,7 @@ def sample_comb( ] = "auto", ) -> NDArray[np.int64]: """Randomly sample indices from a grid, without repeating the same tuple.""" - random_state = legacy_random_state(rng) + random_state = _legacy_random_state(rng) idx = sample_without_replacement( np.prod(dims), nsamp, random_state=random_state, method=method ) diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index 878f0d1d1d..b07f4e3d20 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -4,7 +4,7 @@ import numpy as np -from .._utils.random import accepts_legacy_random_state +from .._utils.random import _accepts_legacy_random_state from ._dpt import _diffmap if TYPE_CHECKING: @@ -13,7 +13,7 @@ from .._utils.random import RNGLike, SeedLike -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def diffmap( adata: AnnData, n_comps: int = 15, diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index 01c9b151a5..6c019dd6fb 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -9,10 +9,10 @@ from .. import logging as logg from .._utils import _choose_graph, get_literal_vals from .._utils.random import ( + _accepts_legacy_random_state, _if_legacy_apply_global, - accepts_legacy_random_state, - legacy_random_state, - set_igraph_rng, + _legacy_random_state, + _set_igraph_rng, ) from ._utils import get_init_pos_from_paga @@ -28,7 +28,7 @@ type _Layout = Literal["fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa"] -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def draw_graph( # noqa: PLR0913 adata: AnnData, layout: _Layout = "fa", @@ -144,7 +144,7 @@ def draw_graph( # noqa: PLR0913 positions = np.array(fa2_positions(adjacency, init_coords, **kwds)) else: g = _utils.get_igraph_from_adjacency(adjacency) - with set_igraph_rng(rng): + with _set_igraph_rng(rng): if layout in {"fr", "drl", "kk", "grid_fr"}: ig_layout = g.layout(layout, seed=init_coords.tolist(), **kwds) elif "rt" in layout: @@ -156,7 +156,7 @@ def draw_graph( # noqa: PLR0913 positions = np.array(ig_layout.coords) adata.uns["draw_graph"] = {} adata.uns["draw_graph"]["params"] = dict( - layout=layout, random_state=legacy_random_state(rng) + layout=layout, random_state=_legacy_random_state(rng) ) key_added = f"X_draw_graph_{key_added_ext or layout}" adata.obsm[key_added] = positions diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index f43723c6e0..a5c9d493b1 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -11,9 +11,9 @@ from .._compat import warn from .._utils import _doc_params from .._utils.random import ( - accepts_legacy_random_state, - legacy_random_state, - set_igraph_rng, + _accepts_legacy_random_state, + _legacy_random_state, + _set_igraph_rng, ) from ._docs import ( doc_adata, @@ -48,7 +48,7 @@ neighbors_key=doc_neighbors_key.format(method="leiden"), obsp=doc_obsp, ) -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def leiden( # noqa: PLR0913 adata: AnnData, resolution: float = 1, @@ -169,7 +169,7 @@ def leiden( # noqa: PLR0913 partition_type = leidenalg.RBConfigurationVertexPartition if use_weights: clustering_args["weights"] = np.array(g.es["weight"]).astype(np.float64) - clustering_args["seed"] = legacy_random_state(rng) + clustering_args["seed"] = _legacy_random_state(rng) part = cast( "MutableVertexPartition", leidenalg.find_partition(g, partition_type, **clustering_args), @@ -181,7 +181,7 @@ def leiden( # noqa: PLR0913 if resolution is not None: clustering_args["resolution"] = resolution clustering_args.setdefault("objective_function", "modularity") - with set_igraph_rng(rng): + with _set_igraph_rng(rng): part = g.community_leiden(**clustering_args) # store output into adata.obs groups = np.array(part.membership) @@ -204,7 +204,7 @@ def leiden( # noqa: PLR0913 adata.uns[key_added] = {} adata.uns[key_added]["params"] = dict( resolution=resolution, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), n_iterations=n_iterations, ) adata.uns[key_added]["modularity"] = part.modularity diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index 97bd90dc8f..ae27a02904 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -10,7 +10,7 @@ from .. import logging as logg from .._compat import CSBase from .._utils import check_use_raw, is_backed_type -from .._utils.random import _if_legacy_apply_global, accepts_legacy_random_state +from .._utils.random import _accepts_legacy_random_state, _if_legacy_apply_global from ..get import _get_obs_rep if TYPE_CHECKING: @@ -51,7 +51,7 @@ def _sparse_nanmean(x: CSBase, /, axis: Literal[0, 1]) -> NDArray[np.float64]: return m -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def score_genes( # noqa: PLR0913 adata: AnnData, gene_list: Sequence[str] | pd.Index[str], diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index f18e954dc8..e8d9b4414f 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -6,7 +6,7 @@ from .._compat import warn from .._settings import settings from .._utils import _doc_params, raise_not_implemented_error_if_backed_type -from .._utils.random import accepts_legacy_random_state, legacy_random_state +from .._utils.random import _accepts_legacy_random_state, _legacy_random_state from ..neighbors._doc import doc_n_pcs, doc_use_rep from ._utils import _choose_representation @@ -16,7 +16,7 @@ from .._utils.random import RNGLike, SeedLike -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) @_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep) def tsne( # noqa: PLR0913 adata: AnnData, @@ -110,7 +110,7 @@ def tsne( # noqa: PLR0913 n_jobs = settings.n_jobs if n_jobs is None else n_jobs params_sklearn = dict( perplexity=perplexity, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), verbose=settings.verbosity > 3, early_exaggeration=early_exaggeration, learning_rate=learning_rate, diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index 8d46298ec5..ba59c2712f 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -10,7 +10,7 @@ from .._compat import warn from .._settings import settings from .._utils import NeighborsView -from .._utils.random import accepts_legacy_random_state, legacy_random_state +from .._utils.random import _accepts_legacy_random_state, _legacy_random_state from ._utils import _choose_representation, get_init_pos_from_paga if TYPE_CHECKING: @@ -24,7 +24,7 @@ type _InitPos = Literal["paga", "spectral", "random"] -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def umap( # noqa: PLR0913, PLR0915 adata: AnnData, *, @@ -186,7 +186,7 @@ def umap( # noqa: PLR0913, PLR0915 init_coords = check_array(init_coords, dtype=np.float32, accept_sparse=False) if rng is not None: - adata.uns[key_uns]["params"]["random_state"] = legacy_random_state(rng) + adata.uns[key_uns]["params"]["random_state"] = _legacy_random_state(rng) neigh_params = neighbors["params"] x = _choose_representation( @@ -211,7 +211,7 @@ def umap( # noqa: PLR0913, PLR0915 negative_sample_rate=negative_sample_rate, n_epochs=n_epochs, init=init_coords, - random_state=legacy_random_state(rng, always_state=True), + random_state=_legacy_random_state(rng, always_state=True), metric=neigh_params.get("metric", "euclidean"), metric_kwds=neigh_params.get("metric_kwds", {}), densmap=False, @@ -251,7 +251,7 @@ def umap( # noqa: PLR0913, PLR0915 a=a, b=b, verbose=settings.verbosity > 3, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), ) x_umap = umap.fit_transform(x_contiguous) adata.obsm[key_obsm] = x_umap # annotate samples with UMAP coordinates