diff --git a/docs/api/preprocessing.md b/docs/api/preprocessing.md index 0950b9b296..5a3be50767 100644 --- a/docs/api/preprocessing.md +++ b/docs/api/preprocessing.md @@ -47,9 +47,12 @@ For visual quality control, see {func}`~scanpy.pl.highest_expr_genes` and pp.recipe_seurat ``` -## Batch effect correction +(pp-data-integration)= -Also see {ref}`data-integration`. Note that a simple batch correction method is available via {func}`pp.regress_out`. Checkout {mod}`scanpy.external` for more. +## Data integration + +Batch effect correction and other data integration. +Note that a simple batch correction method is available via {func}`pp.regress_out`. ```{eval-rst} .. autosummary:: @@ -57,8 +60,11 @@ Also see {ref}`data-integration`. Note that a simple batch correction method is :toctree: generated/ pp.combat + pp.harmony_integrate ``` +Also see {ref}`data integration tools ` and external {ref}`external data integration `. + ## Doublet detection ```{eval-rst} diff --git a/docs/external/preprocessing.md b/docs/external/preprocessing.md index 8cbf6449a6..6702e17b22 100644 --- a/docs/external/preprocessing.md +++ b/docs/external/preprocessing.md @@ -5,6 +5,11 @@ .. currentmodule:: scanpy.external ``` +Previously found here, but now part of scanpy’s main API: +- {func}`scanpy.pp.harmony_integrate` +- {func}`scanpy.pp.scrublet` +- {func}`scanpy.pp.scrublet_simulate_doublets` + (external-data-integration)= ## Data integration @@ -14,10 +19,8 @@ :toctree: ../generated/ pp.bbknn - pp.harmony_integrate pp.mnn_correct pp.scanorama_integrate - ``` ## Sample demultiplexing @@ -39,5 +42,4 @@ Note that the fundamental limitations of imputation are still under [debate](htt pp.dca pp.magic - ``` diff --git a/docs/release-notes/1.10.0.md b/docs/release-notes/1.10.0.md index 969633db0b..b12dd11e06 100644 --- a/docs/release-notes/1.10.0.md +++ b/docs/release-notes/1.10.0.md @@ -25,7 +25,7 @@ Some highlights: * {func}`scanpy.datasets.blobs` now accepts a `random_state` argument {pr}`2683` {smaller}`E Roellin` * {func}`scanpy.pp.pca` and {func}`scanpy.pp.regress_out` now accept a layer argument {pr}`2588` {smaller}`S Dicks` * {func}`scanpy.pp.subsample` with `copy=True` can now be called in backed mode {pr}`2624` {smaller}`E Roellin` -* {func}`scanpy.external.pp.harmony_integrate` now runs with 64 bit floats improving reproducibility {pr}`2655` {smaller}`S Dicks` +* {func}`scanpy.pp.harmony_integrate` now runs with 64 bit floats improving reproducibility {pr}`2655` {smaller}`S Dicks` * {func}`scanpy.tl.rank_genes_groups` no longer warns that it's default was changed from t-test_overestim_var to t-test {pr}`2798` {smaller}`L Heumos` * `scanpy.pp.calculate_qc_metrics` now allows `qc_vars` to be passed as a string {pr}`2859` {smaller}`N Teyssier` * {func}`scanpy.tl.leiden` and {func}`scanpy.tl.louvain` now store clustering parameters in the key provided by the `key_added` parameter instead of always writing to (or overwriting) a default key {pr}`2864` {smaller}`J Fan` diff --git a/docs/release-notes/1.11.0.md b/docs/release-notes/1.11.0.md index 875b32f364..70ceab0cdf 100644 --- a/docs/release-notes/1.11.0.md +++ b/docs/release-notes/1.11.0.md @@ -30,7 +30,7 @@ Release candidates: #### Documentation -- {guilabel}`rc1` Improve {func}`~scanpy.external.pp.harmony_integrate` docs {smaller}`D Kühl` ({pr}`3362`) +- {guilabel}`rc1` Improve {func}`~scanpy.pp.harmony_integrate` docs {smaller}`D Kühl` ({pr}`3362`) - {guilabel}`rc1` Raise {exc}`FutureWarning` when calling deprecated {mod}`scanpy.pp` functions {smaller}`P Angerer` ({pr}`3380`) - {guilabel}`rc1` {smaller}`P Angerer` ({pr}`3407`) diff --git a/pyproject.toml b/pyproject.toml index 126098f593..6087912dec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,6 @@ bbknn = [ "bbknn" ] dask = [ "anndata[dask]", "dask[array]>=2024.5.1" ] # PCA acceleration dask-ml = [ "dask-ml", "scanpy[dask]" ] -harmony = [ "harmonypy" ] leiden = [ "igraph>=0.10.8", "leidenalg>=0.10.1" ] louvain = [ "igraph", "louvain>=0.8.2", "setuptools" ] magic = [ "magic-impute>=2.0.4" ] @@ -136,6 +135,7 @@ doc = [ ] test-min = [ "dependency-groups", # for CI scripts doctests + "pooch", "pytest", "pytest-cov", # only for use from VS Code "pytest-mock", diff --git a/src/scanpy/external/pp/__init__.py b/src/scanpy/external/pp/__init__.py index a8b09f725d..aa3135f95f 100644 --- a/src/scanpy/external/pp/__init__.py +++ b/src/scanpy/external/pp/__init__.py @@ -3,26 +3,39 @@ from __future__ import annotations from ..._compat import deprecated -from ...preprocessing import _scrublet from ._bbknn import bbknn from ._dca import dca -from ._harmony_integrate import harmony_integrate from ._hashsolo import hashsolo from ._magic import magic from ._mnn_correct import mnn_correct from ._scanorama_integrate import scanorama_integrate -scrublet = deprecated("Import from sc.pp instead")(_scrublet.scrublet) -scrublet_simulate_doublets = deprecated("Import from sc.pp instead")( - _scrublet.scrublet_simulate_doublets -) - __all__ = [ "bbknn", "dca", - "harmony_integrate", "hashsolo", "magic", "mnn_correct", "scanorama_integrate", ] + + +@deprecated("Import from sc.pp instead") +def harmony_integrate(*args, **kwargs): + from ...preprocessing import harmony_integrate + + return harmony_integrate(*args, **kwargs) + + +@deprecated("Import from sc.pp instead") +def scrublet(*args, **kwargs): + from ...preprocessing import scrublet + + return scrublet(*args, **kwargs) + + +@deprecated("Import from sc.pp instead") +def scrublet_simulate_doublets(*args, **kwargs): + from ...preprocessing import scrublet_simulate_doublets + + return scrublet_simulate_doublets(*args, **kwargs) diff --git a/src/scanpy/external/pp/_harmony_integrate.py b/src/scanpy/external/pp/_harmony_integrate.py deleted file mode 100644 index d4e9982fef..0000000000 --- a/src/scanpy/external/pp/_harmony_integrate.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Use harmony to integrate cells from different experiments.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np - -from ..._compat import old_positionals -from ..._utils._doctests import doctest_needs - -if TYPE_CHECKING: - from collections.abc import Sequence - - from anndata import AnnData - - -@old_positionals("basis", "adjusted_basis") -@doctest_needs("harmonypy") -def harmony_integrate( - adata: AnnData, - key: str | Sequence[str], - *, - basis: str = "X_pca", - adjusted_basis: str = "X_pca_harmony", - **kwargs, -): - """Use harmonypy :cite:p:`Korsunsky2019` to integrate different experiments. - - Harmony :cite:p:`Korsunsky2019` is an algorithm for integrating single-cell - data from multiple experiments. This function uses the python - port of Harmony, ``harmonypy``, to integrate single-cell data - stored in an AnnData object. As Harmony works by adjusting the - principal components, this function should be run after performing - PCA but before computing the neighbor graph, as illustrated in the - example below. - - Parameters - ---------- - adata - The annotated data matrix. - key - The name of the column in ``adata.obs`` that differentiates - among experiments/batches. To integrate over two or more covariates, - you can pass multiple column names as a list. See ``vars_use`` - parameter of the ``harmonypy`` package for more details. - basis - The name of the field in ``adata.obsm`` where the PCA table is - stored. Defaults to ``'X_pca'``, which is the default for - ``sc.pp.pca()``. - adjusted_basis - The name of the field in ``adata.obsm`` where the adjusted PCA - table will be stored after running this function. Defaults to - ``X_pca_harmony``. - kwargs - Any additional arguments will be passed to - ``harmonypy.run_harmony()``. - - Returns - ------- - Updates adata with the field ``adata.obsm[obsm_out_field]``, - containing principal components adjusted by Harmony such that - different experiments are integrated. - - Example - ------- - First, load libraries and example dataset, and preprocess. - - >>> import scanpy as sc - >>> import scanpy.external as sce - >>> adata = sc.datasets.pbmc3k() - >>> sc.pp.recipe_zheng17(adata) - >>> sc.pp.pca(adata) - - We now arbitrarily assign a batch metadata variable to each cell - for the sake of example, but during real usage there would already - be a column in ``adata.obs`` giving the experiment each cell came - from. - - >>> adata.obs["batch"] = 1350 * ["a"] + 1350 * ["b"] - - Finally, run harmony. Afterwards, there will be a new table in - ``adata.obsm`` containing the adjusted PC's. - - >>> sce.pp.harmony_integrate(adata, "batch") - >>> "X_pca_harmony" in adata.obsm - True - - """ - try: - import harmonypy - except ImportError as e: - e.add_note("Please install `harmonypy` and try again.") - raise - - x = adata.obsm[basis].astype(np.float64) - - harmony_out = harmonypy.run_harmony(x, adata.obs, key, **kwargs) - - adata.obsm[adjusted_basis] = harmony_out.Z_corr.T diff --git a/src/scanpy/preprocessing/__init__.py b/src/scanpy/preprocessing/__init__.py index d80412860d..95cc13a57b 100644 --- a/src/scanpy/preprocessing/__init__.py +++ b/src/scanpy/preprocessing/__init__.py @@ -6,6 +6,7 @@ from ._combat import combat from ._deprecated.highly_variable_genes import filter_genes_dispersion from ._deprecated.sampling import subsample +from ._harmony_integrate import harmony_integrate from ._highly_variable_genes import highly_variable_genes from ._normalization import normalize_total from ._pca import pca @@ -31,6 +32,7 @@ "filter_cells", "filter_genes", "filter_genes_dispersion", + "harmony_integrate", "highly_variable_genes", "log1p", "neighbors", diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py new file mode 100644 index 0000000000..5a5c0c6b3d --- /dev/null +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -0,0 +1,525 @@ +from __future__ import annotations + +from dataclasses import KW_ONLY, InitVar, dataclass, field +from itertools import product +from typing import TYPE_CHECKING + +import numpy as np +from scipy.sparse import csr_matrix # noqa: TID251 +from sklearn.cluster import KMeans +from tqdm.auto import tqdm + +from ... import logging as log +from ..._settings import settings +from ..._settings.verbosity import Verbosity + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Literal + + import pandas as pd + + from ..._compat import CSBase + + +@dataclass +class Harmony: + """Harmony batch correction algorithm. + + Parameters + ---------- + x + Data matrix (n_cells x d) - typically PCA embeddings. + batch_df + DataFrame containing batch information. + batch_key + Column name(s) in batch_df containing batch labels. + """ + + batch_df: InitVar[pd.DataFrame] + batch_key: InitVar[str | Sequence[str]] + _: KW_ONLY + theta: float | Sequence[float] | None + sigma: float + n_clusters: int | None + max_iter_harmony: int + max_iter_clustering: int + tol_harmony: float + tol_clustering: float + ridge_lambda: float + correction_method: Literal["fast", "original"] + block_proportion: float + random_state: int | None + sparse: bool + + batch_codes: np.ndarray = field(init=False) + n_batches: int = field(init=False) + + def __post_init__( + self, batch_df: pd.DataFrame, batch_key: str | Sequence[str] + ) -> None: + if self.max_iter_harmony < 1: + msg = "max_iter_harmony must be >= 1" + raise ValueError(msg) + + # Process batch keys + self.batch_codes, self.n_batches = _get_batch_codes(batch_df, batch_key) + + def fit(self, x: np.ndarray) -> np.ndarray: + """Run Harmony. + + Returns + ------- + z_corr + Batch-corrected embedding matrix (n_cells x d). + """ + if self.random_state is not None: + np.random.seed(self.random_state) + + # Ensure input is C-contiguous float array (infer dtype from x) + x = np.ascontiguousarray(x) + n_cells = x.shape[0] + + # Normalize input for clustering + z_norm = _normalize_rows_l2(x) + + # Build phi matrix (one-hot encoding of batches) + if self.sparse: + phi = _one_hot_encode_sparse(self.batch_codes, self.n_batches, x.dtype) + n_b = np.asarray(phi.sum(axis=0)).ravel() + else: + phi = _one_hot_encode(self.batch_codes, self.n_batches, x.dtype) + n_b = phi.sum(axis=0) + pr_b = (n_b / n_cells).reshape(-1, 1) + + # Set default theta + if self.theta is None: + theta_arr = np.ones(self.n_batches, dtype=x.dtype) * 2.0 + elif isinstance(self.theta, (int, float)): + theta_arr = np.ones(self.n_batches, dtype=x.dtype) * float(self.theta) + else: + theta_arr = np.array(self.theta, dtype=x.dtype) + theta_arr = theta_arr.reshape(1, -1) + + # Set default n_clusters + if self.n_clusters is None: + n_clusters = int(min(100, n_cells / 30)) + n_clusters = max(n_clusters, 2) + else: + n_clusters = self.n_clusters + + # Initialize centroids and state arrays + r, e, o, obj_init = _initialize_centroids( + z_norm, + phi, + pr_b, + n_clusters=n_clusters, + sigma=self.sigma, + theta=theta_arr, + random_state=self.random_state, + ) + + # Main Harmony loop + objectives_harmony = [obj_init] + with tqdm( + range(self.max_iter_harmony), disable=settings.verbosity < Verbosity.info + ) as bar: + for i in bar: + r, e, o, obj = self._cluster( + z_norm, pr_b, r=r, e=e, o=o, theta=theta_arr + ) + if obj is not None: + objectives_harmony.append(obj) + z_hat = self._correct(x, r, o) + z_norm = _normalize_rows_l2(z_hat) + if self._is_convergent(objectives_harmony, self.tol_harmony): + log.info(f"Harmony converged in {i + 1} iterations") + break + else: + log.info( + f"Harmony did not converge after {self.max_iter_harmony} iterations." + ) + + return z_hat + + @staticmethod + def _is_convergent(objectives: list[float], tol: float) -> bool: + """Check Harmony convergence.""" + if len(objectives) < 2: + return False + obj_old = objectives[-2] + obj_new = objectives[-1] + return (obj_old - obj_new) < tol * abs(obj_old) + + def _cluster( + self, + z_norm: np.ndarray, + pr_b: np.ndarray, + *, + r: np.ndarray, + e: np.ndarray, + o: np.ndarray, + theta: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: + """Perform clustering step. Modifies r, e, o in-place.""" + return _clustering( + z_norm, + self.batch_codes, + self.n_batches, + pr_b, + r=r, + e=e, + o=o, + theta=theta, + sigma=self.sigma, + max_iter=self.max_iter_clustering, + tol=self.tol_clustering, + block_proportion=self.block_proportion, + ) + + def _correct(self, x: np.ndarray, r: np.ndarray, o: np.ndarray) -> np.ndarray: + """Perform correction step.""" + if self.correction_method == "fast": + return _correction_fast( + x, + self.batch_codes, + self.n_batches, + r, + o, + ridge_lambda=self.ridge_lambda, + ) + else: + return _correction_original( + x, + self.batch_codes, + self.n_batches, + r, + ridge_lambda=self.ridge_lambda, + ) + + +def _get_batch_codes( + batch_df: pd.DataFrame, + batch_key: str | Sequence[str], +) -> tuple[np.ndarray, int]: + """Get batch codes from DataFrame.""" + if isinstance(batch_key, str): + batch_vec = batch_df[batch_key] + elif len(batch_key) == 1: + batch_vec = batch_df[batch_key[0]] + else: + df = batch_df[list(batch_key)].astype("str") + batch_vec = df.apply(",".join, axis=1) + + batch_cat = batch_vec.astype("category") + codes = batch_cat.cat.codes.to_numpy(copy=True) + n_batches = len(batch_cat.cat.categories) + + return codes.astype(np.int32), n_batches + + +def _one_hot_encode( + codes: np.ndarray, + n_categories: int, + dtype: np.dtype, +) -> np.ndarray: + """One-hot encode category codes.""" + n = len(codes) + phi = np.zeros((n, n_categories), dtype=dtype) + phi[np.arange(n), codes] = 1.0 + return phi + + +def _one_hot_encode_sparse( + codes: np.ndarray, + n_categories: int, + dtype: np.dtype, +): + """One-hot encode category codes as sparse CSR matrix.""" + n = len(codes) + data = np.ones(n, dtype=dtype) + indices = codes.astype(np.int32) + indptr = np.arange(n + 1, dtype=np.int32) + return csr_matrix((data, indices, indptr), shape=(n, n_categories)) + + +def _normalize_rows_l2(x: np.ndarray) -> np.ndarray: + """L2 normalize each row of x.""" + norms = np.linalg.norm(x, axis=1, keepdims=True) + norms = np.maximum(norms, 1e-12) + return x / norms + + +def _normalize_rows_l1(r: np.ndarray) -> None: + """L1 normalize each row of r in-place (rows sum to 1).""" + row_sums = r.sum(axis=1, keepdims=True) + row_sums = np.maximum(row_sums, 1e-12) + r /= row_sums + + +def _initialize_centroids( + z_norm: np.ndarray, + phi: np.ndarray | CSBase, + pr_b: np.ndarray, + *, + n_clusters: int, + sigma: float, + theta: np.ndarray, + random_state: int | None, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, float]: + """Initialize cluster centroids using K-means.""" + kmeans = KMeans( + n_clusters=n_clusters, random_state=random_state, n_init=10, max_iter=25 + ) + kmeans.fit(z_norm) + + # Centroids + y = kmeans.cluster_centers_.copy() + y_norm = _normalize_rows_l2(y) + + # Compute soft cluster assignments r + term = -2.0 / sigma + r = _compute_r(z_norm, y_norm, term) + _normalize_rows_l1(r) + + # Initialize e (expected) and o (observed) + r_sum = r.sum(axis=0) + e = pr_b @ r_sum.reshape(1, -1) + o = phi.T @ r + + # Compute initial objective + obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) + + return r, e, o, obj + + +def _compute_r( + z: np.ndarray, + y: np.ndarray, + term: float, +) -> np.ndarray: + """Compute soft cluster assignments using NumPy dot.""" + dots = z @ y.T + return np.exp(term * (1.0 - dots)) + + +def _clustering( # noqa: PLR0913 + z_norm: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + pr_b: np.ndarray, + *, + r: np.ndarray, + e: np.ndarray, + o: np.ndarray, + theta: np.ndarray, + sigma: float, + max_iter: int, + tol: float, + block_proportion: float, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: + """Run clustering iterations (modifies r, e, o in-place).""" + n_cells = z_norm.shape[0] + k = r.shape[1] + n_blocks = min(n_cells, 1 // block_proportion) + term = -2.0 / sigma + + objectives_clustering = [] + + # Pre-allocate work arrays + y = np.empty((k, z_norm.shape[1]), dtype=z_norm.dtype) + y_norm = np.empty_like(y) + + for _ in range(max_iter): + # Compute cluster centroids: y = r.T @ z_norm, then normalize + np.dot(r.T, z_norm, out=y) + norms = np.linalg.norm(y, axis=1, keepdims=True) + norms = np.maximum(norms, 1e-12) + np.divide(y, norms, out=y_norm) + + # Randomly shuffle cell indices + idx_list = np.random.permutation(n_cells) + + # Process blocks + for block_idx, b in product( + np.array_split(idx_list, n_blocks), range(n_batches) + ): + mask = batch_codes[block_idx] == b + if not np.any(mask): + continue + + cell_idx = block_idx[mask] + + # Remove old r contribution from o and e + r_old = r[cell_idx, :] + r_old_sum = r_old.sum(axis=0) + o[b, :] -= r_old_sum + e -= pr_b * r_old_sum + + # Compute new r values + dots = z_norm[cell_idx, :] @ y_norm.T + r_new = np.exp(term * (1.0 - dots)) + + # Apply penalty + penalty = ((e[b, :] + 1.0) / (o[b, :] + 1.0)) ** theta[0, b] + r_new *= penalty + + # Normalize rows to sum to 1 + row_sums = r_new.sum(axis=1, keepdims=True) + row_sums = np.maximum(row_sums, 1e-12) + r_new /= row_sums + + # Store back + r[cell_idx, :] = r_new + + # Add new r contribution to o and e + r_new_sum = r_new.sum(axis=0) + o[b, :] += r_new_sum + e += pr_b * r_new_sum + + # Compute objective + obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) + objectives_clustering.append(obj) + + # Check convergence + if _is_convergent_clustering(objectives_clustering, tol): + obj = objectives_clustering[-1] + break + else: + obj = None + + return r, e, o, obj + + +def _correction_original( + x: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + r: np.ndarray, + *, + ridge_lambda: float, +) -> np.ndarray: + """Original correction method - per-cluster ridge regression.""" + _, d = x.shape + + # Ridge regularization matrix (don't penalize intercept) + id_mat = np.eye(n_batches + 1) + id_mat[0, 0] = 0 + lambda_mat = ridge_lambda * id_mat + + z = x.copy() + + for r_k in r.T: + r_sum_total = r_k.sum() + r_sum_per_batch = np.zeros(n_batches, dtype=x.dtype) + for b in range(n_batches): + r_sum_per_batch[b] = r_k[batch_codes == b].sum() + + phi_t_phi = np.zeros((n_batches + 1, n_batches + 1), dtype=x.dtype) + phi_t_phi[0, 0] = r_sum_total + phi_t_phi[0, 1:] = r_sum_per_batch + phi_t_phi[1:, 0] = r_sum_per_batch + phi_t_phi[1:, 1:] = np.diag(r_sum_per_batch) + phi_t_phi += lambda_mat + + phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) + phi_t_x[0, :] = r_k @ x + for b in range(n_batches): + mask = batch_codes == b + phi_t_x[b + 1, :] = r_k[mask] @ x[mask] + + try: + w = np.linalg.solve(phi_t_phi, phi_t_x) + except np.linalg.LinAlgError: + w = np.linalg.lstsq(phi_t_phi, phi_t_x, rcond=None)[0] + + w[0, :] = 0 + w_batch = w[batch_codes + 1, :] + z -= r_k[:, np.newaxis] * w_batch + + return z + + +def _correction_fast( + x: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + r: np.ndarray, + o: np.ndarray, + *, + ridge_lambda: float, +) -> np.ndarray: + """Fast correction method using precomputed factors.""" + _, d = x.shape + + z = x.copy() + p = np.eye(n_batches + 1) + + for o_k, r_k in zip(o.T, r.T, strict=True): + n_k = np.sum(o_k) + + factor = 1.0 / (o_k + ridge_lambda) + c = n_k + np.sum(-factor * o_k**2) + c_inv = 1.0 / c + + p[0, 1:] = -factor * o_k + + p_t_b_inv = np.zeros((n_batches + 1, n_batches + 1)) + p_t_b_inv[0, 0] = c_inv + p_t_b_inv[1:, 1:] = np.diag(factor) + p_t_b_inv[1:, 0] = p[0, 1:] * c_inv + + inv_mat = p_t_b_inv @ p + + phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) + phi_t_x[0, :] = r_k @ x + for b in range(n_batches): + mask = batch_codes == b + phi_t_x[b + 1, :] = r_k[mask] @ x[mask] + + w = inv_mat @ phi_t_x + w[0, :] = 0 + + w_batch = w[batch_codes + 1, :] + z -= r_k[:, np.newaxis] * w_batch + + return z + + +def _compute_objective( + y_norm: np.ndarray, + z_norm: np.ndarray, + r: np.ndarray, + *, + theta: np.ndarray, + sigma: float, + o: np.ndarray, + e: np.ndarray, +) -> float: + """Compute Harmony objective function.""" + zy = z_norm @ y_norm.T + kmeans_error = np.sum(r * 2.0 * (1.0 - zy)) + + r_row_sums = r.sum(axis=1, keepdims=True) + r_normalized = r / np.clip(r_row_sums, 1e-12, None) + entropy = sigma * np.sum(r_normalized * np.log(r_normalized + 1e-12)) + + log_ratio = np.log((o + 1) / (e + 1)) + diversity_penalty = sigma * np.sum(theta @ (o * log_ratio)) + + return kmeans_error + entropy + diversity_penalty + + +def _is_convergent_clustering( + objectives: list, + tol: float, + window_size: int = 3, +) -> bool: + """Check clustering convergence using window.""" + if len(objectives) < window_size + 1: + return False + + obj_old = sum(objectives[-window_size - 1 : -1]) + obj_new = sum(objectives[-window_size:]) + + return (obj_old - obj_new) < tol * abs(obj_old) diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py new file mode 100644 index 0000000000..1d29c7b272 --- /dev/null +++ b/src/scanpy/preprocessing/_harmony_integrate.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Literal + + from anndata import AnnData + from numpy.typing import DTypeLike + + +def harmony_integrate( # noqa: PLR0913 + adata: AnnData, + key: str | Sequence[str], + *, + basis: str = "X_pca", + adjusted_basis: str = "X_pca_harmony", + dtype: DTypeLike = np.float64, + theta: float | Sequence[float] | None = None, + sigma: float = 0.1, + n_clusters: int | None = None, + max_iter_harmony: int = 10, + max_iter_clustering: int = 200, + tol_harmony: float = 1e-4, + tol_clustering: float = 1e-5, + ridge_lambda: float = 1.0, + correction_method: Literal["fast", "original"] = "original", + block_proportion: float = 0.05, + random_state: int | None = 0, + sparse: bool = False, +) -> None: + """ + Integrate different experiments using the Harmony algorithm. + + This CPU implementation is based on the harmony-pytorch & rapids_singlecell version, + using NumPy for efficient computation. + + Parameters + ---------- + adata + The annotated data matrix. + key + The key(s) of the column(s) in ``adata.obs`` that differentiates + among experiments/batches. + basis + The name of the field in ``adata.obsm`` where the PCA table is + stored. Defaults to ``'X_pca'``. + adjusted_basis + The name of the field in ``adata.obsm`` where the adjusted PCA + table will be stored. Defaults to ``X_pca_harmony``. + dtype + The data type to use for Harmony computation. + theta + Diversity penalty weight(s). Default is 2 for each batch variable. + sigma + Width of soft clustering kernel. Default 0.1. + n_clusters + Number of clusters. Default is min(100, n_cells/30). + max_iter_harmony + Maximum Harmony iterations. Default 10. + max_iter_clustering + Maximum clustering iterations per Harmony round. Default 200. + tol_harmony + Convergence tolerance for Harmony. Default 1e-4. + tol_clustering + Convergence tolerance for clustering. Default 1e-5. + ridge_lambda + Ridge regression regularization. Default 1.0. + correction_method + Choose which method for the correction step: ``original`` for + original method, ``fast`` for improved method. + block_proportion + Fraction of cells processed per clustering iteration. Default 0.05. + random_state + Random seed for reproducibility. + sparse + Use sparse matrices for batch encoding. Reduces memory for large datasets. + + Returns + ------- + Updates adata with the field ``adata.obsm[adjusted_basis]``, + containing principal components adjusted by Harmony. + """ + from ._harmony import Harmony + + # Ensure the basis exists in adata.obsm + if basis not in adata.obsm: + msg = ( + f"The specified basis {basis!r} is not available in `adata.obsm`. " + f"Available bases: {list(adata.obsm.keys())}" + ) + raise ValueError(msg) + + # Get the input data + input_data = adata.obsm[basis] + + # Convert to numpy array with specified dtype + try: + x = np.ascontiguousarray(input_data, dtype=dtype) + except Exception as e: + msg = ( + f"Could not convert input of type {type(input_data).__name__} " + "to NumPy array." + ) + raise TypeError(msg) from e + + # Check for NaN values + if np.isnan(x).any(): + msg = ( + "Input data contains NaN values. Please handle these before " + "running harmony_integrate." + ) + raise ValueError(msg) + + # Run Harmony + harmony = Harmony( + adata.obs, + key, + theta=theta, + sigma=sigma, + n_clusters=n_clusters, + max_iter_harmony=max_iter_harmony, + max_iter_clustering=max_iter_clustering, + tol_harmony=tol_harmony, + tol_clustering=tol_clustering, + ridge_lambda=ridge_lambda, + correction_method=correction_method, + block_proportion=block_proportion, + random_state=random_state, + sparse=sparse, + ) + harmony_out = harmony.fit(x) + + # Store result + adata.obsm[adjusted_basis] = harmony_out diff --git a/src/testing/scanpy/_pytest/__init__.py b/src/testing/scanpy/_pytest/__init__.py index 04777f6aef..569d47b3f9 100644 --- a/src/testing/scanpy/_pytest/__init__.py +++ b/src/testing/scanpy/_pytest/__init__.py @@ -7,6 +7,7 @@ from types import MappingProxyType from typing import TYPE_CHECKING +import pooch import pytest from packaging.version import Version @@ -26,6 +27,7 @@ def original_settings( request: pytest.FixtureRequest, cache: pytest.Cache, tmp_path_factory: pytest.TempPathFactory, + monkeypatch: pytest.MonkeyPatch, ) -> Generator[Mapping[str, object], None, None]: """Switch to agg backend, reset settings, and close all figures at teardown.""" # make sure seaborn is imported and did its thing @@ -51,6 +53,9 @@ def original_settings( cache.mkdir("debug") # reuse data files between test runs (unless overwritten in the test) sc.settings.datasetdir = cache.mkdir("scanpy-data") + pooch.os_cache = pooch.utils.os_cache = pooch.core.os_cache = lambda p: ( + sc.settings.datasetdir / p + ) # create new writedir for each test run sc.settings.writedir = tmp_path_factory.mktemp("scanpy_write") diff --git a/tests/external/test_harmony_integrate.py b/tests/external/test_harmony_integrate.py deleted file mode 100644 index 2844354a2f..0000000000 --- a/tests/external/test_harmony_integrate.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -import scanpy as sc -import scanpy.external as sce -from testing.scanpy._helpers.data import pbmc3k -from testing.scanpy._pytest.marks import needs - -pytestmark = [needs.harmonypy] - - -def test_harmony_integrate(): - """Test that Harmony integrate works. - - This is a very simple test that just checks to see if the Harmony - integrate wrapper succesfully added a new field to ``adata.obsm`` - and makes sure it has the same dimensions as the original PCA table. - """ - adata = pbmc3k() - sc.pp.recipe_zheng17(adata) - sc.pp.pca(adata) - adata.obs["batch"] = 1350 * ["a"] + 1350 * ["b"] - sce.pp.harmony_integrate(adata, "batch") - assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape diff --git a/tests/test_harmony.py b/tests/test_harmony.py new file mode 100644 index 0000000000..5d90fa146e --- /dev/null +++ b/tests/test_harmony.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import pooch +import pytest +from anndata import AnnData +from scipy.stats import pearsonr + +from scanpy.preprocessing import harmony_integrate +from testing.scanpy._helpers.data import pbmc68k_reduced + +if TYPE_CHECKING: + from typing import Literal + + from numpy.typing import DTypeLike + + +DATA = dict( + pca=("pbmc_3500_pcs.tsv.gz", "md5:27e319b3ddcc0c00d98e70aa8e677b10"), + pca_harmonized=( + "pbmc_3500_pcs_harmonized.tsv.gz", + "md5:a7c4ce4b98c390997c66d63d48e09221", + ), + meta=("pbmc_3500_meta.tsv.gz", "md5:8c7ca20e926513da7cf0def1211baecb"), +) + + +def _get_measure( + x: np.ndarray, base: np.ndarray, norm: Literal["r", "L2"] +) -> np.ndarray: + """Compute correlation or L2 distance between arrays.""" + if norm == "r": # Compute per-column correlation + if x.ndim == 1: + corr, _ = pearsonr(x, base) + return corr + return np.array([pearsonr(x[:, i], base[:, i])[0] for i in range(x.shape[1])]) + if norm == "L2": + # L2 distance normalized by base norm + if x.ndim == 1: + return np.linalg.norm(x - base) / np.linalg.norm(base) + return np.array([ + np.linalg.norm(x[:, i] - base[:, i]) / np.linalg.norm(base[:, i]) + for i in range(x.shape[1]) + ]) + pytest.fail(f"Unknown {norm=!r}") + + +@pytest.fixture +def adata_reference() -> AnnData: + """Load reference data from harmonypy repository.""" + paths = { + f: pooch.retrieve( + f"https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/{name}", + known_hash=hash_, + ) + for f, (name, hash_) in DATA.items() + } + dfs = {f: pd.read_csv(path, delimiter="\t") for f, path in paths.items()} + # Create unique index using row number + cell name + dfs["meta"].index = [f"{i}_{cell}" for i, cell in enumerate(dfs["meta"]["cell"])] + adata = AnnData( + X=None, + obs=dfs["meta"], + obsm={"X_pca": dfs["pca"].values, "harmony_org": dfs["pca_harmonized"].values}, + ) + return adata + + +@pytest.mark.parametrize("correction_method", ["fast", "original"]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_harmony_integrate( + correction_method: Literal["fast", "original"], dtype: DTypeLike +) -> None: + """Test that Harmony integrate works.""" + adata = pbmc68k_reduced() + harmony_integrate( + adata, "bulk_labels", correction_method=correction_method, dtype=dtype + ) + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_harmony_integrate_algos(subtests: pytest.Subtests, dtype: DTypeLike) -> None: + """Test that both correction methods produce similar results.""" + adata = pbmc68k_reduced() + + harmony_integrate(adata, "bulk_labels", correction_method="fast", dtype=dtype) + fast = adata.obsm["X_pca_harmony"].copy() + harmony_integrate(adata, "bulk_labels", correction_method="original", dtype=dtype) + slow = adata.obsm["X_pca_harmony"].copy() + + with subtests.test("r"): + assert _get_measure(fast, slow, "r").min() > 0.99 + with subtests.test("L2"): + assert _get_measure(fast, slow, "L2").max() < 0.1 + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("correction_method", ["fast", "original"]) +def test_harmony_integrate_reference( + *, + subtests: pytest.Subtests, + adata_reference: AnnData, + dtype: DTypeLike, + correction_method: Literal["fast", "original"], +) -> None: + """Test that Harmony produces results similar to the reference implementation.""" + harmony_integrate( + adata_reference, + "donor", + correction_method=correction_method, + dtype=dtype, + max_iter_harmony=20, + ) + x, base = adata_reference.obsm["harmony_org"], adata_reference.obsm["X_pca_harmony"] + + with subtests.test("r"): + assert _get_measure(x, base, "r").min() > 0.95 + with subtests.test("L2"): + assert _get_measure(x, base, "L2").max() < 0.05 + + +def test_harmony_multiple_keys() -> None: + """Test Harmony with multiple batch keys.""" + adata = pbmc68k_reduced() + # Create a second batch key + adata.obs["batch2"] = np.random.choice(["A", "B", "C"], size=adata.n_obs) + + harmony_integrate(adata, ["bulk_labels", "batch2"], correction_method="original") + + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape + + +def test_harmony_custom_parameters() -> None: + """Test Harmony with custom parameters.""" + adata = pbmc68k_reduced() + harmony_integrate( + adata, + "bulk_labels", + theta=1.5, + sigma=0.15, + n_clusters=50, + max_iter_harmony=5, + ridge_lambda=0.5, + ) + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape + + +def test_harmony_no_nan_output() -> None: + """Test that Harmony output contains no NaN values.""" + adata = pbmc68k_reduced() + harmony_integrate(adata, "bulk_labels") + assert not np.isnan(adata.obsm["X_pca_harmony"]).any() + + +def test_harmony_input_validation(subtests) -> None: + """Test that Harmony raises errors for invalid inputs.""" + adata = pbmc68k_reduced() + + with subtests.test("no basis"), pytest.raises(ValueError, match="not available"): + harmony_integrate(adata, "bulk_labels", basis="nonexistent") + with subtests.test("no key"), pytest.raises(KeyError): + harmony_integrate(adata, "nonexistent_key")