From eecce50ba68e6274f90b7ad07e586fd125bf5ea7 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 26 Jan 2026 10:48:16 +0100 Subject: [PATCH 1/9] first draft --- src/scanpy/external/pp/__init__.py | 2 +- src/scanpy/external/pp/_harmony_integrate.py | 100 ---- src/scanpy/preprocessing/__init__.py | 2 + src/scanpy/preprocessing/_harmony/__init__.py | 520 ++++++++++++++++++ .../preprocessing/_harmony_integrate.py | 100 ++++ tests/external/test_harmony_integrate.py | 23 - tests/test_harmony.py | 162 ++++++ 7 files changed, 785 insertions(+), 124 deletions(-) delete mode 100644 src/scanpy/external/pp/_harmony_integrate.py create mode 100644 src/scanpy/preprocessing/_harmony/__init__.py create mode 100644 src/scanpy/preprocessing/_harmony_integrate.py delete mode 100644 tests/external/test_harmony_integrate.py create mode 100644 tests/test_harmony.py diff --git a/src/scanpy/external/pp/__init__.py b/src/scanpy/external/pp/__init__.py index a8b09f725d..4367a6fdfb 100644 --- a/src/scanpy/external/pp/__init__.py +++ b/src/scanpy/external/pp/__init__.py @@ -4,9 +4,9 @@ from ..._compat import deprecated from ...preprocessing import _scrublet +from ...preprocessing._harmony_integrate import harmony_integrate from ._bbknn import bbknn from ._dca import dca -from ._harmony_integrate import harmony_integrate from ._hashsolo import hashsolo from ._magic import magic from ._mnn_correct import mnn_correct diff --git a/src/scanpy/external/pp/_harmony_integrate.py b/src/scanpy/external/pp/_harmony_integrate.py deleted file mode 100644 index 4fd908d955..0000000000 --- a/src/scanpy/external/pp/_harmony_integrate.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Use harmony to integrate cells from different experiments.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np - -from ..._compat import old_positionals -from ..._utils._doctests import doctest_needs - -if TYPE_CHECKING: - from collections.abc import Sequence - - from anndata import AnnData - - -@old_positionals("basis", "adjusted_basis") -@doctest_needs("harmonypy") -def harmony_integrate( - adata: AnnData, - key: str | Sequence[str], - *, - basis: str = "X_pca", - adjusted_basis: str = "X_pca_harmony", - **kwargs, -): - """Use harmonypy :cite:p:`Korsunsky2019` to integrate different experiments. - - Harmony :cite:p:`Korsunsky2019` is an algorithm for integrating single-cell - data from multiple experiments. This function uses the python - port of Harmony, ``harmonypy``, to integrate single-cell data - stored in an AnnData object. As Harmony works by adjusting the - principal components, this function should be run after performing - PCA but before computing the neighbor graph, as illustrated in the - example below. - - Parameters - ---------- - adata - The annotated data matrix. - key - The name of the column in ``adata.obs`` that differentiates - among experiments/batches. To integrate over two or more covariates, - you can pass multiple column names as a list. See ``vars_use`` - parameter of the ``harmonypy`` package for more details. - basis - The name of the field in ``adata.obsm`` where the PCA table is - stored. Defaults to ``'X_pca'``, which is the default for - ``sc.pp.pca()``. - adjusted_basis - The name of the field in ``adata.obsm`` where the adjusted PCA - table will be stored after running this function. Defaults to - ``X_pca_harmony``. - kwargs - Any additional arguments will be passed to - ``harmonypy.run_harmony()``. - - Returns - ------- - Updates adata with the field ``adata.obsm[obsm_out_field]``, - containing principal components adjusted by Harmony such that - different experiments are integrated. - - Example - ------- - First, load libraries and example dataset, and preprocess. - - >>> import scanpy as sc - >>> import scanpy.external as sce - >>> adata = sc.datasets.pbmc3k() - >>> sc.pp.recipe_zheng17(adata) - >>> sc.pp.pca(adata) - - We now arbitrarily assign a batch metadata variable to each cell - for the sake of example, but during real usage there would already - be a column in ``adata.obs`` giving the experiment each cell came - from. - - >>> adata.obs["batch"] = 1350 * ["a"] + 1350 * ["b"] - - Finally, run harmony. Afterwards, there will be a new table in - ``adata.obsm`` containing the adjusted PC's. - - >>> sce.pp.harmony_integrate(adata, "batch") - >>> "X_pca_harmony" in adata.obsm - True - - """ - try: - import harmonypy - except ImportError as e: - msg = "\nplease install harmonypy:\n\n\tpip install harmonypy" - raise ImportError(msg) from e - - x = adata.obsm[basis].astype(np.float64) - - harmony_out = harmonypy.run_harmony(x, adata.obs, key, **kwargs) - - adata.obsm[adjusted_basis] = harmony_out.Z_corr.T diff --git a/src/scanpy/preprocessing/__init__.py b/src/scanpy/preprocessing/__init__.py index d80412860d..95cc13a57b 100644 --- a/src/scanpy/preprocessing/__init__.py +++ b/src/scanpy/preprocessing/__init__.py @@ -6,6 +6,7 @@ from ._combat import combat from ._deprecated.highly_variable_genes import filter_genes_dispersion from ._deprecated.sampling import subsample +from ._harmony_integrate import harmony_integrate from ._highly_variable_genes import highly_variable_genes from ._normalization import normalize_total from ._pca import pca @@ -31,6 +32,7 @@ "filter_cells", "filter_genes", "filter_genes_dispersion", + "harmony_integrate", "highly_variable_genes", "log1p", "neighbors", diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py new file mode 100644 index 0000000000..282f6589a2 --- /dev/null +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -0,0 +1,520 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from scipy.sparse import csr_matrix # noqa: TID251 +from sklearn.cluster import KMeans +from tqdm.auto import tqdm + +if TYPE_CHECKING: + import pandas as pd + + +def harmonize( # noqa: PLR0913, PLR0912 + x: np.ndarray, + batch_df: pd.DataFrame, + batch_key: str | list[str], + *, + theta: float | list[float] | None = None, + sigma: float = 0.1, + n_clusters: int | None = None, + max_iter_harmony: int = 10, + max_iter_clustering: int = 200, + tol_harmony: float = 1e-4, + tol_clustering: float = 1e-5, + ridge_lambda: float = 1.0, + correction_method: str = "original", + block_proportion: float = 0.05, + random_state: int | None = 0, + verbose: bool = False, + sparse: bool = False, +) -> np.ndarray: + """ + Run Harmony batch correction algorithm. + + Parameters + ---------- + x + Data matrix (n_cells x d) - typically PCA embeddings. + batch_df + DataFrame containing batch information. + batch_key + Column name(s) in batch_df containing batch labels. + theta + Diversity penalty weight(s). Default is 2 for each batch variable. + sigma + Width of soft clustering kernel. Default 0.1. + n_clusters + Number of clusters. Default is min(100, n_cells/30). + max_iter_harmony + Maximum Harmony iterations. Default 10. + max_iter_clustering + Maximum clustering iterations per Harmony round. Default 200. + tol_harmony + Convergence tolerance for Harmony. Default 1e-4. + tol_clustering + Convergence tolerance for clustering. Default 1e-5. + ridge_lambda + Ridge regression regularization. Default 1.0. + correction_method + 'original' or 'fast'. Default 'original'. + block_proportion + Fraction of cells processed per clustering iteration. Default 0.05. + random_state + Random seed for reproducibility. + verbose + Print progress information. + sparse + Use sparse matrices for phi. Reduces memory for large datasets. + + Returns + ------- + z_corr + Batch-corrected embedding matrix (n_cells x d). + """ + if random_state is not None: + np.random.seed(random_state) + + # Ensure input is C-contiguous float array (infer dtype from x) + x = np.ascontiguousarray(x) + dtype = x.dtype + n_cells = x.shape[0] + + # Normalize input for clustering + z_norm = _normalize_rows_l2(x) + + # Process batch keys + batch_codes, n_batches = _get_batch_codes(batch_df, batch_key) + + # Build phi matrix (one-hot encoding of batches) + if sparse: + phi = _one_hot_encode_sparse(batch_codes, n_batches, dtype) + n_b = np.asarray(phi.sum(axis=0)).ravel() + else: + phi = _one_hot_encode(batch_codes, n_batches, dtype) + n_b = phi.sum(axis=0) + pr_b = (n_b / n_cells).reshape(-1, 1) + + # Set default theta + if theta is None: + theta_arr = np.ones(n_batches, dtype=dtype) * 2.0 + elif isinstance(theta, (int, float)): + theta_arr = np.ones(n_batches, dtype=dtype) * float(theta) + else: + theta_arr = np.array(theta, dtype=dtype) + theta_arr = theta_arr.reshape(1, -1) + + # Set default n_clusters + if n_clusters is None: + n_clusters = int(min(100, n_cells / 30)) + n_clusters = max(n_clusters, 2) + + # Initialize centroids and state arrays + r, e, o, objectives_harmony = _initialize_centroids( + z_norm, + phi, + pr_b, + n_clusters=n_clusters, + sigma=sigma, + theta=theta_arr, + random_state=random_state, + ) + + # Main Harmony loop + converged = False + z_hat = x.copy() + + for i in tqdm(range(max_iter_harmony), disable=not verbose): + # Clustering step + _clustering( + z_norm, + batch_codes, + n_batches, + pr_b, + r=r, + e=e, + o=o, + theta=theta_arr, + sigma=sigma, + max_iter=max_iter_clustering, + tol=tol_clustering, + block_proportion=block_proportion, + objectives_harmony=objectives_harmony, + ) + + # Correction step + if correction_method == "fast": + z_hat = _correction_fast( + x, batch_codes, n_batches, r, o, ridge_lambda=ridge_lambda + ) + else: + z_hat = _correction_original( + x, batch_codes, n_batches, r, ridge_lambda=ridge_lambda + ) + + # Normalize corrected data for next iteration + z_norm = _normalize_rows_l2(z_hat) + + # Check convergence + if _is_convergent_harmony(objectives_harmony, tol_harmony): + converged = True + if verbose: + print(f"Harmony converged in {i + 1} iterations") + break + + if not converged and verbose: + print(f"Harmony did not converge after {max_iter_harmony} iterations.") + + return z_hat + + +def _get_batch_codes( + batch_df: pd.DataFrame, + batch_key: str | list[str], +) -> tuple[np.ndarray, int]: + """Get batch codes from DataFrame.""" + if isinstance(batch_key, str): + batch_vec = batch_df[batch_key] + elif len(batch_key) == 1: + batch_vec = batch_df[batch_key[0]] + else: + df = batch_df[batch_key].astype("str") + batch_vec = df.apply(lambda row: ",".join(row), axis=1) + + batch_cat = batch_vec.astype("category") + codes = batch_cat.cat.codes.values.copy() + n_batches = len(batch_cat.cat.categories) + + return codes.astype(np.int32), n_batches + + +def _one_hot_encode( + codes: np.ndarray, + n_categories: int, + dtype: np.dtype, +) -> np.ndarray: + """One-hot encode category codes.""" + n = len(codes) + phi = np.zeros((n, n_categories), dtype=dtype) + phi[np.arange(n), codes] = 1.0 + return phi + + +def _one_hot_encode_sparse( + codes: np.ndarray, + n_categories: int, + dtype: np.dtype, +): + """One-hot encode category codes as sparse CSR matrix.""" + n = len(codes) + data = np.ones(n, dtype=dtype) + indices = codes.astype(np.int32) + indptr = np.arange(n + 1, dtype=np.int32) + return csr_matrix((data, indices, indptr), shape=(n, n_categories)) + + +def _normalize_rows_l2(x: np.ndarray) -> np.ndarray: + """L2 normalize each row of x.""" + norms = np.linalg.norm(x, axis=1, keepdims=True) + norms = np.maximum(norms, 1e-12) + return x / norms + + +def _normalize_rows_l1(r: np.ndarray) -> None: + """L1 normalize each row of r in-place (rows sum to 1).""" + row_sums = r.sum(axis=1, keepdims=True) + row_sums = np.maximum(row_sums, 1e-12) + r /= row_sums + + +def _initialize_centroids( + z_norm: np.ndarray, + phi: np.ndarray, + pr_b: np.ndarray, + *, + n_clusters: int, + sigma: float, + theta: np.ndarray, + random_state: int | None, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, list]: + """Initialize cluster centroids using K-means.""" + kmeans = KMeans( + n_clusters=n_clusters, random_state=random_state, n_init=10, max_iter=25 + ) + kmeans.fit(z_norm) + + # Centroids + y = kmeans.cluster_centers_.copy() + y_norm = _normalize_rows_l2(y) + + # Compute soft cluster assignments r + term = -2.0 / sigma + r = _compute_r(z_norm, y_norm, term) + _normalize_rows_l1(r) + + # Initialize e (expected) and o (observed) + r_sum = r.sum(axis=0) + e = pr_b @ r_sum.reshape(1, -1) + o = phi.T @ r + + # Compute initial objective + objectives_harmony: list = [] + obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) + objectives_harmony.append(obj) + + return r, e, o, objectives_harmony + + +def _compute_r( + z: np.ndarray, + y: np.ndarray, + term: float, +) -> np.ndarray: + """Compute soft cluster assignments using NumPy dot.""" + dots = z @ y.T + return np.exp(term * (1.0 - dots)) + + +def _clustering( # noqa: PLR0913 + z_norm: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + pr_b: np.ndarray, + *, + r: np.ndarray, + e: np.ndarray, + o: np.ndarray, + theta: np.ndarray, + sigma: float, + max_iter: int, + tol: float, + block_proportion: float, + objectives_harmony: list, +) -> None: + """Run clustering iterations (modifies r, e, o in-place).""" + n_cells = z_norm.shape[0] + k = r.shape[1] + block_size = max(1, int(n_cells * block_proportion)) + term = -2.0 / sigma + + objectives_clustering = [] + + # Pre-allocate work arrays + y = np.empty((k, z_norm.shape[1]), dtype=z_norm.dtype) + y_norm = np.empty_like(y) + + for _ in range(max_iter): + # Compute cluster centroids: y = r.T @ z_norm, then normalize + np.dot(r.T, z_norm, out=y) + norms = np.linalg.norm(y, axis=1, keepdims=True) + norms = np.maximum(norms, 1e-12) + np.divide(y, norms, out=y_norm) + + # Randomly shuffle cell indices + idx_list = np.random.permutation(n_cells) + + # Process blocks + pos = 0 + while pos < n_cells: + end_pos = min(pos + block_size, n_cells) + block_idx = idx_list[pos:end_pos] + + for b in range(n_batches): + mask = batch_codes[block_idx] == b + if not np.any(mask): + continue + + cell_idx = block_idx[mask] + + # Remove old r contribution from o and e + r_old = r[cell_idx, :] + r_old_sum = r_old.sum(axis=0) + o[b, :] -= r_old_sum + e -= pr_b * r_old_sum + + # Compute new r values + dots = z_norm[cell_idx, :] @ y_norm.T + r_new = np.exp(term * (1.0 - dots)) + + # Apply penalty + penalty = ((e[b, :] + 1.0) / (o[b, :] + 1.0)) ** theta[0, b] + r_new *= penalty + + # Normalize rows to sum to 1 + row_sums = r_new.sum(axis=1, keepdims=True) + row_sums = np.maximum(row_sums, 1e-12) + r_new /= row_sums + + # Store back + r[cell_idx, :] = r_new + + # Add new r contribution to o and e + r_new_sum = r_new.sum(axis=0) + o[b, :] += r_new_sum + e += pr_b * r_new_sum + + pos = end_pos + + # Compute objective + obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) + objectives_clustering.append(obj) + + # Check convergence + if _is_convergent_clustering(objectives_clustering, tol): + objectives_harmony.append(objectives_clustering[-1]) + break + + +def _correction_original( + x: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + r: np.ndarray, + *, + ridge_lambda: float, +) -> np.ndarray: + """Original correction method - per-cluster ridge regression.""" + _, d = x.shape + k = r.shape[1] + + # Ridge regularization matrix (don't penalize intercept) + id_mat = np.eye(n_batches + 1) + id_mat[0, 0] = 0 + lambda_mat = ridge_lambda * id_mat + + z = x.copy() + + for k_idx in range(k): + r_k = r[:, k_idx] + + r_sum_total = r_k.sum() + r_sum_per_batch = np.zeros(n_batches, dtype=x.dtype) + for b in range(n_batches): + r_sum_per_batch[b] = r_k[batch_codes == b].sum() + + phi_t_phi = np.zeros((n_batches + 1, n_batches + 1), dtype=x.dtype) + phi_t_phi[0, 0] = r_sum_total + phi_t_phi[0, 1:] = r_sum_per_batch + phi_t_phi[1:, 0] = r_sum_per_batch + phi_t_phi[1:, 1:] = np.diag(r_sum_per_batch) + phi_t_phi += lambda_mat + + phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) + phi_t_x[0, :] = r_k @ x + for b in range(n_batches): + mask = batch_codes == b + phi_t_x[b + 1, :] = r_k[mask] @ x[mask] + + try: + w = np.linalg.solve(phi_t_phi, phi_t_x) + except np.linalg.LinAlgError: + w = np.linalg.lstsq(phi_t_phi, phi_t_x, rcond=None)[0] + + w[0, :] = 0 + w_batch = w[batch_codes + 1, :] + z -= r_k[:, np.newaxis] * w_batch + + return z + + +def _correction_fast( + x: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + r: np.ndarray, + o: np.ndarray, + *, + ridge_lambda: float, +) -> np.ndarray: + """Fast correction method using precomputed factors.""" + _, d = x.shape + k = r.shape[1] + + z = x.copy() + p = np.eye(n_batches + 1) + + for k_idx in range(k): + o_k = o[:, k_idx] + n_k = np.sum(o_k) + + factor = 1.0 / (o_k + ridge_lambda) + c = n_k + np.sum(-factor * o_k**2) + c_inv = 1.0 / c + + p[0, 1:] = -factor * o_k + + p_t_b_inv = np.zeros((n_batches + 1, n_batches + 1)) + p_t_b_inv[0, 0] = c_inv + p_t_b_inv[1:, 1:] = np.diag(factor) + p_t_b_inv[1:, 0] = p[0, 1:] * c_inv + + inv_mat = p_t_b_inv @ p + + r_k = r[:, k_idx] + phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) + phi_t_x[0, :] = r_k @ x + for b in range(n_batches): + mask = batch_codes == b + phi_t_x[b + 1, :] = r_k[mask] @ x[mask] + + w = inv_mat @ phi_t_x + w[0, :] = 0 + + w_batch = w[batch_codes + 1, :] + z -= r_k[:, np.newaxis] * w_batch + + return z + + +def _compute_objective( + y_norm: np.ndarray, + z_norm: np.ndarray, + r: np.ndarray, + *, + theta: np.ndarray, + sigma: float, + o: np.ndarray, + e: np.ndarray, +) -> float: + """Compute Harmony objective function.""" + zy = z_norm @ y_norm.T + kmeans_error = np.sum(r * 2.0 * (1.0 - zy)) + + r_row_sums = r.sum(axis=1, keepdims=True) + r_normalized = r / np.clip(r_row_sums, 1e-12, None) + entropy = sigma * np.sum(r_normalized * np.log(r_normalized + 1e-12)) + + log_ratio = np.log((o + 1) / (e + 1)) + diversity_penalty = sigma * np.sum(theta @ (o * log_ratio)) + + return kmeans_error + entropy + diversity_penalty + + +def _is_convergent_harmony( + objectives: list, + tol: float, +) -> bool: + """Check Harmony convergence.""" + if len(objectives) < 2: + return False + + obj_old = objectives[-2] + obj_new = objectives[-1] + + return (obj_old - obj_new) < tol * abs(obj_old) + + +def _is_convergent_clustering( + objectives: list, + tol: float, + window_size: int = 3, +) -> bool: + """Check clustering convergence using window.""" + if len(objectives) < window_size + 1: + return False + + obj_old = sum(objectives[-window_size - 1 : -1]) + obj_new = sum(objectives[-window_size:]) + + return (obj_old - obj_new) < tol * abs(obj_old) diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py new file mode 100644 index 0000000000..78249e645e --- /dev/null +++ b/src/scanpy/preprocessing/_harmony_integrate.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from typing import Literal + + from anndata import AnnData + + +def harmony_integrate( + adata: AnnData, + key: str | list[str], + *, + basis: str = "X_pca", + adjusted_basis: str = "X_pca_harmony", + dtype: type = np.float64, + correction_method: Literal["fast", "original"] = "original", + sparse: bool = False, + **kwargs, +) -> None: + """ + Integrate different experiments using the Harmony algorithm. + + This CPU implementation is based on the harmony-pytorch & rapids_singlecell version, + using NumPy for efficient computation. + + Parameters + ---------- + adata + The annotated data matrix. + key + The key(s) of the column(s) in ``adata.obs`` that differentiates + among experiments/batches. + basis + The name of the field in ``adata.obsm`` where the PCA table is + stored. Defaults to ``'X_pca'``. + adjusted_basis + The name of the field in ``adata.obsm`` where the adjusted PCA + table will be stored. Defaults to ``X_pca_harmony``. + dtype + The data type to use for Harmony computation. + correction_method + Choose which method for the correction step: ``original`` for + original method, ``fast`` for improved method. + sparse + Use sparse matrices for batch encoding. Reduces memory for large datasets. + **kwargs + Additional arguments passed to ``harmonize()``. + + Returns + ------- + Updates adata with the field ``adata.obsm[adjusted_basis]``, + containing principal components adjusted by Harmony. + """ + from ._harmony import harmonize + + # Ensure the basis exists in adata.obsm + if basis not in adata.obsm: + msg = ( + f"The specified basis '{basis}' is not available in adata.obsm. " + f"Available bases: {list(adata.obsm.keys())}" + ) + raise ValueError(msg) + + # Get the input data + input_data = adata.obsm[basis] + + # Convert to numpy array with specified dtype + try: + x = np.ascontiguousarray(input_data, dtype=dtype) + except Exception as e: + msg = ( + f"Could not convert input of type {type(input_data).__name__} " + "to NumPy array." + ) + raise TypeError(msg) from e + + # Check for NaN values + if np.isnan(x).any(): + msg = ( + "Input data contains NaN values. Please handle these before " + "running harmony_integrate." + ) + raise ValueError(msg) + + # Run Harmony + harmony_out = harmonize( + x, + adata.obs, + key, + correction_method=correction_method, + sparse=sparse, + **kwargs, + ) + + # Store result + adata.obsm[adjusted_basis] = harmony_out diff --git a/tests/external/test_harmony_integrate.py b/tests/external/test_harmony_integrate.py deleted file mode 100644 index 2844354a2f..0000000000 --- a/tests/external/test_harmony_integrate.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -import scanpy as sc -import scanpy.external as sce -from testing.scanpy._helpers.data import pbmc3k -from testing.scanpy._pytest.marks import needs - -pytestmark = [needs.harmonypy] - - -def test_harmony_integrate(): - """Test that Harmony integrate works. - - This is a very simple test that just checks to see if the Harmony - integrate wrapper succesfully added a new field to ``adata.obsm`` - and makes sure it has the same dimensions as the original PCA table. - """ - adata = pbmc3k() - sc.pp.recipe_zheng17(adata) - sc.pp.pca(adata) - adata.obs["batch"] = 1350 * ["a"] + 1350 * ["b"] - sce.pp.harmony_integrate(adata, "batch") - assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape diff --git a/tests/test_harmony.py b/tests/test_harmony.py new file mode 100644 index 0000000000..fcecacd0c5 --- /dev/null +++ b/tests/test_harmony.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import anndata as ad +import numpy as np +import pandas as pd +import pooch +import pytest +from scipy.stats import pearsonr + +import scanpy as sc +from scanpy.preprocessing import harmony_integrate + + +def _get_measure(x, base, norm): + """Compute correlation or L2 distance between arrays.""" + assert norm in ["r", "L2"] + + if norm == "r": + # Compute per-column correlation + if x.ndim == 1: + corr, _ = pearsonr(x, base) + return corr + else: + corrs = [] + for i in range(x.shape[1]): + corr, _ = pearsonr(x[:, i], base[:, i]) + corrs.append(corr) + return np.array(corrs) + # L2 distance normalized by base norm + elif x.ndim == 1: + return np.linalg.norm(x - base) / np.linalg.norm(base) + else: + dists = [] + for i in range(x.shape[1]): + dist = np.linalg.norm(x[:, i] - base[:, i]) / np.linalg.norm(base[:, i]) + dists.append(dist) + return np.array(dists) + + +@pytest.fixture +def adata_reference(): + """Load reference data from harmonypy repository.""" + x_pca_file = pooch.retrieve( + "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_pcs.tsv.gz", + known_hash="md5:27e319b3ddcc0c00d98e70aa8e677b10", + ) + x_pca = pd.read_csv(x_pca_file, delimiter="\t") + x_pca_harmony_file = pooch.retrieve( + "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_pcs_harmonized.tsv.gz", + known_hash="md5:a7c4ce4b98c390997c66d63d48e09221", + ) + x_pca_harmony = pd.read_csv(x_pca_harmony_file, delimiter="\t") + meta_file = pooch.retrieve( + "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_meta.tsv.gz", + known_hash="md5:8c7ca20e926513da7cf0def1211baecb", + ) + meta = pd.read_csv(meta_file, delimiter="\t") + # Create unique index using row number + cell name + meta.index = [f"{i}_{cell}" for i, cell in enumerate(meta["cell"])] + adata = ad.AnnData( + X=None, + obs=meta, + obsm={"X_pca": x_pca.values, "harmony_org": x_pca_harmony.values}, + ) + return adata + + +@pytest.mark.parametrize("correction_method", ["fast", "original"]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_harmony_integrate(correction_method, dtype): + """Test that Harmony integrate works.""" + adata = sc.datasets.pbmc68k_reduced() + harmony_integrate( + adata, "bulk_labels", correction_method=correction_method, dtype=dtype + ) + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_harmony_integrate_algos(dtype): + """Test that both correction methods produce similar results.""" + adata = sc.datasets.pbmc68k_reduced() + harmony_integrate(adata, "bulk_labels", correction_method="fast", dtype=dtype) + fast = adata.obsm["X_pca_harmony"].copy() + harmony_integrate(adata, "bulk_labels", correction_method="original", dtype=dtype) + slow = adata.obsm["X_pca_harmony"].copy() + assert _get_measure(fast, slow, "r").min() > 0.99 + assert _get_measure(fast, slow, "L2").max() < 0.1 + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("correction_method", ["fast", "original"]) +def test_harmony_integrate_reference(adata_reference, *, dtype, correction_method): + """Test that Harmony produces results similar to the reference implementation.""" + harmony_integrate( + adata_reference, + "donor", + correction_method=correction_method, + dtype=dtype, + max_iter_harmony=20, + ) + + assert ( + _get_measure( + adata_reference.obsm["harmony_org"], + adata_reference.obsm["X_pca_harmony"], + "L2", + ).max() + < 0.05 + ) + assert ( + _get_measure( + adata_reference.obsm["harmony_org"], + adata_reference.obsm["X_pca_harmony"], + "r", + ).min() + > 0.95 + ) + + +def test_harmony_multiple_keys(): + """Test Harmony with multiple batch keys.""" + adata = sc.datasets.pbmc68k_reduced() + # Create a second batch key + adata.obs["batch2"] = np.random.choice(["A", "B", "C"], size=adata.n_obs) + harmony_integrate(adata, ["bulk_labels", "batch2"], correction_method="original") + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape + + +def test_harmony_custom_parameters(): + """Test Harmony with custom parameters.""" + adata = sc.datasets.pbmc68k_reduced() + harmony_integrate( + adata, + "bulk_labels", + theta=1.5, + sigma=0.15, + n_clusters=50, + max_iter_harmony=5, + ridge_lambda=0.5, + ) + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape + + +def test_harmony_no_nan_output(): + """Test that Harmony output contains no NaN values.""" + adata = sc.datasets.pbmc68k_reduced() + harmony_integrate(adata, "bulk_labels") + assert not np.isnan(adata.obsm["X_pca_harmony"]).any() + + +def test_harmony_input_validation(): + """Test that Harmony raises errors for invalid inputs.""" + adata = sc.datasets.pbmc68k_reduced() + + # Test missing basis + with pytest.raises(ValueError, match="not available"): + harmony_integrate(adata, "bulk_labels", basis="nonexistent") + + # Test missing key + with pytest.raises(KeyError): + harmony_integrate(adata, "nonexistent_key") From f27954354d4d0850d8fff4f83949c282cdf51b17 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 26 Jan 2026 11:00:26 +0100 Subject: [PATCH 2/9] add pooch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c61961606c..8e52640212 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ test-min = [ "pytest-randomly", "pytest-rerunfailures", "tuna", + "pooch", "dependency-groups", # for CI scripts doctests ] test = [ @@ -142,7 +143,6 @@ leiden = [ "igraph>=0.10.8", "leidenalg>=0.10.1" ] # Leiden community detect bbknn = [ "bbknn" ] # Batch balanced KNN (batch correction) magic = [ "magic-impute>=2.0.4" ] # MAGIC imputation method skmisc = [ "scikit-misc>=0.5.1" ] # highly_variable_genes method 'seurat_v3' -harmony = [ "harmonypy" ] # Harmony dataset integration scanorama = [ "scanorama" ] # Scanorama dataset integration scrublet = [ "scikit-image>=0.23.1" ] # Doublet detection with automatic thresholds # Plotting From b5527796ffe054a6078f7bd8267b56b5dae2da64 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 10:38:24 +0100 Subject: [PATCH 3/9] improve deprecation --- src/scanpy/external/pp/__init__.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/src/scanpy/external/pp/__init__.py b/src/scanpy/external/pp/__init__.py index 4367a6fdfb..aa3135f95f 100644 --- a/src/scanpy/external/pp/__init__.py +++ b/src/scanpy/external/pp/__init__.py @@ -3,8 +3,6 @@ from __future__ import annotations from ..._compat import deprecated -from ...preprocessing import _scrublet -from ...preprocessing._harmony_integrate import harmony_integrate from ._bbknn import bbknn from ._dca import dca from ._hashsolo import hashsolo @@ -12,17 +10,32 @@ from ._mnn_correct import mnn_correct from ._scanorama_integrate import scanorama_integrate -scrublet = deprecated("Import from sc.pp instead")(_scrublet.scrublet) -scrublet_simulate_doublets = deprecated("Import from sc.pp instead")( - _scrublet.scrublet_simulate_doublets -) - __all__ = [ "bbknn", "dca", - "harmony_integrate", "hashsolo", "magic", "mnn_correct", "scanorama_integrate", ] + + +@deprecated("Import from sc.pp instead") +def harmony_integrate(*args, **kwargs): + from ...preprocessing import harmony_integrate + + return harmony_integrate(*args, **kwargs) + + +@deprecated("Import from sc.pp instead") +def scrublet(*args, **kwargs): + from ...preprocessing import scrublet + + return scrublet(*args, **kwargs) + + +@deprecated("Import from sc.pp instead") +def scrublet_simulate_doublets(*args, **kwargs): + from ...preprocessing import scrublet_simulate_doublets + + return scrublet_simulate_doublets(*args, **kwargs) From a59be404b04c3d33bf839e1a1b3ed914566da008 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 10:45:23 +0100 Subject: [PATCH 4/9] docs --- docs/external/preprocessing.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/external/preprocessing.md b/docs/external/preprocessing.md index 8cbf6449a6..6702e17b22 100644 --- a/docs/external/preprocessing.md +++ b/docs/external/preprocessing.md @@ -5,6 +5,11 @@ .. currentmodule:: scanpy.external ``` +Previously found here, but now part of scanpy’s main API: +- {func}`scanpy.pp.harmony_integrate` +- {func}`scanpy.pp.scrublet` +- {func}`scanpy.pp.scrublet_simulate_doublets` + (external-data-integration)= ## Data integration @@ -14,10 +19,8 @@ :toctree: ../generated/ pp.bbknn - pp.harmony_integrate pp.mnn_correct pp.scanorama_integrate - ``` ## Sample demultiplexing @@ -39,5 +42,4 @@ Note that the fundamental limitations of imputation are still under [debate](htt pp.dca pp.magic - ``` From 94710d56c18a9b9ee194a885d8fa888989b506ce Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 12:16:30 +0100 Subject: [PATCH 5/9] docs --- docs/api/preprocessing.md | 10 ++++++++-- docs/release-notes/1.10.0.md | 2 +- docs/release-notes/1.11.0.md | 2 +- src/scanpy/preprocessing/_harmony/__init__.py | 20 +++++++++---------- .../preprocessing/_harmony_integrate.py | 2 +- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/docs/api/preprocessing.md b/docs/api/preprocessing.md index 0950b9b296..5a3be50767 100644 --- a/docs/api/preprocessing.md +++ b/docs/api/preprocessing.md @@ -47,9 +47,12 @@ For visual quality control, see {func}`~scanpy.pl.highest_expr_genes` and pp.recipe_seurat ``` -## Batch effect correction +(pp-data-integration)= -Also see {ref}`data-integration`. Note that a simple batch correction method is available via {func}`pp.regress_out`. Checkout {mod}`scanpy.external` for more. +## Data integration + +Batch effect correction and other data integration. +Note that a simple batch correction method is available via {func}`pp.regress_out`. ```{eval-rst} .. autosummary:: @@ -57,8 +60,11 @@ Also see {ref}`data-integration`. Note that a simple batch correction method is :toctree: generated/ pp.combat + pp.harmony_integrate ``` +Also see {ref}`data integration tools ` and external {ref}`external data integration `. + ## Doublet detection ```{eval-rst} diff --git a/docs/release-notes/1.10.0.md b/docs/release-notes/1.10.0.md index 969633db0b..b12dd11e06 100644 --- a/docs/release-notes/1.10.0.md +++ b/docs/release-notes/1.10.0.md @@ -25,7 +25,7 @@ Some highlights: * {func}`scanpy.datasets.blobs` now accepts a `random_state` argument {pr}`2683` {smaller}`E Roellin` * {func}`scanpy.pp.pca` and {func}`scanpy.pp.regress_out` now accept a layer argument {pr}`2588` {smaller}`S Dicks` * {func}`scanpy.pp.subsample` with `copy=True` can now be called in backed mode {pr}`2624` {smaller}`E Roellin` -* {func}`scanpy.external.pp.harmony_integrate` now runs with 64 bit floats improving reproducibility {pr}`2655` {smaller}`S Dicks` +* {func}`scanpy.pp.harmony_integrate` now runs with 64 bit floats improving reproducibility {pr}`2655` {smaller}`S Dicks` * {func}`scanpy.tl.rank_genes_groups` no longer warns that it's default was changed from t-test_overestim_var to t-test {pr}`2798` {smaller}`L Heumos` * `scanpy.pp.calculate_qc_metrics` now allows `qc_vars` to be passed as a string {pr}`2859` {smaller}`N Teyssier` * {func}`scanpy.tl.leiden` and {func}`scanpy.tl.louvain` now store clustering parameters in the key provided by the `key_added` parameter instead of always writing to (or overwriting) a default key {pr}`2864` {smaller}`J Fan` diff --git a/docs/release-notes/1.11.0.md b/docs/release-notes/1.11.0.md index 875b32f364..70ceab0cdf 100644 --- a/docs/release-notes/1.11.0.md +++ b/docs/release-notes/1.11.0.md @@ -30,7 +30,7 @@ Release candidates: #### Documentation -- {guilabel}`rc1` Improve {func}`~scanpy.external.pp.harmony_integrate` docs {smaller}`D Kühl` ({pr}`3362`) +- {guilabel}`rc1` Improve {func}`~scanpy.pp.harmony_integrate` docs {smaller}`D Kühl` ({pr}`3362`) - {guilabel}`rc1` Raise {exc}`FutureWarning` when calling deprecated {mod}`scanpy.pp` functions {smaller}`P Angerer` ({pr}`3380`) - {guilabel}`rc1` {smaller}`P Angerer` ({pr}`3407`) diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index 282f6589a2..4a01df6b5a 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -7,11 +7,15 @@ from sklearn.cluster import KMeans from tqdm.auto import tqdm +from ... import logging as log +from ..._settings import settings +from ..._settings.verbosity import Verbosity + if TYPE_CHECKING: import pandas as pd -def harmonize( # noqa: PLR0913, PLR0912 +def harmonize( # noqa: PLR0913 x: np.ndarray, batch_df: pd.DataFrame, batch_key: str | list[str], @@ -27,7 +31,6 @@ def harmonize( # noqa: PLR0913, PLR0912 correction_method: str = "original", block_proportion: float = 0.05, random_state: int | None = 0, - verbose: bool = False, sparse: bool = False, ) -> np.ndarray: """ @@ -63,8 +66,6 @@ def harmonize( # noqa: PLR0913, PLR0912 Fraction of cells processed per clustering iteration. Default 0.05. random_state Random seed for reproducibility. - verbose - Print progress information. sparse Use sparse matrices for phi. Reduces memory for large datasets. @@ -125,7 +126,7 @@ def harmonize( # noqa: PLR0913, PLR0912 converged = False z_hat = x.copy() - for i in tqdm(range(max_iter_harmony), disable=not verbose): + for i in tqdm(range(max_iter_harmony), disable=settings.verbosity < Verbosity.info): # Clustering step _clustering( z_norm, @@ -159,12 +160,11 @@ def harmonize( # noqa: PLR0913, PLR0912 # Check convergence if _is_convergent_harmony(objectives_harmony, tol_harmony): converged = True - if verbose: - print(f"Harmony converged in {i + 1} iterations") + log.info(f"Harmony converged in {i + 1} iterations") break - if not converged and verbose: - print(f"Harmony did not converge after {max_iter_harmony} iterations.") + if not converged: + log.info(f"Harmony did not converge after {max_iter_harmony} iterations.") return z_hat @@ -180,7 +180,7 @@ def _get_batch_codes( batch_vec = batch_df[batch_key[0]] else: df = batch_df[batch_key].astype("str") - batch_vec = df.apply(lambda row: ",".join(row), axis=1) + batch_vec = df.apply(",".join, axis=1) batch_cat = batch_vec.astype("category") codes = batch_cat.cat.codes.values.copy() diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py index 78249e645e..081264ec6e 100644 --- a/src/scanpy/preprocessing/_harmony_integrate.py +++ b/src/scanpy/preprocessing/_harmony_integrate.py @@ -60,7 +60,7 @@ def harmony_integrate( # Ensure the basis exists in adata.obsm if basis not in adata.obsm: msg = ( - f"The specified basis '{basis}' is not available in adata.obsm. " + f"The specified basis {basis!r} is not available in `adata.obsm`. " f"Available bases: {list(adata.obsm.keys())}" ) raise ValueError(msg) From 4b4d827b0484eec031439d3288ef275d6309c90e Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 13:01:58 +0100 Subject: [PATCH 6/9] move API over --- src/scanpy/preprocessing/_harmony/__init__.py | 50 +++----- .../preprocessing/_harmony_integrate.py | 49 +++++++- src/testing/scanpy/_pytest/__init__.py | 5 + tests/test_harmony.py | 107 ++++++++++-------- 4 files changed, 120 insertions(+), 91 deletions(-) diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index 4a01df6b5a..62f8885a80 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -12,6 +12,8 @@ from ..._settings.verbosity import Verbosity if TYPE_CHECKING: + from typing import Literal + import pandas as pd @@ -20,18 +22,18 @@ def harmonize( # noqa: PLR0913 batch_df: pd.DataFrame, batch_key: str | list[str], *, - theta: float | list[float] | None = None, - sigma: float = 0.1, - n_clusters: int | None = None, - max_iter_harmony: int = 10, - max_iter_clustering: int = 200, - tol_harmony: float = 1e-4, - tol_clustering: float = 1e-5, - ridge_lambda: float = 1.0, - correction_method: str = "original", - block_proportion: float = 0.05, - random_state: int | None = 0, - sparse: bool = False, + theta: float | list[float] | None, + sigma: float, + n_clusters: int | None, + max_iter_harmony: int, + max_iter_clustering: int, + tol_harmony: float, + tol_clustering: float, + ridge_lambda: float, + correction_method: Literal["fast", "original"], + block_proportion: float, + random_state: int | None, + sparse: bool, ) -> np.ndarray: """ Run Harmony batch correction algorithm. @@ -44,30 +46,6 @@ def harmonize( # noqa: PLR0913 DataFrame containing batch information. batch_key Column name(s) in batch_df containing batch labels. - theta - Diversity penalty weight(s). Default is 2 for each batch variable. - sigma - Width of soft clustering kernel. Default 0.1. - n_clusters - Number of clusters. Default is min(100, n_cells/30). - max_iter_harmony - Maximum Harmony iterations. Default 10. - max_iter_clustering - Maximum clustering iterations per Harmony round. Default 200. - tol_harmony - Convergence tolerance for Harmony. Default 1e-4. - tol_clustering - Convergence tolerance for clustering. Default 1e-5. - ridge_lambda - Ridge regression regularization. Default 1.0. - correction_method - 'original' or 'fast'. Default 'original'. - block_proportion - Fraction of cells processed per clustering iteration. Default 0.05. - random_state - Random seed for reproducibility. - sparse - Use sparse matrices for phi. Reduces memory for large datasets. Returns ------- diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py index 081264ec6e..ca4945195f 100644 --- a/src/scanpy/preprocessing/_harmony_integrate.py +++ b/src/scanpy/preprocessing/_harmony_integrate.py @@ -8,18 +8,28 @@ from typing import Literal from anndata import AnnData + from numpy.typing import DTypeLike -def harmony_integrate( +def harmony_integrate( # noqa: PLR0913 adata: AnnData, key: str | list[str], *, basis: str = "X_pca", adjusted_basis: str = "X_pca_harmony", - dtype: type = np.float64, + dtype: DTypeLike = np.float64, + theta: float | list[float] | None = None, + sigma: float = 0.1, + n_clusters: int | None = None, + max_iter_harmony: int = 10, + max_iter_clustering: int = 200, + tol_harmony: float = 1e-4, + tol_clustering: float = 1e-5, + ridge_lambda: float = 1.0, correction_method: Literal["fast", "original"] = "original", + block_proportion: float = 0.05, + random_state: int | None = 0, sparse: bool = False, - **kwargs, ) -> None: """ Integrate different experiments using the Harmony algorithm. @@ -42,13 +52,31 @@ def harmony_integrate( table will be stored. Defaults to ``X_pca_harmony``. dtype The data type to use for Harmony computation. + theta + Diversity penalty weight(s). Default is 2 for each batch variable. + sigma + Width of soft clustering kernel. Default 0.1. + n_clusters + Number of clusters. Default is min(100, n_cells/30). + max_iter_harmony + Maximum Harmony iterations. Default 10. + max_iter_clustering + Maximum clustering iterations per Harmony round. Default 200. + tol_harmony + Convergence tolerance for Harmony. Default 1e-4. + tol_clustering + Convergence tolerance for clustering. Default 1e-5. + ridge_lambda + Ridge regression regularization. Default 1.0. correction_method Choose which method for the correction step: ``original`` for original method, ``fast`` for improved method. + block_proportion + Fraction of cells processed per clustering iteration. Default 0.05. + random_state + Random seed for reproducibility. sparse Use sparse matrices for batch encoding. Reduces memory for large datasets. - **kwargs - Additional arguments passed to ``harmonize()``. Returns ------- @@ -91,9 +119,18 @@ def harmony_integrate( x, adata.obs, key, + theta=theta, + sigma=sigma, + n_clusters=n_clusters, + max_iter_harmony=max_iter_harmony, + max_iter_clustering=max_iter_clustering, + tol_harmony=tol_harmony, + tol_clustering=tol_clustering, + ridge_lambda=ridge_lambda, correction_method=correction_method, + block_proportion=block_proportion, + random_state=random_state, sparse=sparse, - **kwargs, ) # Store result diff --git a/src/testing/scanpy/_pytest/__init__.py b/src/testing/scanpy/_pytest/__init__.py index 04777f6aef..569d47b3f9 100644 --- a/src/testing/scanpy/_pytest/__init__.py +++ b/src/testing/scanpy/_pytest/__init__.py @@ -7,6 +7,7 @@ from types import MappingProxyType from typing import TYPE_CHECKING +import pooch import pytest from packaging.version import Version @@ -26,6 +27,7 @@ def original_settings( request: pytest.FixtureRequest, cache: pytest.Cache, tmp_path_factory: pytest.TempPathFactory, + monkeypatch: pytest.MonkeyPatch, ) -> Generator[Mapping[str, object], None, None]: """Switch to agg backend, reset settings, and close all figures at teardown.""" # make sure seaborn is imported and did its thing @@ -51,6 +53,9 @@ def original_settings( cache.mkdir("debug") # reuse data files between test runs (unless overwritten in the test) sc.settings.datasetdir = cache.mkdir("scanpy-data") + pooch.os_cache = pooch.utils.os_cache = pooch.core.os_cache = lambda p: ( + sc.settings.datasetdir / p + ) # create new writedir for each test run sc.settings.writedir = tmp_path_factory.mktemp("scanpy_write") diff --git a/tests/test_harmony.py b/tests/test_harmony.py index fcecacd0c5..b8e2960581 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -1,44 +1,51 @@ from __future__ import annotations -import anndata as ad +from typing import TYPE_CHECKING + import numpy as np import pandas as pd import pooch import pytest +from anndata import AnnData from scipy.stats import pearsonr import scanpy as sc from scanpy.preprocessing import harmony_integrate +if TYPE_CHECKING: + from typing import Literal -def _get_measure(x, base, norm): - """Compute correlation or L2 distance between arrays.""" - assert norm in ["r", "L2"] + from numpy.typing import DTypeLike + +def _get_measure( + x: np.ndarray, base: np.ndarray, norm: Literal["r", "L2"] +) -> np.ndarray: + """Compute correlation or L2 distance between arrays.""" if norm == "r": # Compute per-column correlation if x.ndim == 1: corr, _ = pearsonr(x, base) return corr - else: - corrs = [] - for i in range(x.shape[1]): - corr, _ = pearsonr(x[:, i], base[:, i]) - corrs.append(corr) - return np.array(corrs) + corrs = [] + for i in range(x.shape[1]): + corr, _ = pearsonr(x[:, i], base[:, i]) + corrs.append(corr) + return np.array(corrs) + + assert norm == "L2" # L2 distance normalized by base norm - elif x.ndim == 1: + if x.ndim == 1: return np.linalg.norm(x - base) / np.linalg.norm(base) - else: - dists = [] - for i in range(x.shape[1]): - dist = np.linalg.norm(x[:, i] - base[:, i]) / np.linalg.norm(base[:, i]) - dists.append(dist) - return np.array(dists) + dists = [] + for i in range(x.shape[1]): + dist = np.linalg.norm(x[:, i] - base[:, i]) / np.linalg.norm(base[:, i]) + dists.append(dist) + return np.array(dists) @pytest.fixture -def adata_reference(): +def adata_reference() -> AnnData: """Load reference data from harmonypy repository.""" x_pca_file = pooch.retrieve( "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_pcs.tsv.gz", @@ -57,7 +64,7 @@ def adata_reference(): meta = pd.read_csv(meta_file, delimiter="\t") # Create unique index using row number + cell name meta.index = [f"{i}_{cell}" for i, cell in enumerate(meta["cell"])] - adata = ad.AnnData( + adata = AnnData( X=None, obs=meta, obsm={"X_pca": x_pca.values, "harmony_org": x_pca_harmony.values}, @@ -67,30 +74,44 @@ def adata_reference(): @pytest.mark.parametrize("correction_method", ["fast", "original"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_harmony_integrate(correction_method, dtype): +def test_harmony_integrate( + correction_method: Literal["fast", "original"], dtype: DTypeLike +) -> None: """Test that Harmony integrate works.""" adata = sc.datasets.pbmc68k_reduced() + harmony_integrate( adata, "bulk_labels", correction_method=correction_method, dtype=dtype ) + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_harmony_integrate_algos(dtype): +def test_harmony_integrate_algos(subtests: pytest.Subtests, dtype: DTypeLike) -> None: """Test that both correction methods produce similar results.""" adata = sc.datasets.pbmc68k_reduced() + harmony_integrate(adata, "bulk_labels", correction_method="fast", dtype=dtype) fast = adata.obsm["X_pca_harmony"].copy() harmony_integrate(adata, "bulk_labels", correction_method="original", dtype=dtype) slow = adata.obsm["X_pca_harmony"].copy() - assert _get_measure(fast, slow, "r").min() > 0.99 - assert _get_measure(fast, slow, "L2").max() < 0.1 + + with subtests.test("r"): + assert _get_measure(fast, slow, "r").min() > 0.99 + with subtests.test("L2"): + assert _get_measure(fast, slow, "L2").max() < 0.1 @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("correction_method", ["fast", "original"]) -def test_harmony_integrate_reference(adata_reference, *, dtype, correction_method): +def test_harmony_integrate_reference( + *, + subtests: pytest.Subtests, + adata_reference: AnnData, + dtype: DTypeLike, + correction_method: Literal["fast", "original"], +) -> None: """Test that Harmony produces results similar to the reference implementation.""" harmony_integrate( adata_reference, @@ -99,35 +120,26 @@ def test_harmony_integrate_reference(adata_reference, *, dtype, correction_metho dtype=dtype, max_iter_harmony=20, ) + x, base = adata_reference.obsm["harmony_org"], adata_reference.obsm["X_pca_harmony"] - assert ( - _get_measure( - adata_reference.obsm["harmony_org"], - adata_reference.obsm["X_pca_harmony"], - "L2", - ).max() - < 0.05 - ) - assert ( - _get_measure( - adata_reference.obsm["harmony_org"], - adata_reference.obsm["X_pca_harmony"], - "r", - ).min() - > 0.95 - ) + with subtests.test("r"): + assert _get_measure(x, base, "r").min() > 0.95 + with subtests.test("L2"): + assert _get_measure(x, base, "L2").max() < 0.05 -def test_harmony_multiple_keys(): +def test_harmony_multiple_keys() -> None: """Test Harmony with multiple batch keys.""" adata = sc.datasets.pbmc68k_reduced() # Create a second batch key adata.obs["batch2"] = np.random.choice(["A", "B", "C"], size=adata.n_obs) + harmony_integrate(adata, ["bulk_labels", "batch2"], correction_method="original") + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape -def test_harmony_custom_parameters(): +def test_harmony_custom_parameters() -> None: """Test Harmony with custom parameters.""" adata = sc.datasets.pbmc68k_reduced() harmony_integrate( @@ -142,21 +154,18 @@ def test_harmony_custom_parameters(): assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape -def test_harmony_no_nan_output(): +def test_harmony_no_nan_output() -> None: """Test that Harmony output contains no NaN values.""" adata = sc.datasets.pbmc68k_reduced() harmony_integrate(adata, "bulk_labels") assert not np.isnan(adata.obsm["X_pca_harmony"]).any() -def test_harmony_input_validation(): +def test_harmony_input_validation(subtests) -> None: """Test that Harmony raises errors for invalid inputs.""" adata = sc.datasets.pbmc68k_reduced() - # Test missing basis - with pytest.raises(ValueError, match="not available"): + with subtests.test("no basis"), pytest.raises(ValueError, match="not available"): harmony_integrate(adata, "bulk_labels", basis="nonexistent") - - # Test missing key - with pytest.raises(KeyError): + with subtests.test("no key"), pytest.raises(KeyError): harmony_integrate(adata, "nonexistent_key") From 4627d62c4eb111df65a5e1c46a7a0120ec941050 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 14:27:26 +0100 Subject: [PATCH 7/9] faster --- tests/test_harmony.py | 83 ++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 44 deletions(-) diff --git a/tests/test_harmony.py b/tests/test_harmony.py index b8e2960581..5d90fa146e 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -9,8 +9,8 @@ from anndata import AnnData from scipy.stats import pearsonr -import scanpy as sc from scanpy.preprocessing import harmony_integrate +from testing.scanpy._helpers.data import pbmc68k_reduced if TYPE_CHECKING: from typing import Literal @@ -18,56 +18,53 @@ from numpy.typing import DTypeLike +DATA = dict( + pca=("pbmc_3500_pcs.tsv.gz", "md5:27e319b3ddcc0c00d98e70aa8e677b10"), + pca_harmonized=( + "pbmc_3500_pcs_harmonized.tsv.gz", + "md5:a7c4ce4b98c390997c66d63d48e09221", + ), + meta=("pbmc_3500_meta.tsv.gz", "md5:8c7ca20e926513da7cf0def1211baecb"), +) + + def _get_measure( x: np.ndarray, base: np.ndarray, norm: Literal["r", "L2"] ) -> np.ndarray: """Compute correlation or L2 distance between arrays.""" - if norm == "r": - # Compute per-column correlation + if norm == "r": # Compute per-column correlation if x.ndim == 1: corr, _ = pearsonr(x, base) return corr - corrs = [] - for i in range(x.shape[1]): - corr, _ = pearsonr(x[:, i], base[:, i]) - corrs.append(corr) - return np.array(corrs) - - assert norm == "L2" - # L2 distance normalized by base norm - if x.ndim == 1: - return np.linalg.norm(x - base) / np.linalg.norm(base) - dists = [] - for i in range(x.shape[1]): - dist = np.linalg.norm(x[:, i] - base[:, i]) / np.linalg.norm(base[:, i]) - dists.append(dist) - return np.array(dists) + return np.array([pearsonr(x[:, i], base[:, i])[0] for i in range(x.shape[1])]) + if norm == "L2": + # L2 distance normalized by base norm + if x.ndim == 1: + return np.linalg.norm(x - base) / np.linalg.norm(base) + return np.array([ + np.linalg.norm(x[:, i] - base[:, i]) / np.linalg.norm(base[:, i]) + for i in range(x.shape[1]) + ]) + pytest.fail(f"Unknown {norm=!r}") @pytest.fixture def adata_reference() -> AnnData: """Load reference data from harmonypy repository.""" - x_pca_file = pooch.retrieve( - "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_pcs.tsv.gz", - known_hash="md5:27e319b3ddcc0c00d98e70aa8e677b10", - ) - x_pca = pd.read_csv(x_pca_file, delimiter="\t") - x_pca_harmony_file = pooch.retrieve( - "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_pcs_harmonized.tsv.gz", - known_hash="md5:a7c4ce4b98c390997c66d63d48e09221", - ) - x_pca_harmony = pd.read_csv(x_pca_harmony_file, delimiter="\t") - meta_file = pooch.retrieve( - "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_meta.tsv.gz", - known_hash="md5:8c7ca20e926513da7cf0def1211baecb", - ) - meta = pd.read_csv(meta_file, delimiter="\t") + paths = { + f: pooch.retrieve( + f"https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/{name}", + known_hash=hash_, + ) + for f, (name, hash_) in DATA.items() + } + dfs = {f: pd.read_csv(path, delimiter="\t") for f, path in paths.items()} # Create unique index using row number + cell name - meta.index = [f"{i}_{cell}" for i, cell in enumerate(meta["cell"])] + dfs["meta"].index = [f"{i}_{cell}" for i, cell in enumerate(dfs["meta"]["cell"])] adata = AnnData( X=None, - obs=meta, - obsm={"X_pca": x_pca.values, "harmony_org": x_pca_harmony.values}, + obs=dfs["meta"], + obsm={"X_pca": dfs["pca"].values, "harmony_org": dfs["pca_harmonized"].values}, ) return adata @@ -78,19 +75,17 @@ def test_harmony_integrate( correction_method: Literal["fast", "original"], dtype: DTypeLike ) -> None: """Test that Harmony integrate works.""" - adata = sc.datasets.pbmc68k_reduced() - + adata = pbmc68k_reduced() harmony_integrate( adata, "bulk_labels", correction_method=correction_method, dtype=dtype ) - assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_harmony_integrate_algos(subtests: pytest.Subtests, dtype: DTypeLike) -> None: """Test that both correction methods produce similar results.""" - adata = sc.datasets.pbmc68k_reduced() + adata = pbmc68k_reduced() harmony_integrate(adata, "bulk_labels", correction_method="fast", dtype=dtype) fast = adata.obsm["X_pca_harmony"].copy() @@ -130,7 +125,7 @@ def test_harmony_integrate_reference( def test_harmony_multiple_keys() -> None: """Test Harmony with multiple batch keys.""" - adata = sc.datasets.pbmc68k_reduced() + adata = pbmc68k_reduced() # Create a second batch key adata.obs["batch2"] = np.random.choice(["A", "B", "C"], size=adata.n_obs) @@ -141,7 +136,7 @@ def test_harmony_multiple_keys() -> None: def test_harmony_custom_parameters() -> None: """Test Harmony with custom parameters.""" - adata = sc.datasets.pbmc68k_reduced() + adata = pbmc68k_reduced() harmony_integrate( adata, "bulk_labels", @@ -156,14 +151,14 @@ def test_harmony_custom_parameters() -> None: def test_harmony_no_nan_output() -> None: """Test that Harmony output contains no NaN values.""" - adata = sc.datasets.pbmc68k_reduced() + adata = pbmc68k_reduced() harmony_integrate(adata, "bulk_labels") assert not np.isnan(adata.obsm["X_pca_harmony"]).any() def test_harmony_input_validation(subtests) -> None: """Test that Harmony raises errors for invalid inputs.""" - adata = sc.datasets.pbmc68k_reduced() + adata = pbmc68k_reduced() with subtests.test("no basis"), pytest.raises(ValueError, match="not available"): harmony_integrate(adata, "bulk_labels", basis="nonexistent") From 3d53e3796310d133c0e8ce7e8e331c4c4cec4c09 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 15:45:55 +0100 Subject: [PATCH 8/9] refactor --- src/scanpy/preprocessing/_harmony/__init__.py | 295 ++++++++++-------- .../preprocessing/_harmony_integrate.py | 11 +- 2 files changed, 172 insertions(+), 134 deletions(-) diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index 62f8885a80..ceb6818a76 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import KW_ONLY, InitVar, dataclass, field from typing import TYPE_CHECKING import numpy as np @@ -12,31 +13,17 @@ from ..._settings.verbosity import Verbosity if TYPE_CHECKING: + from collections.abc import Sequence from typing import Literal import pandas as pd + from ..._compat import CSBase -def harmonize( # noqa: PLR0913 - x: np.ndarray, - batch_df: pd.DataFrame, - batch_key: str | list[str], - *, - theta: float | list[float] | None, - sigma: float, - n_clusters: int | None, - max_iter_harmony: int, - max_iter_clustering: int, - tol_harmony: float, - tol_clustering: float, - ridge_lambda: float, - correction_method: Literal["fast", "original"], - block_proportion: float, - random_state: int | None, - sparse: bool, -) -> np.ndarray: - """ - Run Harmony batch correction algorithm. + +@dataclass +class Harmony: + """Harmony batch correction algorithm. Parameters ---------- @@ -46,110 +33,173 @@ def harmonize( # noqa: PLR0913 DataFrame containing batch information. batch_key Column name(s) in batch_df containing batch labels. - - Returns - ------- - z_corr - Batch-corrected embedding matrix (n_cells x d). """ - if random_state is not None: - np.random.seed(random_state) - # Ensure input is C-contiguous float array (infer dtype from x) - x = np.ascontiguousarray(x) - dtype = x.dtype - n_cells = x.shape[0] - - # Normalize input for clustering - z_norm = _normalize_rows_l2(x) - - # Process batch keys - batch_codes, n_batches = _get_batch_codes(batch_df, batch_key) + batch_df: InitVar[pd.DataFrame] + batch_key: InitVar[str | Sequence[str]] + _: KW_ONLY + theta: float | Sequence[float] | None + sigma: float + n_clusters: int | None + max_iter_harmony: int + max_iter_clustering: int + tol_harmony: float + tol_clustering: float + ridge_lambda: float + correction_method: Literal["fast", "original"] + block_proportion: float + random_state: int | None + sparse: bool + + batch_codes: np.ndarray = field(init=False) + n_batches: int = field(init=False) + + def __post_init__( + self, batch_df: pd.DataFrame, batch_key: str | Sequence[str] + ) -> None: + if self.max_iter_harmony < 1: + msg = "max_iter_harmony must be >= 1" + raise ValueError(msg) + + # Process batch keys + self.batch_codes, self.n_batches = _get_batch_codes(batch_df, batch_key) + + def fit(self, x: np.ndarray) -> np.ndarray: + """Run Harmony. + + Returns + ------- + z_corr + Batch-corrected embedding matrix (n_cells x d). + """ + if self.random_state is not None: + np.random.seed(self.random_state) + + # Ensure input is C-contiguous float array (infer dtype from x) + x = np.ascontiguousarray(x) + n_cells = x.shape[0] + + # Normalize input for clustering + z_norm = _normalize_rows_l2(x) + + # Build phi matrix (one-hot encoding of batches) + if self.sparse: + phi = _one_hot_encode_sparse(self.batch_codes, self.n_batches, x.dtype) + n_b = np.asarray(phi.sum(axis=0)).ravel() + else: + phi = _one_hot_encode(self.batch_codes, self.n_batches, x.dtype) + n_b = phi.sum(axis=0) + pr_b = (n_b / n_cells).reshape(-1, 1) + + # Set default theta + if self.theta is None: + theta_arr = np.ones(self.n_batches, dtype=x.dtype) * 2.0 + elif isinstance(self.theta, (int, float)): + theta_arr = np.ones(self.n_batches, dtype=x.dtype) * float(self.theta) + else: + theta_arr = np.array(self.theta, dtype=x.dtype) + theta_arr = theta_arr.reshape(1, -1) - # Build phi matrix (one-hot encoding of batches) - if sparse: - phi = _one_hot_encode_sparse(batch_codes, n_batches, dtype) - n_b = np.asarray(phi.sum(axis=0)).ravel() - else: - phi = _one_hot_encode(batch_codes, n_batches, dtype) - n_b = phi.sum(axis=0) - pr_b = (n_b / n_cells).reshape(-1, 1) - - # Set default theta - if theta is None: - theta_arr = np.ones(n_batches, dtype=dtype) * 2.0 - elif isinstance(theta, (int, float)): - theta_arr = np.ones(n_batches, dtype=dtype) * float(theta) - else: - theta_arr = np.array(theta, dtype=dtype) - theta_arr = theta_arr.reshape(1, -1) - - # Set default n_clusters - if n_clusters is None: - n_clusters = int(min(100, n_cells / 30)) - n_clusters = max(n_clusters, 2) - - # Initialize centroids and state arrays - r, e, o, objectives_harmony = _initialize_centroids( - z_norm, - phi, - pr_b, - n_clusters=n_clusters, - sigma=sigma, - theta=theta_arr, - random_state=random_state, - ) + # Set default n_clusters + if self.n_clusters is None: + n_clusters = int(min(100, n_cells / 30)) + n_clusters = max(n_clusters, 2) + else: + n_clusters = self.n_clusters - # Main Harmony loop - converged = False - z_hat = x.copy() + # Initialize centroids and state arrays + r, e, o, obj_init = _initialize_centroids( + z_norm, + phi, + pr_b, + n_clusters=n_clusters, + sigma=self.sigma, + theta=theta_arr, + random_state=self.random_state, + ) - for i in tqdm(range(max_iter_harmony), disable=settings.verbosity < Verbosity.info): - # Clustering step - _clustering( + # Main Harmony loop + objectives_harmony = [obj_init] + with tqdm( + range(self.max_iter_harmony), disable=settings.verbosity < Verbosity.info + ) as bar: + for i in bar: + r, e, o, obj = self._cluster( + z_norm, pr_b, r=r, e=e, o=o, theta=theta_arr + ) + if obj is not None: + objectives_harmony.append(obj) + z_hat = self._correct(x, r, o) + z_norm = _normalize_rows_l2(z_hat) + if self._is_convergent(objectives_harmony, self.tol_harmony): + log.info(f"Harmony converged in {i + 1} iterations") + break + else: + log.info( + f"Harmony did not converge after {self.max_iter_harmony} iterations." + ) + + return z_hat + + @staticmethod + def _is_convergent(objectives: list[float], tol: float) -> bool: + """Check Harmony convergence.""" + if len(objectives) < 2: + return False + obj_old = objectives[-2] + obj_new = objectives[-1] + return (obj_old - obj_new) < tol * abs(obj_old) + + def _cluster( + self, + z_norm: np.ndarray, + pr_b: np.ndarray, + *, + r: np.ndarray, + e: np.ndarray, + o: np.ndarray, + theta: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: + """Perform clustering step.""" + return _clustering( z_norm, - batch_codes, - n_batches, + self.batch_codes, + self.n_batches, pr_b, r=r, e=e, o=o, - theta=theta_arr, - sigma=sigma, - max_iter=max_iter_clustering, - tol=tol_clustering, - block_proportion=block_proportion, - objectives_harmony=objectives_harmony, + theta=theta, + sigma=self.sigma, + max_iter=self.max_iter_clustering, + tol=self.tol_clustering, + block_proportion=self.block_proportion, ) - # Correction step - if correction_method == "fast": - z_hat = _correction_fast( - x, batch_codes, n_batches, r, o, ridge_lambda=ridge_lambda + def _correct(self, x: np.ndarray, r: np.ndarray, o: np.ndarray) -> np.ndarray: + """Perform correction step.""" + if self.correction_method == "fast": + return _correction_fast( + x, + self.batch_codes, + self.n_batches, + r, + o, + ridge_lambda=self.ridge_lambda, ) else: - z_hat = _correction_original( - x, batch_codes, n_batches, r, ridge_lambda=ridge_lambda + return _correction_original( + x, + self.batch_codes, + self.n_batches, + r, + ridge_lambda=self.ridge_lambda, ) - # Normalize corrected data for next iteration - z_norm = _normalize_rows_l2(z_hat) - - # Check convergence - if _is_convergent_harmony(objectives_harmony, tol_harmony): - converged = True - log.info(f"Harmony converged in {i + 1} iterations") - break - - if not converged: - log.info(f"Harmony did not converge after {max_iter_harmony} iterations.") - - return z_hat - def _get_batch_codes( batch_df: pd.DataFrame, - batch_key: str | list[str], + batch_key: str | Sequence[str], ) -> tuple[np.ndarray, int]: """Get batch codes from DataFrame.""" if isinstance(batch_key, str): @@ -157,11 +207,11 @@ def _get_batch_codes( elif len(batch_key) == 1: batch_vec = batch_df[batch_key[0]] else: - df = batch_df[batch_key].astype("str") + df = batch_df[list(batch_key)].astype("str") batch_vec = df.apply(",".join, axis=1) batch_cat = batch_vec.astype("category") - codes = batch_cat.cat.codes.values.copy() + codes = batch_cat.cat.codes.to_numpy(copy=True) n_batches = len(batch_cat.cat.categories) return codes.astype(np.int32), n_batches @@ -208,14 +258,14 @@ def _normalize_rows_l1(r: np.ndarray) -> None: def _initialize_centroids( z_norm: np.ndarray, - phi: np.ndarray, + phi: np.ndarray | CSBase, pr_b: np.ndarray, *, n_clusters: int, sigma: float, theta: np.ndarray, random_state: int | None, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, list]: +) -> tuple[np.ndarray, np.ndarray, np.ndarray, float]: """Initialize cluster centroids using K-means.""" kmeans = KMeans( n_clusters=n_clusters, random_state=random_state, n_init=10, max_iter=25 @@ -237,11 +287,9 @@ def _initialize_centroids( o = phi.T @ r # Compute initial objective - objectives_harmony: list = [] obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) - objectives_harmony.append(obj) - return r, e, o, objectives_harmony + return r, e, o, obj def _compute_r( @@ -268,8 +316,7 @@ def _clustering( # noqa: PLR0913 max_iter: int, tol: float, block_proportion: float, - objectives_harmony: list, -) -> None: +) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: """Run clustering iterations (modifies r, e, o in-place).""" n_cells = z_norm.shape[0] k = r.shape[1] @@ -340,8 +387,12 @@ def _clustering( # noqa: PLR0913 # Check convergence if _is_convergent_clustering(objectives_clustering, tol): - objectives_harmony.append(objectives_clustering[-1]) + obj = objectives_clustering[-1] break + else: + obj = None + + return r, e, o, obj def _correction_original( @@ -469,20 +520,6 @@ def _compute_objective( return kmeans_error + entropy + diversity_penalty -def _is_convergent_harmony( - objectives: list, - tol: float, -) -> bool: - """Check Harmony convergence.""" - if len(objectives) < 2: - return False - - obj_old = objectives[-2] - obj_new = objectives[-1] - - return (obj_old - obj_new) < tol * abs(obj_old) - - def _is_convergent_clustering( objectives: list, tol: float, diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py index ca4945195f..1d29c7b272 100644 --- a/src/scanpy/preprocessing/_harmony_integrate.py +++ b/src/scanpy/preprocessing/_harmony_integrate.py @@ -5,6 +5,7 @@ import numpy as np if TYPE_CHECKING: + from collections.abc import Sequence from typing import Literal from anndata import AnnData @@ -13,12 +14,12 @@ def harmony_integrate( # noqa: PLR0913 adata: AnnData, - key: str | list[str], + key: str | Sequence[str], *, basis: str = "X_pca", adjusted_basis: str = "X_pca_harmony", dtype: DTypeLike = np.float64, - theta: float | list[float] | None = None, + theta: float | Sequence[float] | None = None, sigma: float = 0.1, n_clusters: int | None = None, max_iter_harmony: int = 10, @@ -83,7 +84,7 @@ def harmony_integrate( # noqa: PLR0913 Updates adata with the field ``adata.obsm[adjusted_basis]``, containing principal components adjusted by Harmony. """ - from ._harmony import harmonize + from ._harmony import Harmony # Ensure the basis exists in adata.obsm if basis not in adata.obsm: @@ -115,8 +116,7 @@ def harmony_integrate( # noqa: PLR0913 raise ValueError(msg) # Run Harmony - harmony_out = harmonize( - x, + harmony = Harmony( adata.obs, key, theta=theta, @@ -132,6 +132,7 @@ def harmony_integrate( # noqa: PLR0913 random_state=random_state, sparse=sparse, ) + harmony_out = harmony.fit(x) # Store result adata.obsm[adjusted_basis] = harmony_out From 96fca7ab77439f28dda4592bf60e561daccea7da Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 16:08:02 +0100 Subject: [PATCH 9/9] more itertools --- src/scanpy/preprocessing/_harmony/__init__.py | 90 +++++++++---------- 1 file changed, 40 insertions(+), 50 deletions(-) diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index ceb6818a76..5a5c0c6b3d 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import KW_ONLY, InitVar, dataclass, field +from itertools import product from typing import TYPE_CHECKING import numpy as np @@ -160,7 +161,7 @@ def _cluster( o: np.ndarray, theta: np.ndarray, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: - """Perform clustering step.""" + """Perform clustering step. Modifies r, e, o in-place.""" return _clustering( z_norm, self.batch_codes, @@ -320,7 +321,7 @@ def _clustering( # noqa: PLR0913 """Run clustering iterations (modifies r, e, o in-place).""" n_cells = z_norm.shape[0] k = r.shape[1] - block_size = max(1, int(n_cells * block_proportion)) + n_blocks = min(n_cells, 1 // block_proportion) term = -2.0 / sigma objectives_clustering = [] @@ -340,46 +341,41 @@ def _clustering( # noqa: PLR0913 idx_list = np.random.permutation(n_cells) # Process blocks - pos = 0 - while pos < n_cells: - end_pos = min(pos + block_size, n_cells) - block_idx = idx_list[pos:end_pos] - - for b in range(n_batches): - mask = batch_codes[block_idx] == b - if not np.any(mask): - continue - - cell_idx = block_idx[mask] - - # Remove old r contribution from o and e - r_old = r[cell_idx, :] - r_old_sum = r_old.sum(axis=0) - o[b, :] -= r_old_sum - e -= pr_b * r_old_sum - - # Compute new r values - dots = z_norm[cell_idx, :] @ y_norm.T - r_new = np.exp(term * (1.0 - dots)) - - # Apply penalty - penalty = ((e[b, :] + 1.0) / (o[b, :] + 1.0)) ** theta[0, b] - r_new *= penalty - - # Normalize rows to sum to 1 - row_sums = r_new.sum(axis=1, keepdims=True) - row_sums = np.maximum(row_sums, 1e-12) - r_new /= row_sums - - # Store back - r[cell_idx, :] = r_new - - # Add new r contribution to o and e - r_new_sum = r_new.sum(axis=0) - o[b, :] += r_new_sum - e += pr_b * r_new_sum - - pos = end_pos + for block_idx, b in product( + np.array_split(idx_list, n_blocks), range(n_batches) + ): + mask = batch_codes[block_idx] == b + if not np.any(mask): + continue + + cell_idx = block_idx[mask] + + # Remove old r contribution from o and e + r_old = r[cell_idx, :] + r_old_sum = r_old.sum(axis=0) + o[b, :] -= r_old_sum + e -= pr_b * r_old_sum + + # Compute new r values + dots = z_norm[cell_idx, :] @ y_norm.T + r_new = np.exp(term * (1.0 - dots)) + + # Apply penalty + penalty = ((e[b, :] + 1.0) / (o[b, :] + 1.0)) ** theta[0, b] + r_new *= penalty + + # Normalize rows to sum to 1 + row_sums = r_new.sum(axis=1, keepdims=True) + row_sums = np.maximum(row_sums, 1e-12) + r_new /= row_sums + + # Store back + r[cell_idx, :] = r_new + + # Add new r contribution to o and e + r_new_sum = r_new.sum(axis=0) + o[b, :] += r_new_sum + e += pr_b * r_new_sum # Compute objective obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) @@ -405,7 +401,6 @@ def _correction_original( ) -> np.ndarray: """Original correction method - per-cluster ridge regression.""" _, d = x.shape - k = r.shape[1] # Ridge regularization matrix (don't penalize intercept) id_mat = np.eye(n_batches + 1) @@ -414,9 +409,7 @@ def _correction_original( z = x.copy() - for k_idx in range(k): - r_k = r[:, k_idx] - + for r_k in r.T: r_sum_total = r_k.sum() r_sum_per_batch = np.zeros(n_batches, dtype=x.dtype) for b in range(n_batches): @@ -458,13 +451,11 @@ def _correction_fast( ) -> np.ndarray: """Fast correction method using precomputed factors.""" _, d = x.shape - k = r.shape[1] z = x.copy() p = np.eye(n_batches + 1) - for k_idx in range(k): - o_k = o[:, k_idx] + for o_k, r_k in zip(o.T, r.T, strict=True): n_k = np.sum(o_k) factor = 1.0 / (o_k + ridge_lambda) @@ -480,7 +471,6 @@ def _correction_fast( inv_mat = p_t_b_inv @ p - r_k = r[:, k_idx] phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) phi_t_x[0, :] = r_k @ x for b in range(n_batches):