diff --git a/docs/release-notes/3983.feat.md b/docs/release-notes/3983.feat.md new file mode 100644 index 0000000000..ca67baba94 --- /dev/null +++ b/docs/release-notes/3983.feat.md @@ -0,0 +1 @@ +Add support for {class}`numpy.random.Generator` to all functions previously accepting a `random_state` parameter {smaller}`P Angerer` diff --git a/pyproject.toml b/pyproject.toml index 134c1f76d7..8f3cc6d91f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -278,6 +278,8 @@ filterwarnings = [ "ignore:The `igraph` implementation of leiden clustering:UserWarning", # everybody uses this zarr 3 feature, including us, XArray, lots of data out there … "ignore:Consolidated metadata is currently not part:UserWarning", + # joblib fallback to serial mode in restricted multiprocessing environments + "ignore:.*joblib will operate in serial mode:UserWarning", ] [tool.coverage] diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index 98d6ce8a1b..baf0c8fdc3 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -7,12 +7,12 @@ from typing import TYPE_CHECKING import numpy as np -from sklearn.utils import check_random_state from . import ensure_igraph if TYPE_CHECKING: from collections.abc import Generator + from typing import Self from numpy.typing import NDArray @@ -21,8 +21,10 @@ "RNGLike", "SeedLike", "_LegacyRandom", + "_if_legacy_apply_global", + "accepts_legacy_random_state", "ith_k_tuple", - "legacy_numpy_gen", + "legacy_random_state", "random_k_tuples", "random_str", ] @@ -38,34 +40,38 @@ class _RNGIgraph: - """Random number generator for igraph so global seed is not changed. + """Random number generator for igraph so global random state is not changed. See :func:`igraph.set_random_number_generator` for the requirements. """ - def __init__(self, random_state: int | np.random.RandomState = 0) -> None: - self._rng = check_random_state(random_state) + def __init__(self, rng: SeedLike | RNGLike | None) -> None: + self._rng = np.random.default_rng(rng) def getrandbits(self, k: int) -> int: - return self._rng.tomaxint() & ((1 << k) - 1) + if isinstance(self._rng, _FakeRandomGen): + i = self._rng._state.tomaxint() + else: + lims = np.iinfo(np.uint64) + i = int(self._rng.integers(0, lims.max, dtype=np.uint64, endpoint=True)) + return i & ((1 << k) - 1) - def randint(self, a: int, b: int) -> int: - return self._rng.randint(a, b + 1) + def randint(self, a: int, b: int) -> np.int64: + """Can’t use `endpoint` here as _FakeRandomGen doesn’t support it.""" + return self._rng.integers(a, b + 1) def __getattr__(self, attr: str): return getattr(self._rng, "normal" if attr == "gauss" else attr) @contextmanager -def set_igraph_random_state( - random_state: int | np.random.RandomState, -) -> Generator[None, None, None]: +def set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]: ensure_igraph() import igraph - rng = _RNGIgraph(random_state) + ig_rng = _RNGIgraph(rng) try: - igraph.set_random_number_generator(rng) + igraph.set_random_number_generator(ig_rng) yield None finally: igraph.set_random_number_generator(random) @@ -76,26 +82,42 @@ def set_igraph_random_state( ################################### -def legacy_numpy_gen( - random_state: _LegacyRandom | None = None, -) -> np.random.Generator: - """Return a random generator that behaves like the legacy one.""" - if random_state is not None: - if isinstance(random_state, np.random.RandomState): - np.random.set_state(random_state.get_state(legacy=False)) - return _FakeRandomGen(random_state) - np.random.seed(random_state) - return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator())) - - class _FakeRandomGen(np.random.Generator): + _arg: _LegacyRandom _state: np.random.RandomState - def __init__(self, random_state: np.random.RandomState) -> None: - self._state = random_state + def __init__( + self, arg: _LegacyRandom, state: np.random.RandomState | None = None + ) -> None: + self._arg = arg + self._state = np.random.RandomState(arg) if state is None else state + super().__init__(self._state._bit_generator) + + @classmethod + def wrap_global( + cls, + arg: _LegacyRandom = None, + state: np.random.RandomState | None = None, + ) -> Self: + """Create a generator that wraps the global `RandomState` backing the legacy `np.random` functions.""" + if arg is not None: + if isinstance(arg, np.random.RandomState): + np.random.set_state(arg.get_state(legacy=False)) + return _FakeRandomGen(arg, state) + np.random.seed(arg) + return _FakeRandomGen(arg, np.random.RandomState(np.random.get_bit_generator())) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _FakeRandomGen): + return False + return self._arg == other._arg + + def __hash__(self) -> int: + return hash((type(self), self._arg)) @classmethod def _delegate(cls) -> None: + names = dict(integers="randint") for name, meth in np.random.Generator.__dict__.items(): if name.startswith("_") or not callable(meth): continue @@ -108,12 +130,66 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs): return wrapper - setattr(cls, name, mk_wrapper(name, meth)) + setattr(cls, names.get(name, name), mk_wrapper(name, meth)) _FakeRandomGen._delegate() +def _if_legacy_apply_global(rng: np.random.Generator) -> np.random.Generator: + """Re-apply legacy `random_state` semantics when `rng` is a `_FakeRandomGen`. + + This resets the global legacy RNG from the original `_arg` and returns a + generator which continues drawing from the same internal state. + """ + if not isinstance(rng, _FakeRandomGen): + return rng + + return _FakeRandomGen.wrap_global(rng._arg, rng._state) + + +def legacy_random_state( + rng: SeedLike | RNGLike | None, *, always_state: bool = False +) -> _LegacyRandom: + """Convert a np.random.Generator into a legacy `random_state` argument. + + If `rng` is already a `_FakeRandomGen`, return its original `_arg` attribute. + """ + if isinstance(rng, _FakeRandomGen): + return rng._state if always_state else rng._arg + rng = np.random.default_rng(rng) + return np.random.RandomState(rng.bit_generator.spawn(1)[0]) + + +def accepts_legacy_random_state[**P, R]( + random_state_default: _LegacyRandom, +) -> callable[[callable[P, R]], callable[P, R]]: + """Make a function accept `random_state: _LegacyRandom` and pass it as `rng`. + + If the decorated function is called with a `random_state` argument, + it’ll be wrapped in a :class:`_FakeRandomGen`. + Passing both ``rng`` and ``random_state`` at the same time is an error. + If neither is given, ``random_state_default`` is used. + """ + + def decorator(func: callable[P, R]) -> callable[P, R]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + match "random_state" in kwargs, "rng" in kwargs: + case True, True: + msg = "Specify at most one of `rng` and `random_state`." + raise TypeError(msg) + case True, False: + kwargs["rng"] = _FakeRandomGen(kwargs.pop("random_state")) + case False, False: + kwargs["rng"] = _FakeRandomGen(random_state_default) + return func(*args, **kwargs) + + return wrapper + + return decorator + + ################### # Random k-tuples # ################### diff --git a/src/scanpy/datasets/_datasets.py b/src/scanpy/datasets/_datasets.py index 1a81939301..5b56be0fad 100644 --- a/src/scanpy/datasets/_datasets.py +++ b/src/scanpy/datasets/_datasets.py @@ -12,13 +12,14 @@ from .._compat import deprecated from .._settings import settings from .._utils._doctests import doctest_internet, doctest_needs +from .._utils.random import accepts_legacy_random_state, legacy_random_state from ..readwrite import read, read_h5ad, read_visium from ._utils import check_datasetdir_exists if TYPE_CHECKING: from typing import Literal - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike type VisiumSampleID = Literal[ "V1_Breast_Cancer_Block_A_Section_1", @@ -54,13 +55,14 @@ HERE = Path(__file__).parent +@accepts_legacy_random_state(0) def blobs( *, n_variables: int = 11, n_centers: int = 5, cluster_std: float = 1.0, n_observations: int = 640, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> AnnData: """Gaussian Blobs. @@ -75,7 +77,7 @@ def blobs( n_observations Number of observations. By default, this is the same observation number as in :func:`scanpy.datasets.krumsiek11`. - random_state + rng Determines random number generation for dataset creation. Returns @@ -98,7 +100,7 @@ def blobs( n_features=n_variables, centers=n_centers, cluster_std=cluster_std, - random_state=random_state, + random_state=legacy_random_state(rng), ) return AnnData(x, obs=dict(blobs=y.astype(str))) diff --git a/src/scanpy/experimental/_docs.py b/src/scanpy/experimental/_docs.py index c6f1bf2f8b..a449c317f1 100644 --- a/src/scanpy/experimental/_docs.py +++ b/src/scanpy/experimental/_docs.py @@ -60,8 +60,8 @@ doc_pca_chunk = """\ n_comps Number of principal components to compute in the PCA step. -random_state - Random seed for setting the initial states for the optimization in the PCA step. +rng + Random number generator for setting the initial states for the optimization in the PCA step. kwargs_pca Dictionary of further keyword arguments passed on to `scanpy.pp.pca()`. """ diff --git a/src/scanpy/experimental/pp/_normalization.py b/src/scanpy/experimental/pp/_normalization.py index cb34b9902b..9cd2d6d092 100644 --- a/src/scanpy/experimental/pp/_normalization.py +++ b/src/scanpy/experimental/pp/_normalization.py @@ -9,6 +9,7 @@ from ... import logging as logg from ..._compat import CSBase, warn from ..._utils import _doc_params, _empty, check_nonnegative_integers, view_to_actual +from ..._utils.random import accepts_legacy_random_state from ...experimental._docs import ( doc_adata, doc_check_values, @@ -27,6 +28,7 @@ from typing import Any from ..._utils import Empty + from ..._utils.random import RNGLike, SeedLike def _pearson_residuals( @@ -160,13 +162,14 @@ def normalize_pearson_residuals( check_values=doc_check_values, inplace=doc_inplace, ) +@accepts_legacy_random_state(0) def normalize_pearson_residuals_pca( adata: AnnData, *, theta: float = 100, clip: float | None = None, n_comps: int | None = 50, - random_state: float = 0, + rng: SeedLike | RNGLike | None = None, kwargs_pca: Mapping[str, Any] = MappingProxyType({}), mask_var: np.ndarray | str | None | Empty = _empty, use_highly_variable: bool | None = None, @@ -233,7 +236,7 @@ def normalize_pearson_residuals_pca( normalize_pearson_residuals( adata_pca, theta=theta, clip=clip, check_values=check_values ) - pca(adata_pca, n_comps=n_comps, random_state=random_state, **kwargs_pca) + pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca) n_comps = adata_pca.obsm["X_pca"].shape[1] # might be None if inplace: diff --git a/src/scanpy/experimental/pp/_recipes.py b/src/scanpy/experimental/pp/_recipes.py index 27d272fc4d..b688b8d4d8 100644 --- a/src/scanpy/experimental/pp/_recipes.py +++ b/src/scanpy/experimental/pp/_recipes.py @@ -17,6 +17,8 @@ ) from scanpy.preprocessing import pca +from ..._utils.random import accepts_legacy_random_state + if TYPE_CHECKING: from collections.abc import Mapping from typing import Any @@ -24,6 +26,8 @@ import pandas as pd from anndata import AnnData + from ..._utils.random import RNGLike, SeedLike + @_doc_params( adata=doc_adata, @@ -33,6 +37,7 @@ check_values=doc_check_values, inplace=doc_inplace, ) +@accepts_legacy_random_state(0) def recipe_pearson_residuals( # noqa: PLR0913 adata: AnnData, *, @@ -42,7 +47,7 @@ def recipe_pearson_residuals( # noqa: PLR0913 batch_key: str | None = None, chunksize: int = 1000, n_comps: int | None = 50, - random_state: float | None = 0, + rng: SeedLike | RNGLike | None = None, kwargs_pca: Mapping[str, Any] = MappingProxyType({}), check_values: bool = True, inplace: bool = True, @@ -133,7 +138,7 @@ def recipe_pearson_residuals( # noqa: PLR0913 experimental.pp.normalize_pearson_residuals( adata_pca, theta=theta, clip=clip, check_values=check_values ) - pca(adata_pca, n_comps=n_comps, random_state=random_state, **kwargs_pca) + pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca) if inplace: normalization_param = adata_pca.uns["pearson_residuals_normalization"] diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 87314140d5..853695c30b 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -11,13 +11,17 @@ import numpy as np import scipy from scipy import sparse -from sklearn.utils import check_random_state from .. import _utils from .. import logging as logg from .._compat import CSBase, CSRBase, SpBase, warn from .._settings import settings from .._utils import NeighborsView, _doc_params, get_literal_vals +from .._utils.random import ( + _FakeRandomGen, + accepts_legacy_random_state, + legacy_random_state, +) from . import _connectivity from ._common import ( _get_indices_distances_from_dense_matrix, @@ -36,7 +40,7 @@ from igraph import Graph from numpy.typing import NDArray - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike, _LegacyRandom from ._types import KnnTransformerLike, _Metric, _MetricFn # TODO: make `type` when https://github.com/sphinx-doc/sphinx/pull/13508 is released @@ -78,6 +82,7 @@ class NeighborsParams(TypedDict): # noqa: D101 @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) +@accepts_legacy_random_state(0) def neighbors( # noqa: PLR0913 adata: AnnData, n_neighbors: int = 15, @@ -90,7 +95,7 @@ def neighbors( # noqa: PLR0913 transformer: KnnTransformerLike | _KnownTransformer | None = None, metric: _Metric | _MetricFn | None = None, metric_kwds: Mapping[str, Any] = MappingProxyType({}), - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, key_added: str | None = None, copy: bool = False, ) -> AnnData | None: @@ -158,8 +163,8 @@ def neighbors( # noqa: PLR0913 Options for the metric. *ignored if ``transformer`` is an instance.* - random_state - A numpy random seed. + rng + A numpy random number generator. *ignored if ``transformer`` is an instance.* key_added @@ -220,21 +225,24 @@ def neighbors( # noqa: PLR0913 transformer=transformer, metric=metric, metric_kwds=metric_kwds, - random_state=random_state, + rng=rng, ) else: params = locals() - if ignored := { + ignored = { p.name for p in signature(neighbors).parameters.values() - if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds", "random_state"} + if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds"} if params[p.name] != p.default - }: + } + if not isinstance(rng, _FakeRandomGen) or rng._arg != 0: + ignored.add("rng/random_state") + rng = _FakeRandomGen(0) + if ignored: warn( f"Parameter(s) ignored if `distances` is given: {ignored}", UserWarning, ) - random_state = 0 if callable(metric): msg = "`metric` must be a string if `distances` is given." raise TypeError(msg) @@ -262,7 +270,7 @@ def neighbors( # noqa: PLR0913 key_added, n_neighbors=neighbors_.n_neighbors, method=method, - random_state=random_state, + random_state=legacy_random_state(rng), metric=metric, **({} if not metric_kwds else dict(metric_kwds=metric_kwds)), **({} if use_rep is None else dict(use_rep=use_rep)), @@ -571,6 +579,7 @@ def to_igraph(self) -> Graph: return _utils.get_igraph_from_adjacency(self.connectivities) @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) + @accepts_legacy_random_state(0) def compute_neighbors( self, n_neighbors: int = 30, @@ -582,7 +591,7 @@ def compute_neighbors( transformer: KnnTransformerLike | _KnownTransformer | None = None, metric: _Metric | _MetricFn = "euclidean", metric_kwds: Mapping[str, Any] = MappingProxyType({}), - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> None: """Compute distances and connectivities of neighbors. @@ -618,7 +627,7 @@ def compute_neighbors( n_neighbors=n_neighbors, metric=metric, metric_params=metric_kwds, # most use _params, not _kwds - random_state=random_state, + random_state=legacy_random_state(rng), ) method, transformer, shortcut = self._handle_transformer( method, transformer, knn=knn, kwds=transformer_kwds_default @@ -840,13 +849,14 @@ def compute_transitions(self, *, density_normalize: bool = True) -> None: self._transitions_sym = self.Z @ conn_norm @ self.Z logg.info(" finished", time=start) + @accepts_legacy_random_state(0) def compute_eigen( self, *, n_comps: int = 15, sym: bool | None = None, sort: Literal["decrease", "increase"] = "decrease", - random_state: _LegacyRandom = 0, + rng: np.random.Generator, ): """Compute eigen decomposition of transition matrix. @@ -859,8 +869,8 @@ def compute_eigen( Instead of computing the eigendecomposition of the assymetric transition matrix, computed the eigendecomposition of the symmetric Ktilde matrix. - random_state - A numpy random seed + rng + A numpy random number generator Returns ------- @@ -893,8 +903,7 @@ def compute_eigen( matrix = matrix.astype(np.float64) # Setting the random initial vector - random_state = check_random_state(random_state) - v0 = random_state.standard_normal(matrix.shape[0]) + v0 = rng.standard_normal(matrix.shape[0]) evals, evecs = sparse.linalg.eigsh( matrix, k=n_comps, which=which, ncv=ncv, v0=v0 ) diff --git a/src/scanpy/preprocessing/_deprecated/sampling.py b/src/scanpy/preprocessing/_deprecated/sampling.py index b153b9489f..2663462a68 100644 --- a/src/scanpy/preprocessing/_deprecated/sampling.py +++ b/src/scanpy/preprocessing/_deprecated/sampling.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from ..._utils.random import legacy_numpy_gen +from ..._utils.random import _FakeRandomGen from .._simple import sample if TYPE_CHECKING: @@ -50,7 +50,7 @@ def subsample( returns a subsampled copy of it (`copy == True`). """ - rng = legacy_numpy_gen(random_state) + rng = _FakeRandomGen.wrap_global(random_state) return sample( data=data, fraction=fraction, n=n_obs, rng=rng, copy=copy, replace=False, axis=0 ) diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 5c6f3b7bf2..98f27aa030 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -5,12 +5,16 @@ import numpy as np from anndata import AnnData from packaging.version import Version -from sklearn.utils import check_random_state from ... import logging as logg from ..._compat import CSBase, DaskArray, pkg_version, warn from ..._settings import settings from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type +from ..._utils.random import ( + _FakeRandomGen, + accepts_legacy_random_state, + legacy_random_state, +) from ...get import _check_mask, _get_obs_rep from .._docs import doc_mask_var_hvg from ._compat import _pca_compat_sparse @@ -25,7 +29,7 @@ from numpy.typing import DTypeLike, NDArray from ..._utils import Empty - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike type MethodDaskML = type[dmld.PCA | dmld.IncrementalPCA | dmld.TruncatedSVD] @@ -54,6 +58,7 @@ @_doc_params( mask_var_hvg=doc_mask_var_hvg, ) +@accepts_legacy_random_state(0) def pca( # noqa: PLR0912, PLR0913, PLR0915 data: AnnData | np.ndarray | CSBase, n_comps: int | None = None, @@ -64,7 +69,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 svd_solver: SvdSolver | None = None, chunked: bool = False, chunk_size: int | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, return_info: bool = False, mask_var: NDArray[np.bool] | str | None | Empty = _empty, use_highly_variable: bool | None = None, @@ -157,7 +162,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 chunk_size Number of observations to include in each chunk. Required if `chunked=True` was passed. - random_state + rng Change to use different initial states for the optimization. return_info Only relevant when not passing an :class:`~anndata.AnnData`: @@ -241,21 +246,23 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 msg = f"PCA is not implemented for matrices of type {type(x)} from layers/obsm" raise NotImplementedError(msg) - # check_random_state returns a numpy RandomState when passed an int but # dask needs an int for random state - if not isinstance(x, DaskArray): - random_state = check_random_state(random_state) - elif not isinstance(random_state, int): - msg = f"random_state needs to be an int, not a {type(random_state).__name__} when passing a dask array" + rng = np.random.default_rng(rng) + if not isinstance(rng, _FakeRandomGen) and not isinstance( + rng._arg, int | np.random.RandomState + ): + # TODO: remove this error and if we don’t have a _FakeRandomGen, + # just use rng.integers to make a seed farther down + msg = f"rng needs to be an int or a np.random.RandomState, not a {type(rng).__name__} when passing a dask array" raise TypeError(msg) if chunked: if ( not zero_center - or random_state + or (not isinstance(rng, _FakeRandomGen) or rng._arg != 0) or (svd_solver is not None and svd_solver != "arpack") ): - logg.debug("Ignoring zero_center, random_state, svd_solver") + logg.debug("Ignoring zero_center, rng, svd_solver") incremental_pca_kwargs = dict() if isinstance(x, DaskArray): @@ -287,9 +294,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 "Also the lobpcg solver has been observed to be inaccurate. Please use 'arpack' instead." ) warn(msg, FutureWarning) - x_pca, pca_ = _pca_compat_sparse( - x, n_comps, solver=svd_solver, random_state=random_state - ) + x_pca, pca_ = _pca_compat_sparse(x, n_comps, solver=svd_solver, rng=rng) else: if not isinstance(x, DaskArray): from sklearn.decomposition import PCA @@ -300,13 +305,13 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 pca_ = PCA( n_components=n_comps, svd_solver=svd_solver, - random_state=random_state, + random_state=legacy_random_state(rng), ) elif isinstance(x._meta, CSBase) or svd_solver == "covariance_eigh": from ._dask import PCAEighDask - if random_state != 0: - msg = f"Ignoring {random_state=} when using a sparse dask array" + if not isinstance(rng, _FakeRandomGen) or rng._arg != 0: + msg = f"Ignoring rng={legacy_random_state(rng)} when using a sparse dask array" warn(msg, UserWarning) if svd_solver not in {None, "covariance_eigh"}: msg = f"Ignoring {svd_solver=} when using a sparse dask array" @@ -319,7 +324,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 pca_ = PCA( n_components=n_comps, svd_solver=svd_solver, - random_state=random_state, + random_state=legacy_random_state(rng), ) x_pca = pca_.fit_transform(x) else: @@ -345,7 +350,9 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 " the following components often resemble the exact PCA very closely" ) pca_ = TruncatedSVD( - n_components=n_comps, random_state=random_state, algorithm=svd_solver + n_components=n_comps, + random_state=legacy_random_state(rng), + algorithm=svd_solver, ) x_pca = pca_.fit_transform(x) diff --git a/src/scanpy/preprocessing/_pca/_compat.py b/src/scanpy/preprocessing/_pca/_compat.py index b1c64e0735..f1f4e33b0d 100644 --- a/src/scanpy/preprocessing/_pca/_compat.py +++ b/src/scanpy/preprocessing/_pca/_compat.py @@ -6,10 +6,11 @@ from fast_array_utils.stats import mean_var from packaging.version import Version from scipy.sparse.linalg import LinearOperator, svds -from sklearn.utils import check_array, check_random_state +from sklearn.utils import check_array from sklearn.utils.extmath import svd_flip from ..._compat import pkg_version +from ..._utils.random import accepts_legacy_random_state, legacy_random_state if TYPE_CHECKING: from typing import Literal @@ -18,21 +19,21 @@ from sklearn.decomposition import PCA from ..._compat import CSBase - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike +@accepts_legacy_random_state(None) def _pca_compat_sparse( x: CSBase, n_pcs: int, *, solver: Literal["arpack", "lobpcg"], mu: NDArray[np.floating] | None = None, - random_state: _LegacyRandom = None, + rng: SeedLike | RNGLike | None = None, ) -> tuple[NDArray[np.floating], PCA]: """Sparse PCA for scikit-learn <1.4.""" - random_state = check_random_state(random_state) - np.random.set_state(random_state.get_state()) - random_init = np.random.rand(np.min(x.shape)) + rng = np.random.default_rng(rng) + random_init = rng.uniform(size=np.min(x.shape)) x = check_array(x, accept_sparse=["csr", "csc"]) if mu is None: @@ -70,7 +71,9 @@ def rmat_op(v: NDArray[np.floating]): from sklearn.decomposition import PCA - pca = PCA(n_components=n_pcs, svd_solver=solver, random_state=random_state) + pca = PCA( + n_components=n_pcs, svd_solver=solver, random_state=legacy_random_state(rng) + ) pca.explained_variance_ = ev pca.explained_variance_ratio_ = ev_ratio pca.components_ = v diff --git a/src/scanpy/preprocessing/_recipes.py b/src/scanpy/preprocessing/_recipes.py index 7ac262b166..179a9ebdbe 100644 --- a/src/scanpy/preprocessing/_recipes.py +++ b/src/scanpy/preprocessing/_recipes.py @@ -7,6 +7,7 @@ from .. import logging as logg from .. import preprocessing as pp from .._compat import CSBase +from .._utils.random import accepts_legacy_random_state from ._deprecated.highly_variable_genes import ( filter_genes_cv_deprecated, filter_genes_dispersion, @@ -16,9 +17,10 @@ if TYPE_CHECKING: from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike +@accepts_legacy_random_state(0) def recipe_weinreb17( adata: AnnData, *, @@ -27,7 +29,7 @@ def recipe_weinreb17( cv_threshold: int = 2, n_pcs: int = 50, svd_solver="randomized", - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, copy: bool = False, ) -> AnnData | None: """Normalize and filter as of :cite:p:`Weinreb2017`. @@ -63,7 +65,7 @@ def recipe_weinreb17( zscore_deprecated(adata.X), n_comps=n_pcs, svd_solver=svd_solver, - random_state=random_state, + rng=rng, ) # update adata adata.obsm["X_pca"] = x_pca diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index a5a080e814..b9215ae163 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -10,15 +10,17 @@ from ... import logging as logg from ... import preprocessing as pp +from ..._utils.random import accepts_legacy_random_state, legacy_random_state from ...get import _get_obs_rep from . import pipeline from .core import Scrublet if TYPE_CHECKING: - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike from ...neighbors import _Metric, _MetricFn +@accepts_legacy_random_state(0) def scrublet( # noqa: PLR0913 adata: AnnData, adata_sim: AnnData | None = None, @@ -39,7 +41,7 @@ def scrublet( # noqa: PLR0913 threshold: float | None = None, verbose: bool = True, copy: bool = False, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> AnnData | None: """Predict doublets using Scrublet :cite:p:`Wolock2019`. @@ -127,7 +129,7 @@ def scrublet( # noqa: PLR0913 copy If :data:`True`, return a copy of the input ``adata`` with Scrublet results added. Otherwise, Scrublet results are added in place. - random_state + rng Initial state for doublet simulation and nearest neighbors. Returns @@ -158,6 +160,7 @@ def scrublet( # noqa: PLR0913 scores for observed transcriptomes and simulated doublets. """ + rng = np.random.default_rng(rng) if threshold is None and not find_spec("skimage"): # pragma: no cover # Scrublet.call_doublets requires `skimage` with `threshold=None` but PCA # is called early, which is wasteful if there is not `skimage` @@ -204,7 +207,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): layer="raw", sim_doublet_ratio=sim_doublet_ratio, synthetic_doublet_umi_subsampling=synthetic_doublet_umi_subsampling, - random_seed=random_state, + rng=rng, ) del ad_obs.layers["raw"] if log_transform: @@ -229,7 +232,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): knn_dist_metric=knn_dist_metric, get_doublet_neighbor_parents=get_doublet_neighbor_parents, threshold=threshold, - random_state=random_state, + rng=rng, verbose=verbose, ) @@ -283,22 +286,23 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): return adata if copy else None +@accepts_legacy_random_state(0) def _scrublet_call_doublets( # noqa: PLR0913 adata_obs: AnnData, adata_sim: AnnData, *, - n_neighbors: int | None = None, - expected_doublet_rate: float = 0.05, - stdev_doublet_rate: float = 0.02, - mean_center: bool = True, - normalize_variance: bool = True, - n_prin_comps: int = 30, - use_approx_neighbors: bool | None = None, - knn_dist_metric: _Metric | _MetricFn = "euclidean", - get_doublet_neighbor_parents: bool = False, - threshold: float | None = None, - random_state: _LegacyRandom = 0, - verbose: bool = True, + n_neighbors: int | None, + expected_doublet_rate: float, + stdev_doublet_rate: float, + mean_center: bool, + normalize_variance: bool, + n_prin_comps: int, + use_approx_neighbors: bool | None, + knn_dist_metric: _Metric | _MetricFn, + get_doublet_neighbor_parents: bool, + threshold: float | None, + rng: np.random.Generator, + verbose: bool, ) -> AnnData: """Core function for predicting doublets using Scrublet :cite:p:`Wolock2019`. @@ -356,8 +360,8 @@ def _scrublet_call_doublets( # noqa: PLR0913 practice to check the threshold visually using the `doublet_scores_sim_` histogram and/or based on co-localization of predicted doublets in a 2-D embedding. - random_state - Initial state for doublet simulation and nearest neighbors. + rng + Random number generator for doublet simulation and nearest neighbors. verbose If :data:`True`, log progress updates. @@ -394,7 +398,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 n_neighbors=n_neighbors, expected_doublet_rate=expected_doublet_rate, stdev_doublet_rate=stdev_doublet_rate, - random_state=random_state, + rng=rng, ) # Ensure normalised matrix sparseness as Scrublet does @@ -420,11 +424,13 @@ def _scrublet_call_doublets( # noqa: PLR0913 if mean_center: logg.info("Embedding transcriptomes using PCA...") - pipeline.pca(scrub, n_prin_comps=n_prin_comps, random_state=scrub._random_state) + pipeline.pca( + scrub, n_prin_comps=n_prin_comps, svd_solver="arpack", rng=scrub._rng + ) else: logg.info("Embedding transcriptomes using Truncated SVD...") pipeline.truncated_svd( - scrub, n_prin_comps=n_prin_comps, random_state=scrub._random_state + scrub, n_prin_comps=n_prin_comps, algorithm="arpack", rng=scrub._rng ) # Score the doublets @@ -457,7 +463,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 .get("sim_doublet_ratio", None) ), "n_neighbors": n_neighbors, - "random_state": random_state, + "random_state": legacy_random_state(rng), }, } @@ -482,13 +488,14 @@ def _scrublet_call_doublets( # noqa: PLR0913 return adata_obs +@accepts_legacy_random_state(0) def scrublet_simulate_doublets( adata: AnnData, *, layer: str | None = None, sim_doublet_ratio: float = 2.0, synthetic_doublet_umi_subsampling: float = 1.0, - random_seed: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> AnnData: """Simulate doublets by adding the counts of random observed transcriptome pairs. @@ -533,7 +540,7 @@ def scrublet_simulate_doublets( """ x = _get_obs_rep(adata, layer=layer) - scrub = Scrublet(x, random_state=random_seed) + scrub = Scrublet(x, rng=rng) scrub.simulate_doublets( sim_doublet_ratio=sim_doublet_ratio, diff --git a/src/scanpy/preprocessing/_scrublet/core.py b/src/scanpy/preprocessing/_scrublet/core.py index b7eb1e4785..a11ebf2f98 100644 --- a/src/scanpy/preprocessing/_scrublet/core.py +++ b/src/scanpy/preprocessing/_scrublet/core.py @@ -7,7 +7,6 @@ import pandas as pd from anndata import AnnData, concat from scipy import sparse -from sklearn.utils import check_random_state from ... import logging as logg from ...neighbors import ( @@ -18,11 +17,10 @@ from .sparse_utils import subsample_counts if TYPE_CHECKING: - from numpy.random import RandomState from numpy.typing import NDArray from ..._compat import CSBase, CSCBase - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike from ...neighbors import _Metric, _MetricFn __all__ = ["Scrublet"] @@ -58,8 +56,8 @@ class Scrublet: stdev_doublet_rate Uncertainty in the expected doublet rate. - random_state - Random state for doublet simulation, approximate + rng + Random number generator for doublet simulation, approximate nearest neighbor search, and PCA/TruncatedSVD. """ @@ -72,12 +70,12 @@ class Scrublet: n_neighbors: InitVar[int | None] = None expected_doublet_rate: float = 0.1 stdev_doublet_rate: float = 0.02 - random_state: InitVar[_LegacyRandom] = 0 + rng: InitVar[SeedLike | RNGLike | None] = None # private fields _n_neighbors: int = field(init=False, repr=False) - _random_state: RandomState = field(init=False, repr=False) + _rng: np.random.Generator = field(init=False, repr=False) _counts_obs: CSCBase = field(init=False, repr=False) _total_counts_obs: NDArray[np.integer] = field(init=False, repr=False) @@ -169,7 +167,7 @@ def __post_init__( counts_obs: CSBase | NDArray[np.integer], total_counts_obs: NDArray[np.integer] | None, n_neighbors: int | None, - random_state: _LegacyRandom, + rng: SeedLike | RNGLike | None, ) -> None: self._counts_obs = sparse.csc_matrix(counts_obs) # noqa: TID251 self._total_counts_obs = ( @@ -182,7 +180,7 @@ def __post_init__( if n_neighbors is None else n_neighbors ) - self._random_state = check_random_state(random_state) + self._rng = np.random.default_rng(rng) def simulate_doublets( self, @@ -218,7 +216,7 @@ def simulate_doublets( n_obs = self._counts_obs.shape[0] n_sim = int(n_obs * sim_doublet_ratio) - pair_ix = sample_comb((n_obs, n_obs), n_sim, random_state=self._random_state) + pair_ix = sample_comb((n_obs, n_obs), n_sim, rng=self._rng) e1 = cast("CSCBase", self._counts_obs[pair_ix[:, 0], :]) e2 = cast("CSCBase", self._counts_obs[pair_ix[:, 1], :]) @@ -229,7 +227,7 @@ def simulate_doublets( e1 + e2, rate=synthetic_doublet_umi_subsampling, original_totals=tots1 + tots2, - random_seed=self._random_state, + rng=self._rng, ) else: self._counts_sim = e1 + e2 @@ -348,7 +346,7 @@ def _nearest_neighbor_classifier( knn=True, transformer=transformer, method=None, - random_state=self._random_state, + rng=self._rng, ) neighbors, _ = _get_indices_distances_from_sparse_matrix(knn.distances, k_adj) if use_approx_neighbors: diff --git a/src/scanpy/preprocessing/_scrublet/pipeline.py b/src/scanpy/preprocessing/_scrublet/pipeline.py index edc3417cd9..53d98ff8ed 100644 --- a/src/scanpy/preprocessing/_scrublet/pipeline.py +++ b/src/scanpy/preprocessing/_scrublet/pipeline.py @@ -6,12 +6,12 @@ from fast_array_utils.stats import mean_var from scipy import sparse +from ..._utils.random import legacy_random_state from .sparse_utils import sparse_multiply, sparse_zscore if TYPE_CHECKING: from typing import Literal - from ..._utils.random import _LegacyRandom from .core import Scrublet @@ -46,10 +46,10 @@ def zscore(self: Scrublet) -> None: def truncated_svd( self: Scrublet, - n_prin_comps: int = 30, + n_prin_comps: int, *, - random_state: _LegacyRandom = 0, - algorithm: Literal["arpack", "randomized"] = "arpack", + rng: np.random.Generator, + algorithm: Literal["arpack", "randomized"], ) -> None: if self._counts_sim_norm is None: msg = "_counts_sim_norm is not set" @@ -57,7 +57,9 @@ def truncated_svd( from sklearn.decomposition import TruncatedSVD svd = TruncatedSVD( - n_components=n_prin_comps, random_state=random_state, algorithm=algorithm + n_components=n_prin_comps, + random_state=legacy_random_state(rng), + algorithm=algorithm, ).fit(self._counts_obs_norm) self.set_manifold( svd.transform(self._counts_obs_norm), svd.transform(self._counts_sim_norm) @@ -66,10 +68,10 @@ def truncated_svd( def pca( self: Scrublet, - n_prin_comps: int = 50, + n_prin_comps: int, *, - random_state: _LegacyRandom = 0, - svd_solver: Literal["auto", "full", "arpack", "randomized"] = "arpack", + rng: np.random.Generator, + svd_solver: Literal["auto", "full", "arpack", "randomized"], ) -> None: if self._counts_sim_norm is None: msg = "_counts_sim_norm is not set" @@ -80,6 +82,8 @@ def pca( x_sim = self._counts_sim_norm.toarray() pca = PCA( - n_components=n_prin_comps, random_state=random_state, svd_solver=svd_solver + n_components=n_prin_comps, + random_state=legacy_random_state(rng), + svd_solver=svd_solver, ).fit(x_obs) self.set_manifold(pca.transform(x_obs), pca.transform(x_sim)) diff --git a/src/scanpy/preprocessing/_scrublet/sparse_utils.py b/src/scanpy/preprocessing/_scrublet/sparse_utils.py index 5b7e7aaaaf..611754f91b 100644 --- a/src/scanpy/preprocessing/_scrublet/sparse_utils.py +++ b/src/scanpy/preprocessing/_scrublet/sparse_utils.py @@ -5,13 +5,11 @@ import numpy as np from fast_array_utils.stats import mean_var from scipy import sparse -from sklearn.utils import check_random_state if TYPE_CHECKING: from numpy.typing import NDArray from ..._compat import CSBase - from ..._utils.random import _LegacyRandom def sparse_multiply( @@ -48,11 +46,10 @@ def subsample_counts( *, rate: float, original_totals, - random_seed: _LegacyRandom = 0, + rng: np.random.Generator, ) -> tuple[CSBase, NDArray[np.int64]]: if rate < 1: - random_seed = check_random_state(random_seed) - e.data = random_seed.binomial(np.round(e.data).astype(int), rate) + e.data = rng.binomial(np.round(e.data).astype(int), rate) current_totals = np.asarray(e.sum(1)).squeeze() unsampled_orig_totals = original_totals - current_totals unsampled_downsamp_totals = np.random.binomial( diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 33a8c8de01..51f8b4a726 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -29,6 +29,7 @@ sanitize_anndata, view_to_actual, ) +from .._utils.random import _if_legacy_apply_global, accepts_legacy_random_state from ..get import _check_mask, _get_obs_rep, _set_obs_rep from ._distributed import materialize_as_ndarray @@ -40,7 +41,7 @@ import pandas as pd from numpy.typing import NDArray - from .._utils.random import RNGLike, SeedLike, _LegacyRandom + from .._utils.random import RNGLike, SeedLike def filter_cells( @@ -971,12 +972,13 @@ def sample( # noqa: PLR0912 return subset, indices +@accepts_legacy_random_state(0) def downsample_counts( adata: AnnData, counts_per_cell: int | Collection[int] | None = None, total_counts: int | None = None, *, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, replace: bool = False, copy: bool = False, ) -> AnnData | None: @@ -1000,7 +1002,7 @@ def downsample_counts( total_counts Target total counts. If the count matrix has more than `total_counts` it will be downsampled to have this number. - random_state + rng Random seed for subsampling. replace Whether to sample the counts with replacement. @@ -1017,6 +1019,7 @@ def downsample_counts( """ raise_not_implemented_error_if_backed_type(adata.X, "downsample_counts") # This logic is all dispatch + rng = np.random.default_rng(rng) total_counts_call = total_counts is not None counts_per_cell_call = counts_per_cell is not None if total_counts_call is counts_per_cell_call: @@ -1026,11 +1029,11 @@ def downsample_counts( adata = adata.copy() if total_counts_call: adata.X = _downsample_total_counts( - adata.X, total_counts, random_state=random_state, replace=replace + adata.X, total_counts, rng=rng, replace=replace ) elif counts_per_cell_call: adata.X = _downsample_per_cell( - adata.X, counts_per_cell, random_state=random_state, replace=replace + adata.X, counts_per_cell, rng=rng, replace=replace ) if copy: return adata @@ -1041,7 +1044,7 @@ def _downsample_per_cell( /, counts_per_cell: int, *, - random_state: _LegacyRandom, + rng: np.random.Generator, replace: bool, ) -> CSBase: n_obs = x.shape[0] @@ -1068,11 +1071,7 @@ def _downsample_per_cell( for rowidx in under_target: row = rows[rowidx] _downsample_array( - row, - counts_per_cell[rowidx], - random_state=random_state, - replace=replace, - inplace=True, + row, counts_per_cell[rowidx], rng=rng, replace=replace, inplace=True ) x.eliminate_zeros() if not issubclass(original_type, CSRBase): # Put it back @@ -1083,11 +1082,7 @@ def _downsample_per_cell( for rowidx in under_target: row = x[rowidx, :] _downsample_array( - row, - counts_per_cell[rowidx], - random_state=random_state, - replace=replace, - inplace=True, + row, counts_per_cell[rowidx], rng=rng, replace=replace, inplace=True ) return x @@ -1097,7 +1092,7 @@ def _downsample_total_counts( /, total_counts: int, *, - random_state: _LegacyRandom, + rng: np.random.Generator, replace: bool, ) -> CSBase: total_counts = int(total_counts) @@ -1108,49 +1103,47 @@ def _downsample_total_counts( original_type = type(x) if not isinstance(x, CSRBase): x = x.tocsr() - _downsample_array( - x.data, - total_counts, - random_state=random_state, - replace=replace, - inplace=True, - ) + _downsample_array(x.data, total_counts, rng=rng, replace=replace, inplace=True) x.eliminate_zeros() if not issubclass(original_type, CSRBase): x = original_type(x) else: v = x.reshape(np.multiply(*x.shape)) - _downsample_array( - v, total_counts, random_state=random_state, replace=replace, inplace=True - ) + _downsample_array(v, total_counts, rng=rng, replace=replace, inplace=True) return x -# TODO: can/should this be parallelized? -@numba.njit(cache=True) # noqa: TID251 def _downsample_array( col: np.ndarray, target: int, *, - random_state: _LegacyRandom = 0, + rng: np.random.Generator, replace: bool = True, inplace: bool = False, -): +) -> np.ndarray: """Evenly reduce counts in cell to target amount. This is an internal function and has some restrictions: * total counts in cell must be less than target """ - np.random.seed(random_state) + rng = _if_legacy_apply_global(rng) cumcounts = col.cumsum() + total = np.int_(cumcounts[-1]) + sample = rng.choice(total, target, replace=replace) + sample.sort() + return _downsample_array_inner(col, cumcounts, sample, inplace=inplace) + + +# TODO: can/should this be parallelized? +@numba.njit(cache=True) # noqa: TID251 +def _downsample_array_inner( + col: np.ndarray, cumcounts: np.ndarray, sample: np.ndarray, *, inplace: bool +) -> np.ndarray: if inplace: col[:] = 0 else: col = np.zeros_like(col) - total = np.int_(cumcounts[-1]) - sample = np.random.choice(total, target, replace=replace) - sample.sort() geneptr = 0 for count in sample: while count >= cumcounts[geneptr]: diff --git a/src/scanpy/preprocessing/_utils.py b/src/scanpy/preprocessing/_utils.py index 0ba97c7b4e..3da0fed2d5 100644 --- a/src/scanpy/preprocessing/_utils.py +++ b/src/scanpy/preprocessing/_utils.py @@ -5,24 +5,25 @@ import numpy as np from sklearn.random_projection import sample_without_replacement +from .._utils.random import legacy_random_state + if TYPE_CHECKING: from typing import Literal from numpy.typing import NDArray - from .._utils.random import _LegacyRandom - def sample_comb( dims: tuple[int, ...], nsamp: int, *, - random_state: _LegacyRandom = None, + rng: np.random.Generator, method: Literal[ "auto", "tracking_selection", "reservoir_sampling", "pool" ] = "auto", ) -> NDArray[np.int64]: """Randomly sample indices from a grid, without repeating the same tuple.""" + random_state = legacy_random_state(rng) idx = sample_without_replacement( np.prod(dims), nsamp, random_state=random_state, method=method ) diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index de4538dfc2..878f0d1d1d 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -2,20 +2,24 @@ from typing import TYPE_CHECKING +import numpy as np + +from .._utils.random import accepts_legacy_random_state from ._dpt import _diffmap if TYPE_CHECKING: from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike +@accepts_legacy_random_state(0) def diffmap( adata: AnnData, n_comps: int = 15, *, neighbors_key: str | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, copy: bool = False, ) -> AnnData | None: """Diffusion Maps :cite:p:`Coifman2005,Haghverdi2015,Wolf2018`. @@ -48,8 +52,8 @@ def diffmap( .obsp[.uns[neighbors_key]['connectivities_key']] and .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances, respectively. - random_state - A numpy random seed + rng + A numpy random number generator copy Return a copy instead of writing to adata. @@ -73,6 +77,7 @@ def diffmap( e.g. `adata.obsm["X_diffmap"][:,1]` """ + rng = np.random.default_rng(rng) if neighbors_key is None: neighbors_key = "neighbors" @@ -83,7 +88,5 @@ def diffmap( msg = "Provide any value greater than 2 for `n_comps`. " raise ValueError(msg) adata = adata.copy() if copy else adata - _diffmap( - adata, n_comps=n_comps, neighbors_key=neighbors_key, random_state=random_state - ) + _diffmap(adata, n_comps=n_comps, neighbors_key=neighbors_key, rng=rng) return adata if copy else None diff --git a/src/scanpy/tools/_docs.py b/src/scanpy/tools/_docs.py index f2474efff6..87bab195c1 100644 --- a/src/scanpy/tools/_docs.py +++ b/src/scanpy/tools/_docs.py @@ -7,11 +7,6 @@ The annotated data matrix.\ """ -doc_random_state = """\ -random_state - Change the initialization of the optimization.\ -""" - doc_restrict_to = """\ restrict_to Restrict the clustering to the categories within the key for sample diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 603202b01b..2b80b61bc9 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -8,6 +8,7 @@ from natsort import natsorted from .. import logging as logg +from .._utils.random import _FakeRandomGen from ..neighbors import Neighbors, OnFlySymMatrix if TYPE_CHECKING: @@ -16,11 +17,17 @@ from anndata import AnnData -def _diffmap(adata, n_comps=15, neighbors_key=None, random_state=0): +def _diffmap( + adata: AnnData, + n_comps: int = 15, + *, + neighbors_key: str | None, + rng: np.random.Generator, +) -> None: start = logg.info(f"computing Diffusion Maps using {n_comps=}(=n_dcs)") dpt = DPT(adata, neighbors_key=neighbors_key) dpt.compute_transitions() - dpt.compute_eigen(n_comps=n_comps, random_state=random_state) + dpt.compute_eigen(n_comps=n_comps, rng=rng) adata.obsm["X_diffmap"] = dpt.eigen_basis adata.uns["diffmap_evals"] = dpt.eigen_values logg.info( @@ -140,7 +147,7 @@ def dpt( "Trying to run `tl.dpt` without prior call of `tl.diffmap`. " "Falling back to `tl.diffmap` with default parameters." ) - _diffmap(adata, neighbors_key=neighbors_key) + _diffmap(adata, neighbors_key=neighbors_key, rng=_FakeRandomGen(0)) # start with the actual computation dpt = DPT( adata, diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index ffb6fd1048..4047d63a41 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -1,6 +1,5 @@ from __future__ import annotations -import random from importlib.util import find_spec from typing import TYPE_CHECKING, Literal @@ -9,6 +8,11 @@ from .. import _utils from .. import logging as logg from .._utils import _choose_graph, get_literal_vals +from .._utils.random import ( + _if_legacy_apply_global, + accepts_legacy_random_state, + set_igraph_rng, +) from ._utils import get_init_pos_from_paga if TYPE_CHECKING: @@ -17,19 +21,20 @@ from anndata import AnnData from .._compat import SpBase - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike type _Layout = Literal["fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa"] +@accepts_legacy_random_state(0) def draw_graph( # noqa: PLR0913 adata: AnnData, layout: _Layout = "fa", *, init_pos: str | bool | None = None, root: int | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, n_jobs: int | None = None, adjacency: SpBase | None = None, key_added_ext: str | None = None, @@ -71,7 +76,7 @@ def draw_graph( # noqa: PLR0913 'rt' (Reingold Tilford tree layout). root Root for tree layouts. - random_state + rng For layouts with random initialization like 'fr', change this to use different intial states for the optimization. If `None`, no seed is set. adjacency @@ -111,6 +116,8 @@ def draw_graph( # noqa: PLR0913 """ start = logg.info(f"drawing single-cell graph using layout {layout!r}") + rng = np.random.default_rng(rng) + rng = _if_legacy_apply_global(rng) if layout not in (layouts := get_literal_vals(_Layout)): msg = f"Provide a valid layout, one of {layouts}." raise ValueError(msg) @@ -124,33 +131,30 @@ def draw_graph( # noqa: PLR0913 init_coords = get_init_pos_from_paga( adata, adjacency, - random_state=random_state, + rng=rng, neighbors_key=neighbors_key, obsp=obsp, ) else: - np.random.seed(random_state) - init_coords = np.random.random((adjacency.shape[0], 2)) + init_coords = rng.random((adjacency.shape[0], 2)) layout = coerce_fa2_layout(layout) # actual drawing if layout == "fa": positions = np.array(fa2_positions(adjacency, init_coords, **kwds)) else: - # igraph doesn't use numpy seed - random.seed(random_state) - g = _utils.get_igraph_from_adjacency(adjacency) - if layout in {"fr", "drl", "kk", "grid_fr"}: - ig_layout = g.layout(layout, seed=init_coords.tolist(), **kwds) - elif "rt" in layout: - if root is not None: - root = [root] - ig_layout = g.layout(layout, root=root, **kwds) - else: - ig_layout = g.layout(layout, **kwds) + with set_igraph_rng(rng): + if layout in {"fr", "drl", "kk", "grid_fr"}: + ig_layout = g.layout(layout, seed=init_coords.tolist(), **kwds) + elif "rt" in layout: + if root is not None: + root = [root] + ig_layout = g.layout(layout, root=root, **kwds) + else: + ig_layout = g.layout(layout, **kwds) positions = np.array(ig_layout.coords) adata.uns["draw_graph"] = {} - adata.uns["draw_graph"]["params"] = dict(layout=layout, random_state=random_state) + adata.uns["draw_graph"]["params"] = dict(layout=layout, random_state=rng) key_added = f"X_draw_graph_{key_added_ext or layout}" adata.obsm[key_added] = positions logg.info( diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 11d6140ad1..1f4fd2a7ea 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -10,13 +10,16 @@ from .. import logging as logg from .._compat import warn from .._utils import _doc_params -from .._utils.random import set_igraph_random_state +from .._utils.random import ( + accepts_legacy_random_state, + legacy_random_state, + set_igraph_rng, +) from ._docs import ( doc_adata, doc_adjacency, doc_neighbors_key, doc_obsp, - doc_random_state, doc_restrict_to, ) from ._utils_clustering import rename_groups, restrict_adjacency @@ -28,7 +31,7 @@ from anndata import AnnData from .._compat import CSBase - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike try: # sphinx-autodoc-typehints + optional dependency from leidenalg.VertexPartition import MutableVertexPartition @@ -40,18 +43,18 @@ @_doc_params( doc_adata=doc_adata, - random_state=doc_random_state, restrict_to=doc_restrict_to, adjacency=doc_adjacency, neighbors_key=doc_neighbors_key.format(method="leiden"), obsp=doc_obsp, ) +@accepts_legacy_random_state(0) def leiden( # noqa: PLR0913 adata: AnnData, resolution: float = 1, *, restrict_to: tuple[str, Sequence[str]] | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, key_added: str = "leiden", adjacency: CSBase | None = None, directed: bool | None = None, @@ -83,7 +86,8 @@ def leiden( # noqa: PLR0913 Higher values lead to more clusters. Set to `None` if overriding `partition_type` to one that doesn’t accept a `resolution_parameter`. - {random_state} + rng + Change the initialization of the optimization. {restrict_to} key_added `adata.obs` key under which to add the cluster labels. @@ -165,7 +169,7 @@ def leiden( # noqa: PLR0913 partition_type = leidenalg.RBConfigurationVertexPartition if use_weights: clustering_args["weights"] = np.array(g.es["weight"]).astype(np.float64) - clustering_args["seed"] = random_state + clustering_args["seed"] = legacy_random_state(rng) part = cast( "MutableVertexPartition", leidenalg.find_partition(g, partition_type, **clustering_args), @@ -177,7 +181,7 @@ def leiden( # noqa: PLR0913 if resolution is not None: clustering_args["resolution"] = resolution clustering_args.setdefault("objective_function", "modularity") - with set_igraph_random_state(random_state): + with set_igraph_rng(rng): part = g.community_leiden(**clustering_args) # store output into adata.obs groups = np.array(part.membership) @@ -200,7 +204,7 @@ def leiden( # noqa: PLR0913 adata.uns[key_added] = {} adata.uns[key_added]["params"] = dict( resolution=resolution, - random_state=random_state, + random_state=rng, n_iterations=n_iterations, ) adata.uns[key_added]["modularity"] = part.modularity diff --git a/src/scanpy/tools/_louvain.py b/src/scanpy/tools/_louvain.py index 2fb1080eea..4e131affd6 100644 --- a/src/scanpy/tools/_louvain.py +++ b/src/scanpy/tools/_louvain.py @@ -17,7 +17,6 @@ doc_adjacency, doc_neighbors_key, doc_obsp, - doc_random_state, doc_restrict_to, ) from ._utils_clustering import rename_groups, restrict_adjacency @@ -42,7 +41,6 @@ @deprecated("Use `scanpy.tl.leiden` instead") @_doc_params( doc_adata=doc_adata, - random_state=doc_random_state, restrict_to=doc_restrict_to, adjacency=doc_adjacency, neighbors_key=doc_neighbors_key.format(method="louvain"), @@ -88,7 +86,8 @@ def louvain( # noqa: PLR0912, PLR0913, PLR0915 resolution (higher resolution means finding more and smaller clusters), which defaults to 1.0. See “Time as a resolution parameter” in :cite:t:`Lambiotte2014`. - {random_state} + random_state + Change the initialization of the optimization. {restrict_to} key_added Key under which to add the cluster labels. (default: ``'louvain'``) diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index f57f50efae..97bd90dc8f 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -10,6 +10,7 @@ from .. import logging as logg from .._compat import CSBase from .._utils import check_use_raw, is_backed_type +from .._utils.random import _if_legacy_apply_global, accepts_legacy_random_state from ..get import _get_obs_rep if TYPE_CHECKING: @@ -19,7 +20,8 @@ from anndata import AnnData from numpy.typing import DTypeLike, NDArray - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike + type _StrIdx = pd.Index[str] type _GetSubset = Callable[[_StrIdx], np.ndarray | CSBase] @@ -49,6 +51,7 @@ def _sparse_nanmean(x: CSBase, /, axis: Literal[0, 1]) -> NDArray[np.float64]: return m +@accepts_legacy_random_state(0) def score_genes( # noqa: PLR0913 adata: AnnData, gene_list: Sequence[str] | pd.Index[str], @@ -58,7 +61,7 @@ def score_genes( # noqa: PLR0913 gene_pool: Sequence[str] | pd.Index[str] | None = None, n_bins: int = 25, score_name: str = "score", - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, copy: bool = False, use_raw: bool | None = None, layer: str | None = None, @@ -93,8 +96,8 @@ def score_genes( # noqa: PLR0913 Number of expression level bins for sampling. score_name Name of the field to be added in `.obs`. - random_state - The random seed for sampling. + rng + The random number generator for sampling. copy Copy `adata` or modify it inplace. use_raw @@ -118,19 +121,20 @@ def score_genes( # noqa: PLR0913 """ start = logg.info(f"computing score {score_name!r}") + rng_was_passed = rng is not None + rng = np.random.default_rng(rng) + if rng_was_passed: # backwards compatibility: call np.random.seed() by default + rng = _if_legacy_apply_global(rng) adata = adata.copy() if copy else adata use_raw = check_use_raw(adata, use_raw, layer=layer) if is_backed_type(adata.X) and not use_raw: msg = f"score_genes is not implemented for matrices of type {type(adata.X)}" raise NotImplementedError(msg) - if random_state is not None: - np.random.seed(random_state) - gene_list, gene_pool, get_subset = _check_score_genes_args( adata, gene_list, gene_pool, use_raw=use_raw, layer=layer ) - del use_raw, layer, random_state + del use_raw, layer # Trying here to match the Seurat approach in scoring cells. # Basically we need to compare genes against random genes in a matched @@ -144,6 +148,7 @@ def score_genes( # noqa: PLR0913 ctrl_size=ctrl_size, n_bins=n_bins, get_subset=get_subset, + rng=rng, ): control_genes = control_genes.union(r_genes) @@ -223,6 +228,7 @@ def _score_genes_bins( ctrl_size: int, n_bins: int, get_subset: _GetSubset, + rng: np.random.Generator, ) -> Generator[pd.Index[str], None, None]: # average expression of genes obs_avg = pd.Series(_nan_means(get_subset(gene_pool), axis=0), index=gene_pool) @@ -243,7 +249,7 @@ def _score_genes_bins( ) logg.warning(msg) if ctrl_size < len(r_genes): - r_genes = r_genes.to_series().sample(ctrl_size).index + r_genes = r_genes.to_series().sample(ctrl_size, random_state=rng).index if ctrl_as_ref: # otherwise `r_genes` is already filtered r_genes = r_genes.difference(gene_list) yield r_genes diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index 17b8ce279e..f18e954dc8 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -6,15 +6,17 @@ from .._compat import warn from .._settings import settings from .._utils import _doc_params, raise_not_implemented_error_if_backed_type +from .._utils.random import accepts_legacy_random_state, legacy_random_state from ..neighbors._doc import doc_n_pcs, doc_use_rep from ._utils import _choose_representation if TYPE_CHECKING: from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike +@accepts_legacy_random_state(0) @_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep) def tsne( # noqa: PLR0913 adata: AnnData, @@ -26,7 +28,7 @@ def tsne( # noqa: PLR0913 metric: str = "euclidean", early_exaggeration: float = 12, learning_rate: float = 1000, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, use_fast_tsne: bool = False, n_jobs: int | None = None, key_added: str | None = None, @@ -74,7 +76,7 @@ def tsne( # noqa: PLR0913 optimization, the early exaggeration factor or the learning rate might be too high. If the cost function gets stuck in a bad local minimum increasing the learning rate helps sometimes. - random_state + rng Change this to use different intial states for the optimization. If `None`, the initial state is not reproducible. n_jobs @@ -108,7 +110,7 @@ def tsne( # noqa: PLR0913 n_jobs = settings.n_jobs if n_jobs is None else n_jobs params_sklearn = dict( perplexity=perplexity, - random_state=random_state, + random_state=legacy_random_state(rng), verbose=settings.verbosity > 3, early_exaggeration=early_exaggeration, learning_rate=learning_rate, diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index 0e14a0f4ca..8d46298ec5 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -4,12 +4,13 @@ from typing import TYPE_CHECKING import numpy as np -from sklearn.utils import check_array, check_random_state +from sklearn.utils import check_array from .. import logging as logg from .._compat import warn from .._settings import settings from .._utils import NeighborsView +from .._utils.random import accepts_legacy_random_state, legacy_random_state from ._utils import _choose_representation, get_init_pos_from_paga if TYPE_CHECKING: @@ -17,11 +18,13 @@ from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike + type _InitPos = Literal["paga", "spectral", "random"] +@accepts_legacy_random_state(0) def umap( # noqa: PLR0913, PLR0915 adata: AnnData, *, @@ -33,7 +36,7 @@ def umap( # noqa: PLR0913, PLR0915 gamma: float = 1.0, negative_sample_rate: int = 5, init_pos: _InitPos | np.ndarray | None = "spectral", - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, a: float | None = None, b: float | None = None, method: Literal["umap", "rapids"] = "umap", @@ -95,11 +98,10 @@ def umap( # noqa: PLR0913, PLR0915 * 'spectral': use a spectral embedding of the graph. * 'random': assign initial embedding positions at random. * A numpy array of initial embedding positions. - random_state - If `int`, `random_state` is the seed used by the random number generator; - If `RandomState` or `Generator`, `random_state` is the random number generator; - If `None`, the random number generator is the `RandomState` instance used - by `np.random`. + rng + If `int`, `rng` is the seed used by the random number generator; + If `Generator`, `random_state` is the random number generator; + If `None`, the random number generator is not reproducible. a More specific parameters controlling the embedding. If `None` these values are set automatically as determined by `min_dist` and @@ -142,6 +144,7 @@ def umap( # noqa: PLR0913, PLR0915 UMAP parameters. """ + rng = np.random.default_rng(rng) adata = adata.copy() if copy else adata key_obsm, key_uns = ("X_umap", "umap") if key_added is None else [key_added] * 2 @@ -175,16 +178,15 @@ def umap( # noqa: PLR0913, PLR0915 init_coords = adata.obsm[init_pos] elif isinstance(init_pos, str) and init_pos == "paga": init_coords = get_init_pos_from_paga( - adata, random_state=random_state, neighbors_key=neighbors_key + adata, rng=rng, neighbors_key=neighbors_key ) else: init_coords = init_pos # Let umap handle it if hasattr(init_coords, "dtype"): init_coords = check_array(init_coords, dtype=np.float32, accept_sparse=False) - if random_state != 0: - adata.uns[key_uns]["params"]["random_state"] = random_state - random_state = check_random_state(random_state) + if rng is not None: + adata.uns[key_uns]["params"]["random_state"] = legacy_random_state(rng) neigh_params = neighbors["params"] x = _choose_representation( @@ -209,7 +211,7 @@ def umap( # noqa: PLR0913, PLR0915 negative_sample_rate=negative_sample_rate, n_epochs=n_epochs, init=init_coords, - random_state=random_state, + random_state=legacy_random_state(rng, always_state=True), metric=neigh_params.get("metric", "euclidean"), metric_kwds=neigh_params.get("metric_kwds", {}), densmap=False, @@ -249,7 +251,7 @@ def umap( # noqa: PLR0913, PLR0915 a=a, b=b, verbose=settings.verbosity > 3, - random_state=random_state, + random_state=legacy_random_state(rng), ) x_umap = umap.fit_transform(x_contiguous) adata.obsm[key_obsm] = x_umap # annotate samples with UMAP coordinates diff --git a/src/scanpy/tools/_utils.py b/src/scanpy/tools/_utils.py index fdcff28720..bdb9ae7f90 100644 --- a/src/scanpy/tools/_utils.py +++ b/src/scanpy/tools/_utils.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from anndata import AnnData + from numpy.typing import NDArray from .._compat import CSRBase, SpBase @@ -77,12 +78,12 @@ def _get_pca_or_small_x(adata: AnnData, n_pcs: int | None) -> np.ndarray | CSRBa def get_init_pos_from_paga( adata: AnnData, + *, + rng: np.random.Generator, adjacency: SpBase | None = None, - random_state=0, neighbors_key: str | None = None, obsp: str | None = None, -): - np.random.seed(random_state) +) -> NDArray[np.float64]: if adjacency is None: adjacency = _choose_graph(adata, obsp, neighbors_key) if "pos" not in adata.uns.get("paga", {}): @@ -99,7 +100,7 @@ def get_init_pos_from_paga( if len(neighbors[1]) > 0: connectivities = connectivities_coarse[i][neighbors] nearest_neighbor = neighbors[1][np.argmax(connectivities)] - noise = np.random.random((len(subset[subset]), 2)) + noise = rng.random((len(subset[subset]), 2)) dist = group_pos - pos[nearest_neighbor] noise = noise * dist init_pos[subset] = group_pos - 0.5 * dist + noise diff --git a/tests/test_pca.py b/tests/test_pca.py index ba996d8ad9..d4cd5e58f3 100644 --- a/tests/test_pca.py +++ b/tests/test_pca.py @@ -244,7 +244,7 @@ def test_pca_transform_randomized(array_type): warnings.filterwarnings("error") if isinstance(adata.X, DaskArray) and isinstance(adata.X._meta, CSBase): patterns = ( - r"Ignoring random_state=14 when using a sparse dask array", + r"Ignoring rng=14 when using a sparse dask array", r"Ignoring svd_solver='randomized' when using a sparse dask array", ) ctx = _helpers.MultiContext( @@ -338,7 +338,7 @@ def test_pca_reproducible(array_type): pbmc.X = array_type(pbmc.X) with ( - pytest.warns(UserWarning, match=r"Ignoring random_state.*sparse dask array") + pytest.warns(UserWarning, match=r"Ignoring rng.*sparse dask array") if isinstance(pbmc.X, DaskArray) and isinstance(pbmc.X._meta, CSBase) else nullcontext() ): diff --git a/tests/test_utils.py b/tests/test_utils.py index 02e37dcabc..40852a972e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,8 +18,8 @@ descend_classes_and_funcs, ) from scanpy._utils.random import ( + _FakeRandomGen, ith_k_tuple, - legacy_numpy_gen, random_k_tuples, random_str, ) @@ -206,7 +206,7 @@ def test_legacy_numpy_gen(*, seed: int, pass_seed: bool, func: str): def _mk_random(func: str, *, direct: bool, seed: int | None) -> np.ndarray: if direct and seed is not None: np.random.seed(seed) - gen = np.random if direct else legacy_numpy_gen(seed) + gen = np.random if direct else _FakeRandomGen.wrap_global(seed) match func: case "choice": arr = np.arange(1000)