Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ dependencies = [
"pynndescent>=0.5.13",
"scikit-learn>=1.6",
"scipy>=1.13",
"scverse-backends>=0.0.2,<0.1",
"scverse-misc[settings]>=0.0.5",
"seaborn>=0.13.2",
"session-info2",
Expand Down
42 changes: 42 additions & 0 deletions src/scanpy/_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Backend dispatch integration for optional Scanpy accelerators."""

from __future__ import annotations

from scverse_backends import BackendDispatcher

dispatcher = BackendDispatcher(
entrypoint_group="scanpy.backends",
host_name="scanpy",
trusted_backends={
"rapids_singlecell": {
"aliases": ["cuda", "rapids", "rapids-singlecell"],
"distributions": [
"rapids-singlecell",
"rapids-singlecell-cu12",
"rapids-singlecell-cu13",
],
"package": "rapids-singlecell",
},
},
reserved_backends={
"gpu": (
"Use 'cuda' for the RAPIDS backend. The generic 'gpu' selector is "
"reserved so future GPU backends can coexist without ambiguity."
),
},
)

backend_dispatch = dispatcher.backend_dispatch
settings = dispatcher.settings
get_backend = dispatcher.get_backend
available_backend_names = dispatcher.available_backend_names
discover = dispatcher.discover

__all__ = [
"available_backend_names",
"backend_dispatch",
"discover",
"dispatcher",
"get_backend",
"settings",
]
34 changes: 33 additions & 1 deletion src/scanpy/_settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,39 @@ def model_post_init(self, context: object) -> None:

@computed_field
@property
def logpath(self) -> Path | None:
def backend(cls) -> str:
"""Active computational backend (default ``'cpu'``)."""
from .._backends import settings as backend_settings

return backend_settings.backend

@backend.setter
@_type_check_arg2(str)
def backend(cls, backend: str) -> None:
from .._backends import settings as backend_settings

backend_settings.backend = backend

def use_backend(cls, backend: str):
"""Temporarily set the active computational backend."""
from .._backends import settings as backend_settings

return backend_settings.use_backend(backend)

def available_backends(cls) -> list[str]:
"""Return canonical names of installed computational backends."""
from .._backends import settings as backend_settings

return backend_settings.available_backends()

def get_backend(cls, name: str):
"""Look up an installed computational backend by name or alias."""
from .._backends import settings as backend_settings

return backend_settings.get_backend(name)

@property
def logpath(cls) -> Path | None:
"""The file path `logfile` was set to."""
if self.logfile is _default_logfile():
return None
Expand Down
2 changes: 2 additions & 0 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

from scanpy._backends import backend_dispatch
from scanpy._compat import CSBase, CSRBase, DaskArray

from .._utils import _resolve_axis, get_literal_vals
Expand Down Expand Up @@ -197,6 +198,7 @@ def _power(x: Array, power: float) -> Array:
return x**power if isinstance(x, np.ndarray) else x.power(power)


@backend_dispatch
def aggregate( # noqa: PLR0912
adata: AnnData,
by: str | Collection[str],
Expand Down
4 changes: 3 additions & 1 deletion src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .. import _utils
from .. import logging as logg
from .._backends import backend_dispatch
from .._compat import CSBase, CSRBase, SpBase, pkg_version, warn
from .._docs import doc_rng
from .._settings import settings
Expand Down Expand Up @@ -82,8 +83,9 @@ class NeighborsParams(TypedDict): # noqa: D101
n_pcs: NotRequired[int]


@_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep, rng=doc_rng)
@_accepts_legacy_random_state(_DEFAULT_SEED := 0)
@backend_dispatch
@_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep, rng=doc_rng)
def neighbors( # noqa: PLR0913
adata: AnnData,
n_neighbors: int = 15,
Expand Down
2 changes: 2 additions & 0 deletions src/scanpy/preprocessing/_harmony/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from ..._backends import backend_dispatch
from ..._compat import warn

if TYPE_CHECKING:
Expand All @@ -16,6 +17,7 @@
from ..._utils.random import RNGLike, SeedLike


@backend_dispatch
def harmony_integrate( # noqa: PLR0913
adata: AnnData,
key: str | Sequence[str],
Expand Down
2 changes: 2 additions & 0 deletions src/scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from fast_array_utils import stats

from .. import logging as logg
from .._backends import backend_dispatch
from .._compat import CSBase, CSRBase, DaskArray, warn
from .._settings import Default, Verbosity, settings
from .._utils import (
Expand Down Expand Up @@ -591,6 +592,7 @@ def _highly_variable_genes_batched(
return df


@backend_dispatch
def highly_variable_genes( # noqa: PLR0913
adata: AnnData,
*,
Expand Down
2 changes: 2 additions & 0 deletions src/scanpy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fast_array_utils.numba import njit

from .. import logging as logg
from .._backends import backend_dispatch
from .._compat import CSBase, CSCBase, CSRBase, DaskArray, warn
from .._utils import axis_mul_or_truediv, dematrix, view_to_actual
from ..get import _get_obs_rep, _set_obs_rep
Expand Down Expand Up @@ -124,6 +125,7 @@ def _normalize_total_helper(
return x, counts_per_cell, gene_subset


@backend_dispatch
def normalize_total( # noqa: PLR0912
adata: AnnData,
*,
Expand Down
4 changes: 3 additions & 1 deletion src/scanpy/preprocessing/_pca/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from packaging.version import Version

from ... import logging as logg
from ..._backends import backend_dispatch
from ..._compat import CSBase, DaskArray, pkg_version, warn
from ..._docs import doc_rng
from ..._settings import Default, settings
Expand Down Expand Up @@ -51,8 +52,9 @@
type SvdSolver = SvdSolvDaskML | SvdSolvSkearn | SvdSolvPCACustom


@_doc_params(mask_var=doc_mask_var, rng=doc_rng)
@_accepts_legacy_random_state(0)
@backend_dispatch
@_doc_params(mask_var=doc_mask_var, rng=doc_rng)
def pca( # noqa: PLR0912, PLR0913, PLR0915
data: AnnData | np.ndarray | CSBase,
n_comps: int | None = None,
Expand Down
2 changes: 2 additions & 0 deletions src/scanpy/preprocessing/_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from scanpy.get import _get_obs_rep
from scanpy.preprocessing._distributed import materialize_as_ndarray

from .._backends import backend_dispatch
from .._compat import CSBase, CSRBase, DaskArray, warn
from .._utils import _doc_params, axis_nnz
from ._docs import (
Expand Down Expand Up @@ -204,6 +205,7 @@ def describe_var(
doc_obs_qc_returns=doc_obs_qc_returns,
doc_var_qc_returns=doc_var_qc_returns,
)
@backend_dispatch
def calculate_qc_metrics(
adata: AnnData,
*,
Expand Down
2 changes: 2 additions & 0 deletions src/scanpy/preprocessing/_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fast_array_utils.stats import mean_var

from .. import logging as logg
from .._backends import backend_dispatch
from .._compat import CSBase, CSCBase, CSRBase, DaskArray, warn
from .._settings import Default, settings
from .._utils import (
Expand Down Expand Up @@ -68,6 +69,7 @@ def clip_array(
return x


@backend_dispatch
@singledispatch
def scale[A: _Array](
data: AnnData | A,
Expand Down
3 changes: 3 additions & 0 deletions src/scanpy/preprocessing/_scrublet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ... import logging as logg
from ... import preprocessing as pp
from ..._backends import backend_dispatch
from ..._docs import doc_rng
from ..._utils import _doc_params
from ..._utils.random import _accepts_legacy_random_state, _LegacyRng
Expand All @@ -23,6 +24,7 @@


@_accepts_legacy_random_state(0)
@backend_dispatch
@_doc_params(rng=doc_rng)
def scrublet( # noqa: PLR0913
adata: AnnData,
Expand Down Expand Up @@ -497,6 +499,7 @@ def _scrublet_call_doublets( # noqa: PLR0913


@_accepts_legacy_random_state(0)
@backend_dispatch
def scrublet_simulate_doublets(
adata: AnnData,
*,
Expand Down
5 changes: 5 additions & 0 deletions src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sklearn.utils import check_array

from .. import logging as logg
from .._backends import backend_dispatch
from .._compat import CSBase, CSRBase, DaskArray
from .._docs import doc_rng
from .._settings import settings
Expand Down Expand Up @@ -50,6 +51,7 @@
from .._utils.random import RNGLike, SeedLike


@backend_dispatch
def filter_cells(
data: AnnData | CSBase | np.ndarray | DaskArray,
*,
Expand Down Expand Up @@ -196,6 +198,7 @@ def filter_cells(
return cell_subset, number_per_cell


@backend_dispatch
def filter_genes(
data: AnnData | CSBase | np.ndarray | DaskArray,
*,
Expand Down Expand Up @@ -306,6 +309,7 @@ def filter_genes(
return gene_subset, number_per_gene


@backend_dispatch
@singledispatch
def log1p(
data: AnnData | np.ndarray | CSBase,
Expand Down Expand Up @@ -507,6 +511,7 @@ def numpy_regress_out(
return data


@backend_dispatch
def regress_out(
adata: AnnData,
keys: str | Sequence[str],
Expand Down
7 changes: 7 additions & 0 deletions src/scanpy/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Testing helpers for Scanpy extensions."""

from __future__ import annotations

from testing.scanpy._backends import validate_backend

__all__ = ["validate_backend"]
4 changes: 3 additions & 1 deletion src/scanpy/tools/_diffmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from .._backends import backend_dispatch
from .._docs import doc_rng
from .._utils import _doc_params
from .._utils.random import _accepts_legacy_random_state
Expand All @@ -15,8 +16,9 @@
from .._utils.random import RNGLike, SeedLike


@_doc_params(rng=doc_rng)
@_accepts_legacy_random_state(0)
@backend_dispatch
@_doc_params(rng=doc_rng)
def diffmap(
adata: AnnData,
n_comps: int = 15,
Expand Down
4 changes: 3 additions & 1 deletion src/scanpy/tools/_draw_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .. import _utils
from .. import logging as logg
from .._backends import backend_dispatch
from .._docs import doc_rng
from .._utils import _choose_graph, _doc_params, get_literal_vals
from .._utils.random import (
Expand All @@ -31,8 +32,9 @@
type _Layout = Literal["fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa"]


@_doc_params(rng=doc_rng)
@_accepts_legacy_random_state(0)
@backend_dispatch
@_doc_params(rng=doc_rng)
def draw_graph( # noqa: PLR0913
adata: AnnData,
layout: _Layout = "fa",
Expand Down
2 changes: 2 additions & 0 deletions src/scanpy/tools/_embedding_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

from .. import logging as logg
from .._backends import backend_dispatch
from .._utils import sanitize_anndata

if TYPE_CHECKING:
Expand All @@ -32,6 +33,7 @@ def _calc_density(x: np.ndarray, y: np.ndarray):
return scaled_z


@backend_dispatch
def embedding_density( # noqa: PLR0912
adata: AnnData,
basis: str = "umap",
Expand Down
4 changes: 3 additions & 1 deletion src/scanpy/tools/_leiden.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .. import _utils
from .. import logging as logg
from .._backends import backend_dispatch
from .._compat import warn
from .._docs import doc_rng
from .._settings import Default
Expand Down Expand Up @@ -43,6 +44,8 @@
)


@_accepts_legacy_random_state(0)
@backend_dispatch
@_doc_params(
doc_adata=doc_adata,
restrict_to=doc_restrict_to,
Expand All @@ -51,7 +54,6 @@
obsp=doc_obsp,
rng=doc_rng,
)
@_accepts_legacy_random_state(0)
def leiden( # noqa: PLR0913
adata: AnnData,
resolution: float = 1,
Expand Down
4 changes: 3 additions & 1 deletion src/scanpy/tools/_louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from .. import _utils
from .. import logging as logg
from .._compat import pkg_version
from .._backends import backend_dispatch
from .._compat import deprecated, pkg_version
from .._utils import _choose_graph, _doc_params
from ._docs import (
doc_adata,
Expand Down Expand Up @@ -40,6 +41,7 @@
)


@backend_dispatch
@deprecated(Deprecation("1.12.0", "Use :func:`scanpy.tl.leiden` instead."))
@_doc_params(
doc_adata=doc_adata,
Expand Down
Loading
Loading