diff --git a/pyproject.toml b/pyproject.toml index a6af021886..4cf55602f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ dependencies = [ "pynndescent>=0.5.13", "scikit-learn>=1.6", "scipy>=1.13", + "scverse-backends>=0.0.2,<0.1", "scverse-misc[settings]>=0.0.5", "seaborn>=0.13.2", "session-info2", diff --git a/src/scanpy/_backends.py b/src/scanpy/_backends.py new file mode 100644 index 0000000000..f9ac431056 --- /dev/null +++ b/src/scanpy/_backends.py @@ -0,0 +1,42 @@ +"""Backend dispatch integration for optional Scanpy accelerators.""" + +from __future__ import annotations + +from scverse_backends import BackendDispatcher + +dispatcher = BackendDispatcher( + entrypoint_group="scanpy.backends", + host_name="scanpy", + trusted_backends={ + "rapids_singlecell": { + "aliases": ["cuda", "rapids", "rapids-singlecell"], + "distributions": [ + "rapids-singlecell", + "rapids-singlecell-cu12", + "rapids-singlecell-cu13", + ], + "package": "rapids-singlecell", + }, + }, + reserved_backends={ + "gpu": ( + "Use 'cuda' for the RAPIDS backend. The generic 'gpu' selector is " + "reserved so future GPU backends can coexist without ambiguity." + ), + }, +) + +backend_dispatch = dispatcher.backend_dispatch +settings = dispatcher.settings +get_backend = dispatcher.get_backend +available_backend_names = dispatcher.available_backend_names +discover = dispatcher.discover + +__all__ = [ + "available_backend_names", + "backend_dispatch", + "discover", + "dispatcher", + "get_backend", + "settings", +] diff --git a/src/scanpy/_settings/__init__.py b/src/scanpy/_settings/__init__.py index 1fff61f409..0e5a1734e5 100644 --- a/src/scanpy/_settings/__init__.py +++ b/src/scanpy/_settings/__init__.py @@ -173,7 +173,39 @@ def model_post_init(self, context: object) -> None: @computed_field @property - def logpath(self) -> Path | None: + def backend(cls) -> str: + """Active computational backend (default ``'cpu'``).""" + from .._backends import settings as backend_settings + + return backend_settings.backend + + @backend.setter + @_type_check_arg2(str) + def backend(cls, backend: str) -> None: + from .._backends import settings as backend_settings + + backend_settings.backend = backend + + def use_backend(cls, backend: str): + """Temporarily set the active computational backend.""" + from .._backends import settings as backend_settings + + return backend_settings.use_backend(backend) + + def available_backends(cls) -> list[str]: + """Return canonical names of installed computational backends.""" + from .._backends import settings as backend_settings + + return backend_settings.available_backends() + + def get_backend(cls, name: str): + """Look up an installed computational backend by name or alias.""" + from .._backends import settings as backend_settings + + return backend_settings.get_backend(name) + + @property + def logpath(cls) -> Path | None: """The file path `logfile` was set to.""" if self.logfile is _default_logfile(): return None diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index fe5140171b..0390d18b3f 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -10,6 +10,7 @@ from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 +from scanpy._backends import backend_dispatch from scanpy._compat import CSBase, CSRBase, DaskArray from .._utils import _resolve_axis, get_literal_vals @@ -197,6 +198,7 @@ def _power(x: Array, power: float) -> Array: return x**power if isinstance(x, np.ndarray) else x.power(power) +@backend_dispatch def aggregate( # noqa: PLR0912 adata: AnnData, by: str | Collection[str], diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 3622c02576..faadb5c966 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -15,6 +15,7 @@ from .. import _utils from .. import logging as logg +from .._backends import backend_dispatch from .._compat import CSBase, CSRBase, SpBase, pkg_version, warn from .._docs import doc_rng from .._settings import settings @@ -82,8 +83,9 @@ class NeighborsParams(TypedDict): # noqa: D101 n_pcs: NotRequired[int] -@_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep, rng=doc_rng) @_accepts_legacy_random_state(_DEFAULT_SEED := 0) +@backend_dispatch +@_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep, rng=doc_rng) def neighbors( # noqa: PLR0913 adata: AnnData, n_neighbors: int = 15, diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index 98b0c6f5fe..98d8e7dbdc 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -4,6 +4,7 @@ import numpy as np +from ..._backends import backend_dispatch from ..._compat import warn if TYPE_CHECKING: @@ -16,6 +17,7 @@ from ..._utils.random import RNGLike, SeedLike +@backend_dispatch def harmony_integrate( # noqa: PLR0913 adata: AnnData, key: str | Sequence[str], diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index d5f3d2cc79..1ff712299e 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -13,6 +13,7 @@ from fast_array_utils import stats from .. import logging as logg +from .._backends import backend_dispatch from .._compat import CSBase, CSRBase, DaskArray, warn from .._settings import Default, Verbosity, settings from .._utils import ( @@ -591,6 +592,7 @@ def _highly_variable_genes_batched( return df +@backend_dispatch def highly_variable_genes( # noqa: PLR0913 adata: AnnData, *, diff --git a/src/scanpy/preprocessing/_normalization.py b/src/scanpy/preprocessing/_normalization.py index 28d7c20293..7b92f8ac71 100644 --- a/src/scanpy/preprocessing/_normalization.py +++ b/src/scanpy/preprocessing/_normalization.py @@ -9,6 +9,7 @@ from fast_array_utils.numba import njit from .. import logging as logg +from .._backends import backend_dispatch from .._compat import CSBase, CSCBase, CSRBase, DaskArray, warn from .._utils import axis_mul_or_truediv, dematrix, view_to_actual from ..get import _get_obs_rep, _set_obs_rep @@ -124,6 +125,7 @@ def _normalize_total_helper( return x, counts_per_cell, gene_subset +@backend_dispatch def normalize_total( # noqa: PLR0912 adata: AnnData, *, diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 4cbcf4c7e6..9ed296a1f8 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -7,6 +7,7 @@ from packaging.version import Version from ... import logging as logg +from ..._backends import backend_dispatch from ..._compat import CSBase, DaskArray, pkg_version, warn from ..._docs import doc_rng from ..._settings import Default, settings @@ -51,8 +52,9 @@ type SvdSolver = SvdSolvDaskML | SvdSolvSkearn | SvdSolvPCACustom -@_doc_params(mask_var=doc_mask_var, rng=doc_rng) @_accepts_legacy_random_state(0) +@backend_dispatch +@_doc_params(mask_var=doc_mask_var, rng=doc_rng) def pca( # noqa: PLR0912, PLR0913, PLR0915 data: AnnData | np.ndarray | CSBase, n_comps: int | None = None, diff --git a/src/scanpy/preprocessing/_qc.py b/src/scanpy/preprocessing/_qc.py index 04340ec7e9..5710322ec3 100644 --- a/src/scanpy/preprocessing/_qc.py +++ b/src/scanpy/preprocessing/_qc.py @@ -13,6 +13,7 @@ from scanpy.get import _get_obs_rep from scanpy.preprocessing._distributed import materialize_as_ndarray +from .._backends import backend_dispatch from .._compat import CSBase, CSRBase, DaskArray, warn from .._utils import _doc_params, axis_nnz from ._docs import ( @@ -204,6 +205,7 @@ def describe_var( doc_obs_qc_returns=doc_obs_qc_returns, doc_var_qc_returns=doc_var_qc_returns, ) +@backend_dispatch def calculate_qc_metrics( adata: AnnData, *, diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index 1a1a047544..451b3376a1 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -11,6 +11,7 @@ from fast_array_utils.stats import mean_var from .. import logging as logg +from .._backends import backend_dispatch from .._compat import CSBase, CSCBase, CSRBase, DaskArray, warn from .._settings import Default, settings from .._utils import ( @@ -68,6 +69,7 @@ def clip_array( return x +@backend_dispatch @singledispatch def scale[A: _Array]( data: AnnData | A, diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index 6aae43c608..6036927e45 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -10,6 +10,7 @@ from ... import logging as logg from ... import preprocessing as pp +from ..._backends import backend_dispatch from ..._docs import doc_rng from ..._utils import _doc_params from ..._utils.random import _accepts_legacy_random_state, _LegacyRng @@ -23,6 +24,7 @@ @_accepts_legacy_random_state(0) +@backend_dispatch @_doc_params(rng=doc_rng) def scrublet( # noqa: PLR0913 adata: AnnData, @@ -497,6 +499,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 @_accepts_legacy_random_state(0) +@backend_dispatch def scrublet_simulate_doublets( adata: AnnData, *, diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index fb325dec35..a46cd942cd 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -23,6 +23,7 @@ from sklearn.utils import check_array from .. import logging as logg +from .._backends import backend_dispatch from .._compat import CSBase, CSRBase, DaskArray from .._docs import doc_rng from .._settings import settings @@ -50,6 +51,7 @@ from .._utils.random import RNGLike, SeedLike +@backend_dispatch def filter_cells( data: AnnData | CSBase | np.ndarray | DaskArray, *, @@ -196,6 +198,7 @@ def filter_cells( return cell_subset, number_per_cell +@backend_dispatch def filter_genes( data: AnnData | CSBase | np.ndarray | DaskArray, *, @@ -306,6 +309,7 @@ def filter_genes( return gene_subset, number_per_gene +@backend_dispatch @singledispatch def log1p( data: AnnData | np.ndarray | CSBase, @@ -507,6 +511,7 @@ def numpy_regress_out( return data +@backend_dispatch def regress_out( adata: AnnData, keys: str | Sequence[str], diff --git a/src/scanpy/testing/__init__.py b/src/scanpy/testing/__init__.py new file mode 100644 index 0000000000..42a4700282 --- /dev/null +++ b/src/scanpy/testing/__init__.py @@ -0,0 +1,7 @@ +"""Testing helpers for Scanpy extensions.""" + +from __future__ import annotations + +from testing.scanpy._backends import validate_backend + +__all__ = ["validate_backend"] diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index 90f2976837..0817448396 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -4,6 +4,7 @@ import numpy as np +from .._backends import backend_dispatch from .._docs import doc_rng from .._utils import _doc_params from .._utils.random import _accepts_legacy_random_state @@ -15,8 +16,9 @@ from .._utils.random import RNGLike, SeedLike -@_doc_params(rng=doc_rng) @_accepts_legacy_random_state(0) +@backend_dispatch +@_doc_params(rng=doc_rng) def diffmap( adata: AnnData, n_comps: int = 15, diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index bdbb37d0be..7036309f62 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -7,6 +7,7 @@ from .. import _utils from .. import logging as logg +from .._backends import backend_dispatch from .._docs import doc_rng from .._utils import _choose_graph, _doc_params, get_literal_vals from .._utils.random import ( @@ -31,8 +32,9 @@ type _Layout = Literal["fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa"] -@_doc_params(rng=doc_rng) @_accepts_legacy_random_state(0) +@backend_dispatch +@_doc_params(rng=doc_rng) def draw_graph( # noqa: PLR0913 adata: AnnData, layout: _Layout = "fa", diff --git a/src/scanpy/tools/_embedding_density.py b/src/scanpy/tools/_embedding_density.py index 37e2bcf606..b0f1d09d14 100644 --- a/src/scanpy/tools/_embedding_density.py +++ b/src/scanpy/tools/_embedding_density.py @@ -7,6 +7,7 @@ import numpy as np from .. import logging as logg +from .._backends import backend_dispatch from .._utils import sanitize_anndata if TYPE_CHECKING: @@ -32,6 +33,7 @@ def _calc_density(x: np.ndarray, y: np.ndarray): return scaled_z +@backend_dispatch def embedding_density( # noqa: PLR0912 adata: AnnData, basis: str = "umap", diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 2f789acbaf..09ea969f21 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -8,6 +8,7 @@ from .. import _utils from .. import logging as logg +from .._backends import backend_dispatch from .._compat import warn from .._docs import doc_rng from .._settings import Default @@ -43,6 +44,8 @@ ) +@_accepts_legacy_random_state(0) +@backend_dispatch @_doc_params( doc_adata=doc_adata, restrict_to=doc_restrict_to, @@ -51,7 +54,6 @@ obsp=doc_obsp, rng=doc_rng, ) -@_accepts_legacy_random_state(0) def leiden( # noqa: PLR0913 adata: AnnData, resolution: float = 1, diff --git a/src/scanpy/tools/_louvain.py b/src/scanpy/tools/_louvain.py index 19fb50bd47..f7138f15cf 100644 --- a/src/scanpy/tools/_louvain.py +++ b/src/scanpy/tools/_louvain.py @@ -11,7 +11,8 @@ from .. import _utils from .. import logging as logg -from .._compat import pkg_version +from .._backends import backend_dispatch +from .._compat import deprecated, pkg_version from .._utils import _choose_graph, _doc_params from ._docs import ( doc_adata, @@ -40,6 +41,7 @@ ) +@backend_dispatch @deprecated(Deprecation("1.12.0", "Use :func:`scanpy.tl.leiden` instead.")) @_doc_params( doc_adata=doc_adata, diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 0e6bb73386..85240dc186 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -13,6 +13,7 @@ from .. import _utils from .. import logging as logg +from .._backends import backend_dispatch from .._compat import CSBase from .._settings import Default from .._settings.presets import DETest @@ -493,6 +494,7 @@ def compute_statistics( # noqa: PLR0912 self.stats.index = self.var_names +@backend_dispatch def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 adata: AnnData, groupby: str, diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index 6ba7f51a4c..69ba104bd5 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -8,6 +8,7 @@ import pandas as pd from .. import logging as logg +from .._backends import backend_dispatch from .._compat import CSBase from .._docs import doc_rng from .._settings import Default, settings @@ -53,8 +54,9 @@ def _sparse_nanmean(x: CSBase, /, axis: Literal[0, 1]) -> NDArray[np.float64]: return m -@_doc_params(rng=doc_rng) @_accepts_legacy_random_state(0) +@backend_dispatch +@_doc_params(rng=doc_rng) def score_genes( # noqa: PLR0913 adata: AnnData, gene_list: Sequence[str] | pd.Index[str], @@ -267,6 +269,7 @@ def _nan_means( return np.nanmean(x, axis=axis, dtype=dtype) +@backend_dispatch def score_genes_cell_cycle( adata: AnnData, *, diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index 54fb28f436..c8d93d04b5 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from .. import logging as logg +from .._backends import backend_dispatch from .._compat import warn from .._docs import doc_rng from .._settings import settings @@ -18,6 +19,7 @@ @_accepts_legacy_random_state(0) +@backend_dispatch @_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep, rng=doc_rng) def tsne( # noqa: PLR0913 adata: AnnData, diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index 316c6c8aed..9bc95bce90 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -6,6 +6,7 @@ from sklearn.utils import check_array from .. import logging as logg +from .._backends import backend_dispatch from .._docs import doc_rng from .._settings import settings from .._utils import NeighborsView, _doc_params @@ -28,6 +29,7 @@ @_accepts_legacy_random_state(0) +@backend_dispatch @_doc_params(rng=doc_rng) def umap( # noqa: PLR0913 adata: AnnData, diff --git a/src/testing/scanpy/__init__.py b/src/testing/scanpy/__init__.py index 08571d88a9..3995ccc90e 100644 --- a/src/testing/scanpy/__init__.py +++ b/src/testing/scanpy/__init__.py @@ -1,3 +1,18 @@ """Scanpy testing utilities.""" -# This file is empty until we design its public API. +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._backends import validate_backend + +__all__ = ["validate_backend"] + + +def __getattr__(name: str): + if name == "validate_backend": + from ._backends import validate_backend + + return validate_backend + raise AttributeError(name) diff --git a/src/testing/scanpy/_backends.py b/src/testing/scanpy/_backends.py new file mode 100644 index 0000000000..87b9438c6b --- /dev/null +++ b/src/testing/scanpy/_backends.py @@ -0,0 +1,102 @@ +"""Backend conformance helpers for Scanpy-compatible accelerators.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from anndata import AnnData +from scverse_backends.testing import run_conformance + +import scanpy as sc +from scanpy._backends import get_backend +from scanpy._compat import SpBase + +if TYPE_CHECKING: + from collections.abc import Sequence + + +def _as_numpy(x): + if hasattr(x, "get"): + x = x.get() + if isinstance(x, SpBase): + return x.toarray() + return np.asarray(x) + + +def _conformance_adata() -> AnnData: + x = np.array( + [ + [10, 0, 0, 8], + [9, 1, 0, 7], + [0, 8, 9, 0], + [1, 7, 8, 0], + [5, 5, 1, 1], + [4, 4, 2, 2], + ], + dtype=np.float32, + ) + return AnnData(x) + + +def _test_normalize_total(backend_name: str) -> None: + expected = _conformance_adata() + actual = expected.copy() + + sc.pp.normalize_total(expected, target_sum=10, backend="cpu") + sc.pp.normalize_total(actual, target_sum=10, backend=backend_name) + + np.testing.assert_allclose(_as_numpy(actual.X), _as_numpy(expected.X)) + + +def _test_neighbors(backend_name: str) -> None: + expected = _conformance_adata() + actual = expected.copy() + + sc.pp.neighbors(expected, n_neighbors=3, use_rep="X", backend="cpu") + sc.pp.neighbors(actual, n_neighbors=3, use_rep="X", backend=backend_name) + + for key in ["distances", "connectivities"]: + np.testing.assert_allclose( + _as_numpy(actual.obsp[key]), + _as_numpy(expected.obsp[key]), + rtol=1e-5, + atol=1e-6, + ) + + +def _test_umap(backend_name: str) -> None: + adata = _conformance_adata() + sc.pp.neighbors(adata, n_neighbors=3, use_rep="X", backend="cpu") + + sc.tl.umap(adata, rng=0, backend=backend_name) + + assert "X_umap" in adata.obsm + assert adata.obsm["X_umap"].shape == (adata.n_obs, 2) + assert np.isfinite(_as_numpy(adata.obsm["X_umap"])).all() + + +_TESTS = { + "normalize_total": _test_normalize_total, + "neighbors": _test_neighbors, + "umap": _test_umap, +} + + +def validate_backend( + backend_name: str, + *, + functions: Sequence[str] | None = None, + raise_on_failure: bool = True, +) -> dict[str, str]: + """Run Scanpy's backend conformance checks against an installed backend.""" + return run_conformance( + backend_name=backend_name, + tests=_TESTS, + get_backend=get_backend, + functions=functions, + raise_on_failure=raise_on_failure, + ) + + +__all__ = ["validate_backend"] diff --git a/tests/test_backends.py b/tests/test_backends.py new file mode 100644 index 0000000000..5f24c0edcd --- /dev/null +++ b/tests/test_backends.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from copy import copy +from inspect import Parameter, signature + +import numpy as np +import pytest +from anndata import AnnData + +import scanpy as sc +from scanpy._backends import dispatcher +from scanpy._backends import settings as backend_settings +from scanpy.testing import validate_backend + + +class FakeRapidsBackend: + name = "rapids_singlecell" + aliases = ("cuda", "rapids", "rapids-singlecell") + + def normalize_total( + self, + adata: AnnData, + *, + target_sum: float | None = None, + fake_param: str | None = None, + ) -> None: + """Fake normalize_total backend implementation. + + Parameters + ---------- + adata + Annotated data matrix. + target_sum + Target total counts. + fake_param + Backend-only parameter used by tests. + """ + sc.pp.normalize_total(adata, target_sum=target_sum, backend="cpu") + adata.uns["fake_backend_called"] = { + "function": "normalize_total", + "fake_param": fake_param, + "target_sum": target_sum, + } + + +DISPATCHED_FUNCTIONS = [ + sc.pp.calculate_qc_metrics, + sc.pp.filter_cells, + sc.pp.filter_genes, + sc.pp.log1p, + sc.pp.highly_variable_genes, + sc.pp.normalize_total, + sc.pp.regress_out, + sc.pp.scale, + sc.pp.pca, + sc.pp.harmony_integrate, + sc.pp.scrublet, + sc.pp.scrublet_simulate_doublets, + sc.pp.neighbors, + sc.tl.umap, + sc.tl.tsne, + sc.tl.diffmap, + sc.tl.draw_graph, + sc.tl.embedding_density, + sc.tl.louvain, + sc.tl.leiden, + sc.tl.score_genes, + sc.tl.score_genes_cell_cycle, + sc.tl.rank_genes_groups, + sc.get.aggregate, +] + + +@pytest.fixture +def fake_rapids_backend(): + registry = dispatcher._registry + dispatch_impl = dispatcher._dispatch_impl + old_backend = backend_settings.backend + old_state = { + "_backends": copy(registry._backends), + "_alias_map": copy(registry._alias_map), + "_load_errors": copy(registry._load_errors), + "_registration_errors": copy(registry._registration_errors), + "_warned_untrusted": copy(registry._warned_untrusted), + "_discovered": registry._discovered, + "_sig_cache": copy(dispatch_impl._sig_cache), + } + + backend_settings._backend_var.set("cpu") + registry._backends.clear() + registry._alias_map.clear() + registry._load_errors.clear() + registry._registration_errors.clear() + registry._warned_untrusted.clear() + registry._discovered = True + registry._register_backend( + FakeRapidsBackend(), + entrypoint_name="rapids_singlecell", + distribution_name="rapids-singlecell", + object_ref="rapids_singlecell.backends.scanpy:ScanpyBackend", + ) + dispatch_impl._sig_cache.clear() + dispatch_impl._update_signatures() + + yield + + backend_settings._backend_var.set(old_backend) + registry._backends.clear() + registry._backends.update(old_state["_backends"]) + registry._alias_map.clear() + registry._alias_map.update(old_state["_alias_map"]) + registry._load_errors.clear() + registry._load_errors.update(old_state["_load_errors"]) + registry._registration_errors.clear() + registry._registration_errors.update(old_state["_registration_errors"]) + registry._warned_untrusted.clear() + registry._warned_untrusted.update(old_state["_warned_untrusted"]) + registry._discovered = old_state["_discovered"] + dispatch_impl._sig_cache.clear() + dispatch_impl._sig_cache.update(old_state["_sig_cache"]) + dispatch_impl._update_signatures() + + +@pytest.mark.parametrize("func", DISPATCHED_FUNCTIONS) +def test_dispatched_functions_have_backend_keyword(func): + backend = signature(func).parameters["backend"] + + assert backend.kind is Parameter.KEYWORD_ONLY + assert backend.default is None + + +def test_settings_resolve_rapids_alias(fake_rapids_backend): + sc.settings.backend = "cuda" + + assert sc.settings.backend == "rapids_singlecell" + assert sc.settings.get_backend("rapids") is not None + assert sc.settings.available_backends() == ["rapids_singlecell"] + + +def test_use_backend_dispatches_and_restores(fake_rapids_backend): + adata = AnnData(np.ones((2, 2), dtype=np.float32)) + + with sc.settings.use_backend("rapids"): + assert sc.settings.backend == "rapids_singlecell" + sc.pp.normalize_total(adata, target_sum=4, fake_param="from-backend") + + assert sc.settings.backend == "cpu" + assert adata.uns["fake_backend_called"] == { + "function": "normalize_total", + "fake_param": "from-backend", + "target_sum": 4, + } + + +def test_call_backend_overrides_settings(fake_rapids_backend): + adata = AnnData(np.array([[1, 1], [2, 2]], dtype=np.float32)) + + with sc.settings.use_backend("cuda"): + sc.pp.normalize_total(adata, target_sum=1, backend="cpu") + + assert "fake_backend_called" not in adata.uns + np.testing.assert_allclose(adata.X.sum(axis=1), [1, 1]) + + +def test_backend_only_parameters_are_injected(fake_rapids_backend): + fake_param = signature(sc.pp.normalize_total).parameters["fake_param"] + + assert fake_param.kind is Parameter.KEYWORD_ONLY + assert fake_param.default is None + + +def test_backend_conformance_harness(fake_rapids_backend): + assert validate_backend("cuda", functions=["normalize_total"]) == { + "normalize_total": "PASSED", + } + + +def test_reserved_gpu_backend_name(): + with pytest.raises(ValueError, match="reserved by scanpy"): + sc.settings.backend = "gpu"