diff --git a/.gitignore b/.gitignore index 931c8cae..6a904c74 100644 --- a/.gitignore +++ b/.gitignore @@ -236,3 +236,6 @@ compile_commands.json # Rust lib Cargo.lock + +/examples/results +*.npy \ No newline at end of file diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 00000000..e390344d --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,89 @@ +# TopK Kernel Benchmarking Suite + +Standalone benchmarking for Vortex's three topk kernel variants, measuring kernel-level latency isolated from the full SGLang inference pipeline. + +## Kernel Variants + +| Kernel | Description | +|--------|-------------| +| `naive` | CUB radix sort (bf16 only) | +| `sglang_m0` | Two-stage hierarchical radix sort, no mapping | +| `sglang_m1` | + LUT mapping (requires `--lut-path`) | +| `sglang_m2` | + Quantile mapping (requires `--quantiles-path`) | +| `sglang_m3` | + Power mapping (configurable via `--mapping-power`) | +| `sglang_m4` | + Log mapping | + +## Quick Start + +```bash +# Activate environment +source /scr/dataset/yuke/xinrui/uv_env/vortex/bin/activate + +# Quick single-config test +python benchmarking/bench_topk.py \ + --batch-sizes 8 \ + --seq-lens 4096 \ + --topk-vals 30 \ + --num-kv-heads 2 \ + --repeat 200 + +# Sweep with histogram analysis +python benchmarking/bench_topk.py \ + --batch-sizes 4 8 16 \ + --seq-lens 2048 4096 8192 \ + --topk-vals 30 64 \ + --num-kv-heads 2 \ + --repeat 100 \ + --histogram + +# Full sweep with JSON output +python benchmarking/bench_topk.py \ + --output-json benchmarking/results.json \ + --histogram +``` + +## CLI Options + +| Argument | Default | Description | +|----------|---------|-------------| +| `--batch-sizes` | 1 4 8 16 32 64 | Batch sizes to sweep | +| `--seq-lens` | 1024 2048 4096 8192 | Sequence lengths to sweep | +| `--topk-vals` | 16 30 64 | TopK values to sweep | +| `--num-kv-heads` | 2 4 8 | KV head counts to sweep | +| `--page-size` | 16 | Tokens per page | +| `--reserved-bos` | 1 | Reserved BOS pages | +| `--reserved-eos` | 2 | Reserved EOS pages | +| `--score-dtype` | bfloat16 | Score tensor dtype (bfloat16 or float32) | +| `--distributions` | normal lognormal uniform | Score distributions to test | +| `--warmup` | 10 | Warmup iterations | +| `--repeat` | 100 | Timed iterations | +| `--mapping-power` | 0.5 | Power parameter for mode=3 | +| `--lut-path` | None | Path to .npy uint8[256] LUT for mode=1 | +| `--quantiles-path` | None | Path to .npy float32[256] quantiles for mode=2 | +| `--output-json` | None | Save results to JSON file | +| `--filter-kernels` | None | Only run specific kernels (e.g., `naive sglang_m0`) | +| `--histogram` | False | Collect bin distribution statistics | + +## Histogram Analysis + +When `--histogram` is passed, each config additionally runs `topk_profile_histogram` and reports: + +- **max/mean ratio**: Peak bin count divided by average (lower = more uniform) +- **Gini coefficient**: Inequality measure of bin distribution (0 = perfectly uniform) +- **nonzero_bins**: How many of the 256 bins received any values + +This shows whether mapping modes improve bin uniformity for a given score distribution. + +## Output Format + +``` +TopK Kernel Benchmark Results +GPU: NVIDIA H100 80GB HBM3 | SM count: 132 + +bs=8 | seq=4096 | topk=30 | heads=2 | pages/seg=256 | dist=normal + naive : 0.0420ms (median) +/- 0.0030ms [min=0.0390, max=0.0510] + sglang mode=0 : 0.0310ms (median) +/- 0.0020ms [min=0.0290, max=0.0380] + sglang mode=3 : 0.0330ms (median) +/- 0.0020ms [min=0.0300, max=0.0400] + sglang mode=4 : 0.0320ms (median) +/- 0.0020ms [min=0.0300, max=0.0390] + histogram stats : max/mean=3.99 gini=0.568 nonzero_bins=70/256 +``` diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmarks/analyze_topk_distribution.py b/benchmarks/analyze_topk_distribution.py new file mode 100644 index 00000000..5531187e --- /dev/null +++ b/benchmarks/analyze_topk_distribution.py @@ -0,0 +1,494 @@ +""" +TopK distribution analysis and visualization. + +Loads profiling data from: + - profile_topk_distribution.py output (.npz): raw histograms, LUT tables + - bench_topk.py output (.json): benchmark results + per-mode histogram data + +Produces visualization plots for evaluating mapping mode effectiveness. + +Usage: + python scripts/analyze_topk_distribution.py \ + --bench-json bench_hitrate.json \ + --output-dir plots/ + + python scripts/analyze_topk_distribution.py \ + --profile-npz profile_output.npz \ + --bench-json bench_hitrate.json \ + --output-dir plots/ --max-segments 8 +""" + +import argparse +import json +import os +from typing import Optional + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +import numpy as np + +# Canonical mapping mode names — shared across all profiling/analysis tools +MAPPING_MODE_NAMES = { + 0: "None", + 1: "LUT CDF", + 2: "Quantile", + 3: "Power", + 4: "Log", + 5: "Index Cache", + 6: "Asinh", + 7: "Log1p", + 8: "Trunc8", + 9: "Erf", + 10: "Tanh", + 11: "Subtract", + 13: "ExpStretch", + 14: "TopkWindow", +} + +MAPPING_MODE_FORMULAS = { + 0: "None (fp16 bucketing)", + 1: "LUT CDF (calibrated)", + 2: "Quantile (calibrated)", + 3: "Power: sign(x)*|x|^p", + 4: "Log: sign(x)*log(|x|+1)", + 5: "Index Cache", + 6: "Asinh: asinh(beta*x)", + 7: "Log1p: sign(x)*log1p(alpha*|x|)", + 8: "Trunc8: bf16 upper-8-bit bucketing", + 9: "Erf: erf(alpha*x)", + 10: "Tanh: tanh(alpha*x)", + 11: "Subtract: x - pivot (RadiK-style)", + 13: "ExpStretch: exp(alpha*x)", + 14: "TopkWindow: k-aware linear windowing", +} + + +def _mode_key_to_display(mode_key: str) -> str: + """Convert a mode key like 'mode_3', 'mode_3_Power', or 'mode_3_Power_noscale' to display name.""" + # Handle noscale suffix + noscale = mode_key.endswith("_noscale") + base_key = mode_key[:-len("_noscale")] if noscale else mode_key + suffix = " noscale" if noscale else "" + + # Handle new format: "mode_3_Power" + parts = base_key.split("_", 2) + if len(parts) >= 3: + return parts[2] + suffix # e.g. "Power noscale" + # Handle old format: "mode_3" + try: + mode_num = int(parts[1]) + return MAPPING_MODE_NAMES.get(mode_num, base_key) + suffix + except (IndexError, ValueError): + return mode_key + + +def _mode_key_to_number(mode_key: str) -> int: + """Extract the mode number from a key like 'mode_3', 'mode_3_Power', or 'mode_3_Power_noscale'.""" + parts = mode_key.split("_") + try: + return int(parts[1]) + except (IndexError, ValueError): + return -1 + + +def compute_per_segment_stats(histograms: np.ndarray) -> dict: + """Compute per-row Gini coefficient and max/mean ratio. + + Args: + histograms: [num_segments, 256] array of bin counts + + Returns: + dict with 'gini' and 'max_mean' arrays of shape [num_segments] + """ + num_seg = histograms.shape[0] + ginis = np.zeros(num_seg) + max_means = np.zeros(num_seg) + + for i in range(num_seg): + row = histograms[i].astype(np.float64) + nonzero = row[row > 0] + if len(nonzero) == 0: + continue + + max_means[i] = nonzero.max() / nonzero.mean() + + # Gini coefficient + sorted_vals = np.sort(nonzero) + n = len(sorted_vals) + index = np.arange(1, n + 1, dtype=np.float64) + ginis[i] = (2.0 * (index * sorted_vals).sum() / (n * sorted_vals.sum()) - (n + 1) / n) + ginis[i] = max(0.0, ginis[i]) + + return {"gini": ginis, "max_mean": max_means} + + +def plot_bin_distribution(histograms: np.ndarray, output_dir: str, max_segments: int = 4): + """Plot 256-bin bar chart per segment (first N segments).""" + num_seg = min(histograms.shape[0], max_segments) + for i in range(num_seg): + fig, ax = plt.subplots(figsize=(12, 4)) + ax.bar(range(256), histograms[i], width=1.0, color="steelblue", edgecolor="none") + ax.set_xlabel("Bin") + ax.set_ylabel("Count") + ax.set_title(f"Segment {i}: 256-bin histogram") + ax.set_xlim(-1, 256) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"bin_dist_seg_{i}.png"), dpi=150) + plt.close(fig) + print(f" Saved {num_seg} bin distribution plots") + + +def plot_bin_heatmap(histograms: np.ndarray, output_dir: str): + """Heatmap: segments x bins, LogNorm colormap.""" + fig, ax = plt.subplots(figsize=(14, max(4, histograms.shape[0] * 0.15 + 1))) + # Add 1 to avoid log(0) + data = histograms.astype(np.float64) + 1 + im = ax.imshow( + data, + aspect="auto", + cmap="viridis", + norm=mcolors.LogNorm(vmin=1, vmax=data.max()), + interpolation="nearest", + ) + ax.set_xlabel("Bin") + ax.set_ylabel("Segment") + ax.set_title("Bin distribution heatmap (log scale)") + fig.colorbar(im, ax=ax, label="Count + 1") + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "bin_heatmap.png"), dpi=150) + plt.close(fig) + print(" Saved bin_heatmap.png") + + +def plot_before_after_mapping( + raw_histograms: np.ndarray, + lut_table: np.ndarray, + output_dir: str, + max_segments: int = 4, +): + """Side-by-side: raw histogram vs. LUT-remapped histogram.""" + num_seg = min(raw_histograms.shape[0], max_segments) + for i in range(num_seg): + raw = raw_histograms[i] + # Remap: redistribute counts through LUT + remapped = np.zeros(256, dtype=np.float64) + for bin_idx in range(256): + new_bin = int(lut_table[bin_idx]) + remapped[new_bin] += raw[bin_idx] + + fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharey=True) + axes[0].bar(range(256), raw, width=1.0, color="steelblue", edgecolor="none") + axes[0].set_title(f"Segment {i}: Raw (mode=0)") + axes[0].set_xlabel("Bin") + axes[0].set_ylabel("Count") + + axes[1].bar(range(256), remapped, width=1.0, color="darkorange", edgecolor="none") + axes[1].set_title(f"Segment {i}: After LUT remap") + axes[1].set_xlabel("Bin") + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"mapping_comparison_{i}.png"), dpi=150) + plt.close(fig) + print(f" Saved {num_seg} mapping comparison plots") + + +def plot_summary_table( + histograms: np.ndarray, + mode_stats_data: Optional[dict], + output_dir: str, +): + """Per-segment stats table: Gini, max/mean, resolution rate.""" + stats = compute_per_segment_stats(histograms) + num_seg = histograms.shape[0] + + col_labels = ["Segment", "Gini", "Max/Mean"] + cell_data = [] + for i in range(num_seg): + cell_data.append([str(i), f"{stats['gini'][i]:.3f}", f"{stats['max_mean'][i]:.2f}"]) + + fig, ax = plt.subplots(figsize=(6, max(2, num_seg * 0.4 + 1))) + ax.axis("off") + table = ax.table(cellText=cell_data, colLabels=col_labels, loc="center", cellLoc="center") + table.auto_set_font_size(False) + table.set_fontsize(9) + table.scale(1.0, 1.3) + ax.set_title("Per-segment distribution stats", fontsize=11, pad=10) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "summary_table.png"), dpi=150, bbox_inches="tight") + plt.close(fig) + print(" Saved summary_table.png") + + +def plot_distribution_comparison(dist_histograms: dict, output_dir: str, suffix: str = "", title: str = ""): + """Overlay 256-bin distributions for different data sources (uniform, normal, real). + + Args: + dist_histograms: {"uniform": [256], "normal": [256], "real": [256], ...} + output_dir: output directory for the plot + suffix: optional suffix for output filename (e.g. "_m0") + title: optional custom title for the plot + """ + names = list(dist_histograms.keys()) + n = len(names) + if n == 0: + print(" No distribution histograms to compare") + return + + fig, axes = plt.subplots(1, n, figsize=(6 * n, 4), squeeze=False) + axes = axes[0] + + for idx, name in enumerate(names): + counts = np.array(dist_histograms[name], dtype=np.float64) + ax = axes[idx] + ax.bar(range(256), counts, width=1.0, color="steelblue", edgecolor="none") + ax.set_xlabel("Bucket") + ax.set_ylabel("Count") + ax.set_xlim(-1, 256) + ax.set_title(name) + + # Annotate with stats + nonzero = counts[counts > 0] + if len(nonzero) > 0: + mean_val = nonzero.mean() + max_val = nonzero.max() + max_mean = max_val / mean_val if mean_val > 0 else 0.0 + sorted_vals = np.sort(nonzero) + nn = len(sorted_vals) + index = np.arange(1, nn + 1, dtype=np.float64) + gini = max(0.0, 2.0 * (index * sorted_vals).sum() / (nn * sorted_vals.sum()) - (nn + 1) / nn) + nz_bins = int(len(nonzero)) + else: + max_mean = gini = 0.0 + nz_bins = 0 + + stats_text = f"gini={gini:.3f}\nmax/mean={max_mean:.2f}\nbins={nz_bins}/256" + ax.text(0.97, 0.95, stats_text, transform=ax.transAxes, + fontsize=8, verticalalignment="top", horizontalalignment="right", + bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.7)) + + fig.suptitle(title or "Bucket Distribution Comparison", fontsize=13) + fig.tight_layout() + fname = f"distribution_comparison{suffix}.png" + fig.savefig(os.path.join(output_dir, fname), dpi=150) + plt.close(fig) + print(f" Saved {fname}") + + +def save_bucket_table(dist_histograms: dict, output_dir: str, filename: str = "bucket_counts.csv"): + """Write a CSV table listing the count per bucket for each distribution. + + Columns: bucket, dist1, dist2, ... (256 rows, one per bucket). + """ + import csv + + names = list(dist_histograms.keys()) + if not names: + return + + path = os.path.join(output_dir, filename) + with open(path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["bucket"] + names) + for b in range(256): + row = [b] + [int(dist_histograms[n][b]) for n in names] + writer.writerow(row) + + # Also print a compact summary to stdout (top-20 hottest buckets per dist) + print(f" Saved {path}") + for name in names: + counts = np.array(dist_histograms[name], dtype=np.int64) + total = counts.sum() + top_idx = np.argsort(counts)[::-1][:20] + print(f" [{name}] total={total} top-20 hottest buckets:") + for rank, idx in enumerate(top_idx): + if counts[idx] == 0: + break + pct = counts[idx] / total * 100 if total > 0 else 0 + print(f" #{rank+1:2d} bucket {idx:3d}: {counts[idx]:>10d} ({pct:5.1f}%)") + + +def plot_mapping_mode_comparison(mode_stats_data: dict, output_dir: str): + """Grouped bar chart comparing modes on gini and max/mean.""" + modes = sorted(mode_stats_data.keys()) + if not modes: + print(" No histogram data to plot mode comparison") + return + + mode_labels = [] + for m in modes: + label = _mode_key_to_display(m) + param = mode_stats_data[m].get("param") + if param: + label = f"{label} ({param})" + mode_labels.append(label) + ginis = [mode_stats_data[m]["gini"] for m in modes] + max_means = [mode_stats_data[m]["max_mean_ratio"] for m in modes] + + x = np.arange(len(modes)) + width = 0.3 + + fig, ax1 = plt.subplots(figsize=(max(10, len(modes) * 0.8), 5)) + ax2 = ax1.twinx() + + bars1 = ax1.bar(x - width / 2, ginis, width, label="Gini", color="darkorange") + bars2 = ax2.bar(x + width / 2, max_means, width, label="Max/Mean", color="seagreen", alpha=0.7) + + ax1.set_xlabel("Mapping Mode") + ax1.set_ylabel("Gini") + ax2.set_ylabel("Max/Mean Ratio") + ax1.set_xticks(x) + ax1.set_xticklabels(mode_labels, rotation=30, ha="right") + ax1.set_ylim(0, 1.1) + ax1.set_title("Mapping Mode Comparison") + + # Combine legends + lines1, labels1 = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right") + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "mode_comparison.png"), dpi=150) + plt.close(fig) + print(" Saved mode_comparison.png") + + +def main(): + parser = argparse.ArgumentParser(description="Analyze TopK bucket sort distribution") + parser.add_argument("--profile-npz", type=str, default=None, + help="Path to .npz from profile_topk_distribution.py") + parser.add_argument("--bench-json", type=str, default=None, + help="Path to JSON from bench_topk.py") + parser.add_argument("--output-dir", type=str, default="plots", + help="Directory for output plots") + parser.add_argument("--max-segments", type=int, default=4, + help="Max segments for per-segment plots") + parser.add_argument("--real-histograms", type=str, default=None, + help="Path to .npy raw_histograms from calibrate_topk.py (real-data bucket counts)") + args = parser.parse_args() + + if args.profile_npz is None and args.bench_json is None and args.real_histograms is None: + parser.error("At least one of --profile-npz, --bench-json, or --real-histograms is required") + + os.makedirs(args.output_dir, exist_ok=True) + print(f"Output directory: {args.output_dir}") + + raw_histograms = None + lut_table = None + mode_stats_data = None + + # Load profile data + if args.profile_npz: + print(f"\nLoading profile data from {args.profile_npz}") + data = np.load(args.profile_npz, allow_pickle=True) + if "raw_histograms" in data: + raw_histograms = data["raw_histograms"] + print(f" raw_histograms: {raw_histograms.shape}") + if "aggregate_lut" in data: + lut_table = data["aggregate_lut"] + print(f" aggregate_lut: {lut_table.shape}") + elif "lut_tables" in data: + # Use first LUT if aggregate not available + lut_table = data["lut_tables"] + if lut_table.ndim > 1: + lut_table = lut_table[0] + print(f" lut_table: {lut_table.shape}") + + # Load bench data + dist_histograms = {} # {distribution_name: [256] counts} for comparison plot + mode_histograms = {} # {mode_key: {dist_name: [256]}} for per-mode plots + + if args.bench_json: + print(f"\nLoading benchmark data from {args.bench_json}") + with open(args.bench_json) as f: + bench_data = json.load(f) + + if bench_data and isinstance(bench_data, list): + # Use first config entry for histogram mode visualization + entry = bench_data[0] + if "histograms" in entry: + mode_stats_data = entry["histograms"] + print(f" Histogram modes: {list(mode_stats_data.keys())}") + + # Extract raw_counts per distribution from bench entries + for entry in bench_data: + dist_name = entry.get("distribution", "unknown") + hist_data = entry.get("histogram", {}) + if "raw_counts" in hist_data and dist_name not in dist_histograms: + dist_histograms[dist_name] = hist_data["raw_counts"] + print(f" Loaded histogram for distribution: {dist_name}") + + # Extract per-mode histograms from histograms data + mode_histograms = {} # {mode_key: {dist_name: [256]}} + for entry in bench_data: + dist_name = entry.get("distribution", "unknown") + histograms_data = entry.get("histograms", {}) + for mode_key, mode_data in histograms_data.items(): + if isinstance(mode_data, dict) and "raw_counts" in mode_data: + if mode_key not in mode_histograms: + mode_histograms[mode_key] = {} + if dist_name not in mode_histograms[mode_key]: + mode_histograms[mode_key][dist_name] = mode_data["raw_counts"] + if mode_histograms: + print(f" Loaded per-mode histograms for: {sorted(mode_histograms.keys())}") + + # Load real-data histograms from .npy (calibrate_topk.py output) + real_counts = None + if args.real_histograms: + print(f"\nLoading real-data histograms from {args.real_histograms}") + real_hists = np.load(args.real_histograms) # [num_samples, 256] + real_counts = real_hists.sum(axis=0).tolist() # aggregate across samples + dist_histograms["real"] = real_counts + print(f" real_histograms shape: {real_hists.shape}, aggregated to [256]") + + # Generate plots + if raw_histograms is not None: + print("\nGenerating histogram plots...") + plot_bin_distribution(raw_histograms, args.output_dir, args.max_segments) + plot_bin_heatmap(raw_histograms, args.output_dir) + plot_summary_table(raw_histograms, mode_stats_data, args.output_dir) + + if lut_table is not None: + print("\nGenerating before/after mapping comparison...") + plot_before_after_mapping(raw_histograms, lut_table, args.output_dir, args.max_segments) + + if mode_stats_data is not None: + print("\nGenerating mode comparison plot...") + plot_mapping_mode_comparison(mode_stats_data, args.output_dir) + + if dist_histograms: + print("\nGenerating distribution comparison plot (raw/unmapped)...") + plot_distribution_comparison(dist_histograms, args.output_dir) + print("\nSaving bucket count table (raw/unmapped)...") + save_bucket_table(dist_histograms, args.output_dir) + + # Per-mode distribution plots and tables + if mode_histograms: + print("\nGenerating per-mode distribution plots and tables...") + for mode_key in sorted(mode_histograms): + mname = _mode_key_to_display(mode_key) + mode_num = _mode_key_to_number(mode_key) + mformula = MAPPING_MODE_FORMULAS.get(mode_num, mname) + # Include hyperparameter value in title if available + param_str = "" + if mode_stats_data and mode_key in mode_stats_data: + param = mode_stats_data[mode_key].get("param") + if param: + param_str = f" [{param}]" + mode_suffix = mname.lower().replace(" ", "_") + plot_distribution_comparison( + mode_histograms[mode_key], args.output_dir, + suffix=f"_{mode_suffix}", + title=f"Bucket Distribution — {mname}{param_str} ({mformula})", + ) + save_bucket_table( + mode_histograms[mode_key], args.output_dir, + filename=f"bucket_counts_{mode_suffix}.csv", + ) + + print(f"\nDone. All outputs saved to {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py new file mode 100644 index 00000000..e103953e --- /dev/null +++ b/benchmarks/autotune_topk_mapping.py @@ -0,0 +1,439 @@ +""" +Auto-tune TopK mapping hyperparameters by profiled kernel latency. + +For each (mode, hyperparameter) combo in the sweep grid, this script picks +the hyperparameter whose remapped score distribution produces the lowest +*unfused* topk kernel latency. The measurement is a split-phase: + + 1. topk_remap_only(x, mode, power) → float32 buffer [NOT timed] + 2. topk_output_sglang(remapped) [TIMED] + +Timing only step 2 isolates the Stage-2 radix cost, which is what bucket +uniformity actually affects. The remap cost is the same constant regardless +of power, so it would only pollute the ranking. + +Non-arithmetic baselines (MAPPING_LUT_CDF=1, MAPPING_QUANTILE=2, +MAPPING_TRUNC8=8) route their mapping through compute_stage1_bin, not +apply_transform, so split-phase is a no-op for them. Those are timed via +the fused kernel and marked `timing_mode="fused_fallback"` in the output. + +Distribution statistics (gini, max/mean, counter-based Stage-2 cost) are +still collected for diagnostics, but they do NOT drive the ranking — the +ranking is purely latency-driven. + +Usage: + python benchmarks/autotune_topk_mapping.py \\ + --topk-val 2048 --batch-size 4 --seq-len 65536 --num-kv-heads 8 \\ + --real-histograms calibration/raw_histograms.npy \\ + --output-json autotune_results.json +""" + +import argparse +import json +import math +from typing import Dict, List, Optional + +import numpy as np +import torch + +from bench_topk import ( + make_topk_inputs, + bench_kernel, + compute_histogram_stats, + scores_from_histogram, +) +from vortex_torch_C import ( + topk_output_sglang, + topk_output_sglang_fused, + topk_remap_only, + topk_profile_histogram, + topk_profile_counters, +) + + +# Modes where topk_mapping.cuh::apply_transform is a genuine value-space +# transform (power / asinh / log / log1p / erf / tanh / subtract / exp_stretch, +# plus the top-spreading shift_pow2 / shift_pow3 / linear_steep family) and +# also mode 0 (identity). For these the split-phase `remap_only + unfused +# topk` is correct. Modes 1/2/8 (LUT_CDF / QUANTILE / TRUNC8) apply their +# mapping inside compute_stage1_bin, so split-phase is a no-op. +ARITHMETIC_MODES = {0, 3, 4, 6, 7, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20} + + +# Only parametric modes need auto-tuning. Mode 0 (none) and mode 4 (log) +# have no knob; mode 0 is always the baseline. Sweep grids widened so the +# autotune actually explores the tails of each transform. +SWEEP_GRID: Dict[int, List[float]] = { + 3: [0.1, 0.5, 1.0, 2.0, 4.0, 5.0, 9.0], # power: p + 6: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # asinh: beta + 7: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # log1p: alpha + 9: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # erf: alpha + 10: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # tanh: alpha + 11: [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], # subtract: pivot + 13: [0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0], # exp_stretch: alpha + 15: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # shift_pow2: pivot + 16: [-4.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], # shift_pow3: pivot (widened) + 17: [0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0], # linear_steep: k + 18: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # half_square: pivot + 19: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # half_cube: pivot + # dense_mant clamp: sweep a wide range because real attention scores + # can span [-400, +200] on some models (raw logits), not just [0, 1]. + 20: [0.0, 1.0, 5.0, 10.0, 20.0, 50.0, 100.0], # dense_mant: clamp pivot +} + +PARAM_NAME = { + 3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", + 15: "pivot", 16: "pivot", 17: "k", + 18: "pivot", 19: "pivot", + 20: "clamp", +} +MODE_NAMES = { + 0: "none", 1: "lut_cdf", 2: "quantile", + 3: "power", 4: "log", 6: "asinh", 7: "log1p", + 8: "trunc8", 9: "erf", 10: "tanh", 11: "subtract", 13: "exp_stretch", + 15: "shift_pow2", 16: "shift_pow3", 17: "linear_steep", + 18: "half_square", 19: "half_cube", + 20: "dense_mant", +} + +# Non-parametric modes — no knob to sweep; timed once as a reference point. +# LUT_CDF (1) and QUANTILE (2) are added here at runtime when the caller +# passes --lut-path / --quantiles-path. +BASELINES = [(0, 0.5), (4, 0.5), (8, 0.5)] + + +# ---------- Real-distribution score generation ---------- +# _build_bin_range_table / scores_from_histogram now live in bench_topk.py +# so both autotune and bench_topk draw scores from the same sampler. + + +def _make_real_inputs(args, histogram: np.ndarray) -> dict: + eff_bs = args.batch_size * args.num_kv_heads + num_pages_per_seg = math.ceil(args.seq_len / args.page_size) + total_dense = eff_bs * num_pages_per_seg + sparse_per_seg = min(args.topk_val + args.reserved_bos + args.reserved_eos, num_pages_per_seg) + + dense_kv_indptr = torch.arange( + 0, (eff_bs + 1) * num_pages_per_seg, num_pages_per_seg, + dtype=torch.int32, device="cuda", + ) + sparse_kv_indptr = torch.arange( + 0, (eff_bs + 1) * sparse_per_seg, sparse_per_seg, + dtype=torch.int32, device="cuda", + ) + dense_kv_indices = torch.arange(total_dense, dtype=torch.int32, device="cuda") + sparse_kv_indices = torch.zeros(eff_bs * sparse_per_seg, dtype=torch.int32, device="cuda") + x = scores_from_histogram(histogram, total_dense, device="cuda", + score_dtype=torch.bfloat16) + remapped = torch.empty(total_dense, dtype=torch.float32, device="cuda").reshape(x.shape) + + return { + "x": x, + "remapped": remapped, + "dense_kv_indptr": dense_kv_indptr, + "sparse_kv_indptr": sparse_kv_indptr, + "dense_kv_indices": dense_kv_indices, + "sparse_kv_indices": sparse_kv_indices, + "eff_batch_size": eff_bs, + "num_pages_per_seg": num_pages_per_seg, + "sparse_per_seg": sparse_per_seg, + } + + +def _ensure_remapped_buffer(inputs: dict) -> torch.Tensor: + """Lazy-allocate a float32 buffer matching x.shape for the split-phase.""" + buf = inputs.get("remapped") + if buf is None: + x = inputs["x"] + buf = torch.empty(x.numel(), dtype=torch.float32, device=x.device).reshape(x.shape) + inputs["remapped"] = buf + return buf + + +# ---------- Latency-based evaluation ---------- + +def _time_fused(inputs, args, mode: int, power: float) -> dict: + """Fused remap+topk kernel latency (used as fallback for modes 1/2/8).""" + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + inputs["sparse_kv_indices"].zero_() + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + lut_t, + q_t, + ) + return bench_kernel(topk_output_sglang_fused, call_args, + warmup=args.warmup, repeat=args.repeat) + + +def _time_unfused_on_remapped(inputs, args, mode: int, power: float) -> dict: + """Time the unfused topk kernel on pre-remapped scores. + + For mode 0 the original scores are used directly. For every other + arithmetic mode we run topk_remap_only once (not timed) into a + pre-allocated float32 buffer, then time topk_output_sglang on that + buffer with bench_kernel's warmup + repeat loop. This isolates the + Stage-2 radix cost from the remap pass. + """ + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + + if mode == 0: + src = inputs["x"] + else: + remapped = _ensure_remapped_buffer(inputs) + topk_remap_only( + inputs["x"], + inputs["dense_kv_indptr"], + remapped, + eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + float(power), + ) + torch.cuda.synchronize() + src = remapped + + inputs["sparse_kv_indices"].zero_() + call_args = ( + src, + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + ) + return bench_kernel(topk_output_sglang, call_args, + warmup=args.warmup, repeat=args.repeat) + + +def _time_mode(inputs, args, mode: int, power: float) -> tuple: + """Returns (latency_dict, timing_mode_str).""" + if mode in ARITHMETIC_MODES: + return _time_unfused_on_remapped(inputs, args, mode, power), "unfused_on_remapped" + return _time_fused(inputs, args, mode, power), "fused_fallback" + + +def _collect_diagnostics(inputs, args, mode: int, power: float) -> dict: + """Optional distribution/counter stats for reporting only (post-timing).""" + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + diag = {} + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + + if args.collect_stats: + hist = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + inputs["x"], inputs["dense_kv_indptr"], hist, + eff_bs, args.reserved_bos, args.reserved_eos, + mode, power, lut_t, q_t, + ) + torch.cuda.synchronize() + diag.update(compute_histogram_stats(hist)) + + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + inputs["sparse_kv_indices"].zero_() + topk_profile_counters( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + c = counter_buf.float() + diag["threshold_bin_mean"] = c[:, 0].mean().item() + diag["num_equal_mean"] = c[:, 2].mean().item() + diag["refine_rounds_mean"] = c[:, 4].mean().item() + # selected_from_thr = topk_val - num_above (clamped >= 0). Used as + # a tie-breaker by bench_topk._load_autotune_hparams when several + # modes have indistinguishable latency. + sel_from_thr = (float(args.topk_val) - c[:, 1]).clamp(min=0.0) + diag["selected_from_thr_mean"] = sel_from_thr.mean().item() + + return diag + + +def _run_sweep(args, inputs, dist_label: str) -> List[dict]: + results = [] + + # Baselines: time them but their param is fixed. + for mode, power in BASELINES: + lat, tmode = _time_mode(inputs, args, mode, power) + entry = { + "mode": mode, + "mode_name": MODE_NAMES.get(mode, f"m{mode}"), + "param_name": "(baseline)", + "param": power, + "distribution": dist_label, + "timing_mode": tmode, + "latency_ms": lat["mean_ms"], + "latency_median_ms": lat["median_ms"], + "latency_min_ms": lat["min_ms"], + } + entry.update(_collect_diagnostics(inputs, args, mode, power)) + results.append(entry) + print( + f" mode={mode:>2d} ({MODE_NAMES[mode]:>10s}) baseline " + f" [{tmode:>20s}] latency={lat['mean_ms']:.4f} ms" + ) + + # Parametric sweep, one (mode, param) combo at a time. + for mode, values in SWEEP_GRID.items(): + pname = PARAM_NAME[mode] + for val in values: + lat, tmode = _time_mode(inputs, args, mode, float(val)) + entry = { + "mode": mode, + "mode_name": MODE_NAMES.get(mode, f"m{mode}"), + "param_name": pname, + "param": float(val), + "distribution": dist_label, + "timing_mode": tmode, + "latency_ms": lat["mean_ms"], + "latency_median_ms": lat["median_ms"], + "latency_min_ms": lat["min_ms"], + } + entry.update(_collect_diagnostics(inputs, args, mode, float(val))) + results.append(entry) + print( + f" mode={mode:>2d} ({MODE_NAMES[mode]:>10s}) {pname}={val:<6.3f} " + f" [{tmode:>20s}] latency={lat['mean_ms']:.4f} ms" + ) + + return results + + +def _print_ranked(results: List[dict]) -> None: + ranked = sorted(results, key=lambda r: r["latency_ms"]) + header = ( + f"{'Rank':>4s} {'Mode':<12s} {'Param':<14s} {'Dist':<10s} {'Latency (ms)':>14s}" + ) + print("\n" + "=" * len(header)) + print("TopK auto-tune results (ranked by measured kernel latency, lower is better)") + print("=" * len(header)) + print(header) + print("-" * len(header)) + for i, r in enumerate(ranked): + param_str = f"{r['param_name']}={r['param']}" if r["param_name"] != "(baseline)" else "(baseline)" + print( + f"{i + 1:4d} {r['mode_name']:<12s} {param_str:<14s} " + f"{r['distribution']:<10s} {r['latency_ms']:14.4f}" + ) + print("=" * len(header)) + + # Best per mode. + best: Dict[int, dict] = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + print("\nBest per mode (by latency):") + for m in sorted(best.keys()): + r = best[m] + param_str = f"{r['param_name']}={r['param']}" if r["param_name"] != "(baseline)" else "(baseline)" + print( + f" mode {m:>2d} ({r['mode_name']:>5s}): {param_str:<16s} " + f"latency={r['latency_ms']:.4f} ms" + ) + + +def main(): + parser = argparse.ArgumentParser("TopK mapping hyperparameter auto-tuner (latency-driven)") + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--seq-len", type=int, default=65536) + parser.add_argument("--topk-val", type=int, default=2048) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--reserved-bos", type=int, default=1) + parser.add_argument("--reserved-eos", type=int, default=2) + parser.add_argument("--distributions", type=str, nargs="+", + default=["normal"], + help="Synthetic distributions when --real-histograms is not set.") + parser.add_argument("--real-histograms", type=str, default=None, + help="Path to raw_histograms.npy from calibration.") + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--repeat", type=int, default=100) + parser.add_argument("--collect-stats", action="store_true", + help="Also collect histogram + counter diagnostics (post-timing, no cost).") + parser.add_argument("--output-json", type=str, default=None) + parser.add_argument("--lut-path", type=str, default=None, + help="Path to .npy uint8[256] LUT for MAPPING_LUT_CDF (mode 1).") + parser.add_argument("--quantiles-path", type=str, default=None, + help="Path to .npy float32[256] quantile table for MAPPING_QUANTILE (mode 2).") + args = parser.parse_args() + + # Modes 1 (LUT_CDF) and 2 (Quantile) are no longer evaluated — they + # don't use topk_mapping::apply_transform (their mapping is done inside + # compute_stage1_bin) and are kept out of the comparison entirely. + args._mapping_lut = None + args._mapping_quantiles = None + + real_histogram: Optional[np.ndarray] = None + if args.real_histograms: + raw = np.load(args.real_histograms) + real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + + all_results: List[dict] = [] + + if real_histogram is not None: + inputs = _make_real_inputs(args, real_histogram) + print("\n=== Latency sweep on REAL distribution " + f"(batch={args.batch_size} heads={args.num_kv_heads} seq={args.seq_len} topk={args.topk_val}) ===") + all_results += _run_sweep(args, inputs, "real") + else: + for dist in args.distributions: + inputs = make_topk_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=dist, + ) + print(f"\n=== Latency sweep on synthetic dist={dist} ===") + all_results += _run_sweep(args, inputs, dist) + + _print_ranked(all_results) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(all_results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_ablation.py b/benchmarks/bench_ablation.py new file mode 100644 index 00000000..53dbcbb4 --- /dev/null +++ b/benchmarks/bench_ablation.py @@ -0,0 +1,341 @@ +"""Phase + merge ablation for the K=30 random-split parallel kernel. + +Splits production latency into per-phase pieces and compares merge variants +on identical pre-filled workspaces. The fixture lives in +csrc/topk_adaptive_profile.cu (NOT in topk_sglang_merge.cu); the production +kernel still uses the SPLITS-specialised merge described in +csrc/topk_sglang_merge.cu's file header. + +Ablation modes (must match the kAblMode_* constants in topk_adaptive_profile.cu): + + 0 full_parallel (re-enters the production workspace API) + 1 local_only (Stage 1 sort + workspace write only) + 2 local_no_workspace (Stage 1 sort, scratch sink — no ws write) + 3 workspace_write_only (write 32 dummy entries / split) + 4 atomic_only (done_counter atomic + last-CTA test only) + 5 merge_prod_default (legacy per-SPLITS dispatch: 2-way/pairwise/k-way) + 6 merge_only_cub_warp (cub::WarpMergeSort — current production merge) + 7 merge_only_cub_block (cub::BlockMergeSort benchmark) + 8 memset_only (host cudaMemsetAsync of done_counter) + 9 merge_only_2way_manual (SPLITS=2 only) + 10 merge_only_pairwise_tree_4(SPLITS=4 only) + 11 merge_kway_all (force k-way for all SPLITS) + +Benchmark matrix (default; override on the CLI): + B ∈ {1, 2, 4, 8, 16, 32, 128} + pages ∈ {8192, 16384, 32768} + topk_val = 30 + partition = contiguous + forced_splits ∈ {2, 4, 8, 16, 32} + +Outputs `bench_results/k30_ablation.csv` (long-form per-row records) and a +wide table `…_summary.csv` with the columns the spec asks for: + B, pages, split, merge_mode, full_adaptive_us, local_only_us, + workspace_write_us, atomic_only_us, merge_only_us, fused_us, + speedup_vs_fused. +""" + +from __future__ import annotations + +import argparse +import csv +import statistics +import time +from collections import defaultdict +from pathlib import Path + +import torch +import vortex_torch_C as C + + +MODES = [ + (0, "full_parallel"), + (1, "local_only"), + (2, "local_no_workspace"), + (3, "workspace_write_only"), + (4, "atomic_only"), + (5, "merge_prod_default"), # legacy: 2-way/pairwise/k-way per SPLITS + (6, "merge_only_cub_warp"), # current production (WarpMergeSort) + (7, "merge_only_cub_block"), + (8, "memset_only"), + (9, "merge_only_2way_manual"), # SPLITS=2 only + (10, "merge_only_pairwise_tree_4"), # SPLITS=4 only + (11, "merge_kway_all"), # force k-way for all SPLITS +] + + +# ---------- input setup ------------------------------------------------------- + +def make_inputs(eff_bs: int, pages: int, topk_val: int = 30, + bos: int = 0, eos: int = 0, seed: int = 0, + dtype: torch.dtype = torch.bfloat16): + torch.manual_seed(seed) + device = "cuda" + x = torch.randn(eff_bs * pages, dtype=dtype, device=device) + dense_kv_indptr = torch.arange(eff_bs + 1, dtype=torch.int32, device=device) * pages + dense_kv_indices = torch.arange(eff_bs * pages, dtype=torch.int32, device=device) + out_per_row = bos + eos + topk_val + sparse_kv_indptr = torch.arange(eff_bs + 1, dtype=torch.int32, device=device) * out_per_row + sparse_kv_indices = torch.full((eff_bs * out_per_row,), -1, + dtype=torch.int32, device=device) + return { + "x": x, + "dense_kv_indptr": dense_kv_indptr, + "sparse_kv_indptr": sparse_kv_indptr, + "dense_kv_indices": dense_kv_indices, + "sparse_kv_indices": sparse_kv_indices, + } + + +def make_workspace(eff_bs: int): + opts = dict(dtype=torch.int32, device="cuda") + n = eff_bs * 32 * 32 # max splits=32, local_k=32 + return { + "partial_keys": torch.empty(n, **opts), + "partial_indices": torch.empty(n, **opts), + "done_counter": torch.empty(eff_bs, **opts), + "scratch": torch.empty(eff_bs * 32, **opts), + } + + +def fill_workspace_for_merge(ws, eff_bs, splits, seed=1): + """Pre-fill partial_keys/indices with sorted top-32 lists per split. + + Production layout: `[B, SPLITS, 32]` flattened to a 1-D int32 tensor. + Each (b, split) slot is sorted descending by uint32 key. Indices are + distinct global page IDs (no -1 sentinels in the prefilled portion). + """ + torch.manual_seed(seed) + n = eff_bs * splits * 32 + keys_base = torch.randint(0, 2**31 - 1, (eff_bs * splits, 32), + dtype=torch.int64, device="cuda").to(torch.int32) + keys_sorted = keys_base.sort(dim=1, descending=True).values + ws["partial_keys"][:n] = keys_sorted.flatten() + indices = torch.arange(n, dtype=torch.int32, device="cuda") + ws["partial_indices"][:n] = indices + + +# ---------- kernel calls ------------------------------------------------------ + +def call_ablation(inputs, ws, eff_bs, pages, topk_val, mode, splits, + bos=0, eos=0): + inputs["sparse_kv_indices"].fill_(-1) + C.topk_output_adaptive_workspace_ablation( + inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + ws["partial_keys"], ws["partial_indices"], + ws["done_counter"], ws["scratch"], + eff_bs, topk_val, bos, eos, pages, + mode, splits, + ) + + +def call_fused(inputs, eff_bs, pages, topk_val, bos=0, eos=0): + inputs["sparse_kv_indices"].fill_(-1) + C.topk_output_sglang_fused( + inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + eff_bs, topk_val, bos, eos, pages, 0, 0.0, None, None, + ) + + +def bench(fn, *args, warmup=20, repeat=200): + for _ in range(warmup): + fn(*args) + torch.cuda.synchronize() + samples = [] + for _ in range(repeat): + torch.cuda.synchronize() + t0 = time.perf_counter() + fn(*args) + torch.cuda.synchronize() + samples.append((time.perf_counter() - t0) * 1e3) # ms + samples.sort() + return { + "mean": statistics.mean(samples), + "p50": samples[len(samples) // 2], + "p90": samples[int(len(samples) * 0.9)], + "min": samples[0], + "max": samples[-1], + } + + +# ---------- correctness check ------------------------------------------------- + +def verify_merge_only(ws, eff_bs, splits, topk_val, bos=0): + """Check that the merge_only_prod_default kernel returns the true top-K. + + Builds a reference by reading partial_keys/indices into Python, picking + the largest topk_val keys per row, and comparing against the kernel's + output as a SET (production merge order is unspecified for ties). + """ + inputs = make_inputs(eff_bs, pages=8192, topk_val=topk_val, bos=bos) + fill_workspace_for_merge(ws, eff_bs, splits, seed=42) + inputs["sparse_kv_indices"].fill_(-1) + C.topk_output_adaptive_workspace_ablation( + inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + ws["partial_keys"], ws["partial_indices"], + ws["done_counter"], ws["scratch"], + eff_bs, topk_val, bos, 0, 8192, + 11, splits, # mode 11 = prod_default + ) + torch.cuda.synchronize() + + # Reference top-K from the prefilled workspace. + n = eff_bs * splits * 32 + keys = ws["partial_keys"][:n].view(eff_bs, splits * 32).to(torch.int64) & 0xFFFFFFFF + idx = ws["partial_indices"][:n].view(eff_bs, splits * 32) + out_per_row = bos + topk_val + out = inputs["sparse_kv_indices"] + + failures = 0 + for b in range(eff_bs): + ref_topk = keys[b].topk(topk_val).indices # local positions + ref_set = set(idx[b, ref_topk].tolist()) + got = out[b * out_per_row + bos : b * out_per_row + bos + topk_val] + got_set = set(got.tolist()) - {-1} + if ref_set != got_set: + failures += 1 + if failures <= 3: + print(f" MERGE CORRECTNESS FAIL b={b} splits={splits} K={topk_val}: " + f"|sym_diff|={len(ref_set ^ got_set)}") + return failures == 0 + + +# ---------- main -------------------------------------------------------------- + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--out", default="bench_results/k30_ablation.csv") + ap.add_argument("--summary-out", default="bench_results/k30_ablation_summary.csv") + ap.add_argument("--warmup", type=int, default=20) + ap.add_argument("--repeat", type=int, default=200) + ap.add_argument("--pages", type=int, nargs="+", default=[8192, 16384, 32768]) + ap.add_argument("--bs", type=int, nargs="+", default=[1, 2, 4, 8, 16, 32, 128]) + ap.add_argument("--splits", type=int, nargs="+", default=[2, 4, 8, 16, 32]) + ap.add_argument("--skip-correctness", action="store_true") + args = ap.parse_args() + + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + + # -- correctness gate first -- + if not args.skip_correctness: + print("=== Merge correctness check (mode=11 merge_kway_all) ===") + ws_check = make_workspace(max(args.bs)) + all_ok = True + for splits in (2, 4, 8, 16, 32): + for K in (1, 4, 8, 16, 30, 32): + ok = verify_merge_only(ws_check, eff_bs=4, splits=splits, topk_val=K) + tag = "OK " if ok else "FAIL" + print(f" splits={splits:2d} K={K:2d} : {tag}") + all_ok &= ok + # reserved_bos cover + ok = verify_merge_only(ws_check, eff_bs=4, splits=splits, topk_val=30, bos=2) + print(f" splits={splits:2d} K=30 bos=2: {'OK ' if ok else 'FAIL'}") + all_ok &= ok + if not all_ok: + print("CORRECTNESS FAILURES — aborting bench") + return 1 + print("All merge-only correctness checks passed.\n") + + long_rows = [] + # cell -> {ablation_name: mean_ms} + cells = defaultdict(dict) + + for pages in args.pages: + for B in args.bs: + inputs = make_inputs(B, pages) + ws = make_workspace(B) + + # Reference: fused. + call_fused(inputs, B, pages, 30) + torch.cuda.synchronize() + s_fused = bench(call_fused, inputs, B, pages, 30, + warmup=args.warmup, repeat=args.repeat) + long_rows.append({ + "pages": pages, "B": B, "splits": 0, + "ablation": "fused_baseline", + **{k: f"{v:.4f}" for k, v in s_fused.items()}, + }) + print(f"\n=== pages={pages} B={B} === fused = {s_fused['mean']*1000:.2f} us") + + for splits in args.splits: + # Pre-fill workspace ahead of the merge-only modes. + fill_workspace_for_merge(ws, B, splits) + torch.cuda.synchronize() + + for mode_id, mode_name in MODES: + if mode_id == 9 and splits != 2: continue + if mode_id == 10 and splits != 4: continue + # Re-prefill before merge-only calls so input layout is fresh. + if mode_id in (5, 6, 7, 9, 10, 11): + fill_workspace_for_merge(ws, B, splits) + torch.cuda.synchronize() + try: + call_ablation(inputs, ws, B, pages, 30, mode_id, splits) + torch.cuda.synchronize() + except RuntimeError as e: + print(f" split={splits} {mode_name}: SKIP ({e})") + continue + stats = bench(call_ablation, inputs, ws, B, pages, 30, + mode_id, splits, + warmup=args.warmup, repeat=args.repeat) + long_rows.append({ + "pages": pages, "B": B, "splits": splits, + "ablation": mode_name, + **{k: f"{v:.4f}" for k, v in stats.items()}, + }) + cells[(pages, B, splits)][mode_name] = stats["mean"] + cells[(pages, B, splits)]["__fused"] = s_fused["mean"] + pct = stats["mean"] / s_fused["mean"] * 100 + print(f" split={splits:2d} {mode_name:<28s}" + f" mean={stats['mean']*1000:7.2f} us" + f" ({pct:5.1f}% of fused)") + + # ---------- write long-form CSV ---------- + long_cols = ["pages", "B", "splits", "ablation", "mean", "p50", "p90", "min", "max"] + with open(out_path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=long_cols) + w.writeheader() + w.writerows(long_rows) + print(f"\nlong-form rows → {out_path} ({len(long_rows)} rows)") + + # ---------- write spec-shaped summary ---------- + summary_path = Path(args.summary_out) + summary_cols = ["B", "pages", "split", "merge_mode", + "full_adaptive_us", "local_only_us", "workspace_write_us", + "atomic_only_us", "merge_only_us", "fused_us", + "speedup_vs_fused"] + + def _us(ms): + return f"{ms * 1000:.2f}" if ms is not None else "" + + with open(summary_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(summary_cols) + for (pages, B, splits), data in sorted(cells.items()): + full = data.get("full_parallel") + local = data.get("local_only") + ws_write = data.get("workspace_write_only") + atomic = data.get("atomic_only") + fused = data.get("__fused") + for merge_name in ("merge_prod_default", "merge_only_cub_warp", + "merge_only_cub_block", "merge_kway_all", + "merge_only_2way_manual", + "merge_only_pairwise_tree_4"): + merge_t = data.get(merge_name) + if merge_t is None: continue + speedup = (fused / full) if (full and fused) else float("nan") + w.writerow([B, pages, splits, merge_name, + _us(full), _us(local), _us(ws_write), + _us(atomic), _us(merge_t), _us(fused), + f"{speedup:.3f}" if speedup == speedup else ""]) + print(f"summary table → {summary_path}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmarks/bench_midk_fused_baseline.py b/benchmarks/bench_midk_fused_baseline.py new file mode 100644 index 00000000..c49080d1 --- /dev/null +++ b/benchmarks/bench_midk_fused_baseline.py @@ -0,0 +1,128 @@ +"""Quick fused baseline measurement at mid-K (K in {64,128,256,512}). + +Goal: establish the bar that any adaptive split implementation has to beat +before we commit to building / templating SELECTK_SORTK kernels. + +Output: bench_results/midk_fused_baseline.csv + + a printed table per K. +""" +from __future__ import annotations + +import argparse +import csv +import math +import os +from pathlib import Path + +import torch +import vortex_torch_C as V + + +def time_kernel_us(fn, warmup=10, repeat=100): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + starts = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + for i in range(repeat): + starts[i].record(); fn(); ends[i].record() + torch.cuda.synchronize() + times = sorted(starts[i].elapsed_time(ends[i]) * 1000.0 for i in range(repeat)) + n = len(times) + mean = sum(times) / n + var = sum((t - mean) ** 2 for t in times) / n + return dict(mean=mean, p50=times[n // 2], p90=times[min(n - 1, int(round(n * 0.9)))], + min=times[0], max=times[-1], std=math.sqrt(var)) + + +def make_inputs(B, pages, K, dtype=torch.bfloat16, reserved_bos=1, reserved_eos=2): + device = torch.device("cuda") + dense_kv_indptr = torch.arange(B + 1, device=device, dtype=torch.int32) * pages + sparse_kv_indptr = torch.arange(B + 1, device=device, dtype=torch.int32) * (K + reserved_bos + reserved_eos) + total = B * pages + torch.manual_seed(0) + scores = torch.randn(total, device=device, dtype=dtype) + dense_kv_indices = torch.arange(total, device=device, dtype=torch.int32) + out = torch.full((B * (K + reserved_bos + reserved_eos),), -1, device=device, dtype=torch.int32) + return scores, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, out + + +def call_fused(scores, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, out, + B, K, reserved_bos, reserved_eos, pages, mapping_mode): + V.topk_output_sglang_fused( + scores, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, out, + B, K, reserved_bos, reserved_eos, pages, + mapping_mode, 0.5, None, None, + ) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--out", default="bench_results/midk_fused_baseline.csv") + ap.add_argument("--pages", nargs="+", type=int, default=[16384, 32768, 65536, 131072]) + ap.add_argument("--ks", nargs="+", type=int, default=[64, 128, 256, 512]) + ap.add_argument("--batches", nargs="+", type=int, default=[1, 2, 4, 8, 16]) + ap.add_argument("--mappings", nargs="+", type=int, default=[0, 8]) # NONE, TRUNC8 + ap.add_argument("--warmup", type=int, default=10) + ap.add_argument("--repeat", type=int, default=100) + args = ap.parse_args() + + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + repo_root = Path(__file__).resolve().parents[1] + if not out_path.is_absolute(): + out_path = repo_root / out_path + + device = torch.cuda.get_device_properties(0) + print(f"# GPU: {device.name}, SMs={device.multi_processor_count}") + print(f"# pages: {args.pages}") + print(f"# Ks: {args.ks}") + print(f"# Bs: {args.batches}") + print(f"# maps: {args.mappings} (0=NONE, 8=TRUNC8)") + print() + + rows = [] + for K in args.ks: + print(f"=== K={K} ===") + print(f"{'pages':>8s} {'B':>3s} {'map':>5s} {'mean_us':>10s} {'p50_us':>10s} " + f"{'min_us':>10s} {'std_us':>8s} status") + for pages in args.pages: + for B in args.batches: + for mapping in args.mappings: + map_name = {0: "NONE", 8: "TRUNC8"}.get(mapping, str(mapping)) + try: + ins = make_inputs(B, pages, K) + # warmup correctness check + call_fused(*ins, B, K, 1, 2, pages, mapping) + torch.cuda.synchronize() + except Exception as e: + print(f"{pages:>8d} {B:>3d} {map_name:>5s} {'-':>10s} {'-':>10s} " + f"{'-':>10s} {'-':>8s} FAILED: {str(e)[:80]}") + rows.append(dict(K=K, pages=pages, B=B, mapping=map_name, + mean_us=None, p50_us=None, p90_us=None, + min_us=None, max_us=None, std_us=None, + status="failed", error=str(e)[:200])) + continue + t = time_kernel_us( + lambda: call_fused(*ins, B, K, 1, 2, pages, mapping), + warmup=args.warmup, repeat=args.repeat, + ) + print(f"{pages:>8d} {B:>3d} {map_name:>5s} {t['mean']:>10.3f} " + f"{t['p50']:>10.3f} {t['min']:>10.3f} {t['std']:>8.3f} ok") + rows.append(dict(K=K, pages=pages, B=B, mapping=map_name, + mean_us=t['mean'], p50_us=t['p50'], p90_us=t['p90'], + min_us=t['min'], max_us=t['max'], std_us=t['std'], + status="ok", error="")) + del ins + print() + + with open(out_path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=list(rows[0].keys())) + w.writeheader() + for r in rows: + w.writerow(r) + print(f"# wrote {out_path}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py new file mode 100644 index 00000000..68a4c956 --- /dev/null +++ b/benchmarks/bench_topk.py @@ -0,0 +1,1140 @@ +""" +TopK kernel benchmarking suite. + +Lean rewrite after the remap-benchmark refactor. Exposes three public +helpers used by autotune_topk_mapping.py (make_topk_inputs, bench_kernel, +compute_histogram_stats) and a CLI with two modes: + + - default : time the baseline (unmapped) kernel and the fused + kernel across a grid of (mode, power, batch, seq_len, + topk_val, distribution) configs. + - --remap-bench: time baseline vs fused vs split-phase (remap-only + + unmapped-topk-on-remapped) and report threshold stats + from topk_profile_counters. +""" + +import argparse +import json +import math +import statistics +from typing import Dict, List + +import numpy as np +import torch + +from vortex_torch_C import ( + topk_output, # full CUB BlockRadixSort topk (max 4096 pages/seg) + topk_output_sglang, # 2-stage radix approximate topk (unmapped baseline) + topk_output_sglang_fused, # fused remap + 2-stage radix topk + topk_output_sglang_ori, # original SGLang reference kernel + topk_output_adaptive, # adaptive split-2 last-CTA-wins (hybrid radix/CUB) + topk_remap_only, # standalone value-space remap + topk_profile_histogram, + topk_profile_counters, +) + +# topk_output's template ladder tops out at 8192 pages per segment +# (see topk.cu::topk_output, branches up to <= 8192). Runs larger than +# that hit TORCH_CHECK(false). +TOPK_OUTPUT_MAX_PAGES = 8192 + +# The ori kernel has TopK baked in at compile time. If setup.py was built +# with a different value, calls will fail; this is the topk_val that +# matches the current build of topk_sglang_ori.cu. +TOPK_ORI_BAKED_IN = 30 + + +MAPPING_MODE_NAMES = { + 0: "None", + 1: "LUT_CDF", + 2: "Quantile", + 3: "Power", + 4: "Log", + 6: "Asinh", + 7: "Log1p", + 8: "Trunc8", + 9: "Erf", + 10: "Tanh", + 11: "Subtract", + 13: "ExpStretch", + 15: "ShiftPow2", + 16: "ShiftPow3", + 17: "LinearSteep", + 18: "HalfSquare", + 19: "HalfCube", + 20: "DenseMant", +} + +# Modes whose value-space transform is a real apply_transform() pass. Modes +# 1 (LUT_CDF), 2 (QUANTILE) and 8 (TRUNC8) apply their mapping inside +# compute_stage1_bin, not apply_transform — so `topk_remap_only` cannot +# reproduce them (the fp32 buffer would just contain the raw values). For +# those modes the split-phase numbers are N/A; only the fused kernel is a +# meaningful reference. +ARITHMETIC_MODES = {0, 3, 4, 6, 7, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20} + + +_AUTOTUNE_TIE_TOLERANCE_MS = 0.0002 # ≈ CUDA event noise floor at this kernel size + + +def _auto_num_splits(eff_batch_size: int, pages_per_seg: int, topk_val: int) -> int: + """Pick num_splits to balance Phase-1 and Phase-2 work on the parallel + kernel. + + Phase-1 per CTA does O(pages/splits) work and runs eff_batch_size*splits + CTAs in parallel; Phase-2 runs eff_batch_size CTAs each doing + O(splits*topk) work on the merged candidate list. Assuming both phases + hit SM saturation, total ≈ (pages/splits + splits*topk)/throughput, + minimized at splits = sqrt(pages/topk). Cap at the SM-budget for + eff_batch_size and the max_safe value (pages_per_seg // topk_val, past + which Phase 1 partitions are smaller than topk_val and gain nothing). + + Returns 1 when splitting cannot help. + """ + max_safe = max(1, pages_per_seg // max(1, topk_val)) + if max_safe <= 1 or eff_batch_size <= 0: + return 1 + try: + sm = torch.cuda.get_device_properties(0).multi_processor_count + except Exception: + sm = 132 + balanced = max(1, int(round((pages_per_seg / max(1, topk_val)) ** 0.5))) + # SM-budget floor is 1, but 1 means "don't split" — pointless for the + # parallel kernel and would ask Phase-1 to cache the entire seq in + # shared memory (blows past the 96 KB ceiling). Clamp to at least 2 + # whenever max_safe allows it; the caller will skip parallel entirely + # if it really doesn't want to split. + sm_budget = max(1, sm // max(1, eff_batch_size)) + choice = min(balanced, sm_budget, max_safe) + if choice < 2 and max_safe >= 2: + choice = 2 + return max(1, choice) + + +def _load_autotune_hparams(path: str) -> Dict[int, float]: + """Load per-mode best hyperparameters from an autotune_results.json. + + The JSON is produced by autotune_topk_mapping.py and contains a list of + {mode, param, latency_ms, num_equal_mean, selected_from_thr_mean, ...} + entries. For each mode we group all sweep entries, find the lowest + latency, then break ties (within `_AUTOTUNE_TIE_TOLERANCE_MS`) by: + + 1. Smallest `num_equal_mean` (= thr_size). Stage-2 cost is O(thr_size), + so a smaller threshold bin is a better proxy for real fused + latency than the noisy `latency_ms` measurement. + 2. Smallest `selected_from_thr_mean`. How many pages the topk has to + pull from the threshold bin during refinement. + 3. Lowest `latency_ms` again (final fallback). + + Modes with no parametric sweep (0=None, 4=Log) return a dummy 0.5; + the caller should override to taste. + """ + with open(path) as f: + data = json.load(f) + grouped: Dict[int, list] = {} + for r in data: + m = r.get("mode") + lat = r.get("latency_ms") + if m is None or lat is None: + continue + grouped.setdefault(m, []).append(r) + + best: Dict[int, dict] = {} + for m, entries in grouped.items(): + min_lat = min(e["latency_ms"] for e in entries) + contenders = [ + e for e in entries + if e["latency_ms"] - min_lat <= _AUTOTUNE_TIE_TOLERANCE_MS + ] + # Tie-breakers: lowest num_equal_mean, then lowest sel_thr, + # then lowest latency. Missing diagnostic fields → +inf so they + # lose tie-breaks (we still keep them as fallback candidates). + def _rank_key(e): + return ( + e.get("num_equal_mean", float("inf")), + e.get("selected_from_thr_mean", float("inf")), + e["latency_ms"], + ) + best[m] = min(contenders, key=_rank_key) + + return {m: float(r["param"]) for m, r in best.items()} + + +def _key_to_fp16(key: int) -> np.float16: + """Invert convert_to_uint8's sign-flip for a single 16-bit key.""" + bits = (key & 0x7FFF) if key >= 0x8000 else ((~key) & 0xFFFF) + return np.array([bits], dtype=np.uint16).view(np.float16)[0] + + +def build_bin_range_table(): + """Per-bin (lo, hi) fp16 value tables for the 256 Stage-1 radix bins. + + Shared by the real-distribution samplers in bench_topk.py and + autotune_topk_mapping.py so both scripts generate identical inputs. + """ + all_bits = np.arange(65536, dtype=np.uint16) + all_fp16 = all_bits.view(np.float16) + keys = np.where( + (all_bits & 0x8000).astype(bool), + (~all_bits).astype(np.uint16), + all_bits | np.uint16(0x8000), + ) + bins = (keys >> 8).astype(np.uint8) + all_f32 = all_fp16.astype(np.float32) + valid = np.isfinite(all_f32) + bin_lo = np.full(256, np.inf, dtype=np.float32) + bin_hi = np.full(256, -np.inf, dtype=np.float32) + for b in range(256): + mask = (bins == b) & valid + if mask.any(): + vals = all_f32[mask] + bin_lo[b] = vals.min() + bin_hi[b] = vals.max() + empty = bin_lo > bin_hi + for b in np.where(empty)[0]: + val = float(_key_to_fp16((int(b) << 8) | 0x80)) + bin_lo[b] = val + bin_hi[b] = val + return bin_lo, bin_hi + + +def scores_from_histogram( + histogram: np.ndarray, + total_pages: int, + device: str = "cuda", + score_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Sample `total_pages` scores whose Stage-1 bucket distribution matches + the given 256-bin histogram (produced by calibration). Each bucket is + sampled uniformly over the fp16 range that maps into it.""" + bin_lo, bin_hi = build_bin_range_table() + counts = histogram.astype(np.float64) + total = counts.sum() + if total == 0: + return torch.zeros(total_pages, 1, 1, dtype=score_dtype, device=device) + probs = counts / total + bin_indices = np.random.choice(256, size=total_pages, p=probs) + lo = bin_lo[bin_indices] + hi = bin_hi[bin_indices] + rand = np.random.uniform(0, 1, size=total_pages).astype(np.float32) + scores_f32 = lo + rand * (hi - lo) + return torch.from_numpy(scores_f32).to(score_dtype).reshape(total_pages, 1, 1).to(device) + + +def make_topk_inputs( + batch_size: int, + num_kv_heads: int, + seq_len: int, + page_size: int, + topk_val: int, + reserved_bos: int, + reserved_eos: int, + score_dtype: torch.dtype, + distribution: str = "normal", + real_histogram: np.ndarray = None, + device: str = "cuda", +) -> dict: + """Synthesize CSR-formatted paged attention inputs for kernel timing. + + When `real_histogram` is provided, scores are drawn from that 256-bin + distribution (ignoring `distribution`) so the benchmark sees the same + Stage-1 bucket distribution as the calibrated model. + """ + eff_batch_size = batch_size * num_kv_heads + num_pages_per_seg = math.ceil(seq_len / page_size) + total_dense_pages = eff_batch_size * num_pages_per_seg + sparse_per_seg = min(topk_val + reserved_bos + reserved_eos, num_pages_per_seg) + total_sparse_pages = eff_batch_size * sparse_per_seg + + dense_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * num_pages_per_seg, num_pages_per_seg, + dtype=torch.int32, device=device, + ) + sparse_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * sparse_per_seg, sparse_per_seg, + dtype=torch.int32, device=device, + ) + dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) + sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) + + if real_histogram is not None: + x = scores_from_histogram(real_histogram, total_dense_pages, device=device, + score_dtype=score_dtype) + elif distribution == "normal": + x = torch.randn(total_dense_pages, 1, 1, device=device).to(score_dtype) + elif distribution == "lognormal": + x = torch.randn(total_dense_pages, 1, 1, device=device).exp().to(score_dtype) + elif distribution == "uniform": + x = torch.rand(total_dense_pages, 1, 1, device=device).to(score_dtype) + elif distribution == "bucket_uniform": + # Uniform across all 256 fp16 radix buckets. Random uint16 bit + # patterns → interpret as fp16. NaN/Inf patterns collapse to ±0. + raw_bits = torch.randint(0, 65536, (total_dense_pages,), dtype=torch.int32, device=device) + abs_bits = raw_bits & 0x7FFF + raw_bits[abs_bits >= 0x7C00] = raw_bits[abs_bits >= 0x7C00] & 0x8000 + x = raw_bits.to(torch.int16).view(torch.float16).float().reshape(total_dense_pages, 1, 1).to(score_dtype) + else: + raise ValueError(f"Unknown distribution: {distribution}") + + return { + "x": x, + "dense_kv_indptr": dense_kv_indptr, + "sparse_kv_indptr": sparse_kv_indptr, + "dense_kv_indices": dense_kv_indices, + "sparse_kv_indices": sparse_kv_indices, + "eff_batch_size": eff_batch_size, + "num_pages_per_seg": num_pages_per_seg, + "sparse_per_seg": sparse_per_seg, + } + + +def bench_kernel(kernel_fn, args, warmup: int = 10, repeat: int = 100) -> dict: + """Time a kernel with CUDA events. Returns latency stats in ms.""" + for _ in range(warmup): + kernel_fn(*args) + torch.cuda.synchronize() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + for i in range(repeat): + start_events[i].record() + kernel_fn(*args) + end_events[i].record() + torch.cuda.synchronize() + + times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + return { + "mean_ms": statistics.mean(times), + "median_ms": statistics.median(times), + "std_ms": statistics.stdev(times) if len(times) > 1 else 0.0, + "min_ms": min(times), + "max_ms": max(times), + } + + +def compute_histogram_stats(histograms: torch.Tensor) -> dict: + """Bin distribution statistics from histogram tensor [B, 256].""" + h = histograms.float() + h_sum = h.sum(dim=0) # [256] + nonzero = h_sum[h_sum > 0] + if len(nonzero) == 0: + return { + "max_mean_ratio": 0.0, "std": 0.0, "gini": 0.0, + "num_nonzero_bins": 0, "entropy": 0.0, "effective_bins": 0.0, + } + mean_val = nonzero.mean().item() + max_val = nonzero.max().item() + std_val = nonzero.std().item() if len(nonzero) > 1 else 0.0 + sorted_bins = nonzero.sort().values + n = len(sorted_bins) + idx = torch.arange(1, n + 1, device=sorted_bins.device, dtype=torch.float32) + gini = (2.0 * (idx * sorted_bins).sum() / (n * sorted_bins.sum()) - (n + 1) / n).item() + p = nonzero / nonzero.sum() + entropy = -(p * p.log2()).sum().item() + return { + "max_mean_ratio": max_val / mean_val if mean_val > 0 else 0.0, + "std": std_val, + "gini": max(0.0, gini), + "num_nonzero_bins": int(len(nonzero)), + "entropy": entropy, + "effective_bins": 2 ** entropy, + } + + +def _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode: int, power: float) -> dict: + """Run topk_profile_counters + topk_profile_histogram once and aggregate + threshold-bin / bucket-distribution stats. Profile kernels run AFTER all + latency measurements, so their writes never contaminate timing. + """ + eff_bs = inputs["eff_batch_size"] + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + inputs["sparse_kv_indices"].zero_() + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + topk_profile_counters( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + c = counter_buf.float() + + # Run the 256-bin histogram profile to compute the rank_target_bins + # metric: how many bins ABOVE the threshold bin (i.e. the bins whose + # pages are selected without Stage-2 refinement) actually contain + # selected pages, and the mean pages-per-such-bin. + hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + inputs["x"], + inputs["dense_kv_indptr"], + hist_buf, + eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + + thr_idx = counter_buf[:, 0].to(torch.int64) # [eff_bs] + hist = hist_buf.to(torch.int64) # [eff_bs, 256] + bin_ids = torch.arange(256, device="cuda", dtype=torch.int64).unsqueeze(0) # [1, 256] + above_mask = bin_ids > thr_idx.unsqueeze(1) # [eff_bs, 256] + above_populated = ((hist > 0) & above_mask).sum(dim=1).float() # bins >thr with any pages + pages_above = (hist * above_mask.to(torch.int64)).sum(dim=1).float() # total pages in those bins + # Mean pages per populated above-threshold bin (per-segment, then + # averaged). Guard against divide-by-zero. + pages_per_bin = torch.where( + above_populated > 0, + pages_above / above_populated, + torch.zeros_like(above_populated), + ) + + # Selected from threshold bin = topk_val - num_above (clamped >= 0). + sel_from_thr = (float(topk_val) - c[:, 1]).clamp(min=0.0) + return { + "threshold_bin_mean": c[:, 0].mean().item(), + "threshold_bin_max": c[:, 0].max().item(), + "num_above_mean": c[:, 1].mean().item(), + "threshold_bin_size_mean": c[:, 2].mean().item(), # NUM_EQUAL + "threshold_bin_size_max": c[:, 2].max().item(), + "selected_from_thr_mean": sel_from_thr.mean().item(), + "selected_from_thr_max": sel_from_thr.max().item(), + "refine_rounds_mean": c[:, 4].mean().item(), + # Rank-target metrics: how the top pages are actually spread. + "above_bins_mean": above_populated.mean().item(), + "pages_per_above_bin_mean": pages_per_bin.mean().item(), + } + + +def _resolve_hparam(args, mode: int) -> float: + """Pick the hyperparameter for a mode: autotune JSON wins, then --mapping-hparam.""" + if mode == 0: + return 0.5 # unused for MAPPING_NONE + hparams: Dict[int, float] = getattr(args, "_autotune_hparams", {}) or {} + if mode in hparams: + return hparams[mode] + return args.mapping_hparam + + +def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, + distribution, modes: List[int], + head_label: str = "all") -> dict: + """Time baseline, fused, and split-phase for each mode at one config. + + `head_label` is metadata: ``"all"`` for the aggregated table (default), + or a stringified head index ``"0".."N-1"`` for per-head benches. The + caller is responsible for setting ``args._real_histogram`` to the + head-sliced sub-histogram before invoking this function in per-head mode. + """ + real_hist = getattr(args, "_real_histogram", None) if distribution == "real" else None + inputs = make_topk_inputs( + batch_size=batch_size, + num_kv_heads=num_kv_heads, + seq_len=seq_len, + page_size=args.page_size, + topk_val=topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=distribution if distribution != "real" else "normal", + real_histogram=real_hist, + ) + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + total_dense = inputs["x"].numel() + + # Baseline = unmapped topk_output_sglang (CUB two-stage radix, the + # kernel every mapped mode's split-phase ends up calling). This is + # the `base_us` column and also what the `None` row reports, so + # None's topk_us == base_us by construction. + baseline_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + baseline = bench_kernel(topk_output_sglang, baseline_args, args.warmup, args.repeat) + + # Optional extra row: the full CUB BlockRadixSort topk from topk.cu. + # This is a "true naive" — exact sort, no bucketing tricks — for A/B + # against the 2-stage approximate baseline. Only runs when pages_per_seg + # fits the kernel's template ladder (<= TOPK_OUTPUT_MAX_PAGES = 4096). + naive_ms = None + if pages_per_seg <= TOPK_OUTPUT_MAX_PAGES: + naive_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["dense_kv_indices"], # NOTE: topk_output arg order differs + inputs["sparse_kv_indptr"], # from topk_output_sglang + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + naive_ms = bench_kernel( + topk_output, naive_args, args.warmup, args.repeat + )["mean_ms"] + + # Optional extra row: the original SGLang kernel from topk_sglang_ori.cu, + # compiled with TopK=TOPK_ORI_BAKED_IN. Only runs when topk_val matches + # that constant; otherwise the row is skipped with a warning. It is NOT + # used as the baseline — this is a separate A/B point so you can see the + # ori-vs-naive gap at a glance. + sglang_ori_ms = None + if topk_val == TOPK_ORI_BAKED_IN: + ori_indices = torch.empty(eff_bs, TOPK_ORI_BAKED_IN, + dtype=torch.int32, device="cuda") + ori_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + ori_indices, + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + sglang_ori_ms = bench_kernel( + topk_output_sglang_ori, ori_args, args.warmup, args.repeat + )["mean_ms"] + + # Pre-allocate the float32 buffer used for the split-phase (remap → baseline). + # Split-phase remapped buffer is **float32** to preserve Stage-2 + # refinement precision. The fused kernel computes transforms in + # fp32 internally (so its Stage-2 sub-bin keys carry transform- + # dependent bits in positions [15:0]); a narrower remapped buffer + # (bf16 or fp16) would zero those bits on round-trip and change + # the Stage-2 tie-break ordering vs the fused path. fp32 is the + # only lossless choice. The kernel supports bf16 output too (see + # topk_remap_only's dispatch table) for experimental paths, but we + # don't use it here because correctness matters more than the + # small memory-bandwidth win. + remapped = torch.empty(total_dense, dtype=torch.float32, device="cuda").reshape(inputs["x"].shape) + + config = { + "batch_size": batch_size, + "num_kv_heads": num_kv_heads, + "seq_len": seq_len, + "topk_val": topk_val, + "distribution": distribution, + "pages_per_seg": pages_per_seg, + "head": head_label, + "baseline_ms": baseline["mean_ms"], + "naive_ms": naive_ms, + "sglang_ori_ms": sglang_ori_ms, + "modes": [], + } + + # Naive row — full CUB BlockRadixSort from topk.cu. No mapping, no + # remap, no fused. Only populated when pages_per_seg fits the kernel. + if naive_ms is not None: + config["modes"].append({ + "mode": -2, # sentinel so ranking/autotune skip it + "mode_name": "Naive", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": naive_ms, + "split_total_ms": None, + "fused_ms": None, + "parallel_ms": None, + "parallel_splits": None, + "threshold_bin_mean": 0.0, + "threshold_bin_max": 0.0, + "num_above_mean": 0.0, + "threshold_bin_size_mean": 0.0, + "threshold_bin_size_max": 0.0, + "selected_from_thr_mean": 0.0, + "selected_from_thr_max": 0.0, + "refine_rounds_mean": 0.0, + "above_bins_mean": 0.0, + "pages_per_above_bin_mean": 0.0, + }) + + # The None row is a pass-through to the naive baseline: no remap, no + # fused, and topk_us == base_us by construction. Distribution metrics + # are populated by running the profile kernels with mode=0 so the user + # can see the unmapped Stage-1 bucket layout as a reference. + none_stats = _collect_threshold_stats( + inputs, topk_val, pages_per_seg, args, mode=0, power=0.5 + ) + + # Adaptive split-2 kernel (last-CTA-wins merge). Enabled via --bench-parallel. + # For the None row we run it with mapping_mode=0 (identity transform). + none_parallel_ms = None + none_parallel_splits = None + none_cluster_ms = None + if getattr(args, "bench_parallel", False): + par_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + 0, # mapping_mode = NONE + 0.5, # mapping_power (unused for mode 0) + ) + inputs["sparse_kv_indices"].zero_() + par_none = bench_kernel(topk_output_adaptive, par_args, args.warmup, args.repeat) + none_parallel_ms = par_none["mean_ms"] + none_parallel_splits = 2 + + config["modes"].append({ + "mode": 0, + "mode_name": "None", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": baseline["mean_ms"], + "split_total_ms": None, + "fused_ms": None, + "parallel_ms": none_parallel_ms, + "parallel_splits": none_parallel_splits, + "cluster_ms": none_cluster_ms, + **none_stats, + }) + + # Extra row for the original SGLang kernel — only populated when the + # build's baked-in TopK matches topk_val. Also a pass-through (no + # remap, no fused); topk_us is the ori kernel latency. + if sglang_ori_ms is not None: + config["modes"].append({ + "mode": -1, # sentinel so ranking/autotune skip it + "mode_name": "sglang_ori", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": sglang_ori_ms, + "split_total_ms": None, + "fused_ms": None, + "parallel_ms": None, + "parallel_splits": None, + "threshold_bin_mean": 0.0, + "threshold_bin_max": 0.0, + "num_above_mean": 0.0, + "threshold_bin_size_mean": 0.0, + "threshold_bin_size_max": 0.0, + "selected_from_thr_mean": 0.0, + "selected_from_thr_max": 0.0, + "refine_rounds_mean": 0.0, + "above_bins_mean": 0.0, + "pages_per_above_bin_mean": 0.0, + }) + else: + print(f"[bench-remap] sglang_ori row SKIPPED: topk_val={topk_val} != " + f"TOPK_ORI_BAKED_IN ({TOPK_ORI_BAKED_IN}). Rebuild topk_sglang_ori.cu " + f"with a matching TopK to enable the row.") + + for mode in modes: + # Mode 0 is already emitted as the `None` row above (pass-through + # to the ori baseline with no remap/fused). Skip to avoid a + # duplicate row and a spurious fused-mode-0 measurement. + if mode == 0: + continue + + power = _resolve_hparam(args, mode) + + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + fused_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + mode, power, lut_t, q_t, + ) + inputs["sparse_kv_indices"].zero_() + fused = bench_kernel(topk_output_sglang_fused, fused_args, args.warmup, args.repeat) + + # Adaptive split-2 kernel (last-CTA-wins merge) with remap mode. + # Only ARITHMETIC_MODES are supported by topk_output_adaptive — the + # LUT/quantile/trunc8 modes have no apply_transform arithmetic path. + parallel_ms = None + parallel_splits_used = None + cluster_ms = None + if getattr(args, "bench_parallel", False) and mode in ARITHMETIC_MODES: + par_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + mode, power, + ) + inputs["sparse_kv_indices"].zero_() + par_bench = bench_kernel( + topk_output_adaptive, par_args, args.warmup, args.repeat + ) + parallel_ms = par_bench["mean_ms"] + parallel_splits_used = 2 + + # Split-phase timing is only meaningful for arithmetic modes. + # MAPPING_LUT_CDF / QUANTILE / TRUNC8 apply their mapping inside + # compute_stage1_bin, which topk_remap_only cannot reproduce, so we + # report N/A for the split-phase fields and rely on the fused kernel + # as the only valid reference latency. + if mode in ARITHMETIC_MODES: + remap_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + remapped, + eff_bs, args.reserved_bos, args.reserved_eos, + mode, power, + ) + remap_only = bench_kernel(topk_remap_only, remap_args, args.warmup, args.repeat) + + # Populate the remapped buffer once so the unfused-topk warmup + # iterations don't read stale data. + topk_remap_only(*remap_args) + torch.cuda.synchronize() + split_topk_args = ( + remapped, + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + split_topk = bench_kernel(topk_output_sglang, split_topk_args, args.warmup, args.repeat) + + remap_ms = remap_only["mean_ms"] + topk_after_remap_ms = split_topk["mean_ms"] + split_total_ms = remap_ms + topk_after_remap_ms + else: + remap_ms = None + topk_after_remap_ms = None + split_total_ms = None + + # Counter collection is run AFTER all timing measurements for this mode + # so it cannot affect the timings. + stats = _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode, power) + + row = { + "mode": mode, + "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), + "power": power, + "remap_ms": remap_ms, + "topk_after_remap_ms": topk_after_remap_ms, + "split_total_ms": split_total_ms, + "fused_ms": fused["mean_ms"], + "parallel_ms": parallel_ms, + "parallel_splits": parallel_splits_used, + "cluster_ms": cluster_ms, + **stats, + } + config["modes"].append(row) + + return config + + +# Stage-2 working-set cap, matches SMEM_INPUT_SIZE in fast_topk_clean_fused +# (32 KB dynamic smem / 2 ping-pong buffers / 4 bytes per int = 4096). +_STAGE2_SMEM_CAP = 4096 + + +def _print_remap_table(results: List[dict]) -> None: + # The printed table only carries metrics that participate in the + # fused-kernel cost model. All purely-informational columns + # (thr_bin / sel_thr / abv_bins / pg/bin) were dropped — they're + # still in the JSON for downstream tools, just not in the table. + header = ( + f"{'mode':<14s} {'remap_ms':>9s} {'topk_ms':>9s} {'split_ms':>9s} " + f"{'fused_ms':>9s} {'par_ms':>9s} {'cluster_ms':>10s} {'splits':>6s} " + f"{'base_ms':>9s} {'s1p2_load':>9s} {'eff_thr':>7s} {'rounds':>6s} " + f"{'s2_work':>8s}" + ) + for cfg in results: + banner = ( + f"\n[batch={cfg['batch_size']} heads={cfg['num_kv_heads']} " + f"seq_len={cfg['seq_len']} topk={cfg['topk_val']} " + f"dist={cfg['distribution']} pages_per_seg={cfg['pages_per_seg']} " + f"head={cfg.get('head', 'all')}]" + ) + print(banner) + extra_notes = [] + if cfg.get("naive_ms") is not None: + extra_notes.append("Naive row = topk.cu (CUB full sort)") + if cfg.get("sglang_ori_ms") is not None: + extra_notes.append("sglang_ori row = topk_sglang_ori.cu") + notes_str = "" + if extra_notes: + notes_str = " | " + " | ".join(extra_notes) + print(f" Baseline: topk_sglang.cu (CUB two-stage){notes_str}") + print( + f" s1p2_load = thr_size (uncapped global re-reads in Stage-1 pass 2) " + f"eff_thr = min(thr_size, {_STAGE2_SMEM_CAP}) " + f"rounds = stage-2 passes (1..4) " + f"s2_work = rounds * eff_thr" + ) + print(header) + print("-" * len(header)) + base_ms = cfg["baseline_ms"] + for row in cfg["modes"]: + if row["mode"] == 0: + label = "None" + elif row["mode"] == -1: + label = row.get("mode_name", "sglang_ori") + elif row["mode"] == -2: + label = row.get("mode_name", "Naive") + else: + label = f"{row['mode_name']}(p={row['power']})" + def _fmt(v): + return f"{v:9.4f}" if v is not None else f"{'N/A':>9s}" + fused_str = _fmt(row.get("fused_ms")) + par_str = _fmt(row.get("parallel_ms")) + cluster_str = f"{row.get('cluster_ms'):10.4f}" if row.get("cluster_ms") is not None else f"{'N/A':>10s}" + splits = row.get("parallel_splits") + splits_str = f"{splits:>6d}" if splits is not None else f"{'N/A':>6s}" + thr_size = row.get("threshold_bin_size_mean", 0.0) + rounds = row.get("refine_rounds_mean", 0.0) + eff_thr = min(thr_size, float(_STAGE2_SMEM_CAP)) + s2_work = rounds * eff_thr + s1p2_load = thr_size # alias: same number, named for the cost-model role + print( + f"{label:<14s} " + f"{_fmt(row['remap_ms'])} " + f"{_fmt(row['topk_after_remap_ms'])} " + f"{_fmt(row['split_total_ms'])} " + f"{fused_str} " + f"{par_str} " + f"{cluster_str} " + f"{splits_str} " + f"{base_ms:9.4f} " + f"{s1p2_load:9.0f} " + f"{eff_thr:7.0f} " + f"{rounds:6.2f} " + f"{s2_work:8.0f}" + ) + + +def _combine_per_head_cfgs(per_head_cfgs: List[dict]) -> dict: + """Combine a list of per-head cfg dicts (same shape, head='0','1',...) + into a single aggregated cfg tagged head='all', by averaging every + numeric field. This is used when --per-head-bench is on so the + aggregated row reflects the realistic per-head behaviour rather than + a separate kernel launch on an averaged histogram. + + Assumes every cfg has the same `modes` list in the same order — which + holds because all per-head sub-runs use identical (batch, heads, seq, + topk, page_size, reserved, mapping_modes) parameters and therefore + take the same code paths through `_remap_bench_one_config`. + """ + assert per_head_cfgs, "_combine_per_head_cfgs called with empty list" + base = per_head_cfgs[0] + n_modes = len(base["modes"]) + # Sanity: same shape. + for c in per_head_cfgs[1:]: + assert len(c["modes"]) == n_modes, ( + f"per-head cfgs disagree on mode count: {n_modes} vs {len(c['modes'])}" + ) + + def _mean_or_none(vals): + vs = [v for v in vals if v is not None] + return (sum(vs) / len(vs)) if vs else None + + combined: Dict = { + "batch_size": base["batch_size"], + "num_kv_heads": base["num_kv_heads"], + "seq_len": base["seq_len"], + "topk_val": base["topk_val"], + "distribution": base["distribution"], + "pages_per_seg": base["pages_per_seg"], + "head": "all", + "baseline_ms": _mean_or_none([c.get("baseline_ms") for c in per_head_cfgs]), + "naive_ms": _mean_or_none([c.get("naive_ms") for c in per_head_cfgs]), + "sglang_ori_ms": _mean_or_none([c.get("sglang_ori_ms") for c in per_head_cfgs]), + "modes": [], + } + + # Numeric fields per mode row that we average; non-numeric fields (mode, + # mode_name, power) are copied from the first cfg since they're identical + # across heads by construction. + NUMERIC_KEYS = ( + "remap_ms", "topk_after_remap_ms", "split_total_ms", "fused_ms", + "parallel_ms", "cluster_ms", + "threshold_bin_mean", "threshold_bin_max", + "num_above_mean", + "threshold_bin_size_mean", "threshold_bin_size_max", + "selected_from_thr_mean", "selected_from_thr_max", + "refine_rounds_mean", + "above_bins_mean", "pages_per_above_bin_mean", + ) + for mi in range(n_modes): + sample = base["modes"][mi] + merged = { + "mode": sample["mode"], + "mode_name": sample["mode_name"], + "power": sample["power"], + } + for key in NUMERIC_KEYS: + merged[key] = _mean_or_none([c["modes"][mi].get(key) for c in per_head_cfgs]) + combined["modes"].append(merged) + return combined + + +def _run_remap_bench(args) -> None: + modes = [int(m) for m in args.mapping_modes] + # Mode 0 is emitted as the "None" row from _remap_bench_one_config + # itself (pass-through to the ori baseline). Drop any user-supplied 0 + # to avoid a duplicate row. + modes = [m for m in modes if m != 0] + + distributions = list(args.distributions) + if getattr(args, "_real_histogram", None) is not None: + if "real" not in distributions: + distributions.append("real") + print(f"[remap-bench] 'real' distribution enabled " + f"(histogram total count = {int(args._real_histogram.sum())})") + + if getattr(args, "per_head_bench", False): + if getattr(args, "_real_histograms_raw", None) is None: + raise SystemExit( + "[bench-remap] --per-head-bench requires --real-histograms with a 2D raw file." + ) + if not args.num_kv_heads or any(h <= 0 for h in args.num_kv_heads): + raise SystemExit("[bench-remap] --per-head-bench requires --num-kv-heads > 0.") + # When the user passes multiple --num-kv-heads values we slice by the + # first one (the others are degenerate for per-head reporting since + # the histogram file has a fixed head count). + per_head_count = int(args.num_kv_heads[0]) + + results = [] + # When --per-head-bench is on, each "real"-distribution aggregate is + # built by averaging the 8 per-head measurements (NOT by running an + # extra kernel on an averaged histogram). This grouping keeps the + # per-head cfgs that should fold into each (bs, heads, seq, topk) + # aggregate point. + per_head_groups: dict = {} + + # ---- Per-head tables (printed first) ---- + if getattr(args, "per_head_bench", False): + raw = args._real_histograms_raw + saved_agg = args._real_histogram + try: + for h in range(per_head_count): + # Slice rows belonging to head `h`. Rows are interleaved as + # row_idx % num_kv_heads = head_idx, so this strided slice + # collects all (call, batch, h) triples across the file. + args._real_histogram = raw[h::per_head_count].sum(axis=0) + for bs in args.batch_sizes: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + cfg = _remap_bench_one_config( + args, bs, heads, seq_len, topk_val, "real", modes, + head_label=str(h), + ) + results.append(cfg) + per_head_groups.setdefault( + (bs, heads, seq_len, topk_val), [] + ).append(cfg) + finally: + args._real_histogram = saved_agg + + # ---- Aggregated tables (printed last) ---- + for bs in args.batch_sizes: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + for dist in distributions: + if dist == "real" and getattr(args, "per_head_bench", False): + cfgs = per_head_groups.get((bs, heads, seq_len, topk_val), []) + if cfgs: + # Combine the per-head cfgs into a single + # aggregated row — no extra kernel launch. + cfg = _combine_per_head_cfgs(cfgs) + results.append(cfg) + continue + cfg = _remap_bench_one_config( + args, bs, heads, seq_len, topk_val, dist, modes, + head_label="all", + ) + results.append(cfg) + + _print_remap_table(results) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +def _run_latency_sweep(args) -> None: + """Simple baseline-vs-fused latency sweep (no split-phase, no counters).""" + modes = [int(m) for m in args.mapping_modes] + distributions = list(args.distributions) + if getattr(args, "_real_histogram", None) is not None and "real" not in distributions: + distributions.append("real") + results = [] + for bs in args.batch_sizes: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + for dist in distributions: + real_hist = args._real_histogram if dist == "real" else None + inputs = make_topk_inputs( + batch_size=bs, num_kv_heads=heads, seq_len=seq_len, + page_size=args.page_size, topk_val=topk_val, + reserved_bos=args.reserved_bos, reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=dist if dist != "real" else "normal", + real_histogram=real_hist, + ) + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + row_modes = [] + for mode in modes: + power = _resolve_hparam(args, mode) + inputs["sparse_kv_indices"].zero_() + if mode == 0: + call = topk_output_sglang + call_args = ( + inputs["x"], inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, + args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + else: + call = topk_output_sglang_fused + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + call_args = ( + inputs["x"], inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, + args.reserved_bos, args.reserved_eos, pages_per_seg, + mode, power, lut_t, q_t, + ) + stats = bench_kernel(call, call_args, args.warmup, args.repeat) + row_modes.append({ + "mode": mode, "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), + "power": power, "mean_ms": stats["mean_ms"], + "median_ms": stats["median_ms"], + }) + print( + f"bs={bs} h={heads} seq={seq_len} topk={topk_val} " + f"dist={dist} mode={mode:>2d} lat={stats['mean_ms']:.4f} ms" + ) + results.append({ + "batch_size": bs, "num_kv_heads": heads, "seq_len": seq_len, + "topk_val": topk_val, "distribution": dist, "modes": row_modes, + }) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +def main(): + p = argparse.ArgumentParser("TopK kernel benchmarks") + p.add_argument("--batch-sizes", type=int, nargs="+", default=[4]) + p.add_argument("--num-kv-heads", type=int, nargs="+", default=[8]) + p.add_argument("--seq-lens", type=int, nargs="+", default=[8192]) + p.add_argument("--topk-vals", type=int, nargs="+", default=[30]) + p.add_argument("--distributions", type=str, nargs="+", + default=["normal"], + choices=["normal", "lognormal", "uniform", "bucket_uniform", "real"], + help="Synthetic distributions. Use 'real' (or --real-histograms) to " + "sample scores from a calibrated raw_histograms.npy.") + p.add_argument("--real-histograms", type=str, default=None, + help="Path to raw_histograms.npy from calibrate_topk.py. When set, a " + "'real' distribution is appended to the sweep so every " + "(mode, hparam) combo is also timed on the calibrated score " + "distribution.") + p.add_argument("--mapping-modes", type=int, nargs="+", + default=[0, 3, 6, 7], + help="Mapping modes to sweep (0=None, 3=Power, 6=Asinh, 7=Log1p, etc.)") + p.add_argument("--mapping-hparam", "--mapping-power", type=float, default=0.5, + dest="mapping_hparam", + help="Fallback hyperparameter for every non-zero mapping mode when " + "no --autotune-json is provided: p for mode 3 (power), beta for " + "mode 6 (asinh), alpha for modes 7/9/10/13 (log1p/erf/tanh/exp_stretch).") + p.add_argument("--autotune-json", type=str, default=None, + help="Path to autotune_results.json produced by autotune_topk_mapping.py. " + "When set, the per-mode hyperparameter with the lowest measured " + "latency in that file is used instead of --mapping-hparam.") + p.add_argument("--lut-path", type=str, default=None, + help="Path to .npy uint8[256] LUT for MAPPING_LUT_CDF (mode 1).") + p.add_argument("--quantiles-path", type=str, default=None, + help="Path to .npy float32[256] quantile table for MAPPING_QUANTILE (mode 2).") + p.add_argument("--page-size", type=int, default=16) + p.add_argument("--reserved-bos", type=int, default=1) + p.add_argument("--reserved-eos", type=int, default=2) + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--repeat", type=int, default=100) + p.add_argument("--output-json", type=str, default=None) + p.add_argument("--remap-bench", action="store_true", + help="Run the split-phase remap/topk/fused/baseline benchmark.") + p.add_argument("--bench-parallel", action="store_true", + help="Time the adaptive split-2 last-CTA-wins kernel " + "(topk_output_adaptive) and fill the parallel_ms column.") + p.add_argument("--num-splits", type=int, default=-1, + help="Partitions per batch for the parallel kernel. -1 = auto " + "(sm_count / eff_batch_size, clamped to pages_per_seg/topk_val).") + p.add_argument("--per-head-bench", action="store_true", + help="In addition to the aggregated 'real'-distribution table, also " + "run the remap-bench once per KV head: slice the calibrated " + "histogram into one sub-histogram per head (using " + "row_idx %% num_kv_heads = head_idx), bench each, and print one " + "table per head followed by the aggregated table. Requires " + "--real-histograms (with a 2D raw file) and --num-kv-heads.") + args = p.parse_args() + + args._autotune_hparams = {} + if args.autotune_json: + args._autotune_hparams = _load_autotune_hparams(args.autotune_json) + print(f"[autotune] using best-latency hyperparameters from {args.autotune_json}:") + for m, v in sorted(args._autotune_hparams.items()): + print(f" mode {m:>2d} -> {v}") + + args._real_histogram = None + args._real_histograms_raw = None + if args.real_histograms: + # mmap_mode='r' keeps the (potentially 20+ GB) raw file off-heap; we + # only materialise per-head sums when --per-head-bench is set. + raw = np.load(args.real_histograms, mmap_mode='r') + args._real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + if raw.ndim > 1: + args._real_histograms_raw = raw + print(f"[real] loaded calibrated histogram from {args.real_histograms} " + f"(shape={raw.shape} → [256] aggregate)") + + args._mapping_lut = None + args._mapping_quantiles = None + if args.lut_path: + lut_np = np.load(args.lut_path).astype(np.uint8) + assert lut_np.shape == (256,), f"LUT must be [256], got {lut_np.shape}" + args._mapping_lut = torch.from_numpy(lut_np).cuda() + print(f"[mapping] loaded LUT from {args.lut_path}") + if args.quantiles_path: + q_np = np.load(args.quantiles_path).astype(np.float32) + assert q_np.shape == (256,), f"quantiles must be [256], got {q_np.shape}" + args._mapping_quantiles = torch.from_numpy(q_np).cuda() + print(f"[mapping] loaded quantiles from {args.quantiles_path}") + + if args.remap_bench: + _run_remap_bench(args) + else: + _run_latency_sweep(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_topk_setting_sweep.py b/benchmarks/bench_topk_setting_sweep.py new file mode 100644 index 00000000..db587b74 --- /dev/null +++ b/benchmarks/bench_topk_setting_sweep.py @@ -0,0 +1,1211 @@ +#!/usr/bin/env python +"""Comprehensive (pages, K, batch, split, mapping, dtype) latency sweep +comparing the three TopK kernels in this repo: + + topk_sglang_merge.cu -> topk_output_adaptive_workspace (adaptive split path) + topk_sglang.cu -> topk_output_sglang_fused (fused two-stage radix) + topk.cu -> topk_output (CUB BlockRadixSort full sort) + +Outputs four files under : + topk_setting_sweep_raw.csv long-form, one row per measurement + topk_setting_sweep_best_adaptive.csv best adaptive split per (pages,K,B,mapping,dtype) + topk_parallel_advantage_summary.csv win/loss region rollup + topk_setting_sweep_report.md human-readable analysis + +See module-level docstring of topk_sglang_merge.cu for the dispatcher +contract this script mirrors when labeling actual_path. +""" +from __future__ import annotations + +import argparse +import csv +import math +import statistics +import sys +import time +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import torch +import vortex_torch_C as V + +# --------------------------------------------------------------------------- # +# Mapping mode constants — must match csrc/topk_mapping.cuh. +# --------------------------------------------------------------------------- # +MAPPING_NONE = 0 +MAPPING_POWER = 3 +MAPPING_LOG = 4 +MAPPING_ASINH = 6 +MAPPING_LOG1P = 7 +MAPPING_TRUNC8 = 8 +MAPPING_ERF = 9 +MAPPING_TANH = 10 + +MAPPING_NAMES = { + MAPPING_NONE: "NONE", + MAPPING_POWER: "POWER", + MAPPING_LOG: "LOG", + MAPPING_ASINH: "ASINH", + MAPPING_LOG1P: "LOG1P", + MAPPING_TRUNC8: "TRUNC8", + MAPPING_ERF: "ERF", + MAPPING_TANH: "TANH", +} +MAPPING_BY_NAME = {v: k for k, v in MAPPING_NAMES.items()} + +# Mirrors the dispatcher in csrc/topk_sglang_merge.cu. +K_MAX_ADAPTIVE = 32 # K <= 32 stays on the adaptive K=30 path +K_FUSED_FALLBACK = 1024 # K >= 1024 routes to fused, even from adaptive entry + +LOCAL_BLOCK_FULL_SORT = 0 +LOCAL_SELECT32_SORT32 = 1 + +# topk.cu template ladder caps at 8192 pages. +TOPK_CU_MAX_PAGES = 8192 + +# kCfg* capacity table from csrc/topk_sglang_merge.cu (BLOCK_FULL_SORT only). +BLOCK_FULL_SORT_CAPACITY = {1: 8192, 2: 8192, 4: 4096, 8: 4096, 16: 2048, 32: 1024} + +DEFAULT_PAGES = [4096, 8192, 16384, 32768, 65536] +DEFAULT_KS = [30, 64, 128, 256, 512, 1024, 2048] +DEFAULT_BATCHES = [1, 2, 4, 8, 16] +DEFAULT_SPLITS = [1, 2, 4, 8, 16, 32] +DEFAULT_MAPPING_NAMES = ["NONE"] +DEFAULT_DTYPES = ["bfloat16"] + +WIN_THRESHOLD = 1.03 # adaptive "wins" if speedup_vs_sglang >= this + +# Production merge mode wired into TopK30_RandomSplit_Select32_Kernel / +# TopK30_RandomSplit_Parallel_Kernel — cub::WarpMergeSort. The merge-only +# ablation sub-sweep timings are written separately to topk_merge_mode_summary.csv. +PROD_MERGE_NAME = "warp_cub" +LOCAL_MODE_NAMES = {LOCAL_BLOCK_FULL_SORT: "BLOCK_FULL_SORT", + LOCAL_SELECT32_SORT32: "SELECT32_SORT32"} + +# Ablation mode IDs from topk_adaptive_profile.cu. +ABL_LOCAL_WITH_WORKSPACE = 1 # populates partial workspace, no merge +ABL_MERGE_PROD_DEFAULT = 5 +ABL_MERGE_CUB_WARP = 6 +ABL_MERGE_CUB_BLOCK = 7 +ABL_MERGE_KWAY = 11 +MERGE_ABL_NAMES = { + ABL_MERGE_PROD_DEFAULT: "prod_default(legacy)", + ABL_MERGE_CUB_WARP: "warp_cub", + ABL_MERGE_CUB_BLOCK: "block_cub", + ABL_MERGE_KWAY: "kway", +} + +# --------------------------------------------------------------------------- # + +def _dtype_str_to_torch(s: str) -> torch.dtype: + return {"bfloat16": torch.bfloat16, "float": torch.float32, "float32": torch.float32}[s] + + +def apply_remap_torch(x: torch.Tensor, mode: int, p: float) -> torch.Tensor: + """Reference-side remap, kept in sync with apply_transform_tmpl in topk_mapping.cuh. + + Only modes used by this sweep are implemented. Adding more requires editing + topk_mapping.cuh and propagating to this function. + """ + if mode in (MAPPING_NONE, MAPPING_TRUNC8): + return x + if mode == MAPPING_POWER: + return torch.copysign(torch.abs(x).pow(p), x) + if mode == MAPPING_LOG: + return torch.copysign(torch.log(torch.abs(x) + 1.0), x) + if mode == MAPPING_ASINH: + return torch.asinh(p * x) + if mode == MAPPING_LOG1P: + return torch.copysign(torch.log1p(p * torch.abs(x)), x) + if mode == MAPPING_ERF: + return torch.erf(p * x) + if mode == MAPPING_TANH: + return torch.tanh(p * x) + raise ValueError(f"reference remap not implemented for mapping_mode={mode}") + + +# --------------------------------------------------------------------------- # +# Tensor / workspace setup. +# --------------------------------------------------------------------------- # +@dataclass +class Inputs: + scores: torch.Tensor + dense_kv_indptr: torch.Tensor + sparse_kv_indptr: torch.Tensor + dense_kv_indices: torch.Tensor + out: torch.Tensor # int32 sparse_kv_indices + B: int + pages: int + K: int + reserved_bos: int + reserved_eos: int + + +def make_inputs(B: int, pages: int, K: int, dtype: torch.dtype, + reserved_bos: int = 1, reserved_eos: int = 2, + seed: int = 0) -> Inputs: + torch.manual_seed(seed) + device = torch.device("cuda") + dense_kv_indptr = torch.arange(B + 1, device=device, dtype=torch.int32) * pages + sparse_kv_indptr = torch.arange(B + 1, device=device, dtype=torch.int32) * (K + reserved_bos + reserved_eos) + total = B * pages + scores = torch.randn(total, device=device, dtype=dtype) + dense_kv_indices = torch.arange(total, device=device, dtype=torch.int32) + out = torch.full((B * (K + reserved_bos + reserved_eos),), -1, device=device, dtype=torch.int32) + return Inputs(scores, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, out, + B, pages, K, reserved_bos, reserved_eos) + + +def make_workspace(B_max: int, max_split: int = 32, K_local: int = 32): + device = torch.device("cuda") + ws_elems = max(B_max * max_split * K_local, 64) + return dict( + partial_keys = torch.zeros(ws_elems, device=device, dtype=torch.int32), + partial_indices = torch.zeros(ws_elems, device=device, dtype=torch.int32), + done_counter = torch.zeros(max(B_max, 1), device=device, dtype=torch.int32), + ) + + +# --------------------------------------------------------------------------- # +# Reference top-K and correctness. +# --------------------------------------------------------------------------- # +def reference_topk(inp: Inputs, mapping_mode: int, mapping_power: float): + """Returns (ref_sets[B], ref_remapped[B] (cpu fp32), threshold_per_row[B]).""" + ref_sets = [] + ref_remapped = [] + thresholds = [] + for b in range(inp.B): + row = inp.scores[b * inp.pages + inp.reserved_bos + : (b + 1) * inp.pages - inp.reserved_eos].float() + remapped = apply_remap_torch(row, mapping_mode, mapping_power) + vals, idx_within = torch.topk(remapped, inp.K) + global_idx = (idx_within + b * inp.pages + inp.reserved_bos).cpu().tolist() + ref_sets.append(set(global_idx)) + ref_remapped.append(remapped.cpu()) + thresholds.append(vals.min().item()) + return ref_sets, ref_remapped, thresholds + + +def check_correctness(inp: Inputs, ref_sets, ref_remapped, thresholds, + mapping_mode: int, mapping_power: float) -> Tuple[bool, str]: + """Set equality with tie tolerance. + + Returns (ok, note). On failure, `note` describes the failure. + """ + out = inp.out.cpu() + for b in range(inp.B): + slot_start = b * (inp.K + inp.reserved_bos + inp.reserved_eos) + inp.reserved_bos + out_row = out[slot_start : slot_start + inp.K].tolist() + out_set = set(out_row) + if -1 in out_set: + return False, f"row {b}: -1 in output (count={out_row.count(-1)})" + if out_set == ref_sets[b]: + continue + # Tie tolerance: every kernel-selected score must reach the threshold, + # within fp tolerance. + row_offset = b * inp.pages + out_within = [g - row_offset - inp.reserved_bos for g in out_set] + npages_eff = inp.pages - inp.reserved_bos - inp.reserved_eos + if any(i < 0 or i >= npages_eff for i in out_within): + return False, f"row {b}: out-of-range global idx" + out_scores = ref_remapped[b][out_within] + thresh = thresholds[b] + tol = max(1e-6, 1e-3 * abs(thresh)) + min_out = out_scores.min().item() + if min_out < thresh - tol: + return False, (f"row {b}: min selected score={min_out:.4f} < " + f"K-th ref score={thresh:.4f} (tol={tol:.2e})") + return True, "" + + +# --------------------------------------------------------------------------- # +# Timing. +# --------------------------------------------------------------------------- # +def time_kernel_us(fn, warmup: int, repeat: int) -> Optional[Dict[str, float]]: + """Per-call event timing. Returns dict with mean/p50/p90/min/max/std (us) + or None if the kernel raised.""" + try: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + except Exception: + return None + starts = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + try: + for i in range(repeat): + starts[i].record() + fn() + ends[i].record() + torch.cuda.synchronize() + except Exception: + return None + times_us = sorted(starts[i].elapsed_time(ends[i]) * 1000.0 for i in range(repeat)) + n = len(times_us) + mean = sum(times_us) / n + var = sum((t - mean) ** 2 for t in times_us) / n + return dict( + mean=mean, + p50=times_us[n // 2], + p90=times_us[min(n - 1, int(round(n * 0.9)))], + min=times_us[0], + max=times_us[-1], + std=math.sqrt(var), + ) + + +# --------------------------------------------------------------------------- # +# Method launchers. +# --------------------------------------------------------------------------- # +def call_fused(inp: Inputs, mapping_mode: int, mapping_power: float): + # Caller is responsible for inp.out.fill_(-1) BEFORE the timed loop if it + # cares about a clean baseline; the fill is its own kernel launch and would + # otherwise pollute kernel-only timing. + V.topk_output_sglang_fused( + inp.scores, inp.dense_kv_indptr, inp.sparse_kv_indptr, inp.dense_kv_indices, + inp.out, inp.B, inp.K, inp.reserved_bos, inp.reserved_eos, inp.pages, + mapping_mode, mapping_power, None, None, + ) + + +def call_topk_cu(inp: Inputs): + # NOTE: arg order differs from sglang variants - + # (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, ...) + V.topk_output( + inp.scores, inp.dense_kv_indptr, inp.dense_kv_indices, inp.sparse_kv_indptr, + inp.out, inp.B, inp.K, inp.reserved_bos, inp.reserved_eos, inp.pages, + ) + + +def call_adaptive(inp: Inputs, ws: dict, mapping_mode: int, mapping_power: float, + forced_split: int, forced_partition: int, local_mode: int): + V.topk_output_adaptive_workspace( + inp.scores, inp.dense_kv_indptr, inp.sparse_kv_indptr, inp.dense_kv_indices, + inp.out, ws["partial_keys"], ws["partial_indices"], ws["done_counter"], + inp.B, inp.K, inp.reserved_bos, inp.reserved_eos, inp.pages, + mapping_mode, mapping_power, + forced_split, forced_partition, local_mode, + ) + + +# --------------------------------------------------------------------------- # +# actual_path classification — mirrors the C++ dispatcher contract. +# --------------------------------------------------------------------------- # +def classify_adaptive_actual_path(K: int, split: int, pages: int, + local_mode: int) -> Tuple[str, Optional[int], bool]: + """Returns (actual_path, actual_split, is_supported). + + actual_split is the effective split count (None for fused fallback). + is_supported is False when the call would TORCH_CHECK fail. + """ + if K >= K_FUSED_FALLBACK: + return ("fused_fallback_large_k", None, True) + if K > K_MAX_ADAPTIVE: + return ("fused_fallback_mid_k", None, True) + if local_mode == LOCAL_BLOCK_FULL_SORT: + chunk_max = (pages + split - 1) // split + cap = BLOCK_FULL_SORT_CAPACITY.get(split, 0) + if cap < chunk_max: + return ("unsupported_capacity", split, False) + return ("adaptive_block_full_sort", split, True) + return ("adaptive_select32_sort32", split, True) + + +# --------------------------------------------------------------------------- # +# Sweep driver. +# --------------------------------------------------------------------------- # +@dataclass +class Row: + device_name: str + sm_count: int + dtype: str + mapping_mode: int + mapping_name: str + mapping_power: float + pages: int + topk: int + batch: int + method: str # "topk_sglang_fused", "topk_cu", "adaptive_merge" + requested_split: Optional[int] + actual_split: Optional[int] + local_mode: str # "BLOCK_FULL_SORT" / "SELECT32_SORT32" / "n/a" + merge_mode: str # "warp_cub" (production); "n/a" if no merge + candidate_count: Optional[int] # split * local_k; None if no merge + actual_path: str + mean_us: Optional[float] + p50_us: Optional[float] + p90_us: Optional[float] + min_us: Optional[float] + max_us: Optional[float] + std_us: Optional[float] + correctness: Optional[bool] + speedup_vs_sglang_fused: Optional[float] + speedup_vs_topk_cu: Optional[float] + notes: str + + +def run_one_setting(pages: int, K: int, B: int, mapping_mode: int, mapping_power: float, + dtype_str: str, splits: List[int], local_mode: int, + warmup: int, repeat: int, ws: dict, + device_name: str, sm_count: int) -> List[Row]: + """Bench every method at one (pages, K, B, mapping, dtype) cell.""" + rows: List[Row] = [] + inp = make_inputs(B, pages, K, _dtype_str_to_torch(dtype_str)) + ref_sets, ref_remapped, thresholds = reference_topk(inp, mapping_mode, mapping_power) + map_name = MAPPING_NAMES.get(mapping_mode, str(mapping_mode)) + + def _row(**kw): + defaults = dict( + device_name=device_name, sm_count=sm_count, dtype=dtype_str, + mapping_mode=mapping_mode, mapping_name=map_name, mapping_power=mapping_power, + pages=pages, topk=K, batch=B, + requested_split=None, actual_split=None, + local_mode="n/a", merge_mode="n/a", candidate_count=None, + mean_us=None, p50_us=None, p90_us=None, min_us=None, max_us=None, std_us=None, + correctness=None, speedup_vs_sglang_fused=None, speedup_vs_topk_cu=None, + notes="", + ) + defaults.update(kw) + return Row(**defaults) + + # ---------- topk_sglang.cu fused baseline ---------- + fused_us: Optional[float] = None + try: + inp.out.fill_(-1) + call_fused(inp, mapping_mode, mapping_power) + torch.cuda.synchronize() + ok, note = check_correctness(inp, ref_sets, ref_remapped, thresholds, + mapping_mode, mapping_power) + except RuntimeError as e: + # Most common: pages > fused dynamic-smem ceiling (~96KB → ~96k pages). + msg = str(e) + path = "fused_unavailable_smem" if "exceeds" in msg or "smem" in msg.lower() else "error" + rows.append(_row(method="topk_sglang_fused", actual_path=path, + correctness=False, notes=f"raised: {msg[:160]}")) + except Exception as e: + rows.append(_row(method="topk_sglang_fused", actual_path="error", + correctness=False, notes=f"raised: {e}")) + else: + t = time_kernel_us(lambda: call_fused(inp, mapping_mode, mapping_power), + warmup, repeat) + if t is None: + rows.append(_row(method="topk_sglang_fused", actual_path="error", + correctness=ok, notes="time_kernel_us returned None")) + else: + fused_us = t["mean"] + rows.append(_row( + method="topk_sglang_fused", actual_path="fused", + mean_us=t["mean"], p50_us=t["p50"], p90_us=t["p90"], + min_us=t["min"], max_us=t["max"], std_us=t["std"], + correctness=ok, + speedup_vs_sglang_fused=1.0, + notes=note, + )) + + # ---------- topk.cu baseline (CUB full sort) ---------- + cub_us: Optional[float] = None + if pages > TOPK_CU_MAX_PAGES: + rows.append(_row(method="topk_cu", actual_path="topk_cu_unsupported", + notes=f"pages={pages} > template ladder cap {TOPK_CU_MAX_PAGES}")) + else: + try: + inp.out.fill_(-1) + call_topk_cu(inp) + torch.cuda.synchronize() + # topk.cu doesn't apply remap, so its output is for raw scores — + # ALWAYS check against the unmapped reference for fairness. + if mapping_mode in (MAPPING_NONE, MAPPING_TRUNC8): + ok, note = check_correctness(inp, ref_sets, ref_remapped, thresholds, + mapping_mode, mapping_power) + else: + ok, note = True, "remap unsupported by topk.cu; correctness skipped" + except Exception as e: + rows.append(_row(method="topk_cu", actual_path="error", + correctness=False, notes=f"raised: {e}")) + else: + t = time_kernel_us(lambda: call_topk_cu(inp), warmup, repeat) + if t is None: + rows.append(_row(method="topk_cu", actual_path="error", + correctness=ok, notes="time_kernel_us returned None")) + else: + cub_us = t["mean"] + rows.append(_row( + method="topk_cu", actual_path="cub_full_sort", + mean_us=t["mean"], p50_us=t["p50"], p90_us=t["p90"], + min_us=t["min"], max_us=t["max"], std_us=t["std"], + correctness=ok, + speedup_vs_sglang_fused=(fused_us / t["mean"]) if fused_us else None, + speedup_vs_topk_cu=1.0, + notes=note, + )) + + # ---------- topk_sglang_merge.cu adaptive (one row per requested split) ---------- + local_mode_str = LOCAL_MODE_NAMES.get(local_mode, "unknown") + for split in splits: + actual_path, actual_split, is_supported = classify_adaptive_actual_path( + K, split, pages, local_mode) + # Adaptive paths use cub::WarpMergeSort over (split * 32) candidates; + # split=1 has no merge stage at all. + on_adaptive_path = actual_path.startswith("adaptive_") + merge_mode = PROD_MERGE_NAME if (on_adaptive_path and split > 1) else "n/a" + candidate_count = (split * 32) if (on_adaptive_path and split > 1) else None + local_mode_for_row = local_mode_str if on_adaptive_path else "n/a" + + if not is_supported: + rows.append(_row( + method="adaptive_merge", requested_split=split, actual_split=actual_split, + local_mode=local_mode_for_row, merge_mode=merge_mode, + candidate_count=candidate_count, actual_path=actual_path, + notes=(f"BLOCK_FULL_SORT cap={BLOCK_FULL_SORT_CAPACITY.get(split,0)} " + f"< chunk_max={(pages + split - 1)//split}"), + )) + continue + + try: + inp.out.fill_(-1) + call_adaptive(inp, ws, mapping_mode, mapping_power, + forced_split=split, forced_partition=1, # CONTIGUOUS + local_mode=local_mode) + torch.cuda.synchronize() + ok, note = check_correctness(inp, ref_sets, ref_remapped, thresholds, + mapping_mode, mapping_power) + except RuntimeError as e: + # K>32 and pages too large for fused fallback's smem. + msg = str(e) + err_path = ("fused_fallback_unavailable_smem" + if (not on_adaptive_path and ("exceeds" in msg or "smem" in msg.lower())) + else actual_path + "_error") + rows.append(_row(method="adaptive_merge", + requested_split=split, actual_split=actual_split, + local_mode=local_mode_for_row, merge_mode=merge_mode, + candidate_count=candidate_count, + actual_path=err_path, + correctness=False, notes=f"raised: {msg[:160]}")) + continue + except Exception as e: + rows.append(_row(method="adaptive_merge", + requested_split=split, actual_split=actual_split, + local_mode=local_mode_for_row, merge_mode=merge_mode, + candidate_count=candidate_count, actual_path=actual_path, + correctness=False, notes=f"raised: {e}")) + continue + + # For fused-fallback paths, all forced_split values produce identical + # timings (same fused kernel called); we still time each entry to + # quantify dispatcher overhead. + t = time_kernel_us( + lambda: call_adaptive(inp, ws, mapping_mode, mapping_power, + forced_split=split, forced_partition=1, + local_mode=local_mode), + warmup, repeat, + ) + if t is None: + rows.append(_row(method="adaptive_merge", + requested_split=split, actual_split=actual_split, + local_mode=local_mode_for_row, merge_mode=merge_mode, + candidate_count=candidate_count, actual_path=actual_path, + correctness=ok, notes="time_kernel_us returned None")) + continue + + rows.append(_row( + method="adaptive_merge", + requested_split=split, actual_split=actual_split, + local_mode=local_mode_for_row, merge_mode=merge_mode, + candidate_count=candidate_count, actual_path=actual_path, + mean_us=t["mean"], p50_us=t["p50"], p90_us=t["p90"], + min_us=t["min"], max_us=t["max"], std_us=t["std"], + correctness=ok, + speedup_vs_sglang_fused=(fused_us / t["mean"]) if fused_us else None, + speedup_vs_topk_cu=(cub_us / t["mean"]) if cub_us else None, + notes=note, + )) + + return rows + + +# --------------------------------------------------------------------------- # +# Merge-mode ablation (K=30 only — the ablation kernels in +# topk_adaptive_profile.cu are hardcoded to kLocalK_Top30 = 32). +# --------------------------------------------------------------------------- # +def call_ablation(inp: Inputs, ws: dict, scratch: torch.Tensor, + ablation_mode: int, forced_split: int): + V.topk_output_adaptive_workspace_ablation( + inp.scores, inp.dense_kv_indptr, inp.sparse_kv_indptr, inp.dense_kv_indices, + inp.out, ws["partial_keys"], ws["partial_indices"], ws["done_counter"], + scratch, + inp.B, inp.K, inp.reserved_bos, inp.reserved_eos, inp.pages, + ablation_mode, forced_split, + ) + + +def run_merge_ablation(pages_list, batches, splits, warmup, repeat, + ws, device_name, sm_count) -> List[dict]: + """Per (pages, B, split, merge_mode), measure merge-only latency. + + Workflow per cell: + 1. Populate the workspace via ablation_mode = LocalWithWorkspace (mode 1). + 2. For each merge variant, time merge-only kernel (modes 5/6/7/11). + + Returns list of dicts (CSV-ready).""" + rows = [] + device = torch.device("cuda") + scratch = torch.zeros(max(1, max(batches) * max(splits)), + device=device, dtype=torch.int32) + K = 30 # ablation harness is K<=32 only + for pages in pages_list: + for B in batches: + inp = make_inputs(B, pages, K, torch.bfloat16) + for split in splits: + if split <= 1: + continue # nothing to merge + # Step 1: populate the workspace (ablation_mode=1). + try: + call_ablation(inp, ws, scratch, + ABL_LOCAL_WITH_WORKSPACE, split) + torch.cuda.synchronize() + except Exception as e: + rows.append(dict(pages=pages, batch=B, split=split, + merge_mode="setup_failed", + mean_us=None, notes=f"populate raised: {e}")) + continue + # Step 2: merge variants. Some require specific splits. + variants = [ABL_MERGE_PROD_DEFAULT, ABL_MERGE_CUB_WARP, + ABL_MERGE_CUB_BLOCK, ABL_MERGE_KWAY] + for ablv in variants: + name = MERGE_ABL_NAMES[ablv] + try: + call_ablation(inp, ws, scratch, ablv, split) + torch.cuda.synchronize() + except Exception as e: + rows.append(dict(pages=pages, batch=B, split=split, + merge_mode=name, candidate_count=split * 32, + mean_us=None, + notes=f"raised: {repr(e)[:120]}")) + continue + t = time_kernel_us( + lambda av=ablv, sp=split: call_ablation(inp, ws, scratch, av, sp), + warmup, repeat, + ) + if t is None: + rows.append(dict(pages=pages, batch=B, split=split, + merge_mode=name, candidate_count=split * 32, + mean_us=None, notes="time_kernel_us failed")) + continue + rows.append(dict( + device_name=device_name, sm_count=sm_count, + pages=pages, batch=B, split=split, + merge_mode=name, candidate_count=split * 32, + mean_us=t["mean"], p50_us=t["p50"], p90_us=t["p90"], + min_us=t["min"], max_us=t["max"], std_us=t["std"], + notes="", + )) + return rows + + +def write_merge_mode_csv(merge_rows: List[dict], path: Path): + if not merge_rows: + with path.open("w") as f: + f.write("# merge ablation skipped (use --merge-ablation to enable)\n") + return + # Pivot to wide form: one row per (pages, batch, split) with columns per merge mode. + by_key = {} + for r in merge_rows: + key = (r["pages"], r["batch"], r["split"]) + by_key.setdefault(key, {"candidate_count": r.get("candidate_count")}) + by_key[key][r["merge_mode"]] = r.get("mean_us") + cols = ["pages", "batch", "split", "candidate_count", + "warp_cub_us", "block_cub_us", "kway_us", "prod_default_us", + "best_merge_mode", "best_merge_us", "speedup_best_vs_warp"] + with path.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(cols) + for key in sorted(by_key): + pages, B, split = key + d = by_key[key] + warp = d.get("warp_cub") + block = d.get("block_cub") + kway = d.get("kway") + prod = d.get("prod_default(legacy)") + choices = [(name, t) for name, t in + (("warp_cub", warp), ("block_cub", block), + ("kway", kway), ("prod_default", prod)) + if t is not None] + if choices: + best_name, best_us = min(choices, key=lambda x: x[1]) + sp = (warp / best_us) if (warp and best_us) else None + else: + best_name, best_us, sp = "n/a", None, None + w.writerow([pages, B, split, d.get("candidate_count"), + f"{warp:.3f}" if warp else "", + f"{block:.3f}" if block else "", + f"{kway:.3f}" if kway else "", + f"{prod:.3f}" if prod else "", + best_name, + f"{best_us:.3f}" if best_us else "", + f"{sp:.3f}" if sp else ""]) + print(f"wrote {path}") + + +# --------------------------------------------------------------------------- # +# Adversarial correctness — additional unit-test-style cases. +# --------------------------------------------------------------------------- # +def adversarial_correctness_test(local_mode: int) -> List[dict]: + """Return list of dicts describing each adversarial case + per-method outcome.""" + device = torch.device("cuda") + K, RES_BOS, RES_EOS, B, PAGES = 30, 1, 2, 2, 4096 + cases = [] + + def build_scores(kind: str, dtype) -> torch.Tensor: + n = B * PAGES + if kind == "all_equal": + return torch.full((n,), 1.5, device=device, dtype=dtype) + if kind == "tie_heavy_high8": + x = torch.randn(n, device=device, dtype=torch.float32) + mask = torch.rand(n, device=device) < 0.05 + x[mask] = 100.0 # identical large values - ties for top-K + return x.to(dtype) + if kind == "mixed_sign": + x = torch.randn(n, device=device, dtype=torch.float32) * 10 + return x.to(dtype) + if kind == "threshold_overflow": + x = torch.zeros(n, device=device, dtype=torch.float32) + mask = torch.rand(n, device=device) < 0.10 + x[mask] = 1.0 # > K items at one bin + return x.to(dtype) + raise ValueError(kind) + + ws = make_workspace(B, max_split=32, K_local=32) + for kind in ("all_equal", "tie_heavy_high8", "mixed_sign", "threshold_overflow"): + scores = build_scores(kind, torch.bfloat16) + inp = make_inputs(B, PAGES, K, torch.bfloat16) + inp.scores.copy_(scores) + ref_sets, ref_remapped, thresholds = reference_topk(inp, MAPPING_NONE, 0.5) + # Fused + try: + inp.out.fill_(-1) + call_fused(inp, MAPPING_NONE, 0.5); torch.cuda.synchronize() + ok_f, note_f = check_correctness(inp, ref_sets, ref_remapped, thresholds, MAPPING_NONE, 0.5) + except Exception as e: + ok_f, note_f = False, f"raised: {e}" + # Adaptive split=1 (production path). + try: + inp.out.fill_(-1) + call_adaptive(inp, ws, MAPPING_NONE, 0.5, + forced_split=1, forced_partition=1, local_mode=local_mode) + torch.cuda.synchronize() + ok_a1, note_a1 = check_correctness(inp, ref_sets, ref_remapped, thresholds, MAPPING_NONE, 0.5) + except Exception as e: + ok_a1, note_a1 = False, f"raised: {e}" + # Adaptive split=4 (merge path). + try: + inp.out.fill_(-1) + call_adaptive(inp, ws, MAPPING_NONE, 0.5, + forced_split=4, forced_partition=1, local_mode=local_mode) + torch.cuda.synchronize() + ok_a4, note_a4 = check_correctness(inp, ref_sets, ref_remapped, thresholds, MAPPING_NONE, 0.5) + except Exception as e: + ok_a4, note_a4 = False, f"raised: {e}" + cases.append(dict(case=kind, fused_ok=ok_f, fused_note=note_f, + adapt_split1_ok=ok_a1, adapt_split1_note=note_a1, + adapt_split4_ok=ok_a4, adapt_split4_note=note_a4)) + return cases + + +# --------------------------------------------------------------------------- # +# CSV / report writers. +# --------------------------------------------------------------------------- # +RAW_COLUMNS = [ + "device_name", "sm_count", "dtype", "mapping_mode", "mapping_name", "mapping_power", + "pages", "topk", "batch", "method", + "requested_split", "actual_split", + "local_mode", "merge_mode", "candidate_count", + "actual_path", + "mean_us", "p50_us", "p90_us", "min_us", "max_us", "std_us", + "correctness", "speedup_vs_sglang_fused", "speedup_vs_topk_cu", "notes", +] + + +def write_raw_csv(rows: List[Row], path: Path): + with path.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(RAW_COLUMNS) + for r in rows: + d = asdict(r) + w.writerow([d[c] for c in RAW_COLUMNS]) + print(f"wrote {path} ({len(rows)} rows)") + + +def write_best_adaptive_csv(rows: List[Row], path: Path): + """One row per (pages,K,B,mapping,dtype). Choose best adaptive split among + rows whose actual_path starts with 'adaptive_'. Also emit fused/cub for ref.""" + cols = [ + "pages", "topk", "batch", "mapping_name", "dtype", + "best_adaptive_split", "best_adaptive_local_mode", "best_adaptive_merge_mode", + "best_adaptive_latency_us", "best_adaptive_actual_path", + "sglang_fused_latency_us", "topk_cu_latency_us", + "speedup_best_adaptive_vs_sglang", "speedup_best_adaptive_vs_topk_cu", + "adaptive_wins_vs_sglang", "adaptive_wins_vs_topk_cu", + ] + by_key: Dict[Tuple, Dict[str, object]] = {} + for r in rows: + key = (r.pages, r.topk, r.batch, r.mapping_name, r.dtype) + rec = by_key.setdefault(key, dict(adaptive=[], fused_us=None, cub_us=None)) + if r.method == "topk_sglang_fused" and r.correctness and r.mean_us is not None: + rec["fused_us"] = r.mean_us + elif r.method == "topk_cu" and r.correctness and r.mean_us is not None: + rec["cub_us"] = r.mean_us + elif (r.method == "adaptive_merge" and r.correctness + and r.actual_path.startswith("adaptive_") + and r.mean_us is not None): + rec["adaptive"].append((r.mean_us, r.requested_split, r.local_mode, + r.merge_mode, r.actual_path)) + + with path.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(cols) + for key, rec in sorted(by_key.items()): + pages, K, B, mapping_name, dtype = key + best = min(rec["adaptive"], default=None) + if best is None: + best_us, best_split, best_local, best_merge, best_path = None, None, "n/a", "n/a", "n/a" + else: + best_us, best_split, best_local, best_merge, best_path = best + fused_us = rec["fused_us"] + cub_us = rec["cub_us"] + sp_f = (fused_us / best_us) if (best_us and fused_us) else None + sp_c = (cub_us / best_us) if (best_us and cub_us) else None + wins_f = (sp_f is not None and sp_f >= WIN_THRESHOLD) + wins_c = (sp_c is not None and sp_c >= WIN_THRESHOLD) + w.writerow([ + pages, K, B, mapping_name, dtype, + best_split, best_local, best_merge, + f"{best_us:.3f}" if best_us is not None else "", + best_path, + f"{fused_us:.3f}" if fused_us is not None else "", + f"{cub_us:.3f}" if cub_us is not None else "", + f"{sp_f:.3f}" if sp_f is not None else "", + f"{sp_c:.3f}" if sp_c is not None else "", + wins_f, wins_c, + ]) + print(f"wrote {path}") + + +def k_bucket(K: int) -> str: + if K <= 32: return "small_K(<=32)" + if K <= 512: return "mid_K(64-512)" + return "large_K(>=1024)" + + +def write_advantage_summary_csv(rows: List[Row], path: Path): + """Group by (k_bucket, pages, batch). Count adaptive wins, mean/best speedup, + best split distribution, common actual_path.""" + by_key: Dict[Tuple, Dict[str, list]] = {} + # Build per-cell best_adaptive entries (one per setting). + setting_best: Dict[Tuple, Dict] = {} + for r in rows: + if not (r.method == "adaptive_merge" and r.correctness and r.mean_us is not None): + continue + if not r.actual_path.startswith("adaptive_"): + continue + key = (r.pages, r.topk, r.batch, r.mapping_name, r.dtype) + rec = setting_best.setdefault(key, {"best_us": float("inf"), "split": None, "path": None}) + if r.mean_us < rec["best_us"]: + rec["best_us"] = r.mean_us; rec["split"] = r.requested_split; rec["path"] = r.actual_path + fused_lookup = {(r.pages, r.topk, r.batch, r.mapping_name, r.dtype): r.mean_us + for r in rows + if r.method == "topk_sglang_fused" and r.correctness and r.mean_us is not None} + + for setting, best in setting_best.items(): + pages, K, B, mapping_name, dtype = setting + bucket = k_bucket(K) + gkey = (bucket, pages, B, mapping_name, dtype) + agg = by_key.setdefault(gkey, dict(speedups=[], splits=[], paths=[], total=0, wins=0)) + agg["total"] += 1 + fused_us = fused_lookup.get(setting) + if fused_us: + sp = fused_us / best["best_us"] + agg["speedups"].append(sp) + if sp >= WIN_THRESHOLD: agg["wins"] += 1 + agg["splits"].append(best["split"]) + agg["paths"].append(best["path"]) + + cols = ["k_bucket", "pages", "batch", "mapping_name", "dtype", + "n_settings", "n_adaptive_wins", "win_rate", + "best_speedup", "mean_speedup", "median_speedup", + "best_split_mode", "best_split_distribution", "common_actual_path"] + with path.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(cols) + for gkey, agg in sorted(by_key.items()): + bucket, pages, B, mapping_name, dtype = gkey + sps = agg["speedups"] + splits = [s for s in agg["splits"] if s is not None] + mode_split = (statistics.mode(splits) if splits else None) + split_dist = "|".join(f"{s}:{splits.count(s)}" for s in sorted(set(splits))) if splits else "" + paths = [p for p in agg["paths"] if p] + common_path = statistics.mode(paths) if paths else "" + w.writerow([ + bucket, pages, B, mapping_name, dtype, + agg["total"], agg["wins"], + f"{agg['wins']/agg['total']:.2f}" if agg["total"] else "", + f"{max(sps):.3f}" if sps else "", + f"{statistics.mean(sps):.3f}" if sps else "", + f"{statistics.median(sps):.3f}" if sps else "", + mode_split, split_dist, common_path, + ]) + print(f"wrote {path}") + + +def _fmt(v, nd=2, default=" -"): + if v is None: return default + return f"{v:>{nd+5}.{nd}f}" + + +def print_per_K_tables(rows: List[Row], splits: List[int], mapping_filter: Optional[str] = None): + """Compact per-K terminal tables.""" + keys = sorted({(r.topk, r.mapping_name, r.dtype, r.pages, r.batch) for r in rows}) + by_k_map = {} + for r in rows: + if mapping_filter is not None and r.mapping_name != mapping_filter: + continue + by_k_map.setdefault((r.topk, r.mapping_name, r.dtype), []).append(r) + for (K, mapping_name, dtype), kr in sorted(by_k_map.items()): + print() + print(f"=== K={K} mapping={mapping_name} dtype={dtype} ===") + # Column header + hdr = (f"{'pages':>6} {'B':>3} {'fused_us':>9} {'cub_us':>8} " + + " ".join(f"{'s='+str(s):>9}" for s in splits) + + f" {'best_us':>8} {'split':>5} {'sp_vs_fused':>11}") + print(hdr) + print("-" * len(hdr)) + cells = {} + for r in kr: + cells.setdefault((r.pages, r.batch), {}) + kk = (r.pages, r.batch) + if r.method == "topk_sglang_fused": + cells[kk]["fused"] = r.mean_us + elif r.method == "topk_cu": + cells[kk]["cub"] = r.mean_us + elif r.method == "adaptive_merge": + cells[kk].setdefault("adapt", {})[r.requested_split] = (r.mean_us, r.actual_path) + for (pages, B), c in sorted(cells.items()): + adapt = c.get("adapt", {}) + adapt_us = {s: (adapt.get(s, (None, ""))[0]) for s in splits} + valid = [(s, adapt[s][0]) for s in splits if s in adapt + and adapt[s][0] is not None + and adapt[s][1].startswith("adaptive_")] + if valid: + best_split, best_us = min(valid, key=lambda kv: kv[1]) + else: + best_split, best_us = None, None + fused_us = c.get("fused") + sp = (fused_us / best_us) if (fused_us and best_us) else None + print( + f"{pages:>6d} {B:>3d} {_fmt(fused_us):>9} {_fmt(c.get('cub')):>8} " + + " ".join(f"{_fmt(adapt_us[s]):>9}" for s in splits) + + f" {_fmt(best_us):>8} {str(best_split) if best_split else '-':>5} " + + (f"{sp:>10.3f}x" if sp else f"{'-':>11}") + ) + + +def write_markdown_report(rows: List[Row], device_info: dict, args, path: Path, + splits: List[int], best_csv: Path, advantage_csv: Path, + raw_csv: Path, merge_csv: Optional[Path] = None, + merge_rows: Optional[List[dict]] = None, + adversarial_csv: Optional[Path] = None, + adversarial_rows: Optional[List[dict]] = None): + failed = [r for r in rows if r.correctness is False] + n_total = len(rows) + with path.open("w") as f: + f.write(f"# TopK Setting Sweep Report\n\n") + f.write(f"- Device: **{device_info['name']}** (SMs: {device_info['sm_count']})\n") + f.write(f"- torch: {torch.__version__} CUDA: {torch.version.cuda}\n") + f.write(f"- Pages: {args.pages}\n") + f.write(f"- K: {args.ks}\n") + f.write(f"- Batches: {args.batches}\n") + f.write(f"- Adaptive splits: {splits}\n") + f.write(f"- Mappings: {args.mappings}\n") + f.write(f"- Local mode: {'BLOCK_FULL_SORT' if args.local_mode == LOCAL_BLOCK_FULL_SORT else 'SELECT32_SORT32'}\n") + f.write(f"- warmup={args.warmup}, repeat={args.repeat}\n") + f.write(f"- Total measurements: {n_total}, correctness failures: {len(failed)}\n\n") + + # Per-K compact tables. + f.write("## Per-K latency tables (us)\n\n") + by_k = {} + for r in rows: + by_k.setdefault((r.topk, r.mapping_name, r.dtype), []).append(r) + for (K, mapping_name, dtype), kr in sorted(by_k.items()): + f.write(f"### K={K}, mapping={mapping_name}, dtype={dtype}\n\n") + cells = {} + for r in kr: + kk = (r.pages, r.batch) + cells.setdefault(kk, {}) + if r.method == "topk_sglang_fused": cells[kk]["fused"] = r.mean_us + elif r.method == "topk_cu": cells[kk]["cub"] = r.mean_us + elif r.method == "adaptive_merge": + cells[kk].setdefault("adapt", {})[r.requested_split] = (r.mean_us, r.actual_path) + head = (["pages", "B", "fused_us", "cub_us"] + + [f"adapt_s{s}_us" for s in splits] + + ["best_us", "best_split", "actual_path", "speedup_vs_fused"]) + f.write("| " + " | ".join(head) + " |\n") + f.write("|" + "|".join("---:" for _ in head) + "|\n") + for (pages, B), c in sorted(cells.items()): + adapt = c.get("adapt", {}) + row = [str(pages), str(B), + f"{c.get('fused'):.2f}" if c.get('fused') else "-", + f"{c.get('cub'):.2f}" if c.get('cub') else "-"] + for s in splits: + val = adapt.get(s, (None, ""))[0] + row.append(f"{val:.2f}" if val else "-") + valid = [(s, adapt[s][0], adapt[s][1]) for s in splits + if s in adapt and adapt[s][0] is not None + and adapt[s][1].startswith("adaptive_")] + if valid: + best_split, best_us, best_path = min(valid, key=lambda x: x[1]) + else: + best_split, best_us, best_path = None, None, "-" + fused_us = c.get('fused') + sp = (fused_us / best_us) if (fused_us and best_us) else None + # If everything was fused-fallback, fall back to noting that. + if best_us is None: + fb = next((v for v in adapt.values() if v[0] is not None), (None, "-")) + row += ["-", "-", fb[1], "-"] + else: + row += [f"{best_us:.2f}", str(best_split), best_path, + f"{sp:.3f}x" if sp else "-"] + f.write("| " + " | ".join(row) + " |\n") + f.write("\n") + + # Merge-mode ablation (K=30 only). + if merge_csv is not None and merge_rows: + f.write("## Merge-mode ablation (K=30, merge stage in isolation)\n\n") + f.write("Source: `topk_output_adaptive_workspace_ablation` modes 5/6/7/11.\n\n") + with merge_csv.open() as g: + f.write("```\n" + g.read() + "```\n\n") + + # Region analysis. + f.write("## Parallel-advantage region analysis\n\n") + f.write(f"Win threshold: speedup_vs_sglang >= {WIN_THRESHOLD}.\n\n") + with advantage_csv.open() as g: + f.write("```\n" + g.read() + "```\n\n") + + # Adversarial correctness. + if adversarial_rows is not None: + f.write("## Adversarial correctness cases\n\n") + f.write("| case | fused | adapt s=1 | adapt s=4 |\n") + f.write("|---|:-:|:-:|:-:|\n") + for r in adversarial_rows: + f.write(f"| {r['case']} | {'PASS' if r['fused_ok'] else 'FAIL'} | " + f"{'PASS' if r['adapt_split1_ok'] else 'FAIL'} | " + f"{'PASS' if r['adapt_split4_ok'] else 'FAIL'} |\n") + f.write("\n") + + # Recommended dispatch policy + f.write("## Recommended production dispatch policy\n\n") + # Compute best split per (K bucket, pages) by majority best_split. + best_by_bucket = {} + for r in rows: + if not (r.method == "adaptive_merge" and r.correctness + and r.mean_us is not None + and r.actual_path.startswith("adaptive_")): + continue + key = (k_bucket(r.topk), r.pages, r.batch, r.mapping_name) + ent = best_by_bucket.setdefault(key, {"best": (float('inf'), None)}) + if r.mean_us < ent["best"][0]: + ent["best"] = (r.mean_us, r.requested_split) + bucket_splits = {} + for (bucket, pages, B, mapping), v in best_by_bucket.items(): + bucket_splits.setdefault((bucket, pages, B, mapping), []).append(v["best"][1]) + f.write("| K_bucket | pages | B | mapping | recommended_split |\n") + f.write("|---|---:|---:|---|---:|\n") + for key, splits_list in sorted(bucket_splits.items()): + bucket, pages, B, mapping = key + try: + rec = statistics.mode(splits_list) + except statistics.StatisticsError: + rec = sorted(splits_list)[0] + f.write(f"| {bucket} | {pages} | {B} | {mapping} | {rec} |\n") + f.write("\n- For `large_K(>=1024)` adaptive entry routes to fused (zero-overhead).\n") + f.write("- For `mid_K(64-512)` adaptive entry currently routes to fused; ") + f.write("a dedicated mid-K kernel is future work.\n") + + # Failures. + if failed: + f.write("\n## Correctness failures\n\n") + f.write("| pages | K | B | mapping | method | split | actual_path | notes |\n") + f.write("|---:|---:|---:|---|---|---:|---|---|\n") + for r in failed: + f.write(f"| {r.pages} | {r.topk} | {r.batch} | {r.mapping_name} | " + f"{r.method} | {r.requested_split} | {r.actual_path} | " + f"{r.notes} |\n") + f.write(f"\n## Files\n- raw csv: `{raw_csv}`\n") + f.write(f"- best adaptive csv: `{best_csv}`\n") + f.write(f"- advantage summary csv: `{advantage_csv}`\n") + print(f"wrote {path}") + + +# --------------------------------------------------------------------------- # +# CLI. +# --------------------------------------------------------------------------- # +def parse_args(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--pages", type=int, nargs="+", default=DEFAULT_PAGES) + p.add_argument("--ks", type=int, nargs="+", default=DEFAULT_KS) + p.add_argument("--batches", type=int, nargs="+", default=DEFAULT_BATCHES) + p.add_argument("--splits", type=int, nargs="+", default=DEFAULT_SPLITS) + p.add_argument("--mappings", type=str, nargs="+", default=DEFAULT_MAPPING_NAMES, + help="Mapping mode names (NONE, TRUNC8, POWER, LOG, ASINH, LOG1P, ERF, TANH).") + p.add_argument("--mapping-power", type=float, default=0.5) + p.add_argument("--dtypes", type=str, nargs="+", default=DEFAULT_DTYPES) + p.add_argument("--local-mode", type=int, default=LOCAL_SELECT32_SORT32, + choices=[LOCAL_BLOCK_FULL_SORT, LOCAL_SELECT32_SORT32], + help="0=BLOCK_FULL_SORT, 1=SELECT32_SORT32 (default).") + p.add_argument("--warmup", type=int, default=20) + p.add_argument("--repeat", type=int, default=200) + p.add_argument("--output-dir", type=Path, + default=Path("bench_results") / time.strftime("setting_sweep_%Y%m%d_%H%M%S")) + p.add_argument("--print-tables", action="store_true", + help="Also print per-K latency tables to stdout (large output).") + p.add_argument("--merge-ablation", action="store_true", default=True, + help="Run merge-mode ablation sub-sweep (K=30 only). Default: on.") + p.add_argument("--no-merge-ablation", dest="merge_ablation", action="store_false", + help="Skip the merge-mode ablation sub-sweep.") + p.add_argument("--adversarial", action="store_true", default=True, + help="Run adversarial correctness cases (default: on).") + p.add_argument("--no-adversarial", dest="adversarial", action="store_false", + help="Skip adversarial correctness cases.") + return p.parse_args() + + +def main(): + args = parse_args() + if not torch.cuda.is_available(): + sys.exit("CUDA is required.") + args.output_dir.mkdir(parents=True, exist_ok=True) + + device_info = dict( + name=torch.cuda.get_device_name(0), + sm_count=torch.cuda.get_device_properties(0).multi_processor_count, + ) + print(f"Device: {device_info['name']} SMs={device_info['sm_count']}") + print(f"Output dir: {args.output_dir.resolve()}") + + # Validate mappings. + mapping_modes = [] + for name in args.mappings: + if name not in MAPPING_BY_NAME: + sys.exit(f"unknown mapping name: {name} (valid: {list(MAPPING_BY_NAME)})") + mapping_modes.append(MAPPING_BY_NAME[name]) + + # Pre-allocate workspace large enough for the largest configuration. + B_max = max(args.batches) + ws = make_workspace(B_max=B_max, max_split=max(args.splits), K_local=32) + + configs = [(pages, K, B, mode, dtype) + for pages in args.pages + for K in args.ks + for B in args.batches + for mode in mapping_modes + for dtype in args.dtypes] + print(f"Configs: {len(configs)} (each runs fused + cub + {len(args.splits)} adaptive)") + print(f"Splits: {args.splits} warmup={args.warmup} repeat={args.repeat}") + + rows: List[Row] = [] + t0 = time.time() + for i, cfg in enumerate(configs, 1): + pages, K, B, mode, dtype = cfg + if i % 5 == 0 or i == 1: + print(f"[{i:3d}/{len(configs)}] pages={pages} K={K} B={B} " + f"mapping={MAPPING_NAMES[mode]} dtype={dtype} " + f"(elapsed {time.time()-t0:.1f}s)") + rows.extend(run_one_setting( + pages, K, B, mode, args.mapping_power, dtype, + args.splits, args.local_mode, + args.warmup, args.repeat, ws, + device_info["name"], device_info["sm_count"], + )) + print(f"\nSweep complete in {time.time()-t0:.1f}s. rows={len(rows)}") + + # Output files. + raw_csv = args.output_dir / "topk_setting_sweep_raw.csv" + best_csv = args.output_dir / "topk_setting_sweep_best_adaptive.csv" + advantage_csv = args.output_dir / "topk_parallel_advantage_summary.csv" + merge_csv = args.output_dir / "topk_merge_mode_summary.csv" + adversarial_csv = args.output_dir / "topk_adversarial_correctness.csv" + report_md = args.output_dir / "topk_setting_sweep_report.md" + + write_raw_csv(rows, raw_csv) + write_best_adaptive_csv(rows, best_csv) + write_advantage_summary_csv(rows, advantage_csv) + + # Merge-mode ablation (K=30 only). + merge_rows = [] + if args.merge_ablation: + print("\nMerge-mode ablation (K=30, ablation harness):") + merge_rows = run_merge_ablation( + pages_list=args.pages, batches=args.batches, + splits=args.splits, warmup=args.warmup, repeat=args.repeat, + ws=ws, device_name=device_info["name"], sm_count=device_info["sm_count"], + ) + print(f" collected {len(merge_rows)} merge-only timings") + write_merge_mode_csv(merge_rows, merge_csv) + + # Adversarial correctness check. + adv_rows = [] + if args.adversarial: + print("\nAdversarial correctness check (K=30, MAPPING_NONE, B=2, pages=4096):") + adv_rows = adversarial_correctness_test(args.local_mode) + for r in adv_rows: + print(f" case={r['case']:>22} fused={r['fused_ok']} " + f"adapt_s1={r['adapt_split1_ok']} adapt_s4={r['adapt_split4_ok']}") + with adversarial_csv.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(["case", "fused_ok", "fused_note", + "adapt_split1_ok", "adapt_split1_note", + "adapt_split4_ok", "adapt_split4_note"]) + for r in adv_rows: + w.writerow([r["case"], r["fused_ok"], r["fused_note"], + r["adapt_split1_ok"], r["adapt_split1_note"], + r["adapt_split4_ok"], r["adapt_split4_note"]]) + print(f"wrote {adversarial_csv}") + + write_markdown_report(rows, device_info, args, report_md, + args.splits, best_csv, advantage_csv, raw_csv, + merge_csv=merge_csv, merge_rows=merge_rows, + adversarial_csv=adversarial_csv, adversarial_rows=adv_rows) + + # Optional terminal tables. + if args.print_tables: + print_per_K_tables(rows, args.splits) + + # Short summary. + print("\n" + "=" * 70) + print(f"Files written under: {args.output_dir.resolve()}") + print(f" raw : {raw_csv.name}") + print(f" best_adapt: {best_csv.name}") + print(f" advantage : {advantage_csv.name}") + print(f" report : {report_md.name}") + failed = [r for r in rows if r.correctness is False] + print(f"Correctness failures: {len(failed)}") + if failed: + for r in failed[:10]: + print(f" - pages={r.pages} K={r.topk} B={r.batch} " + f"map={r.mapping_name} method={r.method} split={r.requested_split} " + f"path={r.actual_path}: {r.notes}") + if len(failed) > 10: + print(f" ... and {len(failed) - 10} more (see raw csv)") + # Quick win-region rollup. + n_adaptive = sum(1 for r in rows if r.method == "adaptive_merge" + and r.actual_path.startswith("adaptive_") and r.correctness) + n_wins = sum(1 for r in rows if r.method == "adaptive_merge" + and r.actual_path.startswith("adaptive_") and r.correctness + and r.speedup_vs_sglang_fused is not None + and r.speedup_vs_sglang_fused >= WIN_THRESHOLD) + print(f"Adaptive-rows (correct, real adaptive path): {n_adaptive}") + print(f" -> wins vs fused (>= {WIN_THRESHOLD}x): {n_wins} " + f"({100.0 * n_wins / max(n_adaptive,1):.1f}%)") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/calibrate_topk.py b/benchmarks/calibrate_topk.py new file mode 100644 index 00000000..f3343aaa --- /dev/null +++ b/benchmarks/calibrate_topk.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +""" +Offline calibration for TopK mapping modes 1 (LUT CDF) and 2 (quantile). + +Runs the model on real data with hit-rate profiling enabled, collects score +histograms from the topk_sglang kernel, and generates: + - lut.npy : uint8[256] CDF-equalized LUT for mapping mode 1 + - quantiles.npy: float32[256] quantile breakpoints for mapping mode 2 + +Usage: + python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 --mem 0.7 \ + --output-dir calibration_output/ +""" + +import argparse +import json +import os +import shutil +import sys + +import numpy as np + +# Add project root to path so we can import from benchmarks/ +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + +from benchmarks.profile_topk_distribution import ( + compute_lut_from_histogram, + generate_tables_from_histograms, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Offline calibration for TopK mapping modes 1 & 2" + ) + parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-1.7B") + parser.add_argument("--topk-val", type=int, default=30) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--mem", type=float, default=0.7) + parser.add_argument( + "--max-total-tokens", + type=int, + default=1048576, + help="Hard cap on KV pool token slots (ServerArgs.max_total_tokens). " + "Block-sparse profiling uses a small bytes/token estimate, so the auto " + "budget can be huge on large GPUs; VTXGraphAttnBackend then allocates " + "dense bf16 sparse_prefill K/V buffers proportional to this cap (~4 KiB per " + "token per buffer). For offline calibration, a few hundred K to 1M tokens " + "is usually enough.", + ) + parser.add_argument( + "--min-free-disk-gb", + type=float, + default=20.0, + help="Abort if the filesystem for --output-dir (and HF cache, typically the same) " + "has less than this many GiB free. First-time model downloads need many GiB. " + "Set to 0 to disable.", + ) + parser.add_argument("--kv-cache-dtype", type=str, default="auto") + parser.add_argument("--topk-type", type=str, default="sglang") + parser.add_argument("--num-prompts", type=int, default=16, + help="Number of calibration prompts to use (default: 16)") + parser.add_argument("--output-dir", type=str, default="calibration_output/") + parser.add_argument("--vortex-module-name", type=str, default="block_sparse_attention") + parser.add_argument( + "--watchdog-timeout", + type=float, + default=None, + metavar="SEC", + help="SGLang scheduler watchdog (seconds). Forward batches must complete within this time. " + "Default: engine default (300). Use 0 to disable when using this repo's SGLang fork.", + ) + args = parser.parse_args() + + # Classic HTTP downloads avoid XET chunk reconstruction ("Background writer channel + # closed") that often surfaces when the disk is full or nearly full. + if "HF_HUB_DISABLE_XET" not in os.environ: + os.environ["HF_HUB_DISABLE_XET"] = "1" + + if args.min_free_disk_gb > 0: + check_path = os.path.abspath(args.output_dir) + while check_path and not os.path.isdir(check_path): + parent = os.path.dirname(check_path) + if parent == check_path: + check_path = os.getcwd() + break + check_path = parent + usage = shutil.disk_usage(check_path) + free_gb = usage.free / (1024.0**3) + if free_gb < args.min_free_disk_gb: + raise SystemExit( + f"[calibrate] ERROR: Only {free_gb:.1f} GiB free on filesystem containing " + f"{args.output_dir!r} (checked from {check_path!r}). " + f"Need at least ~{args.min_free_disk_gb} GiB for Hugging Face weights, hub cache, " + f"and logs. Free disk space or point HF_HOME at a larger disk. " + f"To skip this check: --min-free-disk-gb 0" + ) + + # Lazy imports to avoid slow startup when just checking --help + import sglang as sgl + import torch + import vortex_torch + + os.makedirs(args.output_dir, exist_ok=True) + + print(f"[calibrate] Launching engine with hit-rate profiling enabled...") + engine_kwargs = dict( + model_path=args.model_name, + disable_cuda_graph=True, + page_size=args.page_size, + vortex_topk_val=args.topk_val, + disable_overlap_schedule=True, + attention_backend="flashinfer", + enable_vortex_sparsity=True, + vortex_page_reserved_bos=1, + vortex_page_reserved_eos=2, + vortex_layers_skip=list(range(1)), + vortex_module_name=args.vortex_module_name, + vortex_max_seq_lens=12288, + mem_fraction_static=args.mem, + max_total_tokens=args.max_total_tokens, + kv_cache_dtype=args.kv_cache_dtype, + vortex_topk_type=args.topk_type, + vortex_topk_mapping_mode=0, # Use mode 0 during calibration + vortex_topk_histogram=True, # Enable histogram collection + ) + if args.watchdog_timeout is not None: + engine_kwargs["watchdog_timeout"] = args.watchdog_timeout + llm = sgl.Engine(**engine_kwargs) + + # Clear any residual histograms in the worker process + llm.clear_topk_histograms() + + # Load calibration prompts + prompts_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "examples", "amc23.jsonl" + ) + with open(prompts_path, "r", encoding="utf-8") as f: + all_requests = [json.loads(line) for line in f] + + # Use up to num_prompts + requests = all_requests[:args.num_prompts] + prompts = [req["prompt"] for req in requests] + + print(f"[calibrate] Running {len(prompts)} calibration prompts...") + sampling_params = { + "temperature": 0.6, + "top_p": 0.95, + "top_k": 20, + "max_new_tokens": 8192, + } + llm.generate(prompts, sampling_params) + + # Collect histograms via RPC from worker process + histograms = llm.get_topk_histograms() + print(f"[calibrate] Collected {len(histograms)} histogram batches") + + if len(histograms) == 0: + print("[calibrate] ERROR: No histograms collected. " + "Ensure topk_type='sglang' and vortex_topk_histogram=True.", + file=sys.stderr) + llm.shutdown() + sys.exit(1) + + # Stack all histograms: each is [eff_bs, 256], concatenate along batch dim + all_hists = torch.cat(histograms, dim=0).numpy() # [total_samples, 256] + print(f"[calibrate] Total histogram samples: {all_hists.shape[0]}") + + # Regression guard: refuse to save a collapsed histogram. A healthy + # calibration touches tens to hundreds of bins; if almost everything lands + # in a single bin, the scoring pipeline silently produced zero scores + # (see the Sgl_Decode_Plan_Workload_Kernel `w > topk_val` bug fixed in + # csrc/utils_sglang.cu). Saving 20+ GB of all-zeros wastes disk and poisons + # downstream benches, so fail loudly here. + _pooled = all_hists.sum(axis=0).astype(np.float64) + _total = float(_pooled.sum()) + if _total > 0: + _top_frac = float(_pooled.max()) / _total + _nz_bins = int((_pooled > 0).sum()) + if _top_frac > 0.95 or _nz_bins < 5: + llm.shutdown() + raise SystemExit( + f"[calibrate] ERROR: degenerate histogram — top bin holds " + f"{_top_frac:.2%} of mass, only {_nz_bins}/256 bins nonzero. " + f"The scoring pipeline is likely not running (check " + f"winfo_num_workloads in plan_decode, or `w > topk_val` in " + f"Sgl_Decode_Plan_Workload_Kernel). Refusing to save to avoid " + f"writing a useless multi-GB file." + ) + + # --- Generate LUT (mode 1) --- + # Aggregate histogram across all samples + avg_histogram = all_hists.mean(axis=0) + lut = compute_lut_from_histogram(avg_histogram) + lut_path = os.path.join(args.output_dir, "lut.npy") + np.save(lut_path, lut) + print(f"[calibrate] Saved LUT to {lut_path} (shape={lut.shape}, dtype={lut.dtype})") + + # --- Generate quantiles (mode 2) --- + # Use bin centers as proxy scores weighted by histogram counts + bin_centers = np.arange(256, dtype=np.float32) + # Expand histogram counts into a weighted score distribution + total_counts = avg_histogram.astype(np.float64) + total = total_counts.sum() + if total > 0: + cdf = np.cumsum(total_counts) / total + # Invert CDF to get quantile breakpoints in [0, 255] space + percentiles = np.linspace(0, 1, 256) + quantiles = np.interp(percentiles, cdf, bin_centers).astype(np.float32) + else: + quantiles = bin_centers.copy() + + quantiles_path = os.path.join(args.output_dir, "quantiles.npy") + np.save(quantiles_path, quantiles) + print(f"[calibrate] Saved quantiles to {quantiles_path} (shape={quantiles.shape}, dtype={quantiles.dtype})") + + # Save raw histograms for debugging + raw_path = os.path.join(args.output_dir, "raw_histograms.npy") + np.save(raw_path, all_hists) + print(f"[calibrate] Saved raw histograms to {raw_path} (shape={all_hists.shape})") + + # Cleanup + llm.clear_topk_histograms() + llm.shutdown() + print(f"[calibrate] Done. Output files in {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profile_adaptive_overhead.py b/benchmarks/profile_adaptive_overhead.py new file mode 100644 index 00000000..5c0b2118 --- /dev/null +++ b/benchmarks/profile_adaptive_overhead.py @@ -0,0 +1,212 @@ +"""Decompose the adaptive split-2 TopK kernel's latency into Phase-1, +Phase-2, and barrier+launch overhead — and compare against the naive +CUB sort (topk.cu) and the single-CTA radix baseline (topk_sglang.cu). + +No remap (mode=0), bfloat16 scores only, to keep the comparison clean. + +Usage: + python benchmarks/profile_adaptive_overhead.py [--gpu 4] +""" +from __future__ import annotations + +import argparse +import json +import math +from typing import Dict, List + +import torch + +from vortex_torch_C import ( + topk_output, + topk_output_sglang, + topk_output_adaptive, + topk_adaptive_phase1_only, + topk_adaptive_phase2_only, +) + + +def make_inputs(bs: int, pages: int, K: int, reserved_bos: int = 1, reserved_eos: int = 1, + device: str = "cuda") -> Dict[str, torch.Tensor]: + per_row = pages + reserved_bos + reserved_eos + dense_kv_indptr = torch.arange( + 0, (bs + 1) * per_row, per_row, device=device, dtype=torch.int32) + dense_kv_indices = torch.arange(bs * per_row, device=device, dtype=torch.int32) + per_sparse = K + reserved_bos + reserved_eos + sparse_kv_indptr = torch.arange( + 0, (bs + 1) * per_sparse, per_sparse, device=device, dtype=torch.int32) + sparse_kv_indices = torch.zeros(bs * per_sparse, device=device, dtype=torch.int32) + x = torch.randn(bs * per_row, device=device, dtype=torch.bfloat16) + partial_scores = torch.empty(bs * 2 * K, device=device, dtype=torch.float32) + partial_indices = torch.empty(bs * 2 * K, device=device, dtype=torch.int32) + return dict( + x=x, + dense_kv_indptr=dense_kv_indptr, + dense_kv_indices=dense_kv_indices, + sparse_kv_indptr=sparse_kv_indptr, + sparse_kv_indices=sparse_kv_indices, + partial_scores=partial_scores, + partial_indices=partial_indices, + ) + + +def time_kernel(fn, args, warmup: int = 20, repeat: int = 200) -> float: + """Return mean ms.""" + for _ in range(warmup): + fn(*args) + torch.cuda.synchronize() + starts = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + for i in range(repeat): + starts[i].record() + fn(*args) + ends[i].record() + torch.cuda.synchronize() + times = [s.elapsed_time(e) for s, e in zip(starts, ends)] + return sum(times) / len(times) + + +def run_config(bs: int, pages: int, K: int, reserved_bos: int = 1, reserved_eos: int = 1, + warmup: int = 20, repeat: int = 200) -> Dict[str, float]: + inp = make_inputs(bs, pages, K, reserved_bos, reserved_eos) + + # --- baseline: topk_output_sglang (single-CTA radix-select, mode=0) --- + sglang_args = ( + inp["x"], inp["dense_kv_indptr"], inp["sparse_kv_indptr"], + inp["dense_kv_indices"], inp["sparse_kv_indices"], + bs, K, reserved_bos, reserved_eos, pages, + ) + sglang_ms = time_kernel(topk_output_sglang, sglang_args, warmup, repeat) + + # --- naive CUB sort: topk_output (only if pages <= 8192 — template ladder limit) --- + naive_ms = float("nan") + if pages <= 8192: + naive_args = ( + inp["x"], inp["dense_kv_indptr"], inp["dense_kv_indices"], + inp["sparse_kv_indptr"], inp["sparse_kv_indices"], + bs, K, reserved_bos, reserved_eos, pages, + ) + try: + naive_ms = time_kernel(topk_output, naive_args, warmup, repeat) + except RuntimeError as e: + print(f"[naive skip] bs={bs} pages={pages} K={K}: {e}") + + # --- adaptive full --- + adaptive_args = ( + inp["x"], inp["dense_kv_indptr"], inp["sparse_kv_indptr"], + inp["dense_kv_indices"], inp["sparse_kv_indices"], + bs, K, reserved_bos, reserved_eos, pages, + 0, # mapping_mode = NONE + 0.5, # mapping_power (unused) + ) + adaptive_ms = time_kernel(topk_output_adaptive, adaptive_args, warmup, repeat) + + # --- adaptive Phase 1 only --- + p1_args = ( + inp["x"], inp["dense_kv_indptr"], inp["dense_kv_indices"], + inp["partial_scores"], inp["partial_indices"], + bs, K, reserved_bos, reserved_eos, pages, + ) + p1_ms = time_kernel(topk_adaptive_phase1_only, p1_args, warmup, repeat) + + # --- adaptive Phase 2 only (workspace pre-populated by the last p1 call) --- + p2_args = ( + inp["partial_scores"], inp["partial_indices"], + inp["sparse_kv_indptr"], inp["sparse_kv_indices"], + bs, K, reserved_bos, + ) + p2_ms = time_kernel(topk_adaptive_phase2_only, p2_args, warmup, repeat) + + overhead_ms = adaptive_ms - (p1_ms + p2_ms) + + return { + "bs": bs, "pages": pages, "K": K, + "naive_ms": naive_ms, + "sglang_ms": sglang_ms, + "adaptive_ms": adaptive_ms, + "phase1_ms": p1_ms, + "phase2_ms": p2_ms, + "p1_plus_p2_ms": p1_ms + p2_ms, + "overhead_ms": overhead_ms, + "overhead_frac": overhead_ms / adaptive_ms if adaptive_ms else 0.0, + "adaptive_vs_sglang": adaptive_ms / sglang_ms if sglang_ms else float("nan"), + } + + +def _fmt(v, w=9): + if isinstance(v, float) and math.isnan(v): + return f"{'—':>{w}s}" + if isinstance(v, float): + return f"{v:>{w}.4f}" + return f"{str(v):>{w}s}" + + +def print_table(rows: List[dict]) -> None: + hdr = (f"{'bs':>3s} {'pages':>6s} {'K':>5s} {'naive':>9s} {'sglang':>9s} " + f"{'adaptive':>9s} {'phase1':>9s} {'phase2':>9s} {'p1+p2':>9s} " + f"{'overhead':>9s} {'ovh%':>6s} {'a/sglang':>9s}") + sep = "-" * len(hdr) + print(sep) + print(hdr) + print(sep) + for r in rows: + ovh_pct = 100.0 * r["overhead_frac"] + print(f"{r['bs']:>3d} {r['pages']:>6d} {r['K']:>5d} " + f"{_fmt(r['naive_ms'])} {_fmt(r['sglang_ms'])} " + f"{_fmt(r['adaptive_ms'])} {_fmt(r['phase1_ms'])} {_fmt(r['phase2_ms'])} " + f"{_fmt(r['p1_plus_p2_ms'])} {_fmt(r['overhead_ms'])} " + f"{ovh_pct:>5.1f}% {r['adaptive_vs_sglang']:>8.3f}×") + print(sep) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--gpu", type=int, default=4) + p.add_argument("--warmup", type=int, default=20) + p.add_argument("--repeat", type=int, default=200) + p.add_argument("--output-json", type=str, default=None) + args = p.parse_args() + + torch.cuda.set_device(args.gpu) + + # Sweep: small/medium/large bs × pages × K matrix exercising both + # the light path (K=30) and heavy path (K=2048). + configs = [ + # bs, pages, K + (1, 4096, 30), + (1, 16384, 30), + (1, 32768, 30), + (4, 4096, 30), + (4, 16384, 30), + (4, 32768, 30), + (16, 4096, 30), + (16, 32768, 30), + # heavy + (1, 4096, 2048), + (1, 16384, 2048), + (1, 32768, 2048), + (4, 4096, 2048), + (4, 16384, 2048), + (4, 32768, 2048), + (16, 4096, 2048), + (16, 32768, 2048), + ] + + rows = [] + for (bs, pages, K) in configs: + try: + row = run_config(bs, pages, K, warmup=args.warmup, repeat=args.repeat) + rows.append(row) + print(f"[done] bs={bs} pages={pages} K={K}") + except RuntimeError as e: + print(f"[skip] bs={bs} pages={pages} K={K}: {e}") + + print_table(rows) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(rows, f, indent=2) + print(f"Saved: {args.output_json}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profile_parallel_vs_fused.py b/benchmarks/profile_parallel_vs_fused.py new file mode 100644 index 00000000..ecfd8723 --- /dev/null +++ b/benchmarks/profile_parallel_vs_fused.py @@ -0,0 +1,99 @@ +""" +Driver for Nsight Compute profiling of the parallel vs fused TopK +kernels. Designed to be launched under `ncu` with --launch-skip and +--launch-count to isolate a specific kernel launch from warmup. + +The script does exactly: + args.warmup matching-kernel launches (skipped by ncu --launch-skip) + args.iters matching-kernel launches (captured by ncu --launch-count) + +Pair --launch-skip/--launch-count with --kernel-name so unrelated +launches (torch initializers, cublas, etc.) don't pollute the counts. +""" +import argparse +import torch +from vortex_torch_C import ( + topk_output_sglang_fused, + topk_output_sglang_parallel, +) + + +def make_inputs(eff_bs: int, pages: int, topk: int): + reserved = 0 + dense_indptr = torch.arange( + 0, (eff_bs + 1) * pages, pages, dtype=torch.int32, device="cuda" + ) + sparse_indptr = torch.arange( + 0, (eff_bs + 1) * topk, topk, dtype=torch.int32, device="cuda" + ) + dense_indices = torch.arange(eff_bs * pages, dtype=torch.int32, device="cuda") + torch.manual_seed(0) + x = torch.randn(eff_bs * pages, 1, 1, dtype=torch.bfloat16, device="cuda") + out = torch.zeros(eff_bs * topk, dtype=torch.int32, device="cuda") + return x, dense_indptr, sparse_indptr, dense_indices, out, reserved + + +def main(): + p = argparse.ArgumentParser() + p.add_argument( + "--config", + choices=["A", "B"], + required=True, + help="A: topk=2048 pages=32K ; B: topk=30 pages=2K", + ) + p.add_argument("--eff-bs", type=int, default=1) + p.add_argument( + "--mode", type=int, choices=[15, 16], required=True, + help="15=MAPPING_SHIFT_POW2, 16=MAPPING_SHIFT_POW3", + ) + p.add_argument( + "--power", type=float, default=0.5, + help="Pivot (p) for the shift_pow transforms. 0.5 matches the " + "autotune default for Qwen3-1.7B softmax scores.", + ) + p.add_argument("--num-splits", type=int, default=4) + p.add_argument("--kernel", choices=["fused", "parallel"], required=True) + p.add_argument("--warmup", type=int, default=20) + p.add_argument("--iters", type=int, default=1) + args = p.parse_args() + + pages, topk = (32768, 2048) if args.config == "A" else (2048, 30) + x, dense_indptr, sparse_indptr, dense_indices, out, reserved = make_inputs( + args.eff_bs, pages, topk + ) + + if args.kernel == "fused": + def call(): + topk_output_sglang_fused( + x, dense_indptr, sparse_indptr, dense_indices, out, + args.eff_bs, topk, reserved, reserved, pages, + args.mode, args.power, None, None, + ) + else: + def call(): + topk_output_sglang_parallel( + x, dense_indptr, sparse_indptr, dense_indices, out, + args.eff_bs, topk, reserved, reserved, pages, + args.num_splits, args.mode, args.power, None, None, + ) + + # Warmup: specialised kernel is JIT-instantiated and cudaFuncSetAttribute + # is cached; these launches dominate the first-call overhead and we want + # ncu to skip past them. + for _ in range(args.warmup): + call() + torch.cuda.synchronize() + + # Profiled region. Wrap in NVTX so the same script is also useful under + # Nsight Systems (nsys) if you prefer a timeline view. + torch.cuda.nvtx.range_push( + f"profile-{args.kernel}-mode{args.mode}-cfg{args.config}-eff{args.eff_bs}" + ) + for _ in range(args.iters): + call() + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profile_topk_distribution.py b/benchmarks/profile_topk_distribution.py new file mode 100644 index 00000000..bea911b0 --- /dev/null +++ b/benchmarks/profile_topk_distribution.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +""" +Profile TopK bin distribution and generate mapping tables. + +This script collects Stage 1 (8-bit coarse histogram) distributions from +the topk_sglang kernel and generates LUT/quantile mapping tables that +can be used to equalize the bin distribution for improved sorting efficiency. + +Usage: + python scripts/profile_topk_distribution.py \ + --model-name Qwen/Qwen3-1.7B \ + --output mapping_tables.npz \ + --num-prompts 32 \ + --mem 0.7 + +Output (.npz): + lut_tables: [num_collected, 256] uint8 - CDF-equalized LUT per sample + quantile_tables: [num_collected, 256] float32 - quantile breakpoints per sample + raw_histograms: [num_collected, 256] int32 - raw bin histograms +""" + +import argparse +import numpy as np +import torch + + +def compute_lut_from_histogram(histogram: np.ndarray) -> np.ndarray: + """Compute CDF-equalized LUT from a 256-bin histogram. + + Args: + histogram: [256] int array of bin counts + + Returns: + lut: [256] uint8 array where lut[i] = floor(CDF(i) * 255) + """ + cdf = np.cumsum(histogram).astype(np.float64) + total = cdf[-1] + if total == 0: + return np.arange(256, dtype=np.uint8) + cdf_normalized = cdf / total + lut = np.floor(cdf_normalized * 255).astype(np.uint8) + return lut + + +def compute_quantiles_from_scores(scores: np.ndarray, num_quantiles: int = 256) -> np.ndarray: + """Compute quantile breakpoints from raw float scores. + + Args: + scores: 1D array of float scores + num_quantiles: number of quantile bins (default 256) + + Returns: + quantiles: [num_quantiles] float32 array of sorted breakpoints + """ + if len(scores) == 0: + return np.zeros(num_quantiles, dtype=np.float32) + percentiles = np.linspace(0, 100, num_quantiles) + quantiles = np.percentile(scores, percentiles).astype(np.float32) + return quantiles + + +def generate_tables_from_histograms(histograms: np.ndarray) -> dict: + """Generate LUT and quantile tables from collected histograms. + + Args: + histograms: [N, 256] int32 array of bin histograms + + Returns: + dict with 'lut_tables' and 'aggregate_lut' + """ + N = histograms.shape[0] + lut_tables = np.zeros((N, 256), dtype=np.uint8) + + for i in range(N): + lut_tables[i] = compute_lut_from_histogram(histograms[i]) + + # Aggregate: average histogram across all samples + avg_histogram = histograms.mean(axis=0) + aggregate_lut = compute_lut_from_histogram(avg_histogram) + + return { + 'lut_tables': lut_tables, + 'aggregate_lut': aggregate_lut, + } + + +def main(): + parser = argparse.ArgumentParser( + description="Profile TopK bin distribution and generate mapping tables") + parser.add_argument("--output", type=str, default="mapping_tables.npz", + help="Output .npz file path") + parser.add_argument("--histograms-input", type=str, default=None, + help="Load pre-collected histograms from .npy file instead of running inference") + parser.add_argument("--scores-input", type=str, default=None, + help="Load pre-collected raw scores from .npy for quantile computation") + args = parser.parse_args() + + results = {} + + if args.histograms_input: + print(f"Loading histograms from {args.histograms_input}") + histograms = np.load(args.histograms_input) + if histograms.ndim == 1: + histograms = histograms.reshape(1, -1) + results['raw_histograms'] = histograms + + tables = generate_tables_from_histograms(histograms) + results.update(tables) + + if args.scores_input: + print(f"Loading scores from {args.scores_input}") + scores = np.load(args.scores_input) + quantiles = compute_quantiles_from_scores(scores.flatten()) + results['quantile_table'] = quantiles + + if not results: + print("No input provided. Use --histograms-input or --scores-input.") + print("\nTo collect histograms, use the topk_profile_histogram() function from vortex_torch_C:") + print(" from vortex_torch_C import topk_profile_histogram") + print(" histograms = torch.zeros(eff_batch_size, 256, dtype=torch.int32, device='cuda')") + print(" topk_profile_histogram(scores, dense_kv_indptr, histograms, eff_batch_size, bos, eos)") + print(" np.save('histograms.npy', histograms.cpu().numpy())") + return + + np.savez(args.output, **results) + print(f"Saved mapping tables to {args.output}") + for key, val in results.items(): + print(f" {key}: shape={val.shape}, dtype={val.dtype}") + + +if __name__ == "__main__": + main() diff --git a/csrc/archived/README.md b/csrc/archived/README.md new file mode 100644 index 00000000..6e08a1dc --- /dev/null +++ b/csrc/archived/README.md @@ -0,0 +1,19 @@ +# Archived TopK kernels + +These files are **not compiled** (not listed in `setup.py`) and are kept only +for historical reference. + +- `topk_slgang_ori.cu` — the original SGLang TopK reference kernel (typo in + the filename is intentional, matches the upstream commit it was adapted + from). Superseded by the fused `fast_topk_vortex` path in + `../topk_sglang.cu`. +- `topk_sglang_ori_fastpath.cu` — the `fast_topk_ori` / + `TopKOutput_Ori_Kernel` / `launch_ori_kernel` code extracted out of + `../topk_sglang.cu`. It was the "zero mapping overhead" fast path with + flexible `radix_bits` (4–10). We no longer test it — mode 0 now goes + through the standard fused kernel with `MAPPING_NONE`, which pays no + mapping overhead because `mapped_convert_to_uint8` degenerates to + `convert_to_uint8` in that branch. + +If you need to resurrect any of this, add the `.cu` to `setup.py` and +re-export its entry points from `../register.cc` / `../register.h`. diff --git a/csrc/archived/fast_topk_vortex_prepass.cu b/csrc/archived/fast_topk_vortex_prepass.cu new file mode 100644 index 00000000..5b743f19 --- /dev/null +++ b/csrc/archived/fast_topk_vortex_prepass.cu @@ -0,0 +1,525 @@ +// Archived: not compiled. See csrc/archived/README.md +// +// fast_topk_vortex — the heavy fused remap+topk kernel with auto-range, +// pivot, tail-window, topk-window pre-passes and LUT/quantile support. +// Extracted from csrc/topk_sglang.cu as part of the remap-benchmark refactor. +// Replaced by a lean fast_topk_clean_fused that applies a simple element-wise +// transform (from topk_mapping.cuh apply_transform) in Stage-1 bucketing — +// no pre-pass, no LUT, no auto-range. +// +// References types/constants from its former translation unit (TopKMappingParams, +// needs_*, mapped_convert_to_uint8, kSmem, kThreadsPerBlock, COUNTER_*). This +// file will not compile standalone; kept for history only. + +// ====================================================================== +// Templated version of fast_topk_cuda_tl with mapping support: +// - ScoreT: float or __nv_bfloat16 +// - StopAfterStage1: return after Stage 1 route/filter (for profiling) +// - WriteCounters: write diagnostic counters to global memory + +// - mapping: configurable value-remapping for Stage 1 bin assignment +template +__device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams& mapping, + int* counters = nullptr) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + // Shared memory for mapping LUT / quantiles (loaded once per block) + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + // Auto-range for transform modes (3/4/6/7) + __shared__ float s_range_min, s_range_inv_range; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max of transformed values for linear bucketing. + // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; + // the approximated range may miss extreme outliers but Stage 2 uses raw + // float bits for exact ordering, so correctness is preserved. + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { + float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + // Cross-warp reduction via shared memory + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean of all elements, store in s_range_min. + // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering + // around the mean helps distribute values more evenly across bins. + float local_sum = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(input[idx + row_start]); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(length); // mean as pivot + s_range_inv_range = 0.0f; + } + } + __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) + // and local_max via a sampled quantile estimator. All 256 coarse bins + // are then allocated to [tau_low, local_max]; scores below tau_low + // collapse into bin 0 via linear_map_to_uint8 clamping. + constexpr int MAX_SAMPLES = 1024; + __shared__ float s_samples[MAX_SAMPLES]; + __shared__ int s_sample_count; + + if (tx == 0) s_sample_count = 0; + __syncthreads(); + + // Compute sampling stride so we collect ~MAX_SAMPLES from the segment + const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; + const int sample_stride = max(desired_stride, 1); + + // Each thread samples elements and finds local_max simultaneously + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count, 1); + if (slot < MAX_SAMPLES) { + s_samples[slot] = val; + } + } + + // Reduce local_max across block + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_tw[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_tw[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_tw[0]; + + int nsamp = min(s_sample_count, MAX_SAMPLES); + + // Simple odd-even transposition sort on the sample buffer. + // nsamp <= 1024, and we have 1024 threads, so each thread + // handles one element. O(nsamp) parallel rounds suffice. + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + // Even phase: compare (0,1), (2,3), ... + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + // Odd phase: compare (1,2), (3,4), ... + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + // Estimate tau_low = Q(1 - rho * k / n) + if (tx == 0) { + float rho = mapping.power_exp; // reused as tail expansion factor + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float frac = 1.0f - rho * float(k) / float(length); + frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + // Too few samples or the tail covers everything: full range + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; + } + + // Fallback: if tau_low >= local_max, use full-range linear mapping + if (tau_low >= local_max) { + // Find the actual minimum from sorted samples + tau_low = (nsamp > 0) ? s_samples[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); + } else if (needs_topk_window(mapping.mode)) { + // Topk-window pre-pass with streaming variance heuristic. + // tau_low = max - rho * sigma * sqrt(2 * log(n/k)) + float local_max = -__FLT_MAX__; + float local_sum = 0.0f, local_sum_sq = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + local_sum += val; + local_sum_sq += val * val; + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { + s_warp_maxs_tw2[warp_id] = local_max; + s_warp_sums_tw2[warp_id] = local_sum; + s_warp_sq_tw2[warp_id] = local_sum_sq; + } + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw2[tx]; + local_sum = s_warp_sums_tw2[tx]; + local_sum_sq = s_warp_sq_tw2[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float n = float(length); + float mean = local_sum / n; + float var = local_sum_sq / n - mean * mean; + float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; + float ratio = n / fmaxf(float(k), 1.0f); + float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); + float tau_low = local_max - rho * sigma * z; + if (tau_low >= local_max) tau_low = local_max - 1.0f; + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Stage 1: 8-bit coarse histogram (with optional mapping) + // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) + // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. + // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. + constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries + uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); + const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); + + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&vh_histogram[bin], 1); + if (use_bin_cache) { + bin_cache[idx] = bin; + } + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin; + counters[COUNTER_REMAINING_K] = topk; + } + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = 0; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8(raw_input, mapping, + s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; + counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; + } + if (StopAfterStage1) return; + } + + // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) + if constexpr (WriteCounters) { + // Default: all 4 rounds used; overwritten at break if resolved early + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; + } +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if constexpr (WriteCounters) { + if (tx == 0 && counters) { + counters[COUNTER_REFINE_ROUNDS] = round + 1; + } + } + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Wrapper kernel: one CUDA block per batch*head segment +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + + diff --git a/csrc/archived/topk_mapping_full.cuh b/csrc/archived/topk_mapping_full.cuh new file mode 100644 index 00000000..f85204ec --- /dev/null +++ b/csrc/archived/topk_mapping_full.cuh @@ -0,0 +1,217 @@ +// Archived: not included by any compiled TU. See csrc/archived/README.md. +// The full mapping header supporting LUT_CDF, QUANTILE, TRUNC8, SUBTRACT, +// ADAPTIVE_TAIL_WINDOW, TOPK_WINDOW and the auto-range/pivot/tail-window +// pre-pass infrastructure. Replaced by the lean element-wise-only header +// at csrc/topk_mapping.cuh for the remap-benchmark refactor. +#pragma once +#include +#include +#include + +// ============================================================ +// TopK bucket-sort Stage-1 remapping strategies +// +// These transforms remap float scores before Stage 1's 8-bit +// histogram binning. The primary goal is to maximize coarse-bin +// resolution in the score region that determines the top-k +// cutoff, thereby: +// - shrinking the Stage-1 threshold bin (fewer collisions) +// - reducing COUNTER_NUM_EQUAL / COUNTER_STAGE2_INPUT +// - reducing the number of Stage-2 refine rounds +// +// Stage 2 refinement still uses convert_to_uint32() on raw +// floats, so final ordering correctness is always preserved. +// +// Modes 3/4/6/7/9/10 apply a nonlinear transform then linearly +// map the result to [0,255]. Mode 12 (ADAPTIVE_TAIL_WINDOW) +// directly focuses all 256 bins on the competitive upper tail +// estimated from the top-k ratio, collapsing irrelevant +// low-score mass into bin 0. +// ============================================================ + +enum TopKMappingMode { + MAPPING_NONE = 0, // Original convert_to_uint8 behavior + MAPPING_LUT_CDF = 1, // LUT-based CDF equalization + MAPPING_QUANTILE = 2, // Piecewise-linear quantile mapping + MAPPING_POWER = 3, // Monotonic power transform + MAPPING_LOG = 4, // Log transform + // Mode 5 reserved (previously INDEX_CACHE, removed) + MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp + MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp + MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing + MAPPING_ERF = 9, // erf(alpha * x) + MAPPING_TANH = 10, // tanh(alpha * x) + MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing + MAPPING_ADAPTIVE_TAIL_WINDOW = 12, // focus bins on upper tail via sampled quantile + MAPPING_EXP_STRETCH = 13, // exp(alpha * x), concentrates bin resolution on upper tail + MAPPING_TOPK_WINDOW = 14, // k-aware linear windowing: focus bins on [tau_low, max] +}; + +struct TopKMappingParams { + int mode; // TopKMappingMode + float power_exp; // For MAPPING_POWER (default 0.5) + // For MAPPING_ADAPTIVE_TAIL_WINDOW: tail expansion + // factor rho (default 4.0). tau_low = Q(1 - rho*k/n). + const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr + const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr + bool noscale; // Skip auto-range linear scaling, use fp16 bucketing on f(x) + int sample_stride; // Pre-pass sampling stride (1=full, 8=1/8, 0=skip) + int target_k; // Top-k value; used by MAPPING_ADAPTIVE_TAIL_WINDOW +}; + +// NOTE: convert_to_uint8() must be defined before including this header. +// It is defined in topk_sglang.cu within the anonymous namespace. + +// ---- Individual transform functions (return float, no bucketing) ---- + +__device__ __forceinline__ float transform_power(float x, float p) { + return copysignf(__powf(fabsf(x), p), x); +} + +__device__ __forceinline__ float transform_log(float x) { + return copysignf(__logf(fabsf(x) + 1.0f), x); +} + +__device__ __forceinline__ float transform_asinh(float x, float beta) { + return asinhf(beta * x); +} + +__device__ __forceinline__ float transform_log1p(float x, float alpha) { + return copysignf(log1pf(alpha * fabsf(x)), x); +} + +__device__ __forceinline__ float transform_erf(float x, float alpha) { + return erff(alpha * x); +} + +__device__ __forceinline__ float transform_tanh(float x, float alpha) { + return tanhf(alpha * x); +} + +__device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { + float z = alpha * x; + z = fminf(z, 80.0f); // prevent float32 overflow (exp(80) ~ 5.5e34) + return expf(z); +} + +// ---- Transform dispatcher (returns float, no bucketing) ---- + +__device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { + switch (params.mode) { + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); + default: return x; + } +} + +// ---- Linear bucketing for transform modes ---- + +__device__ __forceinline__ uint8_t linear_map_to_uint8(float val, float range_min, float inv_range) { + int bin = __float2int_rd((val - range_min) * inv_range); + return static_cast(min(max(bin, 0), 255)); +} + +// ---- BF16-aware bucketing (mode 8) ---- +// BF16 has 8 exponent + 7 mantissa bits. Taking the upper 8 bits of the +// sign-flipped bf16 bit-pattern yields only ~20 distinct bins for typical +// data (the byte is almost entirely exponent). Instead, convert through +// fp16 (5 exp + 10 mantissa) which puts 5 exp + 2 mantissa bits in the +// upper byte, giving ~135+ distinct bins — equivalent to mode 0 but +// explicitly available as a named mode for documentation/benchmarking. + +__device__ __forceinline__ uint8_t convert_to_uint8_bf16(float x) { + return convert_to_uint8(x); // fp16 sign-flip bucketing +} + +// ---- Non-transform mapping functions (unchanged) ---- + +// LUT-based CDF equalization: lut[original_bin] -> equalized_bin +__device__ __forceinline__ uint8_t map_lut_cdf(float x, const uint8_t* __restrict__ s_lut) { + return s_lut[convert_to_uint8(x)]; +} + +// Quantile mapping: binary search over 256 sorted thresholds +__device__ __forceinline__ uint8_t map_quantile(float x, const float* __restrict__ s_quantiles) { + // Binary search: find largest index i such that x >= s_quantiles[i] + // s_quantiles is sorted ascending, length 256 + int lo = 0, hi = 255; +#pragma unroll 8 + for (int iter = 0; iter < 8; ++iter) { + int mid = (lo + hi + 1) >> 1; + if (x >= s_quantiles[mid]) { + lo = mid; + } else { + hi = mid - 1; + } + } + return static_cast(lo); +} + +// ---- Unified dispatcher ---- +// For modes 3/4/6/7, range_min and inv_range come from a per-block pre-pass. + +__device__ __forceinline__ uint8_t mapped_convert_to_uint8( + float x, + const TopKMappingParams& params, + const uint8_t* __restrict__ s_lut, + const float* __restrict__ s_quantiles, + float range_min, + float inv_range) +{ + switch (params.mode) { + case MAPPING_LUT_CDF: + if (params.lut != nullptr) return map_lut_cdf(x, s_lut); + return convert_to_uint8(x); // fallback to mode 0 when LUT not calibrated + case MAPPING_QUANTILE: + if (params.quantiles != nullptr) return map_quantile(x, s_quantiles); + return convert_to_uint8(x); // fallback to mode 0 when quantiles not calibrated + case MAPPING_POWER: + case MAPPING_LOG: + case MAPPING_ASINH: + case MAPPING_LOG1P: + case MAPPING_ERF: + case MAPPING_TANH: + case MAPPING_EXP_STRETCH: { + float val = apply_transform(x, params); + if (params.noscale) return convert_to_uint8(val); + return linear_map_to_uint8(val, range_min, inv_range); + } + case MAPPING_TRUNC8: + return convert_to_uint8_bf16(x); + case MAPPING_SUBTRACT: + return convert_to_uint8(x - range_min); // range_min repurposed as pivot + case MAPPING_ADAPTIVE_TAIL_WINDOW: + case MAPPING_TOPK_WINDOW: + return linear_map_to_uint8(x, range_min, inv_range); + default: // MAPPING_NONE + return convert_to_uint8(x); + } +} + +// Helper: check if a mapping mode needs the auto-range pre-pass +__device__ __forceinline__ bool needs_auto_range(int mode) { + return (mode == MAPPING_POWER || mode == MAPPING_LOG || + mode == MAPPING_ASINH || mode == MAPPING_LOG1P || + mode == MAPPING_ERF || mode == MAPPING_TANH || + mode == MAPPING_EXP_STRETCH); +} + +// Helper: check if a mapping mode needs the pivot pre-pass +__device__ __forceinline__ bool needs_pivot(int mode) { + return (mode == MAPPING_SUBTRACT); +} + +// Helper: check if mode is the adaptive tail-window pre-pass +__device__ __forceinline__ bool needs_tail_window(int mode) { + return (mode == MAPPING_ADAPTIVE_TAIL_WINDOW); +} + +// Helper: check if mode is the lightweight topk-window pre-pass +__device__ __forceinline__ bool needs_topk_window(int mode) { + return (mode == MAPPING_TOPK_WINDOW); +} diff --git a/csrc/archived/topk_sglang_cluster.cu b/csrc/archived/topk_sglang_cluster.cu new file mode 100644 index 00000000..453f7bf8 --- /dev/null +++ b/csrc/archived/topk_sglang_cluster.cu @@ -0,0 +1,684 @@ +/** + * Vortex TopK — Hopper Thread Block Cluster + Distributed Shared Memory + * single-kernel fused top-K merge. + * + * Grid = Batch * N CTAs. + * Cluster dim = N (runtime, set via cudaLaunchAttributeClusterDimension). + * Each cluster = one batch. cluster.block_rank() identifies the chunk. + * + * Stage 1 (every CTA): 8-bit radix + 8-bit refinement over its chunk, + * writing the local top-K (fp32 remapped score + int32 index) into THIS + * CTA's shared memory — never through global memory. + * + * Stage 2 (CTA 0 only): after cluster.sync(), read every CTA's + * s_export_scores / s_export_indices directly via + * cg::cluster_group::map_shared_rank() — the reads compile to + * `ld.shared::cluster`. Build a merged 8-bit histogram, find the + * coarse threshold, run the standard 8-bit refinement, and emit K + * indices to global memory using warp-popc compaction. + * + * A second cluster.sync() at the end guarantees no CTA exits while + * CTA 0 is still issuing DSMEM reads into its exported SMEM. + * + * sm_90+ only (Hopper, Blackwell). The kernel body is guarded by + * __CUDA_ARCH__ >= 900 so the file compiles cleanly against the + * sm_86/sm_89 gencode targets in setup.py — the host entrypoint + * TORCH_CHECKs the runtime device compute capability. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "register.h" + +namespace { + +constexpr int kThreadsPerBlock = 1024; +constexpr int kWarpSize = 32; +constexpr int RADIX = 256; +constexpr size_t kMaxDynSmem = 96 * 1024; +constexpr int VORTEX_MAX_TOPK = 2048; +constexpr int kMaxClusterDim = 8; // portable TBC cap + +__device__ __forceinline__ uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +// Required by topk_mapping.cuh's forward decl (even though the cluster +// kernel never calls compute_stage1_bin directly). +__device__ __forceinline__ uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +#include "topk_mapping.cuh" + +// 8-step suffix cumsum: after the call s_hist[0][i] = count of items +// with bin >= i (monotone non-increasing). Same routine as +// topk_sglang_parallel.cu. +__device__ __forceinline__ void run_cumsum_256(int s_hist[2][RADIX + 128]) { + const int tx = threadIdx.x; +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const int j = 1 << i; + const int k = i & 1; + int value = s_hist[k][tx]; + if (tx < RADIX - j) value += s_hist[k][tx + j]; + s_hist[k ^ 1][tx] = value; + } + __syncthreads(); + } +} + +// Warp-level ballot+popc compaction. Exactly one atomicAdd per warp, +// issued by the first active lane. Safe from a divergent region. +__device__ __forceinline__ int warp_compact_slot(bool selected, int* s_counter) { + const uint32_t mask = __activemask(); + const uint32_t ballot = __ballot_sync(mask, selected); + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_count = __popc(ballot); + const int rank_in_warp = __popc(ballot & ((1u << lane) - 1u)); + + const int first_lane = __ffs(mask) - 1; + int base = 0; + if (lane == first_lane) { + base = (warp_count > 0) ? ::atomicAdd(s_counter, warp_count) : 0; + } + base = __shfl_sync(mask, base, first_lane); + return selected ? (base + rank_in_warp) : -1; +} + +namespace cg = cooperative_groups; + +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopK_Cluster_Kernel( + const ScoreT* __restrict__ score, // [Batch, N, chunk_size] + int32_t* __restrict__ global_idx, // [Batch, K] + int N, + int chunk_size, + int K, + float mapping_power) +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + cg::cluster_group cluster = cg::this_cluster(); + const int rank = static_cast(cluster.block_rank()); + // Grid layout: dim3(Batch * N). blockIdx.x = b * N + rank. + const int b = (blockIdx.x - rank) / N; + const int tx = threadIdx.x; + + const ScoreT* chunk_in = score + (static_cast(b) * N + rank) * chunk_size; + const int32_t idx_base = rank * chunk_size; + + // Static SMEM ------------------------------------------------------------ + alignas(128) __shared__ int s_hist_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin; + alignas(128) __shared__ int s_sub_threshold_bin; + alignas(128) __shared__ int s_last_remain; + alignas(128) __shared__ int s_export_count; + // Rank 0 only: contiguous staging buffer for the final K indices before + // the coalesced int4 write to global memory. Sized to VORTEX_MAX_TOPK so + // we don't need to carve it out of the dynamic smem layout (which must + // keep the exports at offset 0 for DSMEM visibility). + alignas(16) __shared__ int32_t s_final_indices[VORTEX_MAX_TOPK]; + auto& s_hist = s_hist_buf[0]; + + // Dynamic SMEM ------------------------------------------------------------ + // [0, K*4) s_export_scores (fp32) <- DSMEM-visible + // [K*4, K*8) s_export_indices (int32) <- DSMEM-visible + // [K*8, K*8 + overlay) Stage-1 cache on ALL ranks; reused by + // rank 0 in Stage 2 as the N*K merge pool. + // Stage 1 : s_remapped[chunk] (fp32) + s_bins[chunk] (uint8 padded) + // Stage 2 : s_merge_scores[N*K] (fp32) + s_merge_indices[N*K] (int32) + // + // Exports sit at the start of the SMEM pool so the base offset is the + // same on every cluster CTA — cg::map_shared_rank uses that offset + // modulo the cluster stride to read a remote CTA. + extern __shared__ char smem_raw[]; + float* s_export_scores = reinterpret_cast (smem_raw); + int32_t* s_export_indices = reinterpret_cast(smem_raw + K * sizeof(float)); + float* s_remapped = reinterpret_cast (smem_raw + K * (sizeof(float) + sizeof(int32_t))); + uint8_t* s_bins = reinterpret_cast(s_remapped + chunk_size); + + // Initialize counters + pad export indices to -1 (so the degenerate + // chunk_size < K case leaves recognisable empty slots). + for (int i = tx; i < K; i += blockDim.x) { + s_export_indices[i] = -1; + s_export_scores [i] = -CUDART_INF_F; + } + if (tx == 0) { + s_counter = 0; + s_threshold_bin = -1; + s_sub_threshold_bin = -1; + s_last_remain = 0; + s_export_count = 0; + } + if (tx < RADIX + 1) s_hist[tx] = 0; + __syncthreads(); + + // ========================================================================= + // Stage 1 — local top-K for this chunk. + // ========================================================================= + if (chunk_size <= K) { + // Degenerate: emit every valid element. + for (int idx = tx; idx < chunk_size; idx += blockDim.x) { + const float raw = vortex_to_float(chunk_in[idx]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const int slot = warp_compact_slot(true, &s_counter); + if (slot >= 0 && slot < K) { + s_export_scores [slot] = remapped; + s_export_indices[slot] = idx + idx_base; + } + } + __syncthreads(); + if (tx == 0) s_export_count = min(s_counter, K); + } else { + // Histogram pass 1 ------------------------------------------------------ + for (int idx = tx; idx < chunk_size; idx += blockDim.x) { + const float raw = vortex_to_float(chunk_in[idx]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t b32 = convert_to_uint32(remapped); + const int bin = (b32 >> 24) & 0xFF; + s_remapped[idx] = remapped; + s_bins [idx] = static_cast(bin); + ::atomicAdd(&s_hist[bin], 1); + } + __syncthreads(); + + run_cumsum_256(s_hist_buf); + + if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { + s_threshold_bin = tx; + s_last_remain = K - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin = s_threshold_bin; + + // Emit bin > threshold; build sub-bin histogram on the tie bin ----- + if (tx < RADIX + 1) s_hist[tx] = 0; + __syncthreads(); + + const int num_iters = (chunk_size + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters; ++it) { + const int idx = it * blockDim.x + tx; + const bool in_range = (idx < chunk_size); + int bin = -1; + if (in_range) bin = static_cast(s_bins[idx]); + const bool take_above = in_range && (bin > threshold_bin); + + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + s_export_scores [slot] = s_remapped[idx]; + s_export_indices[slot] = idx + idx_base; + } else if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[idx]); + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + + // Refinement cumsum → sub-threshold bin -------------------------------- + run_cumsum_256(s_hist_buf); + if (tx < RADIX && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) { + s_sub_threshold_bin = RADIX; // no tie refinement needed + } + __syncthreads(); + const int sub_threshold_bin = s_sub_threshold_bin; + + // Emit tie-bin items --------------------------------------------------- + for (int it = 0; it < num_iters; ++it) { + const int idx = it * blockDim.x + tx; + const bool in_range = (idx < chunk_size); + int bin = -1; + if (in_range) bin = static_cast(s_bins[idx]); + int sub_bin = -1; + if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[idx]); + sub_bin = (b32 >> 16) & 0xFF; + } + + const bool take_sub_above = (sub_bin > sub_threshold_bin); + const int slot = warp_compact_slot(take_sub_above, &s_counter); + if (take_sub_above) { + s_export_scores [slot] = s_remapped[idx]; + s_export_indices[slot] = idx + idx_base; + } else if (sub_bin == sub_threshold_bin) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + s_export_scores [K - pos] = s_remapped[idx]; + s_export_indices[K - pos] = idx + idx_base; + } + } + } + __syncthreads(); + if (tx == 0) s_export_count = K; + } + + // ========================================================================= + // Stage 2 — CENTRALIZED RANK-0 PULL. + // + // Control flow: + // barrier #1 (all CTAs) : release all Stage-1 exports cluster-wide. + // rank != 0 : wait at barrier #2 so their exported SMEM + // stays alive while rank 0 pulls, then exit. + // rank 0 Step A : vectorised DSMEM pull of every rank's + // s_export_* into a local N*K merge pool. + // rank 0 Step B+C : single-block 8-bit radix select over the + // merge pool, staging winners into + // s_final_indices via warp-popc + LOCAL + // atomicAdd on &s_counter / &s_last_remain. + // (No DSMEM atomics anywhere.) + // rank 0 Step D : int4-coalesced global store of + // s_final_indices[K] → global_idx[b, :K]. + // barrier #2 (all CTAs) : release idle ranks. + // ========================================================================= + + // cluster.sync() is both a cross-CTA barrier AND a cluster-wide release + // fence on shared memory, so rank 0's upcoming DSMEM reads of remote + // s_export_* observe the Stage-1 writes above. + cluster.sync(); + + if (rank != 0) { + cluster.sync(); // final barrier — keeps SMEM alive during rank 0's pull + return; + } + + // ---- rank 0 only from here on ------------------------------------------ + + // Pre-fill the staging buffer with -1 so that if fewer than K valid + // candidates exist, the unused tail emits as -1 sentinels rather than + // stale static-SMEM data. + for (int i = tx; i < K; i += blockDim.x) s_final_indices[i] = -1; + + // Reset histogram + counters for Stage 2's radix select. + if (tx < RADIX + 1) s_hist[tx] = 0; + if (tx == 0) { + s_counter = 0; + s_threshold_bin = -1; + s_sub_threshold_bin = -1; + s_last_remain = 0; + } + + // Merge pool: overlays the now-dead Stage-1 cache region. Layout: + // [K*8, K*8 + N*K*4) s_merge_scores (fp32) + // [K*8 + N*K*4, K*8 + N*K*8) s_merge_indices (int32) + const int total = N * K; + float* s_merge_scores = reinterpret_cast (smem_raw + K * sizeof(float) + + K * sizeof(int32_t)); + int32_t* s_merge_indices = reinterpret_cast(s_merge_scores + total); + __syncthreads(); + + // ========================================================================= + // Step A — vectorised DSMEM pull. + // + // map_shared_rank(ptr, 0) degenerates to a local load, so we can sweep + // r=0..N-1 uniformly without special-casing the self-copy. + // ========================================================================= + #pragma unroll + for (int r = 0; r < kMaxClusterDim; ++r) { + if (r >= N) break; + const float* rem_scores = cluster.map_shared_rank(s_export_scores, r); + const int32_t* rem_indices = cluster.map_shared_rank(s_export_indices, r); + float* dst_scores = s_merge_scores + r * K; + int32_t* dst_indices = s_merge_indices + r * K; + + if ((K & 3) == 0) { + const float4* src_s4 = reinterpret_cast(rem_scores); + const int4* src_i4 = reinterpret_cast(rem_indices); + float4* dst_s4 = reinterpret_cast (dst_scores); + int4* dst_i4 = reinterpret_cast (dst_indices); + const int K4 = K >> 2; + for (int i = tx; i < K4; i += blockDim.x) { + dst_s4[i] = src_s4[i]; + dst_i4[i] = src_i4[i]; + } + } else { + for (int i = tx; i < K; i += blockDim.x) { + dst_scores [i] = rem_scores [i]; + dst_indices[i] = rem_indices[i]; + } + } + } + __syncthreads(); + + // ========================================================================= + // Step B+C — local 8-bit radix select over the N*K merge pool, with + // warp-popc compaction into s_final_indices. Ported from + // topk_sglang_parallel.cu Phase 2. + // ========================================================================= + + // (1) Coarse 8-bit histogram on bits [31:24] of the sign-flipped score. + const int num_iters_m = (total + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + if (i < total && s_merge_indices[i] >= 0) { + const uint32_t b32 = convert_to_uint32(s_merge_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_hist[bin], 1); + } + } + __syncthreads(); + + run_cumsum_256(s_hist_buf); + + // Fast path: fewer valid candidates than K — emit them all, skip refinement. + const int valid_count = s_hist[0]; + if (valid_count <= K) { + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + const bool take = (i < total) && (s_merge_indices[i] >= 0); + const int slot = warp_compact_slot(take, &s_counter); + if (take && slot < K) s_final_indices[slot] = s_merge_indices[i]; + } + } else { + if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { + s_threshold_bin = tx; + s_last_remain = K - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin_m = s_threshold_bin; + + // (2) Emit above-threshold winners; build sub-bin histogram on tie-bin. + if (tx < RADIX + 1) s_hist[tx] = 0; + __syncthreads(); + + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + bool in_valid = false; + int bin = -1; + uint32_t b32 = 0; + if (i < total) { + const int32_t idx = s_merge_indices[i]; + if (idx >= 0) { + in_valid = true; + b32 = convert_to_uint32(s_merge_scores[i]); + bin = (b32 >> 24) & 0xFF; + } + } + const bool take_above = in_valid && (bin > threshold_bin_m); + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + s_final_indices[slot] = s_merge_indices[i]; + } else if (in_valid && bin == threshold_bin_m) { + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + + // (3) Refinement cumsum → sub-threshold bin. + run_cumsum_256(s_hist_buf); + if (tx < RADIX && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) { + s_sub_threshold_bin = RADIX; // no tie-bin refinement needed + } + __syncthreads(); + const int sub_threshold_bin_m = s_sub_threshold_bin; + + // (4) Emit tie-bin items: hard wins via warp-popc, remainder via local + // atomic budget. Both atomics hit rank-0's native SMEM only. + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + bool in_threshold = false; + int sub_bin = -1; + if (i < total) { + const int32_t idx = s_merge_indices[i]; + if (idx >= 0) { + const uint32_t b32 = convert_to_uint32(s_merge_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + if (bin == threshold_bin_m) { + in_threshold = true; + sub_bin = (b32 >> 16) & 0xFF; + } + } + } + const bool take_sub_above = in_threshold && (sub_bin > sub_threshold_bin_m); + const int slot = warp_compact_slot(take_sub_above, &s_counter); + if (take_sub_above) { + s_final_indices[slot] = s_merge_indices[i]; + } else if (in_threshold && sub_bin == sub_threshold_bin_m) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) s_final_indices[K - pos] = s_merge_indices[i]; + } + } + } + + __syncthreads(); + + // ========================================================================= + // Step D — coalesced int4 store of s_final_indices[K] → global_idx[b, :K]. + // ========================================================================= + int32_t* out_idx = global_idx + static_cast(b) * K; + if ((K & 3) == 0) { + const int4* src = reinterpret_cast(s_final_indices); + int4* dst = reinterpret_cast (out_idx); + const int K4 = K >> 2; + for (int i = tx; i < K4; i += blockDim.x) dst[i] = src[i]; + } else { + for (int i = tx; i < K; i += blockDim.x) out_idx[i] = s_final_indices[i]; + } + + // Final barrier: releases ranks 1..N-1 that were holding their SMEM + // alive while rank 0 was pulling in Step A. + cluster.sync(); +#else + // sm_86/sm_89 fallback: host dispatcher TORCH_CHECKs compute + // capability, so this stub is never actually invoked. The empty + // body still needs to reference the params so nvcc doesn't warn. + (void)score; (void)global_idx; + (void)N; (void)chunk_size; (void)K; (void)mapping_power; +#endif +} + +// One-shot cudaFuncSetAttribute for dynamic smem ceiling. +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + return ::cudaFuncSetAttribute( + f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + }(); + TORCH_CHECK(result == cudaSuccess, + "fast_cluster_topk_merge setup failed: ", + ::cudaGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +// ============================================================================ +// Host entry point — fast_cluster_topk_merge. +// +// score [batch_size, num_chunks, chunk_size] bf16 or f32 +// global_topk_indices [batch_size, topk_val] int32 (out) +// +// No workspace tensors — Stage-1 partial top-K lives in shared memory, +// consumed by CTA 0 of the cluster via DSMEM. +// ============================================================================ +void fast_cluster_topk_merge( + const at::Tensor& score, + at::Tensor& global_topk_indices, + const int64_t batch_size, + const int64_t num_chunks, + const int64_t chunk_size, + const int64_t topk_val, + const int64_t mapping_mode, + const double mapping_power) +{ + CHECK_CUDA(score); + CHECK_CUDA(global_topk_indices); + + TORCH_CHECK(topk_val > 0 && topk_val <= VORTEX_MAX_TOPK, + "fast_cluster_topk_merge: topk_val=", topk_val, + " must be in (0, ", VORTEX_MAX_TOPK, "]"); + TORCH_CHECK(num_chunks >= 1 && num_chunks <= kMaxClusterDim, + "fast_cluster_topk_merge: num_chunks=", num_chunks, + " must be in [1, ", kMaxClusterDim, "] (portable TBC cap)"); + TORCH_CHECK(batch_size >= 1, "batch_size must be >= 1"); + TORCH_CHECK(chunk_size >= 1, "chunk_size must be >= 1"); + TORCH_CHECK(global_topk_indices.scalar_type() == at::kInt, + "global_topk_indices must be int32"); + TORCH_CHECK(global_topk_indices.numel() >= batch_size * topk_val, + "global_topk_indices is too small for batch_size * topk_val"); + + TORCH_CHECK( + mapping_mode == MAPPING_NONE || + mapping_mode == MAPPING_POWER || + mapping_mode == MAPPING_ASINH || + mapping_mode == MAPPING_LOG1P || + mapping_mode == MAPPING_ERF || + mapping_mode == MAPPING_TANH || + mapping_mode == MAPPING_SUBTRACT || + mapping_mode == MAPPING_EXP_STRETCH || + mapping_mode == MAPPING_SHIFT_POW2 || + mapping_mode == MAPPING_SHIFT_POW3 || + mapping_mode == MAPPING_LINEAR_STEEP, + "fast_cluster_topk_merge: mapping_mode=", mapping_mode, + " not supported. Valid: NONE(0), POWER(3), ASINH(6), LOG1P(7), " + "ERF(9), TANH(10), SUBTRACT(11), EXP_STRETCH(13), SHIFT_POW2(15), " + "SHIFT_POW3(16), LINEAR_STEEP(17)."); + + // Hardware capability gate — Thread Block Clusters require sm_90+. + int dev; + TORCH_CHECK(::cudaGetDevice(&dev) == cudaSuccess, "cudaGetDevice failed"); + cudaDeviceProp prop{}; + TORCH_CHECK(::cudaGetDeviceProperties(&prop, dev) == cudaSuccess, + "cudaGetDeviceProperties failed"); + TORCH_CHECK(prop.major >= 9, + "fast_cluster_topk_merge requires sm_90+ (Hopper/Blackwell). " + "Detected compute capability ", prop.major, ".", prop.minor, "."); + + // Dynamic smem layout (per CTA): + // exports : topk_val * (float + int32) = topk_val * 8 B (DSMEM-visible) + // overlay : used by Stage 1 as the remap/bin cache (all ranks), reused + // by Stage 2 on rank 0 as the N*K merge pool. Sized to the + // larger of the two so either fits. + // cache_bytes = chunk_size * (float + uint8), uint8 region padded. + // merge_bytes = num_chunks * topk_val * (float + int32). + const size_t export_bytes = static_cast(topk_val) * + (sizeof(float) + sizeof(int32_t)); + const size_t cache_bytes = static_cast(chunk_size) * sizeof(float) + + ((static_cast(chunk_size) + 15) & ~size_t(15)); + const size_t merge_bytes = static_cast(num_chunks) * + static_cast(topk_val) * + (sizeof(float) + sizeof(int32_t)); + const size_t overlay_bytes = (cache_bytes > merge_bytes) ? cache_bytes : merge_bytes; + const size_t smem_bytes = export_bytes + overlay_bytes; + TORCH_CHECK(smem_bytes <= kMaxDynSmem, + "fast_cluster_topk_merge: smem ", smem_bytes, + " > ceiling ", kMaxDynSmem, + " (topk_val=", topk_val, ", num_chunks=", num_chunks, + ", chunk_size=", chunk_size, ")"); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + const float mp = static_cast(mapping_power); + + const dim3 grid(static_cast(batch_size * num_chunks), 1, 1); + const dim3 block(kThreadsPerBlock, 1, 1); + + cudaLaunchAttribute attrs[1]{}; + attrs[0].id = cudaLaunchAttributeClusterDimension; + attrs[0].val.clusterDim.x = static_cast(num_chunks); + attrs[0].val.clusterDim.y = 1; + attrs[0].val.clusterDim.z = 1; + + cudaLaunchConfig_t cfg{}; + cfg.gridDim = grid; + cfg.blockDim = block; + cfg.dynamicSmemBytes = smem_bytes; + cfg.stream = stream; + cfg.attrs = attrs; + cfg.numAttrs = 1; + + #define LAUNCH(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + setup_kernel_smem_once, \ + kMaxDynSmem>(); \ + const auto rc_launch = ::cudaLaunchKernelEx( \ + &cfg, TopK_Cluster_Kernel, \ + PTR_EXPR, \ + global_topk_indices.data_ptr(), \ + static_cast(num_chunks), \ + static_cast(chunk_size), \ + static_cast(topk_val), \ + mp); \ + TORCH_CHECK(rc_launch == cudaSuccess, \ + "fast_cluster_topk_merge launch failed: ", \ + ::cudaGetErrorString(rc_launch)); \ + } while (0) + + #define DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + switch (mapping_mode) { \ + case MAPPING_NONE: LAUNCH(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: LAUNCH(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_ASINH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_ERF: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + default: TORCH_CHECK(false, "unreachable mode"); \ + } \ + } while (0) + + if (score.scalar_type() == at::ScalarType::BFloat16) { + DISPATCH_MODE(__nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(score.data_ptr())); + } else if (score.scalar_type() == at::ScalarType::Float) { + DISPATCH_MODE(float, score.data_ptr()); + } else { + TORCH_CHECK(false, "fast_cluster_topk_merge: unsupported dtype ", + score.scalar_type()); + } + + #undef DISPATCH_MODE + #undef LAUNCH + + const auto rc = cudaGetLastError(); + TORCH_CHECK(rc == cudaSuccess, + "fast_cluster_topk_merge kernel failed: ", ::cudaGetErrorString(rc)); +} diff --git a/csrc/archived/topk_sglang_ori_fastpath.cu b/csrc/archived/topk_sglang_ori_fastpath.cu new file mode 100644 index 00000000..29970ecd --- /dev/null +++ b/csrc/archived/topk_sglang_ori_fastpath.cu @@ -0,0 +1,319 @@ +// Archived: not compiled. See csrc/archived/README.md +// +// Flexible-radix (RADIX_BITS 4..10) "ori fast path" for TopK. It was the +// zero-mapping-overhead fast path used when mapping_mode == MAPPING_NONE. +// No longer tested — mode 0 now routes through the fused TopKOutput_Kernel +// with mapping.mode == MAPPING_NONE, which pays no extra cost because +// mapped_convert_to_uint8 collapses to convert_to_uint8 in that branch. +// +// The code below was extracted verbatim from csrc/topk_sglang.cu as of the +// fused-kernel refactor. It references helpers (kSmem, convert_to_uint32, +// vortex_to_float, VORTEX_MAX_TOPK, kThreadsPerBlock, setup_kernel_smem_once, +// CHECK_CUDA, topk_mapping.cuh types) from the surrounding translation unit. +// Dropping this file into a build as-is will not compile; it is reference +// only. + +template +__device__ __forceinline__ uint16_t convert_to_uintN(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return key >> (16 - BITS); +} + +// ====================================================================== +// Ori fast path: zero-overhead topk with no mapping infrastructure. +// Template on RADIX_BITS: 4-10 (16 to 1024 bins). +// ====================================================================== +template +__device__ void fast_topk_ori( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 1 << RADIX_BITS; + constexpr auto RADIX_PAD = RADIX / 2; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + static_assert(RADIX_BITS >= 4 && RADIX_BITS <= 10, "RADIX_BITS must be 4-10"); + static_assert(RADIX <= BLOCK_SIZE, "RADIX must not exceed BLOCK_SIZE"); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + RADIX_PAD]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Stage 1: coarse histogram with RADIX bins + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uintN(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + for (int i = 0; i < RADIX_BITS; ++i) { + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + // Stage 2 cumsum: always 256 sub-bins (8-bit radix on raw float bits) + const auto run_cumsum_s2 = [&] { + for (int i = 0; i < 8; ++i) { + if (C10_LIKELY(tx < 256)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < 256 - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uintN(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < 257) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uintN(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum_s2(); + if (tx < 256 && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < 257) s_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Ori fast-path wrapper: zero mapping overhead, flexible radix +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Ori_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// Helper: launch TopKOutput_Ori_Kernel with radix_bits dispatch +template +void launch_ori_kernel( + const ScoreT* score, const int* dense_kv_indptr, const int* sparse_kv_indptr, + const int* dense_kv_indices, int* sparse_kv_indices, + int topk_val, int reserved_bos, int reserved_eos, + int radix_bits, dim3 nblks, dim3 nthreads, cudaStream_t stream) +{ + #define LAUNCH_ORI(BITS) \ + setup_kernel_smem_once, kSmem>(); \ + TopKOutput_Ori_Kernel<<>>( \ + score, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, \ + topk_val, reserved_bos, reserved_eos) + switch (radix_bits) { + case 4: LAUNCH_ORI(4); break; + case 5: LAUNCH_ORI(5); break; + case 6: LAUNCH_ORI(6); break; + case 7: LAUNCH_ORI(7); break; + case 9: LAUNCH_ORI(9); break; + case 10: LAUNCH_ORI(10); break; + default: LAUNCH_ORI(8); break; + } + #undef LAUNCH_ORI +} + +// ====================================================================== +// Explicit ori baseline entry point — always uses the ori fast path +// ====================================================================== +void topk_output_sglang_ori( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t radix_bits) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output_sglang_ori: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, + "topk_output_sglang_ori: radix_bits must be 4-10, got ", radix_bits); + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + launch_ori_kernel<__nv_bfloat16>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); + } else if (x.scalar_type() == at::ScalarType::Float) { + launch_ori_kernel( + x.data_ptr(), + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); + } else { + TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output_sglang_ori kernel failed: ", ::cudaGetErrorString(result)); +} diff --git a/csrc/archived/topk_sglang_parallel.cu b/csrc/archived/topk_sglang_parallel.cu new file mode 100644 index 00000000..f11e59d3 --- /dev/null +++ b/csrc/archived/topk_sglang_parallel.cu @@ -0,0 +1,639 @@ +/** + * Vortex TopK — single-kernel parallel+merge pipeline. + * + * ONE kernel launch. Per-chunk selection and cross-chunk merge both run + * inside the same grid-(N, Batch) launch. The last-arriving CTA for + * each batch (detected by a program-lifetime __device__ done-counter + + * atomicInc wrap-around) carries out the merge — no second launch, no + * per-call cudaMemset for barrier state. + * + * Correctness: + * Stage 1 per-chunk uses ONE 8-bit radix histogram + ONE 8-bit + * refinement round on the threshold bin (16 bits of selection + * precision). For bf16 input (8 mantissa bits effective), this is + * lossless — two items with the same 16-bit key are bit-identical as + * bf16 values. + * + * Stage 2 merge operates on N*K pre-remapped keys in shared memory + * and uses the same 8-bit-hist + 8-bit-refine pattern, which is + * strictly sufficient to pick the correct top-K from the union. + * + * Low-overhead primitives: + * - Warp-level ballot+popc compaction on the "bin > threshold" path + * so each warp issues ONE atomicAdd on the block counter instead + * of one per thread. + * - Program-lifetime __device__ done-counter sized for realistic + * batch×head counts; atomicInc wraps back to 0 at num_chunks so + * there's no memset on the hot path. + * - Vectorised float4/int4 loads from global → smem in the merge. + * + * Supported mapping modes (IDs from csrc/topk_mapping.cuh): + * 3=POWER, 6=ASINH, 7=LOG1P, 9=ERF, 10=TANH, 11=SUBTRACT, + * 13=EXP_STRETCH, 15=SHIFT_POW2, 16=SHIFT_POW3, 17=LINEAR_STEEP. + */ + + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + #include "register.h" + + namespace { + + // ---- Launch constants ------------------------------------------------------ + + constexpr int kThreadsPerBlock = 1024; + constexpr int kWarpSize = 32; + constexpr int RADIX = 256; + constexpr size_t kMaxDynSmem = 96 * 1024; + constexpr int VORTEX_MAX_TOPK = 2048; + + // Stage-2 holds N*K (key, idx) pairs in smem = 8 B/item. + constexpr int kMergeCap = 8192; + + // Max batch the single kernel can sequence. Sized for realistic + // bs×heads (decode). __device__ globals are zero-initialised at + // program start; atomicInc wrap-around keeps each entry at 0 between + // launches, so no host-side memset on the hot path. + constexpr int kMaxBatch = 8192; + __device__ unsigned int g_done_counter[kMaxBatch]; + + // ---- Device helpers -------------------------------------------------------- + + __device__ __forceinline__ uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + // Required symbol for topk_mapping.cuh's compute_stage1_bin. Not used + // directly by the kernel body here, but the header includes a forward + // declaration that resolves against this definition at link time. + __device__ __forceinline__ uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + template + __device__ __forceinline__ float vortex_to_float(T x); + template <> + __device__ __forceinline__ float vortex_to_float(float x) { return x; } + template <> + __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); + } + + #include "topk_mapping.cuh" + + // ============================================================================ + // 8-step suffix cumsum over 256 bins. After the call s_hist[0][i] is + // the count of items with bin >= i (monotone non-increasing). + // ============================================================================ + __device__ __forceinline__ void run_cumsum_256(int s_hist[2][RADIX + 128]) { + const int tx = threadIdx.x; + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const int j = 1 << i; + const int k = i & 1; + int value = s_hist[k][tx]; + if (tx < RADIX - j) value += s_hist[k][tx + j]; + s_hist[k ^ 1][tx] = value; + } + __syncthreads(); + } + } + + // ============================================================================ + // Warp-level ballot+popc compaction. + // + // Every participating thread offers a boolean `selected`. Exactly ONE + // atomicAdd per warp — issued by the first active lane — reserves + // `warp_count` slots; other selected lanes derive their slot via a + // popc prefix sum. Safe when called from inside a divergent region + // (uses __activemask(), not a fixed all-ones mask). + // ============================================================================ + __device__ __forceinline__ int warp_compact_slot(bool selected, int* s_counter) { + const uint32_t mask = __activemask(); + const uint32_t ballot = __ballot_sync(mask, selected); + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_count = __popc(ballot); + const int rank_in_warp = __popc(ballot & ((1u << lane) - 1u)); + + const int first_lane = __ffs(mask) - 1; + int base = 0; + if (lane == first_lane) { + base = (warp_count > 0) ? ::atomicAdd(s_counter, warp_count) : 0; + } + base = __shfl_sync(mask, base, first_lane); + return selected ? (base + rank_in_warp) : -1; + } + + // ============================================================================ + // Combined kernel — Stage 1 (per-chunk) + barrier + Stage 2 (merge). + // + // Grid = (Batch, N). One CTA per (batch, chunk). + // Block = kThreadsPerBlock = 1024. + // + // Shared-memory layout (reused across phases): + // Phase 1 needs: + // s_remapped[chunk_size] (float) — cached apply_transform output. + // s_bins[chunk_size] (uint8) — cached coarse bin. + // Merge needs: + // s_scores[N*K] (float) — pair buffer, loaded vectorised. + // s_indices[N*K] (int32) — pair buffer. + // kSmemBytes is sized to host max of both. + // + // Sync between phases: + // After Phase 1's workspace writes, __threadfence() publishes them, + // then thread 0 does `atomicInc(&g_done_counter[bx], N-1)` which + // cycles 0→1→…→N-1→0 so no reset is needed between calls. The CTA + // whose returned `old == N-1` is the last one — it falls through + // into the merge; other CTAs return. + // ============================================================================ + template + __global__ __launch_bounds__(kThreadsPerBlock) + void TopK_Parallel_Kernel( + const ScoreT* __restrict__ score, // [Batch, N, chunk_size] + int32_t* __restrict__ global_idx, // [Batch, K] + float* __restrict__ partial_keys, // [Batch, N, K] workspace + int32_t* __restrict__ partial_idx, // [Batch, N, K] workspace + int N, + int chunk_size, + int K, + float mapping_power) + { + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + + // Addresses for this CTA's chunk slice and its slot in the workspace. + const ScoreT* chunk_in = score + (static_cast(b) * N + n) * chunk_size; + float* chunk_keys_out = partial_keys + (static_cast(b) * N + n) * K; + int32_t* chunk_idx_out = partial_idx + (static_cast(b) * N + n) * K; + const int32_t idx_base = n * chunk_size; // batch-local offset + + // ---------------------------------------------------------------- smem + extern __shared__ char smem_raw[]; + + // Shared-memory counters / histogram live in static smem so the + // Phase-1 and merge phases can share the same dynamic pool. + alignas(128) __shared__ int s_hist_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin; + alignas(128) __shared__ int s_sub_threshold_bin; + alignas(128) __shared__ int s_last_remain; + alignas(128) __shared__ int s_is_last; + auto& s_hist = s_hist_buf[0]; + + // ========================================================================= + // Phase 1: per-chunk TopK via 8-bit radix + 8-bit refinement. + // ========================================================================= + // + // Dynamic smem region used as: + // s_remapped : chunk_size * 4 B (cached apply_transform output) + // s_bins : chunk_size * 1 B (cached Stage-1 bin) + // + // Refinement is a second 8-bit bucket on bits [23:16] of the + // sign-flipped u32 key, used to refine the threshold bin. 8 + 8 = + // 16 bits of selection precision → lossless for bf16. + float* s_remapped = reinterpret_cast(smem_raw); + uint8_t* s_bins = reinterpret_cast(s_remapped + chunk_size); + + // ---- Degenerate chunk_size <= K : emit everything as-is. ------------- + if (chunk_size <= K) { + for (int i = tx; i < K; i += blockDim.x) { + if (i < chunk_size) { + const float raw = vortex_to_float(chunk_in[i]); + chunk_keys_out[i] = apply_transform_tmpl(raw, mapping_power); + chunk_idx_out [i] = i + idx_base; + } else { + chunk_keys_out[i] = -CUDART_INF_F; + chunk_idx_out [i] = -1; + } + } + } else { + // ---- Histogram pass 1: transform + bucket; cache both to smem. ---- + if (tx < RADIX + 1) s_hist[tx] = 0; + if (tx == 0) { s_counter = 0; s_threshold_bin = -1; s_last_remain = 0; } + __syncthreads(); + + for (int idx = tx; idx < chunk_size; idx += blockDim.x) { + const float raw = vortex_to_float(chunk_in[idx]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t b32 = convert_to_uint32(remapped); + const int bin = (b32 >> 24) & 0xFF; + s_remapped[idx] = remapped; + s_bins [idx] = static_cast(bin); + ::atomicAdd(&s_hist[bin], 1); + } + __syncthreads(); + + run_cumsum_256(s_hist_buf); + + if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { + s_threshold_bin = tx; + s_last_remain = K - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin = s_threshold_bin; + + // ---- Emit bin > threshold (warp-popc) and build refinement hist. ---- + if (tx < RADIX + 1) s_hist[tx] = 0; + __syncthreads(); + + const int num_iters = (chunk_size + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters; ++it) { + const int idx = it * blockDim.x + tx; + const bool in_range = (idx < chunk_size); + int bin = -1; + if (in_range) bin = static_cast(s_bins[idx]); + const bool take_above = in_range && (bin > threshold_bin); + + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + chunk_keys_out[slot] = s_remapped[idx]; + chunk_idx_out [slot] = idx + idx_base; + } else if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[idx]); + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + + // ---- Refinement cumsum → sub-threshold bin. ------------------------ + run_cumsum_256(s_hist_buf); + if (tx < RADIX && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + // budget for items at the sub-threshold bin + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) { + // Only possible if last_remain == 0 (bin > threshold already emitted + // exactly K items). Nothing more to do; make the sub bin a sentinel. + s_sub_threshold_bin = RADIX; // no sub-threshold bin + } + __syncthreads(); + const int sub_threshold_bin = s_sub_threshold_bin; + + // ---- Emit threshold-bin items using sub-threshold logic. ---------- + for (int it = 0; it < num_iters; ++it) { + const int idx = it * blockDim.x + tx; + const bool in_range = (idx < chunk_size); + int bin = -1; + if (in_range) bin = static_cast(s_bins[idx]); + int sub_bin = -1; + if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[idx]); + sub_bin = (b32 >> 16) & 0xFF; + } + + const bool take_sub_above = (sub_bin > sub_threshold_bin); + const int slot = warp_compact_slot(take_sub_above, &s_counter); + if (take_sub_above) { + chunk_keys_out[slot] = s_remapped[idx]; + chunk_idx_out [slot] = idx + idx_base; + } else if (sub_bin == sub_threshold_bin) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + chunk_keys_out[K - pos] = s_remapped[idx]; + chunk_idx_out [K - pos] = idx + idx_base; + } + } + } + __syncthreads(); + } + + // ========================================================================= + // Barrier: publish this CTA's workspace writes and atomicInc the + // per-batch done-counter. The CTA that sees old == N-1 is the last + // arriving one; every other CTA returns here. + // ========================================================================= + __threadfence(); + __syncthreads(); + if (tx == 0) { + const unsigned int old = ::atomicInc( + &g_done_counter[b], static_cast(N - 1)); + s_is_last = (old == static_cast(N - 1)) ? 1 : 0; + } + __syncthreads(); + if (s_is_last == 0) return; + + // ========================================================================= + // Phase 2 (merge, only in last-arriving CTA): + // load N*K candidates into smem (vectorised) → + // 8-bit histogram in smem → + // threshold → warp-popc emit above + tie-bin refinement. + // ========================================================================= + const int total = N * K; + const float* keys_in = partial_keys + static_cast(b) * total; + const int32_t* idx_in = partial_idx + static_cast(b) * total; + int32_t* out_idx = global_idx + static_cast(b) * K; + + // Reuse the same dynamic smem region as Phase 1 — Phase 1's caches + // are dead now. Layout: [ s_scores : total floats | s_indices : total int32 ]. + float* s_scores = reinterpret_cast(smem_raw); + int32_t* s_indices = reinterpret_cast(s_scores + total); + + // Vectorised 128-bit loads when `total` is a multiple of 4. + if ((total & 3) == 0) { + const float4* keys_v = reinterpret_cast(keys_in); + const int4* idx_v = reinterpret_cast (idx_in); + float4* ss_v = reinterpret_cast (s_scores); + int4* si_v = reinterpret_cast (s_indices); + const int total4 = total >> 2; + for (int i = tx; i < total4; i += blockDim.x) { + ss_v[i] = keys_v[i]; + si_v[i] = idx_v [i]; + } + } else { + for (int i = tx; i < total; i += blockDim.x) { + s_scores [i] = keys_in[i]; + s_indices[i] = idx_in [i]; + } + } + + if (tx < RADIX + 1) s_hist[tx] = 0; + if (tx == 0) { + s_counter = 0; + s_threshold_bin = -1; + s_sub_threshold_bin = -1; + s_last_remain = 0; + } + __syncthreads(); + + // (2) 8-bit histogram in smem. + const int num_iters_m = (total + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + if (i < total && s_indices[i] >= 0) { + const uint32_t b32 = convert_to_uint32(s_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_hist[bin], 1); + } + } + __syncthreads(); + + run_cumsum_256(s_hist_buf); + + // Fast path: no threshold search needed when valid_count ≤ K. + const int valid_count = s_hist[0]; + if (valid_count <= K) { + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + const bool take = (i < total) && (s_indices[i] >= 0); + const int slot = warp_compact_slot(take, &s_counter); + if (take) out_idx[slot] = s_indices[i]; + } + return; + } + + if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { + s_threshold_bin = tx; + s_last_remain = K - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin_m = s_threshold_bin; + + // (3) Emit above threshold via warp-popc; build sub-bin histogram on + // bits [23:16] for the tie-bin refinement. + if (tx < RADIX + 1) s_hist[tx] = 0; + __syncthreads(); + + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + bool in_valid = false; + int bin = -1; + uint32_t b32 = 0; + if (i < total) { + const int32_t idx = s_indices[i]; + if (idx >= 0) { + in_valid = true; + b32 = convert_to_uint32(s_scores[i]); + bin = (b32 >> 24) & 0xFF; + } + } + const bool take_above = in_valid && (bin > threshold_bin_m); + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + out_idx[slot] = s_indices[i]; + } else if (in_valid && bin == threshold_bin_m) { + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + + // (4) Refinement cumsum → sub-threshold bin. + run_cumsum_256(s_hist_buf); + if (tx < RADIX && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) { + s_sub_threshold_bin = RADIX; // no tie-bin refinement needed + } + __syncthreads(); + const int sub_threshold_bin_m = s_sub_threshold_bin; + + // (5) Emit tie-bin items via warp-popc + sub-threshold budget. + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + bool in_threshold = false; + int sub_bin = -1; + if (i < total) { + const int32_t idx = s_indices[i]; + if (idx >= 0) { + const uint32_t b32 = convert_to_uint32(s_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + if (bin == threshold_bin_m) { + in_threshold = true; + sub_bin = (b32 >> 16) & 0xFF; + } + } + } + const bool take_sub_above = in_threshold && (sub_bin > sub_threshold_bin_m); + const int slot = warp_compact_slot(take_sub_above, &s_counter); + if (take_sub_above) { + out_idx[slot] = s_indices[i]; + } else if (in_threshold && sub_bin == sub_threshold_bin_m) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) out_idx[K - pos] = s_indices[i]; + } + } + } + + // ---- setup_kernel_smem_once ------------------------------------------------ + + template + void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + return ::cudaFuncSetAttribute( + f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + }(); + TORCH_CHECK(result == cudaSuccess, + "fast_fused_topk_merge setup failed: ", + ::cudaGetErrorString(result)); + } + + } // namespace + + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + + // ============================================================================ + // Host entry point. + // + // score [batch_size, num_chunks, chunk_size] bf16 or f32 + // global_topk_indices [batch_size, topk_val] int32 (output) + // + // ONE kernel launch. The per-chunk selection (Phase 1) and the + // cross-chunk merge (Phase 2) are fused in TopK_Parallel_Kernel via a + // last-CTA-wins atomicInc barrier. A per-call workspace holds the + // [batch, N, K] partial top-K that the last CTA reads from; the + // done-counter is a program-lifetime __device__ global so nothing + // needs memsetting on the hot path. + // ============================================================================ + void fast_fused_topk_merge( + const at::Tensor& score, + at::Tensor& global_topk_indices, + const int64_t batch_size, + const int64_t num_chunks, + const int64_t chunk_size, + const int64_t topk_val, + const int64_t mapping_mode, + const double mapping_power) + { + CHECK_CUDA(score); + CHECK_CUDA(global_topk_indices); + + TORCH_CHECK(topk_val > 0 && topk_val <= VORTEX_MAX_TOPK, + "fast_fused_topk_merge: topk_val=", topk_val, + " must be in (0, ", VORTEX_MAX_TOPK, "]"); + TORCH_CHECK(num_chunks >= 1, "num_chunks must be >= 1"); + TORCH_CHECK(batch_size >= 1, "batch_size must be >= 1"); + TORCH_CHECK(batch_size <= kMaxBatch, + "fast_fused_topk_merge: batch_size ", batch_size, + " exceeds the __device__ done-counter cap (", kMaxBatch, ")"); + TORCH_CHECK(chunk_size >= 1, "chunk_size must be >= 1"); + TORCH_CHECK(num_chunks * topk_val <= kMergeCap, + "fast_fused_topk_merge: num_chunks*topk_val (", + num_chunks * topk_val, ") exceeds merge cap (", kMergeCap, + "). Reduce num_chunks or topk_val."); + TORCH_CHECK(global_topk_indices.scalar_type() == at::kInt, + "global_topk_indices must be int32"); + TORCH_CHECK(global_topk_indices.numel() >= batch_size * topk_val, + "global_topk_indices is too small for batch_size * topk_val"); + + TORCH_CHECK( + mapping_mode == MAPPING_POWER || + mapping_mode == MAPPING_ASINH || + mapping_mode == MAPPING_LOG1P || + mapping_mode == MAPPING_ERF || + mapping_mode == MAPPING_TANH || + mapping_mode == MAPPING_SUBTRACT || + mapping_mode == MAPPING_EXP_STRETCH || + mapping_mode == MAPPING_SHIFT_POW2 || + mapping_mode == MAPPING_SHIFT_POW3 || + mapping_mode == MAPPING_LINEAR_STEEP, + "fast_fused_topk_merge: mapping_mode=", mapping_mode, + " not supported. Valid: POWER(3), ASINH(6), LOG1P(7), ERF(9), " + "TANH(10), SUBTRACT(11), EXP_STRETCH(13), SHIFT_POW2(15), " + "SHIFT_POW3(16), LINEAR_STEEP(17)."); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + const float mp = static_cast(mapping_power); + + // Dynamic smem must fit whichever phase is larger: + // Phase 1: chunk_size floats + chunk_size bytes. + // Phase 2: num_chunks*topk_val * (float + int32). + const size_t p1_bytes = static_cast(chunk_size) * sizeof(float) + + ((static_cast(chunk_size) + 15) & ~size_t(15)); + const size_t p2_bytes = static_cast(num_chunks) * + static_cast(topk_val) * + (sizeof(float) + sizeof(int32_t)); + const size_t smem_bytes = p1_bytes > p2_bytes ? p1_bytes : p2_bytes; + TORCH_CHECK(smem_bytes <= kMaxDynSmem, + "fast_fused_topk_merge: smem ", smem_bytes, + " > ceiling ", kMaxDynSmem); + + // Per-call workspace for the [batch, N, K] partial top-K. at::empty + // hits the caching allocator (no cudaMalloc in the hot path after + // warmup). The done-counter lives in __device__ memory — no memset. + auto opts_f32 = at::TensorOptions().device(score.device()).dtype(at::kFloat); + auto opts_i32 = at::TensorOptions().device(score.device()).dtype(at::kInt); + const int64_t ws_elems = batch_size * num_chunks * topk_val; + at::Tensor partial_keys = at::empty({ws_elems}, opts_f32); + at::Tensor partial_idx = at::empty({ws_elems}, opts_i32); + + dim3 grid(static_cast(batch_size), + static_cast(num_chunks)); + dim3 block(kThreadsPerBlock); + + #define LAUNCH(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + setup_kernel_smem_once, \ + kMaxDynSmem>(); \ + TopK_Parallel_Kernel \ + <<>>( \ + PTR_EXPR, \ + global_topk_indices.data_ptr(), \ + partial_keys.data_ptr(), \ + partial_idx.data_ptr(), \ + static_cast(num_chunks), \ + static_cast(chunk_size), \ + static_cast(topk_val), \ + mp); \ + } while (0) + + #define DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + switch (mapping_mode) { \ + case MAPPING_POWER: LAUNCH(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_ASINH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_ERF: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + default: TORCH_CHECK(false, "unreachable mode"); \ + } \ + } while (0) + + if (score.scalar_type() == at::ScalarType::BFloat16) { + DISPATCH_MODE(__nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(score.data_ptr())); + } else if (score.scalar_type() == at::ScalarType::Float) { + DISPATCH_MODE(float, score.data_ptr()); + } else { + TORCH_CHECK(false, "fast_fused_topk_merge: unsupported dtype ", + score.scalar_type()); + } + + #undef DISPATCH_MODE + #undef LAUNCH + + const auto rc = cudaGetLastError(); + TORCH_CHECK(rc == cudaSuccess, + "fast_fused_topk_merge kernel failed: ", ::cudaGetErrorString(rc)); + } \ No newline at end of file diff --git a/csrc/archived/topk_slgang_ori.cu b/csrc/archived/topk_slgang_ori.cu new file mode 100644 index 00000000..04a2b73b --- /dev/null +++ b/csrc/archived/topk_slgang_ori.cu @@ -0,0 +1,546 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} diff --git a/csrc/clean.py b/csrc/clean.py new file mode 100644 index 00000000..8d258bb0 --- /dev/null +++ b/csrc/clean.py @@ -0,0 +1,21 @@ +from pathlib import Path +import sys + +def clean_one_leading_space(path: str): + p = Path(path) + text = p.read_text(encoding="utf-8") + + cleaned = "".join( + line[1:] if line.startswith(" ") else line + for line in text.splitlines(keepends=True) + ) + + p.write_text(cleaned, encoding="utf-8") + print(f"Cleaned: {p}") + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python clean_indent.py ") + sys.exit(1) + + clean_one_leading_space(sys.argv[1]) \ No newline at end of file diff --git a/csrc/register.cc b/csrc/register.cc index fd9d4eb2..9771bbd9 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -8,6 +8,109 @@ PYBIND11_MODULE(vortex_torch_C, m){ m.def("Chunkwise_NH2HN_Transpose", &Chunkwise_NH2HN_Transpose); m.def("Chunkwise_HN2NH_Transpose", &Chunkwise_HN2NH_Transpose); m.def("topk_output", &topk_output); + m.def("topk_output_sglang", &topk_output_sglang, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages")); + m.def("topk_output_sglang_ori", &topk_output_sglang_ori, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("indices_out"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages")); + m.def("topk_output_sglang_fused", &topk_output_sglang_fused, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode"), + py::arg("mapping_power"), + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); + m.def("topk_output_adaptive", &topk_output_adaptive, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode"), + py::arg("mapping_power")); + m.def("topk_output_adaptive_workspace", &topk_output_adaptive_workspace, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("partial_keys"), py::arg("partial_indices"), + py::arg("done_counter"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode"), + py::arg("mapping_power"), + py::arg("forced_splits") = -1, + py::arg("forced_partition") = -1, + py::arg("local_mode") = 0); + m.def("topk_output_adaptive_workspace_midk", + &topk_output_adaptive_workspace_midk, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("partial_keys"), py::arg("partial_indices"), + py::arg("done_counter"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode"), + py::arg("mapping_power"), + py::arg("forced_splits") = -1); + m.def("topk_output_adaptive_workspace_ablation", + &topk_output_adaptive_workspace_ablation, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("partial_keys"), py::arg("partial_indices"), + py::arg("done_counter"), py::arg("scratch"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("ablation_mode"), + py::arg("forced_splits") = 8); + m.def("topk_adaptive_phase1_only", &topk_adaptive_phase1_only, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("dense_kv_indices"), + py::arg("partial_scores"), py::arg("partial_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages")); + m.def("topk_adaptive_phase2_only", &topk_adaptive_phase2_only, + py::arg("partial_scores"), py::arg("partial_indices"), + py::arg("sparse_kv_indptr"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos")); + m.def("topk_remap_only", &topk_remap_only, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("remapped"), + py::arg("eff_batch_size"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("mapping_mode"), + py::arg("mapping_power")); + m.def("topk_profile_histogram", &topk_profile_histogram, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("histograms"), py::arg("eff_batch_size"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); + m.def("topk_profile_counters", &topk_profile_counters, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("counters"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index 92499ed6..b565c3ef 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -85,6 +85,185 @@ const int64_t reserved_eos, const int64_t max_seq_lengths ); +void topk_output_sglang( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_seq_lengths +); + +void topk_output_sglang_ori( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +at::Tensor& indices_out, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages +); + +void topk_output_sglang_fused( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode, +const double mapping_power, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt +); + +void topk_output_adaptive( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode, +const double mapping_power +); + +void topk_output_adaptive_workspace( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +at::Tensor& partial_keys, +at::Tensor& partial_indices, +at::Tensor& done_counter, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode, +const double mapping_power, +const int64_t forced_splits, +const int64_t forced_partition, +const int64_t local_mode +); + +void topk_output_adaptive_workspace_midk( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +at::Tensor& partial_keys, +at::Tensor& partial_indices, +at::Tensor& done_counter, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode, +const double mapping_power, +const int64_t forced_splits +); + +void topk_output_adaptive_workspace_ablation( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +at::Tensor& partial_keys, +at::Tensor& partial_indices, +at::Tensor& done_counter, +at::Tensor& scratch, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t ablation_mode, +const int64_t forced_splits +); + +void topk_adaptive_phase1_only( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& partial_scores, +at::Tensor& partial_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages +); + +void topk_adaptive_phase2_only( +const at::Tensor& partial_scores, +const at::Tensor& partial_indices, +const at::Tensor& sparse_kv_indptr, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos +); + +void topk_remap_only( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +at::Tensor& remapped, +const int64_t eff_batch_size, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t mapping_mode, +const double mapping_power +); + +void topk_profile_histogram( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +at::Tensor& histograms, +const int64_t eff_batch_size, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt +); + +void topk_profile_counters( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +at::Tensor& counters, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt +); void sglang_plan_decode_fa3( const at::Tensor& cached_seq_lens, diff --git a/csrc/topk.cu b/csrc/topk.cu index 62d747eb..081bddf4 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -117,8 +117,8 @@ const int page_reserved_eos) void topk_output( const at::Tensor& x, const at::Tensor& dense_kv_indptr, -const at::Tensor& sparse_kv_indptr, const at::Tensor& dense_kv_indices, +const at::Tensor& sparse_kv_indptr, at::Tensor& sparse_kv_indices, const int64_t eff_batch_size, const int64_t topk_val, @@ -196,8 +196,20 @@ const int64_t max_num_pages reserved_bos, reserved_eos ); + } else if (max_num_pages <= 8192){ + TopKOutput_BF16_Kernel<512, 16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); } else { - TORCH_CHECK(false); + TORCH_CHECK(false, "topk_output: max_num_pages=", max_num_pages, + " exceeds the supported template ladder (8192)."); } -} +} \ No newline at end of file diff --git a/csrc/topk_adaptive_profile.cu b/csrc/topk_adaptive_profile.cu new file mode 100644 index 00000000..fc706da9 --- /dev/null +++ b/csrc/topk_adaptive_profile.cu @@ -0,0 +1,1145 @@ +/** + * Profile-only fixtures for the adaptive split TopK. Two distinct fixtures: + * + * [1] LEGACY split-2 histogram fixture (TopK_Phase1_Only_Kernel, + * TopK_Phase2_Only_Kernel, topk_adaptive_phase1_only, + * topk_adaptive_phase2_only). + * Implements an 8-bit coarse histogram + 8-bit refinement on bits [23:16] + * with a fixed split count of 2 (kNumSplits=2, kThreads=1024). + * Kept for historical comparison only; this is NOT the K=30 production + * path and is NOT representative of the current split kernel in + * topk_sglang_merge.cu. barrier = full_adaptive - (phase1 + phase2). + * + * [2] K=30 ablation fixture (Ablation_*_Kernel, + * topk_output_adaptive_workspace_ablation). + * Measures isolated costs: local sort, workspace write, atomic/fence, + * merge-only (multiple variants), memset-only, and the full adaptive path. + * Split configs exactly match production (kAblCfg1..kAblCfg32). + * ScoreT is bf16 only; mode is hardcoded to MAPPING_NONE. + * Production kernel lives in topk_sglang_merge.cu. + * Current production merge = MERGE_CUB_WARP (kAblMode_MergeCubWarp = 6). + * MERGE_PROD_DEFAULT (mode 5) is the legacy per-SPLITS dispatch kept for + * ablation comparison (not the current production merge). + * + * ablation_mode constants and merge variant constants are defined below in + * the K=30 ablation namespace section. + */ + + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + #include + #include + + #include "register.h" + + namespace { + + constexpr int kRadix = 256; + constexpr int kThreads = 1024; + constexpr int kWarpSize = 32; + constexpr int kNumSplits = 2; + constexpr size_t kMaxDynSmem = 96 * 1024; + + __device__ __forceinline__ uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + __device__ __forceinline__ uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ void run_cumsum_256(int s_hist[2][kRadix + 128]) { + const int tx = threadIdx.x; + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + if (tx < kRadix) { + const int j = 1 << i; + const int k = i & 1; + int v = s_hist[k][tx]; + if (tx < kRadix - j) v += s_hist[k][tx + j]; + s_hist[k ^ 1][tx] = v; + } + __syncthreads(); + } + } + + __device__ __forceinline__ int warp_compact_slot(bool selected, int* s_counter) { + const uint32_t mask = __activemask(); + const uint32_t ballot = __ballot_sync(mask, selected); + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_count = __popc(ballot); + const int rank = __popc(ballot & ((1u << lane) - 1u)); + const int first = __ffs(mask) - 1; + int base = 0; + if (lane == first) { + base = (warp_count > 0) ? ::atomicAdd(s_counter, warp_count) : 0; + } + base = __shfl_sync(mask, base, first); + return selected ? (base + rank) : -1; + } + + // ============================================================================ + // Phase 1 ONLY: per-chunk radix select, writes unordered (score, idx) pairs + // into partial_scores/partial_indices. No barrier, no merge. + // ============================================================================ + __global__ __launch_bounds__(kThreads) + void TopK_Phase1_Only_Kernel( + const __nv_bfloat16* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ dense_kv_indices, + float* __restrict__ partial_scores, + int32_t* __restrict__ partial_indices, + const int topk_val, + const int reserved_bos, + const int reserved_eos) + { + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = row_end - row_start; + const int half = (row_len + 1) / 2; + const int ck_begin = (n == 0) ? 0 : half; + const int ck_end = (n == 0) ? half : row_len; + const int ck_len = ck_end - ck_begin; + + const __nv_bfloat16* chunk_in = score + row_start + ck_begin; + const int* idx_map = dense_kv_indices + row_start + ck_begin; + float* part_keys = partial_scores + (static_cast(b) * kNumSplits + n) * topk_val; + int32_t* part_idx = partial_indices + (static_cast(b) * kNumSplits + n) * topk_val; + + extern __shared__ char smem_raw[]; + alignas(128) __shared__ int s_hist_buf[2][kRadix + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin; + alignas(128) __shared__ int s_sub_threshold_bin; + alignas(128) __shared__ int s_last_remain; + auto& s_hist = s_hist_buf[0]; + + float* s_remapped = reinterpret_cast(smem_raw); + uint8_t* s_bins = reinterpret_cast(s_remapped + ck_len); + + if (ck_len <= topk_val) { + for (int i = tx; i < topk_val; i += blockDim.x) { + if (i < ck_len) { + part_keys[i] = __bfloat162float(chunk_in[i]); + part_idx [i] = idx_map[i]; + } else { + part_keys[i] = -CUDART_INF_F; + part_idx [i] = -1; + } + } + return; + } + + if (tx < kRadix + 1) s_hist[tx] = 0; + if (tx == 0) { + s_counter = 0; + s_threshold_bin = -1; + s_sub_threshold_bin = -1; + s_last_remain = 0; + } + __syncthreads(); + + for (int i = tx; i < ck_len; i += blockDim.x) { + const float v = __bfloat162float(chunk_in[i]); + const uint8_t bin = convert_to_uint8(v); + s_remapped[i] = v; + s_bins [i] = bin; + ::atomicAdd(&s_hist[bin], 1); + } + __syncthreads(); + run_cumsum_256(s_hist_buf); + + if (tx < kRadix && s_hist[tx] > topk_val && s_hist[tx + 1] <= topk_val) { + s_threshold_bin = tx; + s_last_remain = topk_val - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin = s_threshold_bin; + + if (tx < kRadix + 1) s_hist[tx] = 0; + __syncthreads(); + + const int num_iters = (ck_len + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + const bool in_range = (i < ck_len); + int bin = -1; + if (in_range) bin = s_bins[i]; + const bool take_above = in_range && (bin > threshold_bin); + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + part_keys[slot] = s_remapped[i]; + part_idx [slot] = idx_map[i]; + } else if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[i]); + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + run_cumsum_256(s_hist_buf); + if (tx < kRadix && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) s_sub_threshold_bin = kRadix; + __syncthreads(); + const int sub_threshold_bin = s_sub_threshold_bin; + + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + const bool in_range = (i < ck_len); + int bin = -1; + if (in_range) bin = s_bins[i]; + int sub_bin = -1; + if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[i]); + sub_bin = (b32 >> 16) & 0xFF; + } + const bool take_sub = (sub_bin > sub_threshold_bin); + const int slot = warp_compact_slot(take_sub, &s_counter); + if (take_sub) { + part_keys[slot] = s_remapped[i]; + part_idx [slot] = idx_map[i]; + } else if (sub_bin == sub_threshold_bin) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + part_keys[topk_val - pos] = s_remapped[i]; + part_idx [topk_val - pos] = idx_map[i]; + } + } + } + } + + // ============================================================================ + // Phase 2 ONLY: read pre-populated (kNumSplits * K) candidates, radix-select + // top-K, write final indices. Grid = (batch,). + // ============================================================================ + __global__ __launch_bounds__(kThreads) + void TopK_Phase2_Only_Kernel( + const float* __restrict__ partial_scores, + const int32_t* __restrict__ partial_indices, + const int* __restrict__ sparse_kv_indptr, + int32_t* __restrict__ sparse_kv_indices, + const int topk_val, + const int reserved_bos) + { + const int b = blockIdx.x; + const int tx = threadIdx.x; + const int total = kNumSplits * topk_val; + + const float* keys_in = partial_scores + static_cast(b) * total; + const int32_t* idx_in = partial_indices + static_cast(b) * total; + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + + extern __shared__ char smem_raw[]; + alignas(128) __shared__ int s_hist_buf[2][kRadix + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin; + alignas(128) __shared__ int s_sub_threshold_bin; + alignas(128) __shared__ int s_last_remain; + auto& s_hist = s_hist_buf[0]; + + float* s_scores = reinterpret_cast(smem_raw); + int32_t* s_indices = reinterpret_cast(s_scores + total); + + if ((total & 3) == 0) { + const float4* kv = reinterpret_cast(keys_in); + const int4* iv = reinterpret_cast (idx_in); + float4* sv = reinterpret_cast(s_scores); + int4* iiv = reinterpret_cast (s_indices); + const int total4 = total >> 2; + for (int i = tx; i < total4; i += blockDim.x) { + sv[i] = kv[i]; + iiv[i] = iv[i]; + } + } else { + for (int i = tx; i < total; i += blockDim.x) { + s_scores[i] = keys_in[i]; + s_indices[i] = idx_in[i]; + } + } + + if (tx < kRadix + 1) s_hist[tx] = 0; + if (tx == 0) { + s_counter = 0; + s_threshold_bin = -1; + s_sub_threshold_bin = -1; + s_last_remain = 0; + } + __syncthreads(); + + const int num_iters = (total + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + if (i < total && s_indices[i] >= 0) { + const uint32_t b32 = convert_to_uint32(s_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_hist[bin], 1); + } + } + __syncthreads(); + run_cumsum_256(s_hist_buf); + + const int valid_count = s_hist[0]; + if (valid_count <= topk_val) { + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + const bool take = (i < total) && (s_indices[i] >= 0); + const int slot = warp_compact_slot(take, &s_counter); + if (take && slot < topk_val) out_idx[slot] = s_indices[i]; + } + return; + } + + if (tx < kRadix && s_hist[tx] > topk_val && s_hist[tx + 1] <= topk_val) { + s_threshold_bin = tx; + s_last_remain = topk_val - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin_m = s_threshold_bin; + + if (tx < kRadix + 1) s_hist[tx] = 0; + __syncthreads(); + + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + bool in_valid = false; + int bin = -1; + uint32_t b32 = 0; + if (i < total) { + const int32_t ii = s_indices[i]; + if (ii >= 0) { + in_valid = true; + b32 = convert_to_uint32(s_scores[i]); + bin = (b32 >> 24) & 0xFF; + } + } + const bool take_above = in_valid && (bin > threshold_bin_m); + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + out_idx[slot] = s_indices[i]; + } else if (in_valid && bin == threshold_bin_m) { + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + + run_cumsum_256(s_hist_buf); + if (tx < kRadix && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) s_sub_threshold_bin = kRadix; + __syncthreads(); + const int sub_threshold_bin_m = s_sub_threshold_bin; + + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + bool in_thr = false; + int sub_bin = -1; + if (i < total) { + const int32_t ii = s_indices[i]; + if (ii >= 0) { + const uint32_t b32 = convert_to_uint32(s_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + if (bin == threshold_bin_m) { + in_thr = true; + sub_bin = (b32 >> 16) & 0xFF; + } + } + } + const bool take = in_thr && (sub_bin > sub_threshold_bin_m); + const int slot = warp_compact_slot(take, &s_counter); + if (take) { + out_idx[slot] = s_indices[i]; + } else if (in_thr && sub_bin == sub_threshold_bin_m) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) out_idx[topk_val - pos] = s_indices[i]; + } + } + } + + template + void setup_smem_once() { + [[maybe_unused]] static const auto r = [] { + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dyn); + }(); + TORCH_CHECK(r == cudaSuccess, "profile kernel setup failed: ", + ::cudaGetErrorString(r)); + } + + } // namespace + + #define CHECK_CUDA_T(x) TORCH_CHECK(x.is_cuda(), #x " must be CUDA") + + // ============================================================================ + // Host entry — Phase 1 only. + // ============================================================================ + void topk_adaptive_phase1_only( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& partial_scores, + at::Tensor& partial_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) + { + CHECK_CUDA_T(x); CHECK_CUDA_T(dense_kv_indptr); + CHECK_CUDA_T(dense_kv_indices); CHECK_CUDA_T(partial_scores); CHECK_CUDA_T(partial_indices); + TORCH_CHECK(x.scalar_type() == at::ScalarType::BFloat16, + "profile kernels require bfloat16 input"); + + const int chunk_max = (static_cast(max_num_pages) + 1) / 2; + const size_t smem = static_cast(chunk_max) * sizeof(float) + + ((static_cast(chunk_max) + 15) & ~size_t(15)); + TORCH_CHECK(smem <= kMaxDynSmem, "phase1 smem too large"); + + setup_smem_once(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + dim3 grid(static_cast(eff_batch_size), + static_cast(kNumSplits)); + TopK_Phase1_Only_Kernel<<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + partial_scores.data_ptr(), + partial_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos), + static_cast(reserved_eos)); + TORCH_CHECK(cudaGetLastError() == cudaSuccess, "phase1 launch failed"); + } + + // ============================================================================ + // Host entry — Phase 2 only (expects partial_* pre-populated). + // ============================================================================ + void topk_adaptive_phase2_only( + const at::Tensor& partial_scores, + const at::Tensor& partial_indices, + const at::Tensor& sparse_kv_indptr, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos) + { + CHECK_CUDA_T(partial_scores); CHECK_CUDA_T(partial_indices); + CHECK_CUDA_T(sparse_kv_indptr); CHECK_CUDA_T(sparse_kv_indices); + + const size_t smem = static_cast(kNumSplits) * + static_cast(topk_val) * + (sizeof(float) + sizeof(int32_t)); + TORCH_CHECK(smem <= kMaxDynSmem, "phase2 smem too large"); + + setup_smem_once(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + dim3 grid(static_cast(eff_batch_size)); + TopK_Phase2_Only_Kernel<<>>( + partial_scores.data_ptr(), + partial_indices.data_ptr(), + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); + TORCH_CHECK(cudaGetLastError() == cudaSuccess, "phase2 launch failed"); + } + + // ============================================================================= + // K=30 phase ablation kernels and host entry. Bench-only fixture for + // `bench_ablation.py`. NOT a production code path. The production K=30 + // random-split parallel kernel and dispatcher live in topk_sglang_merge.cu. + // + // All kernels here are hardcoded to ScoreT=bf16, MAPPING_NONE, partition + // = CONTIGUOUS to keep the template instantiation count small. They share + // no code with the production path beyond the function declarations in + // register.h; helpers are duplicated below in the anonymous namespace. + // ============================================================================= + namespace { + + constexpr int kLocalK_Top30 = 32; + constexpr int kMaxFinalK_Top30 = 32; + constexpr int kPartContiguous = 1; // PART_CONTIGUOUS in topk_sglang_merge.cu + + // Per-split (NUM_THREADS, ITEMS_PER_THREAD) — must match the production + // configurations in topk_sglang_merge.cu (kCfg1..kCfg32) so the ablation + // numbers reflect the production launch parameters. + struct AblSplitCfg { int num_threads, items_per_thread; }; + constexpr AblSplitCfg kAblCfg1 = { 1024, 8 }; + constexpr AblSplitCfg kAblCfg2 = { 1024, 8 }; + constexpr AblSplitCfg kAblCfg4 = { 512, 8 }; + constexpr AblSplitCfg kAblCfg8 = { 256, 16 }; + constexpr AblSplitCfg kAblCfg16 = { 128, 16 }; + constexpr AblSplitCfg kAblCfg32 = { 64, 16 }; + + // --------------------------------------------------------------------------- + // Ablation mode constants (ablation_mode argument to + // topk_output_adaptive_workspace_ablation). + // --------------------------------------------------------------------------- + constexpr int kAblMode_FullAdaptive = 0; // full production path (reference) + constexpr int kAblMode_LocalWithWorkspace = 1; // local sort + workspace write, no merge + constexpr int kAblMode_LocalNoWorkspace = 2; // local sort only, no write, no merge + constexpr int kAblMode_WorkspaceWriteOnly = 3; // synthetic write to workspace + constexpr int kAblMode_AtomicOnly = 4; // atomic counter cost only + constexpr int kAblMode_MergeProdDefault = 5; // merge: legacy per-SPLITS dispatch + // (2-way for SPLITS=2, pairwise for SPLITS=4, + // k-way for SPLITS>=8). NOT current production. + constexpr int kAblMode_MergeCubWarp = 6; // merge: cub::WarpMergeSort — current production + constexpr int kAblMode_MergeCubBlock = 7; // merge: cub::BlockMergeSort benchmark + constexpr int kAblMode_MemsetOnly = 8; // counter memset cost only + constexpr int kAblMode_MergeManual2Way = 9; // merge: manual 2-way (requires split=2) + constexpr int kAblMode_MergePairwise4 = 10; // merge: pairwise tree (requires split=4) + constexpr int kAblMode_MergeKwayAll = 11; // merge: force k-way for all split counts + + // --------------------------------------------------------------------------- + // Merge variant indices for Ablation_MergeOnly_Kernel. + // --------------------------------------------------------------------------- + constexpr int MERGE_PROD_DEFAULT = 0; // legacy: 2-way(SPLITS=2)/pairwise(SPLITS=4)/k-way(>=8) + constexpr int MERGE_CUB_WARP = 1; // cub::WarpMergeSort — matches current production merge + // kMergeIPT=SPLITS; register pressure grows with SPLITS + constexpr int MERGE_CUB_BLOCK = 2; // cub::BlockMergeSort (benchmark; 64 threads) + constexpr int MERGE_MANUAL_2WAY = 3; // manual 2-way merge (requires SPLITS=2) + constexpr int MERGE_PAIRWISE_4 = 4; // pairwise tree (requires SPLITS=4) + constexpr int MERGE_KWAY = 5; // force k-way for all SPLITS (explicit baseline) + + template + __device__ __forceinline__ float vortex_to_float_p(T x); + template <> + __device__ __forceinline__ float vortex_to_float_p(float x) { return x; } + template <> + __device__ __forceinline__ float vortex_to_float_p<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); + } + + struct AblGreaterUint32 { + __device__ __forceinline__ bool operator()(uint32_t a, uint32_t b) const { + return a > b; + } + }; + + // k-way merge — same algorithm as topk_sglang_merge.cu's merge_sorted_kway. + template + __device__ __forceinline__ void abl_merge_sorted_kway( + const uint32_t* __restrict__ keys_in, + const int32_t* __restrict__ idx_in, + int32_t* __restrict__ out_idx, + int final_k) + { + const int lane = threadIdx.x & 31; + const bool is_my_list = (lane < SPLITS); + const uint32_t full = 0xFFFFFFFFu; + + int ptr = 0; + uint32_t cur_key = is_my_list ? keys_in[lane * LOCAL_K] : 0u; + int32_t cur_idx = is_my_list ? idx_in [lane * LOCAL_K] : -1; + + #pragma unroll + for (int t = 0; t < MAX_FINAL_K; ++t) { + uint32_t best_key = cur_key; + int best_lane = lane; + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + uint32_t okey = __shfl_xor_sync(full, best_key, offset); + int olane = __shfl_xor_sync(full, best_lane, offset); + bool take = (okey > best_key) || (okey == best_key && olane < best_lane); + best_key = take ? okey : best_key; + best_lane = take ? olane : best_lane; + } + int32_t win_idx = __shfl_sync(full, cur_idx, best_lane); + if (lane == 0 && t < final_k && win_idx >= 0) out_idx[t] = win_idx; + if (lane == best_lane) { + ++ptr; + if (is_my_list && ptr < LOCAL_K) { + cur_key = keys_in[lane * LOCAL_K + ptr]; + cur_idx = idx_in [lane * LOCAL_K + ptr]; + } else { + cur_key = 0u; + cur_idx = -1; + } + } + } + } + + __device__ __forceinline__ void abl_merge_2way_manual_lane0( + const uint32_t* __restrict__ l0_keys, const int32_t* __restrict__ l0_idx, int n0, + const uint32_t* __restrict__ l1_keys, const int32_t* __restrict__ l1_idx, int n1, + int32_t* __restrict__ out_idx, + int final_k) + { + if (threadIdx.x != 0) return; + int p0 = 0, p1 = 0; + for (int t = 0; t < final_k; ++t) { + const uint32_t k0 = (p0 < n0) ? l0_keys[p0] : 0u; + const uint32_t k1 = (p1 < n1) ? l1_keys[p1] : 0u; + if (k0 >= k1 && p0 < n0) { + out_idx[t] = l0_idx[p0]; ++p0; + } else if (p1 < n1) { + out_idx[t] = l1_idx[p1]; ++p1; + } else { + out_idx[t] = -1; + } + } + } + + __device__ __forceinline__ void abl_merge_pairwise_4( + const uint32_t* __restrict__ keys_in, + const int32_t* __restrict__ idx_in, + int32_t* __restrict__ out_idx, int final_k, + uint32_t* __restrict__ tmp01_keys, int32_t* __restrict__ tmp01_idx, + uint32_t* __restrict__ tmp23_keys, int32_t* __restrict__ tmp23_idx) + { + constexpr int LK = kLocalK_Top30; + const int lane = threadIdx.x & 31; + if (lane == 0) { + int p0 = 0, p1 = 0; + #pragma unroll + for (int t = 0; t < LK; ++t) { + const uint32_t k0 = (p0 < LK) ? keys_in[0 * LK + p0] : 0u; + const uint32_t k1 = (p1 < LK) ? keys_in[1 * LK + p1] : 0u; + if (k0 >= k1 && p0 < LK) { tmp01_keys[t] = k0; tmp01_idx[t] = idx_in[0*LK+p0]; ++p0; } + else if (p1 < LK) { tmp01_keys[t] = k1; tmp01_idx[t] = idx_in[1*LK+p1]; ++p1; } + else { tmp01_keys[t] = 0u; tmp01_idx[t] = -1; } + } + } else if (lane == 1) { + int p2 = 0, p3 = 0; + #pragma unroll + for (int t = 0; t < LK; ++t) { + const uint32_t k2 = (p2 < LK) ? keys_in[2 * LK + p2] : 0u; + const uint32_t k3 = (p3 < LK) ? keys_in[3 * LK + p3] : 0u; + if (k2 >= k3 && p2 < LK) { tmp23_keys[t] = k2; tmp23_idx[t] = idx_in[2*LK+p2]; ++p2; } + else if (p3 < LK) { tmp23_keys[t] = k3; tmp23_idx[t] = idx_in[3*LK+p3]; ++p3; } + else { tmp23_keys[t] = 0u; tmp23_idx[t] = -1; } + } + } + __syncwarp(); + if (lane == 0) { + int p0 = 0, p1 = 0; + for (int t = 0; t < final_k; ++t) { + const uint32_t k0 = (p0 < LK) ? tmp01_keys[p0] : 0u; + const uint32_t k1 = (p1 < LK) ? tmp23_keys[p1] : 0u; + if (k0 >= k1 && p0 < LK) { out_idx[t] = tmp01_idx[p0]; ++p0; } + else if (p1 < LK) { out_idx[t] = tmp23_idx[p1]; ++p1; } + else { out_idx[t] = -1; } + } + } + } + + // uint32 sort key for an fp32 value. + __device__ __forceinline__ uint32_t abl_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + // ---- Ablation kernels -------------------------------------------------------- + + template + __global__ __launch_bounds__(NUM_THREADS) + void Ablation_LocalOnly_Kernel( + const __nv_bfloat16* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ dense_kv_indices, + uint32_t* __restrict__ partial_keys, + int32_t* __restrict__ partial_indices, + const int reserved_bos, + const int reserved_eos) + { + // Stage 1 + workspace write. No atomic. No merge. + using KeyT = uint32_t; using ValueT = int32_t; + using BlockSortT = cub::BlockRadixSort; + __shared__ typename BlockSortT::TempStorage sort_smem; + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = max(0, row_end - row_start); + if (row_len <= 0) return; + const int group_begin = (row_len * n) / SPLITS; + const int group_end = (row_len * (n + 1)) / SPLITS; + const int group_len = group_end - group_begin; + const __nv_bfloat16* row_scores = score + row_start; + const int* row_idxmap = dense_kv_indices + row_start; + KeyT keys[ITEMS_PER_THREAD]; ValueT values[ITEMS_PER_THREAD]; + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int local_rank = tx + k * NUM_THREADS; + if (local_rank < group_len) { + const int pos = group_begin + local_rank; + const float raw = vortex_to_float_p(row_scores[pos]); + keys [k] = abl_to_uint32(raw); + values[k] = row_idxmap[pos]; + } else { keys[k] = 0u; values[k] = -1; } + } + BlockSortT(sort_smem).SortDescending(keys, values); + __syncthreads(); + constexpr int LK = kLocalK_Top30; + const int64_t part_off = (static_cast(b) * SPLITS + n) * LK; + uint32_t* part_keys = partial_keys + part_off; + int32_t* part_idx = partial_indices + part_off; + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int rank = tx * ITEMS_PER_THREAD + k; + if (rank < LK) { part_keys[rank] = keys[k]; part_idx[rank] = values[k]; } + } + } + + template + __global__ __launch_bounds__(NUM_THREADS) + void Ablation_LocalNoWorkspace_Kernel( + const __nv_bfloat16* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ dense_kv_indices, + int32_t* __restrict__ scratch, + const int reserved_bos, + const int reserved_eos) + { + using KeyT = uint32_t; using ValueT = int32_t; + using BlockSortT = cub::BlockRadixSort; + __shared__ typename BlockSortT::TempStorage sort_smem; + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = max(0, row_end - row_start); + if (row_len <= 0) return; + const int group_begin = (row_len * n) / SPLITS; + const int group_end = (row_len * (n + 1)) / SPLITS; + const int group_len = group_end - group_begin; + const __nv_bfloat16* row_scores = score + row_start; + const int* row_idxmap = dense_kv_indices + row_start; + KeyT keys[ITEMS_PER_THREAD]; ValueT values[ITEMS_PER_THREAD]; + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int local_rank = tx + k * NUM_THREADS; + if (local_rank < group_len) { + const int pos = group_begin + local_rank; + const float raw = vortex_to_float_p(row_scores[pos]); + keys [k] = abl_to_uint32(raw); + values[k] = row_idxmap[pos]; + } else { keys[k] = 0u; values[k] = -1; } + } + BlockSortT(sort_smem).SortDescending(keys, values); + if (tx == 0) scratch[blockIdx.x * gridDim.y + blockIdx.y] = values[0]; + } + + template + __global__ __launch_bounds__(32) + void Ablation_WorkspaceWriteOnly_Kernel( + uint32_t* __restrict__ partial_keys, + int32_t* __restrict__ partial_indices) + { + constexpr int LK = kLocalK_Top30; + const int b = blockIdx.x; + const int n = blockIdx.y; + const int lane = threadIdx.x; + const int64_t part_off = (static_cast(b) * SPLITS + n) * LK; + if (lane < LK) { + partial_keys [part_off + lane] = static_cast(b * 31 + n * 7 + lane); + partial_indices[part_off + lane] = b * 1009 + n * 17 + lane; + } + } + + template + __global__ __launch_bounds__(32) + void Ablation_AtomicOnly_Kernel( + int32_t* __restrict__ done_counter, + int32_t* __restrict__ scratch) + { + const int b = blockIdx.x; + const int tx = threadIdx.x; + __shared__ int s_is_last; + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + } + __syncthreads(); + if (s_is_last && tx == 0) scratch[b] = 1; + } + + // Correctness notes for Ablation_MergeOnly_Kernel: + // - MERGE_PROD_DEFAULT exactly mirrors topk_sglang_merge.cu Stage 2 tie + // preference: lower list index wins on equal key (k-way and pairwise); + // lane 0 favors list 0 on equal key (2-way manual). + // - CUB variants sort by uint32 key only. Tie-breaking for duplicate keys + // is implementation-defined and will NOT match production index order. + // Use unique keys for exact index comparison in correctness tests. + // - For throughput benchmarking, duplicate keys are acceptable since only + // latency is measured. + template + __global__ __launch_bounds__(64) + void Ablation_MergeOnly_Kernel( + const uint32_t* __restrict__ partial_keys, + const int32_t* __restrict__ partial_indices, + const int* __restrict__ sparse_kv_indptr, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int reserved_bos) + { + constexpr int LK = kLocalK_Top30; + constexpr int kCandidates = SPLITS * LK; + const int b = blockIdx.x; + const int tx = threadIdx.x; + const int64_t row_off = static_cast(b) * kCandidates; + const uint32_t* keys_in = partial_keys + row_off; + const int32_t* idx_in = partial_indices + row_off; + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + + if constexpr (MERGE_VARIANT == MERGE_PROD_DEFAULT) { + // Mirrors topk_sglang_merge.cu Stage 2 exactly: different strategy per SPLITS. + if constexpr (SPLITS == 2) { + if (tx < 32) abl_merge_2way_manual_lane0( + keys_in, idx_in, LK, + keys_in + LK, idx_in + LK, LK, + out_idx, topk_val); + } else if constexpr (SPLITS == 4) { + __shared__ uint32_t s_pd01k[LK]; __shared__ int32_t s_pd01i[LK]; + __shared__ uint32_t s_pd23k[LK]; __shared__ int32_t s_pd23i[LK]; + if (tx < 32) abl_merge_pairwise_4(keys_in, idx_in, out_idx, topk_val, + s_pd01k, s_pd01i, s_pd23k, s_pd23i); + } else { + if (tx < 32) abl_merge_sorted_kway( + keys_in, idx_in, out_idx, topk_val); + } + } else if constexpr (MERGE_VARIANT == MERGE_CUB_WARP) { + // kMergeIPT grows with SPLITS: SPLITS=16 → kMergeIPT=16, SPLITS=32 → 32. + // Large IPT increases register pressure and may cause spilling on sm_90+. + constexpr int kMergeIPT = (kCandidates + 31) / 32; + using WarpMergeT = cub::WarpMergeSort; + __shared__ typename WarpMergeT::TempStorage warp_merge; + if (tx < 32) { + uint32_t wkeys[kMergeIPT]; int32_t wvalues[kMergeIPT]; + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + wkeys [k] = (rank < kCandidates) ? keys_in[rank] : 0u; + wvalues[k] = (rank < kCandidates) ? idx_in [rank] : -1; + } + WarpMergeT(warp_merge).Sort(wkeys, wvalues, AblGreaterUint32{}); + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + if (rank < topk_val) out_idx[rank] = wvalues[k]; + } + } + } else if constexpr (MERGE_VARIANT == MERGE_CUB_BLOCK) { + constexpr int kBlockThreads = 64; + constexpr int kMergeIPT = (kCandidates + kBlockThreads - 1) / kBlockThreads; + using BlockMergeT = cub::BlockMergeSort; + __shared__ typename BlockMergeT::TempStorage block_merge; + if (tx < kBlockThreads) { + uint32_t wkeys[kMergeIPT]; int32_t wvalues[kMergeIPT]; + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + wkeys [k] = (rank < kCandidates) ? keys_in[rank] : 0u; + wvalues[k] = (rank < kCandidates) ? idx_in [rank] : -1; + } + BlockMergeT(block_merge).Sort(wkeys, wvalues, AblGreaterUint32{}); + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + if (rank < topk_val) out_idx[rank] = wvalues[k]; + } + } + } else if constexpr (MERGE_VARIANT == MERGE_MANUAL_2WAY) { + static_assert(SPLITS == 2, "manual_2way merge requires SPLITS=2"); + if (tx < 32) abl_merge_2way_manual_lane0( + keys_in, idx_in, LK, + keys_in + LK, idx_in + LK, LK, + out_idx, topk_val); + } else if constexpr (MERGE_VARIANT == MERGE_PAIRWISE_4) { + static_assert(SPLITS == 4, "pairwise_tree merge requires SPLITS=4"); + __shared__ uint32_t s_t01k[LK]; + __shared__ int32_t s_t01i[LK]; + __shared__ uint32_t s_t23k[LK]; + __shared__ int32_t s_t23i[LK]; + if (tx < 32) abl_merge_pairwise_4( + keys_in, idx_in, out_idx, topk_val, + s_t01k, s_t01i, s_t23k, s_t23i); + } else if constexpr (MERGE_VARIANT == MERGE_KWAY) { + // Force k-way for all SPLITS — explicit baseline to isolate k-way cost. + if (tx < 32) abl_merge_sorted_kway( + keys_in, idx_in, out_idx, topk_val); + } + } + + } // namespace + + void topk_output_adaptive_workspace_ablation( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& partial_keys, + at::Tensor& partial_indices, + at::Tensor& done_counter, + at::Tensor& scratch, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t ablation_mode, + const int64_t forced_splits) + { + TORCH_CHECK(x.scalar_type() == at::ScalarType::BFloat16, + "ablation kernels are bf16-only"); + TORCH_CHECK(topk_val > 0 && topk_val <= kMaxFinalK_Top30, + "ablation kernels are K<=32 only"); + + int split = forced_splits > 0 ? static_cast(forced_splits) : 8; + TORCH_CHECK(split == 1 || split == 2 || split == 4 || split == 8 || + split == 16 || split == 32, + "forced_splits must be {1,2,4,8,16,32}, got ", split); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + // memset_only: just clear the counter. + if (ablation_mode == 8) { + if (split > 1) { + ::cudaMemsetAsync(done_counter.data_ptr(), 0, + sizeof(int32_t) * static_cast(eff_batch_size), + stream); + } + return; + } + // atomic_only needs the counter pre-cleared so each call sees a fresh state. + if (ablation_mode == 4 && split > 1) { + ::cudaMemsetAsync(done_counter.data_ptr(), 0, + sizeof(int32_t) * static_cast(eff_batch_size), + stream); + } + + uint32_t* part_keys_ptr = reinterpret_cast(partial_keys.data_ptr()); + int32_t* part_idx_ptr = partial_indices.data_ptr(); + int32_t* done_ptr = done_counter.data_ptr(); + int32_t* scratch_ptr = scratch.data_ptr(); + + dim3 grid_full (static_cast(eff_batch_size), + static_cast(split)); + dim3 grid_merge(static_cast(eff_batch_size), 1u); + + const __nv_bfloat16* x_ptr = + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()); + + #define LAUNCH_ABL(KERNEL, GRID, NT, ...) \ + do { KERNEL<<>>(__VA_ARGS__); } while (0) + + switch (ablation_mode) { + case 1: { + switch (split) { + case 1: LAUNCH_ABL((Ablation_LocalOnly_Kernel<1, kAblCfg1.num_threads, kAblCfg1.items_per_thread>), grid_full, kAblCfg1.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 2: LAUNCH_ABL((Ablation_LocalOnly_Kernel<2, kAblCfg2.num_threads, kAblCfg2.items_per_thread>), grid_full, kAblCfg2.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 4: LAUNCH_ABL((Ablation_LocalOnly_Kernel<4, kAblCfg4.num_threads, kAblCfg4.items_per_thread>), grid_full, kAblCfg4.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 8: LAUNCH_ABL((Ablation_LocalOnly_Kernel<8, kAblCfg8.num_threads, kAblCfg8.items_per_thread>), grid_full, kAblCfg8.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 16: LAUNCH_ABL((Ablation_LocalOnly_Kernel<16, kAblCfg16.num_threads, kAblCfg16.items_per_thread>), grid_full, kAblCfg16.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 32: LAUNCH_ABL((Ablation_LocalOnly_Kernel<32, kAblCfg32.num_threads, kAblCfg32.items_per_thread>), grid_full, kAblCfg32.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + } + break; + } + case 2: { + switch (split) { + case 1: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<1, kAblCfg1.num_threads, kAblCfg1.items_per_thread>), grid_full, kAblCfg1.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 2: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<2, kAblCfg2.num_threads, kAblCfg2.items_per_thread>), grid_full, kAblCfg2.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 4: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<4, kAblCfg4.num_threads, kAblCfg4.items_per_thread>), grid_full, kAblCfg4.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 8: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<8, kAblCfg8.num_threads, kAblCfg8.items_per_thread>), grid_full, kAblCfg8.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 16: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<16, kAblCfg16.num_threads, kAblCfg16.items_per_thread>), grid_full, kAblCfg16.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 32: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<32, kAblCfg32.num_threads, kAblCfg32.items_per_thread>), grid_full, kAblCfg32.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + } + break; + } + case 3: { + switch (split) { + case 1: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<1>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + case 2: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<2>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + case 4: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<4>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + case 8: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<8>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + case 16: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<16>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + case 32: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<32>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + } + break; + } + case 4: { + switch (split) { + case 1: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<1>), grid_full, 32, done_ptr, scratch_ptr); break; + case 2: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<2>), grid_full, 32, done_ptr, scratch_ptr); break; + case 4: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<4>), grid_full, 32, done_ptr, scratch_ptr); break; + case 8: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<8>), grid_full, 32, done_ptr, scratch_ptr); break; + case 16: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<16>), grid_full, 32, done_ptr, scratch_ptr); break; + case 32: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<32>), grid_full, 32, done_ptr, scratch_ptr); break; + } + break; + } + // kAblMode_MergeProdDefault=5, kAblMode_MergeCubWarp=6, kAblMode_MergeCubBlock=7 + // map to MERGE_PROD_DEFAULT=0, MERGE_CUB_WARP=1, MERGE_CUB_BLOCK=2 respectively. + case 5: case 6: case 7: { + const int variant = static_cast(ablation_mode - 5); + auto launch_merge = [&](auto split_const_var, int v) { + constexpr int S = decltype(split_const_var)::value; + switch (v) { + case MERGE_PROD_DEFAULT: + LAUNCH_ABL((Ablation_MergeOnly_Kernel), grid_merge, 32, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); break; + case MERGE_CUB_WARP: + LAUNCH_ABL((Ablation_MergeOnly_Kernel), grid_merge, 32, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); break; + case MERGE_CUB_BLOCK: + LAUNCH_ABL((Ablation_MergeOnly_Kernel), grid_merge, 64, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); break; + } + }; + switch (split) { + case 1: launch_merge(std::integral_constant{}, variant); break; + case 2: launch_merge(std::integral_constant{}, variant); break; + case 4: launch_merge(std::integral_constant{}, variant); break; + case 8: launch_merge(std::integral_constant{}, variant); break; + case 16: launch_merge(std::integral_constant{}, variant); break; + case 32: launch_merge(std::integral_constant{}, variant); break; + } + break; + } + case 9: { // kAblMode_MergeManual2Way + TORCH_CHECK(split == 2, "ablation 9 (merge_manual_2way) requires forced_splits=2"); + LAUNCH_ABL((Ablation_MergeOnly_Kernel<2, MERGE_MANUAL_2WAY>), grid_merge, 32, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); + break; + } + case 10: { // kAblMode_MergePairwise4 + TORCH_CHECK(split == 4, "ablation 10 (merge_pairwise_4) requires forced_splits=4"); + LAUNCH_ABL((Ablation_MergeOnly_Kernel<4, MERGE_PAIRWISE_4>), grid_merge, 32, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); + break; + } + case 11: { // kAblMode_MergeKwayAll: force k-way regardless of SPLITS + auto launch_kway = [&](auto split_tag) { + constexpr int S = decltype(split_tag)::value; + LAUNCH_ABL((Ablation_MergeOnly_Kernel), grid_merge, 32, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); + }; + switch (split) { + case 1: launch_kway(std::integral_constant{}); break; + case 2: launch_kway(std::integral_constant{}); break; + case 4: launch_kway(std::integral_constant{}); break; + case 8: launch_kway(std::integral_constant{}); break; + case 16: launch_kway(std::integral_constant{}); break; + case 32: launch_kway(std::integral_constant{}); break; + } + break; + } + case 0: { + // full_parallel — re-enter the production workspace API with forced + // split + CONTIGUOUS partition. This makes the "0" mode useful as the + // 100% reference for the other ablations. + topk_output_adaptive_workspace( + x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, + sparse_kv_indices, partial_keys, partial_indices, done_counter, + eff_batch_size, topk_val, reserved_bos, reserved_eos, + max_num_pages, /*mapping_mode=*/0, /*mapping_power=*/0.0, + forced_splits, /*forced_partition=*/kPartContiguous, /*local_mode=*/0); + break; + } + default: + TORCH_CHECK(false, "unknown ablation_mode=", ablation_mode, + " (valid range: 0–11)"); + } + #undef LAUNCH_ABL + + const auto rc = cudaGetLastError(); + TORCH_CHECK(rc == cudaSuccess, + "ablation launch failed: ", ::cudaGetErrorString(rc)); + } diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh new file mode 100644 index 00000000..0b6474bc --- /dev/null +++ b/csrc/topk_mapping.cuh @@ -0,0 +1,213 @@ +#pragma once +#include +#include +#include + +// ============================================================ +// TopK bucket-sort Stage-1 remap transforms (lean version). +// +// These are element-wise transforms applied to scores before +// the Stage-1 8-bit histogram bucketing. The goal is to spread +// a skewed raw distribution more uniformly across the 256 bins +// so the threshold bin shrinks and Stage-2 refinement does less +// work. Stage 2 still uses convert_to_uint32() on the remapped +// value's raw bits for tie-breaking. +// +// There is no pre-pass, no auto-range, no LUT, no quantile +// table, and no shared-memory state — each transform is a +// pure function of one float. The heavy pre-pass machinery +// (auto-range, pivot, tail-window, topk-window, LUT_CDF, +// QUANTILE, SUBTRACT, TRUNC8) lives in +// csrc/archived/fast_topk_vortex_prepass.cu. +// ============================================================ + +enum TopKMappingMode { + MAPPING_NONE = 0, // identity (no remap) + // MAPPING_LUT_CDF = 1, // bin lookup: new_bin = lut[convert_to_uint8(x)] + // MAPPING_QUANTILE = 2, // binary search over 256 calibrated quantile thresholds + MAPPING_POWER = 3, // sign(x) * |x|^p + MAPPING_LOG = 4, // sign(x) * log(|x| + 1) + MAPPING_ASINH = 6, // asinh(beta * x) + MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|) + MAPPING_TRUNC8 = 8, // identity bucketing (historical name, alias of MAPPING_NONE) + MAPPING_ERF = 9, // erf(alpha * x) + MAPPING_TANH = 10, // tanh(alpha * x) + MAPPING_SUBTRACT = 11, // x - pivot, with pivot = power_exp (free hyperparameter) + MAPPING_EXP_STRETCH = 13, // exp(alpha * x) + // Top-spreading transforms (see CLAUDE.md / remap bench plan): + // amplify differences in the high-score region so the top-K values + // occupy multiple Stage-1 bins instead of collapsing into one. + MAPPING_SHIFT_POW2 = 15, // sign(x - p) * (x - p)^2 [p = power_exp] + MAPPING_SHIFT_POW3 = 16, // (x - p)^3 [p = power_exp] + MAPPING_LINEAR_STEEP = 17, // x + k * max(x, 0) [k = power_exp] + // One-sided spread: collapse below-pivot values into a single bin so + // every above-pivot page gets its own slice of the 256-bin histogram. + MAPPING_HALF_SQUARE = 18, // max(x - p, 0)^2 [p = power_exp] + MAPPING_HALF_CUBE = 19, // max(x - p, 0)^3 [p = power_exp] + // Bit-level remap: identity value transform, but the Stage-1 bucket + // function in fast_topk_clean_fused switches to a mantissa-heavy bit + // slice (bits [23:16] of convert_to_uint32) that gives 128 sub-bins + // per exponent slot instead of 4. Zero per-element compute overhead; + // the "remap" is the bucket change. Monotonic within 2 adjacent + // fp32 exponent slots. + // MAPPING_DENSE_MANT = 20, // identity; bucketing handled in fused kernel +}; + +struct TopKMappingParams { + int mode; // TopKMappingMode + float power_exp; // Free hyperparameter: p / alpha / beta / pivot depending on mode + const uint8_t* __restrict__ lut; // [256] uint8 LUT, MAPPING_LUT_CDF only + const float* __restrict__ quantiles; // [256] float quantile breakpoints, MAPPING_QUANTILE only +}; + +// ---- Element-wise transforms ---- + +__device__ __forceinline__ float transform_power(float x, float p) { + return copysignf(__powf(fabsf(x), p), x); +} + +__device__ __forceinline__ float transform_log(float x) { + return copysignf(__logf(fabsf(x) + 1.0f), x); +} + +__device__ __forceinline__ float transform_asinh(float x, float beta) { + return asinhf(beta * x); +} + +__device__ __forceinline__ float transform_log1p(float x, float alpha) { + return copysignf(log1pf(alpha * fabsf(x)), x); +} + +__device__ __forceinline__ float transform_erf(float x, float alpha) { + return erff(alpha * x); +} + +__device__ __forceinline__ float transform_tanh(float x, float alpha) { + return tanhf(alpha * x); +} + +__device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { + float z = alpha * x; + z = fminf(z, 80.0f); // prevent float32 overflow (exp(80) ~ 5.5e34) + return expf(z); +} + +// Signed squared distance from a pivot. ~3 ops (1 sub, 1 mul, 1 copysign). +// Quadratically amplifies differences between values far from pivot so the +// top-K region gets spread across multiple Stage-1 bins. +__device__ __forceinline__ float transform_shift_pow2(float x, float pivot) { + const float d = x - pivot; + return copysignf(d * d, d); +} + +// Signed cubic of distance from pivot. ~3 ops (1 sub, 2 mul; odd function so +// no copysign). Steeper growth than pow2 for even tighter top-K clusters. +__device__ __forceinline__ float transform_shift_pow3(float x, float pivot) { + const float d = x - pivot; + return d * d * d; +} + +// Half-range linear stretch: positive values get multiplied by (1 + k), +// negative values pass through untouched. ~2 ops (fmax + fma). For softmax- +// style attention scores (which are non-negative after softmax), k = 8..16 +// shifts the positive fp16 exponent up by 3..4 slots and empties out the +// collision at the top of the distribution. +__device__ __forceinline__ float transform_linear_steep(float x, float k) { + return fmaf(k, fmaxf(x, 0.0f), x); +} + +// One-sided shifted square: values below pivot collapse to 0 (they all end +// up in the same low Stage-1 bin), above-pivot values are squared so their +// differences amplify quadratically. ~2 ops (fmax + mul). The whole 256-bin +// histogram becomes dedicated to the top slice of the distribution. +__device__ __forceinline__ float transform_half_square(float x, float pivot) { + const float d = fmaxf(x - pivot, 0.0f); + return d * d; +} + +// One-sided shifted cube: like half_square but cubic. ~3 ops. Best when the +// top-K region is even more tightly clustered and needs steeper amplification. +__device__ __forceinline__ float transform_half_cube(float x, float pivot) { + const float d = fmaxf(x - pivot, 0.0f); + return d * d * d; +} + +// Compile-time templated dispatcher. When the caller knows the mapping mode +// at template-instantiation time, this lets the compiler fully inline the +// transform into the Stage-1 inner loop and eliminate the runtime switch +// that `apply_transform` would otherwise perform per element. Used by the +// per-mode specializations of `fast_topk_clean_fused` in topk_sglang.cu. +template +__device__ __forceinline__ float apply_transform_tmpl(float x, float p) { + if constexpr (MODE == MAPPING_POWER) return transform_power(x, p); + else if constexpr (MODE == MAPPING_LOG) return transform_log(x); + else if constexpr (MODE == MAPPING_ASINH) return transform_asinh(x, p); + else if constexpr (MODE == MAPPING_LOG1P) return transform_log1p(x, p); + else if constexpr (MODE == MAPPING_ERF) return transform_erf(x, p); + else if constexpr (MODE == MAPPING_TANH) return transform_tanh(x, p); + else if constexpr (MODE == MAPPING_SUBTRACT) return x - p; + else if constexpr (MODE == MAPPING_EXP_STRETCH) return transform_exp_stretch(x, p); + else if constexpr (MODE == MAPPING_SHIFT_POW2) return transform_shift_pow2(x, p); + else if constexpr (MODE == MAPPING_SHIFT_POW3) return transform_shift_pow3(x, p); + else if constexpr (MODE == MAPPING_LINEAR_STEEP) return transform_linear_steep(x, p); + else if constexpr (MODE == MAPPING_HALF_SQUARE) return transform_half_square(x, p); + else if constexpr (MODE == MAPPING_HALF_CUBE) return transform_half_cube(x, p); + else return x; // NONE / TRUNC8 +} + +// Pure element-wise dispatcher. Returns the *float value* after the transform. +__device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { + switch (params.mode) { + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_SUBTRACT: return x - params.power_exp; + case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); + case MAPPING_SHIFT_POW2: return transform_shift_pow2(x, params.power_exp); + case MAPPING_SHIFT_POW3: return transform_shift_pow3(x, params.power_exp); + case MAPPING_LINEAR_STEEP: return transform_linear_steep(x, params.power_exp); + case MAPPING_HALF_SQUARE: return transform_half_square(x, params.power_exp); + case MAPPING_HALF_CUBE: return transform_half_cube(x, params.power_exp); + case MAPPING_TRUNC8: + default: return x; // NONE / TRUNC8 + } +} + +// Bin-selection table modes (LUT_CDF / QUANTILE) have been retired. +// This helper is kept for ABI compat with callers that still invoke it. +__device__ __forceinline__ bool mapping_uses_table(int /*mode*/) { + return false; +} + +// Binary search over a sorted [256] quantile table. Returns the largest +// index i such that x >= quantiles[i], in [0, 255]. +__device__ __forceinline__ uint8_t quantile_bin_lookup( + float x, const float* __restrict__ s_quantiles) +{ + int lo = 0, hi = 255; +#pragma unroll 8 + for (int iter = 0; iter < 8; ++iter) { + int mid = (lo + hi + 1) >> 1; + if (x >= s_quantiles[mid]) lo = mid; + else hi = mid - 1; + } + return static_cast(lo); +} + +// Forward decl so compute_stage1_bin can call it. Defined in the enclosing TU. +__device__ __forceinline__ uint8_t convert_to_uint8(float x); + +// Compute the Stage-1 bin for a raw score. LUT_CDF / QUANTILE modes +// have been removed; every mode now goes through the element-wise +// apply_transform + convert_to_uint8. +__device__ __forceinline__ uint8_t compute_stage1_bin( + float raw, + const TopKMappingParams& params, + const uint8_t* __restrict__ /*s_lut*/, + const float* __restrict__ /*s_quantiles*/) +{ + return convert_to_uint8(apply_transform(raw, params)); +} diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu new file mode 100644 index 00000000..cc8c6b37 --- /dev/null +++ b/csrc/topk_sglang.cu @@ -0,0 +1,1377 @@ +/** + * Vortex TopK kernels. + * + * Three production kernels: + * - fast_topk_clean : unmapped baseline (two-stage radix). + * - fast_topk_clean_fused : remap + topk fused (apply_transform + * applied inline in Stage-1 bucketing). + * - TopKRemapOnly_Kernel : standalone element-wise remap pass + * used by the split-phase benchmark. + * + * Profiling kernels (counter collection, histogram collection) live in + * topk_sglang_profile.cu and MUST NOT be used for latency measurements — + * they intentionally write extra diagnostic state to global memory. + * + * Archived / historical kernels: csrc/archived/ (fast_topk_vortex with + * pre-pass modes, TopKOutput_Ori_Kernel with flexible radix_bits, the + * original SGLang reference kernel). + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +// Fused-kernel dynamic smem ceiling. The fused kernel uses `kSmem` bytes for +// f_input_idx (2 × SMEM_INPUT_SIZE ints) AND an extra `max_num_pages` bytes +// for s_bins (one uint8_t per page). Ceiling of 96 KB covers max_num_pages up +// to 65536 and fits the opt-in dynamic-smem limits on every target in +// setup.py (sm_86 ≥99KB, sm_89 100KB, sm_90 228KB, sm_100a/120 ≥100KB). +// Only `topk_output_sglang_fused` uses this ceiling; the other kernels keep +// kSmem as their dynamic-smem budget. +constexpr size_t kFusedSmemMax = 96 * 1024; + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +// Mantissa-heavy Stage-1 bucket for MAPPING_DENSE_MANT. Returns bits +// [23:16] of the sign-adjusted float32 key = 1 exp LSB + 7 top +// mantissa bits. This yields 128 mantissa sub-bins per exp slot (vs +// 4 in the current fp16 scheme — 32× more resolution) and is strictly +// monotonic across 2 adjacent fp32 exponent slots (factor-of-4 value +// range). Designed for the common case where the top-K scores cluster +// tightly: softmax-attention outputs on Qwen / Llama typically live +// in ~1 exp slot of magnitude near the top. Values with exponents +// outside the 2-slot monotonic window collide with lower bins, which +// only causes a correctness issue if top-K elements span more than +// 2 exp slots — verified empirically before shipping. +__device__ __forceinline__ auto convert_to_uint8_dense(float x) -> uint8_t { + const uint32_t bits = __float_as_uint(x); + const uint32_t key = (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + return static_cast((key >> 16) & 0xFFu); +} + +// ---- Vortex additions ---- + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +constexpr int VORTEX_MAX_TOPK = 2048; + +#include "topk_mapping.cuh" + + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Templated clean baseline: identical algorithm to fast_topk_cuda_tl but +// parameterised on ScoreT (float or __nv_bfloat16) for the GQA / paged +// call paths that operate on bf16 attention scores. No mapping, no +// pre-pass — pure two-stage radix topk on fp16 bit-pattern bins. +// ====================================================================== +template +__device__ void fast_topk_clean( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8-bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8-bit radix passes on raw fp32 bits +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// ====================================================================== +// Templated fused kernel: apply_transform(score) -> convert_to_uint8 +// is fused into Stage 1. Stage 2 still uses raw bits for tie-breaking +// (on the *remapped* value, not the original score) — this is a +// benchmarking kernel, the remapped Stage-2 ordering is acceptable. +// No pre-pass, no LUT, no shared-memory mapping state. +// ====================================================================== +template +__device__ void fast_topk_clean_fused( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams mapping) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int f_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int f_counter; + alignas(128) __shared__ int f_threshold_bin_id; + alignas(128) __shared__ int f_num_input[2]; + + // Per-element Stage-1 bin cache. Pass 1 of Stage 1 writes one byte per + // element; pass 2 reads it back so each element only pays a single + // apply_transform + global score read instead of two. + // + // s_bins lives in DYNAMIC shared memory, placed immediately after the + // f_input_idx[2][SMEM_INPUT_SIZE] 2D array in the same extern __shared__ + // region. The host launch reserves `kSmem + max_num_pages` dynamic bytes + // (see `topk_output_sglang_fused`) so every block has `max_num_pages` + // bytes available past f_input_idx's 32 KB span. Per-block `length` + // (from dense_kv_indptr) is ≤ max_num_pages, so indexing stays in bounds. + // + // This layout keeps smem usage at kSmem + 4 KB for the existing + // pages_per_seg ≤ 4096 regimes (identical to the old 32 KB dynamic + + // 4 KB static) and only grows when the caller asks for a larger + // pages_per_seg — no occupancy regression on small configs. + + auto& f_histogram = f_histogram_buf[0]; + extern __shared__ int f_input_idx[][SMEM_INPUT_SIZE]; + uint8_t* const s_bins = reinterpret_cast(&f_input_idx[2][0]); + + const int tx = threadIdx.x; + + // MODE is a compile-time template parameter, so every comparison below + // becomes a constant-folded `if constexpr` branch. The dense bucket + // path (MAPPING_DENSE_MANT) stays in the kernel but is completely + // elided when MODE != MAPPING_DENSE_MANT, and the value-space transform + // path stays in place for standard modes. LUT_CDF / QUANTILE are not + // supported by this templated kernel (they were dropped from the bench + // comparison earlier). + // MAPPING_DENSE_MANT has been retired; always use the fp16 bucket. + constexpr bool use_dense_bucket = false; + + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + + // Stage 1 pass 1: read each score from global, compute the Stage-1 + // bin via the compile-time-dispatched transform, cache it in s_bins so + // pass 2 can skip the second global read. With MODE known at compile + // time, apply_transform_tmpl inlines to just the chosen + // transform's instructions — no runtime switch overhead. + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + int bin; + if constexpr (use_dense_bucket) { + bin = static_cast(convert_to_uint8_dense(remapped)); + } else { + bin = static_cast(convert_to_uint8(remapped)); + } + s_bins[idx] = static_cast(bin); + ::atomicAdd(&f_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = f_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += f_histogram_buf[k][tx + j]; + } + f_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[0] = 0; + f_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; + + if (topk == 0) { + // Shortcut: every page above threshold gets selected. Read the bin + // from the cache so we don't re-touch global memory or recompute + // apply_transform. + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const int bin = static_cast(s_bins[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + + // Stage 1 pass 2: read the cached bin from SMEM. For elements + // outside the threshold bin we skip the global-memory load AND the + // apply_transform call entirely. Only the ~thr_size threshold-bin + // candidates re-read raw and re-apply the templated transform to + // compute the sub-bin needed for Stage-2 refinement. + // + // Sub-bin shift selection (compile-time constant): + // - standard modes: Stage-1 used fp16 top-8-bit bucketing, so + // Stage-2 round 0 refines on uint32 bits [31:24] (the most + // significant bits not captured by the fp16 bucket). + // - MAPPING_DENSE_MANT: Stage-1 used bits [23:16], so the next + // useful discriminator is bits [15:8]. Skipping to offset 8 + // directly avoids two wasted Stage-2 rounds. + constexpr int sub_bin_offset_start = use_dense_bucket ? 8 : 24; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const int bin = static_cast(s_bins[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto pos = ::atomicAdd(&f_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> sub_bin_offset_start) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine on raw bits of the remapped value. The per-round + // bit offset matches the sub_bin shift chosen above: standard modes + // start at offset 24 (bits [31:24]) and step down by 8 per round; + // MAPPING_DENSE_MANT starts at offset 8 (bits [15:8]) because Stage 1 + // already consumed bits [23:16] in the dense bucket. Both values are + // compile-time constants since MODE is a template parameter. + constexpr int stage2_offset_start = use_dense_bucket ? 8 : 24; + constexpr int stage2_max_rounds = use_dense_bucket ? 2 : 4; +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; + __shared__ int f_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = f_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[r_idx ^ 1] = 0; + f_last_remain = topk - f_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx[r_idx][i]; + const auto offset = stage2_offset_start - round * 8; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + // Last refinement round: we have no more discriminator bits + // below the current offset, so emit any remaining elements as + // "tie-break fallback" via f_last_remain (ensures topk is met + // even when thr_size > sel_thr at the finest granularity). + if (round == stage2_max_rounds - 1) { + const auto pos = ::atomicAdd(&f_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&f_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Wrapper kernels: one CUDA block per (batch*head) segment. + +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Clean_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_clean(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Fused_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_clean_fused(score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// Inverse of vortex_to_float: narrow a float back to ScoreT for the +// bf16-output remap path so the subsequent topk kernel can read half +// the bytes of a fp32 remapped buffer. +template +__device__ __forceinline__ T float_to_vortex(float x); +template <> +__device__ __forceinline__ float float_to_vortex(float x) { return x; } +template <> +__device__ __forceinline__ __nv_bfloat16 float_to_vortex<__nv_bfloat16>(float x) { + return __float2bfloat16(x); +} + +// Remap-only kernel: applies the element-wise transform to each score +// in the [dense_kv_indptr[b] + reserved_bos, dense_kv_indptr[b+1] - reserved_eos) +// range and writes the result into an output tensor (OutT = float or +// bf16). Used by the split-phase benchmark (remap → unmapped topk). +// Writing bf16 halves memory bandwidth on the output and on the +// subsequent topk read; precision-wise it's lossless for the Stage-1 +// 8-bit bucket because fp16/bf16 both discard more mantissa than the +// bucket uses. +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKRemapOnly_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + OutT* __restrict__ remapped, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= 0) return; + + const ScoreT* __restrict__ score_blk = score + start; + OutT* __restrict__ remap_blk = remapped + start; + + for (int i = tx; i < nblk; i += kThreadsPerBlock) { + const float y = apply_transform(vortex_to_float(score_blk[i]), mapping); + remap_blk[i] = float_to_vortex(y); + } +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Vortex host entry point — unmapped baseline topk (no remap). +// This is the "original topk kernel" used as the benchmarking baseline. +// ====================================================================== +void topk_output_sglang( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Clean_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Clean_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos); + } else { + TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Fused remap + topk host entry. Applies apply_transform(score, mapping) +// inline inside the Stage-1 histogram build — single kernel launch, +// single pass over the score tensor. +// ====================================================================== +void topk_output_sglang_fused( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output_sglang_fused: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + // Dynamic-smem layout for the fused kernel: + // [ f_input_idx (2 × SMEM_INPUT_SIZE × sizeof(int) = kSmem bytes) + // s_bins (bins_bytes = align_up(max_num_pages, 16)) ] + // The per-launch smem request equals the total of both. It must fit + // under kFusedSmemMax, which setup_kernel_smem_once opted this kernel + // into via cudaFuncSetAttribute(MaxDynamicSharedMemorySize, ...). + const size_t bins_bytes = (static_cast(max_num_pages) + size_t(15)) & ~size_t(15); + const size_t smem_bytes = kSmem + bins_bytes; + TORCH_CHECK(smem_bytes <= kFusedSmemMax, + "topk_output_sglang_fused: max_num_pages (", max_num_pages, + ") exceeds the fused kernel's dynamic smem ceiling. " + "Requested smem=", smem_bytes, " bytes, ceiling=", kFusedSmemMax, + " bytes. Raise kFusedSmemMax (and verify GPU opt-in limits) or " + "reduce pages_per_seg."); + + // The `mapping_lut` / `mapping_quantiles` optional tensors are + // retained in the pybind signature for API backward compatibility + // but are ignored: the templated fused kernel drops the LUT_CDF / + // QUANTILE code paths entirely. + (void)mapping_lut; + (void)mapping_quantiles; + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + // Each mapping mode compiles to its own kernel specialization so + // apply_transform_tmpl is fully inlined (no runtime switch on + // mode in the inner loop). The wrapper's outer dispatch is a one- + // time per-call cost, negligible relative to the kernel runtime. + #define VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + setup_kernel_smem_once, kFusedSmemMax>(); \ + TopKOutput_Fused_Kernel<<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + topk_val, reserved_bos, reserved_eos, mapping); \ + } while (0) + + #define VORTEX_DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + switch (mapping.mode) { \ + case MAPPING_NONE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_LOG: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LOG); break; \ + case MAPPING_ASINH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_TRUNC8: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_TRUNC8); break; \ + case MAPPING_ERF: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP:VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + case MAPPING_HALF_SQUARE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ + case MAPPING_HALF_CUBE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ + default: \ + TORCH_CHECK(false, "topk_output_sglang_fused: unsupported mapping_mode ", mapping.mode); \ + } \ + } while (0) + + if (x.scalar_type() == at::ScalarType::BFloat16) { + VORTEX_DISPATCH_MODE(__nv_bfloat16, reinterpret_cast<__nv_bfloat16*>(x.data_ptr())); + } else if (x.scalar_type() == at::ScalarType::Float) { + VORTEX_DISPATCH_MODE(float, x.data_ptr()); + } else { + TORCH_CHECK(false, "topk_output_sglang_fused: unsupported dtype ", x.scalar_type()); + } + + #undef VORTEX_DISPATCH_MODE + #undef VORTEX_DISPATCH_FUSED + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output_sglang_fused kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Standalone remap kernel. Writes apply_transform(score) into a +// float32 output buffer without running topk. Used by the split-phase +// benchmark (remap → unmapped topk) to measure each phase independently. +// ====================================================================== +void topk_remap_only( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + at::Tensor& remapped, // float32 or bfloat16, same numel as x + const int64_t eff_batch_size, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t mapping_mode, + const double mapping_power) +{ + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(remapped); + TORCH_CHECK(remapped.scalar_type() == at::ScalarType::Float + || remapped.scalar_type() == at::ScalarType::BFloat16, + "remapped output must be float32 or bfloat16"); + + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + // Four-way dispatch on (input dtype, output dtype). bf16→bf16 is the + // new "batch pre-transform" path that halves memory bandwidth vs the + // fp32 output: the remap writes half the bytes and the subsequent + // topk_output_sglang reads half the bytes. Precision is preserved + // because Stage-1 bucketing only uses the top 8 bits of an fp16 key + // which both fp32 and bf16 capture. + #define VORTEX_DISPATCH_REMAP(IN_CPP, OUT_CPP, IN_PTR_EXPR, OUT_PTR_EXPR) \ + TopKRemapOnly_Kernel<<>>( \ + IN_PTR_EXPR, dense_kv_indptr.data_ptr(), OUT_PTR_EXPR, \ + reserved_bos, reserved_eos, mapping) + + const bool in_bf16 = (x.scalar_type() == at::ScalarType::BFloat16); + const bool in_fp32 = (x.scalar_type() == at::ScalarType::Float); + const bool out_bf16 = (remapped.scalar_type() == at::ScalarType::BFloat16); + + if (in_bf16 && out_bf16) { + VORTEX_DISPATCH_REMAP(__nv_bfloat16, __nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(remapped.data_ptr())); + } else if (in_bf16 && !out_bf16) { + VORTEX_DISPATCH_REMAP(__nv_bfloat16, float, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + remapped.data_ptr()); + } else if (in_fp32 && out_bf16) { + VORTEX_DISPATCH_REMAP(float, __nv_bfloat16, + x.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(remapped.data_ptr())); + } else if (in_fp32 && !out_bf16) { + VORTEX_DISPATCH_REMAP(float, float, + x.data_ptr(), + remapped.data_ptr()); + } else { + TORCH_CHECK(false, "topk_remap_only: unsupported dtype ", x.scalar_type()); + } + + #undef VORTEX_DISPATCH_REMAP + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_remap_only kernel failed: ", ::cudaGetErrorString(result)); +} diff --git a/csrc/topk_sglang_merge.cu b/csrc/topk_sglang_merge.cu new file mode 100644 index 00000000..078b6d75 --- /dev/null +++ b/csrc/topk_sglang_merge.cu @@ -0,0 +1,1939 @@ +/** + * Vortex adaptive split TopK — random-split parallel K=30 path + fused + * fallback. Lives in topk_sglang_merge.cu (NOT a new file). + * + * Dispatch summary (host-side topk_output_adaptive_workspace): + * + * topk_val >= 1024 → immediate call to topk_output_sglang_fused. + * No workspace touched, no done_counter memset, no + * split kernel launched. Required for 32k → 2048. + * + * topk_val > 32 (and < 1024) → also forwards to fused (no specialised path). + * + * topk_val <= 32: + * forced_splits > 0 → use that split count (1, 2, 4, 8, 16, 32). + * forced_splits <= 0 → use heuristic pick_split_top30(). + * split == 1 (heuristic only) → fall back to fused. + * split == 1 (forced) → run the SPLITS=1 single-CTA path + * for benchmarking (one CUDA block sorts + * the whole row with cub::BlockRadixSort). + * + * Random split semantics: each split processes ONLY its slice of the row, + * not the whole row filtered by predicate, so total work = O(n) not O(n*S). + * + * group_begin = (n * split_id) / SPLITS + * group_end = (n * (split_id+1)) / SPLITS + * For each logical rank r in [group_begin, group_end), the physical + * page-table position is `permute(r, b_offset, n)`. For pow2 n we use + * the affine bijection + * pos = (r * a + b_offset) & (n - 1) + * with a = golden-ratio constant 2654435769 (odd → bijective mod 2^k). + * Per-row b_offset = b * 1013904223 + r0, where r0 is a fixed seed + * for reproducibility. For non-pow2 n we fall back to the contiguous + * mapping pos = r (the chunks then become consecutive slices). + * + * Local stage: cub::BlockRadixSort::SortDescending. Writes the top kLocalK=32 (key, idx) pairs to + * partial workspace per (row, split). + * + * Merge stage (last CTA, SPLITS > 1): cub::WarpMergeSort over SPLITS*32 + * candidates. kMergeIPT = SPLITS items per thread; the 32 warp lanes + * together hold all SPLITS*32 candidates. Each thread's SPLITS items are + * a contiguous descending-sorted slice of the workspace (one split per + * kMergeIPT items), so the WarpMergeSort precondition is satisfied. + * After the sort, threads write their items to out_idx at the correct + * global rank; only ranks < topk_val are written. + * Final top-topk_val global page IDs land in + * sparse_kv_indices[sparse_kv_indptr[b] + reserved_bos + rank]. + * + * done_counter is the external workspace; the host clears it with + * cudaMemsetAsync before each parallel-path launch. The fused-fallback + * branches do not touch it. + */ + + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + #include + #include + #include + #include + + #include "register.h" + + namespace { + + constexpr int kLocalK_Top30 = 32; // local top-K per chunk + constexpr int kMaxFinalK_Top30 = 32; // accept topk_val up to this + constexpr int64_t kFusedFallbackTopK = 1024; // K >= this routes to fused + + // ============================================================================= + // Local-stage policy for the K=30 split kernel. + // + // BLOCK_FULL_SORT : per-CTA cub::BlockRadixSort over the whole split group, + // capped by the NT*IPT capacity ladder in kCfg* below. + // Original baseline kernel. + // + // SELECT32_SORT32 : per-CTA sglang-style 8-bit radix-select that emits + // exactly LOCAL_K=32 candidates without sorting the + // whole group, followed by a 32-element warp bitonic + // sort (cub::WarpMergeSort with IPT=1). Inner loops + // are strided over the group, so there is no NT*IPT + // ceiling and arbitrary chunk_len is supported. + // + // Both modes share the merge stage: each CTA writes a sorted local top-32 + // to partial workspace, the last CTA per row runs merge_cub_warp_topk. + // ============================================================================= + enum TopK30LocalMode : int { + LOCAL_BLOCK_FULL_SORT = 0, + LOCAL_SELECT32_SORT32 = 1, + }; + + // Affine permutation constants (LCG-style). a is odd → bijective mod 2^k. + constexpr uint32_t kPermuteA = 2654435769u; // golden ratio fractional bits + constexpr uint32_t kPermuteSeedB = 1013904223u; + constexpr uint32_t kPermuteOffset = 0x9E3779B9u; // additional offset + + // ---- bit-level helpers ------------------------------------------------------ + + // Sortable uint32 key for an fp32 value: ascending uint32 == ascending fp32. + __device__ __forceinline__ uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + template + __device__ __forceinline__ float vortex_to_float(T x); + template <> + __device__ __forceinline__ float vortex_to_float(float x) { return x; } + template <> + __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); + } + + // Stage-1 8-bit bin used by topk_mapping.cuh's compute_stage1_bin. Defined + // here so the header pulls in cleanly, even though we don't otherwise use it. + __device__ __forceinline__ uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + #include "topk_mapping.cuh" + + // Affine permutation modulo 2^k, bijective when n is a power of two. + __device__ __forceinline__ int permute_pow2(uint32_t r, uint32_t b_off, uint32_t n_mask) { + return static_cast((r * kPermuteA + b_off) & n_mask); + } + + // True iff n is a strictly positive power of two. + __device__ __host__ __forceinline__ bool is_pow2(int n) { + return n > 0 && ((n & (n - 1)) == 0); + } + + // ============================================================================= + // Partition modes — control how a row's logical-rank space [0, n) is mapped + // to physical positions per split CTA. Goal: keep total work O(n) (no + // per-split full scan) while controlling memory access locality. + // + // AFFINE_RANDOM : pos = (a*r + b_off) & (n-1) [random gather] + // CONTIGUOUS : pos = group_begin + local_rank [coalesced] + // STRIDED : pos = split_id + local_rank * SPLITS [interleaved] + // TILE_RANDOM_128 : tile-permute then read TILE=128 contiguous positions + // within each tile. + // TILE_RANDOM_256 : same as 128 but with TILE=256. + // ============================================================================= + enum PartitionMode : int { + PART_AFFINE_RANDOM = 0, + PART_CONTIGUOUS = 1, + PART_STRIDED = 2, + PART_TILE_RANDOM_128 = 3, + PART_TILE_RANDOM_256 = 4, + }; + constexpr int kTileSize128 = 128; + constexpr int kTileSize256 = 256; + + // Tile-random: divide row into TILE-sized contiguous tiles, permute the tile + // IDs across the row using the affine bijection, and assign tiles_per_split + // = chunk_len / TILE tiles to each split. Within a tile, reads are + // contiguous → coalesced 128B / 256B sectors. + template + __device__ __forceinline__ int tile_random_pos( + int local_rank, int row_len, int split_id, + uint32_t b_off, uint32_t n_mask) + { + const int chunk_len = row_len / SPLITS; + if (chunk_len < TILE) { + // Fallback to affine when tiles don't fit. + const int group_begin = (row_len * split_id) / SPLITS; + const int r = group_begin + local_rank; + return permute_pow2(static_cast(r), b_off, n_mask); + } + const int tiles_per_split = chunk_len / TILE; + const int tile_in_split = local_rank / TILE; + const int offset_in_tile = local_rank & (TILE - 1); + const int global_tile_rank = split_id * tiles_per_split + tile_in_split; + const int tile_count = row_len / TILE; + const uint32_t tile_mask = static_cast(tile_count - 1); + const uint32_t tile_id = + (static_cast(global_tile_rank) * kPermuteA + b_off) & tile_mask; + return static_cast(tile_id) * TILE + offset_in_tile; + } + + template + __device__ __forceinline__ int compute_pos( + int local_rank, int row_len, int split_id, + uint32_t b_off, uint32_t n_mask) + { + if constexpr (PARTITION == PART_CONTIGUOUS) { + const int group_begin = (row_len * split_id) / SPLITS; + return group_begin + local_rank; + } else if constexpr (PARTITION == PART_STRIDED) { + // Each split owns lanes [split_id, split_id+SPLITS, split_id+2*SPLITS, ...]. + // Across all splits, the union covers every position in [0, row_len) + // exactly once when row_len is divisible by SPLITS. + return split_id + local_rank * SPLITS; + } else if constexpr (PARTITION == PART_TILE_RANDOM_128) { + return tile_random_pos( + local_rank, row_len, split_id, b_off, n_mask); + } else if constexpr (PARTITION == PART_TILE_RANDOM_256) { + return tile_random_pos( + local_rank, row_len, split_id, b_off, n_mask); + } else { + // AFFINE_RANDOM (default). + const int group_begin = (row_len * split_id) / SPLITS; + const int r = group_begin + local_rank; + return permute_pow2(static_cast(r), b_off, n_mask); + } + } + + // ============================================================================= + // Descending comparator for cub::WarpMergeSort (sorts largest key first). + // ============================================================================= + struct DescendingUint32 { + __device__ __forceinline__ bool operator()(uint32_t a, uint32_t b) const { + return a > b; + } + }; + + // ============================================================================= + // Single-warp CUB merge of SPLITS sorted top-LOCAL_K lists. + // + // Workspace layout (keys_in / idx_in): SPLITS * LOCAL_K elements, split-major + // with each split's LOCAL_K entries sorted descending. With LOCAL_K=32 and + // kMergeIPT = SPLITS (= SPLITS*32 / 32), thread tx holds exactly SPLITS + // consecutive items starting at tx*SPLITS — always a contiguous sorted slice + // within a single split's list. cub::WarpMergeSort precondition is satisfied. + // + // After the sort the global top-final_k indices are written to out_idx[0..final_k-1] + // by the threads that own those ranks; no lane conflicts. + // + // Register pressure: kMergeIPT = SPLITS. For SPLITS=32 each lane holds 32 + // key+value pairs (~128 B registers). Acceptable for sm_90+. + // ============================================================================= + template + __device__ __forceinline__ void merge_cub_warp_topk( + const uint32_t* __restrict__ keys_in, + const int32_t* __restrict__ idx_in, + int32_t* __restrict__ out_idx, + int final_k) + { + constexpr int kCandidates = SPLITS * LOCAL_K; + constexpr int kMergeIPT = (kCandidates + 31) / 32; + using WarpMergeT = cub::WarpMergeSort; + __shared__ typename WarpMergeT::TempStorage warp_merge_smem; + const int tx = threadIdx.x; + if (tx < 32) { + uint32_t wkeys[kMergeIPT]; + int32_t wvals[kMergeIPT]; + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + wkeys[k] = (rank < kCandidates) ? keys_in[rank] : 0u; + wvals[k] = (rank < kCandidates) ? idx_in [rank] : -1; + } + WarpMergeT(warp_merge_smem).Sort(wkeys, wvals, DescendingUint32{}); + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + if (rank < final_k && wvals[k] >= 0) out_idx[rank] = wvals[k]; + } + } + } + + template + inline void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + return ::cudaFuncSetAttribute( + f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + }(); + TORCH_CHECK(result == cudaSuccess, + "topk_output_adaptive setup failed: ", + ::cudaGetErrorString(result)); + } + + // ============================================================================= + // K=30 random-split parallel kernel. + // + // Grid: (eff_batch_size, SPLITS). + // blockIdx.x = effective row id (0..eff_batch_size-1) + // blockIdx.y = split id (0..SPLITS-1) + // + // Stage 1 (every CTA): + // - Compute group_begin/group_end for this split. + // - For each local rank in [0, group_len), compute physical pos via + // permute_pow2 (or contiguous fallback for non-pow2 n). + // - Apply apply_transform_tmpl, build (uint32_key, int32_global_idx). + // - cub::BlockRadixSort.SortDescending. Top items at start of array. + // + // For SPLITS == 1 the kernel writes the top topk_val directly to + // sparse_kv_indices and returns — no merge. + // + // For SPLITS > 1, write the top kLocalK=32 (key, idx) pairs to the + // partial workspace at offset (b*SPLITS + n)*kLocalK. + // + // Last-CTA-wins barrier (SPLITS > 1): + // __threadfence (release) → atomicAdd → if old == SPLITS-1, last CTA → + // __threadfence (acquire) → __syncthreads. + // + // Stage 2 (last CTA, SPLITS > 1): + // - Load SPLITS*32 candidates into one warp / one block. + // - Sort descending by uint32 key. + // - Lanes 0..topk_val-1 (or threads) write their item to sparse_kv_indices. + // ============================================================================= + template + __global__ __launch_bounds__(NUM_THREADS) + void TopK30_RandomSplit_Parallel_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + uint32_t* __restrict__ partial_keys, + int32_t* __restrict__ partial_indices, + int32_t* __restrict__ done_counter, + const int topk_val, + const int reserved_bos, + const int reserved_eos, + const float mapping_power) + { + using KeyT = uint32_t; + using ValueT = int32_t; + using BlockSortT = cub::BlockRadixSort; + + constexpr int kLocalK = kLocalK_Top30; + + __shared__ typename BlockSortT::TempStorage sort_smem; + __shared__ int s_is_last; + + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = max(0, row_end - row_start); + + if (row_len <= 0) return; + + // --- Group boundaries (no overlap, no gaps across splits). --- + const int group_begin = (static_cast(row_len) * n) / SPLITS; + const int group_end = (static_cast(row_len) * (n + 1)) / SPLITS; + const int group_len = group_end - group_begin; + + // --- Permutation parameters. --- + // For pow2 row_len, use affine bijection mod row_len. For non-pow2, fall + // back to identity (chunks become consecutive slices). + const bool row_is_pow2 = is_pow2(row_len); + const uint32_t n_mask = row_is_pow2 ? static_cast(row_len - 1) : 0u; + const uint32_t b_off = + static_cast(b) * kPermuteSeedB + kPermuteOffset; + + const ScoreT* row_scores = score + row_start; + const int* row_idxmap = dense_kv_indices + row_start; + + // ------------------------------------------------------------------ Stage 1 + KeyT keys[ITEMS_PER_THREAD]; + ValueT values[ITEMS_PER_THREAD]; + + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int local_rank = tx + k * NUM_THREADS; + if (local_rank < group_len) { + int pos; + if (row_is_pow2) { + pos = compute_pos( + local_rank, row_len, n, b_off, n_mask); + } else { + // Non-pow2 fallback: contiguous slice (also semantically valid for + // partition=CONTIGUOUS). + pos = group_begin + local_rank; + } + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + keys [k] = convert_to_uint32(remapped); + values[k] = row_idxmap[pos]; + } else { + keys [k] = 0u; + values[k] = -1; + } + } + + BlockSortT(sort_smem).SortDescending(keys, values); + __syncthreads(); + + // SPLITS == 1 special case: write final output directly. No atomic, no merge. + if constexpr (SPLITS == 1) { + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int rank = tx * ITEMS_PER_THREAD + k; + if (rank < topk_val) out_idx[rank] = values[k]; + } + return; + } + + // SPLITS > 1: write local top kLocalK to partial workspace. + const int64_t part_off = (static_cast(b) * SPLITS + n) * kLocalK; + uint32_t* part_keys = partial_keys + part_off; + int32_t* part_idx = partial_indices + part_off; + + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int rank = tx * ITEMS_PER_THREAD + k; + if (rank < kLocalK) { + part_keys[rank] = keys[k]; + part_idx [rank] = values[k]; + } + } + + // -------------------------------------------------------- Last-CTA barrier + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + // Self-reset: the last CTA clears its slot for the next launch. + // Eliminates the need for cudaMemsetAsync(done_counter) on the host — + // saves ~1-2 µs of CPU launch overhead per call. Same-stream kernels are + // sequenced, so the next launch sees done_counter[b] == 0. + if (s_is_last) done_counter[b] = 0; + } + __syncthreads(); + if (s_is_last == 0) return; + // Acquire fence: ensure the merging CTA observes other CTAs' partial writes. + __threadfence(); + __syncthreads(); + + // ------------------------------------------------------------------ Stage 2 + // cub::WarpMergeSort over all SPLITS*kLocalK candidates (warp 0 only). + // kMergeIPT = SPLITS items per lane; each lane's items are a contiguous + // sorted slice within a single split's list, satisfying WarpMergeSort's + // pre-sorted-per-thread precondition. + const int64_t row_off = static_cast(b) * SPLITS * kLocalK; + const uint32_t* keys_in = partial_keys + row_off; + const int32_t* idx_in = partial_indices + row_off; + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + + merge_cub_warp_topk( + keys_in, idx_in, out_idx, topk_val); + } + + // ============================================================================= + // K=30 SELECT32_SORT32 local-stage kernel (Plan C). + // + // Grid: (eff_batch_size, SPLITS). + // blockIdx.x = effective row id, blockIdx.y = split id. + // + // Per-CTA pipeline: + // Pass 1 - top-byte (bits [31:24]) histogram + suffix-sum-descending, + // find the threshold bin where cumulative count crosses + // LOCAL_K=32 (unique by monotonicity). + // Pass 2 - re-scan the split group: items strictly above the threshold + // bin go straight into the candidate buffer (count is + // guaranteed < LOCAL_K). Items at the threshold bin contribute + // to a sub-bin (bits [23:16]) histogram. + // Pass 3 - find the sub-threshold bin in the sub-hist, then re-scan the + // threshold bin and gather (sub > sub_threshold) and + // (sub == sub_threshold) candidates into the remaining slots. + // Stage D - 32-lane warp bitonic sort over the LOCAL_K candidates, + // descending by uint32 key. Implemented via cub::WarpMergeSort + // with IPT=1 (sort precondition is trivially satisfied). + // + // SPLITS == 1: write top topk_val directly to sparse_kv_indices, no + // workspace, no atomic, no merge. + // SPLITS > 1: write sorted local top-LOCAL_K to partial workspace, the + // last CTA per row runs merge_cub_warp_topk. + // + // No cub::BlockRadixSort smem and no NT*IPT capacity ceiling. Each pass + // is a strided loop over [0, group_len) so the kernel handles any + // chunk length the splits produce, including the SPLITS=1 / 32k case. + // ============================================================================= + template + __global__ __launch_bounds__(NUM_THREADS) + void TopK30_RandomSplit_Select32_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + uint32_t* __restrict__ partial_keys, + int32_t* __restrict__ partial_indices, + int32_t* __restrict__ done_counter, + const int topk_val, + const int reserved_bos, + const int reserved_eos, + const float mapping_power) + { + constexpr int LOCAL_K = kLocalK_Top30; + constexpr int kRadix = 256; + + alignas(128) __shared__ int s_hist_buf[2][kRadix + 128]; + __shared__ int s_above_count; // count strictly above threshold_bin (pass 2) + __shared__ int s_thresh_above_count; // count (bin==t && sub>sub_t) (pass 3) + __shared__ int s_thresh_at_count; // count (bin==t && sub==sub_t) (pass 3, capped) + __shared__ int s_threshold_bin; + __shared__ int s_last_remain; + __shared__ int s_sub_threshold_bin; + __shared__ int s_sub_last_remain; + __shared__ int s_strictly_above_sub; + __shared__ uint32_t s_top_keys[LOCAL_K]; + __shared__ int32_t s_top_idx [LOCAL_K]; + __shared__ int s_is_last; + + using LocalSortT = cub::WarpMergeSort; + __shared__ typename LocalSortT::TempStorage local_sort_smem; + + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = max(0, row_end - row_start); + + const int group_begin = (static_cast(row_len) * n) / SPLITS; + const int group_end = (static_cast(row_len) * (n + 1)) / SPLITS; + const int group_len = group_end - group_begin; + + const bool row_is_pow2 = is_pow2(row_len); + const uint32_t n_mask = row_is_pow2 ? static_cast(row_len - 1) : 0u; + const uint32_t b_off = + static_cast(b) * kPermuteSeedB + kPermuteOffset; + const ScoreT* row_scores = score + row_start; + const int* row_idxmap = dense_kv_indices + row_start; + + // ---- Init shared state. Strided over the +128 padding so any NT works. ---- + for (int i = tx; i < kRadix + 128; i += NUM_THREADS) { + s_hist_buf[0][i] = 0; + s_hist_buf[1][i] = 0; + } + if (tx == 0) { + s_above_count = 0; + s_thresh_above_count = 0; + s_thresh_at_count = 0; + s_threshold_bin = -1; + s_last_remain = 0; + s_sub_threshold_bin = -1; + s_sub_last_remain = 0; + s_strictly_above_sub = 0; + s_is_last = 0; + } + if (tx < LOCAL_K) { + s_top_keys[tx] = 0u; + s_top_idx [tx] = -1; + } + __syncthreads(); + + // Empty-row early exit. SPLITS>1 must still participate in the merge + // barrier so the last-CTA flag fires; padding is already (0u, -1). + if (row_len <= 0 || group_len <= 0) { + if constexpr (SPLITS > 1) { + const int64_t part_off = + (static_cast(b) * SPLITS + n) * LOCAL_K; + if (tx < LOCAL_K) { + partial_keys [part_off + tx] = 0u; + partial_indices[part_off + tx] = -1; + } + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + if (s_is_last) done_counter[b] = 0; // self-reset for next launch + } + __syncthreads(); + if (s_is_last == 0) return; + __threadfence(); + __syncthreads(); + const int64_t row_off = static_cast(b) * SPLITS * LOCAL_K; + merge_cub_warp_topk( + partial_keys + row_off, partial_indices + row_off, + sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos, + topk_val); + } + return; + } + + // Strided suffix-sum-descending over s_hist_buf, ping-pong; result in [0]. + // Works for any NUM_THREADS (uses a strided inner loop over kRadix). + auto run_cumsum_strided = [&]() { + #pragma unroll + for (int i = 0; i < 8; ++i) { + const int j = 1 << i; + const int k = i & 1; + for (int idx = tx; idx < kRadix; idx += NUM_THREADS) { + int v = s_hist_buf[k][idx]; + if (idx + j < kRadix) v += s_hist_buf[k][idx + j]; + s_hist_buf[k ^ 1][idx] = v; + } + __syncthreads(); + } + }; + + // ============================================================ + // Pass 1: top-byte histogram. + // ============================================================ + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) { + pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + } else { + pos = group_begin + local_rank; + } + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + ::atomicAdd(&s_hist_buf[0][bin], 1); + } + __syncthreads(); + + run_cumsum_strided(); + // s_hist_buf[0][bin] = count of items with key>>24 >= bin. + + const int total_items = s_hist_buf[0][0]; + + // Find threshold bin: the unique bin t where total_at_or_above[t] >= K + // and strictly_above[t] < K. Strided so any NT works. + for (int bin = tx; bin < kRadix; bin += NUM_THREADS) { + const int total_at_or_above = s_hist_buf[0][bin]; + const int strictly_above = (bin + 1 < kRadix) ? s_hist_buf[0][bin + 1] : 0; + if (total_at_or_above >= LOCAL_K && strictly_above < LOCAL_K) { + s_threshold_bin = bin; + s_last_remain = LOCAL_K - strictly_above; + } + } + __syncthreads(); + + if (total_items <= LOCAL_K) { + // Few-elements path: collect everything in arbitrary order, pad rest. + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int slot = ::atomicAdd(&s_above_count, 1); + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } + __syncthreads(); + // s_top_keys/idx already pre-padded to (0u, -1) at init. + } else { + const int threshold_bin = s_threshold_bin; + + // Reset both hist buffers for the sub-bin pass. + for (int i = tx; i < kRadix + 128; i += NUM_THREADS) { + s_hist_buf[0][i] = 0; + s_hist_buf[1][i] = 0; + } + __syncthreads(); + + // ============================================================ + // Pass 2: gather strictly-above-threshold items + build sub-hist. + // ============================================================ + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + if (bin > threshold_bin) { + const int slot = ::atomicAdd(&s_above_count, 1); + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } else if (bin == threshold_bin) { + const int sub_bin = static_cast((key >> 16) & 0xFF); + ::atomicAdd(&s_hist_buf[0][sub_bin], 1); + } + } + __syncthreads(); + + run_cumsum_strided(); + + const int last_remain = s_last_remain; + for (int bin = tx; bin < kRadix; bin += NUM_THREADS) { + const int total_at_or_above = s_hist_buf[0][bin]; + const int strictly_above = (bin + 1 < kRadix) ? s_hist_buf[0][bin + 1] : 0; + if (total_at_or_above >= last_remain && strictly_above < last_remain) { + s_sub_threshold_bin = bin; + s_sub_last_remain = last_remain - strictly_above; + s_strictly_above_sub = strictly_above; + } + } + __syncthreads(); + + const int sub_threshold_bin = s_sub_threshold_bin; + const int sub_last_remain = s_sub_last_remain; + const int strictly_above_sub_bn = s_strictly_above_sub; + const int above_base = s_above_count; // = strictly_above_threshold + + // ============================================================ + // Pass 3: gather threshold-bin sub-above + sub-at items. + // ============================================================ + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + if (bin == threshold_bin) { + const int sub_bin = static_cast((key >> 16) & 0xFF); + if (sub_bin > sub_threshold_bin) { + const int rel = ::atomicAdd(&s_thresh_above_count, 1); + const int slot = above_base + rel; + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } else if (sub_bin == sub_threshold_bin) { + const int rel = ::atomicAdd(&s_thresh_at_count, 1); + if (rel < sub_last_remain) { + const int slot = above_base + strictly_above_sub_bn + rel; + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } + } + } + } + __syncthreads(); + } + + // ============================================================ + // Stage D: 32-lane warp bitonic sort over the LOCAL_K candidates. + // cub::WarpMergeSort with IPT=1 has trivial pre-sorted-per-thread + // precondition (each lane owns exactly 1 item). + // ============================================================ + if (tx < 32) { + uint32_t kk[1] = { s_top_keys[tx] }; + int32_t vv[1] = { s_top_idx [tx] }; + LocalSortT(local_sort_smem).Sort(kk, vv, DescendingUint32{}); + s_top_keys[tx] = kk[0]; + s_top_idx [tx] = vv[0]; + } + __syncthreads(); + + // ============================================================ + // SPLITS == 1: direct write to sparse_kv_indices. + // ============================================================ + if constexpr (SPLITS == 1) { + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + if (tx < topk_val) out_idx[tx] = s_top_idx[tx]; + return; + } + + // ============================================================ + // SPLITS > 1: write workspace, last-CTA-wins barrier, merge. + // ============================================================ + const int64_t part_off = (static_cast(b) * SPLITS + n) * LOCAL_K; + if (tx < LOCAL_K) { + partial_keys [part_off + tx] = s_top_keys[tx]; + partial_indices[part_off + tx] = s_top_idx [tx]; + } + + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + if (s_is_last) done_counter[b] = 0; // self-reset for next launch + } + __syncthreads(); + if (s_is_last == 0) return; + __threadfence(); + __syncthreads(); + + const int64_t row_off = static_cast(b) * SPLITS * LOCAL_K; + merge_cub_warp_topk( + partial_keys + row_off, partial_indices + row_off, + sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos, + topk_val); + } + + // ============================================================================= + // Per-split (NUM_THREADS, ITEMS_PER_THREAD) configuration. + // + // NUM_THREADS * ITEMS_PER_THREAD must cover the per-split chunk length + // (= ceil(max_num_pages / SPLITS)). Picked once per SPLITS rather than per + // (SPLITS, max_num_pages) to keep the template instantiation count + // manageable. + // + // cub::BlockRadixSort uses ~NT*IPT*sizeof(KeyT) bytes of static shared + // memory; ptxas rejects kernels exceeding 48 KB static smem on sm_100a + // without opt-in (which static smem can't easily use). Using uint32 keys + // + 8-byte (key,value) effective footprint, we keep NT*IPT*4 <= ~32 KB → + // NT*IPT <= 8192. Coverage: + // + // | chunk_max | covers max_num_pages + // ------------------------------------ + // 1 | 8192 | 8192 + // 2 | 8192 | 16384 + // 4 | 4096 | 16384 + // 8 | 4096 | 32768 + // 16 | 2048 | 32768 + // 32 | 1024 | 32768 + // + // Configs above the coverage row fall back to the fused single-CTA kernel + // in the dispatcher (capacity check below). + // ============================================================================= + struct SplitCfg { int splits, num_threads, items_per_thread; }; + + constexpr SplitCfg kCfg1 = { 1, 1024, 8 }; // cap 8192 + constexpr SplitCfg kCfg2 = { 2, 1024, 8 }; // cap 8192 + constexpr SplitCfg kCfg4 = { 4, 512, 8 }; // cap 4096 + constexpr SplitCfg kCfg8 = { 8, 256, 16 }; // cap 4096 + constexpr SplitCfg kCfg16 = {16, 128, 16 }; // cap 2048 + constexpr SplitCfg kCfg32 = {32, 64, 16 }; // cap 1024 + + // Returns the per-split capacity (NT*IPT) for a given split count, or 0 if + // the split is not supported. + inline int split_capacity(int split) { + switch (split) { + case 1: return kCfg1.num_threads * kCfg1.items_per_thread; + case 2: return kCfg2.num_threads * kCfg2.items_per_thread; + case 4: return kCfg4.num_threads * kCfg4.items_per_thread; + case 8: return kCfg8.num_threads * kCfg8.items_per_thread; + case 16: return kCfg16.num_threads * kCfg16.items_per_thread; + case 32: return kCfg32.num_threads * kCfg32.items_per_thread; + default: return 0; + } + } + + inline int next_supported_split(int required) { + if (required <= 1) return 1; + if (required <= 2) return 2; + if (required <= 4) return 4; + if (required <= 8) return 8; + if (required <= 16) return 16; + return 32; + } + + // ============================================================================= + // Per-split NUM_THREADS for the SELECT32_SORT32 kernel. + // + // No NT*IPT capacity ladder: the kernel scans the split group with strided + // loops, so any group_len works at any NT. Picked here only to balance + // memory throughput vs occupancy. NT=128 is fine for high splits because + // chunk_len shrinks proportionally (max_pages=32k / SPLITS=32 -> 1024). + // ============================================================================= + struct SelectCfg { int splits, num_threads; }; + constexpr SelectCfg kSelCfg1 = { 1, 1024 }; + constexpr SelectCfg kSelCfg2 = { 2, 1024 }; + constexpr SelectCfg kSelCfg4 = { 4, 512 }; + constexpr SelectCfg kSelCfg8 = { 8, 256 }; + constexpr SelectCfg kSelCfg16 = {16, 128 }; + constexpr SelectCfg kSelCfg32 = {32, 128 }; + + // SM-cover policy: pick the smallest supported split such that + // total_ctas = eff_batch_size * split >= sm_count. This prioritises + // filling the device. Capacity / merge cost are NOT considered here — + // the dispatcher's capacity check below catches infeasible configs. + inline int choose_split_k30_b200(int64_t eff_bs, int64_t /*max_pages*/, + int forced, int sm_count) + { + if (forced > 0) return forced; + constexpr int kSMCoverDefault = 180; // B200 multiprocessorCount + const int target_blocks = sm_count > 0 ? sm_count : kSMCoverDefault; + const int required = static_cast( + (target_blocks + eff_bs - 1) / eff_bs); + return next_supported_split(required); + } + + // Default partition mode picker. The B200 sweep at K=30 shows CONTIGUOUS + // dominates affine and tile-random by 10-15% at high splits (8/16/32) and + // is within noise at low splits — coalesced loads are the bottleneck once + // each split has more than a handful of threads. Random vs contiguous is + // correctness-equivalent here (each split's local top-32 is merged via + // CUB WarpMergeSort into the global top-30, regardless of partition layout). + // Override via forced_partition for ablation. + inline int default_partition(int /*split*/, int64_t /*max_num_pages*/) { + return PART_CONTIGUOUS; + } + + // Heuristic split picker for K<=32. + // + // ALWAYS returns an adaptive split count in {1,2,4,8,16,32}. Never falls + // back to fused — for K=30 the dispatcher in topk_output_adaptive_workspace + // is required to stay on the adaptive path. split=1 means "single-CTA + // adaptive kernel", NOT "use fused sglang baseline". + // + // Table from B200 sweep (benchmarks/bench_topk_setting_sweep.py, + // SELECT32_SORT32 local mode, CONTIGUOUS partition, CUB WarpMergeSort merge): + // + // max_pages <= 32768 : split=1 wins or ties at every B in {1..16}; + // e.g. 4k/B=4 -> 17.2us @s=1 vs 23.4us @s=2. + // max_pages == 65536 : split=4 beats split=1 by ~18-19% within adaptive + // (s=1 41.8us vs s=4 33.7us); 4 CTAs * 16k chunk + // keeps the per-CTA radix select small enough that + // the merge cost is amortised by the parallel scan. + // + // forced_splits overrides this for benchmarking. + inline int pick_split_top30(int64_t /*eff_bs*/, int64_t max_pages) { + if (max_pages > 32768) return 4; + return 1; + } + + // ============================================================================= + // Mid-K (K in {64, 128, 256, 512}) generalized SELECTK_SORTK kernel. + // + // Same structure as TopK30_RandomSplit_Select32_Kernel, with LOCAL_K + // templated up to 512 and the local sort + final merge replaced with + // cub::BlockMergeSort variants sized by LOCAL_K and SPLITS*LOCAL_K + // respectively. + // + // Per-CTA pipeline (mirrors K=30 path; only sizes change): + // Pass 1 — top-byte (bits [31:24]) histogram + suffix-sum-descending, + // find threshold bin where cumulative count crosses LOCAL_K. + // Pass 2 — strictly-above-threshold goes straight to candidate buffer; + // equal-to-threshold contributes to sub-bin (bits [23:16]) histogram. + // Pass 3 — sub-threshold then sub-equal candidates. + // Sort — cub::BlockMergeSort over LOCAL_K candidates with NT_SORT=128 + // and IPT_SORT = ceil(LOCAL_K, 128) / 128 (LOCAL_K=64 padded to 128). + // + // Final merge (last CTA, SPLITS > 1): + // cub::BlockMergeSort over SPLITS * LOCAL_K candidates. Capped at 4096 + // candidates total (NT=256, IPT=16) for register pressure. + // + // Capacity policy (max SPLITS per LOCAL_K, candidates capped at 4096): + // LOCAL_K=64 -> SPLITS in {1, 2, 4, 8, 16, 32} (max C=2048) + // LOCAL_K=128 -> SPLITS in {1, 2, 4, 8, 16, 32} (max C=4096) + // LOCAL_K=256 -> SPLITS in {1, 2, 4, 8, 16} (max C=4096) + // LOCAL_K=512 -> SPLITS in {1, 2, 4, 8} (max C=4096) + // + // For NT_SORT=128 we need LOCAL_K to be a multiple of 128 (the slot + // buffer is padded with (key=0, idx=-1) sentinels otherwise; a few + // threads sort dummy items, but the descending sort drops them past + // LOCAL_K and they are never read). + // ============================================================================= + // SortNTConfig sizes the slot buffer + IPT for the local-stage sort. + // We sort SLOTS_PADDED items with NT threads, IPT_SORT items per thread. + // SLOTS_PADDED = max(LOCAL_K, NT) so all NT threads have at least one slot. + template + struct SortNTConfig { + static constexpr int SLOTS_PADDED = + (LOCAL_K >= NT) ? ((LOCAL_K + NT - 1) / NT) * NT : NT; + static constexpr int NT_SORT = NT; + static constexpr int IPT_SORT = SLOTS_PADDED / NT; + }; + + template + struct MergeNTConfig { + // Final merge runs in the same kernel block (last CTA), so NT_MERGE must + // equal the kernel's NUM_THREADS or BlockMergeSort would deadlock. + static constexpr int PADDED = (CANDIDATES + NT - 1) / NT * NT; + static constexpr int NT_MERGE = NT; + static constexpr int IPT_MERGE = PADDED / NT; + }; + + template + __device__ __forceinline__ void merge_block_sort_topk_midk( + const uint32_t* __restrict__ keys_in, + const int32_t* __restrict__ idx_in, + int32_t* __restrict__ out_idx, + int final_k) + { + constexpr int kCandidates = SPLITS * LOCAL_K; + constexpr int NT = MergeNTConfig::NT_MERGE; + constexpr int IPT = MergeNTConfig::IPT_MERGE; + using BlockSortT = cub::BlockMergeSort; + __shared__ typename BlockSortT::TempStorage block_merge_smem; + + const int tx = threadIdx.x; + uint32_t bkeys[IPT]; + int32_t bvals[IPT]; + #pragma unroll + for (int k = 0; k < IPT; ++k) { + const int rank = tx * IPT + k; + bkeys[k] = (rank < kCandidates) ? keys_in[rank] : 0u; + bvals[k] = (rank < kCandidates) ? idx_in [rank] : -1; + } + BlockSortT(block_merge_smem).Sort(bkeys, bvals, DescendingUint32{}); + #pragma unroll + for (int k = 0; k < IPT; ++k) { + const int rank = tx * IPT + k; + if (rank < final_k && bvals[k] >= 0) out_idx[rank] = bvals[k]; + } + } + + template + __global__ __launch_bounds__(NUM_THREADS) + void TopKMidK_RandomSplit_SelectK_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + uint32_t* __restrict__ partial_keys, + int32_t* __restrict__ partial_indices, + int32_t* __restrict__ done_counter, + const int topk_val, + const int reserved_bos, + const int reserved_eos, + const float mapping_power) + { + constexpr int kRadix = 256; + constexpr int SLOTS_PADDED = SortNTConfig::SLOTS_PADDED; + constexpr int NT_SORT = SortNTConfig::NT_SORT; + constexpr int IPT_SORT = SortNTConfig::IPT_SORT; + // cub::BlockMergeSort calls __syncthreads() internally — every thread in + // the block must enter the sort branch, so NT_SORT must equal NUM_THREADS. + static_assert(NT_SORT == NUM_THREADS, + "NT_SORT must equal NUM_THREADS or BlockMergeSort deadlocks"); + + alignas(128) __shared__ int s_hist_buf[2][kRadix + 128]; + __shared__ int s_above_count; + __shared__ int s_thresh_above_count; + __shared__ int s_thresh_at_count; + __shared__ int s_threshold_bin; + __shared__ int s_last_remain; + __shared__ int s_sub_threshold_bin; + __shared__ int s_sub_last_remain; + __shared__ int s_strictly_above_sub; + __shared__ uint32_t s_top_keys[SLOTS_PADDED]; + __shared__ int32_t s_top_idx [SLOTS_PADDED]; + __shared__ int s_is_last; + + using LocalSortT = cub::BlockMergeSort; + __shared__ typename LocalSortT::TempStorage local_sort_smem; + + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = max(0, row_end - row_start); + + const int group_begin = (static_cast(row_len) * n) / SPLITS; + const int group_end = (static_cast(row_len) * (n + 1)) / SPLITS; + const int group_len = group_end - group_begin; + + const bool row_is_pow2 = is_pow2(row_len); + const uint32_t n_mask = row_is_pow2 ? static_cast(row_len - 1) : 0u; + const uint32_t b_off = + static_cast(b) * kPermuteSeedB + kPermuteOffset; + const ScoreT* row_scores = score + row_start; + const int* row_idxmap = dense_kv_indices + row_start; + + // ---- Init shared state. Strided over padding so any NT works. ---- + for (int i = tx; i < kRadix + 128; i += NUM_THREADS) { + s_hist_buf[0][i] = 0; + s_hist_buf[1][i] = 0; + } + if (tx == 0) { + s_above_count = 0; + s_thresh_above_count = 0; + s_thresh_at_count = 0; + s_threshold_bin = -1; + s_last_remain = 0; + s_sub_threshold_bin = -1; + s_sub_last_remain = 0; + s_strictly_above_sub = 0; + s_is_last = 0; + } + for (int i = tx; i < SLOTS_PADDED; i += NUM_THREADS) { + s_top_keys[i] = 0u; + s_top_idx [i] = -1; + } + __syncthreads(); + + // Empty-row early exit (preserves merge barrier for SPLITS>1). + if (row_len <= 0 || group_len <= 0) { + if constexpr (SPLITS > 1) { + const int64_t part_off = + (static_cast(b) * SPLITS + n) * LOCAL_K; + for (int i = tx; i < LOCAL_K; i += NUM_THREADS) { + partial_keys [part_off + i] = 0u; + partial_indices[part_off + i] = -1; + } + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + if (s_is_last) done_counter[b] = 0; // self-reset for next launch + } + __syncthreads(); + if (s_is_last == 0) return; + __threadfence(); + __syncthreads(); + const int64_t row_off = static_cast(b) * SPLITS * LOCAL_K; + merge_block_sort_topk_midk( + partial_keys + row_off, partial_indices + row_off, + sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos, + topk_val); + } + return; + } + + auto run_cumsum_strided = [&]() { + #pragma unroll + for (int i = 0; i < 8; ++i) { + const int j = 1 << i; + const int k = i & 1; + for (int idx = tx; idx < kRadix; idx += NUM_THREADS) { + int v = s_hist_buf[k][idx]; + if (idx + j < kRadix) v += s_hist_buf[k][idx + j]; + s_hist_buf[k ^ 1][idx] = v; + } + __syncthreads(); + } + }; + + // ============== Pass 1: top-byte histogram. ============== + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) { + pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + } else { + pos = group_begin + local_rank; + } + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + ::atomicAdd(&s_hist_buf[0][bin], 1); + } + __syncthreads(); + + run_cumsum_strided(); + const int total_items = s_hist_buf[0][0]; + + for (int bin = tx; bin < kRadix; bin += NUM_THREADS) { + const int total_at_or_above = s_hist_buf[0][bin]; + const int strictly_above = (bin + 1 < kRadix) ? s_hist_buf[0][bin + 1] : 0; + if (total_at_or_above >= LOCAL_K && strictly_above < LOCAL_K) { + s_threshold_bin = bin; + s_last_remain = LOCAL_K - strictly_above; + } + } + __syncthreads(); + + if (total_items <= LOCAL_K) { + // Few-elements path: collect everything, pad rest. + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int slot = ::atomicAdd(&s_above_count, 1); + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } + __syncthreads(); + } else { + const int threshold_bin = s_threshold_bin; + + for (int i = tx; i < kRadix + 128; i += NUM_THREADS) { + s_hist_buf[0][i] = 0; + s_hist_buf[1][i] = 0; + } + __syncthreads(); + + // ============== Pass 2: gather above + sub-hist. ============== + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + if (bin > threshold_bin) { + const int slot = ::atomicAdd(&s_above_count, 1); + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } else if (bin == threshold_bin) { + const int sub_bin = static_cast((key >> 16) & 0xFF); + ::atomicAdd(&s_hist_buf[0][sub_bin], 1); + } + } + __syncthreads(); + + run_cumsum_strided(); + + const int last_remain = s_last_remain; + for (int bin = tx; bin < kRadix; bin += NUM_THREADS) { + const int total_at_or_above = s_hist_buf[0][bin]; + const int strictly_above = (bin + 1 < kRadix) ? s_hist_buf[0][bin + 1] : 0; + if (total_at_or_above >= last_remain && strictly_above < last_remain) { + s_sub_threshold_bin = bin; + s_sub_last_remain = last_remain - strictly_above; + s_strictly_above_sub = strictly_above; + } + } + __syncthreads(); + + const int sub_threshold_bin = s_sub_threshold_bin; + const int sub_last_remain = s_sub_last_remain; + const int strictly_above_sub_bn = s_strictly_above_sub; + const int above_base = s_above_count; + + // ============== Pass 3: sub-above + sub-at. ============== + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + if (bin == threshold_bin) { + const int sub_bin = static_cast((key >> 16) & 0xFF); + if (sub_bin > sub_threshold_bin) { + const int rel = ::atomicAdd(&s_thresh_above_count, 1); + const int slot = above_base + rel; + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } else if (sub_bin == sub_threshold_bin) { + const int rel = ::atomicAdd(&s_thresh_at_count, 1); + if (rel < sub_last_remain) { + const int slot = above_base + strictly_above_sub_bn + rel; + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } + } + } + } + __syncthreads(); + } + + // ============== Sort SLOTS_PADDED candidates with cub::BlockMergeSort. ============== + // The first LOCAL_K slots may have real data; padded slots have (0u, -1). + // Sort uses NT_SORT threads; only those threads load/store sort items. + if (tx < NT_SORT) { + uint32_t kk[IPT_SORT]; + int32_t vv[IPT_SORT]; + #pragma unroll + for (int k = 0; k < IPT_SORT; ++k) { + const int slot = tx * IPT_SORT + k; + kk[k] = s_top_keys[slot]; + vv[k] = s_top_idx [slot]; + } + LocalSortT(local_sort_smem).Sort(kk, vv, DescendingUint32{}); + #pragma unroll + for (int k = 0; k < IPT_SORT; ++k) { + const int slot = tx * IPT_SORT + k; + s_top_keys[slot] = kk[k]; + s_top_idx [slot] = vv[k]; + } + } + __syncthreads(); + + // ============== SPLITS == 1: direct write to sparse_kv_indices. ============== + if constexpr (SPLITS == 1) { + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + for (int rank = tx; rank < topk_val; rank += NUM_THREADS) { + out_idx[rank] = s_top_idx[rank]; + } + return; + } + + // ============== SPLITS > 1: workspace, last-CTA barrier, merge. ============== + const int64_t part_off = (static_cast(b) * SPLITS + n) * LOCAL_K; + for (int i = tx; i < LOCAL_K; i += NUM_THREADS) { + partial_keys [part_off + i] = s_top_keys[i]; + partial_indices[part_off + i] = s_top_idx [i]; + } + + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + if (s_is_last) done_counter[b] = 0; // self-reset for next launch + } + __syncthreads(); + if (s_is_last == 0) return; + __threadfence(); + __syncthreads(); + + const int64_t row_off = static_cast(b) * SPLITS * LOCAL_K; + merge_block_sort_topk_midk( + partial_keys + row_off, partial_indices + row_off, + sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos, + topk_val); + } + + // Mid-K capacity policy. Returns true iff (LOCAL_K, SPLITS) is supported. + inline bool midk_split_supported(int local_k, int splits) { + const int candidates = local_k * splits; + if (candidates > 4096) return false; + if (splits != 1 && splits != 2 && splits != 4 && splits != 8 && + splits != 16 && splits != 32) return false; + return true; + } + + // Pick LOCAL_K from K. We use the smallest power-of-two LOCAL_K >= K. + inline int midk_local_k_from_topk(int topk_val) { + if (topk_val <= 64) return 64; + if (topk_val <= 128) return 128; + if (topk_val <= 256) return 256; + if (topk_val <= 512) return 512; + return -1; // unsupported + } + + } // namespace + + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + + // ============================================================================= + // Workspace API: zero hot-path at::empty allocations. + // + // topk_val >= 1024 → forwards to topk_output_sglang_fused without + // touching workspace tensors or done_counter. + // topk_val <= 32 → uses the K=30 random-split parallel path with + // forced_splits (if > 0) or pick_split_top30(). + // else → also forwards to fused (no specialised path here). + // + // partial_keys / partial_indices must each have at least + // eff_batch_size * SPLITS * kLocalK_Top30 = eff_batch_size * SPLITS * 32 + // int32 elements. done_counter must have at least eff_batch_size int32 + // elements; it is cleared with cudaMemsetAsync inside this call before the + // parallel kernel launches (and is NOT touched on the fused-fallback path). + // + // forced_splits encoding: + // <= 0 : use heuristic pick_split_top30(). + // 1 : single-CTA local sort path (for benchmarking). + // 2/4/8/16/32 : forced parallel split. + // anything else : TORCH_CHECK failure. + // ============================================================================= + void topk_output_adaptive_workspace( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& partial_keys, + at::Tensor& partial_indices, + at::Tensor& done_counter, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + const int64_t forced_splits, + const int64_t forced_partition, + const int64_t local_mode) + { + // ============== Fused fallback (no workspace touch) ============== + // K >= 1024: 32k -> 2048 lives here. Direct delegate. NO workspace check, + // NO memset, NO split kernel launch — this is the near-zero-overhead + // hot fast-path required for the K=2048 workload. + // + // K in (32, 1024) also routes here: those Ks have no specialised + // adaptive kernel and the fused baseline is the right path. NOTE: for + // K <= 32 (the K=30 path) we never come back to fused below — every + // adaptive sub-path stays on the split kernel. + if (topk_val >= kFusedFallbackTopK || topk_val > kMaxFinalK_Top30) { + topk_output_sglang_fused( + x, dense_kv_indptr, sparse_kv_indptr, + dense_kv_indices, sparse_kv_indices, + eff_batch_size, topk_val, + reserved_bos, reserved_eos, max_num_pages, + mapping_mode, mapping_power, std::nullopt, std::nullopt); + return; + } + + // ============== K <= 32 adaptive path (no fused fallback below) ============== + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + TORCH_CHECK(topk_val > 0, "topk_val must be > 0"); + TORCH_CHECK(eff_batch_size >= 1, "eff_batch_size must be >= 1"); + TORCH_CHECK(max_num_pages >= 1, "max_num_pages must be >= 1"); + + // local_mode validation. -1 (or any negative) defaults to SELECT32_SORT32, + // which is the production mode (no NT*IPT capacity ceiling, supports the + // full pages={4096,8192,16384,32768} x splits={1..32} matrix). + int local_mode_int = static_cast(local_mode); + if (local_mode_int < 0) local_mode_int = LOCAL_SELECT32_SORT32; + TORCH_CHECK(local_mode_int == LOCAL_BLOCK_FULL_SORT || + local_mode_int == LOCAL_SELECT32_SORT32, + "local_mode must be 0 (BLOCK_FULL_SORT) or 1 (SELECT32_SORT32), got ", + local_mode_int); + + TORCH_CHECK( + mapping_mode == MAPPING_NONE || + mapping_mode == MAPPING_POWER || + mapping_mode == MAPPING_LOG || + mapping_mode == MAPPING_ASINH || + mapping_mode == MAPPING_LOG1P || + mapping_mode == MAPPING_TRUNC8 || + mapping_mode == MAPPING_ERF || + mapping_mode == MAPPING_TANH || + mapping_mode == MAPPING_SUBTRACT || + mapping_mode == MAPPING_EXP_STRETCH || + mapping_mode == MAPPING_SHIFT_POW2 || + mapping_mode == MAPPING_SHIFT_POW3 || + mapping_mode == MAPPING_LINEAR_STEEP || + mapping_mode == MAPPING_HALF_SQUARE || + mapping_mode == MAPPING_HALF_CUBE, + "topk_output_adaptive_workspace: mapping_mode=", mapping_mode, + " not supported."); + + // Resolve split count. K=30 NEVER falls back to fused: split=1 means + // single-CTA adaptive kernel, not the fused baseline. + int split; + if (forced_splits > 0) { + split = static_cast(forced_splits); + TORCH_CHECK(split == 1 || split == 2 || split == 4 || split == 8 || + split == 16 || split == 32, + "forced_splits must be one of {1,2,4,8,16,32}, got ", split); + } else { + split = pick_split_top30(eff_batch_size, max_num_pages); + } + + // Resolve partition mode. + int partition; + if (forced_partition >= 0) { + partition = static_cast(forced_partition); + TORCH_CHECK(partition == PART_AFFINE_RANDOM || + partition == PART_CONTIGUOUS || + partition == PART_STRIDED || + partition == PART_TILE_RANDOM_128 || + partition == PART_TILE_RANDOM_256, + "forced_partition must be 0=affine,1=contiguous,2=strided," + "3=tile_random_128,4=tile_random_256"); + } else { + partition = default_partition(split, max_num_pages); + } + + // Capacity check applies ONLY to BLOCK_FULL_SORT, which uses + // cub::BlockRadixSort and is bounded by NT*IPT static-smem footprint. + // SELECT32_SORT32 has no such ceiling (its inner loops are strided). + // + // K=30 must NEVER silently fall back to fused — if BLOCK_FULL_SORT can't + // fit the chunk, we fail loudly so the caller picks a finer split or + // switches to SELECT32_SORT32. + if (local_mode_int == LOCAL_BLOCK_FULL_SORT) { + const int chunk_max = static_cast((max_num_pages + split - 1) / split); + const int cap = split_capacity(split); + TORCH_CHECK(cap >= chunk_max, + "topk_output_adaptive_workspace: BLOCK_FULL_SORT split=", split, + " has NT*IPT=", cap, + " < required chunk_max=", chunk_max, + " (max_num_pages=", max_num_pages, + "). Use SELECT32_SORT32 (local_mode=1) or a finer split."); + } + + // From here we enter the parallel path. The split=1 forced case still + // reads partial_keys/partial_indices/done_counter args but does NOT + // touch them — we accept any tensor of the right dtype. + CHECK_CUDA(partial_keys); + CHECK_CUDA(partial_indices); + CHECK_CUDA(done_counter); + TORCH_CHECK(partial_keys.dtype() == at::kInt, + "partial_keys must be int32 (uint32 reinterpreted)"); + TORCH_CHECK(partial_indices.dtype() == at::kInt, "partial_indices must be int32"); + TORCH_CHECK(done_counter.dtype() == at::kInt, "done_counter must be int32"); + + if (split > 1) { + TORCH_CHECK(done_counter.numel() >= eff_batch_size, + "done_counter[", done_counter.numel(), + "] too small for eff_batch_size=", eff_batch_size); + const int64_t need = eff_batch_size * static_cast(split) * kLocalK_Top30; + TORCH_CHECK(partial_keys.numel() >= need, + "partial_keys too small: ", partial_keys.numel(), " < ", need); + TORCH_CHECK(partial_indices.numel() >= need, + "partial_indices too small: ", partial_indices.numel(), " < ", need); + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + const float mp = static_cast(mapping_power); + + // No cudaMemsetAsync(done_counter) — the kernel self-resets done_counter[b] + // = 0 from the last CTA's tx==0 thread, so subsequent launches see it + // already zero. Saves ~1-2 µs of CPU launch overhead per call. Caller + // contract: done_counter must be zero-initialized once at workspace + // allocation (at::zeros) and not touched by anyone else on this stream. + + uint32_t* part_keys_ptr = + reinterpret_cast(partial_keys.data_ptr()); + int32_t* part_idx_ptr = partial_indices.data_ptr(); + int32_t* done_ptr = done_counter.data_ptr(); + + dim3 grid(static_cast(eff_batch_size), + static_cast(split)); + + // ---- BLOCK_FULL_SORT macro chain (TopK30_RandomSplit_Parallel_Kernel) ---- + #define LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART) \ + do { \ + auto* fn = &TopK30_RandomSplit_Parallel_Kernel< \ + DTYPE, MODE_VAL, SPLITS_VAL, NT, IPT, PART>; \ + fn<<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + part_keys_ptr, part_idx_ptr, done_ptr, \ + static_cast(topk_val), \ + static_cast(reserved_bos), \ + static_cast(reserved_eos), \ + mp); \ + } while (0) + + #define DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT) \ + do { \ + switch (partition) { \ + case PART_AFFINE_RANDOM: \ + LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART_AFFINE_RANDOM); break; \ + case PART_CONTIGUOUS: \ + LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART_CONTIGUOUS); break; \ + case PART_STRIDED: \ + LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART_STRIDED); break; \ + case PART_TILE_RANDOM_128: \ + LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART_TILE_RANDOM_128); break; \ + case PART_TILE_RANDOM_256: \ + LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART_TILE_RANDOM_256); break; \ + default: TORCH_CHECK(false, "unreachable partition mode"); \ + } \ + } while (0) + + #define DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + switch (split) { \ + case 1: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 1, kCfg1.num_threads, kCfg1.items_per_thread); break; \ + case 2: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 2, kCfg2.num_threads, kCfg2.items_per_thread); break; \ + case 4: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 4, kCfg4.num_threads, kCfg4.items_per_thread); break; \ + case 8: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 8, kCfg8.num_threads, kCfg8.items_per_thread); break; \ + case 16: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 16, kCfg16.num_threads, kCfg16.items_per_thread); break; \ + case 32: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 32, kCfg32.num_threads, kCfg32.items_per_thread); break; \ + default: TORCH_CHECK(false, "unsupported split=", split); \ + } \ + } while (0) + + // ---- SELECT32_SORT32 macro chain (TopK30_RandomSplit_Select32_Kernel) ---- + #define LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART) \ + do { \ + auto* fn = &TopK30_RandomSplit_Select32_Kernel< \ + DTYPE, MODE_VAL, SPLITS_VAL, NT, PART>; \ + fn<<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + part_keys_ptr, part_idx_ptr, done_ptr, \ + static_cast(topk_val), \ + static_cast(reserved_bos), \ + static_cast(reserved_eos), \ + mp); \ + } while (0) + + #define DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT) \ + do { \ + switch (partition) { \ + case PART_AFFINE_RANDOM: \ + LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART_AFFINE_RANDOM); break; \ + case PART_CONTIGUOUS: \ + LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART_CONTIGUOUS); break; \ + case PART_STRIDED: \ + LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART_STRIDED); break; \ + case PART_TILE_RANDOM_128: \ + LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART_TILE_RANDOM_128); break; \ + case PART_TILE_RANDOM_256: \ + LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART_TILE_RANDOM_256); break; \ + default: TORCH_CHECK(false, "unreachable partition mode"); \ + } \ + } while (0) + + #define DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + switch (split) { \ + case 1: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 1, kSelCfg1.num_threads); break; \ + case 2: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 2, kSelCfg2.num_threads); break; \ + case 4: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 4, kSelCfg4.num_threads); break; \ + case 8: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 8, kSelCfg8.num_threads); break; \ + case 16: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 16, kSelCfg16.num_threads); break; \ + case 32: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 32, kSelCfg32.num_threads); break; \ + default: TORCH_CHECK(false, "unsupported split=", split); \ + } \ + } while (0) + + // Top-level: choose the local-mode chain, then mapping_mode → split → partition. + // MAPPING_TRUNC8 shares its semantics with MAPPING_NONE (identity transform). + // Routing both to MAPPING_NONE saves one template instantiation per chain. + #define DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + if (local_mode_int == LOCAL_SELECT32_SORT32) { \ + switch (mapping_mode) { \ + case MAPPING_NONE: \ + case MAPPING_TRUNC8: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_LOG: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_LOG); break; \ + case MAPPING_ASINH: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_ERF: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + case MAPPING_HALF_SQUARE: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ + case MAPPING_HALF_CUBE: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ + default: TORCH_CHECK(false, "unreachable mapping_mode"); \ + } \ + } else { \ + switch (mapping_mode) { \ + case MAPPING_NONE: \ + case MAPPING_TRUNC8: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_LOG: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_LOG); break; \ + case MAPPING_ASINH: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_ERF: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + case MAPPING_HALF_SQUARE: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ + case MAPPING_HALF_CUBE: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ + default: TORCH_CHECK(false, "unreachable mapping_mode"); \ + } \ + } \ + } while (0) + + if (x.scalar_type() == at::ScalarType::BFloat16) { + DISPATCH_MODE(__nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr())); + } else if (x.scalar_type() == at::ScalarType::Float) { + DISPATCH_MODE(float, x.data_ptr()); + } else { + TORCH_CHECK(false, "topk_output_adaptive_workspace: unsupported dtype ", + x.scalar_type()); + } + + #undef DISPATCH_MODE + #undef DISPATCH_SPLIT_SELECT + #undef DISPATCH_PART_SELECT + #undef LAUNCH_TOP30_SELECT + #undef DISPATCH_SPLIT_BLOCK + #undef DISPATCH_PART_BLOCK + #undef LAUNCH_TOP30_BLOCK + + const auto rc = cudaGetLastError(); + TORCH_CHECK(rc == cudaSuccess, + "topk_output_adaptive_workspace launch failed: ", + ::cudaGetErrorString(rc)); + } + + // ============================================================================= + // Legacy entry point — allocates workspace internally and forwards. + // + // NOTE: this path performs at::empty allocations and is therefore NOT a + // reference for latency benchmarks. New callers should use + // topk_output_adaptive_workspace with preallocated workspace. + // ============================================================================= + void topk_output_adaptive( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power) + { + // Workspace big enough for the largest split this kernel may pick (32). + constexpr int64_t kMaxSplit = 32; + const int64_t ws_elems = eff_batch_size * kMaxSplit * kLocalK_Top30; + + auto opts_i32 = at::TensorOptions().device(x.device()).dtype(at::kInt); + at::Tensor partial_keys = at::empty({ws_elems}, opts_i32); + at::Tensor partial_indices = at::empty({ws_elems}, opts_i32); + at::Tensor done_counter = at::empty({eff_batch_size}, opts_i32); + + topk_output_adaptive_workspace( + x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, + sparse_kv_indices, partial_keys, partial_indices, done_counter, + eff_batch_size, topk_val, reserved_bos, reserved_eos, + max_num_pages, mapping_mode, mapping_power, + /*forced_splits=*/-1, + /*forced_partition=*/-1, + /*local_mode=*/LOCAL_SELECT32_SORT32); + } + + +// ============================================================================= +// Mid-K (K in {64, 128, 256, 512}) adaptive split entry point. +// +// Separate from topk_output_adaptive_workspace so the K=30 production path +// stays untouched. workspace tensors must be sized for +// eff_batch_size * SPLITS * LOCAL_K +// where LOCAL_K is the smallest power of two >= topk_val (max 512), and +// SPLITS is forced_splits if > 0, else 1. +// +// Dispatch contract: +// topk_val < 64 or > 512 → TORCH_CHECK failure (use the K=30 path or fused). +// forced_splits encoding: +// <= 0 : default policy (currently split=1; sweep will inform a heuristic). +// 1/2/4/8/16/32 : forced split, must satisfy midk_split_supported(). +// ============================================================================= +void topk_output_adaptive_workspace_midk( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& partial_keys, + at::Tensor& partial_indices, + at::Tensor& done_counter, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + const int64_t forced_splits) +{ + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + TORCH_CHECK(topk_val >= 64 && topk_val <= 512, + "topk_output_adaptive_workspace_midk: topk_val=", topk_val, + " out of range [64, 512]. Use topk_output_adaptive_workspace " + "for K<=32 or topk_output_sglang_fused for K>512."); + TORCH_CHECK(eff_batch_size >= 1, "eff_batch_size must be >= 1"); + TORCH_CHECK(max_num_pages >= 1, "max_num_pages must be >= 1"); + + const int local_k = midk_local_k_from_topk(static_cast(topk_val)); + TORCH_CHECK(local_k > 0, "unreachable: midk_local_k_from_topk failed for K=", topk_val); + + // Mid-K mappings: NONE / TRUNC8 only for now (kept template count low). + // POWER/LOG/etc. are easy to add later once we measure their value. + TORCH_CHECK(mapping_mode == MAPPING_NONE || mapping_mode == MAPPING_TRUNC8, + "topk_output_adaptive_workspace_midk: mapping_mode=", + mapping_mode, " not yet supported (use NONE or TRUNC8)."); + + int split; + if (forced_splits > 0) { + split = static_cast(forced_splits); + TORCH_CHECK(midk_split_supported(local_k, split), + "topk_output_adaptive_workspace_midk: split=", split, + " not supported for LOCAL_K=", local_k, + " (would need ", split * local_k, " merge candidates, max 4096)."); + } else { + // Sweep-driven default. From bench_results/midk_best_adaptive_p50.csv: + // + // pages <= 65536 : adaptive loses every cell vs fused on p50 — but a + // user calling this entry point explicitly is asking + // for adaptive anyway, so use split=1 (smallest gap). + // pages > 65536 : fused unsupported (smem ceiling). Best splits: + // K=64 → 16 + // K=128 → 16 + // K=256 → 2 + // K=512 → 4 + // + // forced_splits > 0 still overrides this, e.g. for benchmarking. + if (max_num_pages > 65536) { + switch (local_k) { + case 64: split = 16; break; + case 128: split = 16; break; + case 256: split = 2; break; + case 512: split = 4; break; + default: split = 1; + } + } else { + split = 1; + } + } + + CHECK_CUDA(partial_keys); + CHECK_CUDA(partial_indices); + CHECK_CUDA(done_counter); + TORCH_CHECK(partial_keys.dtype() == at::kInt, "partial_keys must be int32"); + TORCH_CHECK(partial_indices.dtype() == at::kInt, "partial_indices must be int32"); + TORCH_CHECK(done_counter.dtype() == at::kInt, "done_counter must be int32"); + + if (split > 1) { + TORCH_CHECK(done_counter.numel() >= eff_batch_size, + "done_counter[", done_counter.numel(), + "] too small for eff_batch_size=", eff_batch_size); + const int64_t need = eff_batch_size * static_cast(split) * local_k; + TORCH_CHECK(partial_keys.numel() >= need, + "partial_keys too small: ", partial_keys.numel(), " < ", need); + TORCH_CHECK(partial_indices.numel() >= need, + "partial_indices too small: ", partial_indices.numel(), " < ", need); + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + const float mp = static_cast(mapping_power); + + // No cudaMemsetAsync — kernel self-resets done_counter (see midk kernel + // and dispatcher comment for topk_output_adaptive_workspace). + + uint32_t* part_keys_ptr = + reinterpret_cast(partial_keys.data_ptr()); + int32_t* part_idx_ptr = partial_indices.data_ptr(); + int32_t* done_ptr = done_counter.data_ptr(); + + dim3 grid(static_cast(eff_batch_size), + static_cast(split)); + + // NT scales inversely with SPLITS so the per-CTA scan loop has roughly + // the same iteration count regardless of split count. Mirror the K=30 + // kSelCfg ladder. With NT=128 at SPLITS=1, a 65k-page row would force + // 512 iters/thread/pass — way slower than the ~64 iters fused achieves + // with NT=1024 single-CTA. Match fused throughput at split=1. + // + // SPLITS=1 : NT=1024 (chunk = full row) + // SPLITS=2 : NT=512 + // SPLITS=4 : NT=256 + // SPLITS=8 : NT=128 + // SPLITS=16 : NT=128 + // SPLITS=32 : NT=128 + + #define LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, SPLITS_VAL, NT_VAL) \ + do { \ + auto* fn = &TopKMidK_RandomSplit_SelectK_Kernel< \ + DTYPE, MODE_VAL, LOCAL_K_VAL, SPLITS_VAL, NT_VAL, PART_CONTIGUOUS>; \ + fn<<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + part_keys_ptr, part_idx_ptr, done_ptr, \ + static_cast(topk_val), \ + static_cast(reserved_bos), \ + static_cast(reserved_eos), \ + mp); \ + } while (0) + + #define DISPATCH_SPLIT_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL) \ + do { \ + switch (split) { \ + case 1: LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 1, 1024); break; \ + case 2: LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 2, 512); break; \ + case 4: LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 4, 256); break; \ + case 8: LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 8, 128); break; \ + case 16: \ + if constexpr ((LOCAL_K_VAL) * 16 <= 4096) \ + LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 16, 128); \ + else \ + TORCH_CHECK(false, "midk: split=16 unsupported for LOCAL_K=", LOCAL_K_VAL);\ + break; \ + case 32: \ + if constexpr ((LOCAL_K_VAL) * 32 <= 4096) \ + LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 32, 128); \ + else \ + TORCH_CHECK(false, "midk: split=32 unsupported for LOCAL_K=", LOCAL_K_VAL);\ + break; \ + default: TORCH_CHECK(false, "midk: unsupported split=", split); \ + } \ + } while (0) + + #define DISPATCH_LK_MIDK(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + switch (local_k) { \ + case 64: DISPATCH_SPLIT_MIDK(DTYPE, PTR_EXPR, MODE_VAL, 64); break; \ + case 128: DISPATCH_SPLIT_MIDK(DTYPE, PTR_EXPR, MODE_VAL, 128); break; \ + case 256: DISPATCH_SPLIT_MIDK(DTYPE, PTR_EXPR, MODE_VAL, 256); break; \ + case 512: DISPATCH_SPLIT_MIDK(DTYPE, PTR_EXPR, MODE_VAL, 512); break; \ + default: TORCH_CHECK(false, "midk: unreachable LOCAL_K=", local_k); \ + } \ + } while (0) + + #define DISPATCH_MIDK(DTYPE, PTR_EXPR) \ + do { \ + /* MAPPING_TRUNC8 aliases MAPPING_NONE; same template instantiation. */ \ + DISPATCH_LK_MIDK(DTYPE, PTR_EXPR, MAPPING_NONE); \ + } while (0) + + if (x.scalar_type() == at::ScalarType::BFloat16) { + DISPATCH_MIDK(__nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr())); + } else if (x.scalar_type() == at::ScalarType::Float) { + DISPATCH_MIDK(float, x.data_ptr()); + } else { + TORCH_CHECK(false, "topk_output_adaptive_workspace_midk: unsupported dtype ", + x.scalar_type()); + } + + #undef DISPATCH_MIDK + #undef DISPATCH_LK_MIDK + #undef DISPATCH_SPLIT_MIDK + #undef LAUNCH_MIDK + + const auto rc = cudaGetLastError(); + TORCH_CHECK(rc == cudaSuccess, + "topk_output_adaptive_workspace_midk launch failed: ", + ::cudaGetErrorString(rc)); +} diff --git a/csrc/topk_sglang_ori.cu b/csrc/topk_sglang_ori.cu new file mode 100644 index 00000000..55a99b21 --- /dev/null +++ b/csrc/topk_sglang_ori.cu @@ -0,0 +1,619 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + namespace { + + // NOTE: TopK is a compile-time constant here because shared-memory + // allocations inside the transform kernels depend on it. We drop it to + // 30 to match the vortex benchmark's --topk-val 30 configuration. The + // transform kernels (decode/prefill/prefill_ragged) still carry a manual + // unroll that assumes TopK==2048; that code path is unreachable from the + // bench (we only invoke fast_topk_interface), so the corresponding + // static_asserts have been removed below. + constexpr int TopK = 30; + constexpr int kThreadsPerBlock = 1024; + + #ifdef USE_ROCM + // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a + // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. + #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES + constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); + #else + constexpr size_t kSmem = 48 * 1024; // bytes + #endif + #else + // Reduced from 128KB to 32KB to improve occupancy. + // Each radix pass needs at most ~TopK candidates in the threshold bin, + // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. + constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) + #endif + + struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; + }; + + // when length <= TopK, we can directly write the indices + __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } + } + + __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } + } + + auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; + } + + template + void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + #ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); + } + + } // namespace + + // The public interface functions below collide by name with identically + // named symbols in topk_sglang.cu. Wrap them in `sglang_ori` so both + // translation units can be linked into the same vortex_torch_C extension. + namespace sglang_ori { + + #ifndef CHECK_CUDA + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + #endif + + void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + } // namespace sglang_ori + +// ====================================================================== +// Thin vortex_torch_C adapter: accepts the same CSR-ish inputs as +// topk_output_sglang so bench_topk.py can treat the original SGLang kernel +// as an alternate baseline. The ori kernel has TopK baked in as a compile- +// time constant; this build sets it to 30 to match --topk-val 30. +// ====================================================================== +void topk_output_sglang_ori( + const at::Tensor& x, // [total_dense, 1, 1] or [total_dense], bf16/fp32 + const at::Tensor& dense_kv_indptr, // int32 [eff_bs + 1] (unused — synthetic bench rows are uniform) + at::Tensor& indices_out, // int32 [eff_bs, TopK] + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) +{ + TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); + TORCH_CHECK(dense_kv_indptr.is_cuda(), "dense_kv_indptr must be a CUDA tensor"); + TORCH_CHECK(indices_out.is_cuda(), "indices_out must be a CUDA tensor"); + TORCH_CHECK(indices_out.scalar_type() == at::ScalarType::Int, + "indices_out must be int32"); + TORCH_CHECK(topk_val == static_cast(30), + "topk_output_sglang_ori: this build of the ori kernel hard-codes TopK=30; " + "rebuild topk_sglang_ori.cu with a different TopK if you need another value. " + "Got topk_val=", topk_val); + TORCH_CHECK(indices_out.dim() == 2 + && indices_out.size(0) == eff_batch_size + && indices_out.size(1) == 30, + "indices_out must be [eff_batch_size, 30]"); + + // ori kernel requires fp32 [B, stride] scores. Caller typically passes + // the bf16 score tensor; we materialize an fp32 view once per call. + at::Tensor score_f32; + if (x.scalar_type() == at::ScalarType::Float) { + score_f32 = x.contiguous().view({eff_batch_size, max_num_pages}); + } else if (x.scalar_type() == at::ScalarType::BFloat16) { + score_f32 = x.to(at::kFloat).contiguous().view({eff_batch_size, max_num_pages}); + } else { + TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); + } + + auto opts_i32 = at::TensorOptions().dtype(at::kInt).device(x.device()); + const int32_t usable_len = + static_cast(max_num_pages - reserved_bos - reserved_eos); + at::Tensor lengths = at::full({eff_batch_size}, usable_len, opts_i32); + at::Tensor row_starts = at::full({eff_batch_size}, + static_cast(reserved_bos), opts_i32); + + sglang_ori::fast_topk_interface( + score_f32, indices_out, lengths, + std::optional(row_starts)); +} \ No newline at end of file diff --git a/csrc/topk_sglang_profile.cu b/csrc/topk_sglang_profile.cu new file mode 100644 index 00000000..af6763a2 --- /dev/null +++ b/csrc/topk_sglang_profile.cu @@ -0,0 +1,605 @@ +/** + * TopK profiling kernels: histogram collection, stage-1-only timing, + * and diagnostic counter collection. + * + * Separated from topk_sglang.cu to reduce template instantiation + * pressure on CUDA shared memory resources. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +// Mirror of convert_to_uint8_dense in topk_sglang.cu so that the +// profile kernel (topk_profile_histogram / topk_profile_counters) +// reports accurate thr_bin / thr_size / abv_bins / pg/bin for +// MAPPING_DENSE_MANT. Keep in sync with the production kernel. +__device__ __forceinline__ auto convert_to_uint8_dense(float x) -> uint8_t { + const uint32_t bits = __float_as_uint(x); + const uint32_t key = (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + return static_cast((key >> 16) & 0xFFu); +} + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + + +constexpr int VORTEX_MAX_TOPK = 2048; + +// Diagnostic counters written by the profiling kernel. These kernels are +// NOT used for latency measurements — they intentionally add global-memory +// writes that distort timings. Latency is measured against the clean +// production kernels in topk_sglang.cu. +constexpr int COUNTER_THRESHOLD_BIN = 0; +constexpr int COUNTER_NUM_ABOVE = 1; +constexpr int COUNTER_NUM_EQUAL = 2; +constexpr int COUNTER_REMAINING_K = 3; +constexpr int COUNTER_REFINE_ROUNDS = 4; +constexpr int COUNTER_STAGE2_INPUT = 5; +constexpr int NUM_TOPK_COUNTERS = 6; + +#include "topk_mapping.cuh" + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling variant of fast_topk_clean_fused that writes diagnostic +// counters at the end of Stage 1 and at each Stage 2 early-exit. +// Shape / semantics identical to the production kernel, with one extra +// global-memory write pass at the end of each stage. Do not use for +// latency measurements. +// ====================================================================== +template +__device__ void fast_topk_profile( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams mapping, + int* __restrict__ counters) // [NUM_TOPK_COUNTERS] +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int p_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int p_counter; + alignas(128) __shared__ int p_threshold_bin_id; + alignas(128) __shared__ int p_num_input[2]; + + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + auto& p_histogram = p_histogram_buf[0]; + extern __shared__ int p_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Mirror of the production kernel: MAPPING_DENSE_MANT bypasses + // apply_transform and uses a mantissa-heavy fp32 bit slice for the + // Stage-1 bucket. + // MAPPING_DENSE_MANT / MAPPING_LUT_CDF / MAPPING_QUANTILE have been + // retired; every mode uses the standard fp16 bucket. + const bool use_dense_bucket = false; + + if (tx < RADIX + 1) p_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); // fmaxf(x, pivot) + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } + ::atomicAdd(&p_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = p_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += p_histogram_buf[k][tx + j]; + } + p_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && p_histogram[tx] > topk && p_histogram[tx + 1] <= topk) { + p_threshold_bin_id = tx; + p_num_input[0] = 0; + p_counter = 0; + } + __syncthreads(); + + const int threshold_bin_0 = p_threshold_bin_id; + const int threshold_bin_size = p_histogram[threshold_bin_0]; // pre-reset count + topk -= p_histogram[threshold_bin_0 + 1]; + + if (tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin_0; + counters[COUNTER_NUM_EQUAL] = threshold_bin_size; + counters[COUNTER_REMAINING_K] = topk; + } + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } + if (bin > threshold_bin_0) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = p_counter; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) p_histogram[tx] = 0; + __syncthreads(); + + const int sub_bin_offset_start = use_dense_bucket ? 8 : 24; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto bin = use_dense_bucket + ? static_cast(convert_to_uint8_dense(remapped)) + : static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + if (bin > threshold_bin_0) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin_0) { + const auto pos = ::atomicAdd(&p_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + p_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> sub_bin_offset_start) & 0xFF; + ::atomicAdd(&p_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + if (tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = p_counter; + counters[COUNTER_STAGE2_INPUT] = p_num_input[0]; + } + } + + // Stage 2 refinement. Standard modes run up to 4 rounds (offsets + // 24/16/8/0); MAPPING_DENSE_MANT runs up to 2 rounds (offsets 8/0) + // because Stage 1 already consumed bits [23:16] of the fp32 key. + const int stage2_offset_start = use_dense_bucket ? 8 : 24; + const int stage2_max_rounds = use_dense_bucket ? 2 : 4; + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = stage2_max_rounds; +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; + __shared__ int p_last_remain; + const auto r_idx = round % 2; + const auto _raw_num_input = p_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && p_histogram[tx] > topk && p_histogram[tx + 1] <= topk) { + p_threshold_bin_id = tx; + p_num_input[r_idx ^ 1] = 0; + p_last_remain = topk - p_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = p_threshold_bin_id; + topk -= p_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = p_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = round + 1; + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) p_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = p_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == stage2_max_rounds - 1) { + const auto pos = ::atomicAdd(&p_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&p_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + p_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&p_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Wrapper: one block per (batch*head) segment. Writes counters per +// segment into a [eff_batch_size, NUM_TOPK_COUNTERS] int32 tensor. +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKProfileCounters_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + int* __restrict__ counters, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_profile( + score_blk, s_indices, 0, nblk, topk_val, mapping, + counters + bx * NUM_TOPK_COUNTERS); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// Histogram-only profiling kernel: builds a 256-bin histogram of the +// remapped bins for each segment. Purely diagnostic — never timed. +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKProfileHistogram_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + int* __restrict__ histograms, // [eff_batch_size, 256] + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + constexpr auto RADIX = 256; + constexpr auto BLOCK_SIZE = kThreadsPerBlock; + __shared__ int s_histogram[RADIX]; + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + + if (tx < RADIX) s_histogram[tx] = 0; + __syncthreads(); + + // MAPPING_DENSE_MANT / MAPPING_LUT_CDF / MAPPING_QUANTILE retired. + const bool use_dense_bucket = false; + if (nblk > 0) { + const ScoreT* __restrict__ score_blk = score + start; + for (int i = tx; i < nblk; i += BLOCK_SIZE) { + const float raw = vortex_to_float(score_blk[i]); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } + ::atomicAdd(&s_histogram[bin], 1); + } + } + __syncthreads(); + + int* __restrict__ out = histograms + bx * RADIX; + if (tx < RADIX) out[tx] = s_histogram[tx]; +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +static TopKMappingParams build_mapping_params( + int64_t mapping_mode, double mapping_power, + std::optional& mapping_lut, + std::optional& mapping_quantiles) +{ + TopKMappingParams m{}; + m.mode = static_cast(mapping_mode); + m.power_exp = static_cast(mapping_power); + m.lut = nullptr; + m.quantiles = nullptr; + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + TORCH_CHECK(lut.is_cuda(), "mapping_lut must be a CUDA tensor"); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + m.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + TORCH_CHECK(q.is_cuda(), "mapping_quantiles must be a CUDA tensor"); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + m.quantiles = q.data_ptr(); + } + return m; +} + +// ====================================================================== +// Profiling: per-segment 256-bin histograms of Stage 1 remapped bins. +// ====================================================================== +void topk_profile_histogram( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + at::Tensor& histograms, + const int64_t eff_batch_size, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(histograms); + TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size + && histograms.size(1) == 256, + "histograms must be [eff_batch_size, 256]"); + TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, + "histograms must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + TopKProfileHistogram_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, reserved_eos, mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + TopKProfileHistogram_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, reserved_eos, mapping); + } else { + TORCH_CHECK(false, "topk_profile_histogram: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling: full pipeline + per-segment diagnostic counters. +// Adds extra global-memory writes — never use for latency measurement. +// ====================================================================== +void topk_profile_counters( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& counters, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_counters: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + CHECK_CUDA(counters); + TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size + && counters.size(1) == NUM_TOPK_COUNTERS, + "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); + TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, "counters must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKProfileCounters_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKProfileCounters_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); + } else { + TORCH_CHECK(false, "topk_profile_counters: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); +} diff --git a/csrc/utils_sglang.cu b/csrc/utils_sglang.cu index 1420e9ec..a7ddf42f 100644 --- a/csrc/utils_sglang.cu +++ b/csrc/utils_sglang.cu @@ -82,16 +82,20 @@ const int page_reserved_eos #pragma unroll for (int i = 0; i < ITEM_PER_THREAD; ++i){ - int16_t w = ((tx_offset + i) < eff_batch_size) ? - (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] + int16_t w = ((tx_offset + i) < eff_batch_size) ? + (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] - page_reserved_bos - page_reserved_eos): 0; - - page_count[i] = (w > topk_val) ? w : 0; + + // See note in Sgl_Decode_Plan_Workload_Kernel: we used to skip slots + // where w ≤ topk_val, but downstream (GeMV / topK / histogram) has no + // matching skip, so it read uninitialised scores and silently + // produced all-zero results. Emit workloads for every slot with w > 0. + page_count[i] = (w > 0) ? w : 0; chunked_page_count_prefix_sum[i + 1] = int((page_count[i] + max_chunk_size - 1) / max_chunk_size); } BlockScanInt(temp.scan_int).InclusiveSum(chunked_page_count_prefix_sum, chunked_page_count_prefix_sum); - + if (tx == 1023){ *winfo_num_workload = chunked_page_count_prefix_sum[ITEM_PER_THREAD]; *winfo_chunk_size = max_chunk_size; @@ -218,16 +222,22 @@ const int page_reserved_eos #pragma unroll for (int i = 0; i < ITEM_PER_THREAD; ++i){ - int16_t w = ((tx_offset + i) < eff_batch_size) ? - (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] + int16_t w = ((tx_offset + i) < eff_batch_size) ? + (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] - page_reserved_bos - page_reserved_eos): 0; - - page_count[i] = (w > topk_val) ? w : 0; + + // Previously: (w > topk_val) ? w : 0, which skipped scoring on slots + // where the dense page count is already ≤ topk_val. Downstream (GeMV, + // topK, histogram profiling) does NOT have a matching skip, so it + // would read uninitialised scores and silently return garbage (all + // zero). Emit workloads for every slot with w > 0 so scoring always + // runs; when w ≤ topk_val the topK degenerates to "select all w". + page_count[i] = (w > 0) ? w : 0; chunked_page_count_prefix_sum[i + 1] = int((page_count[i] + max_chunk_size - 1) / max_chunk_size); } BlockScanInt(temp.scan_int).InclusiveSum(chunked_page_count_prefix_sum, chunked_page_count_prefix_sum); - + if (tx == 1023){ *winfo_num_workloads = chunked_page_count_prefix_sum[ITEM_PER_THREAD]; *winfo_chunk_size = max_chunk_size; diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..d14650fd --- /dev/null +++ b/examples/README.md @@ -0,0 +1,399 @@ +# Vortex Torch Examples + +End-to-end accuracy evaluation and profiling pipelines for Vortex sparse attention on top of the SGLang inference engine. The scripts in this directory evaluate different TopK kernel variants, mapping functions, KV-cache quantization settings, and external sparse-attention backends on math reasoning benchmarks. + +--- + +## Mapping Functions Reference + +The TopK Stage-1 radix histogram uses 256 uint8 bins. A **mapping function** transforms raw attention scores before binning to improve bucket uniformity and reduce tail latency. Set via `--topk-mapping-mode`. + +| Mode | Name | Formula | Requires Calibration | Hyperparameter (`--topk-mapping-power`) | +|------|------|---------|---------------------|-----------------------------------------| +| 0 | None | FP16 bit-pattern bucketing | No | — | +| 1 | LUT CDF | `lut[original_bin]` (CDF equalization) | Yes (`--topk-mapping-lut-path`) | — | +| 2 | Quantile | Binary search over 256 float thresholds | Yes (`--topk-mapping-quantiles-path`) | — | +| 3 | Power | `sign(x) * \|x\|^p` | No | `p` (exponent, default 0.5) | +| 4 | Log | `sign(x) * log(\|x\| + 1)` | No | — | +| 5 | Index Cache | Reuse top-k indices from a preceding layer | No | — (see `--index-cache-shared-layers`) | +| 6 | Asinh | `asinh(beta * x)` | No | `beta` (default 0.5) | +| 7 | Log1p | `sign(x) * log1p(alpha * \|x\|)` | No | `alpha` (default 0.5) | +| 8 | Trunc8 | BF16 upper-8-bit bucketing | No | — | + +Modes 1 and 2 require an offline calibration step (see `calibrate_topk.py` in `benchmarks/`). Modes 3, 6, and 7 accept a tunable hyperparameter via `--topk-mapping-power`. + +--- + +## Python Scripts + +### `verify_algo.py` — End-to-End Accuracy Benchmark + +The primary evaluation script. Loads AMC 2023 math problems from `amc23.jsonl`, runs inference via the SGLang engine with Vortex sparse attention, and scores answers using `lighteval`'s extractive-match metric. Reports `mean@N`, `pass@N`, throughput, and memory access cost. + +**Usage:** + +```bash +python verify_algo.py [OPTIONS] +``` + +**CLI Arguments:** + +| Argument | Default | Description | +|----------|---------|-------------| +| `--trials` | 2 | Number of trials (each prompt repeated N times) | +| `--topk-val` | 30 | Number of top-k pages to select per segment | +| `--page-size` | 16 | Tokens per KV-cache page | +| `--vortex-module-name` | `gqa_block_sparse_attention` | Sparse attention algorithm module | +| `--model-name` | `Qwen/Qwen3-1.7B` | HuggingFace model identifier | +| `-f`, `--full-attention` | off | Disable sparse attention (full-attention baseline) | +| `--mem` | 0.8 | Static GPU memory fraction for SGLang | +| `--kv-cache-dtype` | `auto` | KV cache dtype: `auto`, `fp8_e5m2`, `fp8_e4m3`, `int8` | +| `--topk-type` | `naive` | TopK kernel: `naive` (CUB radix sort) or `sglang` (fast two-stage radix) | +| `--topk-mapping-mode` | 0 | Mapping function for Stage-1 binning (see table above) | +| `--topk-mapping-power` | 0.5 | Hyperparameter for modes 3/6/7 | +| `--topk-mapping-lut-path` | None | `.npy` uint8[256] LUT for mode 1 | +| `--topk-mapping-quantiles-path` | None | `.npy` float32[256] quantiles for mode 2 | +| `--index-cache-shared-layers` | None | Layer IDs that skip the indexer and reuse a previous layer's indices | + +**Fixed engine settings:** `attention_backend=flashinfer`, `vortex_max_seq_lens=12288`, layer 0 skipped, `reserved_bos=1`, `reserved_eos=2`. Sampling: `temperature=0.6`, `top_p=0.95`, `top_k=20`, `max_new_tokens=8192`. + +**Index cache note (mode 5):** When `--topk-mapping-mode 5` is set without `--index-cache-shared-layers`, the script defaults to even layers `[2, 4, 6, ..., 26]` and internally resets the mapping mode to 0 while passing the shared-layer list to the engine. + +**Example — full-attention baseline:** + +```bash +python verify_algo.py --full-attention --trials 8 --mem 0.7 +``` + +**Example — sglang TopK with power mapping:** + +```bash +python verify_algo.py \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power 0.25 \ + --trials 8 --topk-val 30 --mem 0.7 +``` + +**Example — sglang TopK with calibrated LUT:** + +```bash +python verify_algo.py \ + --topk-type sglang \ + --topk-mapping-mode 1 \ + --topk-mapping-lut-path calibration/lut.npy \ + --trials 8 --topk-val 30 --mem 0.7 +``` + +--- + +### `verify_aim24.py` — AIME 2024 Throughput Test (Legacy) + +A standalone throughput script that loads AIME 2024 from HuggingFace (`HuggingFaceH4/aime_2024`), builds chat prompts using the Qwen3 tokenizer with `enable_thinking=True`, and repeats each prompt 8 times. Outputs a JSONL file with generation results and timing metadata. Does **not** compute accuracy metrics. + +**Usage:** + +```bash +python verify_aim24.py +``` + +All settings are hard-coded (no CLI arguments): + +| Setting | Value | +|---------|-------| +| Model | `Qwen/Qwen3-0.6B` | +| Page size | 16 | +| Selected pages | 29 | +| Max sequence length | 20480 | +| Module | `block_sparse_attention` | +| Memory fraction | 0.9 | +| Max new tokens | 16384 | +| CUDA graph | Enabled | + +--- + +## Shell Scripts + +All shell scripts set `CUDA_VISIBLE_DEVICES` and save timestamped logs to `results/`. + +### `verify_algo.sh` — Baseline TopK Comparison (Naive vs SGLang) + +Runs `verify_algo.py` with `block_sparse_attention` comparing the `naive` and `sglang` TopK kernels. Each configuration is repeated `REPEAT_COUNT` times (default 3, overridable via environment variable). + +```bash +REPEAT_COUNT=5 bash verify_algo.sh +``` + +### `verify_algo_topk.sh` — Naive vs SGLang Comparison + +Similar to `verify_algo.sh` but simpler: runs `naive` TopK and `sglang` TopK back-to-back for `block_sparse_attention`, each with 8 trials. + +### `verify_algo_quant.sh` — INT8 KV-Cache Quantization + +Tests sparse attention with `--kv-cache-dtype int8` to measure accuracy under quantized KV caches. + +```bash +bash verify_algo_quant.sh +``` + +### `verify_sparse_backends.sh` — External Sparse Attention Backends + +Evaluates three external sparse-attention algorithms integrated via the Vortex flow interface: + +- `nsa` (Native Sparse Attention) +- `fsa` (Flash Sparse Attention) +- `flash_moba` (Flash MoBA) + +```bash +bash verify_sparse_backends.sh +``` + +### `verify_algo_topk_mapping.sh` — Full Mapping Mode Sweep + +Comprehensive sweep across all mapping modes: + +1. **Baseline:** `naive` TopK, mode 0 +2. **Calibration:** runs `calibrate_topk.py` to generate `lut.npy` and `quantiles.npy` (skipped if files exist) +3. **Mode 1** (LUT CDF) and **Mode 2** (Quantile) with calibrated tables +4. **Modes 0, 3, 4** (no calibration needed) — Power mode uses `--topk-mapping-power 0.5` +5. **Mode 6** (Asinh) — sweeps `beta` in `[0.5, 1.0, 2.0]` +6. **Mode 7** (Log1p) — sweeps `alpha` in `[0.5, 1.0, 2.0]` + +```bash +export CUDA_VISIBLE_DEVICES=0 +bash verify_algo_topk_mapping.sh +``` + +### `verify_algo_topk_mapping_new.sh` — Parametric Mapping Sweep (Modes 3, 6, 7) + +Focused hyperparameter sweep for the three parametric modes, preceded by an auto-tuning step: + +| Mode | Parameter | Sweep Values | +|------|-----------|-------------| +| 3 (Power) | `p` | 0.1, 0.25, 0.75, 0.9 | +| 6 (Asinh) | `beta` | 0.1, 0.5, 1.0, 2.0, 4.0 | +| 7 (Log1p) | `alpha` | 0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0 | + +Requires `calibration/raw_histograms.npy` for the auto-tune step. + +```bash +export CUDA_VISIBLE_DEVICES=5 +bash verify_algo_topk_mapping_new.sh +``` + +### `verify_algo_topk_mapping_indexcache.sh` — Index Cache (Mode 5) + +Tests the index-cache optimization where even-numbered layers `[2, 4, 6, ..., 26]` reuse top-k indices from the nearest preceding full layer, skipping their indexer entirely. + +```bash +bash verify_algo_topk_mapping_indexcache.sh +``` + +### `run_topk_benchmark.sh` — Unified TopK Benchmark Pipeline + +The most comprehensive benchmarking script. Three-step pipeline: + +1. **Calibrate** — collect real-data histograms + LUT/quantile tables +2. **Kernel bench** — latency + histogram profiling across batch sizes, sequence lengths, and distributions, followed by distribution analysis plots and auto-tuning +3. **E2E accuracy** — full-attention baseline plus every mapping mode + +```bash +bash run_topk_benchmark.sh --gpu 5 --trials 8 --model-name Qwen/Qwen3-1.7B +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `--model-name` | `Qwen/Qwen3-1.7B` | HuggingFace model | +| `--topk-val` | 30 | Top-k pages | +| `--trials` | 8 | E2E trial count | +| `--mem` | 0.7 | GPU memory fraction | +| `--gpu` | 5 | CUDA device | +| `--algo` | `block_sparse_attention` | Sparse attention algorithm | +| `--skip-calibrate` | off | Reuse existing calibration | +| `--skip-kernel` | off | Skip kernel-level latency step | +| `--skip-e2e` | off | Skip E2E accuracy step | + +### `run_distribution_analysis.sh` — Bucket Distribution Profiling (All Modes) + +Three-step pipeline to analyze how each mapping mode affects the 256-bin bucket distribution: + +1. **Calibrate** — collect real-data histograms (skippable with `--real-histograms`) +2. **Bench** — histogram profiling with modes 0–8 on `bucket_uniform` and `normal` distributions +3. **Analyze** — generate comparison plots and CSV bucket count tables + +```bash +bash run_distribution_analysis.sh --gpu 5 +bash run_distribution_analysis.sh --gpu 5 --real-histograms /path/to/raw_histograms.npy +``` + +### `run_distribution_analysis_new.sh` — Bucket Distribution Profiling (Modes 3, 6, 7) + +Same pipeline as above but focused on parametric modes only, with an additional auto-tune step: + +1. **Calibrate** (or skip with existing histograms) +2. **Auto-tune** — sweep hyperparameters on synthetic data +3. **Bench** — histogram profiling for modes 3, 6, 7, 8 +4. **Analyze** — comparison plots + tables + +```bash +bash run_distribution_analysis_new.sh --gpu 5 +``` + +--- + +## Benchmarks Directory Scripts + +The `benchmarks/` directory contains standalone profiling and analysis tools used by the shell pipelines above. These can also be run independently. + +### `calibrate_topk.py` — Offline Calibration + +Runs the SGLang engine on real prompts from `amc23.jsonl` with histogram collection enabled. Produces three files: + +- `lut.npy` — uint8[256] CDF-equalized LUT for mode 1 +- `quantiles.npy` — float32[256] quantile breakpoints for mode 2 +- `raw_histograms.npy` — raw per-sample 256-bin histograms + +```bash +python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 --mem 0.7 \ + --output-dir calibration/ +``` + +### `bench_topk.py` — Kernel-Level Latency Benchmark + +Benchmarks `topk_output` (naive/CUB) and `topk_output_sglang` (fast radix) across configurable sweeps of batch size, sequence length, TopK value, KV heads, and score distributions. Optionally collects 256-bin histogram statistics. + +```bash +python benchmarks/bench_topk.py \ + --batch-sizes 4 8 16 \ + --seq-lens 2048 4096 8192 \ + --topk-vals 30 \ + --num-kv-heads 2 \ + --distributions normal lognormal uniform bucket_uniform \ + --histogram \ + --repeat 100 \ + --output-json results.json +``` + +### `autotune_topk_mapping.py` — Hyperparameter Auto-Tuning + +Sweeps hyperparameters for parametric mapping modes (3, 6, 7) using the `topk_profile_histogram` kernel on synthetic data. Ranks configurations by resolution rate, Gini coefficient, max/mean ratio, and nonzero bins. + +```bash +python benchmarks/autotune_topk_mapping.py \ + --topk-val 30 --batch-size 4 --seq-len 4096 --num-kv-heads 2 \ + --real-histograms calibration/raw_histograms.npy \ + --output-json autotune_results.json +``` + +### `analyze_topk_distribution.py` — Visualization and Analysis + +Loads profiling data and generates: +- Per-segment 256-bin bar charts +- Heatmaps (segments x bins, log-scale) +- Before/after LUT mapping comparisons +- Mode comparison grouped bar charts (Gini + max/mean) +- Distribution comparison plots across data sources +- CSV bucket count tables + +```bash +python benchmarks/analyze_topk_distribution.py \ + --bench-json bench_distribution.json \ + --real-histograms calibration/raw_histograms.npy \ + --output-dir plots/ +``` + +### `profile_topk_distribution.py` — Offline Table Generation + +Computes LUT and quantile tables from pre-collected histograms or raw scores without running a model. Outputs a single `.npz` archive. + +```bash +python benchmarks/profile_topk_distribution.py \ + --histograms-input raw_histograms.npy \ + --output mapping_tables.npz +``` + +### `greedy_layer_search.py` — Index Cache Layer Selection + +Greedy forward-selection of layers whose indexer can be skipped (index cache). Iteratively adds layers to the shared set as long as accuracy stays above `--threshold` times the baseline. + +```bash +cd examples && python ../benchmarks/greedy_layer_search.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 \ + --threshold 0.95 \ + --trials 1 \ + --num-layers 28 \ + --mem 0.7 +``` + +--- + +## Data Files + +| File | Description | +|------|-------------| +| `amc23.jsonl` | AMC 2023 math problems with `prompt` and `answer` fields, used by `verify_algo.py` and `calibrate_topk.py` | + +--- + +## Output Structure + +Results are saved under `results/` in timestamped directories: + +``` +results/ +├── dist_analysis_YYYYMMDD_HHMMSS/ +│ ├── step1_calibrate.log +│ ├── step2_autotune.log / step2_bench.log +│ ├── step3_bench.log / step3_analyze.log +│ ├── step4_analyze.log +│ ├── autotune_results.json +│ ├── bench_distribution.json +│ ├── distribution_comparison_*.png +│ ├── bucket_counts_*.csv +│ └── calibration/ +│ ├── lut.npy +│ ├── quantiles.npy +│ └── raw_histograms.npy +├── topk_benchmark_YYYYMMDD_HHMMSS/ +│ ├── kernel_latency.json +│ ├── e2e/ +│ │ ├── full_attention_baseline.log +│ │ ├── sglang_mode0_none.log +│ │ └── ... +│ └── calibration/ +└── *.log (individual run logs) +``` + +--- + +## Quick Start: Typical Workflow + +```bash +export CUDA_VISIBLE_DEVICES=0 + +# 1. Calibrate to generate LUT + quantile tables +python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B --topk-val 30 --mem 0.7 \ + --output-dir examples/calibration/ + +# 2. Run full-attention baseline +python examples/verify_algo.py --full-attention --trials 8 --mem 0.7 + +# 3. Evaluate sparse attention with different mapping modes +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 0 --trials 8 --mem 0.7 + +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 3 --topk-mapping-power 0.25 \ + --trials 8 --mem 0.7 + +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 6 --topk-mapping-power 1.0 \ + --trials 8 --mem 0.7 + +# 4. Or run the full pipeline in one shot +bash examples/run_topk_benchmark.sh --gpu 0 --trials 8 +``` diff --git a/examples/ablation_remap_function_block_size.sh b/examples/ablation_remap_function_block_size.sh new file mode 100644 index 00000000..4bf5bba5 --- /dev/null +++ b/examples/ablation_remap_function_block_size.sh @@ -0,0 +1,279 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. block (page) size +# +# Sweeps BLOCK_SIZE and, for every cell, runs the full +# calibrate -> autotune -> remap-bench +# pipeline so the per-mode hyperparameter is freshly chosen by +# autotune for that block size (NOT hardcoded). +# +# Mapping modes under test (matches the screenshot): +# 0 none — unmapped baseline +# 3 power — p +# 6 asinh — beta +# 7 log1p — alpha +# 9 erf — alpha +# 10 tanh — alpha +# 11 subtract — pivot +# 13 exp_stretch — alpha +# 15 shift_pow2 — pivot +# 16 shift_pow3 — pivot +# 17 linear_steep — k +# +# Output: +# results/ablation_remap_block_size_/ +# bs/{autotune_results.json, remap_bench.json, step{1,2,3}_*.log} +# sweep_index.json +# selected_hparams.txt — per-cell screenshot-style summary +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +MEM=0.7 +MAX_TOTAL_TOKENS=64768 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +BLOCK_SIZES="1 2 4 8 16 32 64" +REAL_HISTOGRAMS="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --block-sizes) BLOCK_SIZES="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_block_size_${MODEL_SLUG}_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +# Per-model calibration cache (reused across block_size cells: page size +# does not change the per-segment score distribution). +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Ablation: remap function vs block_size" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block sizes: ${BLOCK_SIZES}" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Step 0: Calibrate once for this model ────────────────────── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo ">>> Step 0: SKIPPED calibration (using ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo ">>> Step 0: Calibrating ${MODEL_NAME} for raw_histograms.npy" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size 1 \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${SWEEP_DIR}/step0_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 0: Done. raw_histograms -> ${REAL_HIST_PATH}" +fi + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +echo "{" > "${SWEEP_INDEX}" +echo " \"axis_name\": \"block_size\"," >> "${SWEEP_INDEX}" +echo " \"axis_type\": \"kernel\"," >> "${SWEEP_INDEX}" +echo " \"model_name\": \"${MODEL_NAME}\"," >> "${SWEEP_INDEX}" +echo " \"topk_val\": ${TOPK_VAL}," >> "${SWEEP_INDEX}" +echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," >> "${SWEEP_INDEX}" +echo " \"cells\": [" >> "${SWEEP_INDEX}" + +FIRST_CELL=1 +for BLOCK_SIZE in ${BLOCK_SIZES}; do + # Pick a seq_len that satisfies pages/seg > topk_val + 3 reserved. + MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) + SEQ_LEN=${MIN_SEQ_LEN} + # Round up to next power-of-two-ish multiple of 1024 for stable timing. + if [ "${SEQ_LEN}" -lt 8192 ]; then SEQ_LEN=8192; fi + + CELL_DIR="${SWEEP_DIR}/bs${BLOCK_SIZE}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: block_size=${BLOCK_SIZE} seq_len=${SEQ_LEN}" + echo "============================================================" + + echo ">>> Autotuning hparams for block_size=${BLOCK_SIZE}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + echo ">>> Remap bench for block_size=${BLOCK_SIZE}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import MODE_NAMES, PARAM_NAME +except Exception: + MODE_NAMES = {0: "none", 3: "power", 6: "asinh", 7: "log1p", 9: "erf", + 10: "tanh", 11: "subtract", 13: "exp_stretch", + 15: "shift_pow2", 16: "shift_pow3", 17: "linear_steep"} + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} + +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path = sys.argv[1], sys.argv[2] +with open(idx_path) as f: + idx = json.load(f) + +lines = ["== Selected mapping functions (autotuned, block_size sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m not in best: + continue + pname = PARAM_NAME.get(m, "p") + pval = best[m].get("param", 0.0) + parts.append(f"{DISPLAY[m]}({pname}={pval})") + lines.append(f"[block_size={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "Block-size ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Per-cell results: ${SWEEP_DIR}/bs/" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/ablation_remap_function_model.sh b/examples/ablation_remap_function_model.sh new file mode 100644 index 00000000..0212b83a --- /dev/null +++ b/examples/ablation_remap_function_model.sh @@ -0,0 +1,262 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. model +# +# Sweeps MODEL_NAME across the Qwen3 family. For every model: +# 1. Calibrate (or reuse cached raw_histograms_.npy) +# 2. Autotune the per-mode hparam on that model's histogram +# (NOT hardcoded; freshly tuned per model) +# 3. Remap-bench across the autotuned hparams +# +# Mapping modes under test (matches the screenshot): +# 0 none, 3 power, 6 asinh, 7 log1p, 9 erf, 10 tanh, +# 11 subtract, 13 exp_stretch, 15 shift_pow2, 16 shift_pow3, +# 17 linear_steep +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODELS="Qwen/Qwen3-0.6B Qwen/Qwen3-1.7B Qwen/Qwen3-4B Qwen/Qwen3-8B" +TOPK_VAL=2048 +BLOCK_SIZE=1 +MEM=0.7 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --models) MODELS="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +# Per-model max-total-tokens (KV pool cap for calibration). Larger models +# need a smaller cap so they fit at MEM=0.7. Override by passing the env +# var MAX_TOTAL_TOKENS_=N before invocation. +declare -A MAX_TOTAL_TOKENS_LUT +MAX_TOTAL_TOKENS_LUT["qwen3-0.6B"]=131072 +MAX_TOTAL_TOKENS_LUT["qwen3-1.7B"]=64768 +MAX_TOTAL_TOKENS_LUT["qwen3-4B"]=32768 +MAX_TOTAL_TOKENS_LUT["qwen3-8B"]=16384 + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_model_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +mkdir -p "${CALIBRATION_BASE}" + +echo "============================================================" +echo "Ablation: remap function vs model" +echo " Models: ${MODELS}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +{ + echo "{" + echo " \"axis_name\": \"model\"," + echo " \"axis_type\": \"kernel\"," + echo " \"topk_val\": ${TOPK_VAL}," + echo " \"block_size\": ${BLOCK_SIZE}," + echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," + echo " \"cells\": [" +} > "${SWEEP_INDEX}" + +# Pick a single seq_len that satisfies pages/seg > topk_val for all models. +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +SEQ_LEN=${MIN_SEQ_LEN} +if [ "${SEQ_LEN}" -lt 8192 ]; then SEQ_LEN=8192; fi + +FIRST_CELL=1 +for MODEL_NAME in ${MODELS}; do + MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" + MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" + DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" + + # Per-model max-total-tokens (override-able via env). + MTT_DEFAULT="${MAX_TOTAL_TOKENS_LUT[${MODEL_TAG}]:-32768}" + ENV_KEY="MAX_TOTAL_TOKENS_$(echo "${MODEL_TAG}" | tr '.-' '__')" + MAX_TOTAL_TOKENS="${!ENV_KEY:-${MTT_DEFAULT}}" + + CELL_DIR="${SWEEP_DIR}/${MODEL_SLUG}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: model=${MODEL_NAME} (max_total_tokens=${MAX_TOTAL_TOKENS})" + echo "============================================================" + + # Step 1: calibrate (cached per-model) + if [ -f "${DEFAULT_REAL_HIST}" ]; then + echo ">>> Calibration cache hit: ${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + else + echo ">>> Calibrating ${MODEL_NAME}" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${CELL_DIR}/step1_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + fi + + # Step 2: autotune + echo ">>> Autotuning hparams for ${MODEL_NAME}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + # Step 3: remap bench + echo ">>> Remap bench for ${MODEL_NAME}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" "model" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import PARAM_NAME +except Exception: + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path, axis_name = sys.argv[1], sys.argv[2], sys.argv[3] +with open(idx_path) as f: + idx = json.load(f) + +lines = [f"== Selected mapping functions (autotuned, {axis_name} sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m in best: + parts.append(f"{DISPLAY[m]}({PARAM_NAME.get(m,'p')}={best[m].get('param',0.0)})") + lines.append(f"[{axis_name}={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "Model ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/ablation_remap_function_topk_benchmark.sh b/examples/ablation_remap_function_topk_benchmark.sh new file mode 100644 index 00000000..7952bd20 --- /dev/null +++ b/examples/ablation_remap_function_topk_benchmark.sh @@ -0,0 +1,277 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. topk-kernel benchmark workload +# +# Sweeps the kernel-bench INPUT distribution (the workload that +# stresses the TopK kernel) and, per cell, runs autotune + +# remap-bench so the per-mode hparam is freshly chosen for that +# distribution. This is the robustness ablation: do the +# autotuned remap functions still beat the unmapped baseline +# when the input score distribution shifts? +# +# Distributions available in bench_topk.py: +# normal — N(0,1) per-page scores +# lognormal — heavy-tailed positive scores +# uniform — U[0,1) +# bucket_uniform— per-bucket uniform (worst case for radix) +# real — sampled from raw_histograms_.npy +# +# Mapping modes under test (matches the screenshot): +# 0 none, 3 power, 6 asinh, 7 log1p, 9 erf, 10 tanh, +# 11 subtract, 13 exp_stretch, 15 shift_pow2, 16 shift_pow3, +# 17 linear_steep +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +BLOCK_SIZE=1 +SEQ_LEN=32768 +MEM=0.7 +MAX_TOTAL_TOKENS=64768 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# Distributions to sweep (one per cell). "real" requires raw_histograms.npy. +DISTRIBUTION_LIST="normal lognormal uniform bucket_uniform real" +REAL_HISTOGRAMS="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --dist-list|--distributions) DISTRIBUTION_LIST="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + SEQ_LEN=${MIN_SEQ_LEN} +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_topk_benchmark_${MODEL_SLUG}_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +# Need raw_histograms only if "real" is in the distribution list. +NEED_REAL=0 +for d in ${DISTRIBUTION_LIST}; do + if [ "$d" = "real" ]; then NEED_REAL=1; fi +done + +if [ "${NEED_REAL}" -eq 1 ] && [ -z "${REAL_HISTOGRAMS}" ]; then + echo ">>> Step 0: Calibrating ${MODEL_NAME} (needed for distribution=real)" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${SWEEP_DIR}/step0_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Ablation: remap function vs topk-kernel benchmark workload" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " TopK: ${TOPK_VAL}" +echo " Seq len: ${SEQ_LEN}" +echo " Distributions: ${DISTRIBUTION_LIST}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +{ + echo "{" + echo " \"axis_name\": \"distribution\"," + echo " \"axis_type\": \"kernel\"," + echo " \"model_name\": \"${MODEL_NAME}\"," + echo " \"topk_val\": ${TOPK_VAL}," + echo " \"block_size\": ${BLOCK_SIZE}," + echo " \"seq_len\": ${SEQ_LEN}," + echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," + echo " \"cells\": [" +} > "${SWEEP_INDEX}" + +FIRST_CELL=1 +for DIST in ${DISTRIBUTION_LIST}; do + CELL_DIR="${SWEEP_DIR}/dist_${DIST}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: distribution=${DIST}" + echo "============================================================" + + AUTOTUNE_DIST_ARGS=() + BENCH_DIST_ARGS=() + if [ "${DIST}" = "real" ]; then + AUTOTUNE_DIST_ARGS=(--real-histograms "${REAL_HISTOGRAMS}") + BENCH_DIST_ARGS=(--real-histograms "${REAL_HISTOGRAMS}" --distributions real) + else + AUTOTUNE_DIST_ARGS=(--distributions "${DIST}") + BENCH_DIST_ARGS=(--distributions "${DIST}") + fi + + echo ">>> Autotuning hparams on dist=${DIST}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + "${AUTOTUNE_DIST_ARGS[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + echo ">>> Remap bench on dist=${DIST}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + "${BENCH_DIST_ARGS[@]}" \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" "distribution" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import PARAM_NAME +except Exception: + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path, axis_name = sys.argv[1], sys.argv[2], sys.argv[3] +with open(idx_path) as f: + idx = json.load(f) + +lines = [f"== Selected mapping functions (autotuned, {axis_name} sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m in best: + parts.append(f"{DISPLAY[m]}({PARAM_NAME.get(m,'p')}={best[m].get('param',0.0)})") + lines.append(f"[{axis_name}={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "topk_benchmark (kernel workload) ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/ablation_remap_function_topk_val.sh b/examples/ablation_remap_function_topk_val.sh new file mode 100644 index 00000000..4e60440a --- /dev/null +++ b/examples/ablation_remap_function_topk_val.sh @@ -0,0 +1,255 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. topk_val +# +# Sweeps TOPK_VAL and, for every cell, runs +# autotune -> remap-bench +# so the per-mode hyperparameter is freshly chosen by autotune +# for that topk_val (NOT hardcoded). Calibration runs once for +# the model. +# +# Mapping modes under test (matches the screenshot): +# 0 none, 3 power, 6 asinh, 7 log1p, 9 erf, 10 tanh, +# 11 subtract, 13 exp_stretch, 15 shift_pow2, 16 shift_pow3, +# 17 linear_steep +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +BLOCK_SIZE=1 +MEM=0.7 +MAX_TOTAL_TOKENS=64768 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +TOPK_VALS="512 1024 2048 4096" +REAL_HISTOGRAMS="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --topk-vals) TOPK_VALS="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_topk_val_${MODEL_SLUG}_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Ablation: remap function vs topk_val" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Topk vals: ${TOPK_VALS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Step 0: Calibrate once for this model ────────────────────── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo ">>> Step 0: SKIPPED calibration (using ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo ">>> Step 0: Calibrating ${MODEL_NAME}" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + CAL_TOPK_VAL=$(echo "${TOPK_VALS}" | tr ' ' '\n' | sort -n | tail -n 1) + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${CAL_TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${SWEEP_DIR}/step0_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" +fi + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +{ + echo "{" + echo " \"axis_name\": \"topk_val\"," + echo " \"axis_type\": \"kernel\"," + echo " \"model_name\": \"${MODEL_NAME}\"," + echo " \"block_size\": ${BLOCK_SIZE}," + echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," + echo " \"cells\": [" +} > "${SWEEP_INDEX}" + +FIRST_CELL=1 +for TOPK_VAL in ${TOPK_VALS}; do + MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) + SEQ_LEN=${MIN_SEQ_LEN} + if [ "${SEQ_LEN}" -lt 8192 ]; then SEQ_LEN=8192; fi + if [ "${SEQ_LEN}" -lt $(( TOPK_VAL * BLOCK_SIZE * 4 )) ]; then + SEQ_LEN=$(( TOPK_VAL * BLOCK_SIZE * 4 )) + fi + + CELL_DIR="${SWEEP_DIR}/topk${TOPK_VAL}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: topk_val=${TOPK_VAL} seq_len=${SEQ_LEN}" + echo "============================================================" + + echo ">>> Autotuning hparams for topk_val=${TOPK_VAL}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + echo ">>> Remap bench for topk_val=${TOPK_VAL}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" "topk_val" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import PARAM_NAME +except Exception: + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path, axis_name = sys.argv[1], sys.argv[2], sys.argv[3] +with open(idx_path) as f: + idx = json.load(f) + +lines = [f"== Selected mapping functions (autotuned, {axis_name} sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m in best: + parts.append(f"{DISPLAY[m]}({PARAM_NAME.get(m,'p')}={best[m].get('param',0.0)})") + lines.append(f"[{axis_name}={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "topk_val ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/analyze_ablation_remap.py b/examples/analyze_ablation_remap.py new file mode 100644 index 00000000..18b0a0b5 --- /dev/null +++ b/examples/analyze_ablation_remap.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +""" +Analyze remap-function ablation sweeps. + +Reads one or more sweep directories produced by + ablation_remap_function_block_size.sh + ablation_remap_function_topk_val.sh + ablation_remap_function_model.sh + ablation_remap_function_topk_benchmark.sh + +and emits, for each sweep: + - tidy CSV of every (axis_value, mapping_mode, distribution, head) row + - wide CSV tables: latency, speedup vs baseline, chosen hparam + - LaTeX version of the chosen-hparam table + - markdown summary including the screenshot-style "Selected mapping + functions" line per axis value + - matplotlib PDF plots: latency vs axis, speedup vs axis, threshold + bin size vs axis (one curve per mapping mode) + +Usage: + python examples/analyze_ablation_remap.py \ + --sweep-dir results/ablation_remap_block_size_ \ + [--sweep-dir results/ablation_remap_model_ ...] \ + --output-dir results/ablation_remap_analysis_ +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +# Pull mode metadata from the autotune script so we don't duplicate it. +SCRIPT_DIR = Path(__file__).resolve().parent +BENCH_DIR = SCRIPT_DIR.parent / "benchmarks" +sys.path.insert(0, str(BENCH_DIR)) +try: + from autotune_topk_mapping import MODE_NAMES, PARAM_NAME # type: ignore +except Exception: + MODE_NAMES = {0: "none", 3: "power", 4: "log", 6: "asinh", 7: "log1p", + 8: "trunc8", 9: "erf", 10: "tanh", 11: "subtract", + 13: "exp_stretch", 15: "shift_pow2", 16: "shift_pow3", + 17: "linear_steep"} + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} + +DISPLAY_NAME = { + 0: "None", 3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", + 10: "Tanh", 11: "Subtract", 13: "ExpStretch", + 15: "ShiftPow2", 16: "ShiftPow3", 17: "LinearSteep", +} + + +# ---------- Loading ---------- + +def _load_json(path: str) -> Any: + with open(path) as f: + return json.load(f) + + +def _best_per_mode_from_autotune(autotune_results: List[dict]) -> Dict[int, dict]: + best: Dict[int, dict] = {} + for r in autotune_results: + m = int(r["mode"]) + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + return best + + +def _flatten_remap_bench(remap_results: List[dict]) -> pd.DataFrame: + """Flatten bench_topk.py --remap-bench output into one row per + (cfg, mode_row). Drops per-head sub-rows; keeps head='all' so each + cell contributes a single point per (mapping_mode, distribution).""" + rows = [] + for cfg in remap_results: + if cfg.get("head", "all") != "all": + continue + baseline = cfg.get("baseline_ms") + for mr in cfg.get("modes", []): + mode = int(mr["mode"]) + rows.append({ + "distribution": cfg.get("distribution"), + "batch_size": cfg.get("batch_size"), + "num_kv_heads": cfg.get("num_kv_heads"), + "seq_len": cfg.get("seq_len"), + "topk_val": cfg.get("topk_val"), + "pages_per_seg": cfg.get("pages_per_seg"), + "mode": mode, + "mode_name": mr.get("mode_name", MODE_NAMES.get(mode, str(mode))), + "param_value": mr.get("power"), + "fused_ms": mr.get("fused_ms"), + "remap_ms": mr.get("remap_ms"), + "topk_after_remap_ms": mr.get("topk_after_remap_ms"), + "split_total_ms": mr.get("split_total_ms"), + "baseline_ms": baseline, + "threshold_bin_size_mean": mr.get("threshold_bin_size_mean"), + "threshold_bin_size_max": mr.get("threshold_bin_size_max"), + "refine_rounds_mean": mr.get("refine_rounds_mean"), + }) + return pd.DataFrame(rows) + + +def load_sweep(sweep_dir: Path) -> Dict[str, Any]: + idx_path = sweep_dir / "sweep_index.json" + if not idx_path.exists(): + raise FileNotFoundError(f"missing sweep_index.json in {sweep_dir}") + idx = _load_json(idx_path) + axis_name = idx["axis_name"] + + rows: List[pd.DataFrame] = [] + chosen_hparams: List[dict] = [] + for cell in idx["cells"]: + axis_value = cell["axis_value"] + autotune_results = _load_json(cell["autotune_json"]) + best = _best_per_mode_from_autotune(autotune_results) + for mode, r in best.items(): + chosen_hparams.append({ + "axis_value": axis_value, + "mode": int(mode), + "mode_name": r.get("mode_name", MODE_NAMES.get(int(mode), str(mode))), + "param_name": r.get("param_name") or PARAM_NAME.get(int(mode), "p"), + "param_value": r.get("param"), + "autotune_latency_ms": r.get("latency_ms"), + }) + + remap_results = _load_json(cell["remap_bench_json"]) + df = _flatten_remap_bench(remap_results) + df.insert(0, "axis_value", axis_value) + df.insert(0, "axis_name", axis_name) + rows.append(df) + + tidy = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame() + chosen = pd.DataFrame(chosen_hparams) + return { + "axis_name": axis_name, + "axis_type": idx.get("axis_type", "kernel"), + "index": idx, + "tidy": tidy, + "chosen": chosen, + } + + +# ---------- Tables ---------- + +def _wide_latency(tidy: pd.DataFrame, axis_name: str, distribution: Optional[str] = None) -> pd.DataFrame: + df = tidy.copy() + if distribution is not None and "distribution" in df.columns: + df = df[df["distribution"] == distribution] + # Best fused latency per (axis_value, mode) — collapse over distribution + # if no filter was applied. + g = df.groupby(["axis_value", "mode", "mode_name"], dropna=False)["fused_ms"].min().reset_index() + wide = g.pivot(index="axis_value", columns="mode", values="fused_ms") + # Also pivot mode_name → label for column header. + return wide.rename(columns=lambda m: f"{m}:{MODE_NAMES.get(int(m), '?')}") + + +def _wide_baseline(tidy: pd.DataFrame, distribution: Optional[str] = None) -> pd.Series: + df = tidy.copy() + if distribution is not None and "distribution" in df.columns: + df = df[df["distribution"] == distribution] + return df.groupby("axis_value")["baseline_ms"].min() + + +def _wide_speedup(tidy: pd.DataFrame, axis_name: str, distribution: Optional[str] = None) -> pd.DataFrame: + lat = _wide_latency(tidy, axis_name, distribution=distribution) + base = _wide_baseline(tidy, distribution=distribution) + return lat.rdiv(base, axis=0) # baseline / fused + + +def _wide_chosen_hparam(chosen: pd.DataFrame) -> pd.DataFrame: + if chosen.empty: + return pd.DataFrame() + chosen = chosen.copy() + chosen["label"] = chosen.apply( + lambda r: f"{DISPLAY_NAME.get(int(r['mode']), r['mode_name'])}({r['param_name']}={r['param_value']})", + axis=1, + ) + wide = chosen.pivot(index="axis_value", columns="mode", values="label") + return wide.rename(columns=lambda m: f"{m}:{MODE_NAMES.get(int(m), '?')}") + + +def _df_to_latex(df: pd.DataFrame, caption: str, label: str) -> str: + if df.empty: + return f"% empty table for {label}\n" + try: + return df.to_latex( + float_format=lambda v: "" if pd.isna(v) else f"{v:.4f}", + na_rep="", + caption=caption, + label=label, + ) + except Exception: + return df.to_string() + + +# ---------- Plots ---------- + +def _axis_x(values: List[Any]) -> List[float]: + """Convert axis values (which may be strings or ints) to numeric x + coordinates. Strings are mapped to 0..N-1; numerics keep their value.""" + out = [] + for i, v in enumerate(values): + if isinstance(v, (int, float)): + out.append(float(v)) + else: + out.append(float(i)) + return out + + +def _plot_metric_vs_axis(tidy: pd.DataFrame, axis_name: str, metric: str, + out_path: Path, ylabel: str, title: str, + baseline_series: Optional[pd.Series] = None, + logy: bool = False) -> None: + if tidy.empty: + return + g = tidy.groupby(["axis_value", "mode", "mode_name"], dropna=False)[metric].min().reset_index() + axis_values = sorted(g["axis_value"].unique(), + key=lambda v: (not isinstance(v, (int, float)), v)) + x = _axis_x(axis_values) + + fig, ax = plt.subplots(figsize=(7, 4.5)) + cmap = plt.cm.get_cmap("tab10") + for i, mode in enumerate(sorted(g["mode"].unique())): + sub = g[g["mode"] == mode].set_index("axis_value").reindex(axis_values) + ax.plot(x, sub[metric].values, + marker="o", color=cmap(i % 10), + label=f"{mode}:{MODE_NAMES.get(int(mode), '?')}") + + if baseline_series is not None and not baseline_series.empty: + bx = baseline_series.reindex(axis_values).values + ax.plot(x, bx, "k--", linewidth=2, label="baseline (unmapped)") + + ax.set_xlabel(axis_name) + ax.set_ylabel(ylabel) + ax.set_title(title) + if logy: + ax.set_yscale("log") + if all(isinstance(v, (int, float)) for v in axis_values): + ax.set_xticks(x) + ax.set_xticklabels([str(v) for v in axis_values]) + else: + ax.set_xticks(x) + ax.set_xticklabels([str(v) for v in axis_values], rotation=20, ha="right") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=7, ncol=2, loc="best") + fig.tight_layout() + fig.savefig(out_path) + plt.close(fig) + + +# ---------- Per-sweep emitters ---------- + +def emit_sweep(sweep: Dict[str, Any], out_root: Path) -> None: + axis_name = sweep["axis_name"] + out_dir = out_root / axis_name + out_dir.mkdir(parents=True, exist_ok=True) + + tidy: pd.DataFrame = sweep["tidy"] + chosen: pd.DataFrame = sweep["chosen"] + + if tidy.empty: + print(f"[{axis_name}] no data, skipping") + return + + tidy.to_csv(out_dir / "tidy.csv", index=False) + chosen.to_csv(out_dir / "chosen_hparams_long.csv", index=False) + + distributions = sorted([d for d in tidy["distribution"].dropna().unique()]) + + # Per-distribution wide tables + plots. + for dist in distributions + [None]: + suffix = f"_{dist}" if dist else "_all" + lat_wide = _wide_latency(tidy, axis_name, distribution=dist) + spd_wide = _wide_speedup(tidy, axis_name, distribution=dist) + base = _wide_baseline(tidy, distribution=dist) + + lat_wide.to_csv(out_dir / f"table_latency_ms{suffix}.csv") + spd_wide.to_csv(out_dir / f"table_speedup_vs_baseline{suffix}.csv") + base.to_frame("baseline_ms").to_csv(out_dir / f"table_baseline_ms{suffix}.csv") + + with open(out_dir / f"table_latency_ms{suffix}.tex", "w") as f: + f.write(_df_to_latex(lat_wide, + caption=f"Best fused-kernel latency (ms) on {axis_name} sweep ({dist or 'all dists'})", + label=f"tab:lat-{axis_name}{suffix}")) + with open(out_dir / f"table_speedup_vs_baseline{suffix}.tex", "w") as f: + f.write(_df_to_latex(spd_wide, + caption=f"Speedup over unmapped baseline on {axis_name} sweep ({dist or 'all dists'})", + label=f"tab:spd-{axis_name}{suffix}")) + + _plot_metric_vs_axis( + tidy if dist is None else tidy[tidy["distribution"] == dist], + axis_name, "fused_ms", + out_dir / f"plot_latency_vs_{axis_name}{suffix}.pdf", + ylabel="fused TopK kernel latency (ms)", + title=f"TopK kernel latency vs {axis_name} ({dist or 'all dists'})", + baseline_series=base, + ) + # Speedup plot. + spd_long = tidy.copy() + if dist: + spd_long = spd_long[spd_long["distribution"] == dist] + spd_long = spd_long.assign( + speedup=spd_long["baseline_ms"] / spd_long["fused_ms"] + ) + _plot_metric_vs_axis( + spd_long, axis_name, "speedup", + out_dir / f"plot_speedup_vs_{axis_name}{suffix}.pdf", + ylabel="speedup over unmapped baseline", + title=f"Speedup vs {axis_name} ({dist or 'all dists'})", + ) + # Threshold bin size diagnostic. + _plot_metric_vs_axis( + tidy if dist is None else tidy[tidy["distribution"] == dist], + axis_name, "threshold_bin_size_mean", + out_dir / f"plot_threshold_bin_size_vs_{axis_name}{suffix}.pdf", + ylabel="mean threshold-bin size (entries)", + title=f"Stage-1 threshold bin size vs {axis_name} ({dist or 'all dists'})", + ) + + # Chosen-hparam wide table (axis-independent of distribution: autotune + # picks one hparam per mode per axis cell). + chosen_wide = _wide_chosen_hparam(chosen) + chosen_wide.to_csv(out_dir / "table_chosen_hparams.csv") + with open(out_dir / "table_chosen_hparams.tex", "w") as f: + f.write(_df_to_latex(chosen_wide, + caption=f"Autotuned remap-function hyperparameters per {axis_name} cell", + label=f"tab:hparam-{axis_name}")) + + # Markdown summary. + md_lines: List[str] = [] + md_lines.append(f"# Ablation: remap function vs `{axis_name}`\n") + md_lines.append(f"Source: `{sweep['index'].get('cells', [{}])[0].get('cell_dir', '')}/...`\n") + + md_lines.append("\n## Selected mapping functions (autotuned)\n") + md_lines.append("```") + for v in chosen_wide.index.tolist(): + parts = [] + for col in chosen_wide.columns: + label = chosen_wide.loc[v, col] + if isinstance(label, str) and label: + parts.append(label) + md_lines.append(f"[{axis_name}={v}] " + " ".join(parts)) + md_lines.append("```\n") + + md_lines.append("\n## Latency (ms) — best fused, all distributions\n") + md_lines.append(_wide_latency(tidy, axis_name).to_markdown()) + md_lines.append("\n\n## Speedup over unmapped baseline\n") + md_lines.append(_wide_speedup(tidy, axis_name).to_markdown()) + md_lines.append("\n\n## Chosen hyperparameters\n") + md_lines.append(chosen_wide.to_markdown()) + md_lines.append("\n\n## Plots\n") + for p in sorted(out_dir.glob("plot_*.pdf")): + md_lines.append(f"- `{p.name}`") + + with open(out_dir / "summary.md", "w") as f: + f.write("\n".join(md_lines) + "\n") + + print(f"[{axis_name}] wrote artifacts to {out_dir}") + + +# ---------- Top-level ---------- + +def main() -> None: + ap = argparse.ArgumentParser(description="Aggregate ablation_remap_function_*.sh sweep outputs.") + ap.add_argument("--sweep-dir", action="append", required=True, + help="A sweep directory containing sweep_index.json. Repeat for multiple sweeps.") + ap.add_argument("--output-dir", type=str, required=True, + help="Where to write tables, plots, and summary.") + args = ap.parse_args() + + out_root = Path(args.output_dir) + out_root.mkdir(parents=True, exist_ok=True) + + sweeps: List[Dict[str, Any]] = [] + for sd in args.sweep_dir: + sweep = load_sweep(Path(sd)) + emit_sweep(sweep, out_root) + sweeps.append(sweep) + + # Cross-axis recommended hparams: for every mode, pick the param value + # that was selected most often across all axis cells of all sweeps. + all_chosen = pd.concat([s["chosen"] for s in sweeps if not s["chosen"].empty], + ignore_index=True) if sweeps else pd.DataFrame() + rec_lines: List[str] = [] + if not all_chosen.empty: + rec = (all_chosen.groupby(["mode", "mode_name", "param_name"])["param_value"] + .agg(lambda s: s.value_counts().idxmax()) + .reset_index().rename(columns={"param_value": "recommended"})) + rec.to_csv(out_root / "recommended_hparams.csv", index=False) + rec_lines.append("## Cross-axis recommended hparams (mode of selections)\n") + rec_lines.append(rec.to_markdown(index=False)) + + index_lines = ["# Remap-function ablation summary\n"] + for s in sweeps: + axis = s["axis_name"] + index_lines.append(f"- [`{axis}`]({axis}/summary.md)") + if rec_lines: + index_lines.append("") + index_lines.extend(rec_lines) + with open(out_root / "index.md", "w") as f: + f.write("\n".join(index_lines) + "\n") + print(f"[index] {out_root}/index.md") + + +if __name__ == "__main__": + main() diff --git a/examples/archived/run_distribution_analysis.sh b/examples/archived/run_distribution_analysis.sh new file mode 100755 index 00000000..25150153 --- /dev/null +++ b/examples/archived/run_distribution_analysis.sh @@ -0,0 +1,236 @@ +#!/usr/bin/env bash +# ============================================================ +# Bucket Distribution Profiling Pipeline +# +# Profiles the SGLang TopK kernel's first-pass bucket distribution +# to identify hotspot buckets causing tail latency. +# +# Four steps: +# 1. Calibrate — collect real-data histograms +# (skippable via --real-histograms PATH) +# 2. Auto-tune — sweep hyperparameters to find best per-mode power +# 3. Bench — histogram profiling (bucket_uniform + normal) +# noscale kernels use the same autotuned power +# 4. Analyze — comparison plots + bucket count tables +# +# All outputs (JSON, plots, CSV tables, logs) are written to a +# single timestamped folder under examples/results/dist_analysis_*. +# +# Usage: +# bash run_distribution_analysis.sh --gpu 5 +# bash run_distribution_analysis.sh --gpu 5 \ +# --real-histograms /path/to/calibration_dir/raw_histograms.npy +# bash run_distribution_analysis.sh --gpu 5 --block-size 16 +# bash run_distribution_analysis.sh --watchdog-timeout 0 # disable calibrate watchdog (fork) +# bash run_distribution_analysis.sh --max-total-tokens 1048576 # cap KV / VTX buffers during calibrate +# Models (default: 1.7B + 4B). Override with repeated --model-name: +# bash run_distribution_analysis.sh --model-name Qwen/Qwen3-1.7B --model-name Qwen/Qwen3-4B +# ============================================================ + +# Mapping functions: +# 0: None — original fp16 bit-pattern bucketing +# 1: LUT CDF — LUT-based CDF equalization (calibrated) +# 2: Quantile — piecewise-linear quantile mapping (calibrated) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 5: Index Cache — reuse previous layer's indices +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 8: Trunc8 — bf16 upper-8-bit bucketing +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 11: Subtract — x - pivot (RadiK-style scatter) + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=7 +# Models to run (full pipeline per model). Override with one or more --model-name. +MODEL_NAMES=( "Qwen/Qwen3-1.7B" "Qwen/Qwen3-4B" ) +MODEL_NAMES_USER_SET=0 +TOPK_VAL=30 +MEM=0.7 +MAX_TOTAL_TOKENS=1048576 +ALGO="block_sparse_attention" +RADIX_BITS=8 +SAMPLE_STRIDE=1 +SEQ_LEN=32768 +# KV page / block size (passed to benchmarks as --page-size) +BLOCK_SIZE=16 +# The path to the raw_histograms.npy file (set to skip calibration) +REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="" +HAS_WATCHDOG_TIMEOUT=0 +WATCHDOG_TIMEOUT="" +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) + if [ "${MODEL_NAMES_USER_SET}" -eq 0 ]; then + MODEL_NAMES=() + MODEL_NAMES_USER_SET=1 + fi + MODEL_NAMES+=("$2") + shift 2 + ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --radix-bits) RADIX_BITS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size) BLOCK_SIZE="$2"; shift 2 ;; + --watchdog-timeout) HAS_WATCHDOG_TIMEOUT=1; WATCHDOG_TIMEOUT="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +if [ "${#MODEL_NAMES[@]}" -eq 0 ]; then + echo "ERROR: No models in MODEL_NAMES; pass at least one --model-name." + exit 1 +fi + +# Validate seq_len: need pages/seg > topk_val (reserved=3 pages + slack) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +echo "============================================================" +echo "Bucket Distribution Profiling Pipeline" +echo " Models (${#MODEL_NAMES[@]}): ${MODEL_NAMES[*]}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Block size: ${BLOCK_SIZE} (--page-size in benchmarks)" +echo " GPU: ${GPU_ID}" +echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" +echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +if [ "${HAS_WATCHDOG_TIMEOUT}" -eq 1 ]; then + echo " Watchdog (cal): ${WATCHDOG_TIMEOUT}s (0 = off, vortex SGLang fork)" +else + echo " Watchdog (cal): " +fi +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Run id: ${TIMESTAMP}" +echo " Output root: ${RESULTS_DIR}/dist_analysis__${TIMESTAMP}/" +echo "============================================================" + +for MODEL_NAME in "${MODEL_NAMES[@]}"; do + MODEL_SLUG="${MODEL_NAME//\//_}" + RUN_DIR="${RESULTS_DIR}/dist_analysis_${MODEL_SLUG}_${TIMESTAMP}" + mkdir -p "${RUN_DIR}" + + echo "" + echo "############################ MODEL: ${MODEL_NAME} ############################" + echo " Output: ${RUN_DIR}" + + # ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── + if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" + else + echo "" + echo ">>> Step 1: Calibrating — collecting real-inference histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + CALIB_EXTRA_ARGS=() + if [ "${HAS_WATCHDOG_TIMEOUT}" -eq 1 ]; then + CALIB_EXTRA_ARGS+=(--watchdog-timeout "${WATCHDOG_TIMEOUT}") + fi + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + "${CALIB_EXTRA_ARGS[@]}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" + fi + + # Pick up lut.npy / quantiles.npy if calibration produced them. + CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" + LUT_PATH="" + Q_PATH="" + [ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" + [ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + [ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" + [ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" + + # ── Step 2: Auto-tune — rank by fused-topk kernel latency ────── + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by fused-topk kernel latency" + + AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" + AUTOTUNE_EXTRA=(--real-histograms "${REAL_HIST_PATH}") + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size 4 \ + --seq-len ${SEQ_LEN} \ + --page-size "${BLOCK_SIZE}" \ + --num-kv-heads 2 \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + + # ── Step 3: Remap benchmark with autotuned hparams ────────────── + echo "" + echo ">>> Step 3: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" + + BENCH_JSON="${RUN_DIR}/remap_bench.json" + BENCH_EXTRA=() + [ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") + + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes 4 \ + --num-kv-heads 8 \ + --seq-lens ${SEQ_LEN} \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions bucket_uniform normal \ + --mapping-modes 0 1 2 3 6 7 8 9 10 11 13 \ + --autotune-json "${AUTOTUNE_JSON}" \ + "${BENCH_EXTRA[@]}" \ + --repeat 20 \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_bench.log" + + echo ">>> Step 3: Done. Remap bench saved to ${BENCH_JSON}" +done + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Bucket Distribution Profiling Complete" +echo " Per-model outputs under ${RESULTS_DIR}/ (run id ${TIMESTAMP}):" +echo " dist_analysis__${TIMESTAMP}/" +echo " autotune_results.json, bench_distribution.json, plots, CSV, logs" +echo "============================================================" diff --git a/examples/archived/verify_algo_topk_mapping.sh b/examples/archived/verify_algo_topk_mapping.sh new file mode 100644 index 00000000..f361e594 --- /dev/null +++ b/examples/archived/verify_algo_topk_mapping.sh @@ -0,0 +1,204 @@ +#!/usr/bin/env bash +# ============================================================ +# E2E accuracy comparison: naive baseline + unmapped sglang + +# every surviving parametric mapping mode (3, 4, 6, 7, 9, 10, 13) +# with per-mode hyperparameters picked by autotune_topk_mapping.py +# (ranked by measured fused-topk kernel latency, lowest wins). +# +# Surviving mapping modes after the lean refactor: +# 0: None — unmapped baseline +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 13: ExpStretch — y = exp(alpha * x) +# ============================================================ +set -e + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=0 +BENCHMARKS="amc23" +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=2 +SEQ_LEN=32768 +MAX_TOTAL_TOKENS=1048576 +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +sparse_algos=( "block_sparse_attention" ) + +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RESULTS_DIR="results/${MODEL_SLUG}_${BENCH_LABEL}" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +# ============================================================ +# Baseline: naive topk +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_naive_${TIMESTAMP}.log" + echo ">>> naive topk algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val "${TOPK_VAL}" \ + --vortex-module-name "${algo}" \ + --model-name "${MODEL_NAME}" \ + --topk-type naive \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Calibrate (optional) — real-distribution histograms +# ============================================================ +if [ -z "${REAL_HISTOGRAMS}" ]; then + CALIBRATION_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + echo ">>> Max total tokens (KV / VTX cap): ${MAX_TOTAL_TOKENS}" + for algo in "${sparse_algos[@]}"; do + echo ">>> Calibrating ${MODEL_NAME} for ${algo}..." + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem 0.7 \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${algo}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RESULTS_DIR}/calibration_${algo}_${TIMESTAMP}.log" + done + REAL_HISTOGRAMS="${CALIBRATION_DIR}/raw_histograms.npy" +fi + +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HISTOGRAMS}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + +# ============================================================ +# Auto-tune — rank by fused-topk kernel latency +# ============================================================ +AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" +if [ "${SKIP_AUTOTUNE}" -eq 0 ]; then + AUTOTUNE_EXTRA=() + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + if [ -f "${REAL_HISTOGRAMS}" ]; then + echo "============================================================" + echo "Auto-tuning hyperparameters (real distribution, latency-ranked)" + echo "============================================================" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --seq-len "${SEQ_LEN}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HISTOGRAMS}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" + else + echo ">>> WARNING: ${REAL_HISTOGRAMS} not found — autotune will use synthetic data" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --seq-len "${SEQ_LEN}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --page-size "${BLOCK_SIZE}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + fi +fi + +# Extract best per-mode hparam (ranked by kernel latency, lowest wins). +eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode'); lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + print(f'BEST_HPARAM_{m}={best.get(m, {}).get(\"param\", 0.5)}') +" "${AUTOTUNE_JSON}")" +echo ">>> Autotuned hparams (lowest fused-topk latency):" +echo " mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7}" +echo " mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode11=${BEST_HPARAM_11} mode13=${BEST_HPARAM_13}" +echo "" + +run_mapped() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + for algo in "${sparse_algos[@]}"; do + local out="${RESULTS_DIR}/topk_mapping_${algo}_${label}_${TIMESTAMP}.log" + echo ">>> ${label} algo=${algo}" + local extra=() + if [ "${mode}" -eq 0 ]; then + extra+=(--topk-type sglang) + else + extra+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi + { time python verify_algo.py \ + --trials 8 \ + --topk-val "${TOPK_VAL}" \ + --vortex-module-name "${algo}" \ + --model-name "${MODEL_NAME}" \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 \ + "${extra[@]}" ; } \ + 2>&1 | tee "${out}" + done +} + +run_mapped 0 0.5 "sglang_m0" +run_mapped 3 "${BEST_HPARAM_3}" "sglang_m3_p${BEST_HPARAM_3}" +run_mapped 4 0.5 "sglang_m4" +run_mapped 6 "${BEST_HPARAM_6}" "sglang_m6_beta${BEST_HPARAM_6}" +run_mapped 7 "${BEST_HPARAM_7}" "sglang_m7_alpha${BEST_HPARAM_7}" +run_mapped 8 0.5 "sglang_m8" +run_mapped 9 "${BEST_HPARAM_9}" "sglang_m9_alpha${BEST_HPARAM_9}" +run_mapped 10 "${BEST_HPARAM_10}" "sglang_m10_alpha${BEST_HPARAM_10}" +run_mapped 11 "${BEST_HPARAM_11}" "sglang_m11_pivot${BEST_HPARAM_11}" +run_mapped 13 "${BEST_HPARAM_13}" "sglang_m13_alpha${BEST_HPARAM_13}" + +echo "" +echo "============================================================" +echo "All runs complete. Results in ${RESULTS_DIR}/" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Auto-tune: ${AUTOTUNE_JSON}" +echo "============================================================" diff --git a/examples/plot_parallel_comparison.py b/examples/plot_parallel_comparison.py new file mode 100644 index 00000000..3d3f4883 --- /dev/null +++ b/examples/plot_parallel_comparison.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python3 +"""Aggregate baseline / fused / adaptive (split) TopK latencies across remap +functions and emit CSV tables + matplotlib bar plots. + +Reads one remap_bench_*.json per (topk_val, num_splits) tag from the +directories produced by remap_function_bench_topk_parallel.sh and writes: + + results.csv long-form per-(K, splits, batch, mode, dist) rows + summary_topk.csv wide table per K (averaged across batch sizes) + summary_all.csv single combined wide table covering every K + comparison_topk.png bar plot per K (one bar group per mode) + comparison_all.png side-by-side per-K plots + +Input format: + --input "K=2048,splits=ns2=path/to/remap_bench_ns2.json" + --input "K=30,splits=auto=path/to/remap_bench_auto.json" + +The legacy "tag=path" input is also accepted; it lands in the all-K combined +plot but won't fill in the K/splits columns of results.csv. + +Usage: + python plot_parallel_comparison.py \ + --input "K=2048,splits=ns2=.../remap_bench_ns2.json" \ + --input "K=30,splits=auto=.../remap_bench_auto.json" \ + --output-dir /analysis [--emit-csv] [--emit-png] +""" +from __future__ import annotations + +import argparse +import csv +import json +import math +import re +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + + +MODE_DISPLAY = { + 0: "None", + 3: "Power", + 4: "Log", + 6: "Asinh", + 7: "Log1p", + 9: "Erf", + 10: "Tanh", + 11: "Subtract", + 13: "ExpStretch", + 15: "ShiftPow2", + 16: "ShiftPow3", + 17: "LinearSteep", + 18: "HalfSquare", + 19: "HalfCube", +} + + +# --- input parsing ----------------------------------------------------------- + + +def _parse_input_spec(spec: str) -> Tuple[str, dict, Path]: + """Parse "K=2048,splits=ns2=path" -> ("K=2048,splits=ns2", {"K": "2048", + "splits": "ns2"}, Path("path")). + + Falls back to the legacy "tag=path" format if no comma-separated + key=value attrs precede the trailing "=path" segment. + """ + if "=" not in spec: + raise SystemExit(f"--input expects tag=path, got {spec!r}") + # Split on the LAST '=' that doesn't follow a comma (the path delim). + # Handle by scanning from the right. + eq_positions = [i for i, ch in enumerate(spec) if ch == "="] + path: str = "" + label: str = spec + for idx in reversed(eq_positions): + candidate_path = spec[idx + 1:] + if "/" in candidate_path or candidate_path.endswith(".json"): + label = spec[:idx] + path = candidate_path + break + if not path: + # last-resort: split on the rightmost '=' + label, path = spec.rsplit("=", 1) + + p = Path(path) + if not p.exists(): + raise SystemExit(f"{p} not found (input spec: {spec!r})") + + attrs: Dict[str, str] = {} + for segment in label.split(","): + segment = segment.strip() + if not segment or "=" not in segment: + continue + k, v = segment.split("=", 1) + attrs[k.strip()] = v.strip() + return label, attrs, p + + +def _load_rows(json_path: Path) -> List[dict]: + with open(json_path) as f: + data = json.load(f) + return data if isinstance(data, list) else data.get("results", []) + + +# --- aggregation ------------------------------------------------------------ + + +def _per_mode_rows(rows: List[dict], distribution: str | None = None): + """Yield one dict per (config, mode) pair so we can flatten to CSV.""" + for cfg in rows: + if distribution is not None and cfg.get("distribution") != distribution: + continue + cfg_keys = { + "batch_size": cfg.get("batch_size"), + "num_kv_heads": cfg.get("num_kv_heads"), + "seq_len": cfg.get("seq_len"), + "topk_val": cfg.get("topk_val"), + "distribution": cfg.get("distribution"), + "pages_per_seg": cfg.get("pages_per_seg"), + "head": cfg.get("head", "all"), + "baseline_ms": cfg.get("baseline_ms"), + } + for m in cfg.get("modes", []): + mode_id = m.get("mode") + if mode_id is None or mode_id < 0: + continue + yield { + **cfg_keys, + "mode": mode_id, + "mode_name": MODE_DISPLAY.get(mode_id, m.get("mode_name", f"m{mode_id}")), + "power": m.get("power"), + "fused_ms": m.get("fused_ms") + if m.get("fused_ms") is not None + else (m.get("topk_after_remap_ms") + if mode_id == 0 else None), + "parallel_ms": m.get("parallel_ms"), + "parallel_splits": m.get("parallel_splits"), + "remap_ms": m.get("remap_ms"), + "split_total_ms": m.get("split_total_ms"), + } + + +def _aggregate_per_mode(rows: List[dict], distribution: str = "real"): + """Return { mode -> {baseline, fused, parallel} } averaged across configs. + Falls back to all distributions if `distribution` is empty for these rows. + """ + used_dist = distribution + flat = list(_per_mode_rows(rows, distribution)) + if not flat: + used_dist = None + flat = list(_per_mode_rows(rows, None)) + out: Dict[int, Dict[str, List[float]]] = {} + for r in flat: + bucket = out.setdefault(r["mode"], + {"baseline_ms": [], "fused_ms": [], "parallel_ms": []}) + if r.get("baseline_ms") is not None: bucket["baseline_ms"].append(r["baseline_ms"]) + if r.get("fused_ms") is not None: bucket["fused_ms"].append(r["fused_ms"]) + if r.get("parallel_ms") is not None: bucket["parallel_ms"].append(r["parallel_ms"]) + summary = { + m: {k: (sum(v) / len(v) if v else float("nan")) for k, v in sub.items()} + for m, sub in out.items() + } + return summary, used_dist + + +# --- CSV writers ------------------------------------------------------------ + + +def _write_results_csv(records: List[dict], out_path: Path) -> None: + """Long-form per-(K, splits, batch, mode, dist) records → results.csv.""" + if not records: + out_path.write_text("") + return + fields = list({k for r in records for k in r.keys()}) + # Stable column order: identifiers first. + preferred = [ + "label", "K", "splits", "batch_size", "num_kv_heads", "seq_len", + "topk_val", "distribution", "pages_per_seg", "head", + "mode", "mode_name", "power", + "baseline_ms", "fused_ms", "parallel_ms", "parallel_splits", + "remap_ms", "split_total_ms", + ] + head = [c for c in preferred if c in fields] + tail = sorted(c for c in fields if c not in preferred) + cols = head + tail + with open(out_path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=cols) + w.writeheader() + for r in records: + w.writerow({c: r.get(c, "") for c in cols}) + print(f" wrote {out_path} ({len(records)} rows)") + + +def _write_summary_csv(tag: str, summary: Dict[int, Dict[str, float]], + out_path: Path, *, attrs: Dict[str, str] | None = None) -> None: + attrs = attrs or {} + cols = ["K", "splits", "tag", "mode", "mode_name", + "baseline_ms", "fused_ms", "parallel_ms", + "fused_speedup", "parallel_speedup"] + with open(out_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(cols) + for mode_id in sorted(summary): + s = summary[mode_id] + base = s.get("baseline_ms", float("nan")) + fused = s.get("fused_ms", float("nan")) + par = s.get("parallel_ms", float("nan")) + fs = (base / fused) if fused and not math.isnan(fused) else float("nan") + ps = (base / par) if par and not math.isnan(par) else float("nan") + w.writerow([ + attrs.get("K", ""), + attrs.get("splits", ""), + tag, + mode_id, + MODE_DISPLAY.get(mode_id, f"m{mode_id}"), + _csv_num(base), _csv_num(fused), _csv_num(par), + _csv_num(fs), _csv_num(ps), + ]) + print(f" wrote {out_path}") + + +def _write_summary_all_csv(summaries, attrs_by_tag, out_path: Path) -> None: + cols = ["tag", "K", "splits", "mode", "mode_name", + "baseline_ms", "fused_ms", "parallel_ms", + "fused_speedup", "parallel_speedup"] + with open(out_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(cols) + for tag, summary in summaries.items(): + attrs = attrs_by_tag.get(tag, {}) + for mode_id in sorted(summary): + s = summary[mode_id] + base = s.get("baseline_ms", float("nan")) + fused = s.get("fused_ms", float("nan")) + par = s.get("parallel_ms", float("nan")) + fs = (base / fused) if fused and not math.isnan(fused) else float("nan") + ps = (base / par) if par and not math.isnan(par) else float("nan") + w.writerow([ + tag, + attrs.get("K", ""), + attrs.get("splits", ""), + mode_id, + MODE_DISPLAY.get(mode_id, f"m{mode_id}"), + _csv_num(base), _csv_num(fused), _csv_num(par), + _csv_num(fs), _csv_num(ps), + ]) + print(f" wrote {out_path}") + + +def _csv_num(x): + if x is None or (isinstance(x, float) and math.isnan(x)): + return "" + return f"{x:.6f}" + + +# --- plotting --------------------------------------------------------------- + + +def _plot_bars(tag: str, summary: Dict[int, Dict[str, float]], out_path: Path) -> None: + modes = sorted(summary.keys()) + labels = [MODE_DISPLAY.get(m, f"m{m}") for m in modes] + base = [summary[m].get("baseline_ms", float("nan")) for m in modes] + fused = [summary[m].get("fused_ms", float("nan")) for m in modes] + par = [summary[m].get("parallel_ms", float("nan")) for m in modes] + + x = np.arange(len(modes)) + w = 0.27 + fig, ax = plt.subplots(figsize=(max(8, 0.85 * len(modes)), 5)) + ax.bar(x - w, base, w, label="Baseline (sglang)", color="#888888") + ax.bar(x, fused, w, label="Fused (sglang_fused)", color="#4C72B0") + ax.bar(x + w, par, w, label="Adaptive (output_adaptive)", color="#C44E52") + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=30, ha="right") + ax.set_ylabel("Latency (ms, lower is better)") + ax.set_title(f"TopK kernel latency — {tag}") + ax.grid(True, axis="y", linestyle="--", alpha=0.4) + ax.legend(loc="upper right") + fig.tight_layout() + fig.savefig(out_path, dpi=150) + plt.close(fig) + print(f" wrote {out_path}") + + +def _plot_combined(summaries: Dict[str, Dict[int, Dict[str, float]]], out_path: Path) -> None: + if not summaries: + return + tags = list(summaries.keys()) + fig, axes = plt.subplots(1, len(tags), figsize=(max(8, 7 * len(tags)), 5), sharey=False) + if len(tags) == 1: + axes = [axes] + for ax, tag in zip(axes, tags): + summary = summaries[tag] + modes = sorted(summary.keys()) + labels = [MODE_DISPLAY.get(m, f"m{m}") for m in modes] + base = [summary[m].get("baseline_ms", float("nan")) for m in modes] + fused = [summary[m].get("fused_ms", float("nan")) for m in modes] + par = [summary[m].get("parallel_ms", float("nan")) for m in modes] + x = np.arange(len(modes)) + w = 0.27 + ax.bar(x - w, base, w, label="Baseline", color="#888888") + ax.bar(x, fused, w, label="Fused", color="#4C72B0") + ax.bar(x + w, par, w, label="Adaptive", color="#C44E52") + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=30, ha="right") + ax.set_ylabel("Latency (ms)") + ax.set_title(tag) + ax.grid(True, axis="y", linestyle="--", alpha=0.4) + ax.legend(loc="upper right", fontsize=8) + fig.suptitle("Adaptive (split) vs Fused vs Baseline TopK", y=1.02) + fig.tight_layout() + fig.savefig(out_path, bbox_inches="tight", dpi=150) + plt.close(fig) + print(f" wrote {out_path}") + + +# --- main ------------------------------------------------------------------- + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--input", action="append", required=True, + help="=path/to/remap_bench_*.json (repeatable).") + p.add_argument("--output-dir", required=True) + p.add_argument("--distribution", default="real", + help="Distribution column to aggregate (falls back to all).") + p.add_argument("--emit-csv", action="store_true", + help="Write CSV tables (always on; flag kept for explicitness).") + p.add_argument("--emit-png", action="store_true", + help="Write PNG plots (always on; flag kept for explicitness).") + args = p.parse_args() + # Always emit both — flags are kept so the shell wrapper can document intent. + args.emit_csv = True + args.emit_png = True + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + summaries: Dict[str, Dict[int, Dict[str, float]]] = {} + attrs_by_tag: Dict[str, Dict[str, str]] = {} + long_form: List[dict] = [] + + for spec in args.input: + tag, attrs, path = _parse_input_spec(spec) + rows = _load_rows(path) + + # Long-form rows for results.csv + for r in _per_mode_rows(rows, args.distribution) or _per_mode_rows(rows, None): + long_form.append({ + "label": tag, + "K": attrs.get("K", ""), + "splits": attrs.get("splits", ""), + **r, + }) + + summary, _used_dist = _aggregate_per_mode(rows, distribution=args.distribution) + summaries[tag] = summary + attrs_by_tag[tag] = attrs + + if args.emit_csv: + _write_summary_csv(tag, summary, + out_dir / f"summary_{_safe(tag)}.csv", + attrs=attrs) + if args.emit_png: + _plot_bars(tag, summary, out_dir / f"comparison_{_safe(tag)}.png") + + if args.emit_csv: + _write_results_csv(long_form, out_dir / "results.csv") + _write_summary_all_csv(summaries, attrs_by_tag, out_dir / "summary_all.csv") + + if args.emit_png and len(summaries) > 1: + _plot_combined(summaries, out_dir / "comparison_all.png") + + +def _safe(tag: str) -> str: + """Make a tag safe for use in a filename (strip ',', '=', '/').""" + return re.sub(r"[^A-Za-z0-9._-]", "_", tag) + + +if __name__ == "__main__": + main() diff --git a/examples/remap_function_bench_topk2028.sh b/examples/remap_function_bench_topk2028.sh new file mode 100755 index 00000000..26c529c4 --- /dev/null +++ b/examples/remap_function_bench_topk2028.sh @@ -0,0 +1,284 @@ +#!/usr/bin/env bash +# ============================================================ +# Remap Function Benchmark +# +# Compares four kernel configurations for TopK page selection: +# 1. baseline — unmapped topk (topk_output_sglang) +# 2. fused remap + topk — topk_output_sglang_fused +# 3. remap only — topk_remap_only (standalone kernel) +# 4. unmapped topk on remapped — topk_output_sglang on the output +# buffer of step 3 +# +# Per configuration the script also reports the threshold-bin +# position, the threshold-bin size, and how many values are +# selected from the threshold bin (derived from +# topk_profile_counters — collected after all timing measurements, +# never interleaved with latency measurements). +# +# Pipeline: +# 1. Calibrate — run `calibrate_topk.py` on the chosen model to +# collect the REAL per-segment topk distribution +# (raw_histograms.npy). Skippable via +# --real-histograms /path/to/raw_histograms.npy. +# 2. Autotune — run `autotune_topk_mapping.py` on those real +# histograms and pick the per-mode hyperparameter +# with the LOWEST measured topk kernel latency. +# 3. Remap bench— run `bench_topk.py --remap-bench` with the +# autotune-selected per-mode hyperparameters. +# +# Argument layout mirrors run_distribution_analysis_new.sh. +# +# Usage: +# # Default (Qwen/Qwen3-1.7B, block_size=16): +# bash remap_function_bench.sh --gpu 5 +# +# # Larger model + larger page/block size: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --block-size 32 \ +# --seq-len 16384 --topk-val 512 \ +# --modes "0 3 6 7" +# +# # Reuse an existing calibration: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --real-histograms /path/to/calibration/raw_histograms.npy +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +MEM=0.7 +# Cap KV / VTX sparse prefill buffer sizing during Step 1 (see calibrate_topk.py --help). +MAX_TOTAL_TOKENS=64768 +# Min free GiB on the output-dir filesystem before Step 1 (HF weights + cache + logs). +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +SAMPLE_STRIDE=1 +SEQ_LEN=32768 +BLOCK_SIZE=1 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +# Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their +# mapping happens inside compute_stage1_bin, not apply_transform, so +# split-phase timing isn't meaningful for them. +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17 18 19" +# Fallback hparam used only if autotune is explicitly skipped. +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# Empty by default — Step 1 will calibrate on the selected model. +# Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. +# REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms.npy" +#REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-4B.npy" +REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-1.7B.npy" +SKIP_AUTOTUNE=0 +# Optional: pre-built autotune JSON to bypass Step 2 entirely. When set, +# Step 2 is skipped and Step 3 reads its per-mode hparams from this file +# instead. Useful for verification runs where we want to pin the exact +# (mode, hparam) pairs without re-running the latency sweep. +PINNED_AUTOTUNE_JSON="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + --pinned-autotune-json) PINNED_AUTOTUNE_JSON="$2"; SKIP_AUTOTUNE=1; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# Qwen3-1.7B does not use DeepGEMM (no FP8/MoE path). +# Disable its JIT to silence "NVCC Compiler not found ... use NVRTC" on Blackwell. +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" + +# If DeepGEMM JIT is ever re-enabled, make sure it can find nvcc. +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +# Validate seq_len: need pages/seg > topk_val (3 reserved pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +# Calibration artifacts live on /var/tmp (large disk), keyed by model. +# Example: /var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-8B.npy +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +# If no explicit --real-histograms and a cached file exists, reuse it. +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Remap Function Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " Min free disk: ${MIN_FREE_DISK_GB} GiB (Step 1 preflight; 0 = skip)" +echo " GPU: ${GPU_ID}" +echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-distribution topk histograms ── +# calibrate_topk.py runs the model end-to-end with histogram profiling +# enabled and writes per-segment raw_histograms.npy. The histograms are +# aggregated over every layer and every decode/prefill step so the +# autotune in Step 2 sees the true attention-score distribution. +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + # Promote raw_histograms.npy to the shared per-model cache path. + mv -f "${CALIBRATION_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 1: Done. raw_histograms -> ${REAL_HIST_PATH}" + echo ">>> Step 1: Staging dir (lut/quantiles/logs): ${CALIBRATION_DIR}" +fi + +# Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so +# lut.npy / quantiles.npy produced by calibration are no longer consumed. + +# ── Step 2: Auto-tune hyperparameters by profiled fused-topk latency ── +# For every (mode, hparam) combo in the sweep grid, the autotune runs the +# fused remap+topk kernel on the real histogram and measures end-to-end +# kernel latency with CUDA events. The per-mode hparam with the lowest +# measured topk kernel latency wins. +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then + echo "" + if [ -n "${PINNED_AUTOTUNE_JSON}" ]; then + echo ">>> Step 2: SKIPPED (pinned hparams from ${PINNED_AUTOTUNE_JSON})" + AUTOTUNE_ARGS="--autotune-json ${PINNED_AUTOTUNE_JSON}" + else + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" + fi +else + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + AUTOTUNE_ARGS="--autotune-json ${AUTOTUNE_JSON}" +fi + +# ── Step 3: Remap benchmark (baseline / fused / remap / split) ── +echo "" +echo ">>> Step 3: Timing remap / topk / fused / baseline with autotuned hparams" +REMAP_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + ${AUTOTUNE_ARGS} \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_remap_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${REMAP_JSON}" + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Remap Function Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All outputs in: ${RUN_DIR}/" +echo " calibration/raw_histograms.npy — real topk distribution (per layer)" +echo " autotune_results.json — latency-ranked mapping hparams" +echo " remap_bench.json — per-config remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/remap_function_bench_topk30.sh b/examples/remap_function_bench_topk30.sh new file mode 100755 index 00000000..3843906c --- /dev/null +++ b/examples/remap_function_bench_topk30.sh @@ -0,0 +1,267 @@ +#!/usr/bin/env bash +# ============================================================ +# Remap Function Benchmark +# +# Compares four kernel configurations for TopK page selection: +# 1. baseline — unmapped topk (topk_output_sglang) +# 2. fused remap + topk — topk_output_sglang_fused +# 3. remap only — topk_remap_only (standalone kernel) +# 4. unmapped topk on remapped — topk_output_sglang on the output +# buffer of step 3 +# +# Per configuration the script also reports the threshold-bin +# position, the threshold-bin size, and how many values are +# selected from the threshold bin (derived from +# topk_profile_counters — collected after all timing measurements, +# never interleaved with latency measurements). +# +# Pipeline: +# 1. Calibrate — run `calibrate_topk.py` on the chosen model to +# collect the REAL per-segment topk distribution +# (raw_histograms.npy). Skippable via +# --real-histograms /path/to/raw_histograms.npy. +# 2. Autotune — run `autotune_topk_mapping.py` on those real +# histograms and pick the per-mode hyperparameter +# with the LOWEST measured topk kernel latency. +# 3. Remap bench— run `bench_topk.py --remap-bench` with the +# autotune-selected per-mode hyperparameters. +# +# Argument layout mirrors run_distribution_analysis_new.sh. +# +# Usage: +# # Default (Qwen/Qwen3-1.7B, block_size=16): +# bash remap_function_bench.sh --gpu 5 +# +# # Larger model + larger page/block size: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --block-size 32 \ +# --seq-len 16384 --topk-val 512 \ +# --modes "0 3 6 7" +# +# # Reuse an existing calibration: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --real-histograms /path/to/calibration/raw_histograms.npy +# # Tight GPU: lower calibration KV cap (default 1048576): +# bash remap_function_bench_topk30.sh --gpu 0 --max-total-tokens 524288 +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=1 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +MEM=0.7 +MAX_TOTAL_TOKENS=1048576 +ALGO="block_sparse_attention" +SAMPLE_STRIDE=1 +SEQ_LEN=32768 +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +# Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their +# mapping happens inside compute_stage1_bin, not apply_transform, so +# split-phase timing isn't meaningful for them. +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17 18 19 20" +# Fallback hparam used only if autotune is explicitly skipped. +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# Empty by default — Step 1 will calibrate on the selected model. +# Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. +REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-4B.npy" +SKIP_AUTOTUNE=0 + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# Qwen3-1.7B does not use DeepGEMM (no FP8/MoE path). +# Disable its JIT to silence "NVCC Compiler not found ... use NVRTC" on Blackwell. +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" + +# If DeepGEMM JIT is ever re-enabled, make sure it can find nvcc. +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +# Validate seq_len: need pages/seg > topk_val (3 reserved pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +# Calibration artifacts live on /var/tmp (large disk), keyed by model. +# Example: /var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-1.7B.npy +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +# If no explicit --real-histograms and a cached file exists, reuse it. +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Remap Function Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " GPU: ${GPU_ID}" +echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-distribution topk histograms ── +# calibrate_topk.py runs the model end-to-end with histogram profiling +# enabled and writes per-segment raw_histograms.npy. The histograms are +# aggregated over every layer and every decode/prefill step so the +# autotune in Step 2 sees the true attention-score distribution. +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + # Promote raw_histograms.npy to the shared per-model cache path. + mv -f "${CALIBRATION_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 1: Done. raw_histograms -> ${REAL_HIST_PATH}" + echo ">>> Step 1: Staging dir (lut/quantiles/logs): ${CALIBRATION_DIR}" +fi + +# Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so +# lut.npy / quantiles.npy produced by calibration are no longer consumed. + +# ── Step 2: Auto-tune hyperparameters by profiled fused-topk latency ── +# For every (mode, hparam) combo in the sweep grid, the autotune runs the +# fused remap+topk kernel on the real histogram and measures end-to-end +# kernel latency with CUDA events. The per-mode hparam with the lowest +# measured topk kernel latency wins. +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then + echo "" + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" +else + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + AUTOTUNE_ARGS="--autotune-json ${AUTOTUNE_JSON}" +fi + +# ── Step 3: Remap benchmark (baseline / fused / remap / split) ── +echo "" +echo ">>> Step 3: Timing remap / topk / fused / baseline with autotuned hparams" +REMAP_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + ${AUTOTUNE_ARGS} \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_remap_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${REMAP_JSON}" + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Remap Function Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All outputs in: ${RUN_DIR}/" +echo " calibration/raw_histograms.npy — real topk distribution (per layer)" +echo " autotune_results.json — latency-ranked mapping hparams" +echo " remap_bench.json — per-config remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/remap_function_bench_topk_parallel.sh b/examples/remap_function_bench_topk_parallel.sh new file mode 100755 index 00000000..7be48b88 --- /dev/null +++ b/examples/remap_function_bench_topk_parallel.sh @@ -0,0 +1,156 @@ +#!/usr/bin/env bash +# ============================================================ +# Three-way TopK kernel latency comparison for K=30. +# +# Compares (per (batch_size, pages)): +# topk.cu -> topk_output (CUB BlockRadixSort full sort) +# topk_sglang.cu -> topk_output_sglang + +# topk_output_sglang_fused (2-stage radix select) +# topk_sglang_merge.cu -> topk_output_adaptive (adaptive split SELECT32_SORT32) +# +# Pages are varied by --seq-lens (with --page-size 1: pages == seq_len). +# Default sweep is the matrix the user requested: +# batch_sizes = {1, 2, 4, 8, 16} +# pages = {4096, 8192, 16384} +# topk = 30 +# +# No calibration, no remap autotune, no model download — purely synthetic +# scores so the only variable is the kernel itself. +# +# Usage: +# bash examples/remap_function_bench_topk_parallel.sh --gpu 0 +# bash examples/remap_function_bench_topk_parallel.sh --gpu 0 \ +# --batch-sizes "1 2 4 8 16" \ +# --seq-lens "4096 8192 16384 32768" +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults (matrix the brief calls out) ───────────────────── +GPU_ID=0 +TOPK_VALS="30" +BATCH_SIZES="1 2 4 8 16" +SEQ_LENS="4096 8192 16384" # pages-per-seg when page-size=1 +NUM_KV_HEADS=8 +PAGE_SIZE=1 +RESERVED_BOS=1 +RESERVED_EOS=2 +DISTRIBUTIONS="normal" +WARMUP=20 +REPEAT=200 + +# ── Arg parsing ─────────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --topk-vals) TOPK_VALS="$2"; shift 2 ;; + --batch-sizes) BATCH_SIZES="$2"; shift 2 ;; + --seq-lens) SEQ_LENS="$2"; shift 2 ;; + --page-size) PAGE_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --reserved-bos) RESERVED_BOS="$2"; shift 2 ;; + --reserved-eos) RESERVED_EOS="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + *) echo "Unknown option: $1" >&2; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +TOPK_TAG="$(echo ${TOPK_VALS} | tr ' ' '-')" +RUN_DIR="${RESULTS_DIR}/three_way_topk${TOPK_TAG}_bs${PAGE_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" +JSON_PATH="${RUN_DIR}/three_way.json" +CSV_PATH="${RUN_DIR}/summary.csv" + +echo "============================================================" +echo "Three-way TopK kernel comparison" +echo " TopK sweep: ${TOPK_VALS}" +echo " Batch sizes: ${BATCH_SIZES}" +echo " Seq lengths: ${SEQ_LENS} (page_size=${PAGE_SIZE})" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " GPU: ${GPU_ID}" +echo " Warmup/repeat: ${WARMUP}/${REPEAT}" +echo " Output dir: ${RUN_DIR}" +echo "============================================================" + +# ── Run bench_topk.py with all (B, seq_len, K) combos in one shot ── +# --mapping-modes 0 = MAPPING_NONE → no remap, no autotune needed. +# --remap-bench = drives the per-config table that includes baseline +# (topk_sglang) + naive (topk.cu) + sglang_ori rows. +# --bench-parallel = adds the topk_sglang_merge adaptive measurement +# into each row (under "parallel_ms"). +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --bench-parallel \ + --batch-sizes ${BATCH_SIZES} \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens ${SEQ_LENS} \ + --topk-vals ${TOPK_VALS} \ + --page-size "${PAGE_SIZE}" \ + --reserved-bos "${RESERVED_BOS}" \ + --reserved-eos "${RESERVED_EOS}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes 0 \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${JSON_PATH}" \ + 2>&1 | tee "${RUN_DIR}/bench_topk.log" + +# ── Aggregate to a clean CSV: one row per (B, pages, K, dist) ───── +python - "${JSON_PATH}" "${CSV_PATH}" <<'PY' +import csv, json, sys +src, dst = sys.argv[1], sys.argv[2] +with open(src) as f: + data = json.load(f) +rows = data["results"] if isinstance(data, dict) and "results" in data else data +# Header. Latencies in microseconds. +hdr = [ + "topk", "batch_size", "pages", "distribution", + "cub_topk_us", # topk.cu / topk_output (None when pages > 8192) + "sglang_baseline_us", # topk_sglang.cu / topk_output_sglang + "sglang_fused_us", # topk_sglang.cu / topk_output_sglang_fused (==baseline @ MAPPING_NONE) + "adaptive_us", # topk_sglang_merge.cu / topk_output_adaptive + "speedup_adaptive_vs_fused", + "speedup_adaptive_vs_cub", +] +with open(dst, "w", newline="") as f: + w = csv.writer(f) + w.writerow(hdr) + for r in rows: + B = r["batch_size"]; pg = r["pages_per_seg"]; K = r["topk_val"]; dist = r["distribution"] + cub = r.get("naive_ms") + baseline = r.get("baseline_ms") + none_mode = next((m for m in r["modes"] if m.get("mode_name") == "None"), None) + adaptive = none_mode.get("parallel_ms") if none_mode else None + # At MAPPING_NONE the fused kernel == baseline kernel (no remap branch), + # so report baseline as the fused number too for clarity. + fused = baseline + def us(x): return f"{x*1000:.3f}" if x is not None else "" + sp_f = f"{baseline/adaptive:.3f}" if (adaptive and baseline) else "" + sp_c = f"{cub/adaptive:.3f}" if (adaptive and cub) else "" + w.writerow([K, B, pg, dist, us(cub), us(baseline), us(fused), us(adaptive), sp_f, sp_c]) +print(f"wrote {dst}") +PY + +# ── Print human-readable summary table ────────────────────────── +echo "" +echo "============================================================" +echo "Summary (us per kernel call; speedup = fused_us / adaptive_us)" +echo "============================================================" +column -t -s, "${CSV_PATH}" || cat "${CSV_PATH}" + +echo "" +echo "Done. Results:" +echo " raw JSON: ${JSON_PATH}" +echo " summary: ${CSV_PATH}" +echo " log: ${RUN_DIR}/bench_topk.log" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh new file mode 100755 index 00000000..38438bde --- /dev/null +++ b/examples/run_distribution_analysis_new.sh @@ -0,0 +1,196 @@ +#!/usr/bin/env bash +# ============================================================ +# Bucket Distribution / Remap Latency Pipeline (parametric modes) +# +# Tests the surviving parametric mapping modes after the lean +# refactor: +# Mode 3 (Power): y = sign(x) * |x|^p +# Mode 6 (Asinh): y = asinh(beta * x) +# Mode 7 (Log1p): y = sign(x) * log1p(alpha * |x|) +# Mode 9 (Erf): y = erf(alpha * x) +# Mode 10 (Tanh): y = tanh(alpha * x) +# Mode 13 (ExpStretch): y = exp(alpha * x) +# +# Pipeline: +# 1. Calibrate — collect real-distribution histograms from the +# chosen model (skippable via --real-histograms). +# 2. Autotune — rank per-mode hparams by measured fused-topk +# kernel latency (lowest wins). +# 3. Remap bench— bench_topk.py --remap-bench fed with the +# autotune JSON. Reports per-mode remap / topk / +# fused / baseline latencies and threshold stats. +# +# Usage: +# bash run_distribution_analysis_new.sh --gpu 5 +# bash run_distribution_analysis_new.sh --gpu 5 \ +# --model-name Qwen/Qwen3-8B --block-size 32 +# bash run_distribution_analysis_new.sh --gpu 5 \ +# --real-histograms /path/to/raw_histograms.npy +# bash run_distribution_analysis_new.sh --gpu 5 --max-total-tokens 524288 +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +MEM=0.7 +MAX_TOTAL_TOKENS=1048576 +ALGO="block_sparse_attention" +SEQ_LEN=65536 +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="bucket_uniform normal" +# LUT_CDF (1) / QUANTILE (2) are evaluated only when calibration produces +# lut.npy / quantiles.npy. 0 baseline is always included by --remap-bench. +MAPPING_MODES="1 2 3 6 7 8 9 10 11 13" +REPEAT=100 +WARMUP=20 +REAL_HISTOGRAMS="" + +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN}" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/dist_analysis_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Bucket Distribution / Remap Latency Pipeline (parametric modes)" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " GPU: ${GPU_ID}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate ─────────────────────────────────────────── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" +fi + +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + +# ── Step 2: Autotune (latency-ranked) ─────────────────────────── +echo "" +echo ">>> Step 2: Auto-tuning hyperparameters by fused-topk kernel latency" +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +AUTOTUNE_EXTRA=() +[ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") +[ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" +echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + +# ── Step 3: Remap bench with autotuned hparams ────────────────── +echo "" +echo ">>> Step 3: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" +BENCH_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") +[ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --autotune-json "${AUTOTUNE_JSON}" \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${BENCH_JSON}" + +# ── Summary ───────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Bucket Distribution / Remap Latency Pipeline Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All outputs in: ${RUN_DIR}/" +echo " calibration/raw_histograms.npy — real topk distribution" +echo " autotune_results.json — latency-ranked hparams" +echo " remap_bench.json — remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh new file mode 100755 index 00000000..f3eabff9 --- /dev/null +++ b/examples/run_topk_benchmark.sh @@ -0,0 +1,253 @@ +#!/usr/bin/env bash +# ============================================================ +# Unified TopK Benchmark +# +# Three-step pipeline on a single configurable model: +# Step 1: Calibrate — run the model to collect +# real-distribution histograms +# (raw_histograms.npy, lut.npy, +# quantiles.npy). +# Step 2: Latency autotune + bench — rank per-mode hparams by +# measured fused-topk kernel +# latency, then run the +# remap / topk / fused / baseline +# comparison. +# Step 3: E2E accuracy — verify_algo.py on the same +# model for the unmapped baseline +# plus each mapping mode, with +# autotuned hparams. +# +# Usage: +# bash run_topk_benchmark.sh --gpu 0 +# bash run_topk_benchmark.sh --gpu 0 --model-name Qwen/Qwen3-8B \ +# --block-size 32 --topk-val 512 +# bash run_topk_benchmark.sh --gpu 0 --max-total-tokens 1048576 +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +TRIALS=8 +MEM=0.7 +MAX_TOTAL_TOKENS=1048576 +ALGO="block_sparse_attention" +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +SEQ_LEN=32768 +BENCHMARKS="amc23" +SKIP_CALIBRATE=false +SKIP_KERNEL=false +SKIP_E2E=true + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --trials) TRIALS="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --skip-calibrate) SKIP_CALIBRATE=true; shift ;; + --skip-kernel) SKIP_KERNEL=true; shift ;; + --skip-e2e) SKIP_E2E=false; shift ;; # --skip-e2e actually toggles it OFF (enables) + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/topk_benchmark_${MODEL_SLUG}_${BENCH_LABEL}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Unified TopK Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN}" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Trials: ${TRIALS}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " GPU: ${GPU_ID}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate ──────────────────────────────────────── +CALIBRATION_DIR="${RUN_DIR}/calibration" +if [ "${SKIP_CALIBRATE}" = true ] && [ -d "${CALIBRATION_DIR}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (--skip-calibrate)" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — real topk histograms + LUT/quantiles" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + echo ">>> Step 1: Done." +fi + +REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIBRATION_DIR}/lut.npy" ] && LUT_PATH="${CALIBRATION_DIR}/lut.npy" +[ -f "${CALIBRATION_DIR}/quantiles.npy" ] && Q_PATH="${CALIBRATION_DIR}/quantiles.npy" +[ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" +[ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" + +# ── Step 2: Latency autotune + remap bench ─────────────────── +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_KERNEL}" = true ]; then + echo "" + echo ">>> Step 2: SKIPPED (--skip-kernel)" +else + echo "" + echo ">>> Step 2a: Auto-tuning per-mode hparams by fused-topk kernel latency" + AUTOTUNE_EXTRA=() + [ -f "${REAL_HIST_PATH}" ] && AUTOTUNE_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --page-size "${BLOCK_SIZE}" \ + --warmup 20 --repeat 100 \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2a_autotune.log" + echo ">>> Step 2a: Done. Autotune saved to ${AUTOTUNE_JSON}" + + echo "" + echo ">>> Step 2b: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" + BENCH_JSON="${RUN_DIR}/kernel_latency.json" + BENCH_EXTRA=() + [ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions normal bucket_uniform \ + --mapping-modes 0 1 2 3 6 7 8 9 10 11 13 \ + --autotune-json "${AUTOTUNE_JSON}" \ + "${BENCH_EXTRA[@]}" \ + --warmup 20 --repeat 100 \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2b_kernel_bench.log" + echo ">>> Step 2b: Done. Results saved to ${BENCH_JSON}" +fi + +# ── Step 3: E2E accuracy ───────────────────────────────────── +if [ "${SKIP_E2E}" = true ]; then + echo "" + echo ">>> Step 3: SKIPPED (default). Pass --skip-e2e to toggle it ON." +else + echo "" + echo ">>> Step 3: E2E accuracy comparison" + + # Extract autotuned hparams per mode. + eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode'); lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + print(f'BEST_HPARAM_{m}={best.get(m, {}).get(\"param\", 0.5)}') +" "${AUTOTUNE_JSON}")" + + E2E_DIR="${RUN_DIR}/e2e" + mkdir -p "${E2E_DIR}" + + run_e2e() { + # $1=label, remaining args passed to verify_algo.py + local label="$1"; shift + local logfile="${E2E_DIR}/${label}.log" + echo "" + echo " --- ${label} ---" + { time python "${SCRIPT_DIR}/verify_algo.py" \ + --trials "${TRIALS}" \ + --topk-val "${TOPK_VAL}" \ + --model-name "${MODEL_NAME}" \ + --benchmark ${BENCHMARKS} \ + --mem "${MEM}" \ + "$@" ; } \ + 2>&1 | tee "${logfile}" + } + + run_mapped() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + local extra=(--vortex-module-name "${ALGO}") + if [ "${mode}" -eq 0 ]; then + extra+=(--topk-type sglang) + else + extra+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi + run_e2e "${label}" "${extra[@]}" + } + + run_e2e "full_attention_baseline" --full-attention + run_e2e "naive_topk" --vortex-module-name "${ALGO}" --topk-type naive + run_mapped 0 0.5 "sglang_m0_none" + run_mapped 3 "${BEST_HPARAM_3}" "sglang_m3_power_p${BEST_HPARAM_3}" + run_mapped 4 0.5 "sglang_m4_log" + run_mapped 6 "${BEST_HPARAM_6}" "sglang_m6_asinh_beta${BEST_HPARAM_6}" + run_mapped 7 "${BEST_HPARAM_7}" "sglang_m7_log1p_alpha${BEST_HPARAM_7}" + run_mapped 8 0.5 "sglang_m8_trunc8" + run_mapped 9 "${BEST_HPARAM_9}" "sglang_m9_erf_alpha${BEST_HPARAM_9}" + run_mapped 10 "${BEST_HPARAM_10}" "sglang_m10_tanh_alpha${BEST_HPARAM_10}" + run_mapped 11 "${BEST_HPARAM_11}" "sglang_m11_subtract_pivot${BEST_HPARAM_11}" + run_mapped 13 "${BEST_HPARAM_13}" "sglang_m13_expstretch_alpha${BEST_HPARAM_13}" + + echo "" + echo ">>> Step 3: Done. E2E logs saved to ${E2E_DIR}/" +fi + +# ── Final Summary ───────────────────────────────────────────── +echo "" +echo "============================================================" +echo "TopK Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All results: ${RUN_DIR}" +echo " Calibration: ${CALIBRATION_DIR}" +[ "${SKIP_KERNEL}" != true ] && echo " Autotune: ${AUTOTUNE_JSON}" +[ "${SKIP_KERNEL}" != true ] && echo " Kernel JSON: ${RUN_DIR}/kernel_latency.json" +[ "${SKIP_E2E}" != true ] && echo " E2E logs: ${RUN_DIR}/e2e/" +echo "============================================================" diff --git a/examples/verify_aim24.py b/examples/verify_aim24.py new file mode 100644 index 00000000..9e54a967 --- /dev/null +++ b/examples/verify_aim24.py @@ -0,0 +1,106 @@ +import json +import sys +sys.path.append("../") +import python.sglang as sgl +from transformers import AutoTokenizer +import os +from tqdm import tqdm +import time +import torch +os.environ["TOKENIZERS_PARALLELISM"] = "false" +MATH_QUERY_TEMPLATE = """ +Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. + +{Question} +""".strip() + +from datasets import load_dataset, Dataset, concatenate_datasets +def generate_requests(dataset: Dataset, field_name: str, data_format: str, trial: int = 1, rank: int = 0, world_size: int = 1): + requests = [] + + # Step 1: Expand dataset trial times + if trial > 1: + dataset = Dataset.from_dict(dataset.to_dict().copy())  # ensure copy + datasets = [dataset] * trial + dataset = concatenate_datasets(datasets) + + total = len(dataset) + + # Step 2: Partition across ranks + per_proc = total // world_size + remainder = total % world_size + start = rank * per_proc + min(rank, remainder) + end = start + per_proc + (1 if rank < remainder else 0) + subset = dataset.select(list(range(start, end))) + + # Step 3: Format requests + for data in dataset: + conversations = [ + {"role": "user", "content": data_format.format(Question=data[field_name])} + ] + data["conversations"] = conversations + requests.append(data) + + return requests + + +def main(): + model_name = "Qwen/Qwen3-0.6B" + llm = sgl.Engine(model_path=model_name, + disable_cuda_graph=False, + page_size=16, + vortex_num_selected_pages=29, + disable_overlap_schedule=True, + attention_backend="flashinfer", + enable_vortex_sparsity=True, + vortex_page_reserved_bos=1, + vortex_page_reserved_eos=2, + vortex_layers_skip=list(range(1)), + mem_fraction_static=0.9, + vortex_cg=True, + vortex_graph=True, + vortex_module_name="block_sparse_attention", + vortex_max_seq_lens=20480 + ) + + dataset = load_dataset("HuggingFaceH4/aime_2024", split="train") + + requests = generate_requests(dataset, "problem", MATH_QUERY_TEMPLATE) + + + + texts = [ + x["conversations"] for x in requests + ] + + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompts = [ + tokenizer.apply_chat_template( + text, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True + ) for text in texts + ] * 8 + + sampling_params = {"temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_new_tokens": 16384} + total_tokens = 0 + total_time = 0.0 + start = time.perf_counter() + o = llm.generate(prompts, sampling_params) + elapsed = time.perf_counter() - start + total_time += elapsed + e2e_time = 0 + with open(f"0.6B_VTX_CG_TP1_16K.jsonl", "w", encoding="utf-8") as f: + for item in o: + total_tokens += item["meta_info"]["completion_tokens"] + e2e_time = max(e2e_time, item["meta_info"]["e2e_latency"]) + json.dump(item, f, ensure_ascii=False) + f.write("\n") + + meta_data = {"e2e_time": e2e_time, "total_time": total_time, "total_tokens": total_tokens, "throughput": total_tokens / total_time} + json.dump(meta_data, f, ensure_ascii=False) + f.write("\n") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/verify_algo.py b/examples/verify_algo.py index e290a81b..dacba655 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -11,7 +11,11 @@ from lighteval.models.model_output import ModelResponse from datasets import load_dataset, Dataset, concatenate_datasets import argparse +import ast import json +import os +import subprocess +import sys MATH_QUERY_TEMPLATE = """ Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. @@ -47,6 +51,63 @@ def generate_requests(dataset: Dataset, field_name: str, data_format: str, trial return requests +BENCHMARK_REGISTRY = { + "amc23": { + "type": "jsonl", + "path": "amc23.jsonl", + "prompt_key": "prompt", + "answer_key": "answer", + "question_key": "question", + }, + "aime24": { + "type": "huggingface", + "path": "HuggingFaceH4/aime_2024", + "split": "train", + "field_name": "problem", + "answer_key": "answer", + }, +} + +def _load_benchmark(benchmark_name: str, trials: int, tokenizer=None): + """Load benchmark data and return (prompts, requests) tuple.""" + cfg = BENCHMARK_REGISTRY[benchmark_name] + + if cfg["type"] == "jsonl": + script_dir = os.path.dirname(os.path.abspath(__file__)) + jsonl_path = os.path.join(script_dir, cfg["path"]) + with open(jsonl_path, "r", encoding="utf-8") as f: + requests = [json.loads(line) for line in f] + requests = requests * trials + prompts = [req[cfg["prompt_key"]] for req in requests] + return prompts, requests + + elif cfg["type"] == "huggingface": + dataset = load_dataset(cfg["path"], split=cfg["split"]) + hf_requests = generate_requests(dataset, cfg["field_name"], MATH_QUERY_TEMPLATE) + # Normalize keys: ensure "question" and "answer" exist + for req in hf_requests: + if "question" not in req and cfg["field_name"] in req: + req["question"] = req[cfg["field_name"]] + # Build chat-template prompts if tokenizer is provided + if tokenizer is not None: + texts = [x["conversations"] for x in hf_requests] + prompts = [ + tokenizer.apply_chat_template( + text, tokenize=False, add_generation_prompt=True, enable_thinking=True + ) for text in texts + ] * trials + hf_requests = hf_requests * trials + else: + prompts = [ + MATH_QUERY_TEMPLATE.format(Question=x[cfg["field_name"]]) for x in hf_requests + ] * trials + hf_requests = hf_requests * trials + return prompts, hf_requests + + else: + raise ValueError(f"Unknown benchmark type: {cfg['type']}") + + def verify_algos( trials: int = 2, topk_val: int = 30, @@ -54,13 +115,19 @@ def verify_algos( vortex_module_name: str = "gqa_block_sparse_attention", model_name: str = "Qwen/Qwen3-1.7B", sparse_attention: bool = True, -mem: float = 0.8 -): +mem: float = 0.8, +kv_cache_dtype: str = "auto", +topk_type: str = "naive", +topk_mapping_mode: int = 0, +topk_mapping_hparam: float = 0.5, +disable_cuda_graph: bool = False, +benchmark: str = "amc23", +): - llm = sgl.Engine(model_path=model_name, - disable_cuda_graph=False, + llm = sgl.Engine(model_path=model_name, + disable_cuda_graph=disable_cuda_graph, page_size=page_size, - vortex_topk_val=topk_val, + vortex_topk_val=topk_val, disable_overlap_schedule=True, attention_backend="flashinfer", enable_vortex_sparsity=sparse_attention, @@ -69,17 +136,17 @@ def verify_algos( vortex_layers_skip=list(range(1)), vortex_module_name=vortex_module_name, vortex_max_seq_lens=12288, - mem_fraction_static=mem + mem_fraction_static=mem, + kv_cache_dtype=kv_cache_dtype, + vortex_topk_type=topk_type, + vortex_topk_mapping_mode=topk_mapping_mode, + vortex_topk_mapping_hparam=topk_mapping_hparam, ) - - with open("examples/amc23.jsonl", "r", encoding="utf-8") as f: - requests = [json.loads(line) for line in f] - - requests = requests * trials - prompts = [req["prompt"] for req in requests] + tokenizer = AutoTokenizer.from_pretrained(model_name) if benchmark != "amc23" else None + prompts, requests = _load_benchmark(benchmark, trials, tokenizer=tokenizer) sampling_params = {"temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_new_tokens": 8192} - + o = llm.generate(prompts, sampling_params) gold_metric = MultilingualExtractiveMatchMetric( language=Language.ENGLISH, @@ -89,7 +156,7 @@ def verify_algos( pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)), aggregation_function=max, ) - + results = [] for data, item in zip(requests, o): golds = [data["answer"]] @@ -99,7 +166,7 @@ def verify_algos( result = gold_metric.compute(model_response=ModelResponse(text=[predictions]), doc=target) except: result = 0.0 - + results.append( { "score": float(result), @@ -110,7 +177,15 @@ def verify_algos( "num_tokens": item["meta_info"]["completion_tokens"] } ) - + # --- Per-question debug output --- + # print(f"[Q{len(results):03d}] score={float(result):.1f} " + # f"tokens={item['meta_info']['completion_tokens']} " + # f"latency={item['meta_info']['e2e_latency']:.2f}s " + # f"gold={golds[0]}") + # print(f" question: {data['question'][:120]}...") + # print(f" prediction: {predictions[:200]}...") + # print() + total_accuracy = 0.0 total_tokens = 0 @@ -130,12 +205,17 @@ def verify_algos( if sparse_attention: llm_cfg = AutoConfig.from_pretrained(model_name) - flow = vortex_torch.flow.build_vflow(vortex_module_name) - memory_access_runtime = flow.run_indexer_virtual( - group_size=llm_cfg.num_attention_heads // llm_cfg.num_key_value_heads, - page_size=page_size, - head_dim=llm_cfg.head_dim, - ) + flow = vortex_torch.flow.build_vflow(vortex_module_name) + try: + memory_access_runtime = flow.run_indexer_virtual( + group_size=llm_cfg.num_attention_heads // llm_cfg.num_key_value_heads, + page_size=page_size, + head_dim=llm_cfg.head_dim, + ) + except Exception: + # External algorithms (nsa, fsa, flash_moba) override run_indexer_virtual + # to return 0 since their vendored kernels don't participate in vortex profiling + memory_access_runtime = 0.0 else: memory_access_runtime = 0.0 @@ -203,20 +283,101 @@ def parse_args(): default=0.8, help="memory fraction in sglang", ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + default="auto", + choices=["auto", "fp8_e5m2", "fp8_e4m3", "int8"], + help='KV cache dtype (default: "auto").', + ) + + parser.add_argument( + "--topk-type", + type=str, + default="naive", + choices=["naive", "sglang", "sglang_fused"], + help='TopK kernel type: "naive" (CUB radix), "sglang" (unmapped baseline), "sglang_fused" (fused remap + topk). Default: "naive".', + ) + parser.add_argument( + "--topk-mapping-mode", + type=int, + default=0, + choices=[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20], + help='TopK mapping mode for sglang_fused: 0=none, 1=lut_cdf (calibrated), ' + '2=quantile (calibrated), 3=power, 4=log, 6=asinh, 7=log1p, 8=trunc8, ' + '9=erf, 10=tanh, 11=subtract, 13=exp_stretch, 15=shift_pow2, ' + '16=shift_pow3, 17=linear_steep, 18=half_square, 19=half_cube, ' + '20=dense_mant (default: 0).', + ) + + parser.add_argument( + "--topk-mapping-hparam", "--topk-mapping-power", + type=float, + default=0.5, + dest="topk_mapping_hparam", + help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 6), alpha (mode 7/9/10/13), rho (mode 12/14). Default: 0.5.', + ) + + parser.add_argument( + "--benchmark", + type=str, + nargs="+", + default=["amc23"], + help="Benchmark(s) to run. Available: amc23, aime24. " + "Use multiple values to run several benchmarks sequentially (default: amc23).", + ) + + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Optional path. When set, a JSON list of per-benchmark summary dicts is " + "dumped here after all benchmarks finish. Used by the ablation wrappers.", + ) + return parser.parse_args() if __name__ == "__main__": args = parse_args() - summary = verify_algos( - trials=args.trials, - topk_val=args.topk_val, - page_size=args.page_size, - vortex_module_name=args.vortex_module_name, - model_name=args.model_name, - sparse_attention=not(args.full_attention), - mem=args.mem - ) - print(summary) + all_summaries = [] + for bench_name in args.benchmark: + if bench_name not in BENCHMARK_REGISTRY: + print(f"WARNING: Unknown benchmark '{bench_name}', skipping. Available: {list(BENCHMARK_REGISTRY.keys())}") + continue + print(f"\n{'='*60}") + print(f"Benchmark: {bench_name}") + print(f"{'='*60}") + summary = verify_algos( + trials=args.trials, + topk_val=args.topk_val, + page_size=args.page_size, + vortex_module_name=args.vortex_module_name, + model_name=args.model_name, + sparse_attention=not(args.full_attention), + mem=args.mem, + kv_cache_dtype=args.kv_cache_dtype, + topk_type=args.topk_type, + topk_mapping_mode=args.topk_mapping_mode, + topk_mapping_hparam=args.topk_mapping_hparam, + benchmark=bench_name, + ) + summary["benchmark"] = bench_name + summary["model_name"] = args.model_name + summary["topk_val"] = args.topk_val + summary["page_size"] = args.page_size + summary["topk_type"] = args.topk_type + summary["topk_mapping_mode"] = args.topk_mapping_mode + summary["topk_mapping_hparam"] = args.topk_mapping_hparam + summary["full_attention"] = bool(args.full_attention) + print(summary) + all_summaries.append(summary) + + if args.output_json: + os.makedirs(os.path.dirname(os.path.abspath(args.output_json)) or ".", exist_ok=True) + with open(args.output_json, "w") as f: + json.dump(all_summaries, f, indent=2) + print(f"\n[verify_algo] summary JSON written to {args.output_json}") exit(0) \ No newline at end of file diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 17c2a5ed..7a96d1e7 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,17 +1,26 @@ #!/usr/bin/env bash set -e +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +export CUDA_VISIBLE_DEVICES=5 sparse_algos=( - + "block_sparse_attention" ) -for algo in "${sparse_algos[@]}"; do - echo ">>> Running verify_algo.py with --vortex-module-name ${algo}" - python examples/verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --mem 0.7 -done +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_bf16_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype bf16" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_quant.sh new file mode 100644 index 00000000..a2663e97 --- /dev/null +++ b/examples/verify_algo_quant.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -e +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +export CUDA_VISIBLE_DEVICES=5 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype int8 \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/examples/verify_algo_topk.sh b/examples/verify_algo_topk.sh new file mode 100644 index 00000000..6b2744ae --- /dev/null +++ b/examples/verify_algo_topk.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=5 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +REPEAT_COUNT="${REPEAT_COUNT:-3}" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +for repeat_idx in $(seq 1 "${REPEAT_COUNT}"); do + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_naive_${TIMESTAMP}_run${repeat_idx}.log" + echo ">>> Run ${repeat_idx}/${REPEAT_COUNT}: verify_algo.py with --vortex-module-name ${algo} --topk-type naive" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +for repeat_idx in $(seq 1 "${REPEAT_COUNT}"); do + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_sglang_${TIMESTAMP}_run${repeat_idx}.log" + echo ">>> Run ${repeat_idx}/${REPEAT_COUNT}: verify_algo.py with --vortex-module-name ${algo} --topk-type sglang" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh new file mode 100644 index 00000000..2cdc5265 --- /dev/null +++ b/examples/verify_algo_topk_mapping_new.sh @@ -0,0 +1,221 @@ +#!/usr/bin/env bash +# ============================================================ +# E2E accuracy sweep over the surviving parametric mapping modes. +# Each mode runs verify_algo.py with the per-mode hyperparameter +# that autotune_topk_mapping.py picked as having the lowest +# measured fused-topk-kernel latency. +# +# Mapping modes (after the lean refactor): +# 0: None — unmapped baseline (no remap) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) [no knob] +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 13: ExpStretch — y = exp(alpha * x) +# ============================================================ +set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=5 +TOPK_VAL=30 +BENCHMARKS="amc23" +MODEL_NAME="Qwen/Qwen3-1.7B" +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=2 +SEQ_LEN=32768 +MAX_TOTAL_TOKENS=1048576 +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +sparse_algos=( "block_sparse_attention" ) + +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RESULTS_DIR="results/topk_mapping_${MODEL_SLUG}_topk${TOPK_VAL}_${BENCH_LABEL}" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +# ============================================================ +# Step 0: Calibrate (optional) — real-distribution histograms +# ============================================================ +if [ -z "${REAL_HISTOGRAMS}" ]; then + echo "============================================================" + echo "Step 0: Calibrating ${MODEL_NAME} for real-distribution histograms" + echo " Max total tokens (KV / VTX cap): ${MAX_TOTAL_TOKENS}" + echo "============================================================" + CAL_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + mkdir -p "${CAL_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem 0.7 \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${sparse_algos[0]}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CAL_DIR}" \ + 2>&1 | tee "${RESULTS_DIR}/calibrate_${TIMESTAMP}.log" + REAL_HISTOGRAMS="${CAL_DIR}/raw_histograms.npy" +fi + +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HISTOGRAMS}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + +# ============================================================ +# Step 1: Auto-tune — rank by profiled fused-topk kernel latency +# ============================================================ +AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" +if [ "${SKIP_AUTOTUNE}" -eq 0 ]; then + echo "============================================================" + echo "Step 1: Auto-tuning hyperparameters by fused-topk kernel latency" + echo "============================================================" + AUTOTUNE_EXTRA=() + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val ${TOPK_VAL} \ + --batch-size ${BATCH_SIZE} \ + --seq-len ${SEQ_LEN} \ + --num-kv-heads ${NUM_KV_HEADS} \ + --page-size ${BLOCK_SIZE} \ + --real-histograms "${REAL_HISTOGRAMS}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" +fi + +# Extract best per-mode hparam (ranked by measured kernel latency, lowest wins) +eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode') + lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + v = best.get(m, {}).get('param', 0.5) + print(f'BEST_HPARAM_{m}={v}') +" "${AUTOTUNE_JSON}")" +echo ">>> Autotuned hparams (lowest topk kernel latency):" +echo " mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7}" +echo " mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode11=${BEST_HPARAM_11} mode13=${BEST_HPARAM_13}" +echo "" + +run_verify() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + for algo in "${sparse_algos[@]}"; do + local out="${RESULTS_DIR}/topk_mapping_${algo}_${label}_${TIMESTAMP}.log" + echo ">>> ${label} algo=${algo}" + local extra_args=() + if [ "${mode}" -eq 0 ]; then + extra_args+=(--topk-type sglang) + else + extra_args+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi + { time python verify_algo.py \ + --trials 8 \ + --topk-val "${TOPK_VAL}" \ + --vortex-module-name "${algo}" \ + --model-name "${MODEL_NAME}" \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 \ + "${extra_args[@]}" ; } \ + 2>&1 | tee "${out}" + done +} + +echo "============================================================" +echo "Baseline: sglang (no remap)" +echo "============================================================" +run_verify 0 0.5 "sglang_m0" + +echo "============================================================" +echo "Mode 3 (power) — p=${BEST_HPARAM_3} (autotuned)" +echo "============================================================" +run_verify 3 "${BEST_HPARAM_3}" "sglang_m3_p${BEST_HPARAM_3}" + +echo "============================================================" +echo "Mode 4 (log)" +echo "============================================================" +run_verify 4 0.5 "sglang_m4" + +echo "============================================================" +echo "Mode 6 (asinh) — beta=${BEST_HPARAM_6} (autotuned)" +echo "============================================================" +run_verify 6 "${BEST_HPARAM_6}" "sglang_m6_beta${BEST_HPARAM_6}" + +echo "============================================================" +echo "Mode 7 (log1p) — alpha=${BEST_HPARAM_7} (autotuned)" +echo "============================================================" +run_verify 7 "${BEST_HPARAM_7}" "sglang_m7_alpha${BEST_HPARAM_7}" + +echo "============================================================" +echo "Mode 9 (erf) — alpha=${BEST_HPARAM_9} (autotuned)" +echo "============================================================" +run_verify 9 "${BEST_HPARAM_9}" "sglang_m9_alpha${BEST_HPARAM_9}" + +echo "============================================================" +echo "Mode 10 (tanh) — alpha=${BEST_HPARAM_10} (autotuned)" +echo "============================================================" +run_verify 10 "${BEST_HPARAM_10}" "sglang_m10_alpha${BEST_HPARAM_10}" + +echo "============================================================" +echo "Mode 8 (trunc8)" +echo "============================================================" +run_verify 8 0.5 "sglang_m8" + +echo "============================================================" +echo "Mode 11 (subtract) — pivot=${BEST_HPARAM_11} (autotuned)" +echo "============================================================" +run_verify 11 "${BEST_HPARAM_11}" "sglang_m11_pivot${BEST_HPARAM_11}" + +echo "============================================================" +echo "Mode 13 (exp_stretch) — alpha=${BEST_HPARAM_13} (autotuned)" +echo "============================================================" +run_verify 13 "${BEST_HPARAM_13}" "sglang_m13_alpha${BEST_HPARAM_13}" + +echo "" +echo "============================================================" +echo "All runs complete. Results in ${RESULTS_DIR}/" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Auto-tune: ${AUTOTUNE_JSON}" +echo " Mode 3 (power): p = ${BEST_HPARAM_3} (autotuned)" +echo " Mode 6 (asinh): beta = ${BEST_HPARAM_6} (autotuned)" +echo " Mode 7 (log1p): alpha = ${BEST_HPARAM_7} (autotuned)" +echo " Mode 9 (erf): alpha = ${BEST_HPARAM_9} (autotuned)" +echo " Mode 10 (tanh): alpha = ${BEST_HPARAM_10} (autotuned)" +echo " Mode 13 (exp_stretch):alpha = ${BEST_HPARAM_13} (autotuned)" +echo "============================================================" diff --git a/examples/verify_external_backends.sh b/examples/verify_external_backends.sh new file mode 100755 index 00000000..12600d08 --- /dev/null +++ b/examples/verify_external_backends.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=6 + +sparse_algos=( + "nsa" + "fsa" + "flash_moba" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_bf16_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done diff --git a/setup.py b/setup.py index e2723268..e886eaca 100644 --- a/setup.py +++ b/setup.py @@ -16,15 +16,24 @@ sources=[ 'csrc/register.cc', 'csrc/utils_sglang.cu', - 'csrc/topk.cu' + 'csrc/topk.cu', + 'csrc/topk_sglang.cu', + 'csrc/topk_sglang_profile.cu', + 'csrc/topk_sglang_ori.cu', + 'csrc/topk_sglang_merge.cu', + 'csrc/topk_adaptive_profile.cu', ], include_dirs=['csrc'], extra_compile_args={ 'cxx': ['-O3'], 'nvcc': [ '-O3', + '-gencode=arch=compute_86,code=sm_86', '-gencode=arch=compute_89,code=sm_89', - '-gencode=arch=compute_90,code=sm_90' + '-gencode=arch=compute_90,code=sm_90', + '-gencode=arch=compute_100a,code=sm_100a', + '-gencode=arch=compute_120,code=sm_120' + ], }, ), diff --git a/third_party/sglang b/third_party/sglang index e383c0fd..b7825d08 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit e383c0fdd551f74f24d247e8a7cc8013861949ad +Subproject commit b7825d08399fccdf1f29a5380d6601fcef59aca1 diff --git a/vortex_torch/attention_backend/__init__.py b/vortex_torch/attention_backend/__init__.py new file mode 100644 index 00000000..9ca7855b --- /dev/null +++ b/vortex_torch/attention_backend/__init__.py @@ -0,0 +1,3 @@ +# Vendored sparse attention backends for Vortex forward_extend. +# NSA and FSA are pure Triton kernels. +# FlashMoBA requires flash_moba_cuda C++ extension (pip install flash_moba). diff --git a/vortex_torch/attention_backend/flashmoba/__init__.py b/vortex_torch/attention_backend/flashmoba/__init__.py new file mode 100644 index 00000000..aa912b91 --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/__init__.py @@ -0,0 +1,13 @@ +from .flash_moba_interface import ( + flash_moba_varlen_func, + flash_moba_attn_varlen_func, + flash_topk_varlen_func, + decide_lg_block_m, +) + +__all__ = [ + "flash_moba_varlen_func", + "flash_moba_attn_varlen_func", + "flash_topk_varlen_func", + "decide_lg_block_m", +] diff --git a/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py b/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py new file mode 100644 index 00000000..c196c21d --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py @@ -0,0 +1,730 @@ +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import os + +try: + import flash_moba_cuda as flash_moba_gpu +except ImportError: + flash_moba_gpu = None +from .triton_mean_pool import flash_topk_mean_pool + +########################################################################################################################## +# Helper functions +########################################################################################################################## + +def round_multiple(x: int, m: int) -> int: + """Round x up to the nearest multiple of m.""" + return ((x + m - 1) // m) * m + +########################################################################################################################## + +def decide_lg_block_m(top_k: int, chunk_size: int, seqlen: int, causal: bool = False) -> int: + sparsity = 0.0 + budget = top_k * chunk_size + if causal: + density = (2*(budget * seqlen) - budget**2) / (seqlen**2) + else: + density = budget / seqlen + + sparsity = 1 - density + + if sparsity <= 0.5: + lg_block_m = 128 + elif sparsity <= 0.7: + lg_block_m = 256 + elif sparsity <= 0.8: + lg_block_m = 512 + elif sparsity <= 0.9: + lg_block_m = 768 + else: + lg_block_m = 1024 + + # [Optimization] Hardware-aware cap for A6000/3090/4090 to avoid Shared Memory OOM + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + # sm86 (A6000, 3090) and sm89 (4090, L40) have smaller shared memory than A100 (sm80) + if major == 8 and minor > 0: + lg_block_m = min(lg_block_m, 512) + + return lg_block_m + +########################################################################################################################## + +# torch.compile() support is only enabled for pytorch >= 2.4 +# The reason for this is that we are using the new custom_op and register_fake +# APIs, which support inplace modification of inputs in the function itself +if torch.__version__ >= "2.4.0": + _torch_custom_op_wrapper = torch.library.custom_op + _torch_register_fake_wrapper = torch.library.register_fake +else: + def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): + def wrap(func): + return func + if fn is None: + return wrap + return fn + def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): + def wrap(func): + return func + if fn is None: + return wrap + return fn + _torch_custom_op_wrapper = noop_custom_op_wrapper + _torch_register_fake_wrapper = noop_register_fake_wrapper + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +########################################################################################################################## +# Custom ops +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_moba_fused_topk", mutates_args=(), device_types="cuda") +def _moba_fused_topk( + q: torch.Tensor, + km: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_seqlens_km: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_topk: int, + moba_chunk_size: int, + causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, km = [maybe_contiguous(x) for x in (q, km)] + + col_offsets, col_nnz, indices, _, _ = flash_moba_gpu.moba_fused_topk( + q, + km, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_km, + max_seqlen_q, + max_seqlen_k, + moba_topk, + moba_chunk_size, + causal, + ) + return col_offsets, col_nnz, indices + +@_torch_register_fake_wrapper("flash_moba::_moba_fused_topk") +def _moba_fused_topk_fake( + q: torch.Tensor, + km: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_seqlens_km: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_topk: int, + moba_chunk_size: int, + causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, km = [maybe_contiguous(x) for x in (q, km)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + max_lg_col_num = (max_seqlen_k + moba_chunk_size - 1) // moba_chunk_size + + col_offsets = torch.empty((batch_size, num_heads, max_lg_col_num), device=q.device, dtype=torch.int64) + col_nnz = torch.empty((batch_size, num_heads, max_lg_col_num), device=q.device, dtype=torch.int32) + indices = torch.empty((total_q * num_heads * moba_topk), device=q.device, dtype=torch.int32) + + return col_offsets, col_nnz, indices + +if torch.__version__ >= "2.4.0": + _wrapped_moba_fused_topk = torch.ops.flash_moba._moba_fused_topk +else: + _wrapped_moba_fused_topk = _moba_fused_topk + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_varlen_sort", mutates_args=(), device_types="cuda") +def _varlen_sort( + col_offsets: torch.Tensor, + col_nnz: torch.Tensor, + indices: torch.Tensor, +) -> torch.Tensor: + col_offset_ends = col_offsets.view(-1) + col_nnz.view(-1) + return flash_moba_gpu.varlen_sort( + col_offsets.view(-1), col_offset_ends, indices + ) + +@_torch_register_fake_wrapper("flash_moba::_varlen_sort") +def _varlen_sort_fake( + col_offsets: torch.Tensor, + col_nnz: torch.Tensor, + indices: torch.Tensor, +) -> torch.Tensor: + # varlen_sort is out-of-place + col_offset_ends = col_offsets.view(-1) + col_nnz.view(-1) + return torch.empty_like(indices) + +if torch.__version__ >= "2.4.0": + _wrapped_varlen_sort = torch.ops.flash_moba._varlen_sort +else: + _wrapped_varlen_sort = _varlen_sort + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_flash_moba_attn_varlen_forward", mutates_args=(), device_types="cuda") +def _flash_moba_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + return_softmax: bool = False, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + moba_col_offsets = maybe_contiguous(moba_col_offsets) + moba_col_nnz = maybe_contiguous(moba_col_nnz) + moba_row_indices = maybe_contiguous(moba_row_indices) + + out, softmax_lse, S_dmask, rng_state = flash_moba_gpu.moba_varlen_fwd( + q, + k, + v, + None, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + leftpad_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + softcap, + return_softmax, + lg_block_m, + lg_block_n, + None, + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + return out, softmax_lse, S_dmask, rng_state + +@_torch_register_fake_wrapper("flash_moba::_flash_moba_attn_varlen_forward") +def _flash_moba_attn_varlen_forward_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + return_softmax: bool = False, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + out = torch.empty_like(q) + softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) + p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) + seqlen_q_rounded = round_multiple(max_seqlen_q, 128) + seqlen_k_rounded = round_multiple(max_seqlen_k, 128) + if return_softmax: + p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) + rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) + return out, softmax_lse, p, rng_state + +if torch.__version__ >= "2.4.0": + _wrapped_flash_moba_attn_varlen_forward = torch.ops.flash_moba._flash_moba_attn_varlen_forward +else: + _wrapped_flash_moba_attn_varlen_forward = _flash_moba_attn_varlen_forward + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_flash_moba_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") +def _flash_moba_attn_varlen_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float, + alibi_slopes: Optional[torch.Tensor], + deterministic: bool, + rng_state: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> torch.Tensor: + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + ( + dq, + dk, + dv, + softmax_d, + ) = flash_moba_gpu.moba_varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + softcap, + deterministic, + lg_block_m, + lg_block_n, + None, + rng_state, + ) + # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return softmax_d + +@_torch_register_fake_wrapper("flash_moba::_flash_moba_attn_varlen_backward") +def _flash_moba_attn_varlen_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float, + alibi_slopes: Optional[torch.Tensor], + deterministic: bool, + rng_state: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> torch.Tensor: + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + if dq is None: + dq = torch.empty_like(q) + if dk is None: + dk = torch.empty_like(k) + if dv is None: + dv = torch.empty_like(v) + softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) + + return softmax_d + +if torch.__version__ >= "2.4.0": + _wrapped_flash_moba_attn_varlen_backward = torch.ops.flash_moba._flash_moba_attn_varlen_backward +else: + _wrapped_flash_moba_attn_varlen_backward = _flash_moba_attn_varlen_backward + +########################################################################################################################## + +class FlashMobaAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal, + softcap, + alibi_slopes, + deterministic, + return_softmax, + is_grad_enabled, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_moba_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal=causal, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + ) + if is_grad: + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, + moba_col_offsets, moba_col_nnz, moba_row_indices + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.lg_block_m = lg_block_m + ctx.lg_block_n = lg_block_n + + out = out_padded[..., :head_size_og] + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, moba_col_offsets, moba_col_nnz, moba_row_indices = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + head_size_og = dout.size(2) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + _wrapped_flash_moba_attn_varlen_backward( + dout_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + ctx.lg_block_m, + ctx.lg_block_n, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + +########################################################################################################################## + +def flash_topk_varlen_func( + q, + k, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_topk, + moba_chunk_size, + causal=False, +): + """ + Computes the top-k indices for Mixture-of-Blocks Attention (MOBA). + This function handles variable length sequences. + + Args: + q (torch.Tensor): Query tensor of shape (total_q, num_heads, head_size). + k (torch.Tensor): Key tensor of shape (total_k, num_heads, head_size). + cu_seqlens_q (torch.Tensor): Cumulative sequence lengths for queries, shape (batch_size + 1,). + cu_seqlens_k (torch.Tensor): Cumulative sequence lengths for keys, shape (batch_size + 1,). + max_seqlen_q (int): Maximum sequence length for queries. + max_seqlen_k (int): Maximum sequence length for keys. + moba_topk (int): The number of top-k elements to select. + moba_chunk_size (int): The chunk size for MOBA. + causal (bool): Whether to apply causal masking. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - col_offsets (torch.Tensor): Column offsets for the sparse matrix. + - col_nnz (torch.Tensor): Number of non-zero elements per column block. + - indices (torch.Tensor): The top-k indices. + """ + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + + km, cu_seqlens_km, _ = flash_topk_mean_pool(k, cu_seqlens_k, max_seqlen_k, moba_chunk_size) + + col_offsets, col_nnz, indices = _wrapped_moba_fused_topk( + q, + km, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_km, + max_seqlen_q, + max_seqlen_k, + moba_topk, + moba_chunk_size, + causal=causal + ) + + indices = _wrapped_varlen_sort( + col_offsets, col_nnz, indices + ) + + return col_offsets, col_nnz, indices + +########################################################################################################################## + +def flash_moba_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m=64, + lg_block_n=64, + dropout_p=0.0, + softmax_scale=None, + causal=False, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + moba_col_offsets: Optional[torch.Tensor]. Column offsets for MOBA sparse pattern. + Shape: (batch_size, num_heads, max_lg_col_num), dtype: int64 + moba_col_nnz: Optional[torch.Tensor]. Non-zero counts per column for MOBA sparse pattern. + Shape: (batch_size, num_heads, max_lg_col_num), dtype: int32 + moba_row_indices: Optional[torch.Tensor]. Row indices for MOBA sparse pattern (flattened). + dtype: int32 + lg_block_m: int. Logical block size in M dimension (query). Default: 64 + lg_block_n: int. Logical block size in N dimension (key). Default: 64 + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashMobaAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + ) + +########################################################################################################################## + +def flash_moba_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_chunk_size, + moba_topk, + causal=True, +): + + col_offsets, col_nnz, indices = flash_topk_varlen_func( + q, + k, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_topk, + moba_chunk_size, + causal=causal, + ) + + lg_block_m = decide_lg_block_m(moba_topk, moba_chunk_size, max_seqlen_k, causal) + + return flash_moba_attn_varlen_func( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + col_offsets, + col_nnz, + indices, + lg_block_m, + moba_chunk_size, + dropout_p=0.0, + causal=causal, + ) diff --git a/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py b/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py new file mode 100644 index 00000000..6fbd59f1 --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py @@ -0,0 +1,158 @@ +# Copyright (c) 2025, FlashMoBA Team. +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=[ + # triton.Config({'kBlockN': 16}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 64}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 64}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 64}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 64}, num_warps=8, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 128}, num_warps=8, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=8, num_stages=4), + # triton.Config({'kBlockN': 256}, num_warps=4, num_stages=3), + # triton.Config({'kBlockN': 256}, num_warps=8, num_stages=3), + # triton.Config({'kBlockN': 256}, num_warps=8, num_stages=4), + # triton.Config({'kBlockN': 256}, num_warps=16, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=8, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=16, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=16, num_stages=3), + # triton.Config({'kBlockN': 1024}, num_warps=16, num_stages=2), + ], + key=['HEAD_DIM', 'POOL_BLOCK_SIZE'], +) +@triton.jit +def mean_pool_kernel( + # Pointers to matrices + input_ptr, + output_ptr, + # Matrix dimensions + HEAD_DIM: tl.constexpr, + POOL_BLOCK_SIZE: tl.constexpr, + cu_seqlens_input, + cu_seqlens_output, + input_stride_row, input_stride_head, + output_stride_row, output_stride_head, + # Meta-parameters + kBlockN: tl.constexpr, +): + """ + Triton kernel for mean pooling over variable-length sequences. + + This kernel computes the mean of non-overlapping blocks of size `POOL_BLOCK_SIZE` + for each sequence in a batch. It is designed to handle variable sequence lengths. + + Args: + input_ptr: Pointer to the input tensor of shape (total_seqlen, num_heads, head_dim). + output_ptr: Pointer to the output tensor of shape (total_blocks, num_heads, head_dim). + HEAD_DIM: The dimension of each head. + POOL_BLOCK_SIZE: The size of the pooling window. + cu_seqlens_input: Cumulative sequence lengths of the input tensor, shape (batch_size + 1,). + cu_seqlens_output: Cumulative sequence lengths of the output tensor, shape (batch_size + 1,). + input_stride_row: Stride of the input tensor along the sequence dimension. + input_stride_head: Stride of the input tensor along the head dimension. + output_stride_row: Stride of the output tensor along the sequence dimension. + output_stride_head: Stride of the output tensor along the head dimension. + kBlockN: Block size for the sequence dimension, a meta-parameter for tuning. + """ + n_block = tl.program_id(0) + bidb = tl.program_id(1) + bidh = tl.program_id(2) + + seq_start = tl.load(cu_seqlens_input + bidb) + seq_end = tl.load(cu_seqlens_input + bidb + 1) + + block_start_row = seq_start + n_block * POOL_BLOCK_SIZE + + if seq_end <= block_start_row: + return + + actual_block_size = tl.minimum(POOL_BLOCK_SIZE, seq_end - block_start_row) + + offsets_d = tl.arange(0, HEAD_DIM) + # mask_d = offsets_d < HEAD_DIM + + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + + for block_k_start in range(0, actual_block_size, kBlockN): + offsets_k = block_k_start + tl.arange(0, kBlockN) + mask_k = offsets_k < actual_block_size + + row_indices = block_start_row + offsets_k + + input_offset = row_indices[:, None] * input_stride_row.to(tl.int64) + bidh * input_stride_head.to(tl.int64) + offsets_d[None, :] + + inp = tl.load(input_ptr + input_offset, mask=mask_k[:, None], other=0.0) + acc += tl.sum(inp, axis=0) + + # safe division + mean_val = acc / actual_block_size + + output_start = tl.load(cu_seqlens_output + bidb) + output_offset = (output_start + n_block) * output_stride_row.to(tl.int64) + bidh * output_stride_head.to(tl.int64) + offsets_d + tl.store(output_ptr + output_offset, mean_val) + + +def flash_topk_mean_pool(input, cu_seqlens_input, max_seqlen_input, pool_block_size): + """ + Performs mean pooling on variable-length sequences using a Triton kernel. + + This function takes a tensor of packed sequences and applies mean pooling over + fixed-size blocks. + + Args: + input (torch.Tensor): The input tensor of shape (total_seqlen, num_heads, head_dim). + cu_seqlens_input (torch.Tensor): Cumulative sequence lengths for the input, shape (batch_size + 1,). + max_seqlen_input (int): The maximum sequence length in the input batch. + pool_block_size (int): The size of the pooling window. + + Returns: + Tuple[torch.Tensor, torch.Tensor, int]: A tuple containing: + - output (torch.Tensor): The pooled output tensor of shape (total_blocks, num_heads, head_dim). + - cu_seqlens_output (torch.Tensor): Cumulative sequence lengths for the output. + - max_seqlen_output (int): The maximum number of blocks for any sequence in the batch. + """ + total_seqlen, head_num, head_dim = input.shape + batch_size = cu_seqlens_input.shape[0] - 1 + + max_seqlen_output = (max_seqlen_input + pool_block_size - 1) // pool_block_size + + actual_input_seqlens = cu_seqlens_input[1:] - cu_seqlens_input[:-1] + actual_output_seqlens = (actual_input_seqlens + pool_block_size - 1) // pool_block_size + cu_seqlens_output = F.pad(torch.cumsum(actual_output_seqlens, dim=0), (1, 0)).to(torch.int32) + + total_blocks = cu_seqlens_output[-1].item() + + output = torch.zeros((total_blocks, head_num, head_dim), dtype=input.dtype, device=input.device) + + grid = (max_seqlen_output, batch_size, head_num) + + mean_pool_kernel[grid]( + input, + output, + head_dim, + pool_block_size, + cu_seqlens_input, + cu_seqlens_output, + input.stride(0), input.stride(1), + output.stride(0), output.stride(1), + ) + + return output, cu_seqlens_output, max_seqlen_output + \ No newline at end of file diff --git a/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py b/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py new file mode 100644 index 00000000..acca2ac8 --- /dev/null +++ b/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py @@ -0,0 +1,2040 @@ +# Copyright 2025 Ran Yan. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +from ..nsa.topk_sparse_attention import (backward_sum_o_do, + reorder_topk_idx, + get_num_warps_stages, + is_hopper_gpu) + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def fused_fill_kernel(ptr_tile, ptr_m_i_cur_tiles, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + tl.store(ptr_tile + offsets, -1, mask=mask) # fill int32 with -1 + tl.store(ptr_m_i_cur_tiles + offsets, float("-inf"), mask=mask) + + +def fused_fill(topk_idx_permuted_tile: torch.Tensor, m_i_cur_tiles): + + numel = topk_idx_permuted_tile.numel() + BLOCK_SIZE = 1024 + + # Flatten for pointer access + tile_flat = topk_idx_permuted_tile.view(-1) + + m_i_cur_tiles_flat = m_i_cur_tiles.view(-1) + + grid = lambda meta: (triton.cdiv(numel, meta['BLOCK_SIZE']),) + + fused_fill_kernel[grid]( + tile_flat, + m_i_cur_tiles_flat, + numel, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=1, + num_stages=3, + ) + + +@triton.jit +def block_to_token_kernel( + topk_idx_ptr, + result_ptr, + N_token, + K, + min_block_id, + max_block_id, + padding_value, + ts_h, + ts_b, + ts_n, + rs_h, + rs_b, + rs_n, + num_q_loops: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(0) # token index i + pid_h = 0 + offs = tl.arange(0, BLOCK_K) # [0, 1, ..., K-1] + + offs_q = tl.arange(0, num_q_loops) + + pid_j = pid * num_q_loops + offs_q + + topk_idx_offset = pid_h * ts_h + pid_j[None, :] * K + offs[:, None] + block_ids = tl.load( + topk_idx_ptr + topk_idx_offset, mask=(pid_j < N_token)[None, :] & (offs < K)[:, None], other=padding_value + ) + + result_ptrs = result_ptr + pid_h * rs_h + block_ids * N_token + pid_j[None, :] + + mask = (block_ids >= 0) & (block_ids != padding_value) & (pid_j < N_token)[None, :] + tl.store(result_ptrs, pid_j[None, :], mask=mask) + + +def build_block_to_token_triton( + result: torch.Tensor, topk_idx: torch.Tensor, min_block_id: int, max_block_id: int, padding_value: int = -1 +): + """ + Args: + topk_idx: [num_heads, N_token, TopK], block indices per token, padded with padding_value for invalid blocks + num_blocks: int + padding_value: int + + Returns: + result: [num_blocks, N_token], token indices per block, padded by padding_value + """ + assert topk_idx.ndim == 3 + assert padding_value == -1 + num_heads, N_token, TopK = topk_idx.shape + + # 每个 token,每个head 一个 program + num_q_loops = 4 + grid = (triton.cdiv(N_token, num_q_loops),) + BLOCK_K = triton.next_power_of_2(TopK) + block_to_token_kernel[grid]( + topk_idx, + result, + N_token, + TopK, + min_block_id, + max_block_id, + padding_value, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + result.stride(0), + result.stride(1), + result.stride(2), + num_q_loops, + BLOCK_K=BLOCK_K, + num_warps=2, + num_stages=3, + ) + return result + + +@triton.jit +def reduce_kernel( + lse_ptr, # float32 [H, N] + m_ij_ptr, # float32 [H, B, N] + l_ij_first_ptr, # float32 [H, 1, N] + l_ij_rest_ptr, # float32 [H, B, N] + m_ij_last_ptr, # float32 [H, N] + o_ptr, # o: n x h x d + o_tiles_first_ptr, # o_tiles: n x h x 1 x d + o_tiles_rest_ptr, # o_tiles: n x h x b x d + acc_o_scales_first_ptr, # acc_o_scales: n x h x 1 + acc_o_scales_rest_ptr, # acc_o_scales: n x h x b + t_ptr, # topk_idx: h x n x k + token_index_mapping_ptr, + start_head_id, + num_qz_loop, + TOPK, + total_len, + # stride + stride_lse_h, + stride_lse_n, + stride_m_ij_h, + stride_m_ij_b, + stride_m_ij_n, + stride_l_ij_fh, + stride_l_ij_fb, + stride_l_ij_fn, + stride_l_ij_rh, + stride_l_ij_rb, + stride_l_ij_rn, + stride_on, + stride_oh, + stride_od, + stride_otfh, + stride_otfb, + stride_otfn, + stride_otfd, + stride_otrh, + stride_otrb, + stride_otrn, + stride_otrd, + stride_acc_fh, + stride_acc_fb, + stride_acc_fn, + stride_acc_rh, + stride_acc_rb, + stride_acc_rn, + stride_th, + stride_tn, + stride_tk, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_T: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_qy = tl.program_id(0) + pid_q = tl.program_id(1) # token + + pid_q_j = pid_q + pid_qy * num_qz_loop + if pid_q_j < total_len: + t_ptr_j = t_ptr + pid_q_j * stride_tn + + off_d = tl.arange(0, BLOCK_SIZE_D) + o_ptrs = o_ptr + pid_q_j * stride_on + off_d + last_acc_o = tl.load(o_ptrs, mask=off_d < BLOCK_SIZE_D, other=0.0) + acc_o = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) + acc_o += last_acc_o + + lse_ptrs = lse_ptr + pid_q_j * stride_lse_n + # Load lse + lse = tl.load(lse_ptrs, mask=pid_q_j < total_len, other=float("-inf")) + + # the stride is 1 for m_ij_last + m_ij_last = tl.load(m_ij_last_ptr + pid_q_j) + + for block_id in range(TOPK): + t = tl.load(t_ptr_j + block_id * stride_tk, mask=block_id < TOPK, other=-1) + if t != -1: + if t == 0: + real_block_pos = 0 + l_ij_ptr = l_ij_first_ptr + o_tiles_ptr = o_tiles_first_ptr + acc_o_scales_ptr = acc_o_scales_first_ptr + stride_l_ij_b = stride_l_ij_fb + stride_l_ij_n = stride_l_ij_fn + stride_acc_b = stride_acc_fb + stride_acc_n = stride_acc_fn + stride_otb = stride_otfb + stride_otn = stride_otfn + else: + real_block_pos = t - 1 + l_ij_ptr = l_ij_rest_ptr + o_tiles_ptr = o_tiles_rest_ptr + acc_o_scales_ptr = acc_o_scales_rest_ptr + stride_l_ij_b = stride_l_ij_rb + stride_l_ij_n = stride_l_ij_rn + stride_acc_b = stride_acc_rb + stride_acc_n = stride_acc_rn + stride_otb = stride_otrb + stride_otn = stride_otrn + + # init pointers + token_index_mapping_ptrs = ( + token_index_mapping_ptr + t.to(tl.int64) * stride_tim_b + (pid_q_j) * stride_tim_n + ) + real_token_index = tl.load(token_index_mapping_ptrs) + + m_ij = tl.load( + m_ij_ptr + t * stride_m_ij_b + pid_q_j * stride_m_ij_n, mask=pid_q_j < total_len, other=float("-inf") + ) + l_ij = tl.load( + l_ij_ptr + real_block_pos * stride_l_ij_b + real_token_index * stride_l_ij_n, + mask=real_token_index < total_len, + other=0.0, + ) + delta = lse - m_ij + + log_delta = tl.exp2(delta) + l_ij + + # Update lse + lse = m_ij + tl.log2(log_delta) + + o_tiles_ptrs = ( + o_tiles_ptr + real_block_pos.to(tl.int64) * stride_otb + (real_token_index) * stride_otn + off_d + ) + acc_o_scales_ptrs = acc_o_scales_ptr + real_block_pos * stride_acc_b + (real_token_index) * stride_acc_n + + o_tiles = tl.load(o_tiles_ptrs) + acc_o_scales_tiles = tl.load(acc_o_scales_ptrs) + acc_o = o_tiles + acc_o * acc_o_scales_tiles + + # final scale + acc_o = acc_o * tl.exp2(m_ij_last - lse) + tl.store(o_ptrs, acc_o, mask=off_d < BLOCK_SIZE_D) + + # Store back + tl.store( + lse_ptrs, + lse, + mask=pid_q_j < total_len, + ) + + +@triton.jit +def qk_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + m_i_tiles_ptr, # m_i: h x b x n + selected_tokens_ptr, # selected_tokens: sum(valid_lens), + valid_lens_ptr, # valid_lens: (h x b), + valid_start_indices_ptr, # valid_start_indices: (h x b), + num_heads, + num_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + HEAD_DIM, + # sm_scale + sm_scale, + num_q_blocks, + num_b_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_m_i_tiles_h, + stride_m_i_tiles_b, + stride_m_i_tiles_n, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_block_grid = tl.program_id(0) // num_heads # block id + head_id = tl.program_id(0) % num_heads + pid_q = tl.program_id(1) # token + + # get q k start and len after rmpad + k_len = tl.load(cu_seqlens_k + 1) + k_ptrs = tl.make_block_ptr( + base=k_ptr + head_id * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + + for bb in range(num_b_blocks): + pid_block = bb + pid_block_grid * num_b_blocks + + start_id = tl.load(valid_start_indices_ptr + head_id * num_blocks + pid_block) + valid_tokens = tl.load(valid_lens_ptr + head_id * num_blocks + pid_block) + if pid_q * BLOCK_SIZE_Q < valid_tokens: + + c = pid_block * BLOCK_SIZE_K + + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_q_blocks): + pid_q_j = pid_q * num_q_blocks + j + # Enable early return + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + q_ptrs = q_ptr + head_id * stride_qh + q_ptrs_off + # load q + q_mask = (st != -1)[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + # [BLOCK_SIZE_Q, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_Q, BLOCK_SIZE_K] + qk += tl.dot(q, k) * qk_scale + + m_i = tl.max(qk, axis=1) + + m_i_tiles_ptrs = ( + m_i_tiles_ptr + + head_id * stride_m_i_tiles_h + + pid_block * stride_m_i_tiles_b + + st * stride_m_i_tiles_n + ) + tl.store(m_i_tiles_ptrs, m_i, mask=(st != -1)) + + +@triton.jit +def forward_kernel_opt( + q_ptr, + k_ptr, + v_ptr, # V: n x h x d + o_tiles_ptr, # O: n x h x b x d + acc_o_scales_ptr, # acc_o_scales: h x b x n + m_ij_tiles_ptr, + l_ij_ptr, # h x b x n + token_index_mapping_ptr, + selected_tokens_ptr, # selected_tokens: sum(valid_lens), + valid_lens_ptr, # valid_lens: (h x b), + valid_start_indices_ptr, # valid_start_indices: (h x b), + min_block_id, + cur_max_valid_tokens, + num_heads, + num_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + HEAD_DIM, + # sm_scale + sm_scale, + num_q_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_oth, + stride_otb, + stride_otn, + stride_otd, + stride_acc_oh, + stride_acc_ob, + stride_acc_on, + stride_m_ij_tiles_h, + stride_m_ij_tiles_b, + stride_m_ij_tiles_n, + stride_l_ij_h, + stride_l_ij_b, + stride_l_ij_n, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + # get batch id and head id + pid_block = tl.program_id(0) // num_heads # block id + head_id = tl.program_id(0) % num_heads + pid_q = tl.program_id(1) # token + # seq packing is not supported yet + q_start = 0 + k_start = 0 + + k_len = tl.load(cu_seqlens_k + 1) - k_start + + start_id = tl.load(valid_start_indices_ptr + head_id * num_blocks + pid_block) + valid_tokens = tl.load(valid_lens_ptr + head_id * num_blocks + pid_block) + if num_q_blocks * pid_q * BLOCK_SIZE_Q >= valid_tokens: + return + + c = (min_block_id + pid_block) * BLOCK_SIZE_K + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + head_id * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + head_id * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + + # load v + v = tl.load(tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero") + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_q_blocks): + pid_q_j = pid_q * num_q_blocks + j + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (q_start + pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + + # load m_i + mask = st != -1 + + m_ij_tiles_ptrs = ( + m_ij_tiles_ptr + + head_id * stride_m_ij_tiles_h + + (q_start + st) * stride_m_ij_tiles_n + + (pid_block + min_block_id) * stride_m_ij_tiles_b + ) + m_ij = tl.load(m_ij_tiles_ptrs, mask=mask, other=float("-inf")) + + m_ij_tiles_prev_ptrs = ( + m_ij_tiles_ptr + + head_id * stride_m_ij_tiles_h + + (q_start + st) * stride_m_ij_tiles_n + + (pid_block + min_block_id - 1) * stride_m_ij_tiles_b + ) + m_ij_prev = tl.load(m_ij_tiles_prev_ptrs, mask=mask & (pid_block + min_block_id > 0), other=float("-inf")) + + m_i_minus_m_ij = m_ij_prev - m_ij + + q_ptrs = q_ptr + q_start * stride_qn + head_id * stride_qh + q_ptrs_off + # load q + q_mask = mask[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + + # [BLOCK_SIZE_Q, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_Q, BLOCK_SIZE_K] + qk_scale = sm_scale * 1.44269504 + qk += tl.dot(q, k) * qk_scale + + # init statistics + acc_o_buffer = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + + # load m_ij and compute l_ij + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + + # load token index mapping + token_index_mapping_ptrs = ( + token_index_mapping_ptr + (st) * stride_tim_n + (pid_block + min_block_id) * stride_tim_b + ) + token_index_mapping = tl.load(token_index_mapping_ptrs, mask=mask, other=-1) + + l_ij_ptrs = ( + l_ij_ptr + + head_id * stride_l_ij_h + + (q_start + token_index_mapping) * stride_l_ij_n + + (pid_block) * stride_l_ij_b + ) + tl.store(l_ij_ptrs, l_ij, mask=mask) + # scale acc_o + if pid_block + min_block_id == 0: + acc_o_scale = tl.full((BLOCK_SIZE_Q,), 1.0, dtype=tl.float32) + else: + acc_o_scale = tl.exp2(m_i_minus_m_ij) + + tl.store( + acc_o_scales_ptr + + head_id * stride_acc_oh + + (pid_block) * stride_acc_ob + + (q_start + token_index_mapping) * stride_acc_on, + acc_o_scale, + mask=(st != -1), + ) + + p = p.to(v.dtype) + acc_o_buffer = tl.dot(p, v) + + o_ptrs_off = token_index_mapping[:, None] * stride_otn + off_d[None, :] * stride_otd + o_ptrs = o_tiles_ptr + head_id * stride_oth + o_ptrs_off + (pid_block).to(tl.int64) * stride_otb + tl.store(o_ptrs, acc_o_buffer.to(o_tiles_ptr.dtype.element_ty), mask=q_mask) + + +def _topk_sparse_attention_fwd_opt( + q: torch.Tensor, # [total_len, num_heads, head_dim] + k: torch.Tensor, # [total_len, num_heads, head_dim] + v: torch.Tensor, # [total_len, num_heads, head_dim] + topk_idx: torch.Tensor, # [num_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + causal=True, +): + """ + TODO: Currently sequence packing is explicitly done in for loop, will merge in kernels. + """ + o = torch.empty_like(q) + total_len, num_heads, _ = q.shape + lse = torch.empty((num_heads, total_len), dtype=torch.float32, device=q.device) + + permute_results = [] + for i in range(len(cu_seqlens_q) - 1): + cu_seqlens_q_ = cu_seqlens_q[i: i + 2] - cu_seqlens_q[i] + cu_seqlens_k_ = cu_seqlens_k[i: i + 2] - cu_seqlens_k[i] + max_seqlen_q_ = cu_seqlens_q_[1] - cu_seqlens_q_[0] + max_seqlen_k_ = cu_seqlens_k_[1] - cu_seqlens_k_[0] + + q_ = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_ = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + v_ = v[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + topk_idx_ = topk_idx[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + o_seq, lse_seq, permute_results_seq = _topk_sparse_attention_fwd_opt_per_seq( + q_, + k_, + v_, + topk_idx_, + block_size, + cu_seqlens_q_, + cu_seqlens_k_, + max_seqlen_q_, + max_seqlen_k_, + sm_scale, + causal, + ) + o[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = o_seq + + lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = lse_seq + permute_results.append(permute_results_seq) + + return o, lse, permute_results + + +@triton.jit +def index_mapping_kernel( + token_index_mapping_ptr, + selected_tokens_ptr, + valid_lens_ptr, + valid_start_indices_ptr, + stride_im_h, + stride_im_b, + stride_im_n, + BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_q = tl.arange(0, BLOCK_SIZE_K) + offs_n = pid_n * BLOCK_SIZE_K + offs_q + + start_id = tl.load(valid_start_indices_ptr + pid_b) + valid_tokens = tl.load(valid_lens_ptr + pid_b) + + st_offs = start_id + offs_n + # st should be in shape [BLOCK_SIZE_K] + st_mask = offs_n < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + + token_im_ptrs = token_index_mapping_ptr + pid_b * stride_im_b + st * stride_im_n + + tl.store(token_im_ptrs, offs_n, mask=st_mask) + + +def index_mapping(token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks): + max_tokens = valid_lens.max() + BLOCK_SIZE_K = 1024 + grid = (num_blocks, triton.cdiv(max_tokens, BLOCK_SIZE_K)) + + index_mapping_kernel[grid]( + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_K, + num_warps=2, + num_stages=3, + ) + + +def online_softmax( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + block_size, + num_blocks, + head_tile, + head_dim, + sm_scale, + cu_seqlens_q, + cu_seqlens_k, +): + + # launch kernel + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_q_blocks = 8 + num_b_blocks = 1 + grid_qk = lambda META: ( + triton.cdiv(num_blocks, num_b_blocks), + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_q_blocks), + ) + qk_kernel[grid_qk]( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + head_tile, + num_blocks, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + num_q_blocks, + num_b_blocks, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + m_i_cur_tiles.stride(0), + m_i_cur_tiles.stride(1), + m_i_cur_tiles.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) + + m_ij_tiles = m_i_cur_tiles.cummax(dim=1).values + m_ij_last = m_ij_tiles[:, -1] + + return m_ij_tiles, m_ij_last + + +def qkv_kernel( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + block_size, +): + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + # a heuristic that avoids large grid size, and redudant KV loading + num_q_blocks = 8 + + grid_fwd = lambda META: ( + compute_tile_size * head_tile, + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_q_blocks), + ) + + forward_kernel_opt[grid_fwd]( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + num_q_blocks, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + v_tile.stride(0), + v_tile.stride(1), + v_tile.stride(2), + o_tiles.stride(0), + o_tiles.stride(1), + o_tiles.stride(2), + o_tiles.stride(3), + acc_o_scales.stride(0), + acc_o_scales.stride(1), + acc_o_scales.stride(2), + m_ij_tiles.stride(0), + m_ij_tiles.stride(1), + m_ij_tiles.stride(2), + l_ij.stride(0), + l_ij.stride(1), + l_ij.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_stages=3, + num_warps=4, + ) + + +def reduce_output( + lse, + o, + o_tiles_first, + o_tiles_rest, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h, + head_tile, + total_len, + TOPK, + head_dim, +): + + num_qy_loop = 4 + num_qz_loop = total_len // num_qy_loop + + grid_reduce = lambda META: ( + num_qy_loop + (total_len % num_qy_loop != 0), + num_qz_loop, + ) + + reduce_kernel[grid_reduce]( + lse, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + o, + o_tiles_first, + o_tiles_rest, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h * head_tile, + num_qz_loop, + TOPK, + total_len, + lse.stride(0), + lse.stride(1), + m_ij_tiles.stride(0), + m_ij_tiles.stride(1), + m_ij_tiles.stride(2), + l_ij_first.stride(0), + l_ij_first.stride(1), + l_ij_first.stride(2), + l_ij_rest.stride(0), + l_ij_rest.stride(1), + l_ij_rest.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + o_tiles_first.stride(0), + o_tiles_first.stride(1), + o_tiles_first.stride(2), + o_tiles_first.stride(3), + o_tiles_rest.stride(0), + o_tiles_rest.stride(1), + o_tiles_rest.stride(2), + o_tiles_rest.stride(3), + acc_o_scales_first.stride(0), + acc_o_scales_first.stride(1), + acc_o_scales_first.stride(2), + acc_o_scales_rest.stride(0), + acc_o_scales_rest.stride(1), + acc_o_scales_rest.stride(2), + topk_idx_tile.stride(0), + topk_idx_tile.stride(1), + topk_idx_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_T=triton.next_power_of_2(TOPK), + BLOCK_SIZE_D=triton.next_power_of_2(head_dim), + num_warps=1, + num_stages=2, + ) + + +def _topk_sparse_attention_fwd_opt_per_seq( + q: torch.Tensor, # [total_len, num_heads, head_dim] + k: torch.Tensor, # [total_len, num_kv_heads, head_dim] + v: torch.Tensor, # [total_len, num_kv_heads, head_dim] + topk_idx: torch.Tensor, # [num_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + causal=True, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert block_size in {16, 32, 64, 128, 256} + # shape + + total_len, num_heads, head_dim = q.shape + total_len, num_kv_heads, head_dim = k.shape + + assert num_heads % num_kv_heads == 0 + gqa_deg = num_heads // num_kv_heads + + TOPK = topk_idx.shape[-1] + + real_num_blocks = math.ceil(total_len / block_size) + num_blocks = max(real_num_blocks, TOPK) + + head_tile = 1 + reduce_tile_size = num_blocks - 1 + + valid_lens_all = torch.zeros( + ( + num_kv_heads, + num_blocks, + ), + dtype=torch.int32, + device=q.device, + ) + for h in range(num_kv_heads): + topk_idx_tile = topk_idx[h * head_tile: (h + 1) * head_tile] + topk_idx_nonneg = topk_idx_tile[topk_idx_tile >= 0] + valid_lens = torch.bincount(topk_idx_nonneg.view(-1), minlength=num_blocks) + valid_lens_all[h * head_tile: (h + 1) * head_tile] = valid_lens + + global_max_valid_tokens = valid_lens_all[:, 1:].max() if num_blocks > 1 else valid_lens_all.max() + + o_full = torch.zeros_like(q) + lse_full = torch.full((num_heads, total_len), float("-inf"), dtype=torch.float32, device=q.device) + + # New introduced buffers + topk_idx_permuted_tile = torch.full((head_tile, num_blocks, total_len), -1, dtype=torch.int32, device=q.device) + + token_index_mapping = torch.full((head_tile, num_blocks, total_len), 0, dtype=torch.int32, device=q.device) + # first KV block is computed seaprately + o_tiles_first = torch.zeros((head_tile, 1, total_len, head_dim), dtype=torch.bfloat16, device=q.device) + o_tiles_rest = torch.zeros( + (head_tile, reduce_tile_size, global_max_valid_tokens, head_dim), dtype=torch.bfloat16, device=q.device + ) + + # Statistics buffers + # m_i_tiles: 历史最大, m_diff_tiles: 历史最大和当前最大的差值 + # m_i_cur_tiles: 当前最大, # m_ij_tiles: 考虑当前和历史后的最大 + m_i_cur_tiles: torch.Tensor = torch.full( + (head_tile, num_blocks, total_len), float("-inf"), dtype=torch.float32, device=q.device + ) + + # first KV block is reduced separately + l_ij_first = torch.full((head_tile, 1, total_len), 0, dtype=torch.float32, device=q.device) + acc_o_scales_first = torch.full((head_tile, 1, total_len), 1, dtype=torch.float32, device=q.device) + + l_ij_rest = torch.full( + (head_tile, reduce_tile_size, global_max_valid_tokens), 0, dtype=torch.float32, device=q.device + ) + acc_o_scales_rest = torch.full( + (head_tile, reduce_tile_size, global_max_valid_tokens), 1, dtype=torch.float32, device=q.device + ) + + permute_results = {} + permute_results['global_max_valid_tokens'] = global_max_valid_tokens + permute_results['num_blocks'] = num_blocks + permute_results['real_num_blocks'] = real_num_blocks + permute_results['valid_topk_idx_permuted_tile'] = [] + permute_results['valid_lens_all'] = valid_lens_all + permute_results['valid_lens'] = [] + permute_results['valid_start_indices'] = [] + + for h in range(num_heads // head_tile): + q_tile = q[:, h * head_tile: (h + 1) * head_tile] + k_tile = k[:, (h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + v_tile = v[:, (h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + o = o_full[:, h * head_tile: (h + 1) * head_tile] + lse = lse_full[h * head_tile: (h + 1) * head_tile] + + permute_min_block_id = 0 + permute_max_block_id = min(permute_min_block_id + num_blocks, num_blocks) + + topk_idx_tile = topk_idx[(h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + + if h % gqa_deg == 0: + topk_idx_permuted_tile = build_block_to_token_triton( + topk_idx_permuted_tile, topk_idx_tile, permute_min_block_id, permute_max_block_id, padding_value=-1 + ) + + valid_topk_idx_permuted_tile = topk_idx_permuted_tile[topk_idx_permuted_tile != -1] + valid_lens = valid_lens_all[(h // gqa_deg) * head_tile, :] + valid_start_indices = torch.nn.functional.pad(valid_lens.cumsum(0)[:-1], (1, 0), value=0) + + index_mapping( + token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks + ) + + permute_results['valid_topk_idx_permuted_tile'].append(valid_topk_idx_permuted_tile) + permute_results['valid_lens'].append(valid_lens) + permute_results['valid_start_indices'].append(valid_start_indices) + + m_ij_tiles, m_ij_last = online_softmax( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + 0, + total_len, + block_size, + num_blocks, + head_tile, + head_dim, + sm_scale, + cu_seqlens_q, + cu_seqlens_k, + ) + + m_ij_tiles[:, :, :] = m_ij_tiles[:, :, 0][:, :, None] + m_ij_last[:, :] = m_ij_last[:, 0] + for compute_min_block_id in range(min(2, num_blocks)): + if compute_min_block_id == 0: + cur_max_valid_tokens = total_len + cur_valid_lens = valid_lens[0] + cur_valid_start_indices = valid_start_indices[0] + o_tiles = o_tiles_first + l_ij = l_ij_first + acc_o_scales = acc_o_scales_first + compute_tile_size = 1 + else: + cur_max_valid_tokens = valid_lens[compute_min_block_id:].max() + cur_valid_lens = valid_lens[compute_min_block_id:] + cur_valid_start_indices = valid_start_indices[compute_min_block_id:] + o_tiles = o_tiles_rest + l_ij = l_ij_rest + acc_o_scales = acc_o_scales_rest + compute_tile_size = num_blocks - 1 + + # launch kernel + qkv_kernel( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + cur_valid_lens, + cur_valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + block_size, + ) + + reduce_output( + lse, + o, + o_tiles_first, + o_tiles_rest, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h, + head_tile, + total_len, + TOPK, + head_dim, + ) + + o_full[:, h * head_tile: (h + 1) * head_tile] = o + lse_full[h * head_tile: (h + 1) * head_tile] = lse + + if h % gqa_deg == 0: + fused_fill(topk_idx_permuted_tile, m_i_cur_tiles) + + return o_full, lse_full, permute_results + + +@triton.jit +def dq_compute_kernel( + q_ptr, + k_ptr, + v_ptr, + lse_ptr, + delta_ptr, + do_ptr, + dq_tiles_ptr, + token_index_mapping_ptr, + selected_tokens_ptr, + valid_lens_ptr, + valid_start_indices_ptr, + cur_max_valid_tokens, + compute_min_block_id, + head_tile, + num_blocks, + HEAD_DIM, + cu_seqlens_k, + num_dq_blocks, + sm_scale, + debug_ptr, + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tim_h, + stride_tim_b, + stride_tim_n, + stride_dqth, + stride_dqtb, + stride_dqtn, + stride_dqtd, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + + pid_block = tl.program_id(0) + pid_q = tl.program_id(1) # token + # seq packing is not supported yet + q_start = 0 + k_start = 0 + + k_len = tl.load(cu_seqlens_k + 1) - k_start + + start_id = tl.load(valid_start_indices_ptr + pid_block) + valid_tokens = tl.load(valid_lens_ptr + pid_block) + if num_dq_blocks * pid_q * BLOCK_SIZE_Q >= valid_tokens: + return + + c = (pid_block + compute_min_block_id) * BLOCK_SIZE_K + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + + # load k + k = tl.load(tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero") + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + + # load v + v = tl.load(tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero") + + qk_scale = sm_scale * 1.44269504 + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_dq_blocks): + pid_q_j = pid_q * num_dq_blocks + j + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (q_start + pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + tl.store(debug_ptr + tl.arange(0, BLOCK_SIZE_Q), st_offs) + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + + mask = st != -1 + + q_ptrs = q_ptr + q_start * stride_qn + q_ptrs_off + # load q + q_mask = mask[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + do_ptrs = do_ptr + q_start * stride_qn + q_ptrs_off + do = tl.load(do_ptrs, mask=q_mask, other=0) + delta_ptrs = delta_ptr + st[:, None] + d = tl.load(delta_ptrs, mask=mask[:, None], other=0) + lse_ptrs = lse_ptr + st[:, None] + lse = tl.load(lse_ptrs, mask=mask[:, None], other=0) + + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + qk += tl.dot(q, tl.trans(k)) * qk_scale # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + p = tl.exp2(qk - lse) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + dp = tl.dot(do, v) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + ds = sm_scale * p * (dp - d) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + ds = ds.to(q.dtype) + dq = tl.dot(ds, k) # [BLOCK_SIZE_Q, BLOCK_SIZE_D] + + # load token index mapping + token_index_mapping_ptrs = ( + token_index_mapping_ptr + (st) * stride_tim_n + (pid_block + compute_min_block_id) * stride_tim_b + ) + token_index_mapping = tl.load(token_index_mapping_ptrs, mask=mask, other=-1) + + dq_ptrs_off = token_index_mapping[:, None] * stride_dqtn + off_d[None, :] * stride_dqtd + dq_tiles_ptrs = dq_tiles_ptr + dq_ptrs_off + (pid_block).to(tl.int64) * stride_dqtb + tl.store(dq_tiles_ptrs, dq.to(dq_tiles_ptr.dtype.element_ty), mask=q_mask) + + +@triton.jit +def dq_reduce_kernel( + dq_buffer_first_ptr, # [H, 1, N, D] + dq_buffer_rest_ptr, # [H, B, N, D] + dq_ptr, # o: n x h x d + t_ptr, # topk_idx: h x n x k + token_index_mapping_ptr, + num_qz_loop, + TOPK, + total_len, + # stride + stride_dqtfh, + stride_dqtfb, + stride_dqtfn, + stride_dqtfd, + stride_dqtrh, + stride_dqtrb, + stride_dqtrn, + stride_dqtrd, + stride_dqn, + stride_dqh, + stride_dqd, + stride_th, + stride_tn, + stride_tk, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_T: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_qy = tl.program_id(0) + pid_q = tl.program_id(1) # token + + pid_q_j = pid_q + pid_qy * num_qz_loop + if pid_q_j < total_len: + t_ptr_j = t_ptr + pid_q_j * stride_tn + + off_d = tl.arange(0, BLOCK_SIZE_D) + dq_ptrs = dq_ptr + pid_q_j * stride_dqn + off_d + acc_dq = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) + + for block_id in range(TOPK): + t = tl.load(t_ptr_j + block_id * stride_tk, mask=block_id < TOPK, other=-1) + if t != -1: + if t == 0: + dq_buffer_ptr = dq_buffer_first_ptr + stride_dqtb = stride_dqtfb + stride_dqtn = stride_dqtfn + real_block_pos = 0 + else: + dq_buffer_ptr = dq_buffer_rest_ptr + stride_dqtb = stride_dqtrb + stride_dqtn = stride_dqtrn + real_block_pos = t - 1 + + # init pointers + token_index_mapping_ptrs = ( + token_index_mapping_ptr + t.to(tl.int64) * stride_tim_b + (pid_q_j) * stride_tim_n + ) + real_token_index = tl.load(token_index_mapping_ptrs) + + dq_buffer_ptrs = ( + dq_buffer_ptr + real_block_pos.to(tl.int64) * stride_dqtb + (real_token_index) * stride_dqtn + off_d + ) + + dq_buffers = tl.load(dq_buffer_ptrs) + acc_dq = dq_buffers + acc_dq + + tl.store(dq_ptrs, acc_dq, mask=off_d < BLOCK_SIZE_D) + + +def backward_dq_opt( + q, # [total_len, num_heads, head_dim] + k, # [total_len, num_k_heads, head_dim] + v, # [total_len, num_k_heads, head_dim] + topk_idx, # [num_k_heads, total_len, topk] + lse, # [num_heads, total_len] + delta, # [num_heads, total_len] + do, # [total_len, num_heads, head_dim] + dq, # [total_len, num_heads, head_dim] + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, +): + """ + TODO: Currently sequence packing is explicitly done in for loop, will merge in kernels. + """ + for i in range(len(cu_seqlens_q) - 1): + cu_seqlens_q_ = cu_seqlens_q[i: i + 2] - cu_seqlens_q[i] + cu_seqlens_k_ = cu_seqlens_k[i: i + 2] - cu_seqlens_k[i] + + permute_results_ = permute_results[i] + + q_ = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_ = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + v_ = v[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + topk_idx_ = topk_idx[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + lse_ = lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + delta_ = delta[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + do_ = do[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + dq_ = dq[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + + backward_dq_opt_per_seq( + q_, + k_, + v_, + topk_idx_, + lse_, + delta_, + do_, + dq_, + cu_seqlens_q_, + cu_seqlens_k_, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results_, + ) + + dq[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = dq_ + + return dq + + +def backward_dq_opt_per_seq( + q, # [total_len, num_k_heads, head_dim] + k, # [total_len, num_k_heads, head_dim] + v, # [total_len, num_k_heads, head_dim] + topk_idx, # [num_k_heads, total_len, topk] + lse, # [num_k_heads, total_len] + delta, # [num_k_heads, total_len] + do, # [total_len, num_k_heads, head_dim] + dq, # [total_len, num_k_heads, head_dim] + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, +): + head_tile = 1 + total_len = topk_idx.shape[1] + global_max_valid_tokens = permute_results['global_max_valid_tokens'] + num_blocks = permute_results['num_blocks'] + reduce_tile_size = num_blocks - 1 + dq_buffer_first = torch.zeros((head_tile, 1, total_len, head_dim), dtype=torch.bfloat16, device=dq.device) + dq_buffer_rest = torch.zeros( + (head_tile, reduce_tile_size, global_max_valid_tokens, head_dim), dtype=torch.bfloat16, device=dq.device + ) + + num_heads = num_share_q_heads * num_k_heads + + token_index_mapping = torch.full((head_tile, num_blocks, total_len), 0, dtype=torch.int32, device=q.device) + for h in range(num_heads // head_tile): + valid_topk_idx_permuted_tile = permute_results['valid_topk_idx_permuted_tile'][h // num_share_q_heads] + + valid_lens = permute_results['valid_lens'][h // num_share_q_heads] + valid_start_indices = permute_results['valid_start_indices'][h // num_share_q_heads] + + index_mapping(token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks) + q_tile = q[:, h * head_tile: (h + 1) * head_tile] + k_tile = k[:, (h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + v_tile = v[:, (h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + do_tile = do[:, h * head_tile: (h + 1) * head_tile] + lse_tile = lse[h * head_tile: (h + 1) * head_tile] + topk_idx_tile = topk_idx[(h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + delta_tile = delta[h * head_tile: (h + 1) * head_tile] + dq_tile = dq[:, h * head_tile: (h + 1) * head_tile] + + for compute_min_block_id in range(min(2, num_blocks)): + if compute_min_block_id == 0: + compute_tile_size = 1 + cur_max_valid_tokens = total_len + cur_valid_lens = valid_lens[0] + cur_valid_start_indices = valid_start_indices[0] + dq_buffer = dq_buffer_first + else: + compute_tile_size = num_blocks - 1 + cur_max_valid_tokens = valid_lens[compute_min_block_id:].max() + cur_valid_lens = valid_lens[compute_min_block_id:] + cur_valid_start_indices = valid_start_indices[compute_min_block_id:] + dq_buffer = dq_buffer_rest + + BLOCK_SIZE_Q = 128 + num_dq_blocks = 8 + grid_dq = lambda META: ( + compute_tile_size, + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_dq_blocks), + ) + + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + debug = torch.zeros((BLOCK_SIZE_Q,), dtype=torch.int32, device=dq.device) + dq_compute_kernel[grid_dq]( + q_tile, + k_tile, + v_tile, + lse_tile, + delta_tile, + do_tile, + dq_buffer, + token_index_mapping, + valid_topk_idx_permuted_tile, + cur_valid_lens, + cur_valid_start_indices, + cur_max_valid_tokens, + compute_min_block_id, + head_tile, + num_blocks, + head_dim, + cu_seqlens_k, + num_dq_blocks, + sm_scale, + debug, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + v_tile.stride(0), + v_tile.stride(1), + v_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + dq_buffer.stride(0), + dq_buffer.stride(1), + dq_buffer.stride(2), + dq_buffer.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + + num_qy_loop = 4 + num_qz_loop = total_len // num_qy_loop + + grid_reduce = lambda META: ( + num_qy_loop + (total_len % num_qy_loop != 0), + num_qz_loop, + ) + dq_reduce_kernel[grid_reduce]( + dq_buffer_first, + dq_buffer_rest, + dq_tile, + topk_idx_tile, + token_index_mapping, + num_qz_loop, + topk, + total_len, + dq_buffer_first.stride(0), + dq_buffer_first.stride(1), + dq_buffer_first.stride(2), + dq_buffer_first.stride(3), + dq_buffer_rest.stride(0), + dq_buffer_rest.stride(1), + dq_buffer_rest.stride(2), + dq_buffer_rest.stride(3), + dq_tile.stride(0), + dq_tile.stride(1), + dq_tile.stride(2), + topk_idx_tile.stride(0), + topk_idx_tile.stride(1), + topk_idx_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_T=triton.next_power_of_2(topk), + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=1, + num_stages=2, + ) + + dq[:, h * head_tile: (h + 1) * head_tile] = dq_tile + + return dq + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + tq_ptr, # topk_q_idx: kh x N + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DK: sh x n x kh x d + # seqlens + cu_seqlens_q, # [batch_size + 1] + cu_seqlens_k, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + cu_topk_q_count, # [kh, total_blocks] + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tqh, + stride_tqn, + stride_ctqh, + stride_ctqn, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # get topk_q_idx + b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence + act_q_start = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn) + act_q_end = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn) + act_q_len = act_q_end - act_q_start + tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + off_d = tl.arange(0, BLOCK_SIZE_D) + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # init ptrs + q_ptrs = q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd + do_ptrs = do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod + d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh + lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + # loop for q blocks + for i in range(0, act_q_len, BLOCK_SIZE_Q): + # load + idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(tl.int32) + q = tl.load( + q_ptrs + idx_q[:, None] * stride_qn, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + do = tl.load( + do_ptrs + idx_q[:, None] * stride_don, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + lse = tl.load( + lse_ptrs + idx_q[:, None] * stride_ln, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + d = tl.load( + d_ptrs + idx_q[:, None] * stride_dn, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf")) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _topk_sparse_attention_bwd_opt( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + permute_results, +): + + assert block_size in {16, 32, 64, 128, 256} + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + topk = topk_idx.shape[-1] + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # count active querys for each key block, shape: (num_k_heads, total_k_blocks) + seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqblocks = torch.ceil(seqlens / block_size).to(torch.int32) + cu_seqblocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(seqblocks, dim=0), + ] + ).to(torch.int32) + + topk_q_count = torch.cat( + [ + permute_results[i]['valid_lens_all'][:, : permute_results[i]['real_num_blocks']] + for i in range(len(permute_results)) + ], + dim=1, + ) + + cu_topk_q_count = torch.cat( + [ + torch.zeros(topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(topk_q_count, dim=-1), + ], + dim=-1, + ).to(torch.int32) + # active query idx for each key block + # how to get active query idx for sequence b, head h, kv block i? + topk_q_idx = reorder_topk_idx(topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K)) + backward_dkdv[grid]( + q, + k, + v, + topk_q_idx, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_topk_q_count, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + num_q_loop = max_seqlen_q // 32768 + 1 # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + BLOCK_SIZE_K = block_size + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + + backward_dq_opt( + q, + k, + v, + topk_idx, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, + ) + + return dq, dk, dv + + +class FSATopkSparseAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert topk_idx.dtype == torch.int32 + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + permute_results = None + + o, lse, permute_results = _topk_sparse_attention_fwd_opt( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx) + ctx.permute_results = permute_results + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.block_size = block_size + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors + permute_results = ctx.permute_results + + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + block_size = ctx.block_size + assert block_size in {16, 32, 64, 128, 256} + + dq, dk, dv = _topk_sparse_attention_bwd_opt( + o, + do, + lse, + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + permute_results, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def FSA_topk_sparse_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention varlen version implemented in triton. + + Args: + q (torch.Tensor): shape [total_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] + """ + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return FSATopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + softmax_scale, + ) + + +def FSA_topk_sparse_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """FSA topk sparse attention with separate Q and K sequence lengths (for extend/prefill). + + Args: + q (torch.Tensor): shape [total_q, num_q_heads, head_dim] + k (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_q, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], cumulative Q sequence lengths. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], cumulative K sequence lengths. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_q, num_q_heads, head_dim] + """ + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + return FSATopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + ) diff --git a/vortex_torch/attention_backend/fsa/__init__.py b/vortex_torch/attention_backend/fsa/__init__.py new file mode 100644 index 00000000..9efd4740 --- /dev/null +++ b/vortex_torch/attention_backend/fsa/__init__.py @@ -0,0 +1,9 @@ +from .FSA_topk_sparse_attention import ( + FSA_topk_sparse_attention, + FSA_topk_sparse_attention_varlen, +) + +__all__ = [ + "FSA_topk_sparse_attention", + "FSA_topk_sparse_attention_varlen", +] diff --git a/vortex_torch/attention_backend/nsa/__init__.py b/vortex_torch/attention_backend/nsa/__init__.py new file mode 100644 index 00000000..382da01b --- /dev/null +++ b/vortex_torch/attention_backend/nsa/__init__.py @@ -0,0 +1,9 @@ +from .topk_sparse_attention import ( + topk_sparse_attention, + topk_sparse_attention_varlen, +) + +__all__ = [ + "topk_sparse_attention", + "topk_sparse_attention_varlen", +] diff --git a/vortex_torch/attention_backend/nsa/topk_sparse_attention.py b/vortex_torch/attention_backend/nsa/topk_sparse_attention.py new file mode 100644 index 00000000..57a2be70 --- /dev/null +++ b/vortex_torch/attention_backend/nsa/topk_sparse_attention.py @@ -0,0 +1,1280 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +def is_hopper_gpu(): + if torch.cuda.is_available(): + device_capability = torch.cuda.get_device_capability(0) + major, minor = device_capability + return major == 9 + return False + + +def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): + head_large = head_dim > 64 + block_large = block_size > 64 + if is_hopper_gpu: + if head_large and block_large: + num_warps, num_stages = 8, 3 + elif head_large or block_large: + num_warps, num_stages = 4, 3 + else: + num_warps, num_stages = 2, 2 + else: + if head_large and block_large: + num_warps, num_stages = 8, 3 + elif head_large or block_large: + num_warps, num_stages = 8, 3 + else: + num_warps, num_stages = 2, 2 + return num_warps, num_stages + + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel_orig( + q_ptr, # Q: n x h x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + block_size, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + # q loop num + num_q_loop: tl.constexpr, + num_k_loop: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid = tl.program_id(0) + + Q = MAX_SEQ_LEN // num_q_loop + HK = NUM_KV_HEADS // num_k_loop + + # 第几个 (b, kh_chunk, q_chunk) + pid_b = pid // (HK * Q) + pid_kh_chunk = (pid % (HK * Q)) // Q # 每个block处理num_k_loop个KV head + pid_q = pid % Q + + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + + if pid_q * num_q_loop >= q_len: + return + real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) + + for kh_offset in range(num_k_loop): + pid_kh = pid_kh_chunk * num_k_loop + kh_offset + pid_h = pid_kh * NUM_SHARE_Q_HEADS + + for j in range(real_q_loop): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + + """Removed causal attention, which should be: + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0), + axis=0, + ) + """ + # real_topk = tl.sum( + # tl.where((topk_idx >= 0), 1, 0), + # axis=0, + # ) + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0), + axis=0, + ) + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_h = tl.arange(0, BLOCK_SIZE_H) + off_k = tl.arange(0, BLOCK_SIZE_K) + m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32) + # sparse attention + for i in range(real_topk): + # get current block start index + c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K + t_ptr_j = t_ptr_j + stride_tk + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) + + # final scale + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_oh, stride_od), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + lse_ptrs = lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh + tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_o[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_o[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len) + + +@triton.jit +def count_kernel( + x_ptr, # [num_kv_heads, total_len, topk] + y_ptr, # [num_kv_heads, total_blocks] + cu_seqlens, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + topk, + stride_xh, + stride_xn, + stride_xk, + stride_yh, + stride_yn, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_R: tl.constexpr, +): + pid_h = tl.program_id(0) + pid_b = tl.program_id(1) + # get start and len after rmpad + seq_start = tl.load(cu_seqlens + pid_b) + seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start + blocks_start = tl.load(cu_seqblocks + pid_b) + num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start + # load x + off_k = tl.arange(0, BLOCK_SIZE_K) + off_n = tl.arange(0, BLOCK_SIZE_N) + x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn + x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk + # init y + y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32) + # loop + for i in range(0, seq_len, BLOCK_SIZE_N): + x = tl.load( + x_ptrs, + mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :], + other=-1, + ) + x = tl.ravel(x) + y += tl.histogram(x, BLOCK_SIZE_R) + x_ptrs += BLOCK_SIZE_N * stride_xn + # store result + off_r = tl.arange(0, BLOCK_SIZE_R) + y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn + y_ptrs = y_ptr + off_r * stride_yn + tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks) + + +def count_query( + topk_idx: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1] + batch_size = seqlens.shape[0] + BLOCK_SIZE_K = triton.next_power_of_2(topk) + BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K) + BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2) + active_query_count = torch.zeros(num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device) + grid = (num_kv_heads, batch_size) + count_kernel[grid]( + topk_idx, + active_query_count, + cu_seqlens, + cu_seqblocks, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + active_query_count.stride(0), + active_query_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_R=BLOCK_SIZE_R, + num_warps=4, + num_stages=3, + ) + return active_query_count + + +@triton.jit +def pad_topk_idx_kernel( + t_ptr, + p_ptr, + cu_seqlens, + topk, + stride_th, + stride_tn, + stride_tk, + stride_pb, + stride_ph, + stride_pn, + stride_pk, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_start = tl.load(cu_seqlens + pid_b) + q_len = tl.load(cu_seqlens + pid_b + 1) - q_start + if BLOCK_SIZE_N * pid_n >= q_len: + return + # init prts + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + q_start * stride_tn, + shape=(q_len, topk), + strides=(stride_tn, stride_tk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + p_ptrs = tl.make_block_ptr( + base=p_ptr + pid_b * stride_pb + pid_h * stride_ph, + shape=(q_len, topk), + strides=(stride_pn, stride_pk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + # load and save + idxs = tl.load(t_ptrs, boundary_check=(0, 1)) + tl.store(p_ptrs, idxs, boundary_check=(0, 1)) + + +@triton.jit +def save_topk_idx_kernel( + p_ptr, + t_ptr, + cu_seqblocks, + cu_topk_q_count, + n_len, + stride_pb, + stride_ph, + stride_pn, + stride_th, + stride_tn, + stride_ch, + stride_cn, + BLOCK_SIZE_N: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_block_start = tl.load(cu_seqblocks + pid_b) + q_block_end = tl.load(cu_seqblocks + pid_b + 1) + c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn) + c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn) + c_len = c_end - c_start + if c_len <= 0: + return + if pid_n * BLOCK_SIZE_N >= c_len: + return + # init ptrs + p_ptrs = tl.make_block_ptr( + base=p_ptr + pid_b * stride_pb + pid_h * stride_ph + (n_len - c_len) * stride_pn, + shape=(c_len,), + strides=(stride_pn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + c_start * stride_tn, + shape=(c_len,), + strides=(stride_tn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + # load and save + idxs = tl.load(p_ptrs, boundary_check=(0,)) + tl.store(t_ptrs, idxs, boundary_check=(0,)) + + +def reorder_topk_idx( + topk_idx: torch.Tensor, + cu_topk_q_count: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + batch_size = cu_seqlens.shape[0] - 1 + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk] + pad_topk_idx = torch.full( + (batch_size, num_kv_heads, max_seqlen, topk), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + BLOCK_SIZE_N = min(triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T)) + grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N)) + pad_topk_idx_kernel[grid]( + topk_idx, + pad_topk_idx, + cu_seqlens, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + pad_topk_idx.stride(0), + pad_topk_idx.stride(1), + pad_topk_idx.stride(2), + pad_topk_idx.stride(3), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_T=BLOCK_SIZE_T, + ) + # argsort + pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk + pad_topk_q_idx = pad_topk_q_idx.to(torch.int32) + # save as remove pad version + topk_q_idx = torch.full( + (num_kv_heads, cu_topk_q_count[:, -1].max().item()), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + max_len = (cu_topk_q_count[:, cu_seqblocks][:, 1:] - cu_topk_q_count[:, cu_seqblocks][:, :-1]).max().item() + BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192) + grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N)) + save_topk_idx_kernel[grid]( + pad_topk_q_idx, + topk_q_idx, + cu_seqblocks, + cu_topk_q_count, + pad_topk_q_idx.shape[-1], + pad_topk_q_idx.stride(0), + pad_topk_q_idx.stride(1), + pad_topk_q_idx.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return topk_q_idx + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + tq_ptr, # topk_q_idx: kh x N + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DK: sh x n x kh x d + # seqlens + cu_seqlens_q, # [batch_size + 1] + cu_seqlens_k, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + cu_topk_q_count, # [kh, total_blocks] + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tqh, + stride_tqn, + stride_ctqh, + stride_ctqn, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # get topk_q_idx + b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence + act_q_start = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn) + act_q_end = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn) + act_q_len = act_q_end - act_q_start + tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + off_d = tl.arange(0, BLOCK_SIZE_D) + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # init ptrs + q_ptrs = q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd + do_ptrs = do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod + d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh + lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + # loop for q blocks + for i in range(0, act_q_len, BLOCK_SIZE_Q): + # load + idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(tl.int32) + q = tl.load( + q_ptrs + idx_q[:, None] * stride_qn, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + do = tl.load( + do_ptrs + idx_q[:, None] * stride_don, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + lse = tl.load( + lse_ptrs + idx_q[:, None] * stride_ln, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + d = tl.load( + d_ptrs + idx_q[:, None] * stride_dn, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf")) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # q loop num + num_q_loop, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_kh = tl.program_id(1) + pid_q = tl.program_id(2) + pid_h = pid_kh * NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if pid_q * num_q_loop >= q_len: + return + real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) + for j in range(real_q_loop): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0), + axis=0, + ) + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_dqh, stride_dqd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_doh, stride_dod), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_dh, stride_dn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_lh, stride_ln), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + # offsets + off_k = tl.arange(0, BLOCK_SIZE_K) + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32) + # sparse + for i in range(real_topk): + # get current block start index + c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K + t_ptr_j = t_ptr_j + stride_tk + # load + k = tl.load(tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero") + v = tl.load(tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, tl.trans(k)) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _topk_sparse_attention_fwd( + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert block_size in {16, 32, 64, 128, 256} + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + # assert q_len == k_len and k_len == v_len + topk = topk_idx.shape[-1] + assert topk_idx.shape[0] == num_k_heads + assert topk_idx.shape[1] == q_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.zeros_like(q) + + lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device) + + # launch kernel + num_q_loop = num_k_loop = 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + + def grid(meta): + grid = ( + batch_size * triton.cdiv(num_k_heads, num_k_loop) * triton.cdiv(max_seqlen_q, num_q_loop), + ) + return grid + + num_warps, num_stages = get_num_warps_stages(head_dim, block_size, IS_HOPPER_GPU) + forward_kernel_orig[grid]( + q, + k, + v, + topk_idx, + o, + lse, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + block_size, + # num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + num_q_loop=num_q_loop, + num_k_loop=num_k_loop, + MAX_SEQ_LEN=max_seqlen_q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _topk_sparse_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + + assert block_size in {16, 32, 64, 128, 256} + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + topk = topk_idx.shape[-1] + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) + + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # count active querys for each key block, shape: (num_k_heads, total_k_blocks) + seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqblocks = torch.ceil(seqlens / block_size).to(torch.int32) + cu_seqblocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(seqblocks, dim=0), + ] + ).to(torch.int32) + + topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size) + + cu_topk_q_count = torch.cat( + [ + torch.zeros(topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(topk_q_count, dim=-1), + ], + dim=-1, + ).to(torch.int32) + # active query idx for each key block + # how to get active query idx for sequence b, head h, kv block i? + topk_q_idx = reorder_topk_idx(topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K)) + backward_dkdv[grid]( + q, + k, + v, + topk_q_idx, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_topk_q_count, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + num_q_loop = max_seqlen_q // 32768 + 1 # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + BLOCK_SIZE_K = block_size + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + + backward_dq[grid]( + q, + k, + v, + topk_idx, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class TopkSparseAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert topk_idx.dtype == torch.int32 + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + o, lse = _topk_sparse_attention_fwd( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.block_size = block_size + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors + + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + block_size = ctx.block_size + assert block_size in {16, 32, 64, 128, 256} + + dq, dk, dv = _topk_sparse_attention_bwd( + o, + do, + lse, + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def topk_sparse_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention varlen version implemented in triton. + + Args: + q (torch.Tensor): shape [total_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] + """ + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return TopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + softmax_scale, + ) + + +def topk_sparse_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention with separate Q and K sequence lengths (for extend/prefill). + + Same as topk_sparse_attention but accepts separate cu_seqlens for Q and K. + Useful when Q only covers new tokens while K covers all tokens (prefix + new). + + Args: + q (torch.Tensor): shape [total_q, num_q_heads, head_dim] + k (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_q, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], cumulative Q sequence lengths. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], cumulative K sequence lengths. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_q, num_q_heads, head_dim] + """ + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + return TopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + ) diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index eddfa464..6b549054 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,11 +29,14 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher, dequant_pages_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", + "dequant_pages_to_bf16_inplace", "Mean", "Max", "Min", "L2Norm", "GeMM", "Relu", "Silu", "Sigmoid", "Abs", "Add_Mul", diff --git a/vortex_torch/cache/context.py b/vortex_torch/cache/context.py index ae2dd5c2..3cdf0953 100644 --- a/vortex_torch/cache/context.py +++ b/vortex_torch/cache/context.py @@ -10,17 +10,25 @@ class Context(ContextBase): """ __slots__ = ContextBase.__slots__ + ( - + #page infomation "max_new_tokens_per_batch", "page_size", "total_num_pages", - + #model infomation "head_dim", "head_num", - + # auxilary memory in graph "_aux_total_bytes", - - "_aux_total_flops" + + "_aux_total_flops", + + # Quantization: quant_type (0=none, 1=int8, 2=e4m3, 3=e5m2), + # kv_scale (per-tensor fp8 scale), kv_scale_ptr (per-token int8 scale tensor) + # fp8_type: 0=none, 1=e4m3, 2=e5m2 (encoding for Triton kernels) + "quant_type", + "kv_scale", + "kv_scale_ptr", + "fp8_type", ) @@ -36,7 +44,15 @@ def __init__(self) -> None: elif name == "_aux_total_flops": object.__setattr__(self, name, 0) # start from 0 flops elif name == "mode": - object.__setattr__(self, name, Mode.profile) + object.__setattr__(self, name, Mode.profile) + elif name == "quant_type": + object.__setattr__(self, name, 0) # 0 = none (bf16 default) + elif name == "kv_scale": + object.__setattr__(self, name, 1.0) # identity scale for bf16 + elif name == "kv_scale_ptr": + object.__setattr__(self, name, None) # per-token scale tensor (int8 only) + elif name == "fp8_type": + object.__setattr__(self, name, 0) # 0 = none (bf16 default) else: object.__setattr__(self, name, UNSET) diff --git a/vortex_torch/cache/reduce.py b/vortex_torch/cache/reduce.py index 3c4edf2f..eb94795e 100644 --- a/vortex_torch/cache/reduce.py +++ b/vortex_torch/cache/reduce.py @@ -345,8 +345,11 @@ def execute( ) output = self.output_buffer - # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type) - self.impl(x, output, loc, ctx, self.dim, self.reduce_type) + # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, quant_type, scale, kv_scale_ptr) + quant_type = getattr(ctx, 'quant_type', 0) + scale = getattr(ctx, 'kv_scale', 1.0) + kv_scale_ptr = getattr(ctx, 'kv_scale_ptr', None) + self.impl(x, output, loc, ctx, self.dim, self.reduce_type, quant_type, scale, kv_scale_ptr) return output diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index 6bf6dfcd..de4fcbdc 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,4 +1,17 @@ -from .set_kv import set_kv_buffer_launcher - -__all__ = ["set_kv_buffer_launcher"] +from .set_kv import ( + set_kv_buffer_launcher, + set_kv_buffer_int8_launcher, + set_kv_buffer_fp8_launcher, + paged_decode, + dequant_pages_to_bf16, + dequant_pages_to_bf16_inplace, +) +__all__ = [ + "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", + "paged_decode", + "dequant_pages_to_bf16", + "dequant_pages_to_bf16_inplace", +] diff --git a/vortex_torch/cache/triton_kernels/reduce_impl.py b/vortex_torch/cache/triton_kernels/reduce_impl.py index 9921e082..0146af7b 100644 --- a/vortex_torch/cache/triton_kernels/reduce_impl.py +++ b/vortex_torch/cache/triton_kernels/reduce_impl.py @@ -4,6 +4,17 @@ from ..context import Context from ...utils import ReduceType + +# --------------------------------------------------------------------------- +# Helper: Load a page block from src_ptr, handling bf16 / int8 / fp8-stored-as-uint8. +# QUANT_TYPE == 0 -> bf16 pointer, load normally +# QUANT_TYPE == 1 -> int8 pointer, dequant with per-row scale from kv_scale_ptr +# QUANT_TYPE == 2 -> uint8 pointer, bitcast to float8e4nv, dequant with per-tensor scale +# QUANT_TYPE == 3 -> uint8 pointer, bitcast to float8e5, dequant with per-tensor scale +# All quantised paths return a float32 tensor ready for reduction. +# --------------------------------------------------------------------------- + + @triton.jit def reduce_pp_kernel( x, output, loc, @@ -12,9 +23,12 @@ def reduce_pp_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0:Mean, 1:Max, 2:Min, 3:L2Norm -DIM: tl.constexpr # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 -): - +DIM: tl.constexpr, # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) +): + token_id = tl.program_id(0) head_id = tl.program_id(1) @@ -29,7 +43,22 @@ def reduce_pp_kernel( rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - page_block = tl.load(src_ptr) + + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + # Per-row scales stored at kv_scale_ptr[page_id * x_D0 + row] + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) # [x_D0] + page_block = raw * row_scales[:, None] # broadcast [x_D0, 1] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr) if DIM == 1: # reduce over rows -> axis=0 -> length x_D1 @@ -71,11 +100,14 @@ def reduce_pp( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -85,7 +117,10 @@ def reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -97,11 +132,14 @@ def _reduce_pp( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -111,7 +149,10 @@ def _reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -124,9 +165,12 @@ def reduce_rp_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) - DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 + scale, # float: 1.0 for bf16, kv_scale for fp8 + kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): - + # Program IDs: # pid0 = token index (0 .. num_tokens-1) # pid1 = head index (0 .. NUM_KV_HEAD-1) @@ -156,7 +200,20 @@ def reduce_rp_kernel( # Load the full page block for this (token_id, head_id). # Assumes the page is full; add masks here if you have partial tiles. - page_block = tl.load(src_ptr) + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr) # Reduction: if DIM == 1: @@ -196,7 +253,7 @@ def reduce_rp_kernel( # Write to output: layout [num_pages, x_D0] for DIM==2. dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) - + def reduce_rp( @@ -206,11 +263,14 @@ def reduce_rp( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -220,7 +280,10 @@ def reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -232,11 +295,14 @@ def _reduce_rp( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -246,7 +312,10 @@ def _reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -258,7 +327,10 @@ def reduce_pr_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): """ Layouts: @@ -297,7 +369,20 @@ def reduce_pr_kernel( src_ptr = x + x_offset + rows * x_D1 + cols # Load the full page block. Assumes full tiles; add masks if needed. - page_block = tl.load(src_ptr) + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr) # --- Reduction & write-out --- if DIM == 1: @@ -344,11 +429,14 @@ def reduce_pr( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -358,9 +446,12 @@ def reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) - + def _reduce_pr( x: torch.Tensor, output: torch.Tensor, @@ -369,11 +460,14 @@ def _reduce_pr( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -383,7 +477,10 @@ def _reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -395,7 +492,10 @@ def reduce_rr_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): """ Layouts: @@ -420,7 +520,22 @@ def reduce_rr_kernel( rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_base + rows * x_D1 + cols - page_blk = tl.load(src_ptr) # assumes full page; add masks if needed + + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_blk = raw * row_scales[:, None] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_blk = tl.load(src_ptr) # assumes full page; add masks if needed # ---- reduce ---- if DIM == 1: @@ -464,11 +579,14 @@ def reduce_rr( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -478,9 +596,12 @@ def reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) - + def _reduce_rr( x: torch.Tensor, @@ -490,11 +611,14 @@ def _reduce_rr( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -504,5 +628,8 @@ def _reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim - ) \ No newline at end of file + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, + ) diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index cfa3cab2..58468cc0 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -36,6 +36,93 @@ def set_kv_buffer_kernel( tl.store(dst_v_ptr, src_v) +@triton.jit +def set_kv_buffer_int8_kernel( + k_cache, # int8 paged K cache + v_cache, # int8 paged V cache + k_scale_cache, # fp16 per-token K scale [num_pages, page_size, 1] + v_scale_cache, # fp16 per-token V scale [num_pages, page_size, 1] + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr +): + """Quantize bf16 K/V to int8 with per-token absmax scaling and write to paged buffers.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Compute per-token absmax scale: scale = absmax / 127 + absmax_k = tl.max(tl.abs(src_k), axis=0) + absmax_v = tl.max(tl.abs(src_v), axis=0) + # Avoid division by zero + scale_k = absmax_k / 127.0 + 1e-10 + scale_v = absmax_v / 127.0 + 1e-10 + + # Quantize to int8: round(x / scale), clamp to [-128, 127] + q_k = tl.extra.cuda.libdevice.rint(src_k / scale_k) + q_k = tl.minimum(tl.maximum(q_k, -128.0), 127.0).to(tl.int8) + q_v = tl.extra.cuda.libdevice.rint(src_v / scale_v) + q_v = tl.minimum(tl.maximum(q_v, -128.0), 127.0).to(tl.int8) + + # Compute paged destination offset (same layout as bf16 kernel) + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write int8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + # Write per-token scales (fp16): shape [num_pages, page_size, 1] + # Layout: page_id * PAGE_SIZE + in_page_offset (flat per-head, one scale per token per head) + scale_offset = (page_id * NUM_KV_HEAD + head_id) * PAGE_SIZE + in_page_offset + tl.store(k_scale_cache + scale_offset, scale_k.to(tl.float16)) + tl.store(v_scale_cache + scale_offset, scale_v.to(tl.float16)) + + +def set_kv_buffer_int8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale_cache: torch.Tensor, + v_scale_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int +): + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_int8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, + v_cache, + k_scale_cache, + v_scale_cache, + new_k, + new_v, + loc, + NUM_KV_HEAD, + NNZ, + HEAD_DIM, + page_size + ) + + def set_kv_buffer_launcher( k_cache: torch.Tensor, v_cache: torch.Tensor, @@ -44,11 +131,11 @@ def set_kv_buffer_launcher( loc: torch.LongTensor, page_size: int ): - + NNZ = loc.shape[0] NUM_KV_HEAD = new_k.shape[1] HEAD_DIM = new_k.shape[2] - + set_kv_buffer_kernel[(NNZ, NUM_KV_HEAD)]( k_cache, v_cache, @@ -61,3 +148,591 @@ def set_kv_buffer_launcher( page_size ) + +@triton.jit +def set_kv_buffer_fp8_kernel( + k_cache, # uint8 paged K cache + v_cache, # uint8 paged V cache + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + FP8_TYPE: tl.constexpr, # 1: e4m3 (max=448), 2: e5m2 (max=57344) + k_scale, # float: per-tensor scale for K quantization + v_scale, # float: per-tensor scale for V quantization +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Scale down: quantized = real_value / scale + inv_k_scale = 1.0 / k_scale + inv_v_scale = 1.0 / v_scale + scaled_k = src_k * inv_k_scale + scaled_v = src_v * inv_v_scale + + # Clamp and cast to fp8, then bitcast to uint8 for storage + if FP8_TYPE == 1: + # e4m3: max = 448.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -448.0), 448.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -448.0), 448.0) + q_k = clamped_k.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + else: + # e5m2: max = 57344.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -57344.0), 57344.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -57344.0), 57344.0) + q_k = clamped_k.to(tl.float8e5).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e5).to(tl.uint8, bitcast=True) + + # Compute paged destination offset + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write uint8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + +def set_kv_buffer_fp8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int, + k_scale: float, + v_scale: float, + fp8_type: int = 1, +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache. + + Args: + fp8_type: 1 for e4m3 (default), 2 for e5m2. + k_scale: per-tensor scale used for K quantization. + v_scale: per-tensor scale used for V quantization. + """ + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_fp8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, v_cache, + new_k, new_v, + loc, + NUM_KV_HEAD, NNZ, HEAD_DIM, page_size, + FP8_TYPE=fp8_type, + k_scale=k_scale, + v_scale=v_scale, + ) + + +# --------------------------------------------------------------------------- +# Dequantization kernels (read direction: quantized paged cache → bf16) +# --------------------------------------------------------------------------- + +@triton.jit +def _dequant_pages_kernel( + src, # quantized paged buffer flat + src_scale, # per-token scale buffer flat (int8 only) + dst, # bf16 destination buffer flat + page_indices, # int32 page indices to dequant + NUM_PAGES, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, + QUANT_TYPE: tl.constexpr, # 1: int8, 2: e4m3, 3: e5m2 + tensor_scale, # float: per-tensor scale (fp8 only) + COMPACT: tl.constexpr, # True: compact dst; False: in-place dst +): + """Unified dequant kernel for selected pages → bf16. + + QUANT_TYPE==1: load int8, multiply by per-token scale from src_scale. + QUANT_TYPE==2: load uint8, bitcast to float8e4nv, multiply by tensor_scale. + QUANT_TYPE==3: load uint8, bitcast to float8e5, multiply by tensor_scale. + COMPACT==True: dst offset uses page_idx (compact buffer). + COMPACT==False: dst offset uses global_page_id (in-place). + """ + page_idx = tl.program_id(0) + token_idx = tl.program_id(1) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + src_offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + scale_offset = global_page_id * PAGE_SIZE + token_idx + + if QUANT_TYPE == 1: + val = tl.load(src + src_offset, mask=mask_dim, other=0).to(tl.float32) + scale = tl.load(src_scale + scale_offset).to(tl.float32) + result = (val * scale).to(tl.bfloat16) + elif QUANT_TYPE == 2: + raw = tl.load(src + src_offset, mask=mask_dim, other=0) + val = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) + result = (val * tensor_scale).to(tl.bfloat16) + else: # QUANT_TYPE == 3 + raw = tl.load(src + src_offset, mask=mask_dim, other=0) + val = raw.to(tl.float8e5, bitcast=True).to(tl.float32) + result = (val * tensor_scale).to(tl.bfloat16) + + if COMPACT: + dst_offset = (page_idx * PAGE_SIZE + token_idx) * HEAD_DIM + dims + else: + dst_offset = src_offset # same position as source + + tl.store(dst + dst_offset, result, mask=mask_dim) + + +def dequant_pages_to_bf16( + src: torch.Tensor, + src_scale: torch.Tensor, + page_indices: torch.Tensor, + page_size: int, + head_dim: int, + quant_type: int = 1, + tensor_scale: float = 1.0, + out: torch.Tensor = None, +) -> torch.Tensor: + """Dequant selected pages to compact bf16 buffer. + + Args: + quant_type: 1=int8 (per-token scale), 2=fp8 e4m3, 3=fp8 e5m2. + tensor_scale: per-tensor scale (fp8 only, ignored for int8). + out: optional pre-allocated bf16 buffer. + """ + num_accessed_pages = page_indices.shape[0] + if num_accessed_pages == 0: + if out is not None: + return out[:0] + return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src.device) + + if out is not None: + dst = out[:num_accessed_pages] + else: + dst = torch.empty( + (num_accessed_pages, page_size, head_dim), + dtype=torch.bfloat16, + device=src.device, + ) + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_accessed_pages, page_size) + _dequant_pages_kernel[grid]( + src, src_scale, dst, page_indices, + NUM_PAGES=num_accessed_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + COMPACT=True, + ) + + return dst + + +def dequant_pages_to_bf16_inplace( + src: torch.Tensor, + src_scale: torch.Tensor, + dst: torch.Tensor, + page_indices: torch.Tensor, + page_size: int, + head_dim: int, + quant_type: int = 1, + tensor_scale: float = 1.0, +) -> None: + """Dequant selected pages in-place (same page positions in dst). + + Args: + quant_type: 1=int8 (per-token scale), 2=fp8 e4m3, 3=fp8 e5m2. + tensor_scale: per-tensor scale (fp8 only, ignored for int8). + """ + num_pages = page_indices.shape[0] + if num_pages == 0: + return + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_pages, page_size) + _dequant_pages_kernel[grid]( + src, src_scale, dst, page_indices, + NUM_PAGES=num_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + COMPACT=False, + ) + + +# --------------------------------------------------------------------------- +# Paged decode attention (unified quant_type-parameterized) +# --------------------------------------------------------------------------- + +_MIN_BLOCK_KV = 32 + + +@triton.jit +def _tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_paged_decode_stage1( + Q, + K_Buffer, + V_Buffer, + K_Scale_Buffer, + V_Scale_Buffer, + sm_scale, + kv_indptr, + kv_indices, + last_page_len, + Att_Out, + Att_Lse, + num_kv_splits, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_vbs, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, + QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 + tensor_scale, # per-tensor scale for fp8 +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + q = tl.load(Q + off_q, mask=mask_d, other=0.0).to(tl.float32) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < split_kv_end + + page_indices_in_seq = offs_n // PAGE_SIZE + in_page_offsets = offs_n % PAGE_SIZE + page_ids = tl.load( + kv_indices + cur_batch_kv_start_idx + page_indices_in_seq, + mask=mask_n, other=0, + ) + kv_loc = page_ids * PAGE_SIZE + in_page_offsets + + # Load K with quant-type-dependent dequantization + offs_buf_k = kv_loc[:, None] * stride_buf_kbs + offs_d[None, :] + if QUANT_TYPE == 0: + k = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ).to(tl.float32) + elif QUANT_TYPE == 1: + k_int8 = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ).to(tl.float32) + k_scale = tl.load( + K_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, + ).to(tl.float32) + k = k_int8 * k_scale[:, None] + elif QUANT_TYPE == 2: + raw = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ) + k = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * tensor_scale + else: # QUANT_TYPE == 3 + raw = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ) + k = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * tensor_scale + + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * _tanh(qk / logit_cap) + + qk = tl.where(mask_n, qk, float("-inf")) + + # Load V with quant-type-dependent dequantization + offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :] + if QUANT_TYPE == 0: + v = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ).to(tl.float32) + elif QUANT_TYPE == 1: + v_int8 = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ).to(tl.float32) + v_scale = tl.load( + V_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, + ).to(tl.float32) + v = v_int8 * v_scale[:, None] + elif QUANT_TYPE == 2: + raw = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ) + v = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * tensor_scale + else: # QUANT_TYPE == 3 + raw = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ) + v = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * tensor_scale + + # Online softmax accumulation + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) + + tl.store(Att_Out + offs_mid_o, acc / e_sum, mask=mask_dv) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + ) // Lv + + tl.store(Att_Lse + offs_mid_o_1, e_max + tl.log(e_sum)) + + +@triton.jit +def _fwd_kernel_paged_decode_stage2( + Mid_O, + Mid_O_1, + O, + kv_indptr, + last_page_len, + num_kv_splits, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + MAX_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """Stage 2: Reduce split outputs via log-sum-exp merge.""" + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + + for split_kv_id in range(0, MAX_KV_SPLITS): + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def paged_decode( + q: torch.Tensor, + k_buffer: torch.Tensor, + v_buffer: torch.Tensor, + o: torch.Tensor, + kv_indptr: torch.Tensor, + kv_indices: torch.Tensor, + last_page_len: torch.Tensor, + num_kv_splits: torch.Tensor, + max_kv_splits: int, + sm_scale: float, + page_size: int, + quant_type: int = 0, + k_scale_buffer: torch.Tensor = None, + v_scale_buffer: torch.Tensor = None, + tensor_scale: float = 1.0, + logit_cap: float = 0.0, + att_out: torch.Tensor = None, + att_lse: torch.Tensor = None, +): + """Unified paged decode attention. + + Args: + quant_type: Controls K/V loading: + 0: bf16 (k_scale_buffer/v_scale_buffer unused) + 1: int8 with per-token scales (k_scale_buffer/v_scale_buffer required) + 2: fp8 e4m3 with per-tensor scale (tensor_scale required) + 3: fp8 e5m2 with per-tensor scale (tensor_scale required) + """ + batch = q.shape[0] + head_num = q.shape[1] + Lk = q.shape[2] + Lv = Lk + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + BLOCK_N = 128 + MAX_KV_SPLITS = max_kv_splits + + kv_group_num = head_num + num_warps = 4 + + if att_out is None: + att_out = torch.empty( + (batch, head_num, MAX_KV_SPLITS, Lv), + dtype=torch.float32, device=q.device, + ) + else: + att_out = att_out[:batch] + if att_lse is None: + att_lse = torch.empty( + (batch, head_num, MAX_KV_SPLITS), + dtype=torch.float32, device=q.device, + ) + else: + att_lse = att_lse[:batch] + + stride_buf_kbs = k_buffer.shape[-1] + stride_buf_vbs = v_buffer.shape[-1] + + # Use dummy tensors for scale buffers when not needed + _k_scale = k_scale_buffer if k_scale_buffer is not None else k_buffer + _v_scale = v_scale_buffer if v_scale_buffer is not None else v_buffer + + grid_stage1 = (batch, head_num, MAX_KV_SPLITS) + _fwd_kernel_paged_decode_stage1[grid_stage1]( + q, k_buffer, v_buffer, + _k_scale, _v_scale, + sm_scale, kv_indptr, kv_indices, last_page_len, + att_out, att_lse, num_kv_splits, + q.stride(0), q.stride(1), + stride_buf_kbs, stride_buf_vbs, + att_out.stride(0), att_out.stride(1), att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK_N, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, Lv=Lv, + PAGE_SIZE=page_size, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + ) + + grid_stage2 = (batch, head_num) + _fwd_kernel_paged_decode_stage2[grid_stage2]( + att_out, att_lse, o, + kv_indptr, last_page_len, num_kv_splits, + att_out.stride(0), att_out.stride(1), att_out.stride(2), + o.stride(0), o.stride(1), + MAX_KV_SPLITS=MAX_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + PAGE_SIZE=page_size, + num_warps=4, + num_stages=2, + ) + diff --git a/vortex_torch/flow/__init__.py b/vortex_torch/flow/__init__.py index b2fcadc2..bb60b895 100644 --- a/vortex_torch/flow/__init__.py +++ b/vortex_torch/flow/__init__.py @@ -34,9 +34,11 @@ class BlockSparseAttention(vFlow): from .registry import register from .loader import build_vflow from . import algorithms +from . import external_algorithms __all__ = [ "vFlow", "register", "build_vflow", - "algorithms" + "algorithms", + "external_algorithms", ] \ No newline at end of file diff --git a/vortex_torch/flow/external_algorithms.py b/vortex_torch/flow/external_algorithms.py new file mode 100644 index 00000000..5f8935fa --- /dev/null +++ b/vortex_torch/flow/external_algorithms.py @@ -0,0 +1,76 @@ +""" +External sparse attention algorithm registrations for NSA, FSA, and FlashMoBA. + +These vFlow subclasses use simple centroid-based routing for the DECODE path +(forward_indexer + forward_cache), identical to BlockSparseAttention. + +The EXTEND path (forward_extend) is handled directly in vtx_graph_backend.py +using each algorithm's own sparse attention kernel — these vFlow classes are +not involved in extend. +""" + +import torch +from typing import Dict, Tuple + +from .flow import vFlow +from ..indexer import topK, GeMV +from ..cache import Mean as CMean +from ..abs import ContextBase +from .registry import register + + +class _ExternalAlgoBase(vFlow): + """ + Base vFlow for external sparse attention algorithms (NSA, FSA, FlashMoBA). + + Decode routing: centroid-based (same as BlockSparseAttention). + Extend: bypassed — vtx_graph_backend dispatches to algorithm-specific kernels. + """ + + def __init__(self): + super().__init__() + self.gemv = GeMV() + self.output_func = topK() + self.reduction = CMean(dim=1) + + def forward_indexer( + self, + q: torch.Tensor, + o: torch.Tensor, + cache: Dict[str, torch.Tensor], + ctx: ContextBase, + ): + q_mean = q.mean(dim=1, keepdim=True) + score = self.gemv(q_mean, cache["centroids"], ctx=ctx) + self.output_func(score, o, ctx=ctx) + + def forward_cache( + self, + cache: Dict[str, torch.Tensor], + loc: torch.Tensor, + ctx: ContextBase, + ): + self.reduction(cache["k"], cache["centroids"], loc=loc, ctx=ctx) + + def create_cache(self, page_size: int, head_dim: int) -> Dict[str, Tuple[int, int]]: + return { + "centroids": (1, head_dim), + } + + +@register("nsa") +class NSASparseAttention(_ExternalAlgoBase): + """Naive Sparse Attention — decode uses centroid routing, extend uses NSA kernels.""" + pass + + +@register("fsa") +class FSASparseAttention(_ExternalAlgoBase): + """Flash Sparse Attention — decode uses centroid routing, extend uses FSA kernels.""" + pass + + +@register("flash_moba") +class FlashMoBASparseAttention(_ExternalAlgoBase): + """FlashMoBA — decode uses centroid routing, extend uses FlashMoBA kernels.""" + pass diff --git a/vortex_torch/flow/flow.py b/vortex_torch/flow/flow.py index 7efc80e9..7da5c72c 100644 --- a/vortex_torch/flow/flow.py +++ b/vortex_torch/flow/flow.py @@ -431,6 +431,7 @@ def run_indexer_virtual(self, group_size: int, page_size: int, head_dim: int): ctx.page_size = page_size ctx.max_num_pages = 0 ctx.max_num_pages_per_request = 0 + ctx.topk_type = "naive" device = "cuda" dtype = torch.bfloat16 diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index 6d3c586c..17dea66c 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -22,7 +22,9 @@ class Context(ContextBase): # hardware / paging "num_sms", "page_size", "max_num_pages", "max_num_pages_per_request", # misc - "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", + "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", + "topk_mapping_mode", "topk_mapping_power", + "topk_histogram_enabled", # auxilary memory in graph "_aux_total_bytes", @@ -68,6 +70,10 @@ class Context(ContextBase): topk_val: int #: Top-K value used in pruning or selection. page_reserved_bos: int #: Reserved page count for BOS (begin-of-sequence). page_reserved_eos: int #: Reserved page count for EOS (end-of-sequence). + topk_type: str #: TopK kernel type: "naive", "sglang" (unmapped) or "sglang_fused" (remap+topk). + topk_mapping_mode: int #: TopK mapping mode for sglang_fused (0=none, 3=power, 4=log, 6=asinh, 7=log1p, 9=erf, 10=tanh, 13=exp_stretch). + topk_mapping_power: float #: Hyperparameter (p / alpha / beta) for the active mapping mode. + topk_histogram_enabled: bool #: Enable histogram profiling during inference (default False). # --- auxiliary --- _aux_total_bytes: int #: Accumulated auxiliary memory in bytes. @@ -144,12 +150,16 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.page_reserved_bos = sa.vortex_page_reserved_bos self.page_reserved_eos = sa.vortex_page_reserved_eos + self.topk_type = getattr(sa, "vortex_topk_type", "naive") + self.topk_mapping_mode = getattr(sa, "vortex_topk_mapping_mode", 0) + self.topk_mapping_power = getattr(sa, "vortex_topk_mapping_power", 0.5) + self.topk_histogram_enabled = getattr(sa, "vortex_topk_histogram", False) + + device = getattr(model_runner, "device", "cpu") self.max_num_workloads = ( (self.max_num_pages // max(1, sa.vortex_lb_min_chunk_size)) + max_bs * self.num_kv_heads ) - - device = getattr(model_runner, "device", "cpu") self.winfo_q_indices = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) self.winfo_kv_offsets = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) self.winfo_kv_lens = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 5df795b6..889e0682 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,10 +1,21 @@ import torch -from typing import Dict, Callable, Optional +from typing import Dict, Callable, List, Optional from ..abs import vOp -from vortex_torch_C import topk_output +from vortex_torch_C import topk_output, topk_output_sglang, topk_output_sglang_fused, topk_profile_histogram from .context import Context from ..abs import vTensor, FORMAT +# --- Module-level histogram accumulator for offline calibration --- +_calibration_histograms: List[torch.Tensor] = [] + +def get_calibration_histograms() -> List[torch.Tensor]: + """Return collected histogram tensors (each [eff_bs, 256] int32 on CPU).""" + return _calibration_histograms + +def clear_calibration_histograms() -> None: + """Clear all collected calibration histograms.""" + _calibration_histograms.clear() + class topK(vOp): r""" Piecewise top-k dispatcher for packed sequences with reserved pages. @@ -75,13 +86,19 @@ class topK(vOp): """ # Dispatch by input format; only RAGGED is supported for now. - _impl_map: Dict[FORMAT, Callable] = { - FORMAT.RAGGED: topk_output, + _impl_map: Dict[FORMAT, Dict[str, Callable]] = { + FORMAT.RAGGED: { + "naive": topk_output, + "sglang": topk_output_sglang, + "sglang_fused": topk_output_sglang_fused, + }, } def __init__(self): super().__init__() self.impl: Optional[Callable] = None + self.topk_type: str = "naive" + self.last_histograms: Optional[torch.Tensor] = None # ---------------- profile ---------------- def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: @@ -152,7 +169,13 @@ def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: f"{prefix}no implementation for x._format={x_fmt}. " f"Available: {list(self._impl_map.keys())}" ) - self.impl = self._impl_map[x_fmt] + self.topk_type = getattr(ctx, "topk_type", "naive") + impl_variants = self._impl_map[x_fmt] + assert self.topk_type in impl_variants, ( + f"{prefix}no topk implementation for topk_type='{self.topk_type}'. " + f"Available: {list(impl_variants.keys())}" + ) + self.impl = impl_variants[self.topk_type] # ---- optional sanity checks on `o` ---- # We only assert device consistency and leave exact (S_pack, D0, D1) @@ -220,16 +243,85 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso prefix = self._prefix() assert self.impl is not None, f"{prefix}execute called before profile() (impl is None)" - self.impl( - x, - ctx.dense_kv_indptr, - ctx.sparse_kv_indptr, - ctx.dense_kv_indices, - o, - ctx.batch_size * ctx.num_kv_heads, - ctx.topk_val, - ctx.page_reserved_bos, - ctx.page_reserved_eos, - ctx.max_num_pages_per_request, - ) + if self.topk_type == "sglang": + # topk_output_sglang: unmapped baseline (no remap). + self.impl( + x, + ctx.dense_kv_indptr, + ctx.sparse_kv_indptr, + ctx.dense_kv_indices, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) + elif self.topk_type == "sglang_fused": + # topk_output_sglang_fused: single-launch fused remap + topk. + mapping_mode = getattr(ctx, 'topk_mapping_mode', 0) + mapping_power = getattr( + ctx, 'topk_mapping_hparam', + getattr(ctx, 'topk_mapping_power', 0.5), + ) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.sparse_kv_indptr, + ctx.dense_kv_indices, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + int(mapping_mode), + float(mapping_power), + ) + else: + # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.dense_kv_indices, + ctx.sparse_kv_indptr, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) + + # Optional histogram profiling (default disabled, no overhead when off). + # Skip entirely during CUDA graph capture — allocations and D2H copies + # are not permitted while a stream is being captured. + if ( + getattr(ctx, 'topk_histogram_enabled', False) + and self.topk_type in ("sglang", "sglang_fused") + and not torch.cuda.is_current_stream_capturing() + ): + eff_bs = ctx.batch_size * ctx.num_kv_heads + self.last_histograms = torch.zeros(eff_bs, 256, dtype=torch.int32, device=x.device) + hist_mode = 0 + hist_power = 0.5 + if self.topk_type == "sglang_fused": + hist_mode = int(getattr(ctx, 'topk_mapping_mode', 0)) + hist_power = float(getattr( + ctx, 'topk_mapping_hparam', + getattr(ctx, 'topk_mapping_power', 0.5), + )) + topk_profile_histogram( + x, + ctx.dense_kv_indptr, + self.last_histograms, + eff_bs, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + hist_mode, + hist_power, + ) + # Accumulate histograms for offline calibration + _calibration_histograms.append(self.last_histograms.cpu().clone()) + return o diff --git a/vortex_torch/indexer/utils_sglang.py b/vortex_torch/indexer/utils_sglang.py index 74b8cfe6..343207fc 100644 --- a/vortex_torch/indexer/utils_sglang.py +++ b/vortex_torch/indexer/utils_sglang.py @@ -40,7 +40,7 @@ def plan_decode( ctx.max_chunk_size, ctx.min_chunk_size ) - + ctx.set_batch_size(cached_seq_lens.shape[0]) diff --git a/vortex_torch/kernels/__init__.py b/vortex_torch/kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vortex_torch/kernels/fsa/__init__.py b/vortex_torch/kernels/fsa/__init__.py new file mode 100644 index 00000000..25d5b3eb --- /dev/null +++ b/vortex_torch/kernels/fsa/__init__.py @@ -0,0 +1,5 @@ +from .fused_score_kernels import _fused_attention_score_and_transform + +__all__ = [ + "_fused_attention_score_and_transform", +] diff --git a/vortex_torch/kernels/fsa/fused_score_kernels.py b/vortex_torch/kernels/fsa/fused_score_kernels.py new file mode 100644 index 00000000..f2a05ed8 --- /dev/null +++ b/vortex_torch/kernels/fsa/fused_score_kernels.py @@ -0,0 +1,300 @@ +# This file provides a fused implementation of computing attention score for selected attention indices. +# TODO: this implementation may incur illegal memory access issues, will be fixed. +import math + +import torch +import triton +import triton.language as tl + +from ..nsa.utils import is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def fused_score_kernel( + q_ptr, # q_len x h x d + k_ptr, # k_len x h x d + lse_ptr, # h x n + bs_ptr, # h x n x nb + offs_ptr, # BO + kernel_size, + kernel_stride, + num_offs, # BO + num_k_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, # which is also num_q_heads + HEAD_DIM, + # sm_scale + sm_scale, + max_blocks, + pad_len, + block_size, + block_stride, + init_blocks, + local_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_lh, + stride_ln, + stride_bsh, + stride_bsq, + stride_bsnb, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_bkh = tl.program_id(0) + pid_b = pid_bkh // NUM_KV_HEADS + pid_kh = pid_bkh % NUM_KV_HEADS + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) # the blocks id of k + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + + k_start += pid_k * BLOCK_SIZE_K * num_k_blocks + if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: + return + + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_kh * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_kh * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # load q and lse + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + + for j in range(num_k_blocks): + k_start_j = k_start + j * BLOCK_SIZE_K + if k_start_j < k_len: + off_d = tl.arange(0, BLOCK_SIZE_D) + off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + # k offsets + off_k = (k_start_j + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len + k_ptrs = k_ptr + pid_kh * stride_kh + off_k[None, :] * stride_kn + off_d[:, None] * stride_kd + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + + # init block score + bs = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + for i in range(num_offs): + k = tl.load(k_ptrs, mask=causal_mask, other=0) + w = tl.load(offs_ptr + i, mask=i < num_offs, other=0) + # compute qk + qk = tl.dot(q, k) * qk_scale + # compute score and apply weight + bs += w * tl.where(causal_mask, tl.exp2(qk - lse), 0) + + # increment pointers + off_k += 1 + k_ptrs = k_ptr + pid_kh * stride_kh + off_k[None, :] * stride_kn + off_d[:, None] * stride_kd + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + + # init mask and local mask + off_bq = off_q // block_size + off_bk = tl.arange(0, BLOCK_SIZE_K) + bs = tl.where( + ( + (off_bq[:, None] >= k_start_j + off_bk[None, :]) + & (off_bq[:, None] < k_start_j + off_bk[None, :] + local_blocks) + ) + | (off_bk[None, :] < init_blocks - k_start_j), + float("inf"), + bs, + ) + + # save output + bs_ptrs = ( + bs_ptr + + pid_kh.to(tl.int64) * stride_bsh + + q_start * stride_bsq + + k_start_j * stride_bsnb + + off_q[:, None] * stride_bsq + + off_bk[None, :] * stride_bsnb + ) + + tl.store( + bs_ptrs, + bs.to(bs_ptr.dtype.element_ty), + mask=(off_q < q_len)[:, None] & (off_bk < max_blocks - k_start_j)[None, :], + ) + + +def _fused_attention_score_and_transform( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + kernel_size: int, + kernel_stride: int, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + init_blocks: int = 1, + local_blocks: int = 2, + align_baseline: bool = False, +) -> torch.Tensor: + + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + max_blocks = math.ceil(max_seqlen_q / block_size) + # init block score + block_scores = torch.zeros( + num_k_heads, + q_len, + max_blocks, + dtype=torch.float32 if align_baseline else torch.bfloat16, + device=q.device, + ) + offs = ( + torch.arange(kernel_size // kernel_stride, device=q.device)[:, None] + + torch.arange(block_size // kernel_stride, device=q.device)[None, :] + ).view(-1) + + offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max()) + + num_offs = int(offs.shape[0]) + for i in range(cu_seqlens_q.shape[0] - 1): + q_seq = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_seq = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + lse_seq = lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + block_scores_seq = block_scores[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + + _fused_attention_score_and_transform_per_seq( + q_seq, + k_seq, + lse_seq, + block_scores_seq, + kernel_size, + kernel_stride, + block_size, + offs, + num_offs, + cu_seqlens_q[i: i + 2] - cu_seqlens_q[i], + cu_seqlens_k[i: i + 2] - cu_seqlens_k[i], + cu_seqlens_q[i + 1] - cu_seqlens_q[i], + cu_seqlens_k[i + 1] - cu_seqlens_k[i], + sm_scale, + init_blocks, + local_blocks, + ) + block_scores[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = block_scores_seq + return block_scores + + +@torch.inference_mode() +def _fused_attention_score_and_transform_per_seq( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + block_score: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + offs, + num_offs, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + init_blocks: int = 1, + local_blocks: int = 2, +) -> torch.Tensor: + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert lse.dtype == torch.float32 # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale))) + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len > k_len + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) + + max_blocks = math.ceil(max_seqlen_q / block_size) + + pad_len = kernel_size // kernel_stride - 1 + max_blocks = math.ceil(max_seqlen_q / block_size) + + BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks)) + # ensure qk is valid on triton + BLOCK_SIZE_K = max(BLOCK_SIZE_K, 16) + BLOCK_SIZE_Q = 128 + + # launch kernel + num_k_blocks = 1 + grid = lambda META: ( + batch_size * num_k_heads, + triton.cdiv(max_seqlen_q, BLOCK_SIZE_Q), + triton.cdiv(max_blocks, BLOCK_SIZE_K * num_k_blocks), + ) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + fused_score_kernel[grid]( + q, + k, + lse, + block_score, + offs, + kernel_size, + kernel_stride, + num_offs, + num_k_blocks, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + head_dim, + sm_scale, + max_blocks, + pad_len, + block_size, + block_size // kernel_stride, + init_blocks, + local_blocks, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + lse.stride(0), + lse.stride(1), + block_score.stride(0), + block_score.stride(1), + block_score.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) diff --git a/vortex_torch/kernels/nsa/__init__.py b/vortex_torch/kernels/nsa/__init__.py new file mode 100644 index 00000000..9af30295 --- /dev/null +++ b/vortex_torch/kernels/nsa/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2025 Xunhao Lai. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .compressed_attention import compressed_attention +from .weighted_pool import (avgpool_compress, softmaxpool_compress, + weightedpool_compress) + +__all__ = [ + "compressed_attention", + "avgpool_compress", + "weightedpool_compress", + "softmaxpool_compress", +] diff --git a/vortex_torch/kernels/nsa/compressed_attention.py b/vortex_torch/kernels/nsa/compressed_attention.py new file mode 100644 index 00000000..9770a942 --- /dev/null +++ b/vortex_torch/kernels/nsa/compressed_attention.py @@ -0,0 +1,1317 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from typing import Any, Tuple, Union + +import torch +import triton +import triton.language as tl + +from .utils import get_num_warps_stages, is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + v_ptr, # V: n x h x d + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # size and stride at compresstion + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + # skip first kernel_size query block, because they do no attend to any keys + q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 + if q_start_in_seq >= q_len: + return + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq + off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + # attention + lo = 0 + hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) + for i in range(lo, hi, BLOCK_SIZE_K): + i = tl.multiple_of(i, BLOCK_SIZE_K) + # load k + k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")) + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) + # update ptrs + k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # final scale + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + q_start * stride_on + pid_h * stride_oh, + shape=(q_len, HEAD_DIM), + strides=(stride_on, stride_od), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln + tl.store(l_ptrs, lse_i, mask=off_q < q_len) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len) + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DV: sh x n x kh x d + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = pid_k * BLOCK_SIZE_K * kernel_stride + tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1 + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(HEAD_DIM, q_len), + strides=(stride_qd, stride_qn), + offsets=(0, q_lo), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(HEAD_DIM, q_len), + strides=(stride_dod, stride_don), + offsets=(0, q_lo), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), + order=(0, 1), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(1, q_len), + strides=(0, stride_dn), + offsets=(0, q_lo), + block_shape=(1, BLOCK_SIZE_Q), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(1, q_len), + strides=(0, stride_ln), + offsets=(0, q_lo), + block_shape=(1, BLOCK_SIZE_Q), + order=(0, 1), + ) + # loop for q blocks + for i in range(q_lo, q_len, BLOCK_SIZE_Q): + # load + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf")) + qk += tl.dot(k, q) * qk_scale + # compute p, ds + # [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + p = tl.exp2(qk - lse) + # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + dp = tl.dot(v, do) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + # [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM] + dk += tl.dot(ds, tl.trans(q)) + dv += tl.dot(p, tl.trans(do)) + # increment pointers + q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q)) + do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q)) + lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q)) + d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q)) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + # skip first kernel_size query block, because they do no attend to any keys + q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 + if q_start_in_seq >= q_len: + return + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh, + shape=(q_len, HEAD_DIM), + strides=(stride_dqn, stride_dqd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq + off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + lo = 0 + hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) + for i in range(lo, hi, BLOCK_SIZE_K): + # load + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")) + qk += tl.dot(q, tl.trans(k)) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # increment pointers + k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0)) + v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K)) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _compressed_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale: float, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert k_len == v_len and q_len > k_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.zeros_like(q) + lse = torch.full( + (num_q_heads, q_len), + fill_value=-torch.inf, + dtype=torch.float32, + device=q.device, + ) + # launch kernel + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + forward_kernel[grid]( + q, + k, + v, + o, + lse, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _compressed_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale: float, +): + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + backward_dkdv[grid]( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + backward_dq[grid]( + q, + k, + v, + lse, + delta, + do, + dq, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class CompressedAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + o, lse = _compressed_attention_fwd( + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.kernel_size = kernel_size + ctx.kernel_stride = kernel_stride + return o, lse + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + kernel_size = ctx.kernel_size + kernel_stride = ctx.kernel_stride + + dq, dk, dv = _compressed_attention_bwd( + o, + do, + lse, + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +@triton.jit +def score_kernel( + q_ptr, + k_ptr, + lse_ptr, + s_ptr, + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_lh, + stride_ln, + stride_sh, + stride_sq, + stride_sk, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_bkh = tl.program_id(0) + pid_b = pid_bkh // NUM_KV_HEADS + pid_kh = pid_bkh % NUM_KV_HEADS + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: + return + # init k pointer and load k + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + # init score + s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_kh * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_kh * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # load q and lse + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k) * qk_scale + # compute score + s += tl.where(causal_mask, tl.exp2(qk - lse), 0) + # save output + s_ptrs = tl.make_block_ptr( + base=s_ptr + pid_kh * stride_sh + q_start * stride_sq, + shape=(q_len, k_len), + strides=(stride_sq, stride_sk), + offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K), + order=(1, 0), + ) + tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_attention_score( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +) -> torch.Tensor: + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert lse.dtype == torch.float32 # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale))) + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len > k_len + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) + # gqa + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # init score + score = torch.zeros(num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device) + + # launch kernel + grid = lambda META: ( + batch_size * num_k_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + score_kernel[grid]( + q, + k, + lse, + score, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + lse.stride(0), + lse.stride(1), + score.stride(0), + score.stride(1), + score.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) + return score + + +@triton.jit +def _transform_score_kernel( + s_ptr, # score, shape: [num_heads, q_len, k_len] + bs_ptr, # block wise score: [num_heads, q_len, num_k_block] + offs, + cu_seqlens_q, + # shape + num_heads, + num_offs, + max_k_len, + max_blocks, + pad_len, + # kernel & block size + block_size, + block_stride, # block_size // kernel_stride + init_blocks, + local_blocks, + # stride + stride_sh, + stride_sq, + stride_sk, + stride_bsh, + stride_bsq, + stride_bsk, + TOTAL_QUERY_LEN: tl.constexpr, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_O: tl.constexpr, +): + pid_bh = tl.program_id(0) + pid_b = pid_bh // num_heads + pid_h = pid_bh % num_heads + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = pid_k * BLOCK_SIZE_K + if pid_q * BLOCK_SIZE_Q >= q_len: + return + # load weight + off_o = tl.arange(0, BLOCK_SIZE_O) + w = tl.load(offs + off_o, mask=off_o < num_offs, other=0) + # load score + off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len + off_k = off_k[None, :] + off_o[:, None] + s_ptrs = ( + s_ptr + + q_start * stride_sq + + pid_h * stride_sh + + off_q[:, None, None] * stride_sq + + off_k[None, :, :] * stride_sk + ) + # weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK] + s = tl.load( + s_ptrs, + mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len), + other=0, + ) + s = s * w[None, :, None] + s = tl.sum(s, axis=1) + # init mask and local mask + off_bq = off_q // block_size + off_bk = k_start + tl.arange(0, BLOCK_SIZE_K) + s = tl.where( + ((off_bq[:, None] >= off_bk[None, :]) & (off_bq[:, None] < off_bk[None, :] + local_blocks)) + | (off_bk[None, :] < init_blocks - k_start), + float("inf"), + s, + ) + # store block wise score + bs_ptrs = ( + bs_ptr + q_start * stride_bsq + pid_h * stride_bsh + off_q[:, None] * stride_bsq + off_bk[None, :] * stride_bsk + ) + tl.store( + bs_ptrs, + s, + mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :], + ) + + +def transform_score( + score: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + init_blocks: int = 1, + local_blocks: int = 2, +) -> torch.Tensor: + num_k_heads, total_query_len, max_key_len = score.shape + batch_size = cu_seqlens_q.shape[0] - 1 + pad_len = kernel_size // kernel_stride - 1 + max_blocks = math.ceil(max_seqlen_q / block_size) + block_score = torch.zeros( + num_k_heads, + total_query_len, + max_blocks, + dtype=torch.float32, + device=score.device, + ) + offs = ( + torch.arange(kernel_size // kernel_stride, device=score.device)[:, None] + + torch.arange(block_size // kernel_stride, device=score.device)[None, :] + ).view(-1) + + offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max()) + + num_offs = int(offs.shape[0]) + + BLOCK_SIZE_Q = 16 + BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks)) + BLOCK_SIZE_O = triton.next_power_of_2(num_offs) + + def grid(meta): + grid = ( + num_k_heads * batch_size, + triton.cdiv(total_query_len, BLOCK_SIZE_Q), + triton.cdiv(max_blocks, BLOCK_SIZE_K), + ) + return grid + + _transform_score_kernel[grid]( + score, + block_score, + offs, + cu_seqlens_q, + num_k_heads, + offs.shape[0], + max_key_len, + max_blocks, + pad_len, + block_size, + block_size // kernel_stride, + init_blocks, + local_blocks, + score.stride(0), + score.stride(1), + score.stride(2), + block_score.stride(0), + block_score.stride(1), + block_score.stride(2), + TOTAL_QUERY_LEN=total_query_len, + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_O=BLOCK_SIZE_O, + num_warps=4, + num_stages=3, + ) + return block_score + + +def compressed_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + topk: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float = None, + init_blocks: int = 1, + local_blocks: int = 2, + parallel_topk_compute: Union[str, bool] = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention. + + Args: + q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] + kernel_size (int): kernel size in compress_key_value + kernel_stride (int): stride of compress_key_value + block_size (int): key value block size for topk sparse attention. + topk (int): number of blocks for each query. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. + max_seqlen_q (int): max q len of the batch. + max_seqlen_k (int): max k len of the batch. + sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). + init_blocks (int, optional): Number of init blocks for each query. Defaults to 1. + local_blocks (int, optional): Number of local blocks for each query. Defaults to 2. + parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug. + We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention + """ + + if max_seqlen_q is None: + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + if max_seqlen_k is None: + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + + attn_output, lse = CompressedAttention.apply( + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + # do not select topk index + if topk <= 0: + warnings.warn("topk <= 0, returned topk_idx will be None") + return attn_output, None + + assert topk >= init_blocks + local_blocks + with torch.no_grad(): + num_k_heads, num_q_heads = k.shape[1], q.shape[1] + num_shared_q_heads = num_q_heads // num_k_heads + batch_size = cu_seqlens_q.shape[0] - 1 + q_idx = torch.cat( + [torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) for i in range(batch_size)], + dim=0, + ) + q_idx = q_idx // block_size + + # whether to use parallel version + if parallel_topk_compute == "auto": + parallel_topk_compute = cu_seqlens_q[-1] <= 32768 + # parallel version + if parallel_topk_compute: + # recompute score + score = _get_attention_score( + q, + k, + lse, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + # transform score to block-wise score + score = transform_score( + score, + kernel_size, + kernel_stride, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + init_blocks, + local_blocks, + ) + # get topk + topk = min(topk, score.shape[-1]) + topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values + topk_idx[topk_idx > q_idx[None, :, None]] = -1 + topk_idx = topk_idx.to(torch.int32) + # non parallel version, avoid some current bugs when sequence length is too long + # FIXME: need to fix later + else: + topk_idx_list = [] + head_tile = 1 + assert num_k_heads % head_tile == 0, f"Num kv heads: {num_k_heads}, head_tile: {head_tile}" + for h in range(num_k_heads // head_tile): + # recompute score + score = _get_attention_score( + q[:, h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile], + k[:, h * head_tile: (h + 1) * head_tile], + lse[h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile], + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + # transform score to block-wise score + score = transform_score( + score, + kernel_size, + kernel_stride, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + init_blocks, + local_blocks, + ) + # get topk + topk = min(topk, score.shape[-1]) + if score.dtype == torch.float32: + score = score.to(torch.bfloat16) + topk_idx = score.topk(topk, dim=-1, sorted=False).indices + topk_idx = topk_idx.sort(-1).values + + topk_idx[topk_idx > q_idx[None, :, None]] = -1 + topk_idx = topk_idx.to(torch.int32) + topk_idx_list.append(topk_idx) + topk_idx = torch.cat(topk_idx_list, dim=0) + + return attn_output, topk_idx diff --git a/vortex_torch/kernels/nsa/flash_attention.py b/vortex_torch/kernels/nsa/flash_attention.py new file mode 100644 index 00000000..c556a4c4 --- /dev/null +++ b/vortex_torch/kernels/nsa/flash_attention.py @@ -0,0 +1,886 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_num_warps_stages, is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + v_ptr, # V: n x h x d + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + if gqa_interleave: + pid_kh = pid_h % NUM_KV_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_Q * pid_q >= q_len: + return + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + # full attention or causal attention + lo = 0 + if causal: + hi = min(k_len, (pid_q + 1) * BLOCK_SIZE_Q) + else: + hi = k_len + for i in range(lo, hi, BLOCK_SIZE_K): + i = tl.multiple_of(i, BLOCK_SIZE_K) + # load k + k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + if causal: + qk += tl.where(off_q[:, None] >= (i + off_k)[None, :], 0, float("-inf")) + else: + qk += tl.where((off_k < k_len - i)[None, :], 0, float("-inf")) + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.math.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij) + # update ptrs + k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # final scale + acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + q_start * stride_on + pid_h * stride_oh, + shape=(q_len, HEAD_DIM), + strides=(stride_on, stride_od), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln + tl.store(l_ptrs, lse_i, mask=off_q < q_len) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len) + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DV: sh x n x kh x d + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + if gqa_interleave: + pid_kh = pid_h % NUM_SHARE_Q_HEADS + pid_sh = pid_h // NUM_SHARE_Q_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # causal + if causal: + q_lo = pid_k * BLOCK_SIZE_K + else: + q_lo = 0 + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # loop for q blocks + for i in range(q_lo, q_len, BLOCK_SIZE_Q): + # load + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + if causal: + qk = tl.where((off_q + i)[:, None] >= off_k[None, :], float(0.0), float("-inf")) + else: + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.math.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # increment pointers + q_ptrs = tl.advance(q_ptrs, (BLOCK_SIZE_Q, 0)) + do_ptrs = tl.advance(do_ptrs, (BLOCK_SIZE_Q, 0)) + lse_ptrs = tl.advance(lse_ptrs, (BLOCK_SIZE_Q, 0)) + d_ptrs = tl.advance(d_ptrs, (BLOCK_SIZE_Q, 0)) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + if gqa_interleave: + pid_kh = pid_h % NUM_KV_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_Q * pid_q >= q_len: + return + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh, + shape=(q_len, HEAD_DIM), + strides=(stride_dqn, stride_dqd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + # causal + if causal: + k_hi = (pid_q + 1) * BLOCK_SIZE_Q + else: + k_hi = k_len + for j in range(0, k_hi, BLOCK_SIZE_K): + # load + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + if causal: + qk = tl.where(off_q[:, None] >= (off_k + j)[None, :], float(0.0), float("-inf")) + else: + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.math.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # increment pointers + k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _flash_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool, + sm_scale: float, + gqa_interleave: bool = False, +): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + # assert q_len == k_len and k_len == v_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.empty_like(q) + lse = torch.empty(num_q_heads, q_len, dtype=torch.float32, device=q.device) + # launch kernel + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + forward_kernel[grid]( + q, + k, + v, + o, + lse, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _flash_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool, + sm_scale: float, + gqa_interleave: bool = False, +): + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + # compute D + delta = torch.empty([num_o_heads, o_len], device=o.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # compute dk dv + dk = torch.empty(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.empty(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + backward_dkdv[grid]( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.empty_like(q) + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + backward_dq[grid]( + q, + k, + v, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class FlashAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal=True, + sm_scale=None, + gqa_interleave=False, + ): + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + o, lse = _flash_attention_fwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.causal = causal + ctx.gqa_interleave = gqa_interleave + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + causal = ctx.causal + gqa_interleave = ctx.gqa_interleave + dq, dk, dv = _flash_attention_bwd( + o, + do, + lse, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool = False, + sm_scale: Optional[float] = None, + gqa_interleave: bool = False, +) -> torch.Tensor: + """Flash attention with variable length based on triton. + + Args: + q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim] + v (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim] + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. + max_seqlen_q (torch.Tensor): max q len of the batch. + max_seqlen_k (torch.Tensor): max k len of the batch. + causal (bool, optional): Causal mask. Defaults to False. + sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). + gqa_interleave (bool, optional): GQA pattern. Defaults to False, use Llama style GQA. + + Returns: + torch.Tensor: attention output with shape [total_q_len, num_q_heads, head_dim] + """ + return FlashAttention.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) diff --git a/vortex_torch/kernels/nsa/utils.py b/vortex_torch/kernels/nsa/utils.py new file mode 100644 index 00000000..1f158a17 --- /dev/null +++ b/vortex_torch/kernels/nsa/utils.py @@ -0,0 +1,50 @@ +import torch + + +def is_hopper_gpu(): + if torch.cuda.is_available(): + device_capability = torch.cuda.get_device_capability(0) + major, minor = device_capability + return major == 9 + return False + + +def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): + """ + Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton. + + Args: + head_dim (int): Size of the head dimension. + block_size (int): Size of the block in the attention matrix. + is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU. + + Returns: + tuple: (num_warps, num_stages) recommended values. + """ + # Determine if head_dim and block_size exceed 64 + head_large = head_dim > 64 + block_large = block_size > 64 + + if is_hopper_gpu: + # Hopper GPU recommendations + if head_large and block_large: + num_warps = 8 + num_stages = 3 + elif head_large or block_large: + num_warps = 4 + num_stages = 3 + else: + num_warps = 2 + num_stages = 2 + else: + # Ampere GPU recommendations + if head_large and block_large: + num_warps = 8 + num_stages = 3 + elif head_large or block_large: + num_warps = 8 + num_stages = 3 + else: + num_warps = 2 + num_stages = 2 + return num_warps, num_stages diff --git a/vortex_torch/kernels/nsa/weighted_pool.py b/vortex_torch/kernels/nsa/weighted_pool.py new file mode 100644 index 00000000..abfe9d30 --- /dev/null +++ b/vortex_torch/kernels/nsa/weighted_pool.py @@ -0,0 +1,341 @@ +# Copyright 2025 Xunhao Lai. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import triton +import triton.language as tl +from einops import einsum + + +@triton.jit +def sliding_pool_fwd_kernel( + x_ptr, + y_ptr, + w_ptr, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + stride_xn, + stride_xh, + stride_xd, + stride_yn, + stride_yh, + stride_yd, + stride_wh, + stride_wk, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_k = tl.program_id(2) + # get start and len after rmpad + x_start = tl.load(cu_seqlens + pid_b) + x_len = tl.load(cu_seqlens + pid_b + 1) - x_start + y_start = tl.load(y_cu_seqlens + pid_b) + y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start + if pid_k >= y_len: + return + if w_ptr is not None: + # load w + w_ptrs = tl.make_block_ptr( + base=w_ptr + pid_h * stride_wh, + shape=(kernel_size, 1), + strides=(stride_wk, 0), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, 1), + order=(0, 1), + ) + w = tl.load(w_ptrs, boundary_check=(0, 1), padding_option="zero") + # load x + x_ptrs = tl.make_block_ptr( + base=x_ptr + x_start * stride_xn + pid_h * stride_xh, + shape=(x_len, head_dim), + strides=(stride_xn, stride_xd), + offsets=(pid_k * kernel_stride, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute y + if w_ptr is not None: + y = tl.sum(x * w, axis=0) + else: + y = tl.sum(x, axis=0) / kernel_size + off_d = tl.arange(0, BLOCK_SIZE_D) + tl.store( + y_ptr + (y_start + pid_k) * stride_yn + pid_h * stride_yh + off_d * stride_yd, + y.to(y_ptr.dtype.element_ty), + mask=off_d < head_dim, + ) + + +@triton.jit +def sliding_pool_dxdw_kernel( + x_ptr, + dx_ptr, + dy_ptr, + w_ptr, + dw_ptr, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + stride_xn, + stride_xh, + stride_xd, + stride_dxn, + stride_dxh, + stride_dxd, + stride_dyn, + stride_dyh, + stride_dyd, + stride_wh, + stride_wk, + stride_dwh, + stride_dwn, + stride_dwk, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_k = tl.program_id(2) + # get start and len after rmpad + x_start = tl.load(cu_seqlens + pid_b) + x_len = tl.load(cu_seqlens + pid_b + 1) - x_start + y_start = tl.load(y_cu_seqlens + pid_b) + y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start + if pid_k >= y_len: + return + # offsets + off_d = tl.arange(0, BLOCK_SIZE_D) + off_k = tl.arange(0, BLOCK_SIZE_K) + if w_ptr is not None: + # load w + w_ptrs = w_ptr + pid_h * stride_wh + off_k * stride_wk + w = tl.load(w_ptrs, mask=off_k < kernel_size, other=0) + # load x + x_ptrs = tl.make_block_ptr( + base=x_ptr + x_start * stride_xn + pid_h * stride_xh, + shape=(head_dim, x_len), + strides=(stride_xd, stride_xn), + offsets=(0, pid_k * kernel_stride), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero") + # load dy + dy_ptrs = dy_ptr + pid_h * stride_dyh + (y_start + pid_k) * stride_dyn + off_d * stride_dyd + dy = tl.load(dy_ptrs, mask=off_d < head_dim, other=0) + if w_ptr is not None: + # compute dx, [1, D] x [K, 1] -> [K, D] + dx = dy[None, :] * w[:, None] + # compute dw, [D, 1] x [D, K] -> [D, K] -> [K] + dw = tl.sum(dy[:, None] * x, axis=0) + # store dw + dw_ptrs = dw_ptr + pid_h * stride_dwh + (y_start + pid_k) * stride_dwn + off_k * stride_dwk + tl.store(dw_ptrs, dw.to(dw_ptr.dtype.element_ty), mask=off_k < kernel_size) + else: + dx = dy[None, :] / kernel_size + # store dx + dx_ptrs = ( + dx_ptr + + pid_h * stride_dxh + + (x_start + pid_k * kernel_stride + off_k[:, None]) * stride_dxn + + off_d[None, :] * stride_dxd + ) + tl.atomic_add( + dx_ptrs, + dx.to(dx_ptr.dtype.element_ty), + mask=(off_k < x_len - pid_k * kernel_stride)[:, None] & (off_d < head_dim)[None, :], + ) + + +class SlidingWindowWeightedPool(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # [num_heads, kernel_size] + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + ): + # dtype check + assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 + if w is not None: + assert x.dtype == w.dtype + assert cu_seqlens.dtype == torch.int32 + # shape check + total_len, num_heads, head_dim = x.shape + batch_size = cu_seqlens.shape[0] - 1 + if w is not None: + assert w.shape[0] == num_heads + assert w.shape[1] == kernel_size + assert kernel_size % kernel_stride == 0 + assert kernel_size in {16, 32, 64, 128} + # compute seqlens after compression + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 + # corner case, if sequence_length < kernel_size, no compression for this sequence + y_seqlens[seqlens < kernel_size] = 0 + y_cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(y_seqlens, dim=0), + ], + dim=0, + ).to(torch.int32) + # output buffer + y = torch.zeros(y_cu_seqlens[-1], num_heads, head_dim, dtype=x.dtype, device=x.device) + # launch kernel + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(kernel_size) + grid = (batch_size, num_heads, y_seqlens.max().item()) + sliding_pool_fwd_kernel[grid]( + x, + y, + w, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + x.stride(0), + x.stride(1), + x.stride(2), + y.stride(0), + y.stride(1), + y.stride(2), + w.stride(0) if w is not None else None, + w.stride(1) if w is not None else None, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + ctx.save_for_backward(x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens) + ctx.kernel_size = kernel_size + ctx.kernel_stride = kernel_stride + ctx.head_dim = head_dim + return y, y_cu_seqlens + + @staticmethod + def backward(ctx, dy, _): + x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens = ctx.saved_tensors + kernel_size = ctx.kernel_size + kernel_stride = ctx.kernel_stride + head_dim = ctx.head_dim + batch_size = cu_seqlens.shape[0] - 1 + num_heads = x.shape[1] + # compute dx + dx = torch.zeros_like(x, dtype=torch.float32) + if w is not None: + dw = torch.zeros( + num_heads, + y_cu_seqlens[-1], + kernel_size, + dtype=torch.float32, + device=w.device, + ) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(kernel_size) + grid = (batch_size, num_heads, y_seqlens.max().item()) + sliding_pool_dxdw_kernel[grid]( + x, + dx, + dy, + w, + dw if w is not None else None, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + x.stride(0), + x.stride(1), + x.stride(2), + dx.stride(0), + dx.stride(1), + dx.stride(2), + dy.stride(0), + dy.stride(1), + dy.stride(2), + w.stride(0) if w is not None else None, + w.stride(1) if w is not None else None, + dw.stride(0) if w is not None else None, + dw.stride(1) if w is not None else None, + dw.stride(2) if w is not None else None, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + dx = dx.to(x.dtype) + if w is None: + dw = None + else: + dw = dw.sum(1).to(w.dtype) + return dx, dw, None, None, None + + +def weightedpool_compress( + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # [num_heads, kernel_size] + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w, cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = einsum(pe, w, "h k d, h k -> h d") + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens + + +def avgpool_compress( + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # don't need weight + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + assert w is None, "don't need additional weight for avgpool" + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w, cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = torch.mean(pe, dim=1) + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens + + +def softmaxpool_compress( + x: torch.Tensor, + w: torch.Tensor, + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w.softmax(-1), cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = torch.mean(pe, dim=1) + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens