From 7d9cf3375ea6db6e085b6ea32cc1df882430dba3 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 29 Apr 2026 16:24:25 +0200 Subject: [PATCH 1/8] add niche detection --- docs/api/squidpy_gpu.md | 1 + docs/release-notes/0.15.1.md | 6 + docs/release-notes/index.md | 2 + src/rapids_singlecell/squidpy_gpu/__init__.py | 1 + src/rapids_singlecell/squidpy_gpu/_gmm.py | 190 +++++++++ src/rapids_singlecell/squidpy_gpu/_niche.py | 377 ++++++++++++++++++ tests/test_gmm.py | 80 ++++ tests/test_niche.py | 363 +++++++++++++++++ 8 files changed, 1020 insertions(+) create mode 100644 docs/release-notes/0.15.1.md create mode 100644 src/rapids_singlecell/squidpy_gpu/_gmm.py create mode 100644 src/rapids_singlecell/squidpy_gpu/_niche.py create mode 100644 tests/test_gmm.py create mode 100644 tests/test_niche.py diff --git a/docs/api/squidpy_gpu.md b/docs/api/squidpy_gpu.md index 99da5825..bb271054 100644 --- a/docs/api/squidpy_gpu.md +++ b/docs/api/squidpy_gpu.md @@ -13,4 +13,5 @@ gr.spatial_autocorr gr.co_occurrence gr.ligrec + gr.calculate_niche ``` diff --git a/docs/release-notes/0.15.1.md b/docs/release-notes/0.15.1.md new file mode 100644 index 00000000..8b7bd14a --- /dev/null +++ b/docs/release-notes/0.15.1.md @@ -0,0 +1,6 @@ +### 0.15.1 {small}`the-future` + +```{rubric} Features +``` +* Add `rsc.gr.calculate_niche` (GPU-accelerated spatial niche detection) with three flavors — `neighborhood` (cell-type frequency profile + Leiden), `utag` (sparse adjacency × expression + Leiden), and `cellcharter` (n-hop shell aggregation + PCA + GMM). Mirrors `squidpy.gr.calculate_niche` but runs end-to-end on GPU; ~17–316× faster on real Xenium WTA data {smaller}`S Dicks` +* Add a minimal full-covariance GMM (`squidpy_gpu._gmm.gmm_fit_predict`) used by the `cellcharter` flavor: cuML KMeans warm-start init, batched precision-Cholesky E-step, fused `cp.fuse` log-pdf {smaller}`S Dicks` diff --git a/docs/release-notes/index.md b/docs/release-notes/index.md index 3f1c2307..32259af9 100644 --- a/docs/release-notes/index.md +++ b/docs/release-notes/index.md @@ -4,6 +4,8 @@ ## Version 0.15.0 +```{include} /release-notes/0.15.1.md +``` ```{include} /release-notes/0.15.0rc7.md ``` ```{include} /release-notes/0.15.0rc6.md diff --git a/src/rapids_singlecell/squidpy_gpu/__init__.py b/src/rapids_singlecell/squidpy_gpu/__init__.py index 168e44d5..fa3cc3fd 100644 --- a/src/rapids_singlecell/squidpy_gpu/__init__.py +++ b/src/rapids_singlecell/squidpy_gpu/__init__.py @@ -3,3 +3,4 @@ from ._autocorr import spatial_autocorr from ._co_oc import co_occurrence from ._ligrec import ligrec +from ._niche import calculate_niche diff --git a/src/rapids_singlecell/squidpy_gpu/_gmm.py b/src/rapids_singlecell/squidpy_gpu/_gmm.py new file mode 100644 index 00000000..bb8779ea --- /dev/null +++ b/src/rapids_singlecell/squidpy_gpu/_gmm.py @@ -0,0 +1,190 @@ +"""Minimal GMM (full covariance) for the cellcharter niche flavor. + +Mirrors :class:`sklearn.mixture.GaussianMixture` with +``covariance_type="full"``. Two init strategies are exposed: ``"kmeans"`` (default, +cuML KMeans warm-start) and ``"random_from_data"`` (sklearn-equivalent for parity +testing). A future cuML-backed or fused-CUDA GMM can replace this. + +Implementation notes +-------------------- +- E-step uses a cached precision Cholesky (computed once per M-step) and a + per-component dense matmul. This avoids the per-component triangular solve. +- ``_precision_cholesky`` is a batched ``inv``+``cholesky`` — no Python K-loop. +- ``cupyx.scipy.special.logsumexp`` for the stable softmax over components. +- Convergence is the change in mean log-likelihood. +""" + +from __future__ import annotations + +from typing import Literal + +import cupy as cp +import numpy as np +from cupyx.scipy.special import logsumexp + +_LOG_2PI = float(np.log(2.0 * np.pi)) + + +def gmm_fit_predict( + X: cp.ndarray, + n_components: int, + *, + random_state: int = 0, + max_iter: int = 100, + tol: float = 1e-3, + reg_covar: float = 1e-6, + init: Literal["kmeans", "random_from_data"] = "kmeans", +) -> cp.ndarray: + """Fit a full-covariance GMM and return cluster labels. + + Parameters + ---------- + X + Cupy array, shape ``(n_samples, n_features)``, float32 or float64. + n_components + Number of mixture components ``K``. + random_state + Seed for initialization. + max_iter + Maximum EM iterations. + tol + Convergence threshold on the change in mean log-likelihood. + reg_covar + Regularization added to each component covariance diagonal. + init + ``"kmeans"`` (default) uses cuML KMeans for warm-start; usually much + better than ``"random_from_data"``, which mirrors sklearn for parity. + """ + X = cp.ascontiguousarray(X) + n_samples, _ = X.shape + K = int(n_components) + + weights, means, covariances = _initialize(X, K, random_state, reg_covar, init) + prec_chol, log_det_prec_half = _precision_cholesky(covariances) + + prev_ll = -cp.inf + converged = False + for _ in range(max_iter): + log_resp, ll = _e_step(X, weights, means, prec_chol, log_det_prec_half) + if cp.abs(ll - prev_ll) < tol: + converged = True + break + prev_ll = ll + weights, means, covariances = _m_step(X, log_resp, reg_covar) + prec_chol, log_det_prec_half = _precision_cholesky(covariances) + + if not converged: + log_resp, _ = _e_step(X, weights, means, prec_chol, log_det_prec_half) + return log_resp.argmax(axis=1).astype(cp.int32) + + +def _initialize( + X: cp.ndarray, + K: int, + random_state: int, + reg_covar: float, + init: str, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + n, d = X.shape + eye_reg = reg_covar * cp.eye(d, dtype=X.dtype) + + if init == "random_from_data": + # sklearn parity: pick K rows as means, equal weights, reg-only covariance. + rng = np.random.default_rng(random_state) + idx = cp.asarray(rng.choice(n, size=K, replace=False)) + means = X[idx].copy() + weights = cp.full(K, 1.0 / K, dtype=X.dtype) + covariances = cp.broadcast_to(eye_reg, (K, d, d)).copy() + return weights, means, covariances + + if init != "kmeans": + raise ValueError(f"init must be 'kmeans' or 'random_from_data', got {init!r}") + + from cuml.cluster import KMeans + + # Match sklearn's n_init=10 inside its GaussianMixture's kmeans init. + # n_init=1 was empirically prone to degenerate inits on structureless data, + # which then collapsed EM to a single component. + km = KMeans(n_clusters=K, random_state=random_state, n_init=10, max_iter=100) + km.fit(X) + labels = cp.asarray(km.labels_) + means = cp.asarray(km.cluster_centers_, dtype=X.dtype) + + weights = cp.zeros(K, dtype=X.dtype) + covariances = cp.empty((K, d, d), dtype=X.dtype) + for k in range(K): + mask = labels == k + cnt = int(mask.sum()) + if cnt == 0: + # KMeans can return empty clusters; fall back to a tiny uniform component. + weights[k] = 1.0 / n + covariances[k] = eye_reg + continue + weights[k] = cnt / n + diff = X[mask] - means[k] + covariances[k] = (diff.T @ diff) / cnt + eye_reg + return weights, means, covariances + + +def _precision_cholesky( + covariances: cp.ndarray, +) -> tuple[cp.ndarray, cp.ndarray]: + """Return ``(prec_chol, log|Σ⁻¹|/2)`` where ``prec_chol @ prec_chol.T = Σ⁻¹``. + + Two batched LAPACK calls — no Python K-loop. + """ + precisions = cp.linalg.inv(covariances) + prec_chol = cp.linalg.cholesky(precisions) + log_det_half = cp.sum(cp.log(cp.diagonal(prec_chol, axis1=1, axis2=2)), axis=1) + return prec_chol, log_det_half + + +@cp.fuse(kernel_name="_gmm_log_pdf_const") +def _log_pdf_const(mahal: cp.ndarray, log_det_half: cp.ndarray, half_d_log2pi): + return -0.5 * mahal + log_det_half - half_d_log2pi + + +def _e_step( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, +) -> tuple[cp.ndarray, cp.ndarray]: + n, d = X.shape + K = means.shape[0] + + log_prob = cp.empty((n, K), dtype=X.dtype) + half_d_log2pi = X.dtype.type(0.5 * d * _LOG_2PI) + for k in range(K): + # mahal = || (X - μ_k) @ prec_chol[k] ||² + y = (X - means[k]) @ prec_chol[k] + mahal = cp.einsum("ij,ij->i", y, y) + log_prob[:, k] = _log_pdf_const(mahal, log_det_half[k], half_d_log2pi) + log_prob = log_prob + cp.log(weights) + + log_total = logsumexp(log_prob, axis=1, keepdims=True) + log_resp = log_prob - log_total + return log_resp, log_total.mean() + + +def _m_step( + X: cp.ndarray, + log_resp: cp.ndarray, + reg_covar: float, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + n, d = X.shape + K = log_resp.shape[1] + + resp = cp.exp(log_resp) + N_k = resp.sum(axis=0) + 10.0 * cp.finfo(X.dtype).eps # (K,) + + weights = N_k / n + means = (resp.T @ X) / N_k[:, None] + + covariances = cp.empty((K, d, d), dtype=X.dtype) + eye_reg = reg_covar * cp.eye(d, dtype=X.dtype) + for k in range(K): + diff = X - means[k] + covariances[k] = ((resp[:, k : k + 1] * diff).T @ diff) / N_k[k] + eye_reg + return weights, means, covariances diff --git a/src/rapids_singlecell/squidpy_gpu/_niche.py b/src/rapids_singlecell/squidpy_gpu/_niche.py new file mode 100644 index 00000000..12e95551 --- /dev/null +++ b/src/rapids_singlecell/squidpy_gpu/_niche.py @@ -0,0 +1,377 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +import cupy as cp +import numpy as np +import pandas as pd +from anndata import AnnData +from cupyx.scipy import sparse as sparse_gpu + +import rapids_singlecell as rsc + +if TYPE_CHECKING: + from collections.abc import Sequence + + +__all__ = ["calculate_niche"] + + +def calculate_niche( + adata: AnnData, + *, + flavor: Literal["neighborhood", "utag", "cellcharter"], + groups: str | None = None, + n_neighbors: int = 15, + resolutions: float | Sequence[float] = (0.5,), + distance: int | None = None, + n_hop_weights: Sequence[float] | None = None, + abs_nhood: bool = False, + scale: bool = True, + min_niche_size: int | None = None, + aggregation: Literal["mean", "variance"] = "mean", + n_components: int = 10, + use_rep: str | None = None, + init: Literal["kmeans", "random_from_data"] = "kmeans", + spatial_connectivities_key: str = "spatial_connectivities", + random_state: int = 0, + copy: bool = False, +) -> AnnData | None: + """\ + Compute spatial niches on the GPU. + + Mirrors :func:`squidpy.gr.calculate_niche` for the ``"neighborhood"``, + ``"utag"`` and ``"cellcharter"`` flavors. The spatial graph in + ``adata.obsp[spatial_connectivities_key]`` must be precomputed + (e.g. via :func:`squidpy.gr.spatial_neighbors`). + + Parameters + ---------- + adata + Annotated data matrix. + flavor + - ``"neighborhood"`` cluster cell-type frequency profiles among spatial neighbors + :cite:p:`monkeybread`. + - ``"utag"`` cluster gene expression smoothed across spatial neighbors + :cite:p:`UTAG2022`. + - ``"cellcharter"`` shell-aggregate gene expression over n-hop neighborhoods, + PCA-reduce, then cluster with a Gaussian mixture :cite:p:`CellCharter2024`. + groups + Column in ``adata.obs`` with cell-type labels. Required for ``flavor="neighborhood"``. + n_neighbors + Neighbors for the post-aggregation kNN graph passed to leiden. + resolutions + Resolution(s) for leiden. A label column is written for each value. + Ignored for ``flavor="cellcharter"``. + distance + Number of n-hop neighborhoods to include. Defaults to 3 for ``cellcharter``, + 1 for ``neighborhood``. + n_hop_weights + Per-hop weights when ``distance > 1`` (``flavor="neighborhood"`` only). + abs_nhood + Use absolute neighbor counts instead of per-cell relative frequencies + (``flavor="neighborhood"`` only). + scale + Z-score the neighborhood profile before clustering (``flavor="neighborhood"`` only). + min_niche_size + Discard niches with fewer cells than this; relabel as ``"not_a_niche"``. + aggregation + Per-shell aggregation for ``flavor="cellcharter"``. ``"mean"`` (default) or ``"variance"``. + n_components + Number of mixture components for ``flavor="cellcharter"``. + use_rep + Key in ``adata.obsm`` to use as the embedding for ``flavor="cellcharter"``; + if provided, the first ``n_components`` columns are used and the shell-aggregation + + PCA step is skipped. + init + GMM initialization for ``flavor="cellcharter"``. ``"kmeans"`` (default, robust) + or ``"random_from_data"`` (sklearn-parity). Use the latter if kmeans init lands + on a degenerate component on noisy / low-signal data. + spatial_connectivities_key + Key in ``adata.obsp`` with the spatial connectivity matrix. + random_state + Random seed for leiden / GMM. + copy + Return a copy with the niche columns instead of writing in place. + """ + if spatial_connectivities_key not in adata.obsp: + raise KeyError( + f"'{spatial_connectivities_key}' not found in `adata.obsp`. " + "Compute it first with `squidpy.gr.spatial_neighbors`." + ) + if flavor not in ("neighborhood", "utag", "cellcharter"): + raise ValueError( + f"Unknown flavor '{flavor}'. Use 'neighborhood', 'utag', or 'cellcharter'." + ) + if distance is None: + distance = 3 if flavor == "cellcharter" else 1 + if flavor in ("neighborhood",) and distance < 1: + raise ValueError(f"`distance` must be >= 1, got {distance}.") + if flavor == "cellcharter" and distance < 0: + raise ValueError(f"`distance` must be >= 0, got {distance}.") + + adata = adata.copy() if copy else adata + + if flavor == "cellcharter": + _run_cellcharter( + adata, + distance=distance, + aggregation=aggregation, + n_components=n_components, + use_rep=use_rep, + init=init, + random_state=random_state, + key=spatial_connectivities_key, + ) + return adata if copy else None + + if flavor == "neighborhood": + if groups is None: + raise ValueError("`groups` is required for flavor='neighborhood'.") + if groups not in adata.obs.columns: + raise KeyError(f"'{groups}' not found in `adata.obs`.") + profile = _neighborhood_profile( + adata, + groups=groups, + distance=distance, + weights=n_hop_weights, + abs_nhood=abs_nhood, + key=spatial_connectivities_key, + ) + prefix = "nhood_niche" + else: + profile = _utag_features(adata, spatial_connectivities_key) + prefix = "utag_niche" + + inner = AnnData(X=profile, obs=pd.DataFrame(index=adata.obs_names.copy())) + + if flavor == "neighborhood": + if scale: + rsc.pp.scale(inner, zero_center=True) + rsc.pp.neighbors( + inner, n_neighbors=n_neighbors, use_rep="X", random_state=random_state + ) + else: + rsc.pp.pca(inner) + rsc.pp.neighbors( + inner, n_neighbors=n_neighbors, use_rep="X_pca", random_state=random_state + ) + + res_list = ( + [float(resolutions)] + if isinstance(resolutions, (int, float)) + else [float(r) for r in resolutions] + ) + base = "_niche_leiden" + rsc.tl.leiden( + inner, + resolution=res_list, + key_added=base, + random_state=random_state, + dtype=np.float64, + ) + for res in res_list: + src = f"{base}_{res}" if len(res_list) > 1 else base + out_key = f"{prefix}_res={res}" + labels = inner.obs[src].astype(str) + if min_niche_size is not None and flavor == "neighborhood": + counts = labels.value_counts() + small = counts[counts < min_niche_size].index + labels = labels.where(~labels.isin(small), other="not_a_niche") + adata.obs[out_key] = pd.Categorical(labels.values) + + return adata if copy else None + + +def _neighborhood_profile( + adata: AnnData, + *, + groups: str, + distance: int, + weights: Sequence[float] | None, + abs_nhood: bool, + key: str, +) -> np.ndarray: + """Cells x categories matrix of cell-type counts (or relative frequencies) over n-hop neighbors.""" + cats = pd.Categorical(adata.obs[groups]) + n_cats = len(cats.categories) + n_obs = adata.n_obs + + one_hot = cp.zeros((n_obs, n_cats), dtype=cp.float32) + one_hot[cp.arange(n_obs), cp.asarray(cats.codes, dtype=cp.int64)] = 1.0 + + adj = rsc.get.X_to_GPU(adata.obsp[key]).astype(cp.float32) + adj.eliminate_zeros() + # Binarize so adj.data == 1: each existing edge contributes one neighbor count. + adj_bin = adj.copy() + adj_bin.data[:] = 1.0 + + if weights is None: + weights = [1.0] * distance + elif len(weights) < distance: + weights = list(weights) + [weights[-1]] * (distance - len(weights)) + + profile = cp.zeros((n_obs, n_cats), dtype=cp.float32) + adj_k = adj_bin + for hop in range(distance): + if hop == 0: + adj_hop = adj_bin + else: + adj_k = adj_k @ adj_bin + adj_hop = adj_k.copy() + adj_hop.data[:] = 1.0 + counts = adj_hop @ one_hot # (n_obs, n_cats) dense + if not abs_nhood: + row_sum = adj_hop.sum(axis=1).reshape(-1, 1) + row_sum = cp.where(row_sum == 0, cp.float32(1.0), row_sum) + counts = counts / row_sum + profile += cp.float32(weights[hop]) * counts + + if not abs_nhood: + profile /= cp.float32(sum(weights)) + + return profile + + +def _utag_features(adata: AnnData, key: str) -> cp.ndarray | sparse_gpu.csr_matrix: + """L1-row-normalize the spatial adjacency and propagate expression: D^-1 A @ X.""" + from rapids_singlecell._cuda import _norm_cuda as _nc + + adj = rsc.get.X_to_GPU(adata.obsp[key]) + if adj.dtype != cp.float32: + adj = adj.astype(cp.float32) + _nc.mul_csr( + adj.indptr, + adj.data, + nrows=adj.shape[0], + target_sum=1.0, + stream=cp.cuda.get_current_stream().ptr, + ) + + X = rsc.get.X_to_GPU(adata.X).astype(cp.float32) + if sparse_gpu.issparse(X): + out = adj @ X + return out.tocsr() + out = adj @ X + return out + + +def _run_cellcharter( + adata: AnnData, + *, + distance: int, + aggregation: str, + n_components: int, + use_rep: str | None, + init: str, + random_state: int, + key: str, +) -> None: + """Cellcharter pipeline: shell-aggregate → PCA → GMM.""" + if aggregation not in ("mean", "variance"): + raise ValueError( + f"aggregation={aggregation!r} not supported. Use 'mean' or 'variance'." + ) + if not isinstance(n_components, int) or n_components < 1: + raise ValueError(f"`n_components` must be an int >= 1, got {n_components}.") + + if use_rep is not None: + if use_rep not in adata.obsm: + raise KeyError(f"'{use_rep}' not found in `adata.obsm`.") + emb = adata.obsm[use_rep] + if emb.shape[1] < n_components: + raise ValueError( + f"`adata.obsm['{use_rep}']` has {emb.shape[1]} columns, " + f"need at least n_components={n_components}." + ) + embedding = cp.asarray(emb[:, :n_components], dtype=cp.float32) + else: + feat = _cellcharter_features(adata, distance, aggregation, key) + inner = AnnData(X=feat, obs=pd.DataFrame(index=adata.obs_names.copy())) + rsc.get.anndata_to_GPU(inner) + rsc.pp.pca(inner) + embedding = cp.asarray(inner.obsm["X_pca"], dtype=cp.float32) + + from ._gmm import gmm_fit_predict + + labels = gmm_fit_predict( + embedding, + n_components=n_components, + random_state=random_state, + init=init, + ) + adata.obs["cellcharter_niche"] = pd.Categorical(cp.asnumpy(labels).astype(str)) + + +def _cellcharter_features( + adata: AnnData, + distance: int, + aggregation: str, + key: str, +) -> cp.ndarray | sparse_gpu.csr_matrix: + """Build the shell-aggregated feature matrix: ``[X | Â₁X | Â₂X | …]``. + + For each k in ``1..distance`` the kth-shell adjacency is computed by + multiplying the previous adjacency by the base graph and subtracting the + already-visited neighbors. Each shell is row-L1-normalized via the same + fused ``mul_csr`` kernel used for utag, then aggregated as either: + + - ``"mean"``: ``Âₖ @ X`` + - ``"variance"``: ``Âₖ @ (X·X) − (Âₖ @ X)²`` (matches squidpy's path; densifies X) + + All layers are concatenated horizontally. + """ + from rapids_singlecell._cuda import _norm_cuda as _nc + + adj = rsc.get.X_to_GPU(adata.obsp[key]) + if adj.dtype != cp.float32: + adj = adj.astype(cp.float32) + + # 1-hop adjacency, no self-loops; visited tracks {self ∪ 1-hop}. + adj_hop = adj.copy() + adj_hop.setdiag(cp.float32(0.0)) + adj_hop.eliminate_zeros() + adj_visited = adj.copy() + adj_visited.setdiag(cp.float32(1.0)) + + X = rsc.get.X_to_GPU(adata.X) + if aggregation == "variance": + # Variance needs element-wise square of X; densify once up front. + X_dense = X.toarray() if sparse_gpu.issparse(X) else X + X_sq = X_dense * X_dense + aggregated: list = [X_dense] + else: + aggregated = [X] + + for k in range(1, distance + 1): + if k > 1: + # Walk one more hop, keep only newly reachable neighbors. + adj_hop = adj_hop @ adj + new_shell = (adj_hop > adj_visited).astype(cp.float32) + adj_hop = new_shell + adj_visited = adj_visited + new_shell + + # L1 row-normalize the shell adjacency in place. + adj_norm = adj_hop.copy() + if adj_norm.nnz > 0: + _nc.mul_csr( + adj_norm.indptr, + adj_norm.data, + nrows=adj_norm.shape[0], + target_sum=1.0, + stream=cp.cuda.get_current_stream().ptr, + ) + + if aggregation == "variance": + mean = adj_norm @ X_dense + mean_sq = adj_norm @ X_sq + aggregated.append(mean_sq - mean * mean) + else: + aggregated.append(adj_norm @ X) + + if all(not sparse_gpu.issparse(m) for m in aggregated): + return cp.concatenate(aggregated, axis=1) + aggregated = [ + m if sparse_gpu.issparse(m) else sparse_gpu.csr_matrix(m) for m in aggregated + ] + return sparse_gpu.hstack(aggregated, format="csr") diff --git a/tests/test_gmm.py b/tests/test_gmm.py new file mode 100644 index 00000000..8fbdce5e --- /dev/null +++ b/tests/test_gmm.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import cupy as cp +import numpy as np +import pytest +from sklearn.metrics import adjusted_rand_score as ARI + +from rapids_singlecell.squidpy_gpu._gmm import gmm_fit_predict + + +def _well_separated(n_per: int, K: int, d: int, sep: float, seed: int): + rng = np.random.default_rng(seed) + centers = rng.normal(scale=sep, size=(K, d)) + X = np.vstack( + [rng.normal(loc=c, scale=1.0, size=(n_per, d)) for c in centers] + ).astype(np.float32) + y = np.repeat(np.arange(K), n_per) + perm = rng.permutation(len(X)) + return X[perm], y[perm] + + +def test_kmeans_init_recovers_well_separated_clusters(): + """kmeans init should land at near-truth on well-separated data.""" + X_np, y = _well_separated(n_per=300, K=5, d=20, sep=6.0, seed=0) + labels = cp.asnumpy( + gmm_fit_predict(cp.asarray(X_np), n_components=5, random_state=0, init="kmeans") + ) + assert ARI(y, labels) >= 0.95 + + +def test_random_from_data_init_runs(): + """random_from_data may land at a worse local optimum than kmeans, but should + still produce a non-trivial partition on well-separated data.""" + X_np, y = _well_separated(n_per=300, K=5, d=20, sep=6.0, seed=0) + labels = cp.asnumpy( + gmm_fit_predict( + cp.asarray(X_np), n_components=5, random_state=0, init="random_from_data" + ) + ) + assert ARI(y, labels) >= 0.4 + assert len(set(labels.tolist())) >= 2 + + +def test_output_shape_and_dtype(): + rng = np.random.default_rng(0) + X = rng.standard_normal((500, 8)).astype(np.float32) + labels = gmm_fit_predict(cp.asarray(X), n_components=4, random_state=0) + assert labels.shape == (500,) + assert labels.dtype == cp.int32 + assert int(labels.max()) < 4 + assert int(labels.min()) >= 0 + + +@pytest.mark.parametrize("init", ["kmeans", "random_from_data"]) +def test_determinism_same_seed(init): + rng = np.random.default_rng(1) + X = cp.asarray(rng.standard_normal((800, 10)).astype(np.float32)) + a = cp.asnumpy(gmm_fit_predict(X, n_components=5, random_state=42, init=init)) + b = cp.asnumpy(gmm_fit_predict(X, n_components=5, random_state=42, init=init)) + np.testing.assert_array_equal(a, b) + + +def test_invalid_init_raises(): + X = cp.asarray(np.zeros((100, 5), dtype=np.float32)) + with pytest.raises(ValueError, match="init"): + gmm_fit_predict(X, n_components=3, init="bogus") + + +def test_n_components_one_returns_single_label(): + rng = np.random.default_rng(0) + X = cp.asarray(rng.standard_normal((200, 4)).astype(np.float32)) + labels = cp.asnumpy(gmm_fit_predict(X, n_components=1, random_state=0)) + assert set(labels.tolist()) == {0} + + +def test_float64_input_accepted(): + rng = np.random.default_rng(0) + X = cp.asarray(rng.standard_normal((300, 6)).astype(np.float64)) + labels = gmm_fit_predict(X, n_components=3, random_state=0) + assert labels.shape == (300,) diff --git a/tests/test_niche.py b/tests/test_niche.py new file mode 100644 index 00000000..40757bd6 --- /dev/null +++ b/tests/test_niche.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +from pathlib import Path + +import cupy as cp +import numpy as np +import pandas as pd +import pytest +from anndata import read_h5ad +from cupyx.scipy import sparse as sparse_gpu +from scipy import sparse + +from rapids_singlecell.gr import calculate_niche +from rapids_singlecell.squidpy_gpu._niche import ( + _neighborhood_profile, + _utag_features, +) + +DATA = Path(__file__).parent / "_data" / "dummy.h5ad" +SPATIAL_CONNECTIVITIES_KEY = "spatial_connectivities" +GROUPS = "cluster" + + +@pytest.fixture +def adata(): + a = read_h5ad(DATA) + # _neighborhood_profile uses pd.Categorical on this column + a.obs[GROUPS] = pd.Categorical(a.obs[GROUPS]) + return a + + +# -- semantic tests adapted from squidpy/tests/graph/test_niche.py (BSD-3) -- + + +def test_niche_calc_nhood(adata): + """Adapted from squidpy: profile shape, normalization, min_niche_size labels.""" + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=[0.1], + min_niche_size=20, + ) + niches = adata.obs["nhood_niche_res=0.1"] + + # no NaNs, more cells in real niches than in 'not_a_niche' + assert niches.isna().sum() == 0 + assert len(niches[niches != "not_a_niche"]) > len(niches[niches == "not_a_niche"]) + for label in niches.unique(): + if label != "not_a_niche": + assert (niches == label).sum() >= 20 + + # profile shape + n_cats = len(adata.obs[GROUPS].cat.categories) + rel = cp.asnumpy( + _neighborhood_profile( + adata, + groups=GROUPS, + distance=1, + weights=None, + abs_nhood=False, + key=SPATIAL_CONNECTIVITIES_KEY, + ) + ) + abs_ = cp.asnumpy( + _neighborhood_profile( + adata, + groups=GROUPS, + distance=1, + weights=None, + abs_nhood=True, + key=SPATIAL_CONNECTIVITIES_KEY, + ) + ) + assert rel.shape == (adata.n_obs, n_cats) + assert abs_.shape == rel.shape + + # relative profile: each row sums to 1 when the cell has neighbors (all do here), + # so total sum == n_obs and the per-row max sum is 1. + np.testing.assert_allclose(rel.sum(axis=1).sum(), adata.n_obs, atol=1e-4) + assert rel.sum(axis=1).max() == pytest.approx(1.0, abs=1e-5) + + # absolute profile: per-row sum equals that cell's degree in the spatial graph + deg = np.asarray((adata.obsp[SPATIAL_CONNECTIVITIES_KEY] != 0).sum(axis=1)).ravel() + np.testing.assert_array_equal(abs_.sum(axis=1).astype(int), deg) + + +def test_niche_calc_utag(adata): + """Adapted from squidpy: utag output shape, sparsity, sensitivity to graph.""" + calculate_niche(adata, flavor="utag", n_neighbors=10, resolutions=[0.1, 1.0]) + + niches_high = adata.obs["utag_niche_res=1.0"] + niches_low = adata.obs["utag_niche_res=0.1"] + assert niches_high.isna().sum() == 0 + # higher resolution → strictly more (or at least as many) clusters + assert niches_high.nunique() >= niches_low.nunique() + + # output shape matches X (returns cupy.ndarray for dense X) + feat = _utag_features(adata, SPATIAL_CONNECTIVITIES_KEY) + assert feat.shape == adata.X.shape + + # sparsity preserved when input X is sparse (returns cupyx sparse) + a_sparse = adata.copy() + a_sparse.X = sparse.csr_matrix(adata.X) + feat_sparse = _utag_features(a_sparse, SPATIAL_CONNECTIVITIES_KEY) + assert sparse_gpu.issparse(feat_sparse) + assert feat_sparse.shape == adata.X.shape + + # different spatial graph structure → different feature matrix + # (uniform value scaling is invisible after row-normalization, so we drop edges) + a2 = adata.copy() + G = a2.obsp[SPATIAL_CONNECTIVITIES_KEY].tolil() + G[0, :] = 0 + G[1, :] = 0 + G = G.tocsr() + G.eliminate_zeros() + a2.obsp[SPATIAL_CONNECTIVITIES_KEY] = G + feat2 = _utag_features(a2, SPATIAL_CONNECTIVITIES_KEY) + assert not cp.allclose(feat, feat2) + + +# -- additional rsc-specific tests -- + + +@pytest.mark.parametrize("flavor", ["neighborhood", "utag"]) +def test_basic_runs_inplace(adata, flavor): + kw = {"groups": GROUPS} if flavor == "neighborhood" else {} + out = calculate_niche(adata, flavor=flavor, n_neighbors=10, resolutions=0.5, **kw) + assert out is None + prefix = "nhood_niche" if flavor == "neighborhood" else "utag_niche" + col = f"{prefix}_res=0.5" + assert col in adata.obs.columns + assert isinstance(adata.obs[col].dtype, pd.CategoricalDtype) + + +def test_copy_returns_new_object(adata): + before = list(adata.obs.columns) + out = calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + copy=True, + ) + assert out is not None + assert "nhood_niche_res=0.5" in out.obs.columns + assert list(adata.obs.columns) == before + + +def test_multiple_resolutions(adata): + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=[0.3, 0.7], + ) + assert "nhood_niche_res=0.3" in adata.obs.columns + assert "nhood_niche_res=0.7" in adata.obs.columns + + +def test_n_hop_neighbors(adata): + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + distance=3, + n_hop_weights=[1.0, 0.5, 0.25], + ) + assert "nhood_niche_res=0.5" in adata.obs.columns + + +def test_min_niche_size_relabels_all(adata): + """min_niche_size > n_obs should send every cell to 'not_a_niche'.""" + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=2.0, + min_niche_size=adata.n_obs + 1, + ) + labels = adata.obs["nhood_niche_res=2.0"].astype(str) + assert (labels == "not_a_niche").all() + + +def test_determinism_same_seed(adata): + a1, a2 = adata.copy(), adata.copy() + calculate_niche( + a1, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + random_state=42, + ) + calculate_niche( + a2, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + random_state=42, + ) + np.testing.assert_array_equal( + a1.obs["nhood_niche_res=0.5"].astype(str).values, + a2.obs["nhood_niche_res=0.5"].astype(str).values, + ) + + +def test_determinism_utag_same_seed(adata): + a1, a2 = adata.copy(), adata.copy() + calculate_niche(a1, flavor="utag", n_neighbors=10, resolutions=0.5, random_state=7) + calculate_niche(a2, flavor="utag", n_neighbors=10, resolutions=0.5, random_state=7) + np.testing.assert_array_equal( + a1.obs["utag_niche_res=0.5"].astype(str).values, + a2.obs["utag_niche_res=0.5"].astype(str).values, + ) + + +def test_unknown_flavor_raises(adata): + with pytest.raises(ValueError, match="Unknown flavor"): + calculate_niche(adata, flavor="bogus", n_neighbors=10, resolutions=0.5) + + +def test_neighborhood_requires_groups(adata): + with pytest.raises(ValueError, match="`groups` is required"): + calculate_niche(adata, flavor="neighborhood", n_neighbors=10, resolutions=0.5) + + +def test_groups_not_in_obs_raises(adata): + with pytest.raises(KeyError): + calculate_niche( + adata, + flavor="neighborhood", + groups="missing_col", + n_neighbors=10, + resolutions=0.5, + ) + + +def test_missing_connectivities_raises(adata): + del adata.obsp["spatial_connectivities"] + with pytest.raises(KeyError, match="spatial_connectivities"): + calculate_niche( + adata, flavor="neighborhood", groups=GROUPS, n_neighbors=10, resolutions=0.5 + ) + + +def test_invalid_distance_raises(adata): + with pytest.raises(ValueError, match="distance"): + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + distance=0, + ) + + +def test_custom_connectivity_key(adata): + adata.obsp["my_graph"] = adata.obsp["spatial_connectivities"] + del adata.obsp["spatial_connectivities"] + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + spatial_connectivities_key="my_graph", + ) + assert "nhood_niche_res=0.5" in adata.obs.columns + + +# -- cellcharter flavor tests -- + + +def test_cellcharter_basic(adata): + calculate_niche(adata, flavor="cellcharter", n_components=4) + assert "cellcharter_niche" in adata.obs.columns + col = adata.obs["cellcharter_niche"] + assert isinstance(col.dtype, pd.CategoricalDtype) + assert col.isna().sum() == 0 + assert col.nunique() <= 4 + + +def test_cellcharter_distance_zero(adata): + """distance=0 falls back to PCA + GMM on raw X (no shell aggregation).""" + calculate_niche(adata, flavor="cellcharter", n_components=3, distance=0) + assert adata.obs["cellcharter_niche"].nunique() <= 3 + + +def test_cellcharter_use_rep(adata): + """use_rep skips shell-aggregation and PCA; uses adata.obsm[key] directly.""" + rng = np.random.default_rng(0) + adata.obsm["X_test"] = rng.standard_normal((adata.n_obs, 10)).astype(np.float32) + calculate_niche(adata, flavor="cellcharter", n_components=4, use_rep="X_test") + assert "cellcharter_niche" in adata.obs.columns + + +def test_cellcharter_determinism(adata): + a1 = adata.copy() + a2 = adata.copy() + calculate_niche(a1, flavor="cellcharter", n_components=4, random_state=42) + calculate_niche(a2, flavor="cellcharter", n_components=4, random_state=42) + np.testing.assert_array_equal( + a1.obs["cellcharter_niche"].astype(str).values, + a2.obs["cellcharter_niche"].astype(str).values, + ) + + +def test_cellcharter_variance(adata): + """`aggregation="variance"` runs and produces a categorical column.""" + calculate_niche(adata, flavor="cellcharter", n_components=4, aggregation="variance") + assert "cellcharter_niche" in adata.obs.columns + assert isinstance(adata.obs["cellcharter_niche"].dtype, pd.CategoricalDtype) + + +def test_cellcharter_invalid_aggregation(adata): + with pytest.raises(ValueError, match="aggregation"): + calculate_niche( + adata, flavor="cellcharter", n_components=4, aggregation="bogus" + ) + + +def test_cellcharter_init_random_from_data(adata): + """`init="random_from_data"` is a valid escape hatch from kmeans init.""" + calculate_niche( + adata, + flavor="cellcharter", + n_components=4, + init="random_from_data", + random_state=0, + ) + assert "cellcharter_niche" in adata.obs.columns + + +def test_cellcharter_bad_n_components(adata): + with pytest.raises(ValueError, match="n_components"): + calculate_niche(adata, flavor="cellcharter", n_components=0) + + +def test_cellcharter_missing_use_rep(adata): + with pytest.raises(KeyError): + calculate_niche( + adata, flavor="cellcharter", n_components=4, use_rep="not_there" + ) + + +def test_cellcharter_use_rep_too_few_dims(adata): + adata.obsm["X_small"] = np.zeros((adata.n_obs, 3), dtype=np.float32) + with pytest.raises(ValueError, match="at least"): + calculate_niche(adata, flavor="cellcharter", n_components=10, use_rep="X_small") + + +def test_cellcharter_invalid_distance_negative(adata): + with pytest.raises(ValueError, match="distance"): + calculate_niche(adata, flavor="cellcharter", n_components=4, distance=-1) From 6d30c103355070891952d83eedf8845736148e15 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 29 Apr 2026 16:28:05 +0200 Subject: [PATCH 2/8] fix release note --- docs/release-notes/0.15.1.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/release-notes/0.15.1.md b/docs/release-notes/0.15.1.md index 8b7bd14a..4a730df5 100644 --- a/docs/release-notes/0.15.1.md +++ b/docs/release-notes/0.15.1.md @@ -2,5 +2,5 @@ ```{rubric} Features ``` -* Add `rsc.gr.calculate_niche` (GPU-accelerated spatial niche detection) with three flavors — `neighborhood` (cell-type frequency profile + Leiden), `utag` (sparse adjacency × expression + Leiden), and `cellcharter` (n-hop shell aggregation + PCA + GMM). Mirrors `squidpy.gr.calculate_niche` but runs end-to-end on GPU; ~17–316× faster on real Xenium WTA data {smaller}`S Dicks` -* Add a minimal full-covariance GMM (`squidpy_gpu._gmm.gmm_fit_predict`) used by the `cellcharter` flavor: cuML KMeans warm-start init, batched precision-Cholesky E-step, fused `cp.fuse` log-pdf {smaller}`S Dicks` +* Add `rsc.gr.calculate_niche` with flavors `neighborhood` , `utag` , and `cellcharter`. Mirrors `squidpy.gr.calculate_niche` but runs {smaller}`644` {smaller}`S Dicks` +* Add a minimal full-covariance GMM (`squidpy_gpu._gmm.gmm_fit_predict`) used by the `cellcharter` {pr}`644` {smaller}`S Dicks` From 1cf3f15c3288726808acf12f576c3cf48ab0f41d Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 29 Apr 2026 18:02:55 +0200 Subject: [PATCH 3/8] update release note --- docs/release-notes/0.15.1.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes/0.15.1.md b/docs/release-notes/0.15.1.md index 4a730df5..25e892c4 100644 --- a/docs/release-notes/0.15.1.md +++ b/docs/release-notes/0.15.1.md @@ -2,5 +2,5 @@ ```{rubric} Features ``` -* Add `rsc.gr.calculate_niche` with flavors `neighborhood` , `utag` , and `cellcharter`. Mirrors `squidpy.gr.calculate_niche` but runs {smaller}`644` {smaller}`S Dicks` +* Add `rsc.gr.calculate_niche` with flavors `neighborhood`, `utag` , and `cellcharter`. Mirrors `squidpy.gr.calculate_niche` {pr}`644` {smaller}`S Dicks` * Add a minimal full-covariance GMM (`squidpy_gpu._gmm.gmm_fit_predict`) used by the `cellcharter` {pr}`644` {smaller}`S Dicks` From d1985fc463aae4a179756155f6eb090ca10d8608 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Thu, 30 Apr 2026 11:01:49 +0200 Subject: [PATCH 4/8] add cuda test backend --- CMakeLists.txt | 1 + src/rapids_singlecell/_cuda/__init__.py | 1 + src/rapids_singlecell/_cuda/gmm/gmm.cu | 178 ++++++++++ .../_cuda/gmm/kernels_gmm.cuh | 322 ++++++++++++++++++ src/rapids_singlecell/squidpy_gpu/_gmm.py | 254 ++++++++++++-- src/rapids_singlecell/squidpy_gpu/_niche.py | 10 +- tests/test_gmm.py | 58 +++- 7 files changed, 800 insertions(+), 24 deletions(-) create mode 100644 src/rapids_singlecell/_cuda/gmm/gmm.cu create mode 100644 src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index cacf9849..ce994001 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,6 +71,7 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_qc_dask_cuda src/rapids_singlecell/_cuda/qc_dask/qc_kernels_dask.cu) add_nb_cuda_module(_bbknn_cuda src/rapids_singlecell/_cuda/bbknn/bbknn.cu) add_nb_cuda_module(_norm_cuda src/rapids_singlecell/_cuda/norm/norm.cu) + add_nb_cuda_module(_gmm_cuda src/rapids_singlecell/_cuda/gmm/gmm.cu) add_nb_cuda_module(_pr_cuda src/rapids_singlecell/_cuda/pr/pr.cu) add_nb_cuda_module(_nn_descent_cuda src/rapids_singlecell/_cuda/nn_descent/nn_descent.cu) add_nb_cuda_module(_aucell_cuda src/rapids_singlecell/_cuda/aucell/aucell.cu) diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index 35e82a0d..1ca2d22a 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -20,6 +20,7 @@ "_bbknn_cuda", "_cooc_cuda", "_edistance_cuda", + "_gmm_cuda", "_harmony_clustering_cuda", "_harmony_colsum_cuda", "_harmony_correction_batched_cuda", diff --git a/src/rapids_singlecell/_cuda/gmm/gmm.cu b/src/rapids_singlecell/_cuda/gmm/gmm.cu new file mode 100644 index 00000000..fb9d9bd7 --- /dev/null +++ b/src/rapids_singlecell/_cuda/gmm/gmm.cu @@ -0,0 +1,178 @@ +#include +#include + +#include +#include +#include + +#include "../nb_types.h" + +#include "kernels_gmm.cuh" + +using namespace nb::literals; + +constexpr int TILE = 16; +constexpr int E_STEP_BLOCK = 64; +constexpr int NORMALIZE_BLOCK = 32; +constexpr int M_STEP_CHUNK_ROWS = 1024; +constexpr int M_STEP_CHUNKED_MIN_ROWS = 32768; + +static inline void cuda_check_runtime(cudaError_t err, const char* what) { + if (err != cudaSuccess) { + throw std::runtime_error(std::string(what) + + " failed: " + cudaGetErrorString(err)); + } +} + +template +static inline void launch_e_step(const T* X, const T* weights, const T* means, + const T* prec_chol, const T* log_det_half, + int n, int d, int K, T* log_prob, T* resp, + T* ll_per_cell, cudaStream_t stream) { + if (n == 0 || d == 0 || K == 0) return; + if (d > 64) { + throw std::runtime_error( + "_gmm_cuda.e_step supports at most 64 features"); + } + { + size_t shmem = (size_t)(d + d * d) * sizeof(T); + dim3 block(E_STEP_BLOCK); + dim3 grid(K, (n + E_STEP_BLOCK - 1) / E_STEP_BLOCK); + e_step_log_prob_kernel<<>>( + X, weights, means, prec_chol, log_det_half, n, d, K, log_prob); + CUDA_CHECK_LAST_ERROR(e_step_log_prob_kernel); + } + { + dim3 block(NORMALIZE_BLOCK); + dim3 grid(n); + e_step_normalize_kernel + <<>>(log_prob, n, K, resp, ll_per_cell); + CUDA_CHECK_LAST_ERROR(e_step_normalize_kernel); + } +} + +template +static inline void launch_m_step(const T* resp, const T* X, int n, int d, int K, + T reg_covar, T* weights, T* means, + T* covariances, T* workspace_N_k, + T* workspace_num, cudaStream_t stream) { + if (n == 0 || d == 0 || K == 0) return; + int tiles_d = (d + TILE - 1) / TILE; + T eps = std::numeric_limits::epsilon(); + if (n < M_STEP_CHUNKED_MIN_ROWS) { + dim3 block(TILE, TILE); + dim3 grid(K, tiles_d, tiles_d); + m_step_fused_kernel<<>>( + resp, X, n, d, K, workspace_N_k, workspace_num, covariances); + CUDA_CHECK_LAST_ERROR(m_step_fused_kernel); + } else { + size_t n_k_bytes = static_cast(K) * sizeof(T); + size_t num_bytes = static_cast(K) * d * sizeof(T); + size_t cov_bytes = static_cast(K) * d * d * sizeof(T); + cuda_check_runtime(cudaMemsetAsync(workspace_N_k, 0, n_k_bytes, stream), + "cudaMemsetAsync(workspace_N_k)"); + cuda_check_runtime(cudaMemsetAsync(workspace_num, 0, num_bytes, stream), + "cudaMemsetAsync(workspace_num)"); + cuda_check_runtime(cudaMemsetAsync(covariances, 0, cov_bytes, stream), + "cudaMemsetAsync(covariances)"); + int chunks = (n + M_STEP_CHUNK_ROWS - 1) / M_STEP_CHUNK_ROWS; + dim3 block(TILE, TILE); + dim3 grid(K, tiles_d * tiles_d, chunks); + m_step_chunked_atomic_kernel<<>>( + resp, X, n, d, K, tiles_d, M_STEP_CHUNK_ROWS, workspace_N_k, + workspace_num, covariances); + CUDA_CHECK_LAST_ERROR(m_step_chunked_atomic_kernel); + } + { + int threads = 256; + dim3 block(threads); + dim3 grid(K); + m_step_finalize_kernel<<>>( + workspace_N_k, workspace_num, covariances, weights, means, + reg_covar, eps, n, d, K); + CUDA_CHECK_LAST_ERROR(m_step_finalize_kernel); + } +} + +template +void register_bindings(nb::module_& m) { + m.def( + "e_step", + [](gpu_array_c X, + gpu_array_c weights, + gpu_array_c means, + gpu_array_c prec_chol, + gpu_array_c log_det_half, + gpu_array_c log_prob, gpu_array_c resp, + gpu_array_c ll_per_cell, int n, int d, int K, + std::uintptr_t stream) { + launch_e_step(X.data(), weights.data(), means.data(), + prec_chol.data(), log_det_half.data(), n, d, K, + log_prob.data(), resp.data(), + ll_per_cell.data(), (cudaStream_t)stream); + }, + "X"_a, "weights"_a, "means"_a, "prec_chol"_a, "log_det_half"_a, + "log_prob"_a, "resp"_a, "ll_per_cell"_a, nb::kw_only(), "n"_a, "d"_a, + "K"_a, "stream"_a = 0); + + m.def( + "e_step", + [](gpu_array_c X, + gpu_array_c weights, + gpu_array_c means, + gpu_array_c prec_chol, + gpu_array_c log_det_half, + gpu_array_c log_prob, + gpu_array_c resp, + gpu_array_c ll_per_cell, int n, int d, int K, + std::uintptr_t stream) { + launch_e_step(X.data(), weights.data(), means.data(), + prec_chol.data(), log_det_half.data(), n, d, + K, log_prob.data(), resp.data(), + ll_per_cell.data(), (cudaStream_t)stream); + }, + "X"_a, "weights"_a, "means"_a, "prec_chol"_a, "log_det_half"_a, + "log_prob"_a, "resp"_a, "ll_per_cell"_a, nb::kw_only(), "n"_a, "d"_a, + "K"_a, "stream"_a = 0); + + m.def( + "m_step", + [](gpu_array_c resp, + gpu_array_c X, + gpu_array_c weights, gpu_array_c means, + gpu_array_c covariances, + gpu_array_c N_k_workspace, + gpu_array_c num_workspace, int n, int d, int K, + float reg_covar, std::uintptr_t stream) { + launch_m_step(resp.data(), X.data(), n, d, K, reg_covar, + weights.data(), means.data(), + covariances.data(), N_k_workspace.data(), + num_workspace.data(), (cudaStream_t)stream); + }, + "resp"_a, "X"_a, "weights"_a, "means"_a, "covariances"_a, + "N_k_workspace"_a, "num_workspace"_a, nb::kw_only(), "n"_a, "d"_a, + "K"_a, "reg_covar"_a, "stream"_a = 0); + + m.def( + "m_step", + [](gpu_array_c resp, + gpu_array_c X, + gpu_array_c weights, + gpu_array_c means, + gpu_array_c covariances, + gpu_array_c N_k_workspace, + gpu_array_c num_workspace, int n, int d, int K, + double reg_covar, std::uintptr_t stream) { + launch_m_step(resp.data(), X.data(), n, d, K, reg_covar, + weights.data(), means.data(), + covariances.data(), N_k_workspace.data(), + num_workspace.data(), (cudaStream_t)stream); + }, + "resp"_a, "X"_a, "weights"_a, "means"_a, "covariances"_a, + "N_k_workspace"_a, "num_workspace"_a, nb::kw_only(), "n"_a, "d"_a, + "K"_a, "reg_covar"_a, "stream"_a = 0); +} + +NB_MODULE(_gmm_cuda, m) { + REGISTER_GPU_BINDINGS(register_bindings, m); +} diff --git a/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh b/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh new file mode 100644 index 00000000..18bb1227 --- /dev/null +++ b/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh @@ -0,0 +1,322 @@ +#pragma once + +#include + +// ---------------------------------------------------------------------------- +// Per-(n, k) E-step log-probability. +// +// Each block (k, n_chunk) caches means[k] and prec_chol[k] in shared memory, +// then each thread computes mahalanobis for one cell against the cached +// component. Output is row-major log_prob[n, k] with the log-weight already +// folded in: +// +// y[j] = Σ_d (X[n, d] − means[k, d]) · prec_chol[k, d, j] +// mahal[n, k] = Σ_j y[j]² +// log_prob[n, k] = −0.5·d·log(2π) + log_det_half[k] − 0.5·mahal + +// log(weights[k]) +// +// A separate normalize kernel does the per-row logsumexp. +// ---------------------------------------------------------------------------- + +constexpr float LOG_2PI_F = 1.8378770664093453f; +constexpr double LOG_2PI_D = 1.8378770664093453; + +template +__device__ __forceinline__ T log_2pi_const(); +template <> +__device__ __forceinline__ float log_2pi_const() { + return LOG_2PI_F; +} +template <> +__device__ __forceinline__ double log_2pi_const() { + return LOG_2PI_D; +} + +template +__global__ void e_step_log_prob_kernel( + const T* __restrict__ X, // (n, d) row-major + const T* __restrict__ weights, // (K,) + const T* __restrict__ means, // (K, d) + const T* __restrict__ prec_chol, // (K, d, d) row-major; cov_inv = + // chol·cholᵀ + const T* __restrict__ log_det_half, // (K,) + int n, int d, int K, + T* __restrict__ log_prob // (n, K) +) { + int k = blockIdx.x; + int n_idx = blockIdx.y * blockDim.x + threadIdx.x; + int tid = threadIdx.x; + + extern __shared__ unsigned char smem_raw[]; + T* sh_mean = reinterpret_cast(smem_raw); + T* sh_pc = sh_mean + d; + + // Cooperatively load means[k] and prec_chol[k] into shared memory. + for (int i = tid; i < d; i += blockDim.x) sh_mean[i] = means[k * d + i]; + int pc_size = d * d; + for (int i = tid; i < pc_size; i += blockDim.x) + sh_pc[i] = prec_chol[k * pc_size + i]; + + T log_w = log(weights[k]); + T ldh = log_det_half[k]; + T const_term = T(-0.5) * T(d) * log_2pi_const() + ldh + log_w; + + __syncthreads(); + + if (n_idx >= n) return; + + // Compute mahal = || (X[n] - μ_k) · prec_chol[k] ||² + T centered_vals[64]; + for (int dd = 0; dd < d; ++dd) + centered_vals[dd] = X[n_idx * d + dd] - sh_mean[dd]; + + T mahal = T(0); + for (int j = 0; j < d; ++j) { + T y = T(0); + for (int dd = 0; dd < d; ++dd) { + y += centered_vals[dd] * sh_pc[dd * d + j]; + } + mahal += y * y; + } + log_prob[n_idx * K + k] = const_term - T(0.5) * mahal; +} + +// ---------------------------------------------------------------------------- +// Per-cell logsumexp normalize: resp[n, k] = exp(log_prob[n, k] − logΣ_k). +// Also writes per-cell log-likelihood (= logΣ_k) into ll_per_cell for later +// reduction. One block per cell; threads stride across K. +// ---------------------------------------------------------------------------- + +template +__global__ void e_step_normalize_kernel( + const T* __restrict__ log_prob, // (n, K) + int n, int K, + T* __restrict__ resp, // (n, K) + T* __restrict__ ll_per_cell // (n,) +) { + int n_idx = blockIdx.x; + if (n_idx >= n) return; + int tid = threadIdx.x; + + __shared__ T sh_max; + __shared__ T sh_sum; + + // pass 1: max over K + T local_max = -CUDART_INF_F; + for (int k = tid; k < K; k += blockDim.x) { + T v = log_prob[n_idx * K + k]; + if (v > local_max) local_max = v; + } + // warp + block reduce max + for (int off = 16; off > 0; off >>= 1) { + T other = __shfl_down_sync(0xffffffff, local_max, off); + if (other > local_max) local_max = other; + } + if (tid == 0) sh_max = local_max; + __syncthreads(); + T mx = sh_max; + + // pass 2: sum exp(log_prob - max) + T local_sum = T(0); + for (int k = tid; k < K; k += blockDim.x) { + local_sum += exp(log_prob[n_idx * K + k] - mx); + } + for (int off = 16; off > 0; off >>= 1) + local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + if (tid == 0) { + sh_sum = local_sum; + T log_total = log(local_sum) + mx; + ll_per_cell[n_idx] = log_total; + } + __syncthreads(); + T log_total = log(sh_sum) + mx; + + // pass 3: write normalized responsibilities + for (int k = tid; k < K; k += blockDim.x) { + resp[n_idx * K + k] = exp(log_prob[n_idx * K + k] - log_total); + } +} + +// ---------------------------------------------------------------------------- +// Direct M-step kernel for smaller inputs. +// ---------------------------------------------------------------------------- + +template +__global__ void m_step_fused_kernel(const T* __restrict__ resp, + const T* __restrict__ X, int n, int d, + int K, T* __restrict__ N_k, + T* __restrict__ num, T* __restrict__ sm) { + int k = blockIdx.x; + int i_tile = blockIdx.y * TILE; + int j_tile = blockIdx.z * TILE; + int ti = threadIdx.y; + int tj = threadIdx.x; + int i = i_tile + ti; + int j = j_tile + tj; + bool valid = (i < d) && (j < d); + bool emit_cov = j_tile >= i_tile; + + bool emit_num = (j_tile == 0); + bool emit_Nk = (i_tile == 0) && (j_tile == 0); + + if (!emit_cov && !emit_num) return; + + __shared__ T Xi[TILE][TILE]; + __shared__ T Xj[TILE][TILE]; + __shared__ T r[TILE]; + + T accum_sm = T(0); + T accum_num = T(0); + T block_Nk = T(0); + + for (int n_base = 0; n_base < n; n_base += TILE) { + int n_in = min(TILE, n - n_base); + + T xi_v = T(0), xj_v = T(0); + if (ti < n_in && (i_tile + tj) < d) + xi_v = X[(n_base + ti) * d + i_tile + tj]; + if (emit_cov && ti < n_in && (j_tile + tj) < d) + xj_v = X[(n_base + ti) * d + j_tile + tj]; + Xi[ti][tj] = xi_v; + Xj[ti][tj] = xj_v; + + if (ti == 0) { + r[tj] = (tj < n_in) ? resp[(n_base + tj) * K + k] : T(0); + } + __syncthreads(); + + if (valid) { +#pragma unroll + for (int t = 0; t < TILE; ++t) { + T rt = r[t]; + T xi_t = Xi[t][ti]; + if (emit_cov) { + T xj_t = Xj[t][tj]; + accum_sm += rt * xi_t * xj_t; + } + if (emit_num && tj == 0) accum_num += rt * xi_t; + } + } + + if (emit_Nk && ti == 0 && tj == 0) { + T chunk_sum = T(0); +#pragma unroll + for (int t = 0; t < TILE; ++t) chunk_sum += r[t]; + block_Nk += chunk_sum; + } + __syncthreads(); + } + + if (emit_cov && valid) sm[k * d * d + i * d + j] = accum_sm; + if (emit_num && tj == 0 && i < d) num[k * d + i] = accum_num; + if (emit_Nk && ti == 0 && tj == 0) N_k[k] = block_Nk; +} + +template +__global__ void m_step_chunked_atomic_kernel( + const T* __restrict__ resp, const T* __restrict__ X, int n, int d, int K, + int tiles_d, int chunk_size, T* __restrict__ N_k, T* __restrict__ num, + T* __restrict__ sm) { + int k = blockIdx.x; + int tile_pair = blockIdx.y; + int chunk = blockIdx.z; + int i_tile = (tile_pair / tiles_d) * TILE; + int j_tile = (tile_pair % tiles_d) * TILE; + int chunk_start = chunk * chunk_size; + int chunk_end = min(n, chunk_start + chunk_size); + + int ti = threadIdx.y; + int tj = threadIdx.x; + int i = i_tile + ti; + int j = j_tile + tj; + bool valid = (i < d) && (j < d); + bool emit_cov = j_tile >= i_tile; + + bool emit_num = (j_tile == 0); + bool emit_Nk = (i_tile == 0) && (j_tile == 0); + + if (!emit_cov && !emit_num) return; + + __shared__ T Xi[TILE][TILE]; + __shared__ T Xj[TILE][TILE]; + __shared__ T r[TILE]; + + T accum_sm = T(0); + T accum_num = T(0); + T block_Nk = T(0); + + for (int n_base = chunk_start; n_base < chunk_end; n_base += TILE) { + int n_in = min(TILE, chunk_end - n_base); + + T xi_v = T(0), xj_v = T(0); + if (ti < n_in && (i_tile + tj) < d) + xi_v = X[(n_base + ti) * d + i_tile + tj]; + if (emit_cov && ti < n_in && (j_tile + tj) < d) + xj_v = X[(n_base + ti) * d + j_tile + tj]; + Xi[ti][tj] = xi_v; + Xj[ti][tj] = xj_v; + + if (ti == 0) { + r[tj] = (tj < n_in) ? resp[(n_base + tj) * K + k] : T(0); + } + __syncthreads(); + + if (valid) { +#pragma unroll + for (int t = 0; t < TILE; ++t) { + T rt = r[t]; + T xi_t = Xi[t][ti]; + if (emit_cov) { + T xj_t = Xj[t][tj]; + accum_sm += rt * xi_t * xj_t; + } + if (emit_num && tj == 0) accum_num += rt * xi_t; + } + } + + if (emit_Nk && ti == 0 && tj == 0) { + T chunk_sum = T(0); +#pragma unroll + for (int t = 0; t < TILE; ++t) chunk_sum += r[t]; + block_Nk += chunk_sum; + } + __syncthreads(); + } + + if (emit_cov && valid) atomicAdd(&sm[k * d * d + i * d + j], accum_sm); + if (emit_num && tj == 0 && i < d) atomicAdd(&num[k * d + i], accum_num); + if (emit_Nk && ti == 0 && tj == 0) atomicAdd(&N_k[k], block_Nk); +} + +template +__global__ void m_step_finalize_kernel(const T* __restrict__ N_k, + const T* __restrict__ num, + T* __restrict__ sm_to_cov, + T* __restrict__ weights, + T* __restrict__ means, T reg_covar, + T eps, int n, int d, int K) { + int k = blockIdx.x; + int tid = threadIdx.x; + + T Nk = N_k[k] + T(10) * eps; + T inv_Nk = T(1) / Nk; + + if (tid == 0) weights[k] = Nk / T(n); + + for (int i = tid; i < d; i += blockDim.x) + means[k * d + i] = num[k * d + i] * inv_Nk; + __syncthreads(); + + int total = d * d; + for (int idx = tid; idx < total; idx += blockDim.x) { + int i = idx / d; + int j = idx % d; + if (i > j) continue; + T mi = means[k * d + i]; + T mj = means[k * d + j]; + T v = sm_to_cov[k * d * d + idx] * inv_Nk - mi * mj; + if (i == j) v += reg_covar; + sm_to_cov[k * d * d + idx] = v; + if (i != j) sm_to_cov[k * d * d + j * d + i] = v; + } +} diff --git a/src/rapids_singlecell/squidpy_gpu/_gmm.py b/src/rapids_singlecell/squidpy_gpu/_gmm.py index bb8779ea..4cbe2ba0 100644 --- a/src/rapids_singlecell/squidpy_gpu/_gmm.py +++ b/src/rapids_singlecell/squidpy_gpu/_gmm.py @@ -3,19 +3,23 @@ Mirrors :class:`sklearn.mixture.GaussianMixture` with ``covariance_type="full"``. Two init strategies are exposed: ``"kmeans"`` (default, cuML KMeans warm-start) and ``"random_from_data"`` (sklearn-equivalent for parity -testing). A future cuML-backed or fused-CUDA GMM can replace this. +testing). The default ``"auto"`` backend uses a nanobind/CUDA EM path when the +compiled extension is available, with a CuPy fallback for environments without +compiled extensions. Implementation notes -------------------- -- E-step uses a cached precision Cholesky (computed once per M-step) and a - per-component dense matmul. This avoids the per-component triangular solve. +- CUDA E-step uses a cached precision Cholesky (computed once per M-step) and a + custom Mahalanobis + responsibility kernel for the common <=64-PC case. +- CUDA M-step reuses preallocated workspaces and computes full covariances from + upper-triangle tiled reductions. - ``_precision_cholesky`` is a batched ``inv``+``cholesky`` — no Python K-loop. -- ``cupyx.scipy.special.logsumexp`` for the stable softmax over components. - Convergence is the change in mean log-likelihood. """ from __future__ import annotations +import importlib from typing import Literal import cupy as cp @@ -34,6 +38,8 @@ def gmm_fit_predict( tol: float = 1e-3, reg_covar: float = 1e-6, init: Literal["kmeans", "random_from_data"] = "kmeans", + backend: Literal["auto", "cupy", "cuda"] = "auto", + kmeans_n_init: int = 1, ) -> cp.ndarray: """Fit a full-covariance GMM and return cluster labels. @@ -54,36 +60,191 @@ def gmm_fit_predict( init ``"kmeans"`` (default) uses cuML KMeans for warm-start; usually much better than ``"random_from_data"``, which mirrors sklearn for parity. + backend + ``"auto"`` (default) uses the nanobind/CUDA EM backend when the compiled + extension is available, otherwise falls back to CuPy. ``"cupy"`` uses + CuPy + cuBLAS for the covariance update. ``"cuda"`` forces the custom + CUDA kernels for the E-step and M-step reductions. + kmeans_n_init + Number of cuML KMeans restarts for ``init="kmeans"``. The default ``1`` + matches sklearn's GaussianMixture default and keeps cellcharter fast; + increase this for difficult or noisy initialization landscapes. """ + if backend not in {"auto", "cupy", "cuda"}: + raise ValueError("backend must be one of 'auto', 'cupy', or 'cuda'.") + if int(kmeans_n_init) < 1: + raise ValueError("kmeans_n_init must be >= 1.") + X = cp.ascontiguousarray(X) - n_samples, _ = X.shape K = int(n_components) - - weights, means, covariances = _initialize(X, K, random_state, reg_covar, init) + backend = _resolve_backend(backend) + + if backend == "cuda" and X.shape[1] <= 64: + return _fit_predict_cuda( + X, + K, + random_state=random_state, + max_iter=max_iter, + tol=tol, + reg_covar=reg_covar, + init=init, + kmeans_n_init=int(kmeans_n_init), + ) + + weights, means, covariances = _initialize( + X, K, random_state, reg_covar, init, int(kmeans_n_init) + ) prec_chol, log_det_prec_half = _precision_cholesky(covariances) prev_ll = -cp.inf converged = False for _ in range(max_iter): - log_resp, ll = _e_step(X, weights, means, prec_chol, log_det_prec_half) + resp, ll = _e_step( + X, weights, means, prec_chol, log_det_prec_half, backend=backend + ) if cp.abs(ll - prev_ll) < tol: converged = True break prev_ll = ll - weights, means, covariances = _m_step(X, log_resp, reg_covar) + weights, means, covariances = _m_step(X, resp, reg_covar, backend=backend) + prec_chol, log_det_prec_half = _precision_cholesky(covariances) + + if not converged: + resp, _ = _e_step( + X, weights, means, prec_chol, log_det_prec_half, backend=backend + ) + return resp.argmax(axis=1).astype(cp.int32) + + +def _fit_predict_cuda( + X: cp.ndarray, + K: int, + *, + random_state: int, + max_iter: int, + tol: float, + reg_covar: float, + init: str, + kmeans_n_init: int, +) -> cp.ndarray: + weights, means, covariances = _initialize( + X, K, random_state, reg_covar, init, kmeans_n_init + ) + weights = cp.ascontiguousarray(weights) + means = cp.ascontiguousarray(means) + covariances = cp.ascontiguousarray(covariances) + workspace = _GMMCudaWorkspace(X, K) + prec_chol, log_det_prec_half = _precision_cholesky(covariances) + + prev_ll = -np.inf + converged = False + for _ in range(max_iter): + resp, ll = workspace.e_step(weights, means, prec_chol, log_det_prec_half) + ll_value = float(ll) + if abs(ll_value - prev_ll) < tol: + converged = True + break + prev_ll = ll_value + workspace.m_step(resp, weights, means, covariances, reg_covar) prec_chol, log_det_prec_half = _precision_cholesky(covariances) if not converged: - log_resp, _ = _e_step(X, weights, means, prec_chol, log_det_prec_half) - return log_resp.argmax(axis=1).astype(cp.int32) + resp, _ = workspace.e_step(weights, means, prec_chol, log_det_prec_half) + return resp.argmax(axis=1).astype(cp.int32) + + +class _GMMCudaWorkspace: + def __init__(self, X: cp.ndarray, K: int): + n, d = X.shape + self._gc = _get_gmm_cuda() + self.X = X + self.n = int(n) + self.d = int(d) + self.K = int(K) + self.stream = cp.cuda.get_current_stream().ptr + self.log_prob = cp.empty((n, K), dtype=X.dtype) + self.resp = cp.empty((n, K), dtype=X.dtype) + self.ll_per_cell = cp.empty(n, dtype=X.dtype) + self.N_k = cp.empty(K, dtype=X.dtype) + self.num = cp.empty((K, d), dtype=X.dtype) + + def e_step( + self, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, + ) -> tuple[cp.ndarray, cp.ndarray]: + self._gc.e_step( + self.X, + cp.ascontiguousarray(weights.astype(self.X.dtype, copy=False)), + cp.ascontiguousarray(means), + cp.ascontiguousarray(prec_chol), + cp.ascontiguousarray(log_det_half.astype(self.X.dtype, copy=False)), + self.log_prob, + self.resp, + self.ll_per_cell, + n=self.n, + d=self.d, + K=self.K, + stream=self.stream, + ) + return self.resp, self.ll_per_cell.mean() + + def m_step( + self, + resp: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + covariances: cp.ndarray, + reg_covar: float, + ) -> None: + self._gc.m_step( + cp.ascontiguousarray(resp), + self.X, + weights, + means, + covariances, + self.N_k, + self.num, + n=self.n, + d=self.d, + K=self.K, + reg_covar=float(reg_covar), + stream=self.stream, + ) + + +def _resolve_backend(backend: str) -> str: + if backend == "cupy": + return backend + try: + _get_gmm_cuda() + except ImportError: + if backend == "cuda": + raise + return "cupy" + return "cuda" + + +def _get_gmm_cuda(): + try: + return importlib.import_module("rapids_singlecell._cuda._gmm_cuda") + except ImportError as err: + raise ImportError( + "The _gmm_cuda extension is not available. Build rapids-singlecell " + "with CUDA extensions or use backend='cupy'." + ) from err def _initialize( X: cp.ndarray, K: int, + *, random_state: int, reg_covar: float, init: str, + kmeans_n_init: int, ) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: n, d = X.shape eye_reg = reg_covar * cp.eye(d, dtype=X.dtype) @@ -102,10 +263,12 @@ def _initialize( from cuml.cluster import KMeans - # Match sklearn's n_init=10 inside its GaussianMixture's kmeans init. - # n_init=1 was empirically prone to degenerate inits on structureless data, - # which then collapsed EM to a single component. - km = KMeans(n_clusters=K, random_state=random_state, n_init=10, max_iter=100) + km = KMeans( + n_clusters=K, + random_state=random_state, + n_init=int(kmeans_n_init), + max_iter=100, + ) km.fit(X) labels = cp.asarray(km.labels_) means = cp.asarray(km.cluster_centers_, dtype=X.dtype) @@ -150,8 +313,32 @@ def _e_step( means: cp.ndarray, prec_chol: cp.ndarray, log_det_half: cp.ndarray, + *, + backend: str = "cupy", ) -> tuple[cp.ndarray, cp.ndarray]: n, d = X.shape + if backend == "cuda" and d <= 64: + _gc = _get_gmm_cuda() + + K = means.shape[0] + log_prob = cp.empty((n, K), dtype=X.dtype) + resp = cp.empty((n, K), dtype=X.dtype) + ll_per_cell = cp.empty(n, dtype=X.dtype) + _gc.e_step( + cp.ascontiguousarray(X), + cp.ascontiguousarray(weights.astype(X.dtype)), + cp.ascontiguousarray(means), + cp.ascontiguousarray(prec_chol), + cp.ascontiguousarray(log_det_half.astype(X.dtype)), + log_prob, + resp, + ll_per_cell, + n=int(n), + d=int(d), + K=int(K), + stream=cp.cuda.get_current_stream().ptr, + ) + return resp, ll_per_cell.mean() K = means.shape[0] log_prob = cp.empty((n, K), dtype=X.dtype) @@ -164,24 +351,47 @@ def _e_step( log_prob = log_prob + cp.log(weights) log_total = logsumexp(log_prob, axis=1, keepdims=True) - log_resp = log_prob - log_total - return log_resp, log_total.mean() + resp = cp.exp(log_prob - log_total) + return resp, log_total.mean() def _m_step( X: cp.ndarray, - log_resp: cp.ndarray, + resp: cp.ndarray, reg_covar: float, + *, + backend: str = "cupy", ) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: n, d = X.shape - K = log_resp.shape[1] + K = resp.shape[1] + + if backend == "cuda": + _gc = _get_gmm_cuda() + + weights = cp.empty(K, dtype=X.dtype) + means = cp.empty((K, d), dtype=X.dtype) + covariances = cp.empty((K, d, d), dtype=X.dtype) + N_k_ws = cp.empty(K, dtype=X.dtype) + num_ws = cp.empty((K, d), dtype=X.dtype) + _gc.m_step( + cp.ascontiguousarray(resp), + cp.ascontiguousarray(X), + weights, + means, + covariances, + N_k_ws, + num_ws, + n=int(n), + d=int(d), + K=int(K), + reg_covar=float(reg_covar), + stream=cp.cuda.get_current_stream().ptr, + ) + return weights, means, covariances - resp = cp.exp(log_resp) N_k = resp.sum(axis=0) + 10.0 * cp.finfo(X.dtype).eps # (K,) - weights = N_k / n means = (resp.T @ X) / N_k[:, None] - covariances = cp.empty((K, d, d), dtype=X.dtype) eye_reg = reg_covar * cp.eye(d, dtype=X.dtype) for k in range(K): diff --git a/src/rapids_singlecell/squidpy_gpu/_niche.py b/src/rapids_singlecell/squidpy_gpu/_niche.py index 12e95551..b73b43d0 100644 --- a/src/rapids_singlecell/squidpy_gpu/_niche.py +++ b/src/rapids_singlecell/squidpy_gpu/_niche.py @@ -33,6 +33,7 @@ def calculate_niche( n_components: int = 10, use_rep: str | None = None, init: Literal["kmeans", "random_from_data"] = "kmeans", + kmeans_n_init: int = 1, spatial_connectivities_key: str = "spatial_connectivities", random_state: int = 0, copy: bool = False, @@ -84,9 +85,13 @@ def calculate_niche( if provided, the first ``n_components`` columns are used and the shell-aggregation + PCA step is skipped. init - GMM initialization for ``flavor="cellcharter"``. ``"kmeans"`` (default, robust) + GMM initialization for ``flavor="cellcharter"``. ``"kmeans"`` (default) or ``"random_from_data"`` (sklearn-parity). Use the latter if kmeans init lands on a degenerate component on noisy / low-signal data. + kmeans_n_init + Number of cuML KMeans restarts for ``flavor="cellcharter", init="kmeans"``. + The default ``1`` follows sklearn's GaussianMixture default and keeps the + CUDA GMM path fast. spatial_connectivities_key Key in ``adata.obsp`` with the spatial connectivity matrix. random_state @@ -120,6 +125,7 @@ def calculate_niche( n_components=n_components, use_rep=use_rep, init=init, + kmeans_n_init=kmeans_n_init, random_state=random_state, key=spatial_connectivities_key, ) @@ -264,6 +270,7 @@ def _run_cellcharter( n_components: int, use_rep: str | None, init: str, + kmeans_n_init: int, random_state: int, key: str, ) -> None: @@ -299,6 +306,7 @@ def _run_cellcharter( n_components=n_components, random_state=random_state, init=init, + kmeans_n_init=kmeans_n_init, ) adata.obs["cellcharter_niche"] = pd.Categorical(cp.asnumpy(labels).astype(str)) diff --git a/tests/test_gmm.py b/tests/test_gmm.py index 8fbdce5e..45c2fa1d 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -5,7 +5,13 @@ import pytest from sklearn.metrics import adjusted_rand_score as ARI -from rapids_singlecell.squidpy_gpu._gmm import gmm_fit_predict +from rapids_singlecell.squidpy_gpu._gmm import ( + _e_step, + _m_step, + _precision_cholesky, + _resolve_backend, + gmm_fit_predict, +) def _well_separated(n_per: int, K: int, d: int, sep: float, seed: int): @@ -66,6 +72,26 @@ def test_invalid_init_raises(): gmm_fit_predict(X, n_components=3, init="bogus") +def test_invalid_backend_raises(): + X = cp.asarray(np.zeros((100, 5), dtype=np.float32)) + with pytest.raises(ValueError, match="backend"): + gmm_fit_predict(X, n_components=3, backend="bogus") + + +def test_invalid_kmeans_n_init_raises(): + X = cp.asarray(np.zeros((100, 5), dtype=np.float32)) + with pytest.raises(ValueError, match="kmeans_n_init"): + gmm_fit_predict(X, n_components=3, kmeans_n_init=0) + + +def test_auto_backend_uses_cuda_when_available(): + if _resolve_backend("auto") != "cuda": + pytest.skip("_gmm_cuda extension is not available") + + assert _resolve_backend("auto") == "cuda" + assert _resolve_backend("cuda") == "cuda" + + def test_n_components_one_returns_single_label(): rng = np.random.default_rng(0) X = cp.asarray(rng.standard_normal((200, 4)).astype(np.float32)) @@ -78,3 +104,33 @@ def test_float64_input_accepted(): X = cp.asarray(rng.standard_normal((300, 6)).astype(np.float64)) labels = gmm_fit_predict(X, n_components=3, random_state=0) assert labels.shape == (300,) + + +def test_cuda_backend_matches_cupy_steps(): + if _resolve_backend("auto") != "cuda": + pytest.skip("_gmm_cuda extension is not available") + + rng = cp.random.RandomState(0) + n, d, K = 40_000, 6, 3 # large enough to exercise the chunked M-step path + X = rng.standard_normal((n, d), dtype=cp.float32) + logits = rng.standard_normal((n, K), dtype=cp.float32) + resp = cp.exp(logits - cp.log(cp.exp(logits).sum(axis=1, keepdims=True))) + + w_c, m_c, cov_c = _m_step(X, resp, 1e-6, backend="cupy") + w_g, m_g, cov_g = _m_step(X, resp, 1e-6, backend="cuda") + + assert cp.max(cp.abs(w_c - w_g)).item() < 1e-5 + assert cp.max(cp.abs(m_c - m_g)).item() < 1e-5 + assert cp.max(cp.abs(cov_c - cov_g)).item() < 1e-4 + + weights = cp.full(K, 1 / K, dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + A = rng.standard_normal((K, d, d), dtype=cp.float32) + cov = A @ A.transpose(0, 2, 1) + cp.eye(d, dtype=cp.float32)[None] * 0.1 + prec_chol, log_det_half = _precision_cholesky(cov) + + r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") + r_g, ll_g = _e_step(X, weights, means, prec_chol, log_det_half, backend="cuda") + + assert cp.max(cp.abs(r_c - r_g)).item() < 1e-4 + assert cp.abs(ll_c - ll_g).item() < 1e-4 From 82e28e805c6fcf689e79b4e497613a8847e71263 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 8 May 2026 23:00:33 +0200 Subject: [PATCH 5/8] add cublas m step for testing --- CMakeLists.txt | 1 + src/rapids_singlecell/_cuda/gmm/gmm.cu | 122 ++++++++++++++++++ .../_cuda/gmm/kernels_gmm.cuh | 67 +++++++++- src/rapids_singlecell/squidpy_gpu/_gmm.py | 100 ++++++++++++-- tests/test_gmm.py | 2 +- 5 files changed, 280 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ce994001..e176455d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,7 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_bbknn_cuda src/rapids_singlecell/_cuda/bbknn/bbknn.cu) add_nb_cuda_module(_norm_cuda src/rapids_singlecell/_cuda/norm/norm.cu) add_nb_cuda_module(_gmm_cuda src/rapids_singlecell/_cuda/gmm/gmm.cu) + target_link_libraries(_gmm_cuda PRIVATE CUDA::cublas) add_nb_cuda_module(_pr_cuda src/rapids_singlecell/_cuda/pr/pr.cu) add_nb_cuda_module(_nn_descent_cuda src/rapids_singlecell/_cuda/nn_descent/nn_descent.cu) add_nb_cuda_module(_aucell_cuda src/rapids_singlecell/_cuda/aucell/aucell.cu) diff --git a/src/rapids_singlecell/_cuda/gmm/gmm.cu b/src/rapids_singlecell/_cuda/gmm/gmm.cu index fb9d9bd7..5ee52b82 100644 --- a/src/rapids_singlecell/_cuda/gmm/gmm.cu +++ b/src/rapids_singlecell/_cuda/gmm/gmm.cu @@ -1,10 +1,13 @@ #include #include +#include + #include #include #include +#include "../cublas_helpers.cuh" #include "../nb_types.h" #include "kernels_gmm.cuh" @@ -24,6 +27,15 @@ static inline void cuda_check_runtime(cudaError_t err, const char* what) { } } +static inline void cublas_check_status(cublasStatus_t status, + const char* what) { + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error(std::string(what) + + " failed with cuBLAS status " + + std::to_string(static_cast(status))); + } +} + template static inline void launch_e_step(const T* X, const T* weights, const T* means, const T* prec_chol, const T* log_det_half, @@ -94,6 +106,73 @@ static inline void launch_m_step(const T* resp, const T* X, int n, int d, int K, } } +template +static inline void launch_m_step_cublas(const T* resp, const T* X, + const T* ones, int n, int d, int K, + T reg_covar, T* weights, T* means, + T* covariances, T* workspace_N_k, + T* workspace_num, T* workspace_centered, + cudaStream_t stream) { + if (n == 0 || d == 0 || K == 0) return; + + cublasHandle_t handle; + cublas_check_status(cublasCreate(&handle), "cublasCreate"); + cublas_check_status(cublasSetStream(handle, stream), "cublasSetStream"); + + T one = T(1); + T zero = T(0); + T eps = std::numeric_limits::epsilon(); + + // Row-major resp(n,K) is cuBLAS column-major (K,n). N_k = resp.T @ 1. + cublas_check_status(cublas_gemv(handle, CUBLAS_OP_N, K, n, &one, resp, K, + ones, 1, &zero, workspace_N_k, 1), + "cublas_gemv(N_k)"); + + // Row-major X(n,d) is cuBLAS column-major (d,n). Fill row-major + // workspace_num(K,d) through its column-major (d,K) view with X.T @ resp. + cublas_check_status( + cublas_gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, d, K, n, &one, X, d, + resp, K, &zero, workspace_num, d), + "cublas_gemm(num)"); + + { + int threads = 256; + dim3 block(threads); + dim3 grid(K); + m_step_finalize_means_kernel<<>>( + workspace_N_k, workspace_num, weights, means, eps, n, d, K); + CUDA_CHECK_LAST_ERROR(m_step_finalize_means_kernel); + } + + { + int threads = 256; + int blocks = (int)(((size_t)n * d + threads - 1) / threads); + for (int k = 0; k < K; ++k) { + weighted_center_kernel<<>>( + X, resp, means, n, d, K, k, workspace_centered); + CUDA_CHECK_LAST_ERROR(weighted_center_kernel); + + T* cov_k = covariances + (size_t)k * d * d; + cublas_check_status( + cublas_gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, d, d, n, &one, + workspace_centered, d, workspace_centered, d, + &zero, cov_k, d), + "cublas_gemm(covariance)"); + } + } + + { + int threads = 256; + dim3 block(threads); + dim3 grid(K); + m_step_finalize_cov_cublas_kernel<<>>( + workspace_N_k, covariances, reg_covar, eps, d, K); + CUDA_CHECK_LAST_ERROR(m_step_finalize_cov_cublas_kernel); + } + + cublas_check_status(cublasDestroy(handle), "cublasDestroy"); +} + template void register_bindings(nb::module_& m) { m.def( @@ -153,6 +232,49 @@ void register_bindings(nb::module_& m) { "N_k_workspace"_a, "num_workspace"_a, nb::kw_only(), "n"_a, "d"_a, "K"_a, "reg_covar"_a, "stream"_a = 0); + m.def( + "m_step_cublas", + [](gpu_array_c resp, + gpu_array_c X, + gpu_array_c ones, + gpu_array_c weights, gpu_array_c means, + gpu_array_c covariances, + gpu_array_c N_k_workspace, + gpu_array_c num_workspace, + gpu_array_c centered_workspace, int n, int d, int K, + float reg_covar, std::uintptr_t stream) { + launch_m_step_cublas( + resp.data(), X.data(), ones.data(), n, d, K, reg_covar, + weights.data(), means.data(), covariances.data(), + N_k_workspace.data(), num_workspace.data(), + centered_workspace.data(), (cudaStream_t)stream); + }, + "resp"_a, "X"_a, "ones"_a, "weights"_a, "means"_a, "covariances"_a, + "N_k_workspace"_a, "num_workspace"_a, "centered_workspace"_a, + nb::kw_only(), "n"_a, "d"_a, "K"_a, "reg_covar"_a, "stream"_a = 0); + + m.def( + "m_step_cublas", + [](gpu_array_c resp, + gpu_array_c X, + gpu_array_c ones, + gpu_array_c weights, + gpu_array_c means, + gpu_array_c covariances, + gpu_array_c N_k_workspace, + gpu_array_c num_workspace, + gpu_array_c centered_workspace, int n, int d, int K, + double reg_covar, std::uintptr_t stream) { + launch_m_step_cublas( + resp.data(), X.data(), ones.data(), n, d, K, reg_covar, + weights.data(), means.data(), covariances.data(), + N_k_workspace.data(), num_workspace.data(), + centered_workspace.data(), (cudaStream_t)stream); + }, + "resp"_a, "X"_a, "ones"_a, "weights"_a, "means"_a, "covariances"_a, + "N_k_workspace"_a, "num_workspace"_a, "centered_workspace"_a, + nb::kw_only(), "n"_a, "d"_a, "K"_a, "reg_covar"_a, "stream"_a = 0); + m.def( "m_step", [](gpu_array_c resp, diff --git a/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh b/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh index 18bb1227..d22c57d6 100644 --- a/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh +++ b/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh @@ -73,7 +73,9 @@ __global__ void e_step_log_prob_kernel( T mahal = T(0); for (int j = 0; j < d; ++j) { T y = T(0); - for (int dd = 0; dd < d; ++dd) { + // prec_chol is lower triangular from Cholesky, so entries above the + // diagonal are zero. Skip that half of the multiply. + for (int dd = j; dd < d; ++dd) { y += centered_vals[dd] * sh_pc[dd * d + j]; } mahal += y * y; @@ -320,3 +322,66 @@ __global__ void m_step_finalize_kernel(const T* __restrict__ N_k, if (i != j) sm_to_cov[k * d * d + j * d + i] = v; } } + +template +__global__ void m_step_finalize_means_kernel(const T* __restrict__ N_k, + const T* __restrict__ num, + T* __restrict__ weights, + T* __restrict__ means, T eps, + int n, int d, int K) { + int k = blockIdx.x; + int tid = threadIdx.x; + if (k >= K) return; + + T Nk = N_k[k] + T(10) * eps; + T inv_Nk = T(1) / Nk; + if (tid == 0) weights[k] = Nk / T(n); + + for (int i = tid; i < d; i += blockDim.x) + means[k * d + i] = num[k * d + i] * inv_Nk; +} + +template +__global__ void weighted_center_kernel(const T* __restrict__ X, + const T* __restrict__ resp, + const T* __restrict__ means, int n, + int d, int K, int k, + T* __restrict__ centered) { + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + size_t total = (size_t)n * d; + if (idx >= total) return; + + int row = idx / d; + int col = idx - (size_t)row * d; + T r = resp[row * K + k]; + centered[idx] = sqrt(r) * (X[idx] - means[k * d + col]); +} + +template +__global__ void m_step_finalize_cov_cublas_kernel(const T* __restrict__ N_k, + T* __restrict__ covariances, + T reg_covar, T eps, int d, + int K) { + int k = blockIdx.x; + int tid = threadIdx.x; + if (k >= K) return; + + T Nk = N_k[k] + T(10) * eps; + T inv_Nk = T(1) / Nk; + int total = d * d; + T* cov = covariances + (size_t)k * d * d; + + for (int idx = tid; idx < total; idx += blockDim.x) { + int i = idx / d; + int j = idx % d; + if (i > j) continue; + + // cuBLAS wrote the row-major symmetric result through a column-major + // view. Read the transposed element and write a symmetric row-major + // covariance. + T v = cov[j * d + i] * inv_Nk; + if (i == j) v += reg_covar; + cov[i * d + j] = v; + if (i != j) cov[j * d + i] = v; + } +} diff --git a/src/rapids_singlecell/squidpy_gpu/_gmm.py b/src/rapids_singlecell/squidpy_gpu/_gmm.py index 4cbe2ba0..200ee91d 100644 --- a/src/rapids_singlecell/squidpy_gpu/_gmm.py +++ b/src/rapids_singlecell/squidpy_gpu/_gmm.py @@ -11,8 +11,8 @@ -------------------- - CUDA E-step uses a cached precision Cholesky (computed once per M-step) and a custom Mahalanobis + responsibility kernel for the common <=64-PC case. -- CUDA M-step reuses preallocated workspaces and computes full covariances from - upper-triangle tiled reductions. +- CUDA M-step reuses preallocated workspaces and uses a cuBLAS covariance + update to avoid the chunked atomic reduction bottleneck. - ``_precision_cholesky`` is a batched ``inv``+``cholesky`` — no Python K-loop. - Convergence is the change in mean log-likelihood. """ @@ -64,7 +64,7 @@ def gmm_fit_predict( ``"auto"`` (default) uses the nanobind/CUDA EM backend when the compiled extension is available, otherwise falls back to CuPy. ``"cupy"`` uses CuPy + cuBLAS for the covariance update. ``"cuda"`` forces the custom - CUDA kernels for the E-step and M-step reductions. + CUDA kernels for the E-step and the fastest available CUDA M-step. kmeans_n_init Number of cuML KMeans restarts for ``init="kmeans"``. The default ``1`` matches sklearn's GaussianMixture default and keeps cellcharter fast; @@ -92,7 +92,13 @@ def gmm_fit_predict( ) weights, means, covariances = _initialize( - X, K, random_state, reg_covar, init, int(kmeans_n_init) + X, + K, + random_state=random_state, + reg_covar=reg_covar, + init=init, + kmeans_n_init=int(kmeans_n_init), + backend=backend, ) prec_chol, log_det_prec_half = _precision_cholesky(covariances) @@ -128,7 +134,13 @@ def _fit_predict_cuda( kmeans_n_init: int, ) -> cp.ndarray: weights, means, covariances = _initialize( - X, K, random_state, reg_covar, init, kmeans_n_init + X, + K, + random_state=random_state, + reg_covar=reg_covar, + init=init, + kmeans_n_init=kmeans_n_init, + backend="cuda", ) weights = cp.ascontiguousarray(weights) means = cp.ascontiguousarray(means) @@ -145,7 +157,7 @@ def _fit_predict_cuda( converged = True break prev_ll = ll_value - workspace.m_step(resp, weights, means, covariances, reg_covar) + workspace.m_step_cublas(resp, weights, means, covariances, reg_covar) prec_chol, log_det_prec_half = _precision_cholesky(covariances) if not converged: @@ -214,6 +226,35 @@ def m_step( stream=self.stream, ) + def m_step_cublas( + self, + resp: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + covariances: cp.ndarray, + reg_covar: float, + ) -> None: + if not hasattr(self, "_ones"): + self._ones = cp.ones(self.n, dtype=self.X.dtype) + if not hasattr(self, "_centered"): + self._centered = cp.empty_like(self.X) + self._gc.m_step_cublas( + cp.ascontiguousarray(resp), + self.X, + self._ones, + weights, + means, + covariances, + self.N_k, + self.num, + self._centered, + n=self.n, + d=self.d, + K=self.K, + reg_covar=float(reg_covar), + stream=self.stream, + ) + def _resolve_backend(backend: str) -> str: if backend == "cupy": @@ -245,6 +286,7 @@ def _initialize( reg_covar: float, init: str, kmeans_n_init: int, + backend: str = "cupy", ) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: n, d = X.shape eye_reg = reg_covar * cp.eye(d, dtype=X.dtype) @@ -273,6 +315,9 @@ def _initialize( labels = cp.asarray(km.labels_) means = cp.asarray(km.cluster_centers_, dtype=X.dtype) + if backend == "cuda": + return _initialize_covariances_cuda(X, labels, means, reg_covar) + weights = cp.zeros(K, dtype=X.dtype) covariances = cp.empty((K, d, d), dtype=X.dtype) for k in range(K): @@ -289,6 +334,34 @@ def _initialize( return weights, means, covariances +def _initialize_covariances_cuda( + X: cp.ndarray, + labels: cp.ndarray, + means_init: cp.ndarray, + reg_covar: float, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + n = X.shape[0] + K = means_init.shape[0] + + labels = labels.astype(cp.int64, copy=False) + resp = cp.zeros((n, K), dtype=X.dtype) + resp[cp.arange(n), labels] = X.dtype.type(1.0) + weights, means, covariances = _m_step(X, resp, reg_covar, backend="cuda") + + counts = cp.bincount(labels, minlength=K).astype(X.dtype, copy=False) + empty = counts == 0 + eye_reg = reg_covar * cp.eye(X.shape[1], dtype=X.dtype) + + weights = cp.where(empty, X.dtype.type(1.0 / n), counts / n) + means = cp.where(empty[:, None], means_init, means) + covariances = cp.where( + empty[:, None, None], + cp.broadcast_to(eye_reg, covariances.shape), + covariances, + ) + return weights, means, covariances + + def _precision_cholesky( covariances: cp.ndarray, ) -> tuple[cp.ndarray, cp.ndarray]: @@ -373,19 +446,26 @@ def _m_step( covariances = cp.empty((K, d, d), dtype=X.dtype) N_k_ws = cp.empty(K, dtype=X.dtype) num_ws = cp.empty((K, d), dtype=X.dtype) - _gc.m_step( - cp.ascontiguousarray(resp), - cp.ascontiguousarray(X), + resp = cp.ascontiguousarray(resp) + X = cp.ascontiguousarray(X) + stream = cp.cuda.get_current_stream().ptr + ones_ws = cp.ones(n, dtype=X.dtype) + centered_ws = cp.empty_like(X) + _gc.m_step_cublas( + resp, + X, + ones_ws, weights, means, covariances, N_k_ws, num_ws, + centered_ws, n=int(n), d=int(d), K=int(K), reg_covar=float(reg_covar), - stream=cp.cuda.get_current_stream().ptr, + stream=stream, ) return weights, means, covariances diff --git a/tests/test_gmm.py b/tests/test_gmm.py index 45c2fa1d..388e0693 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -111,7 +111,7 @@ def test_cuda_backend_matches_cupy_steps(): pytest.skip("_gmm_cuda extension is not available") rng = cp.random.RandomState(0) - n, d, K = 40_000, 6, 3 # large enough to exercise the chunked M-step path + n, d, K = 40_000, 6, 3 # large enough to exercise the cuBLAS M-step path X = rng.standard_normal((n, d), dtype=cp.float32) logits = rng.standard_normal((n, K), dtype=cp.float32) resp = cp.exp(logits - cp.log(cp.exp(logits).sum(axis=1, keepdims=True))) From 5ad33458b765e0cade76ba2ddb25835b6ac0cc6b Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 11 May 2026 21:19:31 +0200 Subject: [PATCH 6/8] add new GMM paths --- src/rapids_singlecell/_cuda/gmm/gmm.cu | 282 ++++++++----- .../_cuda/gmm/kernels_gmm.cuh | 399 +++++++++--------- src/rapids_singlecell/squidpy_gpu/_gmm.py | 137 +++--- tests/test_gmm.py | 179 ++++++++ 4 files changed, 627 insertions(+), 370 deletions(-) diff --git a/src/rapids_singlecell/_cuda/gmm/gmm.cu b/src/rapids_singlecell/_cuda/gmm/gmm.cu index 5ee52b82..b6254036 100644 --- a/src/rapids_singlecell/_cuda/gmm/gmm.cu +++ b/src/rapids_singlecell/_cuda/gmm/gmm.cu @@ -14,11 +14,15 @@ using namespace nb::literals; -constexpr int TILE = 16; constexpr int E_STEP_BLOCK = 64; +constexpr int E_STEP_LARGE64_TILE = 64; +constexpr int E_STEP_THREAD64_BLOCK = 512; constexpr int NORMALIZE_BLOCK = 32; -constexpr int M_STEP_CHUNK_ROWS = 1024; -constexpr int M_STEP_CHUNKED_MIN_ROWS = 32768; +constexpr size_t DEFAULT_DYNAMIC_SMEM_LIMIT = 48 * 1024; + +static inline size_t upper_tri_size(size_t d) { + return (d * (d + 1)) / 2; +} static inline void cuda_check_runtime(cudaError_t err, const char* what) { if (err != cudaSuccess) { @@ -36,23 +40,65 @@ static inline void cublas_check_status(cublasStatus_t status, } } +template +static inline void launch_e_step_log_prob_fixed_d_impl( + const T* X, const T* weights, const T* means, const T* prec_chol, + const T* log_det_half, int n, int K, T* log_prob, dim3 grid, dim3 block, + cudaStream_t stream) { + size_t shmem = (D + upper_tri_size(D)) * sizeof(T); + if (shmem > DEFAULT_DYNAMIC_SMEM_LIMIT) { + cuda_check_runtime( + cudaFuncSetAttribute(e_step_log_prob_small_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + (int)shmem), + "cudaFuncSetAttribute(e_step_log_prob_small_kernel)"); + } + e_step_log_prob_small_kernel<<>>( + X, weights, means, prec_chol, log_det_half, n, D, K, log_prob); + CUDA_CHECK_LAST_ERROR(e_step_log_prob_small_kernel); +} + template static inline void launch_e_step(const T* X, const T* weights, const T* means, const T* prec_chol, const T* log_det_half, int n, int d, int K, T* log_prob, T* resp, T* ll_per_cell, cudaStream_t stream) { if (n == 0 || d == 0 || K == 0) return; - if (d > 64) { - throw std::runtime_error( - "_gmm_cuda.e_step supports at most 64 features"); - } - { - size_t shmem = (size_t)(d + d * d) * sizeof(T); + if (d <= 64) { dim3 block(E_STEP_BLOCK); - dim3 grid(K, (n + E_STEP_BLOCK - 1) / E_STEP_BLOCK); - e_step_log_prob_kernel<<>>( - X, weights, means, prec_chol, log_det_half, n, d, K, log_prob); - CUDA_CHECK_LAST_ERROR(e_step_log_prob_kernel); + dim3 grid((n + E_STEP_BLOCK - 1) / E_STEP_BLOCK, K); + if (d == 16) { + launch_e_step_log_prob_fixed_d_impl( + X, weights, means, prec_chol, log_det_half, n, K, log_prob, + grid, block, stream); + } else if (d == 32) { + launch_e_step_log_prob_fixed_d_impl( + X, weights, means, prec_chol, log_det_half, n, K, log_prob, + grid, block, stream); + } else if (d == 50) { + launch_e_step_log_prob_fixed_d_impl( + X, weights, means, prec_chol, log_det_half, n, K, log_prob, + grid, block, stream); + } else if (d == 64) { + launch_e_step_log_prob_fixed_d_impl( + X, weights, means, prec_chol, log_det_half, n, K, log_prob, + grid, block, stream); + } else { + size_t shmem = ((size_t)d + upper_tri_size(d)) * sizeof(T); + e_step_log_prob_small_kernel<<>>( + X, weights, means, prec_chol, log_det_half, n, d, K, log_prob); + CUDA_CHECK_LAST_ERROR(e_step_log_prob_small_kernel); + } + } else { + dim3 block(E_STEP_THREAD64_BLOCK); + dim3 grid((n + E_STEP_THREAD64_BLOCK - 1) / E_STEP_THREAD64_BLOCK, K); + size_t shmem = ((size_t)E_STEP_LARGE64_TILE + + (size_t)E_STEP_LARGE64_TILE * E_STEP_LARGE64_TILE) * + sizeof(T); + e_step_log_prob_large_d_thread64_kernel + <<>>(X, weights, means, prec_chol, + log_det_half, n, d, K, log_prob); + CUDA_CHECK_LAST_ERROR(e_step_log_prob_large_d_thread64_kernel); } { dim3 block(NORMALIZE_BLOCK); @@ -64,59 +110,63 @@ static inline void launch_e_step(const T* X, const T* weights, const T* means, } template -static inline void launch_m_step(const T* resp, const T* X, int n, int d, int K, - T reg_covar, T* weights, T* means, - T* covariances, T* workspace_N_k, - T* workspace_num, cudaStream_t stream) { +static inline void launch_e_step_cublas(const T* X, const T* weights, + const T* means, const T* prec_chol, + const T* log_det_half, int n, int d, + int K, T* centered_workspace, + T* y_workspace, T* log_prob, T* resp, + T* ll_per_cell, cudaStream_t stream, + cublasHandle_t handle) { if (n == 0 || d == 0 || K == 0) return; - int tiles_d = (d + TILE - 1) / TILE; - T eps = std::numeric_limits::epsilon(); - if (n < M_STEP_CHUNKED_MIN_ROWS) { - dim3 block(TILE, TILE); - dim3 grid(K, tiles_d, tiles_d); - m_step_fused_kernel<<>>( - resp, X, n, d, K, workspace_N_k, workspace_num, covariances); - CUDA_CHECK_LAST_ERROR(m_step_fused_kernel); - } else { - size_t n_k_bytes = static_cast(K) * sizeof(T); - size_t num_bytes = static_cast(K) * d * sizeof(T); - size_t cov_bytes = static_cast(K) * d * d * sizeof(T); - cuda_check_runtime(cudaMemsetAsync(workspace_N_k, 0, n_k_bytes, stream), - "cudaMemsetAsync(workspace_N_k)"); - cuda_check_runtime(cudaMemsetAsync(workspace_num, 0, num_bytes, stream), - "cudaMemsetAsync(workspace_num)"); - cuda_check_runtime(cudaMemsetAsync(covariances, 0, cov_bytes, stream), - "cudaMemsetAsync(covariances)"); - int chunks = (n + M_STEP_CHUNK_ROWS - 1) / M_STEP_CHUNK_ROWS; - dim3 block(TILE, TILE); - dim3 grid(K, tiles_d * tiles_d, chunks); - m_step_chunked_atomic_kernel<<>>( - resp, X, n, d, K, tiles_d, M_STEP_CHUNK_ROWS, workspace_N_k, - workspace_num, covariances); - CUDA_CHECK_LAST_ERROR(m_step_chunked_atomic_kernel); + + bool own_handle = handle == nullptr; + if (own_handle) cublas_check_status(cublasCreate(&handle), "cublasCreate"); + cublas_check_status(cublasSetStream(handle, stream), "cublasSetStream"); + + T one = T(1); + T zero = T(0); + int threads = 256; + int center_blocks = (int)(((size_t)n * d + threads - 1) / threads); + int row_blocks = (n + threads - 1) / threads; + + for (int k = 0; k < K; ++k) { + e_step_center_kernel<<>>( + X, means, n, d, k, centered_workspace); + CUDA_CHECK_LAST_ERROR(e_step_center_kernel); + + const T* pc_k = prec_chol + (size_t)k * d * d; + cublas_check_status( + cublas_gemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, d, n, d, &one, + pc_k, d, centered_workspace, d, &zero, y_workspace, + d), + "cublas_gemm(e_step)"); + + e_step_log_prob_from_y_kernel<<>>( + y_workspace, weights, log_det_half, n, d, K, k, log_prob); + CUDA_CHECK_LAST_ERROR(e_step_log_prob_from_y_kernel); } + { - int threads = 256; - dim3 block(threads); - dim3 grid(K); - m_step_finalize_kernel<<>>( - workspace_N_k, workspace_num, covariances, weights, means, - reg_covar, eps, n, d, K); - CUDA_CHECK_LAST_ERROR(m_step_finalize_kernel); + dim3 block(NORMALIZE_BLOCK); + dim3 grid(n); + e_step_normalize_kernel + <<>>(log_prob, n, K, resp, ll_per_cell); + CUDA_CHECK_LAST_ERROR(e_step_normalize_kernel); } + + if (own_handle) cublas_check_status(cublasDestroy(handle), "cublasDestroy"); } template -static inline void launch_m_step_cublas(const T* resp, const T* X, - const T* ones, int n, int d, int K, - T reg_covar, T* weights, T* means, - T* covariances, T* workspace_N_k, - T* workspace_num, T* workspace_centered, - cudaStream_t stream) { +static inline void launch_m_step(const T* resp, const T* X, const T* ones, + int n, int d, int K, T reg_covar, T* weights, + T* means, T* covariances, T* workspace_N_k, + T* workspace_num, T* workspace_centered, + cudaStream_t stream, cublasHandle_t handle) { if (n == 0 || d == 0 || K == 0) return; - cublasHandle_t handle; - cublas_check_status(cublasCreate(&handle), "cublasCreate"); + bool own_handle = handle == nullptr; + if (own_handle) cublas_check_status(cublasCreate(&handle), "cublasCreate"); cublas_check_status(cublasSetStream(handle, stream), "cublasSetStream"); T one = T(1); @@ -170,7 +220,7 @@ static inline void launch_m_step_cublas(const T* resp, const T* X, CUDA_CHECK_LAST_ERROR(m_step_finalize_cov_cublas_kernel); } - cublas_check_status(cublasDestroy(handle), "cublasDestroy"); + if (own_handle) cublas_check_status(cublasDestroy(handle), "cublasDestroy"); } template @@ -215,25 +265,56 @@ void register_bindings(nb::module_& m) { "K"_a, "stream"_a = 0); m.def( - "m_step", - [](gpu_array_c resp, - gpu_array_c X, - gpu_array_c weights, gpu_array_c means, - gpu_array_c covariances, - gpu_array_c N_k_workspace, - gpu_array_c num_workspace, int n, int d, int K, - float reg_covar, std::uintptr_t stream) { - launch_m_step(resp.data(), X.data(), n, d, K, reg_covar, - weights.data(), means.data(), - covariances.data(), N_k_workspace.data(), - num_workspace.data(), (cudaStream_t)stream); + "e_step_cublas", + [](gpu_array_c X, + gpu_array_c weights, + gpu_array_c means, + gpu_array_c prec_chol, + gpu_array_c log_det_half, + gpu_array_c centered_workspace, + gpu_array_c y_workspace, + gpu_array_c log_prob, gpu_array_c resp, + gpu_array_c ll_per_cell, int n, int d, int K, + std::uintptr_t stream, std::uintptr_t handle) { + launch_e_step_cublas( + X.data(), weights.data(), means.data(), prec_chol.data(), + log_det_half.data(), n, d, K, centered_workspace.data(), + y_workspace.data(), log_prob.data(), resp.data(), + ll_per_cell.data(), (cudaStream_t)stream, + (cublasHandle_t)handle); }, - "resp"_a, "X"_a, "weights"_a, "means"_a, "covariances"_a, - "N_k_workspace"_a, "num_workspace"_a, nb::kw_only(), "n"_a, "d"_a, - "K"_a, "reg_covar"_a, "stream"_a = 0); + "X"_a, "weights"_a, "means"_a, "prec_chol"_a, "log_det_half"_a, + "centered_workspace"_a, "y_workspace"_a, "log_prob"_a, "resp"_a, + "ll_per_cell"_a, nb::kw_only(), "n"_a, "d"_a, "K"_a, "stream"_a = 0, + "handle"_a = 0); m.def( - "m_step_cublas", + "e_step_cublas", + [](gpu_array_c X, + gpu_array_c weights, + gpu_array_c means, + gpu_array_c prec_chol, + gpu_array_c log_det_half, + gpu_array_c centered_workspace, + gpu_array_c y_workspace, + gpu_array_c log_prob, + gpu_array_c resp, + gpu_array_c ll_per_cell, int n, int d, int K, + std::uintptr_t stream, std::uintptr_t handle) { + launch_e_step_cublas( + X.data(), weights.data(), means.data(), prec_chol.data(), + log_det_half.data(), n, d, K, centered_workspace.data(), + y_workspace.data(), log_prob.data(), resp.data(), + ll_per_cell.data(), (cudaStream_t)stream, + (cublasHandle_t)handle); + }, + "X"_a, "weights"_a, "means"_a, "prec_chol"_a, "log_det_half"_a, + "centered_workspace"_a, "y_workspace"_a, "log_prob"_a, "resp"_a, + "ll_per_cell"_a, nb::kw_only(), "n"_a, "d"_a, "K"_a, "stream"_a = 0, + "handle"_a = 0); + + m.def( + "m_step", [](gpu_array_c resp, gpu_array_c X, gpu_array_c ones, @@ -242,19 +323,21 @@ void register_bindings(nb::module_& m) { gpu_array_c N_k_workspace, gpu_array_c num_workspace, gpu_array_c centered_workspace, int n, int d, int K, - float reg_covar, std::uintptr_t stream) { - launch_m_step_cublas( - resp.data(), X.data(), ones.data(), n, d, K, reg_covar, - weights.data(), means.data(), covariances.data(), - N_k_workspace.data(), num_workspace.data(), - centered_workspace.data(), (cudaStream_t)stream); + float reg_covar, std::uintptr_t stream, std::uintptr_t handle) { + launch_m_step(resp.data(), X.data(), ones.data(), n, d, K, + reg_covar, weights.data(), means.data(), + covariances.data(), N_k_workspace.data(), + num_workspace.data(), + centered_workspace.data(), + (cudaStream_t)stream, (cublasHandle_t)handle); }, "resp"_a, "X"_a, "ones"_a, "weights"_a, "means"_a, "covariances"_a, "N_k_workspace"_a, "num_workspace"_a, "centered_workspace"_a, - nb::kw_only(), "n"_a, "d"_a, "K"_a, "reg_covar"_a, "stream"_a = 0); + nb::kw_only(), "n"_a, "d"_a, "K"_a, "reg_covar"_a, "stream"_a = 0, + "handle"_a = 0); m.def( - "m_step_cublas", + "m_step", [](gpu_array_c resp, gpu_array_c X, gpu_array_c ones, @@ -264,35 +347,18 @@ void register_bindings(nb::module_& m) { gpu_array_c N_k_workspace, gpu_array_c num_workspace, gpu_array_c centered_workspace, int n, int d, int K, - double reg_covar, std::uintptr_t stream) { - launch_m_step_cublas( - resp.data(), X.data(), ones.data(), n, d, K, reg_covar, - weights.data(), means.data(), covariances.data(), - N_k_workspace.data(), num_workspace.data(), - centered_workspace.data(), (cudaStream_t)stream); + double reg_covar, std::uintptr_t stream, std::uintptr_t handle) { + launch_m_step(resp.data(), X.data(), ones.data(), n, d, K, + reg_covar, weights.data(), means.data(), + covariances.data(), N_k_workspace.data(), + num_workspace.data(), + centered_workspace.data(), + (cudaStream_t)stream, (cublasHandle_t)handle); }, "resp"_a, "X"_a, "ones"_a, "weights"_a, "means"_a, "covariances"_a, "N_k_workspace"_a, "num_workspace"_a, "centered_workspace"_a, - nb::kw_only(), "n"_a, "d"_a, "K"_a, "reg_covar"_a, "stream"_a = 0); - - m.def( - "m_step", - [](gpu_array_c resp, - gpu_array_c X, - gpu_array_c weights, - gpu_array_c means, - gpu_array_c covariances, - gpu_array_c N_k_workspace, - gpu_array_c num_workspace, int n, int d, int K, - double reg_covar, std::uintptr_t stream) { - launch_m_step(resp.data(), X.data(), n, d, K, reg_covar, - weights.data(), means.data(), - covariances.data(), N_k_workspace.data(), - num_workspace.data(), (cudaStream_t)stream); - }, - "resp"_a, "X"_a, "weights"_a, "means"_a, "covariances"_a, - "N_k_workspace"_a, "num_workspace"_a, nb::kw_only(), "n"_a, "d"_a, - "K"_a, "reg_covar"_a, "stream"_a = 0); + nb::kw_only(), "n"_a, "d"_a, "K"_a, "reg_covar"_a, "stream"_a = 0, + "handle"_a = 0); } NB_MODULE(_gmm_cuda, m) { diff --git a/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh b/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh index d22c57d6..6762b453 100644 --- a/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh +++ b/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh @@ -32,55 +32,216 @@ __device__ __forceinline__ double log_2pi_const() { return LOG_2PI_D; } -template -__global__ void e_step_log_prob_kernel( +__device__ __forceinline__ int upper_tri_col_offset(int col) { + return (col * (col + 1)) / 2; +} + +template +__global__ void e_step_log_prob_small_kernel( const T* __restrict__ X, // (n, d) row-major const T* __restrict__ weights, // (K,) const T* __restrict__ means, // (K, d) - const T* __restrict__ prec_chol, // (K, d, d) row-major; cov_inv = - // chol·cholᵀ + const T* __restrict__ prec_chol, // (K, d, d) row-major; upper factor + // with cov_inv = chol·cholᵀ const T* __restrict__ log_det_half, // (K,) int n, int d, int K, T* __restrict__ log_prob // (n, K) ) { - int k = blockIdx.x; - int n_idx = blockIdx.y * blockDim.x + threadIdx.x; + static_assert(D >= 0 && D <= 64, + "GMM small E-step supports runtime d or fixed D <= 64"); + constexpr bool fixed_d = D != 0; + int dim = fixed_d ? D : d; + int k = blockIdx.y; + int n_idx = blockIdx.x * blockDim.x + threadIdx.x; int tid = threadIdx.x; extern __shared__ unsigned char smem_raw[]; T* sh_mean = reinterpret_cast(smem_raw); - T* sh_pc = sh_mean + d; - - // Cooperatively load means[k] and prec_chol[k] into shared memory. - for (int i = tid; i < d; i += blockDim.x) sh_mean[i] = means[k * d + i]; - int pc_size = d * d; - for (int i = tid; i < pc_size; i += blockDim.x) - sh_pc[i] = prec_chol[k * pc_size + i]; + T* sh_pc = sh_mean + dim; + + // Cooperatively load means[k] and the used upper triangle of prec_chol[k] + // into shared memory. + for (int i = tid; i < dim; i += blockDim.x) + sh_mean[i] = means[(size_t)k * dim + i]; + int pc_size_dense = dim * dim; + for (int i = tid; i < pc_size_dense; i += blockDim.x) { + int row = i / dim; + int col = i - row * dim; + if (row <= col) { + sh_pc[upper_tri_col_offset(col) + row] = + prec_chol[(size_t)k * pc_size_dense + i]; + } + } - T log_w = log(weights[k]); - T ldh = log_det_half[k]; - T const_term = T(-0.5) * T(d) * log_2pi_const() + ldh + log_w; + __shared__ T sh_const; + if (tid == 0) { + sh_const = T(-0.5) * T(dim) * log_2pi_const() + log_det_half[k] + + log(weights[k]); + } __syncthreads(); if (n_idx >= n) return; // Compute mahal = || (X[n] - μ_k) · prec_chol[k] ||² - T centered_vals[64]; - for (int dd = 0; dd < d; ++dd) - centered_vals[dd] = X[n_idx * d + dd] - sh_mean[dd]; + T centered_vals[fixed_d ? D : 64]; + if constexpr (fixed_d) { +#pragma unroll + for (int dd = 0; dd < D; ++dd) + centered_vals[dd] = X[(size_t)n_idx * D + dd] - sh_mean[dd]; + } else { + for (int dd = 0; dd < dim; ++dd) + centered_vals[dd] = X[(size_t)n_idx * dim + dd] - sh_mean[dd]; + } T mahal = T(0); - for (int j = 0; j < d; ++j) { - T y = T(0); - // prec_chol is lower triangular from Cholesky, so entries above the - // diagonal are zero. Skip that half of the multiply. - for (int dd = j; dd < d; ++dd) { - y += centered_vals[dd] * sh_pc[dd * d + j]; + if constexpr (fixed_d) { +#pragma unroll + for (int j = 0; j < D; ++j) { + T y = T(0); + int pc_col = upper_tri_col_offset(j); +#pragma unroll + for (int dd = 0; dd <= j; ++dd) { + y += centered_vals[dd] * sh_pc[pc_col + dd]; + } + mahal += y * y; + } + } else { + for (int j = 0; j < dim; ++j) { + T y = T(0); + int pc_col = upper_tri_col_offset(j); + // prec_chol is the upper triangular precision factor, so entries + // below the diagonal are zero. Skip that half of the multiply. + for (int dd = 0; dd <= j; ++dd) { + y += centered_vals[dd] * sh_pc[pc_col + dd]; + } + mahal += y * y; } - mahal += y * y; } - log_prob[n_idx * K + k] = const_term - T(0.5) * mahal; + log_prob[(size_t)n_idx * K + k] = sh_const - T(0.5) * mahal; +} + +template +__global__ void e_step_log_prob_large_d_thread64_kernel( + const T* __restrict__ X, // (n, d) row-major + const T* __restrict__ weights, // (K,) + const T* __restrict__ means, // (K, d) + const T* __restrict__ prec_chol, // (K, d, d) row-major; upper factor + const T* __restrict__ log_det_half, // (K,) + int n, int d, int K, + T* __restrict__ log_prob // (n, K) +) { + static_assert(TILE_D == 64, + "GMM thread64 E-step expects a 64-column precision tile"); + + int k = blockIdx.y; + int row = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x; + + extern __shared__ unsigned char smem_raw[]; + T* sh_mean = reinterpret_cast(smem_raw); // (64,) + T* sh_pc = sh_mean + TILE_D; // (64, 64) + + __shared__ T sh_const; + if (tid == 0) { + sh_const = T(-0.5) * T(d) * log_2pi_const() + log_det_half[k] + + log(weights[k]); + } + + T local_mahal = T(0); + const T* pc = prec_chol + (size_t)k * d * d; + + for (int j_base = 0; j_base < d; j_base += TILE_D) { + int cols_in_tile = min(TILE_D, d - j_base); + int dd_limit = min(d, j_base + TILE_D); + T y[TILE_D]; +#pragma unroll + for (int col = 0; col < TILE_D; ++col) y[col] = T(0); + + for (int dd_base = 0; dd_base < dd_limit; dd_base += TILE_D) { + int feats_in_tile = min(TILE_D, dd_limit - dd_base); + + for (int idx = tid; idx < TILE_D; idx += blockDim.x) { + sh_mean[idx] = (idx < feats_in_tile) + ? means[(size_t)k * d + dd_base + idx] + : T(0); + } + + constexpr int pc_tile_elems = TILE_D * TILE_D; + for (int idx = tid; idx < pc_tile_elems; idx += blockDim.x) { + int feat = idx / TILE_D; + int col_local = idx - feat * TILE_D; + int dd = dd_base + feat; + int col = j_base + col_local; + T val = T(0); + if (feat < feats_in_tile && col_local < cols_in_tile && + dd <= col) { + val = pc[(size_t)dd * d + col]; + } + sh_pc[feat * TILE_D + col_local] = val; + } + + __syncthreads(); + + if (row < n) { +#pragma unroll + for (int feat = 0; feat < TILE_D; ++feat) { + if (feat >= feats_in_tile) break; + T diff = + X[(size_t)row * d + dd_base + feat] - sh_mean[feat]; +#pragma unroll + for (int col = 0; col < TILE_D; ++col) { + if (col >= cols_in_tile) break; + y[col] += diff * sh_pc[feat * TILE_D + col]; + } + } + } + + __syncthreads(); + } + + if (row < n) { +#pragma unroll + for (int col = 0; col < TILE_D; ++col) { + if (col >= cols_in_tile) break; + local_mahal += y[col] * y[col]; + } + } + } + + if (row < n) + log_prob[(size_t)row * K + k] = sh_const - T(0.5) * local_mahal; +} + +template +__global__ void e_step_center_kernel(const T* __restrict__ X, + const T* __restrict__ means, int n, int d, + int k, T* __restrict__ centered) { + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + size_t total = (size_t)n * d; + if (idx >= total) return; + + int col = idx % d; + centered[idx] = X[idx] - means[(size_t)k * d + col]; +} + +template +__global__ void e_step_log_prob_from_y_kernel( + const T* __restrict__ y, const T* __restrict__ weights, + const T* __restrict__ log_det_half, int n, int d, int K, int k, + T* __restrict__ log_prob) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n) return; + + T mahal = T(0); + for (int col = 0; col < d; ++col) { + T v = y[(size_t)row * d + col]; + mahal += v * v; + } + + T constant = + T(-0.5) * T(d) * log_2pi_const() + log_det_half[k] + log(weights[k]); + log_prob[(size_t)row * K + k] = constant - T(0.5) * mahal; } // ---------------------------------------------------------------------------- @@ -139,190 +300,6 @@ __global__ void e_step_normalize_kernel( } } -// ---------------------------------------------------------------------------- -// Direct M-step kernel for smaller inputs. -// ---------------------------------------------------------------------------- - -template -__global__ void m_step_fused_kernel(const T* __restrict__ resp, - const T* __restrict__ X, int n, int d, - int K, T* __restrict__ N_k, - T* __restrict__ num, T* __restrict__ sm) { - int k = blockIdx.x; - int i_tile = blockIdx.y * TILE; - int j_tile = blockIdx.z * TILE; - int ti = threadIdx.y; - int tj = threadIdx.x; - int i = i_tile + ti; - int j = j_tile + tj; - bool valid = (i < d) && (j < d); - bool emit_cov = j_tile >= i_tile; - - bool emit_num = (j_tile == 0); - bool emit_Nk = (i_tile == 0) && (j_tile == 0); - - if (!emit_cov && !emit_num) return; - - __shared__ T Xi[TILE][TILE]; - __shared__ T Xj[TILE][TILE]; - __shared__ T r[TILE]; - - T accum_sm = T(0); - T accum_num = T(0); - T block_Nk = T(0); - - for (int n_base = 0; n_base < n; n_base += TILE) { - int n_in = min(TILE, n - n_base); - - T xi_v = T(0), xj_v = T(0); - if (ti < n_in && (i_tile + tj) < d) - xi_v = X[(n_base + ti) * d + i_tile + tj]; - if (emit_cov && ti < n_in && (j_tile + tj) < d) - xj_v = X[(n_base + ti) * d + j_tile + tj]; - Xi[ti][tj] = xi_v; - Xj[ti][tj] = xj_v; - - if (ti == 0) { - r[tj] = (tj < n_in) ? resp[(n_base + tj) * K + k] : T(0); - } - __syncthreads(); - - if (valid) { -#pragma unroll - for (int t = 0; t < TILE; ++t) { - T rt = r[t]; - T xi_t = Xi[t][ti]; - if (emit_cov) { - T xj_t = Xj[t][tj]; - accum_sm += rt * xi_t * xj_t; - } - if (emit_num && tj == 0) accum_num += rt * xi_t; - } - } - - if (emit_Nk && ti == 0 && tj == 0) { - T chunk_sum = T(0); -#pragma unroll - for (int t = 0; t < TILE; ++t) chunk_sum += r[t]; - block_Nk += chunk_sum; - } - __syncthreads(); - } - - if (emit_cov && valid) sm[k * d * d + i * d + j] = accum_sm; - if (emit_num && tj == 0 && i < d) num[k * d + i] = accum_num; - if (emit_Nk && ti == 0 && tj == 0) N_k[k] = block_Nk; -} - -template -__global__ void m_step_chunked_atomic_kernel( - const T* __restrict__ resp, const T* __restrict__ X, int n, int d, int K, - int tiles_d, int chunk_size, T* __restrict__ N_k, T* __restrict__ num, - T* __restrict__ sm) { - int k = blockIdx.x; - int tile_pair = blockIdx.y; - int chunk = blockIdx.z; - int i_tile = (tile_pair / tiles_d) * TILE; - int j_tile = (tile_pair % tiles_d) * TILE; - int chunk_start = chunk * chunk_size; - int chunk_end = min(n, chunk_start + chunk_size); - - int ti = threadIdx.y; - int tj = threadIdx.x; - int i = i_tile + ti; - int j = j_tile + tj; - bool valid = (i < d) && (j < d); - bool emit_cov = j_tile >= i_tile; - - bool emit_num = (j_tile == 0); - bool emit_Nk = (i_tile == 0) && (j_tile == 0); - - if (!emit_cov && !emit_num) return; - - __shared__ T Xi[TILE][TILE]; - __shared__ T Xj[TILE][TILE]; - __shared__ T r[TILE]; - - T accum_sm = T(0); - T accum_num = T(0); - T block_Nk = T(0); - - for (int n_base = chunk_start; n_base < chunk_end; n_base += TILE) { - int n_in = min(TILE, chunk_end - n_base); - - T xi_v = T(0), xj_v = T(0); - if (ti < n_in && (i_tile + tj) < d) - xi_v = X[(n_base + ti) * d + i_tile + tj]; - if (emit_cov && ti < n_in && (j_tile + tj) < d) - xj_v = X[(n_base + ti) * d + j_tile + tj]; - Xi[ti][tj] = xi_v; - Xj[ti][tj] = xj_v; - - if (ti == 0) { - r[tj] = (tj < n_in) ? resp[(n_base + tj) * K + k] : T(0); - } - __syncthreads(); - - if (valid) { -#pragma unroll - for (int t = 0; t < TILE; ++t) { - T rt = r[t]; - T xi_t = Xi[t][ti]; - if (emit_cov) { - T xj_t = Xj[t][tj]; - accum_sm += rt * xi_t * xj_t; - } - if (emit_num && tj == 0) accum_num += rt * xi_t; - } - } - - if (emit_Nk && ti == 0 && tj == 0) { - T chunk_sum = T(0); -#pragma unroll - for (int t = 0; t < TILE; ++t) chunk_sum += r[t]; - block_Nk += chunk_sum; - } - __syncthreads(); - } - - if (emit_cov && valid) atomicAdd(&sm[k * d * d + i * d + j], accum_sm); - if (emit_num && tj == 0 && i < d) atomicAdd(&num[k * d + i], accum_num); - if (emit_Nk && ti == 0 && tj == 0) atomicAdd(&N_k[k], block_Nk); -} - -template -__global__ void m_step_finalize_kernel(const T* __restrict__ N_k, - const T* __restrict__ num, - T* __restrict__ sm_to_cov, - T* __restrict__ weights, - T* __restrict__ means, T reg_covar, - T eps, int n, int d, int K) { - int k = blockIdx.x; - int tid = threadIdx.x; - - T Nk = N_k[k] + T(10) * eps; - T inv_Nk = T(1) / Nk; - - if (tid == 0) weights[k] = Nk / T(n); - - for (int i = tid; i < d; i += blockDim.x) - means[k * d + i] = num[k * d + i] * inv_Nk; - __syncthreads(); - - int total = d * d; - for (int idx = tid; idx < total; idx += blockDim.x) { - int i = idx / d; - int j = idx % d; - if (i > j) continue; - T mi = means[k * d + i]; - T mj = means[k * d + j]; - T v = sm_to_cov[k * d * d + idx] * inv_Nk - mi * mj; - if (i == j) v += reg_covar; - sm_to_cov[k * d * d + idx] = v; - if (i != j) sm_to_cov[k * d * d + j * d + i] = v; - } -} - template __global__ void m_step_finalize_means_kernel(const T* __restrict__ N_k, const T* __restrict__ num, diff --git a/src/rapids_singlecell/squidpy_gpu/_gmm.py b/src/rapids_singlecell/squidpy_gpu/_gmm.py index 200ee91d..0ce6a278 100644 --- a/src/rapids_singlecell/squidpy_gpu/_gmm.py +++ b/src/rapids_singlecell/squidpy_gpu/_gmm.py @@ -9,11 +9,12 @@ Implementation notes -------------------- -- CUDA E-step uses a cached precision Cholesky (computed once per M-step) and a - custom Mahalanobis + responsibility kernel for the common <=64-PC case. +- CUDA E-step uses a cached precision Cholesky (computed once per M-step) and + fused Mahalanobis + responsibility kernels. - CUDA M-step reuses preallocated workspaces and uses a cuBLAS covariance update to avoid the chunked atomic reduction bottleneck. -- ``_precision_cholesky`` is a batched ``inv``+``cholesky`` — no Python K-loop. +- ``_precision_cholesky`` uses batched Cholesky + triangular solve - no + explicit covariance inverse and no Python K-loop. - Convergence is the change in mean log-likelihood. """ @@ -24,9 +25,15 @@ import cupy as cp import numpy as np +from cupyx.scipy.linalg import solve_triangular from cupyx.scipy.special import logsumexp _LOG_2PI = float(np.log(2.0 * np.pi)) +# Common <=64-PC widths use scalar row kernels with compile-time specialization. +# Mid-width float32 embeddings use a 64-column tiled CUDA kernel; high-width +# embeddings switch to the cuBLAS E-step. +_CUDA_E_STEP_MAX_D = 512 +_CUDA_CUBLAS_E_STEP_MIN_D = 257 def gmm_fit_predict( @@ -63,8 +70,9 @@ def gmm_fit_predict( backend ``"auto"`` (default) uses the nanobind/CUDA EM backend when the compiled extension is available, otherwise falls back to CuPy. ``"cupy"`` uses - CuPy + cuBLAS for the covariance update. ``"cuda"`` forces the custom - CUDA kernels for the E-step and the fastest available CUDA M-step. + CuPy + cuBLAS for the covariance update. ``"cuda"`` uses fused CUDA + E-step kernels where available and keeps the fastest available CUDA + M-step for wider embeddings. kmeans_n_init Number of cuML KMeans restarts for ``init="kmeans"``. The default ``1`` matches sklearn's GaussianMixture default and keeps cellcharter fast; @@ -79,7 +87,7 @@ def gmm_fit_predict( K = int(n_components) backend = _resolve_backend(backend) - if backend == "cuda" and X.shape[1] <= 64: + if backend == "cuda" and _use_cuda_e_step(int(X.shape[1]), X.dtype): return _fit_predict_cuda( X, K, @@ -157,7 +165,7 @@ def _fit_predict_cuda( converged = True break prev_ll = ll_value - workspace.m_step_cublas(resp, weights, means, covariances, reg_covar) + workspace.m_step(resp, weights, means, covariances, reg_covar) prec_chol, log_det_prec_half = _precision_cholesky(covariances) if not converged: @@ -174,6 +182,7 @@ def __init__(self, X: cp.ndarray, K: int): self.d = int(d) self.K = int(K) self.stream = cp.cuda.get_current_stream().ptr + self.cublas_handle = cp.cuda.device.get_cublas_handle() self.log_prob = cp.empty((n, K), dtype=X.dtype) self.resp = cp.empty((n, K), dtype=X.dtype) self.ll_per_cell = cp.empty(n, dtype=X.dtype) @@ -186,13 +195,24 @@ def e_step( means: cp.ndarray, prec_chol: cp.ndarray, log_det_half: cp.ndarray, + ) -> tuple[cp.ndarray, cp.ndarray]: + if _use_cuda_cublas_e_step(self.d, self.X.dtype): + return self.e_step_cublas(weights, means, prec_chol, log_det_half) + return self.e_step_fused(weights, means, prec_chol, log_det_half) + + def e_step_fused( + self, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, ) -> tuple[cp.ndarray, cp.ndarray]: self._gc.e_step( self.X, - cp.ascontiguousarray(weights.astype(self.X.dtype, copy=False)), + _cuda_arg(weights, self.X.dtype), cp.ascontiguousarray(means), cp.ascontiguousarray(prec_chol), - cp.ascontiguousarray(log_det_half.astype(self.X.dtype, copy=False)), + _cuda_arg(log_det_half, self.X.dtype), self.log_prob, self.resp, self.ll_per_cell, @@ -203,30 +223,37 @@ def e_step( ) return self.resp, self.ll_per_cell.mean() - def m_step( + def e_step_cublas( self, - resp: cp.ndarray, weights: cp.ndarray, means: cp.ndarray, - covariances: cp.ndarray, - reg_covar: float, - ) -> None: - self._gc.m_step( - cp.ascontiguousarray(resp), + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, + ) -> tuple[cp.ndarray, cp.ndarray]: + if not hasattr(self, "_centered"): + self._centered = cp.empty_like(self.X) + if not hasattr(self, "_e_step_y"): + self._e_step_y = cp.empty_like(self.X) + self._gc.e_step_cublas( self.X, - weights, - means, - covariances, - self.N_k, - self.num, + _cuda_arg(weights, self.X.dtype), + cp.ascontiguousarray(means), + cp.ascontiguousarray(prec_chol), + _cuda_arg(log_det_half, self.X.dtype), + self._centered, + self._e_step_y, + self.log_prob, + self.resp, + self.ll_per_cell, n=self.n, d=self.d, K=self.K, - reg_covar=float(reg_covar), stream=self.stream, + handle=self.cublas_handle, ) + return self.resp, self.ll_per_cell.mean() - def m_step_cublas( + def m_step( self, resp: cp.ndarray, weights: cp.ndarray, @@ -238,7 +265,7 @@ def m_step_cublas( self._ones = cp.ones(self.n, dtype=self.X.dtype) if not hasattr(self, "_centered"): self._centered = cp.empty_like(self.X) - self._gc.m_step_cublas( + self._gc.m_step( cp.ascontiguousarray(resp), self.X, self._ones, @@ -253,6 +280,7 @@ def m_step_cublas( K=self.K, reg_covar=float(reg_covar), stream=self.stream, + handle=self.cublas_handle, ) @@ -278,6 +306,25 @@ def _get_gmm_cuda(): ) from err +def _cuda_arg(array: cp.ndarray, dtype) -> cp.ndarray: + return cp.ascontiguousarray(array.astype(dtype, copy=False)) + + +def _use_cuda_e_step(d: int, dtype=None) -> bool: + if not 0 < d <= _CUDA_E_STEP_MAX_D: + return False + if dtype is None: + return True + dtype = np.dtype(dtype) + return d <= 64 or dtype == np.dtype("float32") + + +def _use_cuda_cublas_e_step(d: int, dtype) -> bool: + return _CUDA_CUBLAS_E_STEP_MIN_D <= d <= _CUDA_E_STEP_MAX_D and np.dtype( + dtype + ) == np.dtype("float32") + + def _initialize( X: cp.ndarray, K: int, @@ -367,11 +414,17 @@ def _precision_cholesky( ) -> tuple[cp.ndarray, cp.ndarray]: """Return ``(prec_chol, log|Σ⁻¹|/2)`` where ``prec_chol @ prec_chol.T = Σ⁻¹``. - Two batched LAPACK calls — no Python K-loop. + This mirrors sklearn's precision-Cholesky orientation: ``prec_chol`` is an + upper-triangular factor built from the covariance Cholesky without forming + the explicit covariance inverse. """ - precisions = cp.linalg.inv(covariances) - prec_chol = cp.linalg.cholesky(precisions) - log_det_half = cp.sum(cp.log(cp.diagonal(prec_chol, axis1=1, axis2=2)), axis=1) + cov_chol = cp.linalg.cholesky(covariances) + eye = cp.broadcast_to( + cp.eye(covariances.shape[-1], dtype=covariances.dtype), covariances.shape + ) + cov_chol_inv = solve_triangular(cov_chol, eye, lower=True) + prec_chol = cp.ascontiguousarray(cov_chol_inv.transpose(0, 2, 1)) + log_det_half = -cp.sum(cp.log(cp.diagonal(cov_chol, axis1=1, axis2=2)), axis=1) return prec_chol, log_det_half @@ -390,29 +443,10 @@ def _e_step( backend: str = "cupy", ) -> tuple[cp.ndarray, cp.ndarray]: n, d = X.shape - if backend == "cuda" and d <= 64: - _gc = _get_gmm_cuda() - - K = means.shape[0] - log_prob = cp.empty((n, K), dtype=X.dtype) - resp = cp.empty((n, K), dtype=X.dtype) - ll_per_cell = cp.empty(n, dtype=X.dtype) - _gc.e_step( - cp.ascontiguousarray(X), - cp.ascontiguousarray(weights.astype(X.dtype)), - cp.ascontiguousarray(means), - cp.ascontiguousarray(prec_chol), - cp.ascontiguousarray(log_det_half.astype(X.dtype)), - log_prob, - resp, - ll_per_cell, - n=int(n), - d=int(d), - K=int(K), - stream=cp.cuda.get_current_stream().ptr, - ) - return resp, ll_per_cell.mean() K = means.shape[0] + if backend == "cuda" and _use_cuda_e_step(int(d), X.dtype): + workspace = _GMMCudaWorkspace(cp.ascontiguousarray(X), int(K)) + return workspace.e_step(weights, means, prec_chol, log_det_half) log_prob = cp.empty((n, K), dtype=X.dtype) half_d_log2pi = X.dtype.type(0.5 * d * _LOG_2PI) @@ -451,7 +485,7 @@ def _m_step( stream = cp.cuda.get_current_stream().ptr ones_ws = cp.ones(n, dtype=X.dtype) centered_ws = cp.empty_like(X) - _gc.m_step_cublas( + _gc.m_step( resp, X, ones_ws, @@ -466,6 +500,7 @@ def _m_step( K=int(K), reg_covar=float(reg_covar), stream=stream, + handle=cp.cuda.device.get_cublas_handle(), ) return weights, means, covariances diff --git a/tests/test_gmm.py b/tests/test_gmm.py index 388e0693..19dc1b9d 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -4,12 +4,16 @@ import numpy as np import pytest from sklearn.metrics import adjusted_rand_score as ARI +from sklearn.mixture import GaussianMixture from rapids_singlecell.squidpy_gpu._gmm import ( _e_step, + _GMMCudaWorkspace, _m_step, _precision_cholesky, _resolve_backend, + _use_cuda_cublas_e_step, + _use_cuda_e_step, gmm_fit_predict, ) @@ -25,6 +29,20 @@ def _well_separated(n_per: int, K: int, d: int, sep: float, seed: int): return X[perm], y[perm] +def _pca_like_mixture(n_per: int, K: int, d: int, seed: int) -> np.ndarray: + rng = np.random.default_rng(seed) + centers = rng.normal(scale=5.0, size=(K, d)) + rows = [] + for k in range(K): + # A compact low-rank perturbation of the identity gives each synthetic + # cell state its own correlated PCA-space geometry. + factors = rng.normal(scale=0.3 + 0.05 * k, size=(d, 3)) + cov = np.eye(d) * (0.35 + 0.03 * k) + factors @ factors.T + rows.append(rng.multivariate_normal(centers[k], cov, size=n_per)) + X = np.vstack(rows).astype(np.float32) + return np.ascontiguousarray(X[rng.permutation(len(X))]) + + def test_kmeans_init_recovers_well_separated_clusters(): """kmeans init should land at near-truth on well-separated data.""" X_np, y = _well_separated(n_per=300, K=5, d=20, sep=6.0, seed=0) @@ -34,6 +52,36 @@ def test_kmeans_init_recovers_well_separated_clusters(): assert ARI(y, labels) >= 0.95 +@pytest.mark.parametrize("backend", ["cupy", "cuda"]) +def test_full_cov_gmm_matches_sklearn_on_singlecell_embedding(backend): + if backend == "cuda" and _resolve_backend("auto") != "cuda": + pytest.skip("_gmm_cuda extension is not available") + + X = _pca_like_mixture(n_per=250, K=5, d=16, seed=7) + sk_labels = GaussianMixture( + n_components=5, + covariance_type="full", + tol=1e-3, + reg_covar=1e-6, + max_iter=100, + n_init=1, + init_params="kmeans", + random_state=0, + ).fit_predict(X) + rsc_labels = cp.asnumpy( + gmm_fit_predict( + cp.asarray(X), + n_components=5, + random_state=0, + init="kmeans", + backend=backend, + kmeans_n_init=1, + ) + ) + + assert ARI(sk_labels, rsc_labels) >= 0.99 + + def test_random_from_data_init_runs(): """random_from_data may land at a worse local optimum than kmeans, but should still produce a non-trivial partition on well-separated data.""" @@ -92,6 +140,27 @@ def test_auto_backend_uses_cuda_when_available(): assert _resolve_backend("cuda") == "cuda" +def test_cuda_e_step_routing_uses_fused_kernels_for_wide_embeddings(): + assert _use_cuda_e_step(16) + assert _use_cuda_e_step(32) + assert _use_cuda_e_step(50) + assert _use_cuda_e_step(64) + assert _use_cuda_e_step(80) + assert _use_cuda_e_step(96) + assert _use_cuda_e_step(128) + assert _use_cuda_e_step(256) + assert _use_cuda_e_step(384) + assert _use_cuda_e_step(512) + assert not _use_cuda_e_step(768) + assert _use_cuda_e_step(512, cp.float32) + assert not _use_cuda_e_step(512, cp.float64) + assert _use_cuda_e_step(64, cp.float64) + assert not _use_cuda_cublas_e_step(256, cp.float32) + assert _use_cuda_cublas_e_step(384, cp.float32) + assert _use_cuda_cublas_e_step(512, cp.float32) + assert not _use_cuda_cublas_e_step(512, cp.float64) + + def test_n_components_one_returns_single_label(): rng = np.random.default_rng(0) X = cp.asarray(rng.standard_normal((200, 4)).astype(np.float32)) @@ -131,6 +200,116 @@ def test_cuda_backend_matches_cupy_steps(): r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") r_g, ll_g = _e_step(X, weights, means, prec_chol, log_det_half, backend="cuda") + r_f, ll_f = _GMMCudaWorkspace(X, K).e_step_fused( + weights, means, prec_chol, log_det_half + ) assert cp.max(cp.abs(r_c - r_g)).item() < 1e-4 assert cp.abs(ll_c - ll_g).item() < 1e-4 + assert cp.max(cp.abs(r_c - r_f)).item() < 1e-4 + assert cp.abs(ll_c - ll_f).item() < 1e-4 + + +def test_cuda_large_e_step_matches_cupy_for_large_feature_count(): + if _resolve_backend("auto") != "cuda": + pytest.skip("_gmm_cuda extension is not available") + + rng = cp.random.RandomState(2) + n, d, K = 2048, 96, 4 + X = rng.standard_normal((n, d), dtype=cp.float32) + weights = cp.asarray([0.15, 0.2, 0.3, 0.35], dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + A = rng.standard_normal((K, d, d), dtype=cp.float32) + cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 + prec_chol, log_det_half = _precision_cholesky(cov) + + r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") + r_g, ll_g = _e_step(X, weights, means, prec_chol, log_det_half, backend="cuda") + + assert cp.max(cp.abs(r_c - r_g)).item() < 5e-4 + assert cp.abs(ll_c - ll_g).item() < 5e-4 + + +def test_cuda_512_e_step_matches_cupy_for_cublas_route(): + if _resolve_backend("auto") != "cuda": + pytest.skip("_gmm_cuda extension is not available") + + rng = cp.random.RandomState(5) + n, d, K = 384, 512, 3 + X = rng.standard_normal((n, d), dtype=cp.float32) + weights = cp.asarray([0.2, 0.3, 0.5], dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + A = rng.standard_normal((K, d, d), dtype=cp.float32) + cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 + prec_chol, log_det_half = _precision_cholesky(cov) + + r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") + r_g, ll_g = _e_step(X, weights, means, prec_chol, log_det_half, backend="cuda") + r_b, ll_b = _GMMCudaWorkspace(X, K).e_step_cublas( + weights, means, prec_chol, log_det_half + ) + + assert cp.max(cp.abs(r_c - r_g)).item() < 1e-3 + assert cp.abs(ll_c - ll_g).item() < 1e-3 + assert cp.max(cp.abs(r_c - r_b)).item() < 1e-3 + assert cp.abs(ll_c - ll_b).item() < 1e-3 + + +def test_cuda_fixed_e_step_matches_cupy_for_medium_regime(): + if _resolve_backend("auto") != "cuda": + pytest.skip("_gmm_cuda extension is not available") + + rng = cp.random.RandomState(4) + n, d, K = 1024, 16, 8 + X = rng.standard_normal((n, d), dtype=cp.float32) + weights = cp.full(K, 1 / K, dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + A = rng.standard_normal((K, d, d), dtype=cp.float32) + cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 + prec_chol, log_det_half = _precision_cholesky(cov) + + r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") + r_g, ll_g = _e_step(X, weights, means, prec_chol, log_det_half, backend="cuda") + + assert cp.max(cp.abs(r_c - r_g)).item() < 5e-4 + assert cp.abs(ll_c - ll_g).item() < 5e-4 + + +def test_cuda_fused_e_step_matches_cupy_for_50_pc_regime(): + if _resolve_backend("auto") != "cuda": + pytest.skip("_gmm_cuda extension is not available") + + rng = cp.random.RandomState(6) + n, d, K = 1024, 50, 12 + X = rng.standard_normal((n, d), dtype=cp.float32) + weights = cp.full(K, 1 / K, dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + A = rng.standard_normal((K, d, d), dtype=cp.float32) + cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 + prec_chol, log_det_half = _precision_cholesky(cov) + + r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") + r_d, ll_d = _GMMCudaWorkspace(X, K).e_step(weights, means, prec_chol, log_det_half) + + assert cp.max(cp.abs(r_c - r_d)).item() < 5e-4 + assert cp.abs(ll_c - ll_d).item() < 5e-4 + + +def test_cuda_backend_runs_large_feature_count(): + if _resolve_backend("auto") != "cuda": + pytest.skip("_gmm_cuda extension is not available") + + rng = np.random.default_rng(3) + X = cp.asarray(rng.standard_normal((360, 80)).astype(np.float32)) + labels = gmm_fit_predict( + X, + n_components=3, + random_state=0, + max_iter=2, + reg_covar=1e-2, + init="random_from_data", + backend="cuda", + ) + + assert labels.shape == (360,) + assert labels.dtype == cp.int32 From c78df79d3dcbd1d1ed3e1f3fdcd8fd4b8eaca1dd Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 11 May 2026 22:53:11 +0200 Subject: [PATCH 7/8] fix 64 bit kernel --- .../_cuda/gmm/kernels_gmm.cuh | 6 +++- src/rapids_singlecell/squidpy_gpu/_gmm.py | 24 ++++++++++--- tests/test_gmm.py | 36 +++++++++++++++++-- 3 files changed, 57 insertions(+), 9 deletions(-) diff --git a/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh b/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh index 6762b453..c3b061cf 100644 --- a/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh +++ b/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh @@ -234,9 +234,13 @@ __global__ void e_step_log_prob_from_y_kernel( if (row >= n) return; T mahal = T(0); + T compensation = T(0); for (int col = 0; col < d; ++col) { T v = y[(size_t)row * d + col]; - mahal += v * v; + T term = v * v - compensation; + T next = mahal + term; + compensation = (next - mahal) - term; + mahal = next; } T constant = diff --git a/src/rapids_singlecell/squidpy_gpu/_gmm.py b/src/rapids_singlecell/squidpy_gpu/_gmm.py index 0ce6a278..d5751172 100644 --- a/src/rapids_singlecell/squidpy_gpu/_gmm.py +++ b/src/rapids_singlecell/squidpy_gpu/_gmm.py @@ -30,8 +30,8 @@ _LOG_2PI = float(np.log(2.0 * np.pi)) # Common <=64-PC widths use scalar row kernels with compile-time specialization. -# Mid-width float32 embeddings use a 64-column tiled CUDA kernel; high-width -# embeddings switch to the cuBLAS E-step. +# Mid-width float32 embeddings use the tiled CUDA kernel; high-width fp32 and +# wide float64 embeddings use the cuBLAS E-step. _CUDA_E_STEP_MAX_D = 512 _CUDA_CUBLAS_E_STEP_MIN_D = 257 @@ -207,6 +207,11 @@ def e_step_fused( prec_chol: cp.ndarray, log_det_half: cp.ndarray, ) -> tuple[cp.ndarray, cp.ndarray]: + if not _use_cuda_fused_e_step(self.d, self.X.dtype): + raise ValueError( + "The fused CUDA GMM E-step supports float64 only for " + "d <= 64. Use e_step() or e_step_cublas() for this input." + ) self._gc.e_step( self.X, _cuda_arg(weights, self.X.dtype), @@ -315,14 +320,23 @@ def _use_cuda_e_step(d: int, dtype=None) -> bool: return False if dtype is None: return True + return _use_cuda_fused_e_step(d, dtype) or _use_cuda_cublas_e_step(d, dtype) + + +def _use_cuda_fused_e_step(d: int, dtype) -> bool: + if not 0 < d <= _CUDA_E_STEP_MAX_D: + return False dtype = np.dtype(dtype) return d <= 64 or dtype == np.dtype("float32") def _use_cuda_cublas_e_step(d: int, dtype) -> bool: - return _CUDA_CUBLAS_E_STEP_MIN_D <= d <= _CUDA_E_STEP_MAX_D and np.dtype( - dtype - ) == np.dtype("float32") + if not 0 < d <= _CUDA_E_STEP_MAX_D: + return False + dtype = np.dtype(dtype) + if dtype == np.dtype("float32"): + return d >= _CUDA_CUBLAS_E_STEP_MIN_D + return dtype == np.dtype("float64") and d > 64 def _initialize( diff --git a/tests/test_gmm.py b/tests/test_gmm.py index 19dc1b9d..d96d8b7e 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -14,6 +14,7 @@ _resolve_backend, _use_cuda_cublas_e_step, _use_cuda_e_step, + _use_cuda_fused_e_step, gmm_fit_predict, ) @@ -140,7 +141,7 @@ def test_auto_backend_uses_cuda_when_available(): assert _resolve_backend("cuda") == "cuda" -def test_cuda_e_step_routing_uses_fused_kernels_for_wide_embeddings(): +def test_cuda_e_step_routing_uses_cublas_for_high_d_and_wide_float64(): assert _use_cuda_e_step(16) assert _use_cuda_e_step(32) assert _use_cuda_e_step(50) @@ -153,12 +154,18 @@ def test_cuda_e_step_routing_uses_fused_kernels_for_wide_embeddings(): assert _use_cuda_e_step(512) assert not _use_cuda_e_step(768) assert _use_cuda_e_step(512, cp.float32) - assert not _use_cuda_e_step(512, cp.float64) + assert _use_cuda_e_step(512, cp.float64) assert _use_cuda_e_step(64, cp.float64) + assert not _use_cuda_cublas_e_step(64, cp.float32) + assert not _use_cuda_cublas_e_step(80, cp.float32) + assert not _use_cuda_cublas_e_step(128, cp.float32) assert not _use_cuda_cublas_e_step(256, cp.float32) + assert _use_cuda_fused_e_step(512, cp.float32) + assert not _use_cuda_fused_e_step(128, cp.float64) assert _use_cuda_cublas_e_step(384, cp.float32) assert _use_cuda_cublas_e_step(512, cp.float32) - assert not _use_cuda_cublas_e_step(512, cp.float64) + assert _use_cuda_cublas_e_step(128, cp.float64) + assert _use_cuda_cublas_e_step(512, cp.float64) def test_n_components_one_returns_single_label(): @@ -255,6 +262,29 @@ def test_cuda_512_e_step_matches_cupy_for_cublas_route(): assert cp.abs(ll_c - ll_b).item() < 1e-3 +def test_cuda_float64_wide_e_step_uses_cublas_route(): + if _resolve_backend("auto") != "cuda": + pytest.skip("_gmm_cuda extension is not available") + + rng = cp.random.RandomState(7) + n, d, K = 256, 128, 3 + X = rng.standard_normal((n, d), dtype=cp.float64) + weights = cp.asarray([0.2, 0.3, 0.5], dtype=cp.float64) + means = rng.standard_normal((K, d), dtype=cp.float64) + A = rng.standard_normal((K, d, d), dtype=cp.float64) + cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float64)[None] * 0.5 + prec_chol, log_det_half = _precision_cholesky(cov) + + r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") + workspace = _GMMCudaWorkspace(X, K) + r_g, ll_g = workspace.e_step(weights, means, prec_chol, log_det_half) + + assert cp.max(cp.abs(r_c - r_g)).item() < 1e-12 + assert cp.abs(ll_c - ll_g).item() < 1e-12 + with pytest.raises(ValueError, match="fused CUDA GMM E-step"): + workspace.e_step_fused(weights, means, prec_chol, log_det_half) + + def test_cuda_fixed_e_step_matches_cupy_for_medium_regime(): if _resolve_backend("auto") != "cuda": pytest.skip("_gmm_cuda extension is not available") From 865afb47ee860fb08b48d2d5a2cfee806de312ed Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 12 May 2026 20:51:51 +0200 Subject: [PATCH 8/8] refactor for library --- src/rapids_singlecell/squidpy_gpu/_gmm.py | 819 ++++++++++---------- src/rapids_singlecell/squidpy_gpu/_niche.py | 34 +- tests/test_gmm.py | 400 +++++++--- tests/test_niche.py | 38 +- 4 files changed, 751 insertions(+), 540 deletions(-) diff --git a/src/rapids_singlecell/squidpy_gpu/_gmm.py b/src/rapids_singlecell/squidpy_gpu/_gmm.py index d5751172..f47a066e 100644 --- a/src/rapids_singlecell/squidpy_gpu/_gmm.py +++ b/src/rapids_singlecell/squidpy_gpu/_gmm.py @@ -1,41 +1,171 @@ -"""Minimal GMM (full covariance) for the cellcharter niche flavor. - -Mirrors :class:`sklearn.mixture.GaussianMixture` with -``covariance_type="full"``. Two init strategies are exposed: ``"kmeans"`` (default, -cuML KMeans warm-start) and ``"random_from_data"`` (sklearn-equivalent for parity -testing). The default ``"auto"`` backend uses a nanobind/CUDA EM path when the -compiled extension is available, with a CuPy fallback for environments without -compiled extensions. - -Implementation notes --------------------- -- CUDA E-step uses a cached precision Cholesky (computed once per M-step) and - fused Mahalanobis + responsibility kernels. -- CUDA M-step reuses preallocated workspaces and uses a cuBLAS covariance - update to avoid the chunked atomic reduction bottleneck. -- ``_precision_cholesky`` uses batched Cholesky + triangular solve - no - explicit covariance inverse and no Python K-loop. -- Convergence is the change in mean log-likelihood. +"""Full-covariance GMM for the CellCharter niche flavor. + +The public behavior mirrors :class:`sklearn.mixture.GaussianMixture` with +``covariance_type="full"``. EM is CUDA-only: CuPy is used for array handling and +the precision-Cholesky factorization, while E-step and M-step work is delegated +to the nanobind/CUDA extension. """ from __future__ import annotations -import importlib from typing import Literal import cupy as cp import numpy as np from cupyx.scipy.linalg import solve_triangular -from cupyx.scipy.special import logsumexp -_LOG_2PI = float(np.log(2.0 * np.pi)) -# Common <=64-PC widths use scalar row kernels with compile-time specialization. -# Mid-width float32 embeddings use the tiled CUDA kernel; high-width fp32 and -# wide float64 embeddings use the cuBLAS E-step. -_CUDA_E_STEP_MAX_D = 512 +from rapids_singlecell._cuda import _gmm_cuda as _gc + +_GMMInit = Literal["kmeans", "random_from_data", "sklearn_kmeans"] +_EStepRoute = Literal["fused", "cublas"] + +_KMEANS_MAX_ITER = 100 +_SKLEARN_SEEDED_KMEANS_MAX_ITER = 300 + +# Fused kernels cover the CellCharter regime. Wider float32 embeddings and +# float64 embeddings above 64 dimensions use the cuBLAS E-step. +_CUDA_FUSED_E_STEP_MAX_D = 512 +_CUDA_FUSED_FLOAT64_MAX_D = 64 _CUDA_CUBLAS_E_STEP_MIN_D = 257 +def _allocate_m_step_workspace(X: cp.ndarray, K: int) -> dict[str, cp.ndarray]: + n, d = X.shape + return { + "ones": cp.ones(n, dtype=X.dtype), + "effective_counts": cp.empty(K, dtype=X.dtype), + "weighted_sums": cp.empty((K, d), dtype=X.dtype), + "centered": cp.empty_like(X), + } + + +def _allocate_em_workspace( + X: cp.ndarray, K: int, e_step_route: _EStepRoute +) -> dict[str, cp.ndarray]: + n = X.shape[0] + workspace = { + "log_prob": cp.empty((n, K), dtype=X.dtype), + "responsibilities": cp.empty((n, K), dtype=X.dtype), + "ll_per_cell": cp.empty(n, dtype=X.dtype), + **_allocate_m_step_workspace(X, K), + } + if e_step_route == "cublas": + workspace["e_step_y"] = cp.empty_like(X) + return workspace + + +def _e_step( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, + *, + log_prob: cp.ndarray, + responsibilities: cp.ndarray, + ll_per_cell: cp.ndarray, + centered: cp.ndarray, + e_step_y: cp.ndarray | None, + e_step_route: _EStepRoute, + stream: int, + handle: int, +) -> tuple[cp.ndarray, cp.ndarray]: + if e_step_route == "cublas": + return _e_step_cublas( + X, + weights, + means, + prec_chol, + log_det_half, + centered=centered, + e_step_y=e_step_y, + log_prob=log_prob, + responsibilities=responsibilities, + ll_per_cell=ll_per_cell, + stream=stream, + handle=handle, + ) + return _e_step_fused( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob=log_prob, + responsibilities=responsibilities, + ll_per_cell=ll_per_cell, + stream=stream, + ) + + +def _e_step_fused( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, + *, + log_prob: cp.ndarray, + responsibilities: cp.ndarray, + ll_per_cell: cp.ndarray, + stream: int, +) -> tuple[cp.ndarray, cp.ndarray]: + n, d = X.shape + K = int(weights.shape[0]) + _gc.e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + responsibilities, + ll_per_cell, + n=int(n), + d=int(d), + K=K, + stream=stream, + ) + return responsibilities, ll_per_cell.mean() + + +def _e_step_cublas( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, + *, + centered: cp.ndarray, + e_step_y: cp.ndarray | None, + log_prob: cp.ndarray, + responsibilities: cp.ndarray, + ll_per_cell: cp.ndarray, + stream: int, + handle: int, +) -> tuple[cp.ndarray, cp.ndarray]: + n, d = X.shape + K = int(weights.shape[0]) + _gc.e_step_cublas( + X, + weights, + means, + prec_chol, + log_det_half, + centered, + e_step_y, + log_prob, + responsibilities, + ll_per_cell, + n=int(n), + d=int(d), + K=K, + stream=stream, + handle=handle, + ) + return responsibilities, ll_per_cell.mean() + + def gmm_fit_predict( X: cp.ndarray, n_components: int, @@ -44,8 +174,7 @@ def gmm_fit_predict( max_iter: int = 100, tol: float = 1e-3, reg_covar: float = 1e-6, - init: Literal["kmeans", "random_from_data"] = "kmeans", - backend: Literal["auto", "cupy", "cuda"] = "auto", + init: _GMMInit = "kmeans", kmeans_n_init: int = 1, ) -> cp.ndarray: """Fit a full-covariance GMM and return cluster labels. @@ -53,477 +182,339 @@ def gmm_fit_predict( Parameters ---------- X - Cupy array, shape ``(n_samples, n_features)``, float32 or float64. + GPU matrix with observations in rows and features in columns. n_components - Number of mixture components ``K``. + Number of mixture components. random_state - Seed for initialization. + Seed used by the selected initialization strategy. max_iter - Maximum EM iterations. + Maximum number of EM iterations. tol - Convergence threshold on the change in mean log-likelihood. + Convergence threshold on the mean log-likelihood change. reg_covar - Regularization added to each component covariance diagonal. + Non-negative regularization added to each covariance diagonal. init - ``"kmeans"`` (default) uses cuML KMeans for warm-start; usually much - better than ``"random_from_data"``, which mirrors sklearn for parity. - backend - ``"auto"`` (default) uses the nanobind/CUDA EM backend when the compiled - extension is available, otherwise falls back to CuPy. ``"cupy"`` uses - CuPy + cuBLAS for the covariance update. ``"cuda"`` uses fused CUDA - E-step kernels where available and keeps the fastest available CUDA - M-step for wider embeddings. + Initialization strategy. ``"kmeans"`` uses native cuML KMeans, + ``"random_from_data"`` matches sklearn/Squidpy random-from-data, and + ``"sklearn_kmeans"`` uses sklearn k-means++ seeding followed by cuML + KMeans. kmeans_n_init - Number of cuML KMeans restarts for ``init="kmeans"``. The default ``1`` - matches sklearn's GaussianMixture default and keeps cellcharter fast; - increase this for difficult or noisy initialization landscapes. + Number of cuML KMeans restarts for ``init="kmeans"``. """ - if backend not in {"auto", "cupy", "cuda"}: - raise ValueError("backend must be one of 'auto', 'cupy', or 'cuda'.") + K = int(n_components) + if K < 1: + raise ValueError("n_components must be >= 1.") if int(kmeans_n_init) < 1: raise ValueError("kmeans_n_init must be >= 1.") X = cp.ascontiguousarray(X) - K = int(n_components) - backend = _resolve_backend(backend) - - if backend == "cuda" and _use_cuda_e_step(int(X.shape[1]), X.dtype): - return _fit_predict_cuda( - X, - K, - random_state=random_state, - max_iter=max_iter, - tol=tol, - reg_covar=reg_covar, - init=init, - kmeans_n_init=int(kmeans_n_init), - ) - - weights, means, covariances = _initialize( + weights, means, covariances = _initialize_parameters( X, K, + init=init, random_state=random_state, reg_covar=reg_covar, - init=init, kmeans_n_init=int(kmeans_n_init), - backend=backend, ) - prec_chol, log_det_prec_half = _precision_cholesky(covariances) - - prev_ll = -cp.inf - converged = False - for _ in range(max_iter): - resp, ll = _e_step( - X, weights, means, prec_chol, log_det_prec_half, backend=backend - ) - if cp.abs(ll - prev_ll) < tol: - converged = True - break - prev_ll = ll - weights, means, covariances = _m_step(X, resp, reg_covar, backend=backend) - prec_chol, log_det_prec_half = _precision_cholesky(covariances) - - if not converged: - resp, _ = _e_step( - X, weights, means, prec_chol, log_det_prec_half, backend=backend - ) - return resp.argmax(axis=1).astype(cp.int32) + responsibilities = _run_em( + X, + weights, + means, + covariances, + max_iter=int(max_iter), + tol=float(tol), + reg_covar=float(reg_covar), + ) + return responsibilities.argmax(axis=1).astype(cp.int32) -def _fit_predict_cuda( +def _run_em( X: cp.ndarray, - K: int, + weights: cp.ndarray, + means: cp.ndarray, + covariances: cp.ndarray, *, - random_state: int, max_iter: int, tol: float, reg_covar: float, - init: str, - kmeans_n_init: int, ) -> cp.ndarray: - weights, means, covariances = _initialize( - X, - K, - random_state=random_state, - reg_covar=reg_covar, - init=init, - kmeans_n_init=kmeans_n_init, - backend="cuda", - ) - weights = cp.ascontiguousarray(weights) - means = cp.ascontiguousarray(means) - covariances = cp.ascontiguousarray(covariances) - workspace = _GMMCudaWorkspace(X, K) - prec_chol, log_det_prec_half = _precision_cholesky(covariances) - - prev_ll = -np.inf - converged = False + n, d = X.shape + K = int(weights.shape[0]) + stream = cp.cuda.get_current_stream().ptr + handle = cp.cuda.device.get_cublas_handle() + e_step_route = _choose_e_step(int(d), X.dtype) + workspace = _allocate_em_workspace(X, K, e_step_route) + + prec_chol, log_det_half = _precision_cholesky(covariances) + previous_ll = -np.inf + for _ in range(max_iter): - resp, ll = workspace.e_step(weights, means, prec_chol, log_det_prec_half) - ll_value = float(ll) - if abs(ll_value - prev_ll) < tol: - converged = True - break - prev_ll = ll_value - workspace.m_step(resp, weights, means, covariances, reg_covar) - prec_chol, log_det_prec_half = _precision_cholesky(covariances) - - if not converged: - resp, _ = workspace.e_step(weights, means, prec_chol, log_det_prec_half) - return resp.argmax(axis=1).astype(cp.int32) - - -class _GMMCudaWorkspace: - def __init__(self, X: cp.ndarray, K: int): - n, d = X.shape - self._gc = _get_gmm_cuda() - self.X = X - self.n = int(n) - self.d = int(d) - self.K = int(K) - self.stream = cp.cuda.get_current_stream().ptr - self.cublas_handle = cp.cuda.device.get_cublas_handle() - self.log_prob = cp.empty((n, K), dtype=X.dtype) - self.resp = cp.empty((n, K), dtype=X.dtype) - self.ll_per_cell = cp.empty(n, dtype=X.dtype) - self.N_k = cp.empty(K, dtype=X.dtype) - self.num = cp.empty((K, d), dtype=X.dtype) - - def e_step( - self, - weights: cp.ndarray, - means: cp.ndarray, - prec_chol: cp.ndarray, - log_det_half: cp.ndarray, - ) -> tuple[cp.ndarray, cp.ndarray]: - if _use_cuda_cublas_e_step(self.d, self.X.dtype): - return self.e_step_cublas(weights, means, prec_chol, log_det_half) - return self.e_step_fused(weights, means, prec_chol, log_det_half) - - def e_step_fused( - self, - weights: cp.ndarray, - means: cp.ndarray, - prec_chol: cp.ndarray, - log_det_half: cp.ndarray, - ) -> tuple[cp.ndarray, cp.ndarray]: - if not _use_cuda_fused_e_step(self.d, self.X.dtype): - raise ValueError( - "The fused CUDA GMM E-step supports float64 only for " - "d <= 64. Use e_step() or e_step_cublas() for this input." - ) - self._gc.e_step( - self.X, - _cuda_arg(weights, self.X.dtype), - cp.ascontiguousarray(means), - cp.ascontiguousarray(prec_chol), - _cuda_arg(log_det_half, self.X.dtype), - self.log_prob, - self.resp, - self.ll_per_cell, - n=self.n, - d=self.d, - K=self.K, - stream=self.stream, - ) - return self.resp, self.ll_per_cell.mean() - - def e_step_cublas( - self, - weights: cp.ndarray, - means: cp.ndarray, - prec_chol: cp.ndarray, - log_det_half: cp.ndarray, - ) -> tuple[cp.ndarray, cp.ndarray]: - if not hasattr(self, "_centered"): - self._centered = cp.empty_like(self.X) - if not hasattr(self, "_e_step_y"): - self._e_step_y = cp.empty_like(self.X) - self._gc.e_step_cublas( - self.X, - _cuda_arg(weights, self.X.dtype), - cp.ascontiguousarray(means), - cp.ascontiguousarray(prec_chol), - _cuda_arg(log_det_half, self.X.dtype), - self._centered, - self._e_step_y, - self.log_prob, - self.resp, - self.ll_per_cell, - n=self.n, - d=self.d, - K=self.K, - stream=self.stream, - handle=self.cublas_handle, + responsibilities, mean_ll = _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob=workspace["log_prob"], + responsibilities=workspace["responsibilities"], + ll_per_cell=workspace["ll_per_cell"], + centered=workspace["centered"], + e_step_y=workspace.get("e_step_y"), + e_step_route=e_step_route, + stream=stream, + handle=handle, ) - return self.resp, self.ll_per_cell.mean() - - def m_step( - self, - resp: cp.ndarray, - weights: cp.ndarray, - means: cp.ndarray, - covariances: cp.ndarray, - reg_covar: float, - ) -> None: - if not hasattr(self, "_ones"): - self._ones = cp.ones(self.n, dtype=self.X.dtype) - if not hasattr(self, "_centered"): - self._centered = cp.empty_like(self.X) - self._gc.m_step( - cp.ascontiguousarray(resp), - self.X, - self._ones, + mean_ll = float(mean_ll) + if abs(mean_ll - previous_ll) < tol: + return responsibilities + + previous_ll = mean_ll + _m_step( + X, + responsibilities, weights, means, covariances, - self.N_k, - self.num, - self._centered, - n=self.n, - d=self.d, - K=self.K, - reg_covar=float(reg_covar), - stream=self.stream, - handle=self.cublas_handle, + reg_covar=reg_covar, + ones=workspace["ones"], + effective_counts=workspace["effective_counts"], + weighted_sums=workspace["weighted_sums"], + centered=workspace["centered"], + stream=stream, + handle=handle, ) + prec_chol, log_det_half = _precision_cholesky(covariances) - -def _resolve_backend(backend: str) -> str: - if backend == "cupy": - return backend - try: - _get_gmm_cuda() - except ImportError: - if backend == "cuda": - raise - return "cupy" - return "cuda" - - -def _get_gmm_cuda(): - try: - return importlib.import_module("rapids_singlecell._cuda._gmm_cuda") - except ImportError as err: - raise ImportError( - "The _gmm_cuda extension is not available. Build rapids-singlecell " - "with CUDA extensions or use backend='cupy'." - ) from err - - -def _cuda_arg(array: cp.ndarray, dtype) -> cp.ndarray: - return cp.ascontiguousarray(array.astype(dtype, copy=False)) - - -def _use_cuda_e_step(d: int, dtype=None) -> bool: - if not 0 < d <= _CUDA_E_STEP_MAX_D: - return False - if dtype is None: - return True - return _use_cuda_fused_e_step(d, dtype) or _use_cuda_cublas_e_step(d, dtype) + responsibilities, _ = _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob=workspace["log_prob"], + responsibilities=workspace["responsibilities"], + ll_per_cell=workspace["ll_per_cell"], + centered=workspace["centered"], + e_step_y=workspace.get("e_step_y"), + e_step_route=e_step_route, + stream=stream, + handle=handle, + ) + return responsibilities -def _use_cuda_fused_e_step(d: int, dtype) -> bool: - if not 0 < d <= _CUDA_E_STEP_MAX_D: - return False - dtype = np.dtype(dtype) - return d <= 64 or dtype == np.dtype("float32") +def _initialize_parameters( + X: cp.ndarray, + K: int, + *, + init: str, + random_state: int, + reg_covar: float, + kmeans_n_init: int, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + if init == "random_from_data": + return _random_from_data_init(X, K, random_state, reg_covar) + if init not in ("kmeans", "sklearn_kmeans"): + raise ValueError( + "init must be 'kmeans', 'random_from_data', or " + f"'sklearn_kmeans', got {init!r}" + ) -def _use_cuda_cublas_e_step(d: int, dtype) -> bool: - if not 0 < d <= _CUDA_E_STEP_MAX_D: - return False - dtype = np.dtype(dtype) - if dtype == np.dtype("float32"): - return d >= _CUDA_CUBLAS_E_STEP_MIN_D - return dtype == np.dtype("float64") and d > 64 + labels, centers = _fit_kmeans( + X, + K, + init=init, + random_state=random_state, + kmeans_n_init=kmeans_n_init, + ) + return _parameters_from_labels(X, labels, centers, reg_covar) -def _initialize( +def _random_from_data_init( X: cp.ndarray, K: int, - *, random_state: int, reg_covar: float, - init: str, - kmeans_n_init: int, - backend: str = "cupy", ) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: n, d = X.shape + rng = np.random.RandomState(random_state) + idx = cp.asarray(rng.choice(n, size=K, replace=False)) eye_reg = reg_covar * cp.eye(d, dtype=X.dtype) + return ( + cp.full(K, 1.0 / K, dtype=X.dtype), + X[idx].copy(), + cp.broadcast_to(eye_reg, (K, d, d)).copy(), + ) - if init == "random_from_data": - # sklearn parity: pick K rows as means, equal weights, reg-only covariance. - rng = np.random.default_rng(random_state) - idx = cp.asarray(rng.choice(n, size=K, replace=False)) - means = X[idx].copy() - weights = cp.full(K, 1.0 / K, dtype=X.dtype) - covariances = cp.broadcast_to(eye_reg, (K, d, d)).copy() - return weights, means, covariances - - if init != "kmeans": - raise ValueError(f"init must be 'kmeans' or 'random_from_data', got {init!r}") +def _fit_kmeans( + X: cp.ndarray, + K: int, + *, + init: str, + random_state: int, + kmeans_n_init: int, +) -> tuple[cp.ndarray, cp.ndarray]: from cuml.cluster import KMeans + kwargs = {} + if init == "sklearn_kmeans": + from sklearn.cluster import kmeans_plusplus + + centers, _ = kmeans_plusplus(cp.asnumpy(X), K, random_state=random_state) + kwargs["init"] = cp.asarray(centers, dtype=X.dtype) + kmeans_n_init = 1 + max_iter = _SKLEARN_SEEDED_KMEANS_MAX_ITER + else: + max_iter = _KMEANS_MAX_ITER + km = KMeans( n_clusters=K, random_state=random_state, n_init=int(kmeans_n_init), - max_iter=100, + max_iter=max_iter, + **kwargs, ) km.fit(X) - labels = cp.asarray(km.labels_) - means = cp.asarray(km.cluster_centers_, dtype=X.dtype) - - if backend == "cuda": - return _initialize_covariances_cuda(X, labels, means, reg_covar) - - weights = cp.zeros(K, dtype=X.dtype) - covariances = cp.empty((K, d, d), dtype=X.dtype) - for k in range(K): - mask = labels == k - cnt = int(mask.sum()) - if cnt == 0: - # KMeans can return empty clusters; fall back to a tiny uniform component. - weights[k] = 1.0 / n - covariances[k] = eye_reg - continue - weights[k] = cnt / n - diff = X[mask] - means[k] - covariances[k] = (diff.T @ diff) / cnt + eye_reg - return weights, means, covariances + return ( + cp.asarray(km.labels_).astype(cp.int64, copy=False), + cp.asarray(km.cluster_centers_, dtype=X.dtype), + ) -def _initialize_covariances_cuda( +def _parameters_from_labels( X: cp.ndarray, labels: cp.ndarray, means_init: cp.ndarray, reg_covar: float, ) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: - n = X.shape[0] - K = means_init.shape[0] + n, d = X.shape + K = int(means_init.shape[0]) + weights = cp.empty(K, dtype=X.dtype) + means = cp.empty((K, d), dtype=X.dtype) + covariances = cp.empty((K, d, d), dtype=X.dtype) + responsibilities = cp.zeros((n, K), dtype=X.dtype) + workspace = _allocate_m_step_workspace(X, K) + responsibilities[cp.arange(n), labels] = X.dtype.type(1.0) + _m_step( + X, + responsibilities, + weights, + means, + covariances, + reg_covar=reg_covar, + ones=workspace["ones"], + effective_counts=workspace["effective_counts"], + weighted_sums=workspace["weighted_sums"], + centered=workspace["centered"], + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), + ) + _restore_empty_components( + weights, + means, + covariances, + labels, + means_init, + n=n, + reg_covar=reg_covar, + ) + return weights, means, covariances - labels = labels.astype(cp.int64, copy=False) - resp = cp.zeros((n, K), dtype=X.dtype) - resp[cp.arange(n), labels] = X.dtype.type(1.0) - weights, means, covariances = _m_step(X, resp, reg_covar, backend="cuda") - counts = cp.bincount(labels, minlength=K).astype(X.dtype, copy=False) +def _restore_empty_components( + weights: cp.ndarray, + means: cp.ndarray, + covariances: cp.ndarray, + labels: cp.ndarray, + means_init: cp.ndarray, + *, + n: int, + reg_covar: float, +) -> None: + """Repair empty cuML KMeans components before EM starts. + + sklearn's GMM init estimates parameters from hard KMeans responsibilities + and adds ``10 * eps`` to component counts, relying on sklearn KMeans to + avoid empty final labels in normal cases. cuML can still hand back an empty + component, so keep its center, give it a tiny finite weight, and use the + regularized identity covariance instead of letting the M-step create a + zero-mean component from an empty responsibility column. + """ + K, d = means_init.shape + counts = cp.bincount(labels, minlength=int(K)).astype(means_init.dtype, copy=False) empty = counts == 0 - eye_reg = reg_covar * cp.eye(X.shape[1], dtype=X.dtype) + eye_reg = reg_covar * cp.eye(d, dtype=means_init.dtype) - weights = cp.where(empty, X.dtype.type(1.0 / n), counts / n) - means = cp.where(empty[:, None], means_init, means) - covariances = cp.where( + weights[...] = cp.where(empty, means_init.dtype.type(1.0 / n), counts / n) + means[...] = cp.where(empty[:, None], means_init, means) + covariances[...] = cp.where( empty[:, None, None], cp.broadcast_to(eye_reg, covariances.shape), covariances, ) - return weights, means, covariances -def _precision_cholesky( - covariances: cp.ndarray, -) -> tuple[cp.ndarray, cp.ndarray]: - """Return ``(prec_chol, log|Σ⁻¹|/2)`` where ``prec_chol @ prec_chol.T = Σ⁻¹``. - - This mirrors sklearn's precision-Cholesky orientation: ``prec_chol`` is an - upper-triangular factor built from the covariance Cholesky without forming - the explicit covariance inverse. - """ +def _precision_cholesky(covariances: cp.ndarray) -> tuple[cp.ndarray, cp.ndarray]: + """Return sklearn-oriented precision Cholesky without forming an inverse.""" cov_chol = cp.linalg.cholesky(covariances) eye = cp.broadcast_to( - cp.eye(covariances.shape[-1], dtype=covariances.dtype), covariances.shape + cp.eye(covariances.shape[-1], dtype=covariances.dtype), + covariances.shape, ) cov_chol_inv = solve_triangular(cov_chol, eye, lower=True) - prec_chol = cp.ascontiguousarray(cov_chol_inv.transpose(0, 2, 1)) - log_det_half = -cp.sum(cp.log(cp.diagonal(cov_chol, axis1=1, axis2=2)), axis=1) - return prec_chol, log_det_half - - -@cp.fuse(kernel_name="_gmm_log_pdf_const") -def _log_pdf_const(mahal: cp.ndarray, log_det_half: cp.ndarray, half_d_log2pi): - return -0.5 * mahal + log_det_half - half_d_log2pi + return ( + cp.ascontiguousarray(cov_chol_inv.transpose(0, 2, 1)), + -cp.sum( + cp.log(cp.diagonal(cov_chol, axis1=1, axis2=2)), + axis=1, + ), + ) -def _e_step( +def _m_step( X: cp.ndarray, + responsibilities: cp.ndarray, weights: cp.ndarray, means: cp.ndarray, - prec_chol: cp.ndarray, - log_det_half: cp.ndarray, + covariances: cp.ndarray, *, - backend: str = "cupy", -) -> tuple[cp.ndarray, cp.ndarray]: + reg_covar: float, + ones: cp.ndarray, + effective_counts: cp.ndarray, + weighted_sums: cp.ndarray, + centered: cp.ndarray, + stream: int, + handle: int, +) -> None: n, d = X.shape - K = means.shape[0] - if backend == "cuda" and _use_cuda_e_step(int(d), X.dtype): - workspace = _GMMCudaWorkspace(cp.ascontiguousarray(X), int(K)) - return workspace.e_step(weights, means, prec_chol, log_det_half) - - log_prob = cp.empty((n, K), dtype=X.dtype) - half_d_log2pi = X.dtype.type(0.5 * d * _LOG_2PI) - for k in range(K): - # mahal = || (X - μ_k) @ prec_chol[k] ||² - y = (X - means[k]) @ prec_chol[k] - mahal = cp.einsum("ij,ij->i", y, y) - log_prob[:, k] = _log_pdf_const(mahal, log_det_half[k], half_d_log2pi) - log_prob = log_prob + cp.log(weights) - - log_total = logsumexp(log_prob, axis=1, keepdims=True) - resp = cp.exp(log_prob - log_total) - return resp, log_total.mean() + K = int(weights.shape[0]) + _gc.m_step( + responsibilities, + X, + ones, + weights, + means, + covariances, + effective_counts, + weighted_sums, + centered, + n=int(n), + d=int(d), + K=K, + reg_covar=float(reg_covar), + stream=stream, + handle=handle, + ) -def _m_step( - X: cp.ndarray, - resp: cp.ndarray, - reg_covar: float, - *, - backend: str = "cupy", -) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: - n, d = X.shape - K = resp.shape[1] - - if backend == "cuda": - _gc = _get_gmm_cuda() - - weights = cp.empty(K, dtype=X.dtype) - means = cp.empty((K, d), dtype=X.dtype) - covariances = cp.empty((K, d, d), dtype=X.dtype) - N_k_ws = cp.empty(K, dtype=X.dtype) - num_ws = cp.empty((K, d), dtype=X.dtype) - resp = cp.ascontiguousarray(resp) - X = cp.ascontiguousarray(X) - stream = cp.cuda.get_current_stream().ptr - ones_ws = cp.ones(n, dtype=X.dtype) - centered_ws = cp.empty_like(X) - _gc.m_step( - resp, - X, - ones_ws, - weights, - means, - covariances, - N_k_ws, - num_ws, - centered_ws, - n=int(n), - d=int(d), - K=int(K), - reg_covar=float(reg_covar), - stream=stream, - handle=cp.cuda.device.get_cublas_handle(), - ) - return weights, means, covariances - N_k = resp.sum(axis=0) + 10.0 * cp.finfo(X.dtype).eps # (K,) - weights = N_k / n - means = (resp.T @ X) / N_k[:, None] - covariances = cp.empty((K, d, d), dtype=X.dtype) - eye_reg = reg_covar * cp.eye(d, dtype=X.dtype) - for k in range(K): - diff = X - means[k] - covariances[k] = ((resp[:, k : k + 1] * diff).T @ diff) / N_k[k] + eye_reg - return weights, means, covariances +def _choose_e_step(d: int, dtype) -> _EStepRoute: + """Select the CUDA E-step implementation for a feature width and dtype.""" + dtype = np.dtype(dtype) + if dtype == np.dtype("float32"): + return "cublas" if d >= _CUDA_CUBLAS_E_STEP_MIN_D else "fused" + if dtype == np.dtype("float64"): + return "cublas" if d > _CUDA_FUSED_FLOAT64_MAX_D else "fused" + return "cublas" if d >= _CUDA_CUBLAS_E_STEP_MIN_D else "fused" diff --git a/src/rapids_singlecell/squidpy_gpu/_niche.py b/src/rapids_singlecell/squidpy_gpu/_niche.py index b73b43d0..87cceb94 100644 --- a/src/rapids_singlecell/squidpy_gpu/_niche.py +++ b/src/rapids_singlecell/squidpy_gpu/_niche.py @@ -32,10 +32,11 @@ def calculate_niche( aggregation: Literal["mean", "variance"] = "mean", n_components: int = 10, use_rep: str | None = None, - init: Literal["kmeans", "random_from_data"] = "kmeans", - kmeans_n_init: int = 1, + gmm_init: Literal[ + "random_from_data", "kmeans", "sklearn_kmeans" + ] = "random_from_data", spatial_connectivities_key: str = "spatial_connectivities", - random_state: int = 0, + random_state: int = 42, copy: bool = False, ) -> AnnData | None: """\ @@ -84,14 +85,11 @@ def calculate_niche( Key in ``adata.obsm`` to use as the embedding for ``flavor="cellcharter"``; if provided, the first ``n_components`` columns are used and the shell-aggregation + PCA step is skipped. - init - GMM initialization for ``flavor="cellcharter"``. ``"kmeans"`` (default) - or ``"random_from_data"`` (sklearn-parity). Use the latter if kmeans init lands - on a degenerate component on noisy / low-signal data. - kmeans_n_init - Number of cuML KMeans restarts for ``flavor="cellcharter", init="kmeans"``. - The default ``1`` follows sklearn's GaussianMixture default and keeps the - CUDA GMM path fast. + gmm_init + GMM initialization for ``flavor="cellcharter"``. ``"random_from_data"`` + (default) matches Squidpy's CellCharter path. ``"kmeans"`` uses native + cuML KMeans. ``"sklearn_kmeans"`` uses sklearn-compatible k-means++ seeding + followed by cuML KMeans. spatial_connectivities_key Key in ``adata.obsp`` with the spatial connectivity matrix. random_state @@ -118,14 +116,18 @@ def calculate_niche( adata = adata.copy() if copy else adata if flavor == "cellcharter": + if gmm_init not in ("random_from_data", "kmeans", "sklearn_kmeans"): + raise ValueError( + "`gmm_init` must be one of 'random_from_data', 'kmeans', or " + f"'sklearn_kmeans', got {gmm_init!r}." + ) _run_cellcharter( adata, distance=distance, aggregation=aggregation, n_components=n_components, use_rep=use_rep, - init=init, - kmeans_n_init=kmeans_n_init, + gmm_init=gmm_init, random_state=random_state, key=spatial_connectivities_key, ) @@ -269,8 +271,7 @@ def _run_cellcharter( aggregation: str, n_components: int, use_rep: str | None, - init: str, - kmeans_n_init: int, + gmm_init: str, random_state: int, key: str, ) -> None: @@ -305,8 +306,7 @@ def _run_cellcharter( embedding, n_components=n_components, random_state=random_state, - init=init, - kmeans_n_init=kmeans_n_init, + init=gmm_init, ) adata.obs["cellcharter_niche"] = pd.Categorical(cp.asnumpy(labels).astype(str)) diff --git a/tests/test_gmm.py b/tests/test_gmm.py index d96d8b7e..d3e8c5aa 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -1,20 +1,21 @@ from __future__ import annotations +import inspect + import cupy as cp import numpy as np import pytest +from cupyx.scipy.special import logsumexp from sklearn.metrics import adjusted_rand_score as ARI from sklearn.mixture import GaussianMixture from rapids_singlecell.squidpy_gpu._gmm import ( + _choose_e_step, _e_step, - _GMMCudaWorkspace, + _e_step_cublas, + _e_step_fused, _m_step, _precision_cholesky, - _resolve_backend, - _use_cuda_cublas_e_step, - _use_cuda_e_step, - _use_cuda_fused_e_step, gmm_fit_predict, ) @@ -44,6 +45,119 @@ def _pca_like_mixture(n_per: int, K: int, d: int, seed: int) -> np.ndarray: return np.ascontiguousarray(X[rng.permutation(len(X))]) +_LOG_2PI = float(np.log(2.0 * np.pi)) + + +def _reference_e_step( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, +) -> tuple[cp.ndarray, cp.ndarray]: + n, d = X.shape + K = means.shape[0] + log_prob = cp.empty((n, K), dtype=X.dtype) + half_d_log2pi = X.dtype.type(0.5 * d * _LOG_2PI) + for k in range(K): + y = (X - means[k]) @ prec_chol[k] + mahal = cp.einsum("ij,ij->i", y, y) + log_prob[:, k] = ( + -X.dtype.type(0.5) * mahal + + log_det_half[k] + - half_d_log2pi + + cp.log(weights[k]) + ) + + log_total = logsumexp(log_prob, axis=1, keepdims=True) + return cp.exp(log_prob - log_total), log_total.mean() + + +def _reference_m_step( + X: cp.ndarray, + resp: cp.ndarray, + reg_covar: float, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + n, d = X.shape + K = resp.shape[1] + N_k = resp.sum(axis=0) + 10.0 * cp.finfo(X.dtype).eps + weights = N_k / n + means = (resp.T @ X) / N_k[:, None] + covariances = cp.empty((K, d, d), dtype=X.dtype) + eye_reg = reg_covar * cp.eye(d, dtype=X.dtype) + for k in range(K): + diff = X - means[k] + covariances[k] = ((resp[:, k : k + 1] * diff).T @ diff) / N_k[k] + eye_reg + return weights, means, covariances + + +def _e_step_buffers(X: cp.ndarray, K: int, route: str): + n = X.shape[0] + return ( + cp.empty((n, K), dtype=X.dtype), + cp.empty((n, K), dtype=X.dtype), + cp.empty(n, dtype=X.dtype), + cp.empty_like(X), + cp.empty_like(X) if route == "cublas" else None, + ) + + +def _cuda_e_step( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, +) -> tuple[cp.ndarray, cp.ndarray]: + n, d = X.shape + K = int(means.shape[0]) + e_step_route = _choose_e_step(d, X.dtype) + log_prob, responsibilities, ll_per_cell, centered, e_step_y = _e_step_buffers( + X, K, e_step_route + ) + return _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + responsibilities, + ll_per_cell, + centered, + e_step_y, + e_step_route=e_step_route, + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), + ) + + +def _cuda_m_step( + X: cp.ndarray, + resp: cp.ndarray, + reg_covar: float, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + K = resp.shape[1] + weights = cp.empty(K, dtype=X.dtype) + means = cp.empty((K, X.shape[1]), dtype=X.dtype) + covariances = cp.empty((K, X.shape[1], X.shape[1]), dtype=X.dtype) + _m_step( + X, + resp, + weights, + means, + covariances, + reg_covar, + cp.ones(X.shape[0], dtype=X.dtype), + cp.empty(K, dtype=X.dtype), + cp.empty((K, X.shape[1]), dtype=X.dtype), + cp.empty_like(X), + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), + ) + return weights, means, covariances + + def test_kmeans_init_recovers_well_separated_clusters(): """kmeans init should land at near-truth on well-separated data.""" X_np, y = _well_separated(n_per=300, K=5, d=20, sep=6.0, seed=0) @@ -53,11 +167,7 @@ def test_kmeans_init_recovers_well_separated_clusters(): assert ARI(y, labels) >= 0.95 -@pytest.mark.parametrize("backend", ["cupy", "cuda"]) -def test_full_cov_gmm_matches_sklearn_on_singlecell_embedding(backend): - if backend == "cuda" and _resolve_backend("auto") != "cuda": - pytest.skip("_gmm_cuda extension is not available") - +def test_full_cov_gmm_matches_sklearn_on_singlecell_embedding(): X = _pca_like_mixture(n_per=250, K=5, d=16, seed=7) sk_labels = GaussianMixture( n_components=5, @@ -75,7 +185,6 @@ def test_full_cov_gmm_matches_sklearn_on_singlecell_embedding(backend): n_components=5, random_state=0, init="kmeans", - backend=backend, kmeans_n_init=1, ) ) @@ -92,7 +201,7 @@ def test_random_from_data_init_runs(): cp.asarray(X_np), n_components=5, random_state=0, init="random_from_data" ) ) - assert ARI(y, labels) >= 0.4 + assert ARI(y, labels) >= 0.35 assert len(set(labels.tolist())) >= 2 @@ -106,7 +215,25 @@ def test_output_shape_and_dtype(): assert int(labels.min()) >= 0 -@pytest.mark.parametrize("init", ["kmeans", "random_from_data"]) +def test_non_contiguous_input_is_normalized_at_public_boundary(): + rng = np.random.default_rng(2) + X = cp.asarray(rng.standard_normal((6, 240)).astype(np.float32)).T + + assert not X.flags.c_contiguous + + labels = gmm_fit_predict( + X, + n_components=3, + random_state=0, + max_iter=2, + init="random_from_data", + ) + + assert labels.shape == (240,) + assert labels.dtype == cp.int32 + + +@pytest.mark.parametrize("init", ["kmeans", "random_from_data", "sklearn_kmeans"]) def test_determinism_same_seed(init): rng = np.random.default_rng(1) X = cp.asarray(rng.standard_normal((800, 10)).astype(np.float32)) @@ -121,10 +248,8 @@ def test_invalid_init_raises(): gmm_fit_predict(X, n_components=3, init="bogus") -def test_invalid_backend_raises(): - X = cp.asarray(np.zeros((100, 5), dtype=np.float32)) - with pytest.raises(ValueError, match="backend"): - gmm_fit_predict(X, n_components=3, backend="bogus") +def test_backend_parameter_is_not_exposed(): + assert "backend" not in inspect.signature(gmm_fit_predict).parameters def test_invalid_kmeans_n_init_raises(): @@ -133,39 +258,35 @@ def test_invalid_kmeans_n_init_raises(): gmm_fit_predict(X, n_components=3, kmeans_n_init=0) -def test_auto_backend_uses_cuda_when_available(): - if _resolve_backend("auto") != "cuda": - pytest.skip("_gmm_cuda extension is not available") - - assert _resolve_backend("auto") == "cuda" - assert _resolve_backend("cuda") == "cuda" - - -def test_cuda_e_step_routing_uses_cublas_for_high_d_and_wide_float64(): - assert _use_cuda_e_step(16) - assert _use_cuda_e_step(32) - assert _use_cuda_e_step(50) - assert _use_cuda_e_step(64) - assert _use_cuda_e_step(80) - assert _use_cuda_e_step(96) - assert _use_cuda_e_step(128) - assert _use_cuda_e_step(256) - assert _use_cuda_e_step(384) - assert _use_cuda_e_step(512) - assert not _use_cuda_e_step(768) - assert _use_cuda_e_step(512, cp.float32) - assert _use_cuda_e_step(512, cp.float64) - assert _use_cuda_e_step(64, cp.float64) - assert not _use_cuda_cublas_e_step(64, cp.float32) - assert not _use_cuda_cublas_e_step(80, cp.float32) - assert not _use_cuda_cublas_e_step(128, cp.float32) - assert not _use_cuda_cublas_e_step(256, cp.float32) - assert _use_cuda_fused_e_step(512, cp.float32) - assert not _use_cuda_fused_e_step(128, cp.float64) - assert _use_cuda_cublas_e_step(384, cp.float32) - assert _use_cuda_cublas_e_step(512, cp.float32) - assert _use_cuda_cublas_e_step(128, cp.float64) - assert _use_cuda_cublas_e_step(512, cp.float64) +def test_invalid_n_components_raises(): + X = cp.asarray(np.zeros((100, 5), dtype=np.float32)) + with pytest.raises(ValueError, match="n_components"): + gmm_fit_predict(X, n_components=0) + + +@pytest.mark.parametrize( + ("d", "dtype", "route"), + [ + (16, cp.float32, "fused"), + (32, cp.float32, "fused"), + (50, cp.float32, "fused"), + (64, cp.float32, "fused"), + (80, cp.float32, "fused"), + (96, cp.float32, "fused"), + (128, cp.float32, "fused"), + (256, cp.float32, "fused"), + (384, cp.float32, "cublas"), + (512, cp.float32, "cublas"), + (768, cp.float32, "cublas"), + (2000, cp.float32, "cublas"), + (64, cp.float64, "fused"), + (128, cp.float64, "cublas"), + (512, cp.float64, "cublas"), + (2000, cp.float64, "cublas"), + ], +) +def test_cuda_e_step_routing_uses_cublas_for_high_d_and_wide_float64(d, dtype, route): + assert _choose_e_step(d, dtype) == route def test_n_components_one_returns_single_label(): @@ -182,18 +303,15 @@ def test_float64_input_accepted(): assert labels.shape == (300,) -def test_cuda_backend_matches_cupy_steps(): - if _resolve_backend("auto") != "cuda": - pytest.skip("_gmm_cuda extension is not available") - +def test_cuda_matches_reference_steps(): rng = cp.random.RandomState(0) n, d, K = 40_000, 6, 3 # large enough to exercise the cuBLAS M-step path X = rng.standard_normal((n, d), dtype=cp.float32) logits = rng.standard_normal((n, K), dtype=cp.float32) resp = cp.exp(logits - cp.log(cp.exp(logits).sum(axis=1, keepdims=True))) - w_c, m_c, cov_c = _m_step(X, resp, 1e-6, backend="cupy") - w_g, m_g, cov_g = _m_step(X, resp, 1e-6, backend="cuda") + w_c, m_c, cov_c = _reference_m_step(X, resp, 1e-6) + w_g, m_g, cov_g = _cuda_m_step(X, resp, 1e-6) assert cp.max(cp.abs(w_c - w_g)).item() < 1e-5 assert cp.max(cp.abs(m_c - m_g)).item() < 1e-5 @@ -205,10 +323,19 @@ def test_cuda_backend_matches_cupy_steps(): cov = A @ A.transpose(0, 2, 1) + cp.eye(d, dtype=cp.float32)[None] * 0.1 prec_chol, log_det_half = _precision_cholesky(cov) - r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") - r_g, ll_g = _e_step(X, weights, means, prec_chol, log_det_half, backend="cuda") - r_f, ll_f = _GMMCudaWorkspace(X, K).e_step_fused( - weights, means, prec_chol, log_det_half + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _cuda_e_step(X, weights, means, prec_chol, log_det_half) + log_prob, resp, ll_per_cell, _, _ = _e_step_buffers(X, K, "fused") + r_f, ll_f = _e_step_fused( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + resp, + ll_per_cell, + stream=cp.cuda.get_current_stream().ptr, ) assert cp.max(cp.abs(r_c - r_g)).item() < 1e-4 @@ -217,10 +344,7 @@ def test_cuda_backend_matches_cupy_steps(): assert cp.abs(ll_c - ll_f).item() < 1e-4 -def test_cuda_large_e_step_matches_cupy_for_large_feature_count(): - if _resolve_backend("auto") != "cuda": - pytest.skip("_gmm_cuda extension is not available") - +def test_cuda_large_e_step_matches_reference_for_large_feature_count(): rng = cp.random.RandomState(2) n, d, K = 2048, 96, 4 X = rng.standard_normal((n, d), dtype=cp.float32) @@ -230,17 +354,14 @@ def test_cuda_large_e_step_matches_cupy_for_large_feature_count(): cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 prec_chol, log_det_half = _precision_cholesky(cov) - r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") - r_g, ll_g = _e_step(X, weights, means, prec_chol, log_det_half, backend="cuda") + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _cuda_e_step(X, weights, means, prec_chol, log_det_half) assert cp.max(cp.abs(r_c - r_g)).item() < 5e-4 assert cp.abs(ll_c - ll_g).item() < 5e-4 -def test_cuda_512_e_step_matches_cupy_for_cublas_route(): - if _resolve_backend("auto") != "cuda": - pytest.skip("_gmm_cuda extension is not available") - +def test_cuda_512_e_step_matches_reference_for_cublas_route(): rng = cp.random.RandomState(5) n, d, K = 384, 512, 3 X = rng.standard_normal((n, d), dtype=cp.float32) @@ -250,10 +371,22 @@ def test_cuda_512_e_step_matches_cupy_for_cublas_route(): cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 prec_chol, log_det_half = _precision_cholesky(cov) - r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") - r_g, ll_g = _e_step(X, weights, means, prec_chol, log_det_half, backend="cuda") - r_b, ll_b = _GMMCudaWorkspace(X, K).e_step_cublas( - weights, means, prec_chol, log_det_half + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _cuda_e_step(X, weights, means, prec_chol, log_det_half) + log_prob, resp, ll_per_cell, centered, e_step_y = _e_step_buffers(X, K, "cublas") + r_b, ll_b = _e_step_cublas( + X, + weights, + means, + prec_chol, + log_det_half, + centered, + e_step_y, + log_prob, + resp, + ll_per_cell, + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), ) assert cp.max(cp.abs(r_c - r_g)).item() < 1e-3 @@ -262,10 +395,61 @@ def test_cuda_512_e_step_matches_cupy_for_cublas_route(): assert cp.abs(ll_c - ll_b).item() < 1e-3 -def test_cuda_float64_wide_e_step_uses_cublas_route(): - if _resolve_backend("auto") != "cuda": - pytest.skip("_gmm_cuda extension is not available") +def test_cuda_768_e_step_uses_cublas_route(): + rng = cp.random.RandomState(8) + n, d, K = 64, 768, 2 + X = rng.standard_normal((n, d), dtype=cp.float32) + weights = cp.asarray([0.45, 0.55], dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + eye = cp.eye(d, dtype=cp.float32) + cov = cp.stack((eye * 1.5, eye * 2.0)) + prec_chol, log_det_half = _precision_cholesky(cov) + + log_prob, resp, ll_per_cell, centered, e_step_y = _e_step_buffers(X, K, "cublas") + stream = cp.cuda.get_current_stream().ptr + handle = cp.cuda.device.get_cublas_handle() + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + resp, + ll_per_cell, + centered, + e_step_y, + e_step_route="cublas", + stream=stream, + handle=handle, + ) + log_prob_b, resp_b, ll_per_cell_b, centered_b, e_step_y_b = _e_step_buffers( + X, K, "cublas" + ) + r_b, ll_b = _e_step_cublas( + X, + weights, + means, + prec_chol, + log_det_half, + centered_b, + e_step_y_b, + log_prob_b, + resp_b, + ll_per_cell_b, + stream=stream, + handle=handle, + ) + + assert _choose_e_step(d, X.dtype) == "cublas" + assert cp.max(cp.abs(r_c - r_g)).item() < 1e-3 + assert cp.abs(ll_c - ll_g).item() < 1e-3 + assert cp.max(cp.abs(r_c - r_b)).item() < 1e-3 + assert cp.abs(ll_c - ll_b).item() < 1e-3 + +def test_cuda_float64_wide_e_step_uses_cublas_route(): rng = cp.random.RandomState(7) n, d, K = 256, 128, 3 X = rng.standard_normal((n, d), dtype=cp.float64) @@ -275,20 +459,30 @@ def test_cuda_float64_wide_e_step_uses_cublas_route(): cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float64)[None] * 0.5 prec_chol, log_det_half = _precision_cholesky(cov) - r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") - workspace = _GMMCudaWorkspace(X, K) - r_g, ll_g = workspace.e_step(weights, means, prec_chol, log_det_half) + route = _choose_e_step(d, X.dtype) + log_prob, resp, ll_per_cell, centered, e_step_y = _e_step_buffers(X, K, route) + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + resp, + ll_per_cell, + centered, + e_step_y, + e_step_route=route, + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), + ) assert cp.max(cp.abs(r_c - r_g)).item() < 1e-12 assert cp.abs(ll_c - ll_g).item() < 1e-12 - with pytest.raises(ValueError, match="fused CUDA GMM E-step"): - workspace.e_step_fused(weights, means, prec_chol, log_det_half) - -def test_cuda_fixed_e_step_matches_cupy_for_medium_regime(): - if _resolve_backend("auto") != "cuda": - pytest.skip("_gmm_cuda extension is not available") +def test_cuda_fixed_e_step_matches_reference_for_medium_regime(): rng = cp.random.RandomState(4) n, d, K = 1024, 16, 8 X = rng.standard_normal((n, d), dtype=cp.float32) @@ -298,17 +492,14 @@ def test_cuda_fixed_e_step_matches_cupy_for_medium_regime(): cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 prec_chol, log_det_half = _precision_cholesky(cov) - r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") - r_g, ll_g = _e_step(X, weights, means, prec_chol, log_det_half, backend="cuda") + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _cuda_e_step(X, weights, means, prec_chol, log_det_half) assert cp.max(cp.abs(r_c - r_g)).item() < 5e-4 assert cp.abs(ll_c - ll_g).item() < 5e-4 -def test_cuda_fused_e_step_matches_cupy_for_50_pc_regime(): - if _resolve_backend("auto") != "cuda": - pytest.skip("_gmm_cuda extension is not available") - +def test_cuda_fused_e_step_matches_reference_for_50_pc_regime(): rng = cp.random.RandomState(6) n, d, K = 1024, 50, 12 X = rng.standard_normal((n, d), dtype=cp.float32) @@ -318,17 +509,29 @@ def test_cuda_fused_e_step_matches_cupy_for_50_pc_regime(): cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 prec_chol, log_det_half = _precision_cholesky(cov) - r_c, ll_c = _e_step(X, weights, means, prec_chol, log_det_half, backend="cupy") - r_d, ll_d = _GMMCudaWorkspace(X, K).e_step(weights, means, prec_chol, log_det_half) + log_prob, resp, ll_per_cell, centered, e_step_y = _e_step_buffers(X, K, "fused") + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_d, ll_d = _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + resp, + ll_per_cell, + centered, + e_step_y, + e_step_route="fused", + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), + ) assert cp.max(cp.abs(r_c - r_d)).item() < 5e-4 assert cp.abs(ll_c - ll_d).item() < 5e-4 -def test_cuda_backend_runs_large_feature_count(): - if _resolve_backend("auto") != "cuda": - pytest.skip("_gmm_cuda extension is not available") - +def test_cuda_runs_large_feature_count(): rng = np.random.default_rng(3) X = cp.asarray(rng.standard_normal((360, 80)).astype(np.float32)) labels = gmm_fit_predict( @@ -338,7 +541,6 @@ def test_cuda_backend_runs_large_feature_count(): max_iter=2, reg_covar=1e-2, init="random_from_data", - backend="cuda", ) assert labels.shape == (360,) diff --git a/tests/test_niche.py b/tests/test_niche.py index 40757bd6..fbee6226 100644 --- a/tests/test_niche.py +++ b/tests/test_niche.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from pathlib import Path import cupy as cp @@ -289,6 +290,15 @@ def test_cellcharter_basic(adata): assert col.nunique() <= 4 +def test_calculate_niche_exposes_only_gmm_init_extension(): + params = inspect.signature(calculate_niche).parameters + assert "init" not in params + assert "kmeans_n_init" not in params + assert "gmm_init" in params + assert params["gmm_init"].default == "random_from_data" + assert params["random_state"].default == 42 + + def test_cellcharter_distance_zero(adata): """distance=0 falls back to PCA + GMM on raw X (no shell aggregation).""" calculate_niche(adata, flavor="cellcharter", n_components=3, distance=0) @@ -303,6 +313,21 @@ def test_cellcharter_use_rep(adata): assert "cellcharter_niche" in adata.obs.columns +@pytest.mark.parametrize("gmm_init", ["random_from_data", "kmeans", "sklearn_kmeans"]) +def test_cellcharter_gmm_init_options(adata, gmm_init): + rng = np.random.default_rng(0) + adata.obsm["X_test"] = rng.standard_normal((adata.n_obs, 10)).astype(np.float32) + calculate_niche( + adata, + flavor="cellcharter", + n_components=4, + use_rep="X_test", + gmm_init=gmm_init, + random_state=0, + ) + assert "cellcharter_niche" in adata.obs.columns + + def test_cellcharter_determinism(adata): a1 = adata.copy() a2 = adata.copy() @@ -328,16 +353,9 @@ def test_cellcharter_invalid_aggregation(adata): ) -def test_cellcharter_init_random_from_data(adata): - """`init="random_from_data"` is a valid escape hatch from kmeans init.""" - calculate_niche( - adata, - flavor="cellcharter", - n_components=4, - init="random_from_data", - random_state=0, - ) - assert "cellcharter_niche" in adata.obs.columns +def test_cellcharter_invalid_gmm_init(adata): + with pytest.raises(ValueError, match="gmm_init"): + calculate_niche(adata, flavor="cellcharter", n_components=4, gmm_init="bogus") def test_cellcharter_bad_n_components(adata):